advanced_analyzer #8

Merged
dp merged 23 commits from advanced_analyzer into main 2026-04-21 22:32:18 +00:00
2 changed files with 354 additions and 16 deletions
Showing only changes of commit fe7dc4d5a1 - Show all commits

View File

@ -132,8 +132,13 @@ from __future__ import annotations
import argparse
import getpass
import multiprocessing as mp
import os
import queue as _queue_mod
import re
import sys
import threading
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
@ -155,6 +160,7 @@ from load_sas import (
discover_partition_values_chunked,
infer_schema,
iter_sas_chunks,
read_sas_metadata,
read_sas_preview,
render_create_indexes,
render_create_table,
@ -702,7 +708,15 @@ def _discover_cluster_partitions(
return merged
def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int:
def load_cluster(
conn,
cluster: ClusterSpec,
schemaname: str,
*,
workers: int = 1,
progress_queue: Any = None,
db_overrides: Optional[Dict[str, Optional[str]]] = None,
) -> int:
"""Load every file in ``cluster`` into one table. Returns total rows loaded.
When ``cluster.partition_by`` is non-empty, partition values are
@ -713,6 +727,20 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int:
file mid-cluster fails, earlier chunks - including chunks from earlier
files in the cluster - stay committed; only the in-flight chunk is
rolled back by :func:`main`.
``workers`` controls parallelism for the *append* phase. The first file
always runs serially on ``conn`` (to create the table and, when
partitioned, pre-create partitions). When ``workers > 1`` the remaining
files dispatch to a ``ProcessPoolExecutor``; each worker opens its own
psycopg2 connection, re-infers the per-file schema, runs the same
:func:`load_sas.assert_schema_compatible` check the serial path uses,
and streams chunks via COPY. Workers report per-chunk row counts to
``progress_queue`` so the caller can drive a single aggregated tqdm
bar regardless of how many workers are in flight.
``db_overrides`` carries ``{"user", "password"}`` into workers when the
caller prompted for credentials interactively; leave ``None`` to let
workers read the standard libpq environment variables on their own.
"""
if not cluster.files:
return 0
@ -782,21 +810,44 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int:
conn, schemaname, cluster.tablename, first, first_columns,
cluster.include, cluster.exclude,
total_rows=first_total_rows,
progress_queue=progress_queue,
)
# Commit the first file (and the CREATE TABLE) before spawning workers
# so their ``assert_schema_compatible`` probes actually see the new
# table. Without this, worker connections started mid-transaction on
# the main connection would see nothing in information_schema.
conn.commit()
for path in rest:
columns, path_total_rows = _infer_cluster_schema(
path, cluster.include, cluster.exclude
)
# Uses the same check that if_exists=append runs. A type mismatch or
# missing column aborts the cluster; because chunks commit as they
# load, earlier chunks in the cluster remain in the table.
assert_schema_compatible(conn, schemaname, cluster.tablename, columns)
total += _stream_file(
conn, schemaname, cluster.tablename, path, columns,
cluster.include, cluster.exclude,
total_rows=path_total_rows,
)
if rest:
if workers > 1:
total += _load_remaining_files_parallel(
rest,
schemaname,
cluster.tablename,
cluster.include,
cluster.exclude,
workers=workers,
progress_queue=progress_queue,
db_overrides=db_overrides,
)
else:
for path in rest:
columns, path_total_rows = _infer_cluster_schema(
path, cluster.include, cluster.exclude
)
# Uses the same check that if_exists=append runs. A type
# mismatch or missing column aborts the cluster; because
# chunks commit as they load, earlier chunks in the
# cluster remain in the table.
assert_schema_compatible(
conn, schemaname, cluster.tablename, columns
)
total += _stream_file(
conn, schemaname, cluster.tablename, path, columns,
cluster.include, cluster.exclude,
total_rows=path_total_rows,
progress_queue=progress_queue,
)
# -- Index support ------------------------------------------------------
if cluster.indexes:
@ -815,8 +866,25 @@ def _stream_file(
exclude,
*,
total_rows: Optional[int] = None,
progress_queue: Any = None,
) -> int:
"""Stream ``path`` into an existing table chunk by chunk.
When ``progress_queue`` is provided, each chunk's row count is published
to the queue as ``("rows", n)`` tuples instead of being rendered to a
per-file tqdm bar. That lets :func:`main` drive a single folder-wide
progress bar from a background drainer thread, which is the only way
to keep a coherent progress view when the folder loader is running
files in parallel workers.
"""
def _chunks():
if progress_queue is not None:
for chunk_df, _chunk_meta in iter_sas_chunks(path):
chunk_df = apply_column_filter(chunk_df, include, exclude)
progress_queue.put(("rows", len(chunk_df)))
yield chunk_df
return
pbar = tqdm(
total=total_rows,
unit="row",
@ -836,6 +904,134 @@ def _stream_file(
return copy_dataframes(conn, schemaname, tablename, _chunks(), columns)
# ---------------------------------------------------------------------------
# Parallel append workers
# ---------------------------------------------------------------------------
def _worker_load_append_file(
path_str: str,
schemaname: str,
tablename: str,
include: Optional[List[str]],
exclude: Optional[List[str]],
progress_queue: Any,
db_overrides: Optional[Dict[str, Optional[str]]],
) -> Tuple[str, int, Optional[str]]:
"""Worker process: load one SAS file in append mode.
Runs in a subprocess spawned by :func:`_load_remaining_files_parallel`.
Opens its own psycopg2 connection, re-infers the per-file schema (so
per-file ``INTEGER`` vs ``BIGINT`` drift is caught by the existing
schema-compat check just like in the serial path), and streams chunks
via ``COPY``. Row counts are published to the shared queue for the
main process's global tqdm bar.
Returns ``(path_str, rows_loaded, error_or_None)`` - failures are
returned rather than raised so the parent can aggregate results
across workers without losing partial progress.
"""
from pathlib import Path as _Path
from dotenv import load_dotenv as _load_dotenv
from load_sas import (
apply_column_filter as _apply_column_filter,
assert_schema_compatible as _assert_schema_compatible,
connect as _connect,
copy_dataframes as _copy_dataframes,
infer_schema as _infer_schema,
iter_sas_chunks as _iter_sas_chunks,
read_sas_preview as _read_sas_preview,
)
_load_dotenv()
path = _Path(path_str)
try:
preview_df, meta = _read_sas_preview(path)
preview_df = _apply_column_filter(preview_df, include, exclude)
total_rows = getattr(meta, "number_rows", None)
columns = _infer_schema(preview_df, meta, total_rows=total_rows)
user = db_overrides.get("user") if db_overrides else None
password = db_overrides.get("password") if db_overrides else None
conn = _connect(user=user, password=password)
conn.autocommit = False
try:
_assert_schema_compatible(conn, schemaname, tablename, columns)
def _chunks():
for chunk_df, _chunk_meta in _iter_sas_chunks(path):
chunk_df = _apply_column_filter(chunk_df, include, exclude)
if progress_queue is not None:
progress_queue.put(("rows", len(chunk_df)))
yield chunk_df
rows = _copy_dataframes(
conn, schemaname, tablename, _chunks(), columns
)
conn.commit()
return (path_str, rows, None)
finally:
conn.close()
except Exception as e:
return (path_str, 0, f"{type(e).__name__}: {e}")
def _load_remaining_files_parallel(
files: List[Path],
schemaname: str,
tablename: str,
include: Optional[List[str]],
exclude: Optional[List[str]],
*,
workers: int,
progress_queue: Any,
db_overrides: Optional[Dict[str, Optional[str]]],
) -> int:
"""Run append-mode loads for ``files`` across a process pool.
Each file is an independent unit of work submitted to
``ProcessPoolExecutor``. Workers infer schema, validate compatibility,
and stream via COPY just like the serial path. Failures are collected
and re-raised as a single ``RuntimeError`` at the end so that all
other workers' rows still count toward the committed total.
"""
total = 0
errors: List[Tuple[str, str]] = []
with ProcessPoolExecutor(max_workers=workers) as pool:
futures = [
pool.submit(
_worker_load_append_file,
str(p),
schemaname,
tablename,
include,
exclude,
progress_queue,
db_overrides,
)
for p in files
]
for fut in as_completed(futures):
path_str, rows, err = fut.result()
if err is not None:
errors.append((path_str, err))
else:
total += rows
if errors:
joined = "\n".join(f" {p}: {e}" for p, e in errors)
raise RuntimeError(
f"{len(errors)} worker(s) failed while appending to "
f"{schemaname}.{tablename}:\n{joined}"
)
return total
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
@ -875,6 +1071,22 @@ def _build_argparser() -> argparse.ArgumentParser:
"PGUSER / PGPASSWORD from the environment or .env file."
),
)
p.add_argument(
"--workers",
type=int,
default=1,
metavar="N",
help=(
"Number of worker processes for the append phase. With N=1 "
"(default) files load serially on the main connection. With "
"N>1 the first file of each cluster still runs serially (to "
"create the table), then the remaining files load in parallel "
"across N processes, each with its own psycopg2 connection. "
"On a big box try N close to your core count. When N>1 the "
"per-chunk row target drops to 500,000 unless you've pinned "
"GENERIC_LOADER_CHUNK_ROWS, so peak memory stays bounded."
),
)
return p
@ -993,6 +1205,101 @@ def main(argv: Optional[List[str]] = None) -> int:
if args.dbcreds:
db_user = input("Database username: ")
db_password = getpass.getpass("Database password: ")
db_overrides: Optional[Dict[str, Optional[str]]] = (
{"user": db_user, "password": db_password} if args.dbcreds else None
)
workers = max(1, int(args.workers))
# When running parallel workers, bound peak memory: each worker buffers a
# chunk (read + prepared + serialized) so total memory scales with
# workers × chunk_rows × avg_row_bytes. Drop the default chunk target to
# 500k unless the operator has explicitly pinned it. Setting the env var
# before workers spawn means they inherit it through forkserver / spawn.
if (
workers > 1
and "GENERIC_LOADER_CHUNK_ROWS" not in os.environ
):
os.environ["GENERIC_LOADER_CHUNK_ROWS"] = "500000"
print(
"[info] parallel mode: bounding per-chunk rows to 500,000. "
"Pin GENERIC_LOADER_CHUNK_ROWS to override.",
file=sys.stderr,
)
# -- Metadata pre-scan -----------------------------------------------------
# Sum ``number_rows`` across every file so the tqdm bar has a real
# denominator. ``read_sas_metadata`` uses pyreadstat's ``metadataonly=True``
# fast path; a few ms per sas7bdat even on large files.
print(
f"pre-scanning row counts for {sum(len(c.files) for c in loadable)} "
f"file(s)...",
file=sys.stderr,
)
grand_total = 0
unknown_total_files: List[str] = []
for c in loadable:
for p in c.files:
try:
meta = read_sas_metadata(p)
n = getattr(meta, "number_rows", None)
if n is None:
unknown_total_files.append(p.name)
else:
grand_total += int(n)
except Exception as e:
unknown_total_files.append(f"{p.name} ({e})")
if unknown_total_files:
print(
f"[warn] could not read row count from "
f"{len(unknown_total_files)} file(s); progress bar ETA will "
f"be approximate.",
file=sys.stderr,
)
print(
f" total rows across folder: {grand_total:,}",
file=sys.stderr,
)
# -- Shared progress plumbing ---------------------------------------------
# The queue crosses process boundaries when workers > 1 (managed proxy)
# and is a plain in-process queue otherwise; the put/get contract is
# identical either way. A daemon thread drains it and advances the one
# tqdm bar that spans the whole folder load.
manager: Optional[Any] = None
progress_queue: Any
if workers > 1:
manager = mp.Manager()
progress_queue = manager.Queue()
else:
progress_queue = _queue_mod.Queue()
pbar = tqdm(
total=grand_total or None,
unit="row",
unit_scale=True,
desc=f"{cfg.folder.name}",
file=sys.stderr,
dynamic_ncols=True,
)
stop_drainer = threading.Event()
def _drainer() -> None:
while not stop_drainer.is_set():
try:
event = progress_queue.get(timeout=0.1)
except _queue_mod.Empty:
continue
except (EOFError, OSError):
return
if not event:
continue
kind = event[0]
if kind == "rows":
pbar.update(event[1])
drainer_thread = threading.Thread(target=_drainer, daemon=True)
drainer_thread.start()
conn = connect(user=db_user, password=db_password)
conn.autocommit = False
@ -1002,10 +1309,18 @@ def main(argv: Optional[List[str]] = None) -> int:
for cluster in loadable:
print(
f"\n>>> loading cluster {cluster.tablename!r} "
f"({len(cluster.files)} file(s))"
f"({len(cluster.files)} file(s)) "
f"[workers={workers}]"
)
try:
rows = load_cluster(conn, cluster, cfg.schemaname)
rows = load_cluster(
conn,
cluster,
cfg.schemaname,
workers=workers,
progress_queue=progress_queue,
db_overrides=db_overrides,
)
conn.commit()
totals.append((cluster.tablename, len(cluster.files), rows))
print(
@ -1022,7 +1337,14 @@ def main(argv: Optional[List[str]] = None) -> int:
if args.fail_fast:
break
finally:
# Drain any pending progress events before shutting the bar down so
# the final rendered total matches what actually landed.
stop_drainer.set()
drainer_thread.join(timeout=2.0)
pbar.close()
conn.close()
if manager is not None:
manager.shutdown()
print("\n=== summary ===")
for name, fcount, rows in totals:

View File

@ -580,6 +580,22 @@ def read_sas_preview(
return reader(str(Path(path)), row_limit=row_limit, **kwargs)
def read_sas_metadata(path: Path) -> Any:
"""Read only the metadata (no rows) from a SAS file.
Uses pyreadstat's ``metadataonly=True`` fast path: the reader decodes
the file header (column names, formats, total row count, etc.) and
returns without touching the data pages. Orders of magnitude faster
than :func:`read_sas_preview` when all you need is
``meta.number_rows`` - typically a few ms per sas7bdat file, which
makes it cheap to pre-scan a whole folder to populate a global
progress bar.
"""
reader, kwargs = _sas_reader(path)
_, meta = reader(str(Path(path)), metadataonly=True, **kwargs)
return meta
def iter_sas_chunks(
path: Path,
*,