advanced_analyzer #8
@ -132,8 +132,13 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import getpass
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import queue as _queue_mod
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
@ -155,6 +160,7 @@ from load_sas import (
|
||||
discover_partition_values_chunked,
|
||||
infer_schema,
|
||||
iter_sas_chunks,
|
||||
read_sas_metadata,
|
||||
read_sas_preview,
|
||||
render_create_indexes,
|
||||
render_create_table,
|
||||
@ -702,7 +708,15 @@ def _discover_cluster_partitions(
|
||||
return merged
|
||||
|
||||
|
||||
def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int:
|
||||
def load_cluster(
|
||||
conn,
|
||||
cluster: ClusterSpec,
|
||||
schemaname: str,
|
||||
*,
|
||||
workers: int = 1,
|
||||
progress_queue: Any = None,
|
||||
db_overrides: Optional[Dict[str, Optional[str]]] = None,
|
||||
) -> int:
|
||||
"""Load every file in ``cluster`` into one table. Returns total rows loaded.
|
||||
|
||||
When ``cluster.partition_by`` is non-empty, partition values are
|
||||
@ -713,6 +727,20 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int:
|
||||
file mid-cluster fails, earlier chunks - including chunks from earlier
|
||||
files in the cluster - stay committed; only the in-flight chunk is
|
||||
rolled back by :func:`main`.
|
||||
|
||||
``workers`` controls parallelism for the *append* phase. The first file
|
||||
always runs serially on ``conn`` (to create the table and, when
|
||||
partitioned, pre-create partitions). When ``workers > 1`` the remaining
|
||||
files dispatch to a ``ProcessPoolExecutor``; each worker opens its own
|
||||
psycopg2 connection, re-infers the per-file schema, runs the same
|
||||
:func:`load_sas.assert_schema_compatible` check the serial path uses,
|
||||
and streams chunks via COPY. Workers report per-chunk row counts to
|
||||
``progress_queue`` so the caller can drive a single aggregated tqdm
|
||||
bar regardless of how many workers are in flight.
|
||||
|
||||
``db_overrides`` carries ``{"user", "password"}`` into workers when the
|
||||
caller prompted for credentials interactively; leave ``None`` to let
|
||||
workers read the standard libpq environment variables on their own.
|
||||
"""
|
||||
if not cluster.files:
|
||||
return 0
|
||||
@ -782,20 +810,43 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int:
|
||||
conn, schemaname, cluster.tablename, first, first_columns,
|
||||
cluster.include, cluster.exclude,
|
||||
total_rows=first_total_rows,
|
||||
progress_queue=progress_queue,
|
||||
)
|
||||
# Commit the first file (and the CREATE TABLE) before spawning workers
|
||||
# so their ``assert_schema_compatible`` probes actually see the new
|
||||
# table. Without this, worker connections started mid-transaction on
|
||||
# the main connection would see nothing in information_schema.
|
||||
conn.commit()
|
||||
|
||||
if rest:
|
||||
if workers > 1:
|
||||
total += _load_remaining_files_parallel(
|
||||
rest,
|
||||
schemaname,
|
||||
cluster.tablename,
|
||||
cluster.include,
|
||||
cluster.exclude,
|
||||
workers=workers,
|
||||
progress_queue=progress_queue,
|
||||
db_overrides=db_overrides,
|
||||
)
|
||||
else:
|
||||
for path in rest:
|
||||
columns, path_total_rows = _infer_cluster_schema(
|
||||
path, cluster.include, cluster.exclude
|
||||
)
|
||||
# Uses the same check that if_exists=append runs. A type mismatch or
|
||||
# missing column aborts the cluster; because chunks commit as they
|
||||
# load, earlier chunks in the cluster remain in the table.
|
||||
assert_schema_compatible(conn, schemaname, cluster.tablename, columns)
|
||||
# Uses the same check that if_exists=append runs. A type
|
||||
# mismatch or missing column aborts the cluster; because
|
||||
# chunks commit as they load, earlier chunks in the
|
||||
# cluster remain in the table.
|
||||
assert_schema_compatible(
|
||||
conn, schemaname, cluster.tablename, columns
|
||||
)
|
||||
total += _stream_file(
|
||||
conn, schemaname, cluster.tablename, path, columns,
|
||||
cluster.include, cluster.exclude,
|
||||
total_rows=path_total_rows,
|
||||
progress_queue=progress_queue,
|
||||
)
|
||||
|
||||
# -- Index support ------------------------------------------------------
|
||||
@ -815,8 +866,25 @@ def _stream_file(
|
||||
exclude,
|
||||
*,
|
||||
total_rows: Optional[int] = None,
|
||||
progress_queue: Any = None,
|
||||
) -> int:
|
||||
"""Stream ``path`` into an existing table chunk by chunk.
|
||||
|
||||
When ``progress_queue`` is provided, each chunk's row count is published
|
||||
to the queue as ``("rows", n)`` tuples instead of being rendered to a
|
||||
per-file tqdm bar. That lets :func:`main` drive a single folder-wide
|
||||
progress bar from a background drainer thread, which is the only way
|
||||
to keep a coherent progress view when the folder loader is running
|
||||
files in parallel workers.
|
||||
"""
|
||||
def _chunks():
|
||||
if progress_queue is not None:
|
||||
for chunk_df, _chunk_meta in iter_sas_chunks(path):
|
||||
chunk_df = apply_column_filter(chunk_df, include, exclude)
|
||||
progress_queue.put(("rows", len(chunk_df)))
|
||||
yield chunk_df
|
||||
return
|
||||
|
||||
pbar = tqdm(
|
||||
total=total_rows,
|
||||
unit="row",
|
||||
@ -836,6 +904,134 @@ def _stream_file(
|
||||
return copy_dataframes(conn, schemaname, tablename, _chunks(), columns)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parallel append workers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _worker_load_append_file(
|
||||
path_str: str,
|
||||
schemaname: str,
|
||||
tablename: str,
|
||||
include: Optional[List[str]],
|
||||
exclude: Optional[List[str]],
|
||||
progress_queue: Any,
|
||||
db_overrides: Optional[Dict[str, Optional[str]]],
|
||||
) -> Tuple[str, int, Optional[str]]:
|
||||
"""Worker process: load one SAS file in append mode.
|
||||
|
||||
Runs in a subprocess spawned by :func:`_load_remaining_files_parallel`.
|
||||
Opens its own psycopg2 connection, re-infers the per-file schema (so
|
||||
per-file ``INTEGER`` vs ``BIGINT`` drift is caught by the existing
|
||||
schema-compat check just like in the serial path), and streams chunks
|
||||
via ``COPY``. Row counts are published to the shared queue for the
|
||||
main process's global tqdm bar.
|
||||
|
||||
Returns ``(path_str, rows_loaded, error_or_None)`` - failures are
|
||||
returned rather than raised so the parent can aggregate results
|
||||
across workers without losing partial progress.
|
||||
"""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
from dotenv import load_dotenv as _load_dotenv
|
||||
|
||||
from load_sas import (
|
||||
apply_column_filter as _apply_column_filter,
|
||||
assert_schema_compatible as _assert_schema_compatible,
|
||||
connect as _connect,
|
||||
copy_dataframes as _copy_dataframes,
|
||||
infer_schema as _infer_schema,
|
||||
iter_sas_chunks as _iter_sas_chunks,
|
||||
read_sas_preview as _read_sas_preview,
|
||||
)
|
||||
|
||||
_load_dotenv()
|
||||
|
||||
path = _Path(path_str)
|
||||
try:
|
||||
preview_df, meta = _read_sas_preview(path)
|
||||
preview_df = _apply_column_filter(preview_df, include, exclude)
|
||||
total_rows = getattr(meta, "number_rows", None)
|
||||
columns = _infer_schema(preview_df, meta, total_rows=total_rows)
|
||||
|
||||
user = db_overrides.get("user") if db_overrides else None
|
||||
password = db_overrides.get("password") if db_overrides else None
|
||||
conn = _connect(user=user, password=password)
|
||||
conn.autocommit = False
|
||||
try:
|
||||
_assert_schema_compatible(conn, schemaname, tablename, columns)
|
||||
|
||||
def _chunks():
|
||||
for chunk_df, _chunk_meta in _iter_sas_chunks(path):
|
||||
chunk_df = _apply_column_filter(chunk_df, include, exclude)
|
||||
if progress_queue is not None:
|
||||
progress_queue.put(("rows", len(chunk_df)))
|
||||
yield chunk_df
|
||||
|
||||
rows = _copy_dataframes(
|
||||
conn, schemaname, tablename, _chunks(), columns
|
||||
)
|
||||
conn.commit()
|
||||
return (path_str, rows, None)
|
||||
finally:
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
return (path_str, 0, f"{type(e).__name__}: {e}")
|
||||
|
||||
|
||||
def _load_remaining_files_parallel(
|
||||
files: List[Path],
|
||||
schemaname: str,
|
||||
tablename: str,
|
||||
include: Optional[List[str]],
|
||||
exclude: Optional[List[str]],
|
||||
*,
|
||||
workers: int,
|
||||
progress_queue: Any,
|
||||
db_overrides: Optional[Dict[str, Optional[str]]],
|
||||
) -> int:
|
||||
"""Run append-mode loads for ``files`` across a process pool.
|
||||
|
||||
Each file is an independent unit of work submitted to
|
||||
``ProcessPoolExecutor``. Workers infer schema, validate compatibility,
|
||||
and stream via COPY just like the serial path. Failures are collected
|
||||
and re-raised as a single ``RuntimeError`` at the end so that all
|
||||
other workers' rows still count toward the committed total.
|
||||
"""
|
||||
total = 0
|
||||
errors: List[Tuple[str, str]] = []
|
||||
|
||||
with ProcessPoolExecutor(max_workers=workers) as pool:
|
||||
futures = [
|
||||
pool.submit(
|
||||
_worker_load_append_file,
|
||||
str(p),
|
||||
schemaname,
|
||||
tablename,
|
||||
include,
|
||||
exclude,
|
||||
progress_queue,
|
||||
db_overrides,
|
||||
)
|
||||
for p in files
|
||||
]
|
||||
for fut in as_completed(futures):
|
||||
path_str, rows, err = fut.result()
|
||||
if err is not None:
|
||||
errors.append((path_str, err))
|
||||
else:
|
||||
total += rows
|
||||
|
||||
if errors:
|
||||
joined = "\n".join(f" {p}: {e}" for p, e in errors)
|
||||
raise RuntimeError(
|
||||
f"{len(errors)} worker(s) failed while appending to "
|
||||
f"{schemaname}.{tablename}:\n{joined}"
|
||||
)
|
||||
|
||||
return total
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -875,6 +1071,22 @@ def _build_argparser() -> argparse.ArgumentParser:
|
||||
"PGUSER / PGPASSWORD from the environment or .env file."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=1,
|
||||
metavar="N",
|
||||
help=(
|
||||
"Number of worker processes for the append phase. With N=1 "
|
||||
"(default) files load serially on the main connection. With "
|
||||
"N>1 the first file of each cluster still runs serially (to "
|
||||
"create the table), then the remaining files load in parallel "
|
||||
"across N processes, each with its own psycopg2 connection. "
|
||||
"On a big box try N close to your core count. When N>1 the "
|
||||
"per-chunk row target drops to 500,000 unless you've pinned "
|
||||
"GENERIC_LOADER_CHUNK_ROWS, so peak memory stays bounded."
|
||||
),
|
||||
)
|
||||
return p
|
||||
|
||||
|
||||
@ -993,6 +1205,101 @@ def main(argv: Optional[List[str]] = None) -> int:
|
||||
if args.dbcreds:
|
||||
db_user = input("Database username: ")
|
||||
db_password = getpass.getpass("Database password: ")
|
||||
db_overrides: Optional[Dict[str, Optional[str]]] = (
|
||||
{"user": db_user, "password": db_password} if args.dbcreds else None
|
||||
)
|
||||
|
||||
workers = max(1, int(args.workers))
|
||||
|
||||
# When running parallel workers, bound peak memory: each worker buffers a
|
||||
# chunk (read + prepared + serialized) so total memory scales with
|
||||
# workers × chunk_rows × avg_row_bytes. Drop the default chunk target to
|
||||
# 500k unless the operator has explicitly pinned it. Setting the env var
|
||||
# before workers spawn means they inherit it through forkserver / spawn.
|
||||
if (
|
||||
workers > 1
|
||||
and "GENERIC_LOADER_CHUNK_ROWS" not in os.environ
|
||||
):
|
||||
os.environ["GENERIC_LOADER_CHUNK_ROWS"] = "500000"
|
||||
print(
|
||||
"[info] parallel mode: bounding per-chunk rows to 500,000. "
|
||||
"Pin GENERIC_LOADER_CHUNK_ROWS to override.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# -- Metadata pre-scan -----------------------------------------------------
|
||||
# Sum ``number_rows`` across every file so the tqdm bar has a real
|
||||
# denominator. ``read_sas_metadata`` uses pyreadstat's ``metadataonly=True``
|
||||
# fast path; a few ms per sas7bdat even on large files.
|
||||
print(
|
||||
f"pre-scanning row counts for {sum(len(c.files) for c in loadable)} "
|
||||
f"file(s)...",
|
||||
file=sys.stderr,
|
||||
)
|
||||
grand_total = 0
|
||||
unknown_total_files: List[str] = []
|
||||
for c in loadable:
|
||||
for p in c.files:
|
||||
try:
|
||||
meta = read_sas_metadata(p)
|
||||
n = getattr(meta, "number_rows", None)
|
||||
if n is None:
|
||||
unknown_total_files.append(p.name)
|
||||
else:
|
||||
grand_total += int(n)
|
||||
except Exception as e:
|
||||
unknown_total_files.append(f"{p.name} ({e})")
|
||||
if unknown_total_files:
|
||||
print(
|
||||
f"[warn] could not read row count from "
|
||||
f"{len(unknown_total_files)} file(s); progress bar ETA will "
|
||||
f"be approximate.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
f" total rows across folder: {grand_total:,}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# -- Shared progress plumbing ---------------------------------------------
|
||||
# The queue crosses process boundaries when workers > 1 (managed proxy)
|
||||
# and is a plain in-process queue otherwise; the put/get contract is
|
||||
# identical either way. A daemon thread drains it and advances the one
|
||||
# tqdm bar that spans the whole folder load.
|
||||
manager: Optional[Any] = None
|
||||
progress_queue: Any
|
||||
if workers > 1:
|
||||
manager = mp.Manager()
|
||||
progress_queue = manager.Queue()
|
||||
else:
|
||||
progress_queue = _queue_mod.Queue()
|
||||
|
||||
pbar = tqdm(
|
||||
total=grand_total or None,
|
||||
unit="row",
|
||||
unit_scale=True,
|
||||
desc=f"{cfg.folder.name}",
|
||||
file=sys.stderr,
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
stop_drainer = threading.Event()
|
||||
|
||||
def _drainer() -> None:
|
||||
while not stop_drainer.is_set():
|
||||
try:
|
||||
event = progress_queue.get(timeout=0.1)
|
||||
except _queue_mod.Empty:
|
||||
continue
|
||||
except (EOFError, OSError):
|
||||
return
|
||||
if not event:
|
||||
continue
|
||||
kind = event[0]
|
||||
if kind == "rows":
|
||||
pbar.update(event[1])
|
||||
|
||||
drainer_thread = threading.Thread(target=_drainer, daemon=True)
|
||||
drainer_thread.start()
|
||||
|
||||
conn = connect(user=db_user, password=db_password)
|
||||
conn.autocommit = False
|
||||
@ -1003,9 +1310,17 @@ def main(argv: Optional[List[str]] = None) -> int:
|
||||
print(
|
||||
f"\n>>> loading cluster {cluster.tablename!r} "
|
||||
f"({len(cluster.files)} file(s)) "
|
||||
f"[workers={workers}]"
|
||||
)
|
||||
try:
|
||||
rows = load_cluster(conn, cluster, cfg.schemaname)
|
||||
rows = load_cluster(
|
||||
conn,
|
||||
cluster,
|
||||
cfg.schemaname,
|
||||
workers=workers,
|
||||
progress_queue=progress_queue,
|
||||
db_overrides=db_overrides,
|
||||
)
|
||||
conn.commit()
|
||||
totals.append((cluster.tablename, len(cluster.files), rows))
|
||||
print(
|
||||
@ -1022,7 +1337,14 @@ def main(argv: Optional[List[str]] = None) -> int:
|
||||
if args.fail_fast:
|
||||
break
|
||||
finally:
|
||||
# Drain any pending progress events before shutting the bar down so
|
||||
# the final rendered total matches what actually landed.
|
||||
stop_drainer.set()
|
||||
drainer_thread.join(timeout=2.0)
|
||||
pbar.close()
|
||||
conn.close()
|
||||
if manager is not None:
|
||||
manager.shutdown()
|
||||
|
||||
print("\n=== summary ===")
|
||||
for name, fcount, rows in totals:
|
||||
|
||||
@ -580,6 +580,22 @@ def read_sas_preview(
|
||||
return reader(str(Path(path)), row_limit=row_limit, **kwargs)
|
||||
|
||||
|
||||
def read_sas_metadata(path: Path) -> Any:
|
||||
"""Read only the metadata (no rows) from a SAS file.
|
||||
|
||||
Uses pyreadstat's ``metadataonly=True`` fast path: the reader decodes
|
||||
the file header (column names, formats, total row count, etc.) and
|
||||
returns without touching the data pages. Orders of magnitude faster
|
||||
than :func:`read_sas_preview` when all you need is
|
||||
``meta.number_rows`` - typically a few ms per sas7bdat file, which
|
||||
makes it cheap to pre-scan a whole folder to populate a global
|
||||
progress bar.
|
||||
"""
|
||||
reader, kwargs = _sas_reader(path)
|
||||
_, meta = reader(str(Path(path)), metadataonly=True, **kwargs)
|
||||
return meta
|
||||
|
||||
|
||||
def iter_sas_chunks(
|
||||
path: Path,
|
||||
*,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user