advanced_analyzer #8

Merged
dp merged 23 commits from advanced_analyzer into main 2026-04-21 22:32:18 +00:00
2 changed files with 354 additions and 16 deletions
Showing only changes of commit fe7dc4d5a1 - Show all commits

View File

@ -132,8 +132,13 @@ from __future__ import annotations
import argparse import argparse
import getpass import getpass
import multiprocessing as mp
import os
import queue as _queue_mod
import re import re
import sys import sys
import threading
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
@ -155,6 +160,7 @@ from load_sas import (
discover_partition_values_chunked, discover_partition_values_chunked,
infer_schema, infer_schema,
iter_sas_chunks, iter_sas_chunks,
read_sas_metadata,
read_sas_preview, read_sas_preview,
render_create_indexes, render_create_indexes,
render_create_table, render_create_table,
@ -702,7 +708,15 @@ def _discover_cluster_partitions(
return merged return merged
def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int: def load_cluster(
conn,
cluster: ClusterSpec,
schemaname: str,
*,
workers: int = 1,
progress_queue: Any = None,
db_overrides: Optional[Dict[str, Optional[str]]] = None,
) -> int:
"""Load every file in ``cluster`` into one table. Returns total rows loaded. """Load every file in ``cluster`` into one table. Returns total rows loaded.
When ``cluster.partition_by`` is non-empty, partition values are When ``cluster.partition_by`` is non-empty, partition values are
@ -713,6 +727,20 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int:
file mid-cluster fails, earlier chunks - including chunks from earlier file mid-cluster fails, earlier chunks - including chunks from earlier
files in the cluster - stay committed; only the in-flight chunk is files in the cluster - stay committed; only the in-flight chunk is
rolled back by :func:`main`. rolled back by :func:`main`.
``workers`` controls parallelism for the *append* phase. The first file
always runs serially on ``conn`` (to create the table and, when
partitioned, pre-create partitions). When ``workers > 1`` the remaining
files dispatch to a ``ProcessPoolExecutor``; each worker opens its own
psycopg2 connection, re-infers the per-file schema, runs the same
:func:`load_sas.assert_schema_compatible` check the serial path uses,
and streams chunks via COPY. Workers report per-chunk row counts to
``progress_queue`` so the caller can drive a single aggregated tqdm
bar regardless of how many workers are in flight.
``db_overrides`` carries ``{"user", "password"}`` into workers when the
caller prompted for credentials interactively; leave ``None`` to let
workers read the standard libpq environment variables on their own.
""" """
if not cluster.files: if not cluster.files:
return 0 return 0
@ -782,21 +810,44 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int:
conn, schemaname, cluster.tablename, first, first_columns, conn, schemaname, cluster.tablename, first, first_columns,
cluster.include, cluster.exclude, cluster.include, cluster.exclude,
total_rows=first_total_rows, total_rows=first_total_rows,
progress_queue=progress_queue,
) )
# Commit the first file (and the CREATE TABLE) before spawning workers
# so their ``assert_schema_compatible`` probes actually see the new
# table. Without this, worker connections started mid-transaction on
# the main connection would see nothing in information_schema.
conn.commit()
for path in rest: if rest:
columns, path_total_rows = _infer_cluster_schema( if workers > 1:
path, cluster.include, cluster.exclude total += _load_remaining_files_parallel(
) rest,
# Uses the same check that if_exists=append runs. A type mismatch or schemaname,
# missing column aborts the cluster; because chunks commit as they cluster.tablename,
# load, earlier chunks in the cluster remain in the table. cluster.include,
assert_schema_compatible(conn, schemaname, cluster.tablename, columns) cluster.exclude,
total += _stream_file( workers=workers,
conn, schemaname, cluster.tablename, path, columns, progress_queue=progress_queue,
cluster.include, cluster.exclude, db_overrides=db_overrides,
total_rows=path_total_rows, )
) else:
for path in rest:
columns, path_total_rows = _infer_cluster_schema(
path, cluster.include, cluster.exclude
)
# Uses the same check that if_exists=append runs. A type
# mismatch or missing column aborts the cluster; because
# chunks commit as they load, earlier chunks in the
# cluster remain in the table.
assert_schema_compatible(
conn, schemaname, cluster.tablename, columns
)
total += _stream_file(
conn, schemaname, cluster.tablename, path, columns,
cluster.include, cluster.exclude,
total_rows=path_total_rows,
progress_queue=progress_queue,
)
# -- Index support ------------------------------------------------------ # -- Index support ------------------------------------------------------
if cluster.indexes: if cluster.indexes:
@ -815,8 +866,25 @@ def _stream_file(
exclude, exclude,
*, *,
total_rows: Optional[int] = None, total_rows: Optional[int] = None,
progress_queue: Any = None,
) -> int: ) -> int:
"""Stream ``path`` into an existing table chunk by chunk.
When ``progress_queue`` is provided, each chunk's row count is published
to the queue as ``("rows", n)`` tuples instead of being rendered to a
per-file tqdm bar. That lets :func:`main` drive a single folder-wide
progress bar from a background drainer thread, which is the only way
to keep a coherent progress view when the folder loader is running
files in parallel workers.
"""
def _chunks(): def _chunks():
if progress_queue is not None:
for chunk_df, _chunk_meta in iter_sas_chunks(path):
chunk_df = apply_column_filter(chunk_df, include, exclude)
progress_queue.put(("rows", len(chunk_df)))
yield chunk_df
return
pbar = tqdm( pbar = tqdm(
total=total_rows, total=total_rows,
unit="row", unit="row",
@ -836,6 +904,134 @@ def _stream_file(
return copy_dataframes(conn, schemaname, tablename, _chunks(), columns) return copy_dataframes(conn, schemaname, tablename, _chunks(), columns)
# ---------------------------------------------------------------------------
# Parallel append workers
# ---------------------------------------------------------------------------
def _worker_load_append_file(
path_str: str,
schemaname: str,
tablename: str,
include: Optional[List[str]],
exclude: Optional[List[str]],
progress_queue: Any,
db_overrides: Optional[Dict[str, Optional[str]]],
) -> Tuple[str, int, Optional[str]]:
"""Worker process: load one SAS file in append mode.
Runs in a subprocess spawned by :func:`_load_remaining_files_parallel`.
Opens its own psycopg2 connection, re-infers the per-file schema (so
per-file ``INTEGER`` vs ``BIGINT`` drift is caught by the existing
schema-compat check just like in the serial path), and streams chunks
via ``COPY``. Row counts are published to the shared queue for the
main process's global tqdm bar.
Returns ``(path_str, rows_loaded, error_or_None)`` - failures are
returned rather than raised so the parent can aggregate results
across workers without losing partial progress.
"""
from pathlib import Path as _Path
from dotenv import load_dotenv as _load_dotenv
from load_sas import (
apply_column_filter as _apply_column_filter,
assert_schema_compatible as _assert_schema_compatible,
connect as _connect,
copy_dataframes as _copy_dataframes,
infer_schema as _infer_schema,
iter_sas_chunks as _iter_sas_chunks,
read_sas_preview as _read_sas_preview,
)
_load_dotenv()
path = _Path(path_str)
try:
preview_df, meta = _read_sas_preview(path)
preview_df = _apply_column_filter(preview_df, include, exclude)
total_rows = getattr(meta, "number_rows", None)
columns = _infer_schema(preview_df, meta, total_rows=total_rows)
user = db_overrides.get("user") if db_overrides else None
password = db_overrides.get("password") if db_overrides else None
conn = _connect(user=user, password=password)
conn.autocommit = False
try:
_assert_schema_compatible(conn, schemaname, tablename, columns)
def _chunks():
for chunk_df, _chunk_meta in _iter_sas_chunks(path):
chunk_df = _apply_column_filter(chunk_df, include, exclude)
if progress_queue is not None:
progress_queue.put(("rows", len(chunk_df)))
yield chunk_df
rows = _copy_dataframes(
conn, schemaname, tablename, _chunks(), columns
)
conn.commit()
return (path_str, rows, None)
finally:
conn.close()
except Exception as e:
return (path_str, 0, f"{type(e).__name__}: {e}")
def _load_remaining_files_parallel(
files: List[Path],
schemaname: str,
tablename: str,
include: Optional[List[str]],
exclude: Optional[List[str]],
*,
workers: int,
progress_queue: Any,
db_overrides: Optional[Dict[str, Optional[str]]],
) -> int:
"""Run append-mode loads for ``files`` across a process pool.
Each file is an independent unit of work submitted to
``ProcessPoolExecutor``. Workers infer schema, validate compatibility,
and stream via COPY just like the serial path. Failures are collected
and re-raised as a single ``RuntimeError`` at the end so that all
other workers' rows still count toward the committed total.
"""
total = 0
errors: List[Tuple[str, str]] = []
with ProcessPoolExecutor(max_workers=workers) as pool:
futures = [
pool.submit(
_worker_load_append_file,
str(p),
schemaname,
tablename,
include,
exclude,
progress_queue,
db_overrides,
)
for p in files
]
for fut in as_completed(futures):
path_str, rows, err = fut.result()
if err is not None:
errors.append((path_str, err))
else:
total += rows
if errors:
joined = "\n".join(f" {p}: {e}" for p, e in errors)
raise RuntimeError(
f"{len(errors)} worker(s) failed while appending to "
f"{schemaname}.{tablename}:\n{joined}"
)
return total
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# CLI # CLI
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -875,6 +1071,22 @@ def _build_argparser() -> argparse.ArgumentParser:
"PGUSER / PGPASSWORD from the environment or .env file." "PGUSER / PGPASSWORD from the environment or .env file."
), ),
) )
p.add_argument(
"--workers",
type=int,
default=1,
metavar="N",
help=(
"Number of worker processes for the append phase. With N=1 "
"(default) files load serially on the main connection. With "
"N>1 the first file of each cluster still runs serially (to "
"create the table), then the remaining files load in parallel "
"across N processes, each with its own psycopg2 connection. "
"On a big box try N close to your core count. When N>1 the "
"per-chunk row target drops to 500,000 unless you've pinned "
"GENERIC_LOADER_CHUNK_ROWS, so peak memory stays bounded."
),
)
return p return p
@ -993,6 +1205,101 @@ def main(argv: Optional[List[str]] = None) -> int:
if args.dbcreds: if args.dbcreds:
db_user = input("Database username: ") db_user = input("Database username: ")
db_password = getpass.getpass("Database password: ") db_password = getpass.getpass("Database password: ")
db_overrides: Optional[Dict[str, Optional[str]]] = (
{"user": db_user, "password": db_password} if args.dbcreds else None
)
workers = max(1, int(args.workers))
# When running parallel workers, bound peak memory: each worker buffers a
# chunk (read + prepared + serialized) so total memory scales with
# workers × chunk_rows × avg_row_bytes. Drop the default chunk target to
# 500k unless the operator has explicitly pinned it. Setting the env var
# before workers spawn means they inherit it through forkserver / spawn.
if (
workers > 1
and "GENERIC_LOADER_CHUNK_ROWS" not in os.environ
):
os.environ["GENERIC_LOADER_CHUNK_ROWS"] = "500000"
print(
"[info] parallel mode: bounding per-chunk rows to 500,000. "
"Pin GENERIC_LOADER_CHUNK_ROWS to override.",
file=sys.stderr,
)
# -- Metadata pre-scan -----------------------------------------------------
# Sum ``number_rows`` across every file so the tqdm bar has a real
# denominator. ``read_sas_metadata`` uses pyreadstat's ``metadataonly=True``
# fast path; a few ms per sas7bdat even on large files.
print(
f"pre-scanning row counts for {sum(len(c.files) for c in loadable)} "
f"file(s)...",
file=sys.stderr,
)
grand_total = 0
unknown_total_files: List[str] = []
for c in loadable:
for p in c.files:
try:
meta = read_sas_metadata(p)
n = getattr(meta, "number_rows", None)
if n is None:
unknown_total_files.append(p.name)
else:
grand_total += int(n)
except Exception as e:
unknown_total_files.append(f"{p.name} ({e})")
if unknown_total_files:
print(
f"[warn] could not read row count from "
f"{len(unknown_total_files)} file(s); progress bar ETA will "
f"be approximate.",
file=sys.stderr,
)
print(
f" total rows across folder: {grand_total:,}",
file=sys.stderr,
)
# -- Shared progress plumbing ---------------------------------------------
# The queue crosses process boundaries when workers > 1 (managed proxy)
# and is a plain in-process queue otherwise; the put/get contract is
# identical either way. A daemon thread drains it and advances the one
# tqdm bar that spans the whole folder load.
manager: Optional[Any] = None
progress_queue: Any
if workers > 1:
manager = mp.Manager()
progress_queue = manager.Queue()
else:
progress_queue = _queue_mod.Queue()
pbar = tqdm(
total=grand_total or None,
unit="row",
unit_scale=True,
desc=f"{cfg.folder.name}",
file=sys.stderr,
dynamic_ncols=True,
)
stop_drainer = threading.Event()
def _drainer() -> None:
while not stop_drainer.is_set():
try:
event = progress_queue.get(timeout=0.1)
except _queue_mod.Empty:
continue
except (EOFError, OSError):
return
if not event:
continue
kind = event[0]
if kind == "rows":
pbar.update(event[1])
drainer_thread = threading.Thread(target=_drainer, daemon=True)
drainer_thread.start()
conn = connect(user=db_user, password=db_password) conn = connect(user=db_user, password=db_password)
conn.autocommit = False conn.autocommit = False
@ -1002,10 +1309,18 @@ def main(argv: Optional[List[str]] = None) -> int:
for cluster in loadable: for cluster in loadable:
print( print(
f"\n>>> loading cluster {cluster.tablename!r} " f"\n>>> loading cluster {cluster.tablename!r} "
f"({len(cluster.files)} file(s))" f"({len(cluster.files)} file(s)) "
f"[workers={workers}]"
) )
try: try:
rows = load_cluster(conn, cluster, cfg.schemaname) rows = load_cluster(
conn,
cluster,
cfg.schemaname,
workers=workers,
progress_queue=progress_queue,
db_overrides=db_overrides,
)
conn.commit() conn.commit()
totals.append((cluster.tablename, len(cluster.files), rows)) totals.append((cluster.tablename, len(cluster.files), rows))
print( print(
@ -1022,7 +1337,14 @@ def main(argv: Optional[List[str]] = None) -> int:
if args.fail_fast: if args.fail_fast:
break break
finally: finally:
# Drain any pending progress events before shutting the bar down so
# the final rendered total matches what actually landed.
stop_drainer.set()
drainer_thread.join(timeout=2.0)
pbar.close()
conn.close() conn.close()
if manager is not None:
manager.shutdown()
print("\n=== summary ===") print("\n=== summary ===")
for name, fcount, rows in totals: for name, fcount, rows in totals:

View File

@ -580,6 +580,22 @@ def read_sas_preview(
return reader(str(Path(path)), row_limit=row_limit, **kwargs) return reader(str(Path(path)), row_limit=row_limit, **kwargs)
def read_sas_metadata(path: Path) -> Any:
"""Read only the metadata (no rows) from a SAS file.
Uses pyreadstat's ``metadataonly=True`` fast path: the reader decodes
the file header (column names, formats, total row count, etc.) and
returns without touching the data pages. Orders of magnitude faster
than :func:`read_sas_preview` when all you need is
``meta.number_rows`` - typically a few ms per sas7bdat file, which
makes it cheap to pre-scan a whole folder to populate a global
progress bar.
"""
reader, kwargs = _sas_reader(path)
_, meta = reader(str(Path(path)), metadataonly=True, **kwargs)
return meta
def iter_sas_chunks( def iter_sas_chunks(
path: Path, path: Path,
*, *,