Update load_sas.py to support streaming data loads with iter_sas_chunks and copy_dataframes. Enhance documentation for schema inference and type detection, clarifying the use of read_sas_preview and the implications of sampling. Add __pycache__ to .gitignore.

This commit is contained in:
David Peterson 2026-04-18 10:44:32 -05:00
parent 3a0537270c
commit 5645ff5597
2 changed files with 192 additions and 77 deletions

View File

@ -1,3 +1,4 @@
/.venv /.venv
/samples /samples
/.env /.env
/__pycache__

View File

@ -107,26 +107,29 @@ loosening is not checked here because the append-mode check already covers it.
5. Library usage 5. Library usage
---------------- ----------------
The CLI is a thin wrapper around composable functions. A typical orchestrator The CLI is a thin wrapper around composable functions. The preferred pattern
looks like:: infers the schema from a bounded preview and then streams the rest of the
file chunk-by-chunk into ``COPY`` - crucial for SAS files with hundreds of
millions of rows::
from dotenv import load_dotenv from dotenv import load_dotenv
from load_sas import ( from load_sas import (
load_config, read_sas, apply_column_filter, infer_schema, load_config, read_sas_preview, iter_sas_chunks, apply_column_filter,
validate_against_manifest, render_create_table, infer_schema, validate_against_manifest, render_create_table,
connect, create_table, copy_dataframe, connect, create_table, copy_dataframes,
) )
load_dotenv() load_dotenv()
cfg = load_config("config.yaml") cfg = load_config("config.yaml")
df, meta = read_sas(cfg.filename)
df = apply_column_filter(df, cfg.include, cfg.exclude)
columns = infer_schema(df, meta)
# Optional: preview # Schema from a preview slice (bounded by TYPE_INFERENCE_SAMPLE_ROWS).
preview_df, meta = read_sas_preview(cfg.filename)
preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude)
total_rows = getattr(meta, "number_rows", None)
columns = infer_schema(preview_df, meta, total_rows=total_rows)
# Optional: preview DDL / validate against a manifest.
print(render_create_table(cfg.schemaname, cfg.tablename, columns)) print(render_create_table(cfg.schemaname, cfg.tablename, columns))
# Optional: validate against a manifest
problems = validate_against_manifest(columns, Path("expected.json")) problems = validate_against_manifest(columns, Path("expected.json"))
assert not problems, problems assert not problems, problems
@ -134,15 +137,23 @@ looks like::
conn.autocommit = False conn.autocommit = False
try: try:
create_table(conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists) create_table(conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists)
rows = copy_dataframe(conn, cfg.schemaname, cfg.tablename, df, columns) chunks = (
apply_column_filter(df, cfg.include, cfg.exclude)
for df, _ in iter_sas_chunks(cfg.filename)
)
rows = copy_dataframes(conn, cfg.schemaname, cfg.tablename, chunks, columns)
conn.commit() conn.commit()
finally: finally:
conn.close() conn.close()
For small files (or tests) the legacy one-shot API still works:
:func:`read_sas` returns the whole frame and :func:`copy_dataframe` copies it
in one round trip.
All functions are side-effect free except :func:`connect`, :func:`create_table`, All functions are side-effect free except :func:`connect`, :func:`create_table`,
and :func:`copy_dataframe`; schema inference (:func:`infer_schema`) accepts a :func:`copy_dataframe`, and :func:`copy_dataframes`; schema inference
``coerce_chars`` kwarg to override the module-level ``COERCE_CHAR_COLUMNS`` (:func:`infer_schema`) accepts a ``coerce_chars`` kwarg to override the
without mutating global state. module-level ``COERCE_CHAR_COLUMNS`` without mutating global state.
6. Type inference summary 6. Type inference summary
------------------------- -------------------------
@ -164,25 +175,32 @@ Priority order used by :func:`infer_schema`:
``DOUBLE PRECISION``. ``DOUBLE PRECISION``.
Type inference scans only the first ``TYPE_INFERENCE_SAMPLE_ROWS`` rows for Type inference scans only the first ``TYPE_INFERENCE_SAMPLE_ROWS`` rows for
performance on large files. Nullability and all-null detection still run over performance on large files. The CLI enforces this at read time via
the full column (they're vectorized and fast) so a ``NOT NULL`` constraint is :func:`read_sas_preview`, so the whole file is never materialized just to pick
never declared for a column that has a null anywhere in the file. Tradeoff: types. Sampled specs carry an ``inferred_from_sample`` marker and the usual
if the first N rows fit ``INTEGER`` but a later row exceeds int32, COPY will tradeoffs: if the first N rows fit ``INTEGER`` but a later row exceeds int32,
fail; bump the sample size or set ``TYPE_INFERENCE_SAMPLE_ROWS = None`` to or a column had no nulls in the preview but does later in the file, ``COPY``
scan the whole column. will fail mid-stream and the whole transaction rolls back. Set
``TYPE_INFERENCE_SAMPLE_ROWS = None`` to scan every row when exact typing
matters more than speed.
Streaming loads use :func:`iter_sas_chunks` + :func:`copy_dataframes`, which
share one cursor and transaction so a failure mid-file rolls back the whole
load.
7. Tunables 7. Tunables
----------- -----------
Module-level knobs at the top of this file: Module-level knobs at the top of this file:
* ``COERCE_CHAR_COLUMNS`` - whether to promote stringly-typed numerics/ * ``COERCE_CHAR_COLUMNS`` - promote stringly-typed numerics / dates
dates (default True). (default True).
* ``CHAR_INFERENCE_MIN_VALUES`` - minimum non-empty sample size before * ``CHAR_INFERENCE_MIN_VALUES`` - minimum non-empty sample size before
char-column coercion is attempted. char-column coercion is attempted.
* ``NUMERIC_INT_RANGE`` - INTEGER bounds; values outside become * ``NUMERIC_INT_RANGE`` - INTEGER bounds; values outside become
``BIGINT``. ``BIGINT``.
* ``TYPE_INFERENCE_SAMPLE_ROWS`` - cap on rows used for type inference * ``TYPE_INFERENCE_SAMPLE_ROWS`` - cap on rows read for type inference
(``None`` = scan the whole column). (``None`` = scan the whole column).
* ``DEFAULT_CHUNK_ROWS`` - rows per streaming COPY chunk.
""" """
from __future__ import annotations from __future__ import annotations
@ -195,7 +213,7 @@ import os
import sys import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple
import pandas as pd import pandas as pd
import psycopg2 import psycopg2
@ -221,13 +239,16 @@ NUMERIC_INT_RANGE = (-2_147_483_648, 2_147_483_647)
"""INTEGER bounds; anything outside becomes BIGINT.""" """INTEGER bounds; anything outside becomes BIGINT."""
TYPE_INFERENCE_SAMPLE_ROWS: Optional[int] = 10_000 TYPE_INFERENCE_SAMPLE_ROWS: Optional[int] = 10_000
"""Cap on rows inspected during per-column type inference. The row-walking """Cap on rows inspected during per-column type inference. Also governs how
helpers (date detection on object columns, string-coercion probes, whole-number many rows :func:`read_sas_preview` pulls from the file for dry-run / validate /
check on numeric columns) operate on ``df.head(TYPE_INFERENCE_SAMPLE_ROWS)`` schema-inference flows. Set to ``None`` to scan every row (and read the whole
instead of the full frame, which matters on SAS files with hundreds of millions file into memory for the preview step - don't do this on multi-hundred-million
of rows. Nullability is still evaluated across the whole column (cheap, row files)."""
vectorized) so ``NOT NULL`` declarations remain safe. Set to ``None`` to scan
every row.""" DEFAULT_CHUNK_ROWS = 100_000
"""Rows per chunk when streaming a SAS file into ``COPY``. Larger values mean
fewer COPY round-trips but more peak memory per chunk; smaller values are
gentler on memory."""
VALID_IF_EXISTS = ("fail", "replace", "append") VALID_IF_EXISTS = ("fail", "replace", "append")
@ -256,6 +277,12 @@ class ColumnSpec:
sas_format: Optional[str] = None sas_format: Optional[str] = None
source_dtype: Optional[str] = None source_dtype: Optional[str] = None
notes: List[str] = field(default_factory=list) notes: List[str] = field(default_factory=list)
sampled: bool = False
"""True when the type was inferred from a bounded preview rather than the
full file. A sampled spec carries the usual sampling risks: a later chunk
could contain a value that exceeds the inferred integer range, doesn't
parse as the inferred type, or is null in a column the preview showed as
non-null - all of which surface as mid-``COPY`` failures."""
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -355,8 +382,8 @@ def load_config(path: Path) -> LoaderConfig:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def read_sas(path: Path) -> Tuple[pd.DataFrame, Any]: def _sas_reader(path: Path) -> Tuple[Any, Dict[str, Any]]:
"""Dispatch to the right pyreadstat reader by extension. """Return ``(pyreadstat_reader, extra_kwargs)`` for ``path``.
Invariants (learned the hard way while building the sample generator): Invariants (learned the hard way while building the sample generator):
@ -364,15 +391,58 @@ def read_sas(path: Path) -> Tuple[pd.DataFrame, Any]:
encoding on XPORT files it wrote itself. encoding on XPORT files it wrote itself.
* ``.sas7bdat`` - explicit ``encoding="latin-1"`` per colleague guidance. * ``.sas7bdat`` - explicit ``encoding="latin-1"`` per colleague guidance.
""" """
path = Path(path) suffix = Path(path).suffix.lower()
suffix = path.suffix.lower()
if suffix in (".xpt", ".xport"): if suffix in (".xpt", ".xport"):
return pyreadstat.read_xport(str(path)) return pyreadstat.read_xport, {}
if suffix == ".sas7bdat": if suffix == ".sas7bdat":
return pyreadstat.read_sas7bdat(str(path), encoding="latin-1") return pyreadstat.read_sas7bdat, {"encoding": "latin-1"}
raise ValueError(f"Unsupported SAS file extension: {suffix}") raise ValueError(f"Unsupported SAS file extension: {suffix}")
def read_sas(path: Path) -> Tuple[pd.DataFrame, Any]:
"""Read an entire SAS file into memory. Only safe for small files.
Kept for backward compatibility and tests; the CLI now uses
:func:`read_sas_preview` + :func:`iter_sas_chunks` so it never materializes
the whole frame at once.
"""
reader, kwargs = _sas_reader(path)
return reader(str(Path(path)), **kwargs)
def read_sas_preview(
path: Path,
*,
rows: Optional[int] = None,
) -> Tuple[pd.DataFrame, Any]:
"""Read the first ``rows`` records from ``path`` plus its metadata.
Defaults to ``TYPE_INFERENCE_SAMPLE_ROWS`` when ``rows`` is not given.
Passing ``rows=None`` with ``TYPE_INFERENCE_SAMPLE_ROWS=None`` reads the
whole file (pyreadstat treats ``row_limit=0`` as unlimited).
"""
reader, kwargs = _sas_reader(path)
effective = rows if rows is not None else TYPE_INFERENCE_SAMPLE_ROWS
row_limit = int(effective) if effective else 0
return reader(str(Path(path)), row_limit=row_limit, **kwargs)
def iter_sas_chunks(
path: Path,
*,
chunksize: int = DEFAULT_CHUNK_ROWS,
):
"""Yield ``(df_chunk, meta)`` tuples for streaming loads.
Thin wrapper over ``pyreadstat.read_file_in_chunks`` that picks the right
underlying reader by extension and threads through our encoding defaults.
"""
reader, kwargs = _sas_reader(path)
yield from pyreadstat.read_file_in_chunks(
reader, str(Path(path)), chunksize=chunksize, **kwargs
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Column filtering # Column filtering
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -552,6 +622,7 @@ def infer_schema(
meta: Any, meta: Any,
*, *,
coerce_chars: bool = COERCE_CHAR_COLUMNS, coerce_chars: bool = COERCE_CHAR_COLUMNS,
total_rows: Optional[int] = None,
) -> Dict[str, ColumnSpec]: ) -> Dict[str, ColumnSpec]:
"""Infer a Postgres column spec for each column in ``df``. """Infer a Postgres column spec for each column in ``df``.
@ -563,18 +634,24 @@ def infer_schema(
``COERCE_CHAR_COLUMNS`` without mutating global state. Internally the ``COERCE_CHAR_COLUMNS`` without mutating global state. Internally the
char-inference helpers still read the constant - a full override would char-inference helpers still read the constant - a full override would
thread the flag through, but the one-knob story here is intentional. thread the flag through, but the one-knob story here is intentional.
``total_rows`` lets callers who already sampled the frame (e.g. via
:func:`read_sas_preview`) report the real file size in the per-column
"inferred from first N of M rows" note. Falls back to ``len(df)``.
""" """
original_formats: Dict[str, str] = dict(getattr(meta, "original_variable_types", {}) or {}) original_formats: Dict[str, str] = dict(getattr(meta, "original_variable_types", {}) or {})
# Row-walking type probes run on a bounded head slice; nullability and the # Row-walking type probes run on a bounded head slice; nullability and the
# all-null check still see every row so NOT NULL declarations stay honest. # all-null check still see every row so NOT NULL declarations stay honest.
total_rows = len(df) df_rows = len(df)
if TYPE_INFERENCE_SAMPLE_ROWS is not None and total_rows > TYPE_INFERENCE_SAMPLE_ROWS: effective_total = total_rows if total_rows is not None else df_rows
if TYPE_INFERENCE_SAMPLE_ROWS is not None and df_rows > TYPE_INFERENCE_SAMPLE_ROWS:
sample_df = df.head(TYPE_INFERENCE_SAMPLE_ROWS) sample_df = df.head(TYPE_INFERENCE_SAMPLE_ROWS)
sampled = True sample_size = TYPE_INFERENCE_SAMPLE_ROWS
else: else:
sample_df = df sample_df = df
sampled = False sample_size = df_rows
sampled = sample_size < effective_total
# Temporarily flip the module-level flag if the caller asked us to. # Temporarily flip the module-level flag if the caller asked us to.
global COERCE_CHAR_COLUMNS global COERCE_CHAR_COLUMNS
@ -614,8 +691,8 @@ def infer_schema(
if sampled: if sampled:
notes.append( notes.append(
f"type inferred from first {TYPE_INFERENCE_SAMPLE_ROWS:,} of " f"type inferred from first {sample_size:,} of "
f"{total_rows:,} rows" f"{effective_total:,} rows"
) )
nullable = _is_nullable(series) nullable = _is_nullable(series)
@ -627,6 +704,7 @@ def infer_schema(
sas_format=sas_format, sas_format=sas_format,
source_dtype=str(series.dtype), source_dtype=str(series.dtype),
notes=notes, notes=notes,
sampled=sampled,
) )
return out return out
finally: finally:
@ -914,6 +992,45 @@ def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.Da
return out return out
def copy_dataframes(
conn,
schema_name: str,
table_name: str,
dfs: Iterable[pd.DataFrame],
columns: Dict[str, ColumnSpec],
) -> int:
"""Stream an iterable of DataFrames into one ``COPY`` session.
All chunks share a cursor and transaction, so a failure mid-stream
rolls back the whole load when the caller hasn't committed yet.
Empty chunks are skipped. Returns the total rows inserted.
"""
col_list = ", ".join(_quote_ident(name) for name in columns.keys())
sql = (
f"COPY {_qualified(schema_name, table_name)} ({col_list}) "
f"FROM STDIN WITH (FORMAT csv, NULL '')"
)
total = 0
with conn.cursor() as cur:
for df in dfs:
if df.empty:
continue
prepared = _prepare_for_copy(df, columns)
buf = io.StringIO()
prepared.to_csv(
buf,
index=False,
header=False,
na_rep="",
date_format="%Y-%m-%d %H:%M:%S",
)
buf.seek(0)
cur.copy_expert(sql, buf)
total += len(prepared)
return total
def copy_dataframe( def copy_dataframe(
conn, conn,
schema_name: str, schema_name: str,
@ -923,31 +1040,10 @@ def copy_dataframe(
) -> int: ) -> int:
"""Stream ``df`` into Postgres via ``COPY ... FROM STDIN``. """Stream ``df`` into Postgres via ``COPY ... FROM STDIN``.
Returns the number of rows inserted. Convenience wrapper around :func:`copy_dataframes` for single-frame
callers. Returns the number of rows inserted.
""" """
prepared = _prepare_for_copy(df, columns) return copy_dataframes(conn, schema_name, table_name, [df], columns)
buf = io.StringIO()
prepared.to_csv(
buf,
index=False,
header=False,
na_rep="",
date_format="%Y-%m-%d %H:%M:%S",
)
buf.seek(0)
col_list = ", ".join(_quote_ident(name) for name in columns.keys())
sql = (
f"COPY {_qualified(schema_name, table_name)} ({col_list}) "
f"FROM STDIN WITH (FORMAT csv, NULL '')"
)
with conn.cursor() as cur:
cur.copy_expert(sql, buf)
rowcount = cur.rowcount
return int(rowcount) if rowcount is not None else len(prepared)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -1060,9 +1156,13 @@ def main(argv: Optional[List[str]] = None) -> int:
print(f"error: SAS file not found: {cfg.filename}", file=sys.stderr) print(f"error: SAS file not found: {cfg.filename}", file=sys.stderr)
return 2 return 2
df, meta = read_sas(cfg.filename) # Schema inference uses a bounded preview read so we never load a
df = apply_column_filter(df, cfg.include, cfg.exclude) # hundreds-of-millions-of-rows file into memory just to pick types.
columns = infer_schema(df, meta) # NB: ``meta.number_rows`` on a ``row_limit``-ed read reflects rows
# returned, not the file's total, so we don't trust it here.
preview_df, meta = read_sas_preview(cfg.filename)
preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude)
columns = infer_schema(preview_df, meta)
if args.validate: if args.validate:
manifest_path = cfg.filename.with_suffix("").with_suffix(".expected.json") manifest_path = cfg.filename.with_suffix("").with_suffix(".expected.json")
@ -1080,11 +1180,25 @@ def main(argv: Optional[List[str]] = None) -> int:
print(render_create_table(cfg.schemaname, cfg.tablename, columns)) print(render_create_table(cfg.schemaname, cfg.tablename, columns))
return 0 return 0
# Release the preview frame before opening the stream - lets the GC reclaim
# it while we're holding a Postgres transaction open.
del preview_df
def _filtered_chunks():
seen = 0
for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename):
chunk_df = apply_column_filter(chunk_df, cfg.include, cfg.exclude)
seen += len(chunk_df)
print(f" streaming... {seen:,} rows", file=sys.stderr)
yield chunk_df
conn = connect() conn = connect()
conn.autocommit = False conn.autocommit = False
try: try:
create_table(conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists) create_table(conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists)
inserted = copy_dataframe(conn, cfg.schemaname, cfg.tablename, df, columns) inserted = copy_dataframes(
conn, cfg.schemaname, cfg.tablename, _filtered_chunks(), columns
)
conn.commit() conn.commit()
except Exception: except Exception:
conn.rollback() conn.rollback()