diff --git a/generic_loader/load_folder.py b/generic_loader/load_folder.py index 82fe459..be2df4a 100644 --- a/generic_loader/load_folder.py +++ b/generic_loader/load_folder.py @@ -140,6 +140,7 @@ from typing import Any, Dict, List, Optional, Tuple import yaml from dotenv import load_dotenv +from tqdm import tqdm from load_sas import ( VALID_IF_EXISTS, @@ -658,13 +659,21 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: # --------------------------------------------------------------------------- -def _infer_cluster_schema(path: Path, include, exclude): - """Infer the Postgres column schema from a SAS file preview.""" +def _infer_cluster_schema( + path: Path, include, exclude +) -> Tuple[Dict, Optional[int]]: + """Infer the Postgres column schema from a SAS file preview. + + Returns ``(columns, total_rows)``. ``total_rows`` comes from the + pyreadstat metadata (the file's declared row count) and is threaded + through to :func:`_stream_file` so the tqdm progress bar has a real + denominator instead of an indeterminate spinner. + """ preview_df, meta = read_sas_preview(path) preview_df = apply_column_filter(preview_df, include, exclude) total_rows = getattr(meta, "number_rows", None) columns = infer_schema(preview_df, meta, total_rows=total_rows) - return columns + return columns, total_rows def _discover_cluster_partitions( @@ -709,7 +718,9 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int: return 0 first, *rest = cluster.files - first_columns = _infer_cluster_schema(first, cluster.include, cluster.exclude) + first_columns, first_total_rows = _infer_cluster_schema( + first, cluster.include, cluster.exclude + ) # -- Validate index columns early --------------------------------------- if cluster.indexes: @@ -770,10 +781,13 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int: total += _stream_file( conn, schemaname, cluster.tablename, first, first_columns, cluster.include, cluster.exclude, + total_rows=first_total_rows, ) for path in rest: - columns = _infer_cluster_schema(path, cluster.include, cluster.exclude) + columns, path_total_rows = _infer_cluster_schema( + path, cluster.include, cluster.exclude + ) # Uses the same check that if_exists=append runs. A type mismatch or # missing column aborts the cluster; because chunks commit as they # load, earlier chunks in the cluster remain in the table. @@ -781,6 +795,7 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int: total += _stream_file( conn, schemaname, cluster.tablename, path, columns, cluster.include, cluster.exclude, + total_rows=path_total_rows, ) # -- Index support ------------------------------------------------------ @@ -798,17 +813,25 @@ def _stream_file( columns, include, exclude, + *, + total_rows: Optional[int] = None, ) -> int: def _chunks(): - seen = 0 - for chunk_df, _chunk_meta in iter_sas_chunks(path): - chunk_df = apply_column_filter(chunk_df, include, exclude) - seen += len(chunk_df) - print( - f" {path.name}: streaming... {seen:,} rows", - file=sys.stderr, - ) - yield chunk_df + pbar = tqdm( + total=total_rows, + unit="row", + unit_scale=True, + desc=f" {path.name}", + file=sys.stderr, + dynamic_ncols=True, + ) + try: + for chunk_df, _chunk_meta in iter_sas_chunks(path): + chunk_df = apply_column_filter(chunk_df, include, exclude) + pbar.update(len(chunk_df)) + yield chunk_df + finally: + pbar.close() return copy_dataframes(conn, schemaname, tablename, _chunks(), columns) @@ -902,7 +925,7 @@ def main(argv: Optional[List[str]] = None) -> int: print() for c in loadable: print(f"--- DDL for cluster {c.tablename!r} ---") - columns = _infer_cluster_schema(c.files[0], c.include, c.exclude) + columns, _ = _infer_cluster_schema(c.files[0], c.include, c.exclude) # Print parent CREATE TABLE (with PARTITION BY if applicable). print( render_create_table( diff --git a/generic_loader/load_sas.py b/generic_loader/load_sas.py index cab2f76..5b268f6 100644 --- a/generic_loader/load_sas.py +++ b/generic_loader/load_sas.py @@ -239,6 +239,7 @@ import pyarrow.csv as pa_csv import pyreadstat import yaml from dotenv import load_dotenv +from tqdm import tqdm logger = logging.getLogger(__name__) @@ -273,10 +274,15 @@ detonates ``COPY`` mid-stream (seen in production on a 2.5M-row file where so large that a full read won't fit in memory, set this to an integer cap and accept that sampled specs can't be trusted for ``NOT NULL``.""" -DEFAULT_CHUNK_ROWS = 100_000 +DEFAULT_CHUNK_ROWS = 2_000_000 """Rows per chunk when streaming a SAS file into ``COPY``. Larger values mean -fewer COPY round-trips but more peak memory per chunk; smaller values are -gentler on memory.""" +fewer COPY round-trips and lower per-row overhead but more peak memory per +chunk; smaller values are gentler on memory. + +The chunk size can be overridden at runtime via the +``GENERIC_LOADER_CHUNK_ROWS`` environment variable (read inside +:func:`iter_sas_chunks`), so ``.env``-driven overrides work without code +changes. Explicit ``chunksize=`` kwargs still win over both.""" VALID_IF_EXISTS = ("fail", "replace", "append") @@ -577,13 +583,27 @@ def read_sas_preview( def iter_sas_chunks( path: Path, *, - chunksize: int = DEFAULT_CHUNK_ROWS, + chunksize: Optional[int] = None, ): """Yield ``(df_chunk, meta)`` tuples for streaming loads. Thin wrapper over ``pyreadstat.read_file_in_chunks`` that picks the right underlying reader by extension and threads through our encoding defaults. + + When ``chunksize`` is ``None`` (the default), the effective value comes + from the ``GENERIC_LOADER_CHUNK_ROWS`` environment variable if set and + parseable, otherwise from :data:`DEFAULT_CHUNK_ROWS`. An explicit int + always wins. """ + if chunksize is None: + raw = os.environ.get("GENERIC_LOADER_CHUNK_ROWS") + if raw is not None: + try: + chunksize = int(raw) + except ValueError: + chunksize = DEFAULT_CHUNK_ROWS + else: + chunksize = DEFAULT_CHUNK_ROWS reader, kwargs = _sas_reader(path) yield from pyreadstat.read_file_in_chunks( reader, str(Path(path)), chunksize=chunksize, **kwargs @@ -2072,13 +2092,24 @@ def main(argv: Optional[List[str]] = None) -> int: # it while we're holding a Postgres transaction open. del preview_df + total_rows = getattr(meta, "number_rows", None) + def _filtered_chunks(): - seen = 0 - for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename): - chunk_df = apply_column_filter(chunk_df, cfg.include, cfg.exclude) - seen += len(chunk_df) - print(f" streaming... {seen:,} rows", file=sys.stderr) - yield chunk_df + pbar = tqdm( + total=total_rows, + unit="row", + unit_scale=True, + desc=f" {cfg.filename.name}", + file=sys.stderr, + dynamic_ncols=True, + ) + try: + for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename): + chunk_df = apply_column_filter(chunk_df, cfg.include, cfg.exclude) + pbar.update(len(chunk_df)) + yield chunk_df + finally: + pbar.close() db_user = db_password = None if args.dbcreds: diff --git a/requirements.txt b/requirements.txt index 2511142..c09f442 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,10 @@ pandas>=2.0,<3.0 pyreadstat>=1.2,<2.0 numpy>=2.1,<3.0 -pyarrow>=16.0,<21.0 +pyarrow>=22.0,<24.0 pyyaml>=6.0,<7.0 psycopg2-binary>=2.9,<3.0 python-dotenv>=1.0,<2.0 boto3>=1.28,<2.0 openpyxl>=3.1,<4.0 +tqdm>=4.66,<5.0