Update requirements and enhance SAS file processing with progress tracking

Updated the pyarrow version in requirements.txt to improve compatibility. Enhanced the _infer_cluster_schema and _stream_file functions in load_folder.py and load_sas.py to return total row counts for better progress tracking during data streaming. Integrated tqdm for visual feedback on row processing, improving user experience during large data loads.
This commit is contained in:
David Peterson 2026-04-20 21:44:49 -05:00
parent 7beb44ac4d
commit 96f2d6fe79
3 changed files with 81 additions and 26 deletions

View File

@ -140,6 +140,7 @@ from typing import Any, Dict, List, Optional, Tuple
import yaml import yaml
from dotenv import load_dotenv from dotenv import load_dotenv
from tqdm import tqdm
from load_sas import ( from load_sas import (
VALID_IF_EXISTS, VALID_IF_EXISTS,
@ -658,13 +659,21 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _infer_cluster_schema(path: Path, include, exclude): def _infer_cluster_schema(
"""Infer the Postgres column schema from a SAS file preview.""" path: Path, include, exclude
) -> Tuple[Dict, Optional[int]]:
"""Infer the Postgres column schema from a SAS file preview.
Returns ``(columns, total_rows)``. ``total_rows`` comes from the
pyreadstat metadata (the file's declared row count) and is threaded
through to :func:`_stream_file` so the tqdm progress bar has a real
denominator instead of an indeterminate spinner.
"""
preview_df, meta = read_sas_preview(path) preview_df, meta = read_sas_preview(path)
preview_df = apply_column_filter(preview_df, include, exclude) preview_df = apply_column_filter(preview_df, include, exclude)
total_rows = getattr(meta, "number_rows", None) total_rows = getattr(meta, "number_rows", None)
columns = infer_schema(preview_df, meta, total_rows=total_rows) columns = infer_schema(preview_df, meta, total_rows=total_rows)
return columns return columns, total_rows
def _discover_cluster_partitions( def _discover_cluster_partitions(
@ -709,7 +718,9 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int:
return 0 return 0
first, *rest = cluster.files first, *rest = cluster.files
first_columns = _infer_cluster_schema(first, cluster.include, cluster.exclude) first_columns, first_total_rows = _infer_cluster_schema(
first, cluster.include, cluster.exclude
)
# -- Validate index columns early --------------------------------------- # -- Validate index columns early ---------------------------------------
if cluster.indexes: if cluster.indexes:
@ -770,10 +781,13 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int:
total += _stream_file( total += _stream_file(
conn, schemaname, cluster.tablename, first, first_columns, conn, schemaname, cluster.tablename, first, first_columns,
cluster.include, cluster.exclude, cluster.include, cluster.exclude,
total_rows=first_total_rows,
) )
for path in rest: for path in rest:
columns = _infer_cluster_schema(path, cluster.include, cluster.exclude) columns, path_total_rows = _infer_cluster_schema(
path, cluster.include, cluster.exclude
)
# Uses the same check that if_exists=append runs. A type mismatch or # Uses the same check that if_exists=append runs. A type mismatch or
# missing column aborts the cluster; because chunks commit as they # missing column aborts the cluster; because chunks commit as they
# load, earlier chunks in the cluster remain in the table. # load, earlier chunks in the cluster remain in the table.
@ -781,6 +795,7 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int:
total += _stream_file( total += _stream_file(
conn, schemaname, cluster.tablename, path, columns, conn, schemaname, cluster.tablename, path, columns,
cluster.include, cluster.exclude, cluster.include, cluster.exclude,
total_rows=path_total_rows,
) )
# -- Index support ------------------------------------------------------ # -- Index support ------------------------------------------------------
@ -798,17 +813,25 @@ def _stream_file(
columns, columns,
include, include,
exclude, exclude,
*,
total_rows: Optional[int] = None,
) -> int: ) -> int:
def _chunks(): def _chunks():
seen = 0 pbar = tqdm(
total=total_rows,
unit="row",
unit_scale=True,
desc=f" {path.name}",
file=sys.stderr,
dynamic_ncols=True,
)
try:
for chunk_df, _chunk_meta in iter_sas_chunks(path): for chunk_df, _chunk_meta in iter_sas_chunks(path):
chunk_df = apply_column_filter(chunk_df, include, exclude) chunk_df = apply_column_filter(chunk_df, include, exclude)
seen += len(chunk_df) pbar.update(len(chunk_df))
print(
f" {path.name}: streaming... {seen:,} rows",
file=sys.stderr,
)
yield chunk_df yield chunk_df
finally:
pbar.close()
return copy_dataframes(conn, schemaname, tablename, _chunks(), columns) return copy_dataframes(conn, schemaname, tablename, _chunks(), columns)
@ -902,7 +925,7 @@ def main(argv: Optional[List[str]] = None) -> int:
print() print()
for c in loadable: for c in loadable:
print(f"--- DDL for cluster {c.tablename!r} ---") print(f"--- DDL for cluster {c.tablename!r} ---")
columns = _infer_cluster_schema(c.files[0], c.include, c.exclude) columns, _ = _infer_cluster_schema(c.files[0], c.include, c.exclude)
# Print parent CREATE TABLE (with PARTITION BY if applicable). # Print parent CREATE TABLE (with PARTITION BY if applicable).
print( print(
render_create_table( render_create_table(

View File

@ -239,6 +239,7 @@ import pyarrow.csv as pa_csv
import pyreadstat import pyreadstat
import yaml import yaml
from dotenv import load_dotenv from dotenv import load_dotenv
from tqdm import tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -273,10 +274,15 @@ detonates ``COPY`` mid-stream (seen in production on a 2.5M-row file where
so large that a full read won't fit in memory, set this to an integer cap so large that a full read won't fit in memory, set this to an integer cap
and accept that sampled specs can't be trusted for ``NOT NULL``.""" and accept that sampled specs can't be trusted for ``NOT NULL``."""
DEFAULT_CHUNK_ROWS = 100_000 DEFAULT_CHUNK_ROWS = 2_000_000
"""Rows per chunk when streaming a SAS file into ``COPY``. Larger values mean """Rows per chunk when streaming a SAS file into ``COPY``. Larger values mean
fewer COPY round-trips but more peak memory per chunk; smaller values are fewer COPY round-trips and lower per-row overhead but more peak memory per
gentler on memory.""" chunk; smaller values are gentler on memory.
The chunk size can be overridden at runtime via the
``GENERIC_LOADER_CHUNK_ROWS`` environment variable (read inside
:func:`iter_sas_chunks`), so ``.env``-driven overrides work without code
changes. Explicit ``chunksize=`` kwargs still win over both."""
VALID_IF_EXISTS = ("fail", "replace", "append") VALID_IF_EXISTS = ("fail", "replace", "append")
@ -577,13 +583,27 @@ def read_sas_preview(
def iter_sas_chunks( def iter_sas_chunks(
path: Path, path: Path,
*, *,
chunksize: int = DEFAULT_CHUNK_ROWS, chunksize: Optional[int] = None,
): ):
"""Yield ``(df_chunk, meta)`` tuples for streaming loads. """Yield ``(df_chunk, meta)`` tuples for streaming loads.
Thin wrapper over ``pyreadstat.read_file_in_chunks`` that picks the right Thin wrapper over ``pyreadstat.read_file_in_chunks`` that picks the right
underlying reader by extension and threads through our encoding defaults. underlying reader by extension and threads through our encoding defaults.
When ``chunksize`` is ``None`` (the default), the effective value comes
from the ``GENERIC_LOADER_CHUNK_ROWS`` environment variable if set and
parseable, otherwise from :data:`DEFAULT_CHUNK_ROWS`. An explicit int
always wins.
""" """
if chunksize is None:
raw = os.environ.get("GENERIC_LOADER_CHUNK_ROWS")
if raw is not None:
try:
chunksize = int(raw)
except ValueError:
chunksize = DEFAULT_CHUNK_ROWS
else:
chunksize = DEFAULT_CHUNK_ROWS
reader, kwargs = _sas_reader(path) reader, kwargs = _sas_reader(path)
yield from pyreadstat.read_file_in_chunks( yield from pyreadstat.read_file_in_chunks(
reader, str(Path(path)), chunksize=chunksize, **kwargs reader, str(Path(path)), chunksize=chunksize, **kwargs
@ -2072,13 +2092,24 @@ def main(argv: Optional[List[str]] = None) -> int:
# it while we're holding a Postgres transaction open. # it while we're holding a Postgres transaction open.
del preview_df del preview_df
total_rows = getattr(meta, "number_rows", None)
def _filtered_chunks(): def _filtered_chunks():
seen = 0 pbar = tqdm(
total=total_rows,
unit="row",
unit_scale=True,
desc=f" {cfg.filename.name}",
file=sys.stderr,
dynamic_ncols=True,
)
try:
for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename): for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename):
chunk_df = apply_column_filter(chunk_df, cfg.include, cfg.exclude) chunk_df = apply_column_filter(chunk_df, cfg.include, cfg.exclude)
seen += len(chunk_df) pbar.update(len(chunk_df))
print(f" streaming... {seen:,} rows", file=sys.stderr)
yield chunk_df yield chunk_df
finally:
pbar.close()
db_user = db_password = None db_user = db_password = None
if args.dbcreds: if args.dbcreds:

View File

@ -1,9 +1,10 @@
pandas>=2.0,<3.0 pandas>=2.0,<3.0
pyreadstat>=1.2,<2.0 pyreadstat>=1.2,<2.0
numpy>=2.1,<3.0 numpy>=2.1,<3.0
pyarrow>=16.0,<21.0 pyarrow>=22.0,<24.0
pyyaml>=6.0,<7.0 pyyaml>=6.0,<7.0
psycopg2-binary>=2.9,<3.0 psycopg2-binary>=2.9,<3.0
python-dotenv>=1.0,<2.0 python-dotenv>=1.0,<2.0
boto3>=1.28,<2.0 boto3>=1.28,<2.0
openpyxl>=3.1,<4.0 openpyxl>=3.1,<4.0
tqdm>=4.66,<5.0