diff --git a/generic_loader/load_folder.py b/generic_loader/load_folder.py index be2df4a..0593c46 100644 --- a/generic_loader/load_folder.py +++ b/generic_loader/load_folder.py @@ -132,8 +132,13 @@ from __future__ import annotations import argparse import getpass +import multiprocessing as mp +import os +import queue as _queue_mod import re import sys +import threading +from concurrent.futures import ProcessPoolExecutor, as_completed from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Tuple @@ -155,6 +160,7 @@ from load_sas import ( discover_partition_values_chunked, infer_schema, iter_sas_chunks, + read_sas_metadata, read_sas_preview, render_create_indexes, render_create_table, @@ -702,7 +708,15 @@ def _discover_cluster_partitions( return merged -def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int: +def load_cluster( + conn, + cluster: ClusterSpec, + schemaname: str, + *, + workers: int = 1, + progress_queue: Any = None, + db_overrides: Optional[Dict[str, Optional[str]]] = None, +) -> int: """Load every file in ``cluster`` into one table. Returns total rows loaded. When ``cluster.partition_by`` is non-empty, partition values are @@ -713,6 +727,20 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int: file mid-cluster fails, earlier chunks - including chunks from earlier files in the cluster - stay committed; only the in-flight chunk is rolled back by :func:`main`. + + ``workers`` controls parallelism for the *append* phase. The first file + always runs serially on ``conn`` (to create the table and, when + partitioned, pre-create partitions). When ``workers > 1`` the remaining + files dispatch to a ``ProcessPoolExecutor``; each worker opens its own + psycopg2 connection, re-infers the per-file schema, runs the same + :func:`load_sas.assert_schema_compatible` check the serial path uses, + and streams chunks via COPY. Workers report per-chunk row counts to + ``progress_queue`` so the caller can drive a single aggregated tqdm + bar regardless of how many workers are in flight. + + ``db_overrides`` carries ``{"user", "password"}`` into workers when the + caller prompted for credentials interactively; leave ``None`` to let + workers read the standard libpq environment variables on their own. """ if not cluster.files: return 0 @@ -782,21 +810,44 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int: conn, schemaname, cluster.tablename, first, first_columns, cluster.include, cluster.exclude, total_rows=first_total_rows, + progress_queue=progress_queue, ) + # Commit the first file (and the CREATE TABLE) before spawning workers + # so their ``assert_schema_compatible`` probes actually see the new + # table. Without this, worker connections started mid-transaction on + # the main connection would see nothing in information_schema. + conn.commit() - for path in rest: - columns, path_total_rows = _infer_cluster_schema( - path, cluster.include, cluster.exclude - ) - # Uses the same check that if_exists=append runs. A type mismatch or - # missing column aborts the cluster; because chunks commit as they - # load, earlier chunks in the cluster remain in the table. - assert_schema_compatible(conn, schemaname, cluster.tablename, columns) - total += _stream_file( - conn, schemaname, cluster.tablename, path, columns, - cluster.include, cluster.exclude, - total_rows=path_total_rows, - ) + if rest: + if workers > 1: + total += _load_remaining_files_parallel( + rest, + schemaname, + cluster.tablename, + cluster.include, + cluster.exclude, + workers=workers, + progress_queue=progress_queue, + db_overrides=db_overrides, + ) + else: + for path in rest: + columns, path_total_rows = _infer_cluster_schema( + path, cluster.include, cluster.exclude + ) + # Uses the same check that if_exists=append runs. A type + # mismatch or missing column aborts the cluster; because + # chunks commit as they load, earlier chunks in the + # cluster remain in the table. + assert_schema_compatible( + conn, schemaname, cluster.tablename, columns + ) + total += _stream_file( + conn, schemaname, cluster.tablename, path, columns, + cluster.include, cluster.exclude, + total_rows=path_total_rows, + progress_queue=progress_queue, + ) # -- Index support ------------------------------------------------------ if cluster.indexes: @@ -815,8 +866,25 @@ def _stream_file( exclude, *, total_rows: Optional[int] = None, + progress_queue: Any = None, ) -> int: + """Stream ``path`` into an existing table chunk by chunk. + + When ``progress_queue`` is provided, each chunk's row count is published + to the queue as ``("rows", n)`` tuples instead of being rendered to a + per-file tqdm bar. That lets :func:`main` drive a single folder-wide + progress bar from a background drainer thread, which is the only way + to keep a coherent progress view when the folder loader is running + files in parallel workers. + """ def _chunks(): + if progress_queue is not None: + for chunk_df, _chunk_meta in iter_sas_chunks(path): + chunk_df = apply_column_filter(chunk_df, include, exclude) + progress_queue.put(("rows", len(chunk_df))) + yield chunk_df + return + pbar = tqdm( total=total_rows, unit="row", @@ -836,6 +904,134 @@ def _stream_file( return copy_dataframes(conn, schemaname, tablename, _chunks(), columns) +# --------------------------------------------------------------------------- +# Parallel append workers +# --------------------------------------------------------------------------- + + +def _worker_load_append_file( + path_str: str, + schemaname: str, + tablename: str, + include: Optional[List[str]], + exclude: Optional[List[str]], + progress_queue: Any, + db_overrides: Optional[Dict[str, Optional[str]]], +) -> Tuple[str, int, Optional[str]]: + """Worker process: load one SAS file in append mode. + + Runs in a subprocess spawned by :func:`_load_remaining_files_parallel`. + Opens its own psycopg2 connection, re-infers the per-file schema (so + per-file ``INTEGER`` vs ``BIGINT`` drift is caught by the existing + schema-compat check just like in the serial path), and streams chunks + via ``COPY``. Row counts are published to the shared queue for the + main process's global tqdm bar. + + Returns ``(path_str, rows_loaded, error_or_None)`` - failures are + returned rather than raised so the parent can aggregate results + across workers without losing partial progress. + """ + from pathlib import Path as _Path + + from dotenv import load_dotenv as _load_dotenv + + from load_sas import ( + apply_column_filter as _apply_column_filter, + assert_schema_compatible as _assert_schema_compatible, + connect as _connect, + copy_dataframes as _copy_dataframes, + infer_schema as _infer_schema, + iter_sas_chunks as _iter_sas_chunks, + read_sas_preview as _read_sas_preview, + ) + + _load_dotenv() + + path = _Path(path_str) + try: + preview_df, meta = _read_sas_preview(path) + preview_df = _apply_column_filter(preview_df, include, exclude) + total_rows = getattr(meta, "number_rows", None) + columns = _infer_schema(preview_df, meta, total_rows=total_rows) + + user = db_overrides.get("user") if db_overrides else None + password = db_overrides.get("password") if db_overrides else None + conn = _connect(user=user, password=password) + conn.autocommit = False + try: + _assert_schema_compatible(conn, schemaname, tablename, columns) + + def _chunks(): + for chunk_df, _chunk_meta in _iter_sas_chunks(path): + chunk_df = _apply_column_filter(chunk_df, include, exclude) + if progress_queue is not None: + progress_queue.put(("rows", len(chunk_df))) + yield chunk_df + + rows = _copy_dataframes( + conn, schemaname, tablename, _chunks(), columns + ) + conn.commit() + return (path_str, rows, None) + finally: + conn.close() + except Exception as e: + return (path_str, 0, f"{type(e).__name__}: {e}") + + +def _load_remaining_files_parallel( + files: List[Path], + schemaname: str, + tablename: str, + include: Optional[List[str]], + exclude: Optional[List[str]], + *, + workers: int, + progress_queue: Any, + db_overrides: Optional[Dict[str, Optional[str]]], +) -> int: + """Run append-mode loads for ``files`` across a process pool. + + Each file is an independent unit of work submitted to + ``ProcessPoolExecutor``. Workers infer schema, validate compatibility, + and stream via COPY just like the serial path. Failures are collected + and re-raised as a single ``RuntimeError`` at the end so that all + other workers' rows still count toward the committed total. + """ + total = 0 + errors: List[Tuple[str, str]] = [] + + with ProcessPoolExecutor(max_workers=workers) as pool: + futures = [ + pool.submit( + _worker_load_append_file, + str(p), + schemaname, + tablename, + include, + exclude, + progress_queue, + db_overrides, + ) + for p in files + ] + for fut in as_completed(futures): + path_str, rows, err = fut.result() + if err is not None: + errors.append((path_str, err)) + else: + total += rows + + if errors: + joined = "\n".join(f" {p}: {e}" for p, e in errors) + raise RuntimeError( + f"{len(errors)} worker(s) failed while appending to " + f"{schemaname}.{tablename}:\n{joined}" + ) + + return total + + # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- @@ -875,6 +1071,22 @@ def _build_argparser() -> argparse.ArgumentParser: "PGUSER / PGPASSWORD from the environment or .env file." ), ) + p.add_argument( + "--workers", + type=int, + default=1, + metavar="N", + help=( + "Number of worker processes for the append phase. With N=1 " + "(default) files load serially on the main connection. With " + "N>1 the first file of each cluster still runs serially (to " + "create the table), then the remaining files load in parallel " + "across N processes, each with its own psycopg2 connection. " + "On a big box try N close to your core count. When N>1 the " + "per-chunk row target drops to 500,000 unless you've pinned " + "GENERIC_LOADER_CHUNK_ROWS, so peak memory stays bounded." + ), + ) return p @@ -993,6 +1205,101 @@ def main(argv: Optional[List[str]] = None) -> int: if args.dbcreds: db_user = input("Database username: ") db_password = getpass.getpass("Database password: ") + db_overrides: Optional[Dict[str, Optional[str]]] = ( + {"user": db_user, "password": db_password} if args.dbcreds else None + ) + + workers = max(1, int(args.workers)) + + # When running parallel workers, bound peak memory: each worker buffers a + # chunk (read + prepared + serialized) so total memory scales with + # workers × chunk_rows × avg_row_bytes. Drop the default chunk target to + # 500k unless the operator has explicitly pinned it. Setting the env var + # before workers spawn means they inherit it through forkserver / spawn. + if ( + workers > 1 + and "GENERIC_LOADER_CHUNK_ROWS" not in os.environ + ): + os.environ["GENERIC_LOADER_CHUNK_ROWS"] = "500000" + print( + "[info] parallel mode: bounding per-chunk rows to 500,000. " + "Pin GENERIC_LOADER_CHUNK_ROWS to override.", + file=sys.stderr, + ) + + # -- Metadata pre-scan ----------------------------------------------------- + # Sum ``number_rows`` across every file so the tqdm bar has a real + # denominator. ``read_sas_metadata`` uses pyreadstat's ``metadataonly=True`` + # fast path; a few ms per sas7bdat even on large files. + print( + f"pre-scanning row counts for {sum(len(c.files) for c in loadable)} " + f"file(s)...", + file=sys.stderr, + ) + grand_total = 0 + unknown_total_files: List[str] = [] + for c in loadable: + for p in c.files: + try: + meta = read_sas_metadata(p) + n = getattr(meta, "number_rows", None) + if n is None: + unknown_total_files.append(p.name) + else: + grand_total += int(n) + except Exception as e: + unknown_total_files.append(f"{p.name} ({e})") + if unknown_total_files: + print( + f"[warn] could not read row count from " + f"{len(unknown_total_files)} file(s); progress bar ETA will " + f"be approximate.", + file=sys.stderr, + ) + print( + f" total rows across folder: {grand_total:,}", + file=sys.stderr, + ) + + # -- Shared progress plumbing --------------------------------------------- + # The queue crosses process boundaries when workers > 1 (managed proxy) + # and is a plain in-process queue otherwise; the put/get contract is + # identical either way. A daemon thread drains it and advances the one + # tqdm bar that spans the whole folder load. + manager: Optional[Any] = None + progress_queue: Any + if workers > 1: + manager = mp.Manager() + progress_queue = manager.Queue() + else: + progress_queue = _queue_mod.Queue() + + pbar = tqdm( + total=grand_total or None, + unit="row", + unit_scale=True, + desc=f"{cfg.folder.name}", + file=sys.stderr, + dynamic_ncols=True, + ) + stop_drainer = threading.Event() + + def _drainer() -> None: + while not stop_drainer.is_set(): + try: + event = progress_queue.get(timeout=0.1) + except _queue_mod.Empty: + continue + except (EOFError, OSError): + return + if not event: + continue + kind = event[0] + if kind == "rows": + pbar.update(event[1]) + + drainer_thread = threading.Thread(target=_drainer, daemon=True) + drainer_thread.start() conn = connect(user=db_user, password=db_password) conn.autocommit = False @@ -1002,10 +1309,18 @@ def main(argv: Optional[List[str]] = None) -> int: for cluster in loadable: print( f"\n>>> loading cluster {cluster.tablename!r} " - f"({len(cluster.files)} file(s))" + f"({len(cluster.files)} file(s)) " + f"[workers={workers}]" ) try: - rows = load_cluster(conn, cluster, cfg.schemaname) + rows = load_cluster( + conn, + cluster, + cfg.schemaname, + workers=workers, + progress_queue=progress_queue, + db_overrides=db_overrides, + ) conn.commit() totals.append((cluster.tablename, len(cluster.files), rows)) print( @@ -1022,7 +1337,14 @@ def main(argv: Optional[List[str]] = None) -> int: if args.fail_fast: break finally: + # Drain any pending progress events before shutting the bar down so + # the final rendered total matches what actually landed. + stop_drainer.set() + drainer_thread.join(timeout=2.0) + pbar.close() conn.close() + if manager is not None: + manager.shutdown() print("\n=== summary ===") for name, fcount, rows in totals: diff --git a/generic_loader/load_sas.py b/generic_loader/load_sas.py index 5b268f6..8dfe18f 100644 --- a/generic_loader/load_sas.py +++ b/generic_loader/load_sas.py @@ -580,6 +580,22 @@ def read_sas_preview( return reader(str(Path(path)), row_limit=row_limit, **kwargs) +def read_sas_metadata(path: Path) -> Any: + """Read only the metadata (no rows) from a SAS file. + + Uses pyreadstat's ``metadataonly=True`` fast path: the reader decodes + the file header (column names, formats, total row count, etc.) and + returns without touching the data pages. Orders of magnitude faster + than :func:`read_sas_preview` when all you need is + ``meta.number_rows`` - typically a few ms per sas7bdat file, which + makes it cheap to pre-scan a whole folder to populate a global + progress bar. + """ + reader, kwargs = _sas_reader(path) + _, meta = reader(str(Path(path)), metadataonly=True, **kwargs) + return meta + + def iter_sas_chunks( path: Path, *,