foxtrot/generic_loader/load_sas.py

1266 lines
45 KiB
Python
Raw Normal View History

2026-04-18 14:34:48 +00:00
"""Per-file SAS-to-Postgres loader.
Library-style functions plus a thin CLI wrapper. Designed so an orchestrator
can wrap the library for directory/batch mode; orchestration is out of scope
here.
Python 3.9 compatible (target is an air-gapped host that currently only has
3.9). ``from __future__ import annotations`` lets us use PEP 585 generics
as annotations; runtime-resolved type uses (dataclass defaults, etc.) stick
to ``typing``.
-------------------------------------------------------------------------------
USAGE
-------------------------------------------------------------------------------
Supported inputs:
* ``.sas7bdat`` (read with ``encoding="latin-1"``)
* ``.xpt`` / ``.xport`` (SAS transport files)
1. YAML config
--------------
Every invocation is driven by a YAML file describing one SAS file to load::
filename: samples/sample_kitchensink.xpt # required; relative paths are
# resolved against the config
# file's directory when possible
schemaname: public # required
tablename: kitchensink # required
# Optional. One of: fail | replace | append. Default: fail.
# fail - error out if the target table already exists
# replace - DROP and recreate the table from the inferred schema
# append - keep the existing table; pre-flight a schema-compat check,
# then COPY the new rows in
if_exists: append
# Optional, mutually exclusive. Restrict which columns are loaded.
# include:
# - ID
# - INTCOL
# exclude:
# - ALLNULL
2. Database connection
----------------------
The loader uses standard libpq environment variables (read via ``os.environ``)::
PGHOST, PGPORT, PGUSER, PGPASSWORD, PGDATABASE
The CLI calls ``python-dotenv``'s ``load_dotenv()`` at startup, so a local
``.env`` file is picked up automatically. Library callers are responsible for
populating the environment themselves (either call ``load_dotenv()`` or export
the vars) before calling :func:`connect`.
3. Command-line interface
-------------------------
::
python load_sas.py --config path/to/config.yaml [--validate] [--dry-run]
[--dbcreds]
Flags:
--config PATH Required. Path to the YAML config above.
--validate Compare the inferred schema against
``<sas-file-stem>.expected.json`` sitting next to the SAS
file. Exits nonzero on mismatch. Safe to combine with
``--dry-run``.
--dry-run Print the inferred ``CREATE TABLE`` SQL and stop. The
database is never touched (no connection is opened).
--dbcreds Prompt interactively for the database username and
password instead of reading ``PGUSER`` / ``PGPASSWORD``
from the environment or ``.env`` file. The password
prompt does not echo. Has no effect with ``--dry-run``
(no connection is opened).
Exit codes:
0 - success (load completed, or dry-run/validate passed)
1 - validation failure
2 - config references a SAS file that does not exist
Other nonzero - uncaught exception (traceback printed); the transaction
is rolled back before exit.
Typical invocations::
# Preview the inferred schema without connecting to Postgres.
python load_sas.py --config sample_config.yaml --dry-run
# Check the inferred schema against an expected-types manifest.
python load_sas.py --config sample_config.yaml --validate --dry-run
# Actually load the data.
python load_sas.py --config sample_config.yaml
# Load the data, prompting for credentials instead of using .env.
python load_sas.py --config sample_config.yaml --dbcreds
4. Expected-types manifest (``--validate``)
-------------------------------------------
``--validate`` looks for a JSON file named ``<sas-stem>.expected.json`` next
to the SAS file, e.g. ``samples/sample_kitchensink.xpt`` pairs with
``samples/sample_kitchensink.expected.json``. Each top-level key is a column
name; the value is an object with any of::
{
"postgres_type": "BIGINT", # exact expected type, OR
"acceptable_types": ["TEXT", # any-of list of acceptable types
"VARCHAR"],
"nullable": true, # default true; false = must be NOT NULL
"note": "free-form comment" # ignored by the loader
}
Type comparison ignores length/precision modifiers and normalizes synonyms
(e.g. ``INT`` == ``INTEGER`` == ``INT4``; ``VARCHAR(10)`` == ``VARCHAR``).
Nullability tightening (inferred NULL, manifest NOT NULL) is a hard failure;
loosening is not checked here because the append-mode check already covers it.
5. Library usage
----------------
The CLI is a thin wrapper around composable functions. The preferred pattern
infers the schema from a bounded preview and then streams the rest of the
file chunk-by-chunk into ``COPY`` - crucial for SAS files with hundreds of
millions of rows::
from dotenv import load_dotenv
from load_sas import (
load_config, read_sas_preview, iter_sas_chunks, apply_column_filter,
infer_schema, validate_against_manifest, render_create_table,
connect, create_table, copy_dataframes,
)
load_dotenv()
cfg = load_config("config.yaml")
# Schema from a preview slice (bounded by TYPE_INFERENCE_SAMPLE_ROWS).
preview_df, meta = read_sas_preview(cfg.filename)
preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude)
total_rows = getattr(meta, "number_rows", None)
columns = infer_schema(preview_df, meta, total_rows=total_rows)
# Optional: preview DDL / validate against a manifest.
print(render_create_table(cfg.schemaname, cfg.tablename, columns))
problems = validate_against_manifest(columns, Path("expected.json"))
assert not problems, problems
conn = connect()
conn.autocommit = False
try:
create_table(conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists)
chunks = (
apply_column_filter(df, cfg.include, cfg.exclude)
for df, _ in iter_sas_chunks(cfg.filename)
)
rows = copy_dataframes(conn, cfg.schemaname, cfg.tablename, chunks, columns)
conn.commit()
finally:
conn.close()
For small files (or tests) the legacy one-shot API still works:
:func:`read_sas` returns the whole frame and :func:`copy_dataframe` copies it
in one round trip.
All functions are side-effect free except :func:`connect`, :func:`create_table`,
:func:`copy_dataframe`, and :func:`copy_dataframes`; schema inference
(:func:`infer_schema`) accepts a ``coerce_chars`` kwarg to override the
module-level ``COERCE_CHAR_COLUMNS`` without mutating global state.
6. Type inference summary
-------------------------
Priority order used by :func:`infer_schema`:
1. SAS format string (via ``meta.original_variable_types``):
``DATETIME*`` -> ``TIMESTAMP``, ``TIME*`` -> ``TIME``,
``DATE*`` / ``YYMMDD*`` / ``MMDDYY*`` / ``DDMMYY*`` / ``JULIAN*`` -> ``DATE``.
2. All-null column -> ``TEXT`` (with a note).
3. pandas datetime dtype -> ``TIMESTAMP``.
4. Object columns containing only ``datetime.date`` / ``datetime.datetime``
-> ``DATE`` or ``TIMESTAMP``.
5. Object columns of strings: if ``COERCE_CHAR_COLUMNS`` is True and at
least ``CHAR_INFERENCE_MIN_VALUES`` non-empty values parse cleanly, they
are promoted to ``INTEGER`` / ``BIGINT`` / ``DOUBLE PRECISION`` /
``DATE`` / ``TIMESTAMP``; otherwise ``TEXT``.
6. Numeric columns of whole numbers -> ``INTEGER`` (or ``BIGINT`` if any
value exceeds the int32 range ``NUMERIC_INT_RANGE``); otherwise
``DOUBLE PRECISION``.
Type inference scans only the first ``TYPE_INFERENCE_SAMPLE_ROWS`` rows for
performance on large files. The CLI enforces this at read time via
:func:`read_sas_preview`, so the whole file is never materialized just to pick
types. Sampled specs carry an ``inferred_from_sample`` marker and the usual
tradeoffs: if the first N rows fit ``INTEGER`` but a later row exceeds int32,
or a column had no nulls in the preview but does later in the file, ``COPY``
will fail mid-stream and the whole transaction rolls back. Set
``TYPE_INFERENCE_SAMPLE_ROWS = None`` to scan every row when exact typing
matters more than speed.
Streaming loads use :func:`iter_sas_chunks` + :func:`copy_dataframes`, which
share one cursor and transaction so a failure mid-file rolls back the whole
load.
7. Tunables
-----------
Module-level knobs at the top of this file:
* ``COERCE_CHAR_COLUMNS`` - promote stringly-typed numerics / dates
(default True).
* ``CHAR_INFERENCE_MIN_VALUES`` - minimum non-empty sample size before
char-column coercion is attempted.
* ``NUMERIC_INT_RANGE`` - INTEGER bounds; values outside become
``BIGINT``.
* ``TYPE_INFERENCE_SAMPLE_ROWS`` - cap on rows read for type inference
(``None`` = scan the whole column).
* ``DEFAULT_CHUNK_ROWS`` - rows per streaming COPY chunk.
2026-04-18 14:34:48 +00:00
"""
from __future__ import annotations
import argparse
import datetime as dt
2026-04-18 17:37:22 +00:00
import getpass
2026-04-18 14:34:48 +00:00
import io
import json
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple
2026-04-18 14:34:48 +00:00
import pandas as pd
import psycopg2
import psycopg2.extensions
import pyreadstat
import yaml
from dotenv import load_dotenv
# ---------------------------------------------------------------------------
# Top-level tunables
# ---------------------------------------------------------------------------
COERCE_CHAR_COLUMNS = True
"""If True, try to promote object (string) columns to numeric/date/timestamp
when every non-empty value parses cleanly."""
CHAR_INFERENCE_MIN_VALUES = 3
"""Don't attempt character-column coercion with fewer than this many non-empty
values; too small a sample is easy to mis-infer."""
NUMERIC_INT_RANGE = (-2_147_483_648, 2_147_483_647)
"""INTEGER bounds; anything outside becomes BIGINT."""
TYPE_INFERENCE_SAMPLE_ROWS: Optional[int] = 10_000
"""Cap on rows inspected during per-column type inference. Also governs how
many rows :func:`read_sas_preview` pulls from the file for dry-run / validate /
schema-inference flows. Set to ``None`` to scan every row (and read the whole
file into memory for the preview step - don't do this on multi-hundred-million
row files)."""
DEFAULT_CHUNK_ROWS = 100_000
"""Rows per chunk when streaming a SAS file into ``COPY``. Larger values mean
fewer COPY round-trips but more peak memory per chunk; smaller values are
gentler on memory."""
2026-04-18 14:34:48 +00:00
VALID_IF_EXISTS = ("fail", "replace", "append")
# ---------------------------------------------------------------------------
# Dataclasses
# ---------------------------------------------------------------------------
@dataclass
class LoaderConfig:
filename: Path
schemaname: str
tablename: str
if_exists: str = "fail"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
@dataclass
class ColumnSpec:
name: str
postgres_type: str
nullable: bool
sas_format: Optional[str] = None
source_dtype: Optional[str] = None
notes: List[str] = field(default_factory=list)
sampled: bool = False
"""True when the type was inferred from a bounded preview rather than the
full file. A sampled spec carries the usual sampling risks: a later chunk
could contain a value that exceeds the inferred integer range, doesn't
parse as the inferred type, or is null in a column the preview showed as
non-null - all of which surface as mid-``COPY`` failures."""
2026-04-18 14:34:48 +00:00
# ---------------------------------------------------------------------------
# Custom exceptions
# ---------------------------------------------------------------------------
class TableExistsError(RuntimeError):
"""Raised when if_exists=fail and the target table already exists."""
class SchemaCompatibilityError(RuntimeError):
"""Raised when if_exists=append and the incoming schema is not
compatible with the existing table."""
class ValidationError(RuntimeError):
"""Raised when --validate detects a mismatch against the manifest."""
# ---------------------------------------------------------------------------
# Connection
# ---------------------------------------------------------------------------
2026-04-18 17:37:22 +00:00
def connect(
*,
user: Optional[str] = None,
password: Optional[str] = None,
) -> psycopg2.extensions.connection:
2026-04-18 14:34:48 +00:00
"""Open a psycopg2 connection using standard libpq env vars.
Assumes `.env` has already been loaded (the CLI does this before calling).
Orchestrators that wrap this module should either call ``load_dotenv()``
themselves or ensure the env vars are set.
2026-04-18 17:37:22 +00:00
``user`` and ``password`` override the corresponding env vars when supplied
(used by the ``--dbcreds`` CLI flag to accept interactive input).
2026-04-18 14:34:48 +00:00
"""
conn = psycopg2.connect(
host=os.environ.get("PGHOST"),
port=os.environ.get("PGPORT"),
2026-04-18 17:37:22 +00:00
user=user or os.environ.get("PGUSER"),
password=password or os.environ.get("PGPASSWORD"),
2026-04-18 14:34:48 +00:00
dbname=os.environ.get("PGDATABASE"),
)
return conn
# ---------------------------------------------------------------------------
# Config loading
# ---------------------------------------------------------------------------
def load_config(path: Path) -> LoaderConfig:
"""Parse and validate the YAML config at ``path``."""
path = Path(path)
with path.open("r", encoding="utf-8") as f:
raw = yaml.safe_load(f)
if not isinstance(raw, dict):
raise ValueError(f"Config at {path} must be a YAML mapping at the top level.")
missing = [k for k in ("filename", "schemaname", "tablename") if k not in raw]
if missing:
raise ValueError(f"Config {path} missing required keys: {', '.join(missing)}")
filename = Path(raw["filename"])
if not filename.is_absolute():
filename = (path.parent / filename).resolve() if (path.parent / filename).exists() else Path(raw["filename"])
schemaname = str(raw["schemaname"])
tablename = str(raw["tablename"])
if_exists = str(raw.get("if_exists", "fail")).lower()
if if_exists not in VALID_IF_EXISTS:
raise ValueError(
f"Config {path}: if_exists={if_exists!r} is not one of {VALID_IF_EXISTS}"
)
include = raw.get("include")
exclude = raw.get("exclude")
if include is not None and exclude is not None:
raise ValueError(
f"Config {path}: 'include' and 'exclude' are mutually exclusive; set at most one."
)
if include is not None and not isinstance(include, list):
raise ValueError(f"Config {path}: 'include' must be a list of column names.")
if exclude is not None and not isinstance(exclude, list):
raise ValueError(f"Config {path}: 'exclude' must be a list of column names.")
return LoaderConfig(
filename=filename,
schemaname=schemaname,
tablename=tablename,
if_exists=if_exists,
include=[str(c) for c in include] if include is not None else None,
exclude=[str(c) for c in exclude] if exclude is not None else None,
)
# ---------------------------------------------------------------------------
# Reader
# ---------------------------------------------------------------------------
def _sas_reader(path: Path) -> Tuple[Any, Dict[str, Any]]:
"""Return ``(pyreadstat_reader, extra_kwargs)`` for ``path``.
2026-04-18 14:34:48 +00:00
Invariants (learned the hard way while building the sample generator):
* ``.xpt`` / ``.xport`` - no encoding arg; pyreadstat is flaky about
encoding on XPORT files it wrote itself.
* ``.sas7bdat`` - explicit ``encoding="latin-1"`` per colleague guidance.
"""
suffix = Path(path).suffix.lower()
2026-04-18 14:34:48 +00:00
if suffix in (".xpt", ".xport"):
return pyreadstat.read_xport, {}
2026-04-18 14:34:48 +00:00
if suffix == ".sas7bdat":
return pyreadstat.read_sas7bdat, {}
2026-04-18 14:34:48 +00:00
raise ValueError(f"Unsupported SAS file extension: {suffix}")
def read_sas(path: Path) -> Tuple[pd.DataFrame, Any]:
"""Read an entire SAS file into memory. Only safe for small files.
Kept for backward compatibility and tests; the CLI now uses
:func:`read_sas_preview` + :func:`iter_sas_chunks` so it never materializes
the whole frame at once.
"""
reader, kwargs = _sas_reader(path)
return reader(str(Path(path)), **kwargs)
def read_sas_preview(
path: Path,
*,
rows: Optional[int] = None,
) -> Tuple[pd.DataFrame, Any]:
"""Read the first ``rows`` records from ``path`` plus its metadata.
Defaults to ``TYPE_INFERENCE_SAMPLE_ROWS`` when ``rows`` is not given.
Passing ``rows=None`` with ``TYPE_INFERENCE_SAMPLE_ROWS=None`` reads the
whole file (pyreadstat treats ``row_limit=0`` as unlimited).
"""
reader, kwargs = _sas_reader(path)
effective = rows if rows is not None else TYPE_INFERENCE_SAMPLE_ROWS
row_limit = int(effective) if effective else 0
return reader(str(Path(path)), row_limit=row_limit, **kwargs)
def iter_sas_chunks(
path: Path,
*,
chunksize: int = DEFAULT_CHUNK_ROWS,
):
"""Yield ``(df_chunk, meta)`` tuples for streaming loads.
Thin wrapper over ``pyreadstat.read_file_in_chunks`` that picks the right
underlying reader by extension and threads through our encoding defaults.
"""
reader, kwargs = _sas_reader(path)
yield from pyreadstat.read_file_in_chunks(
reader, str(Path(path)), chunksize=chunksize, **kwargs
)
2026-04-18 14:34:48 +00:00
# ---------------------------------------------------------------------------
# Column filtering
# ---------------------------------------------------------------------------
def apply_column_filter(
df: pd.DataFrame,
include: Optional[List[str]],
exclude: Optional[List[str]],
) -> pd.DataFrame:
"""Restrict ``df`` to the requested columns. Names missing from the frame
raise a clear error rather than silently dropping."""
if include is not None and exclude is not None:
raise ValueError("include and exclude are mutually exclusive.")
if include is not None:
missing = [c for c in include if c not in df.columns]
if missing:
raise ValueError(f"include references unknown columns: {missing}")
return df.loc[:, list(include)].copy()
if exclude is not None:
missing = [c for c in exclude if c not in df.columns]
if missing:
raise ValueError(f"exclude references unknown columns: {missing}")
return df.drop(columns=list(exclude)).copy()
return df.copy()
# ---------------------------------------------------------------------------
# Type inference
# ---------------------------------------------------------------------------
_DATE_FORMAT_PREFIXES = ("DATE", "YYMMDD", "MMDDYY", "DDMMYY", "JULIAN")
def _format_driven_type(sas_format: Optional[str]) -> Optional[str]:
"""Return a Postgres type inferred from the SAS format string, or None
if the format doesn't pin it down."""
if not sas_format:
return None
fmt = sas_format.upper().lstrip()
# DATETIME must be checked before DATE since "DATETIME20." starts with "DATE".
if fmt.startswith("DATETIME"):
return "TIMESTAMP"
if fmt.startswith("TIME"):
return "TIME"
for prefix in _DATE_FORMAT_PREFIXES:
if fmt.startswith(prefix):
return "DATE"
return None
def _all_null(series: pd.Series) -> bool:
if pd.api.types.is_object_dtype(series):
return bool(series.map(lambda v: v is None or (isinstance(v, str) and v == "") or (isinstance(v, float) and pd.isna(v))).all())
return bool(series.isna().all())
def _char_missing_mask(series: pd.Series) -> pd.Series:
return series.map(lambda v: v is None or (isinstance(v, float) and pd.isna(v)) or (isinstance(v, str) and v == ""))
def _is_nullable(series: pd.Series) -> bool:
"""True if the column has at least one missing value."""
if pd.api.types.is_object_dtype(series):
return bool(_char_missing_mask(series).any())
return bool(series.isna().any())
def _numeric_int_target(series: pd.Series) -> Optional[str]:
"""Given a numeric (float64) series, if every non-null value is a whole
number, return INTEGER or BIGINT depending on range; else None."""
nonnull = series.dropna()
if nonnull.empty:
return None
# Whole-number test. Guard against inf.
try:
whole = ((nonnull % 1) == 0).all()
except TypeError:
return None
if not whole:
return None
lo, hi = NUMERIC_INT_RANGE
vmin = nonnull.min()
vmax = nonnull.max()
if lo <= vmin and vmax <= hi:
return "INTEGER"
return "BIGINT"
def _object_is_dates(series: pd.Series) -> Tuple[bool, bool]:
"""Return (all-date-like, any-datetime). If every non-null value is a
``datetime.date`` / ``datetime.datetime`` / ``pd.Timestamp``, return True
plus whether at least one carries a time component."""
nonnull = series.dropna()
if nonnull.empty:
return False, False
any_datetime = False
for v in nonnull:
if isinstance(v, dt.datetime) or isinstance(v, pd.Timestamp):
any_datetime = True
continue
if isinstance(v, dt.date):
continue
return False, False
return True, any_datetime
def _try_int_coerce(values: List[str]) -> Optional[str]:
"""If every value parses as an int, return INTEGER/BIGINT, else None."""
ints: List[int] = []
for v in values:
s = v.strip()
try:
ints.append(int(s))
except ValueError:
return None
if not ints:
return None
lo, hi = NUMERIC_INT_RANGE
if all(lo <= i <= hi for i in ints):
return "INTEGER"
return "BIGINT"
def _try_float_coerce(values: List[str]) -> bool:
for v in values:
try:
float(v)
except ValueError:
return False
return True
def _try_date_coerce(values: List[str]) -> bool:
for v in values:
try:
dt.date.fromisoformat(v)
except (ValueError, TypeError):
return False
return True
def _try_datetime_coerce(values: List[str]) -> bool:
for v in values:
try:
dt.datetime.fromisoformat(v)
except (ValueError, TypeError):
return False
return True
def _infer_char_type(series: pd.Series) -> str:
"""Object/string column inference. Returns a Postgres type string."""
mask = _char_missing_mask(series)
nonempty = [str(v) for v in series[~mask].tolist()]
if not COERCE_CHAR_COLUMNS or len(nonempty) < CHAR_INFERENCE_MIN_VALUES:
return "TEXT"
int_guess = _try_int_coerce(nonempty)
if int_guess is not None:
return int_guess
if _try_float_coerce(nonempty):
return "DOUBLE PRECISION"
if _try_date_coerce(nonempty):
return "DATE"
if _try_datetime_coerce(nonempty):
return "TIMESTAMP"
return "TEXT"
def infer_schema(
df: pd.DataFrame,
meta: Any,
*,
coerce_chars: bool = COERCE_CHAR_COLUMNS,
total_rows: Optional[int] = None,
2026-04-18 14:34:48 +00:00
) -> Dict[str, ColumnSpec]:
"""Infer a Postgres column spec for each column in ``df``.
``meta`` is the pyreadstat metadata object; we read
``meta.original_variable_types`` (a dict keyed by column name) for
format-driven date/time/timestamp inference.
The ``coerce_chars`` kwarg lets callers override the module-level
``COERCE_CHAR_COLUMNS`` without mutating global state. Internally the
char-inference helpers still read the constant - a full override would
thread the flag through, but the one-knob story here is intentional.
``total_rows`` lets callers who already sampled the frame (e.g. via
:func:`read_sas_preview`) report the real file size in the per-column
"inferred from first N of M rows" note. Falls back to ``len(df)``.
2026-04-18 14:34:48 +00:00
"""
original_formats: Dict[str, str] = dict(getattr(meta, "original_variable_types", {}) or {})
# Row-walking type probes run on a bounded head slice; nullability and the
# all-null check still see every row so NOT NULL declarations stay honest.
df_rows = len(df)
effective_total = total_rows if total_rows is not None else df_rows
if TYPE_INFERENCE_SAMPLE_ROWS is not None and df_rows > TYPE_INFERENCE_SAMPLE_ROWS:
sample_df = df.head(TYPE_INFERENCE_SAMPLE_ROWS)
sample_size = TYPE_INFERENCE_SAMPLE_ROWS
else:
sample_df = df
sample_size = df_rows
sampled = sample_size < effective_total
2026-04-18 14:34:48 +00:00
# Temporarily flip the module-level flag if the caller asked us to.
global COERCE_CHAR_COLUMNS
saved = COERCE_CHAR_COLUMNS
COERCE_CHAR_COLUMNS = coerce_chars
try:
out: Dict[str, ColumnSpec] = {}
for col in df.columns:
series = df[col]
sample_series = sample_df[col]
2026-04-18 14:34:48 +00:00
sas_format = original_formats.get(col)
notes: List[str] = []
pg_type = _format_driven_type(sas_format)
if pg_type is None:
if _all_null(series):
pg_type = "TEXT"
notes.append("all-null column; defaulting to TEXT")
elif pd.api.types.is_datetime64_any_dtype(series):
pg_type = "TIMESTAMP"
elif pd.api.types.is_object_dtype(series):
is_dates, any_dt = _object_is_dates(sample_series)
2026-04-18 14:34:48 +00:00
if is_dates:
pg_type = "TIMESTAMP" if any_dt else "DATE"
else:
pg_type = _infer_char_type(sample_series)
2026-04-18 14:34:48 +00:00
elif pd.api.types.is_numeric_dtype(series):
int_target = _numeric_int_target(sample_series)
2026-04-18 14:34:48 +00:00
if int_target is not None:
pg_type = int_target
else:
pg_type = "DOUBLE PRECISION"
else:
pg_type = "TEXT"
notes.append(f"unhandled dtype {series.dtype}; defaulting to TEXT")
if sampled:
notes.append(
f"type inferred from first {sample_size:,} of "
f"{effective_total:,} rows"
)
2026-04-18 14:34:48 +00:00
nullable = _is_nullable(series)
out[col] = ColumnSpec(
name=col,
postgres_type=pg_type,
nullable=nullable,
sas_format=sas_format,
source_dtype=str(series.dtype),
notes=notes,
sampled=sampled,
2026-04-18 14:34:48 +00:00
)
return out
finally:
COERCE_CHAR_COLUMNS = saved
# ---------------------------------------------------------------------------
# Table management
# ---------------------------------------------------------------------------
def _quote_ident(ident: str) -> str:
"""Quote a Postgres identifier. psycopg2 doesn't expose this directly
until 2.8+ with sql.Identifier; we do it by hand to stay driver-simple."""
return '"' + ident.replace('"', '""') + '"'
def _qualified(schema: str, table: str) -> str:
return f"{_quote_ident(schema)}.{_quote_ident(table)}"
def _table_exists(conn, schema: str, table: str) -> bool:
with conn.cursor() as cur:
cur.execute(
"SELECT 1 FROM information_schema.tables "
"WHERE table_schema = %s AND table_name = %s",
(schema, table),
)
return cur.fetchone() is not None
def render_create_table(schema: str, table: str, columns: Dict[str, ColumnSpec]) -> str:
lines = []
for spec in columns.values():
null_clause = "" if spec.nullable else " NOT NULL"
lines.append(f" {_quote_ident(spec.name)} {spec.postgres_type}{null_clause}")
body = ",\n".join(lines)
return f"CREATE TABLE {_qualified(schema, table)} (\n{body}\n);"
def _create_table_sql(conn, schema: str, table: str, columns: Dict[str, ColumnSpec]) -> None:
sql = render_create_table(schema, table, columns)
with conn.cursor() as cur:
cur.execute(sql)
def _drop_table(conn, schema: str, table: str) -> None:
with conn.cursor() as cur:
cur.execute(f"DROP TABLE {_qualified(schema, table)}")
# Normalization table: map both loader-emitted and Postgres-reported type
# strings to a single canonical family name. Ignore length/precision
# modifiers like VARCHAR(n) and NUMERIC(p,s).
_TYPE_NORMALIZATION: Dict[str, str] = {
"INTEGER": "integer",
"INT": "integer",
"INT4": "integer",
"BIGINT": "bigint",
"INT8": "bigint",
"SMALLINT": "smallint",
"INT2": "smallint",
"DOUBLE PRECISION": "double precision",
"FLOAT8": "double precision",
"REAL": "real",
"FLOAT4": "real",
"NUMERIC": "numeric",
"DECIMAL": "numeric",
"TEXT": "text",
"VARCHAR": "character varying",
"CHARACTER VARYING": "character varying",
"CHAR": "character",
"CHARACTER": "character",
"BPCHAR": "character",
"BOOLEAN": "boolean",
"BOOL": "boolean",
"DATE": "date",
"TIMESTAMP": "timestamp without time zone",
"TIMESTAMP WITHOUT TIME ZONE": "timestamp without time zone",
"TIMESTAMPTZ": "timestamp with time zone",
"TIMESTAMP WITH TIME ZONE": "timestamp with time zone",
"TIME": "time without time zone",
"TIME WITHOUT TIME ZONE": "time without time zone",
"TIMETZ": "time with time zone",
"TIME WITH TIME ZONE": "time with time zone",
}
def _normalize_type(pg_type: str) -> str:
"""Strip length/precision modifiers and map to canonical family."""
stripped = pg_type.strip().upper()
# Remove trailing (n) / (p,s) before the space-separated tail.
# Examples: "VARCHAR(10)" -> "VARCHAR"; "TIMESTAMP(6) WITHOUT TIME ZONE" -> "TIMESTAMP WITHOUT TIME ZONE"
import re
stripped = re.sub(r"\(\s*\d+\s*(?:,\s*\d+\s*)?\)", "", stripped).strip()
# Collapse doubled whitespace after paren removal.
stripped = re.sub(r"\s+", " ", stripped)
return _TYPE_NORMALIZATION.get(stripped, stripped.lower())
def _assert_schema_compatible(
conn, schema: str, table: str, columns: Dict[str, ColumnSpec]
) -> None:
"""Pre-flight check for if_exists=append. See plan section on option B."""
with conn.cursor() as cur:
cur.execute(
"SELECT column_name, data_type, is_nullable "
"FROM information_schema.columns "
"WHERE table_schema = %s AND table_name = %s",
(schema, table),
)
existing = {row[0]: (row[1], row[2]) for row in cur.fetchall()}
mismatches: List[str] = []
warnings: List[str] = []
for name, spec in columns.items():
if name not in existing:
mismatches.append(
f"column {name!r} not present in target {schema}.{table}"
)
continue
target_type, target_nullable = existing[name]
inferred_norm = _normalize_type(spec.postgres_type)
target_norm = _normalize_type(target_type)
if inferred_norm != target_norm:
mismatches.append(
f"column {name!r}: inferred {spec.postgres_type} "
f"(normalized {inferred_norm!r}) but target is {target_type} "
f"(normalized {target_norm!r})"
)
target_is_notnull = (target_nullable == "NO")
if spec.nullable and target_is_notnull:
warnings.append(
f"column {name!r}: incoming allows NULLs but target is NOT NULL; "
"COPY will fail if any NULLs appear"
)
for w in warnings:
print(f"[warn] {w}", file=sys.stderr)
if mismatches:
raise SchemaCompatibilityError(
"append-mode schema compatibility check failed:\n - "
+ "\n - ".join(mismatches)
)
def assert_schema_compatible(
conn,
schema_name: str,
table_name: str,
columns: Dict[str, ColumnSpec],
) -> None:
"""Public wrapper around :func:`_assert_schema_compatible`.
Intended for orchestrators (e.g. the folder loader) that append multiple
files into one table and need to re-run the same compatibility check
that ``if_exists=append`` performs internally. Raises
:class:`SchemaCompatibilityError` on mismatch.
"""
_assert_schema_compatible(conn, schema_name, table_name, columns)
2026-04-18 14:34:48 +00:00
def create_table(
conn,
schema_name: str,
table_name: str,
columns: Dict[str, ColumnSpec],
if_exists: str,
) -> None:
"""Create (or verify) the target table according to ``if_exists``."""
if if_exists not in VALID_IF_EXISTS:
raise ValueError(f"if_exists must be one of {VALID_IF_EXISTS}, got {if_exists!r}")
exists = _table_exists(conn, schema_name, table_name)
if exists:
if if_exists == "fail":
raise TableExistsError(
f"Table {schema_name}.{table_name} already exists and if_exists=fail"
)
if if_exists == "replace":
_drop_table(conn, schema_name, table_name)
_create_table_sql(conn, schema_name, table_name, columns)
return
if if_exists == "append":
_assert_schema_compatible(conn, schema_name, table_name, columns)
return
else:
_create_table_sql(conn, schema_name, table_name, columns)
# ---------------------------------------------------------------------------
# COPY loading
# ---------------------------------------------------------------------------
def _seconds_to_time(v: Any) -> Optional[dt.time]:
if v is None:
return None
if isinstance(v, float) and pd.isna(v):
return None
if isinstance(v, dt.time):
return v
if isinstance(v, (dt.datetime, pd.Timestamp)):
return v.time() if not pd.isna(v) else None
try:
total = int(round(float(v)))
except (TypeError, ValueError):
return None
h, rem = divmod(total, 3600)
m, s = divmod(rem, 60)
# Clamp; TIME8. is always within a day.
h = max(0, min(h, 23))
return dt.time(h, m, s)
def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.DataFrame:
"""Materialize a copy of ``df`` with each column in the right shape for
``to_csv`` so the CSV lands as valid input for the target Postgres type.
"""
out = pd.DataFrame(index=df.index)
for name, spec in columns.items():
series = df[name]
pg = spec.postgres_type.upper()
if pg in ("INTEGER", "BIGINT", "SMALLINT"):
if pd.api.types.is_object_dtype(series):
series = pd.to_numeric(
series.replace({"": None}), errors="coerce"
)
out[name] = series.astype("Int64")
elif pg in ("DOUBLE PRECISION", "REAL", "NUMERIC"):
if pd.api.types.is_object_dtype(series):
series = pd.to_numeric(
series.replace({"": None}), errors="coerce"
)
out[name] = series.astype("float64")
elif pg == "DATE":
if pd.api.types.is_datetime64_any_dtype(series):
out[name] = series.dt.date
elif pd.api.types.is_object_dtype(series):
def _to_date(v: Any) -> Optional[dt.date]:
if v is None or (isinstance(v, float) and pd.isna(v)):
return None
if isinstance(v, dt.datetime):
return v.date()
if isinstance(v, dt.date):
return v
if isinstance(v, str):
if v == "":
return None
try:
return dt.date.fromisoformat(v)
except ValueError:
return None
return None
out[name] = series.map(_to_date)
else:
out[name] = series
elif pg in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE", "TIMESTAMP WITH TIME ZONE"):
if pd.api.types.is_datetime64_any_dtype(series):
out[name] = series
elif pd.api.types.is_object_dtype(series):
def _to_dt(v: Any) -> Optional[dt.datetime]:
if v is None or (isinstance(v, float) and pd.isna(v)):
return None
if isinstance(v, dt.datetime):
return v
if isinstance(v, dt.date):
return dt.datetime(v.year, v.month, v.day)
if isinstance(v, pd.Timestamp):
return v.to_pydatetime() if not pd.isna(v) else None
if isinstance(v, str):
if v == "":
return None
try:
return dt.datetime.fromisoformat(v)
except ValueError:
return None
return None
out[name] = series.map(_to_dt)
else:
out[name] = series
elif pg in ("TIME", "TIME WITHOUT TIME ZONE", "TIME WITH TIME ZONE"):
out[name] = series.map(_seconds_to_time)
elif pg in ("TEXT", "VARCHAR", "CHARACTER VARYING", "CHAR", "CHARACTER"):
# Leave empty strings as "" so `NULL ''` in COPY turns them into NULL.
def _to_str(v: Any) -> Any:
if v is None:
return ""
if isinstance(v, float) and pd.isna(v):
return ""
return str(v)
out[name] = series.map(_to_str)
elif pg == "BOOLEAN":
out[name] = series.astype("boolean") if series.dtype != object else series
else:
out[name] = series
return out
def copy_dataframes(
2026-04-18 14:34:48 +00:00
conn,
schema_name: str,
table_name: str,
dfs: Iterable[pd.DataFrame],
2026-04-18 14:34:48 +00:00
columns: Dict[str, ColumnSpec],
) -> int:
"""Stream an iterable of DataFrames into one ``COPY`` session.
2026-04-18 14:34:48 +00:00
All chunks share a cursor and transaction, so a failure mid-stream
rolls back the whole load when the caller hasn't committed yet.
Empty chunks are skipped. Returns the total rows inserted.
2026-04-18 14:34:48 +00:00
"""
col_list = ", ".join(_quote_ident(name) for name in columns.keys())
sql = (
f"COPY {_qualified(schema_name, table_name)} ({col_list}) "
f"FROM STDIN WITH (FORMAT csv, NULL '')"
)
total = 0
2026-04-18 14:34:48 +00:00
with conn.cursor() as cur:
for df in dfs:
if df.empty:
continue
prepared = _prepare_for_copy(df, columns)
buf = io.StringIO()
prepared.to_csv(
buf,
index=False,
header=False,
na_rep="",
date_format="%Y-%m-%d %H:%M:%S",
)
buf.seek(0)
cur.copy_expert(sql, buf)
total += len(prepared)
return total
def copy_dataframe(
conn,
schema_name: str,
table_name: str,
df: pd.DataFrame,
columns: Dict[str, ColumnSpec],
) -> int:
"""Stream ``df`` into Postgres via ``COPY ... FROM STDIN``.
2026-04-18 14:34:48 +00:00
Convenience wrapper around :func:`copy_dataframes` for single-frame
callers. Returns the number of rows inserted.
"""
return copy_dataframes(conn, schema_name, table_name, [df], columns)
2026-04-18 14:34:48 +00:00
# ---------------------------------------------------------------------------
# Manifest validation
# ---------------------------------------------------------------------------
def _match_manifest_type(inferred: str, manifest_entry: Dict[str, Any]) -> bool:
inferred_norm = _normalize_type(inferred)
if "postgres_type" in manifest_entry:
return inferred_norm == _normalize_type(manifest_entry["postgres_type"])
if "acceptable_types" in manifest_entry:
return any(
inferred_norm == _normalize_type(t)
for t in manifest_entry["acceptable_types"]
)
return False
def validate_against_manifest(
inferred: Dict[str, ColumnSpec],
manifest_path: Path,
) -> List[str]:
"""Compare the inferred schema against the expected-types manifest.
Returns a list of human-readable problem strings; empty list means OK.
"""
manifest_path = Path(manifest_path)
if not manifest_path.exists():
return [f"manifest not found: {manifest_path}"]
with manifest_path.open("r", encoding="utf-8") as f:
manifest = json.load(f)
problems: List[str] = []
only_in_inferred = set(inferred) - set(manifest)
only_in_manifest = set(manifest) - set(inferred)
if only_in_inferred:
problems.append(
f"columns in inferred but not manifest: {sorted(only_in_inferred)}"
)
if only_in_manifest:
problems.append(
f"columns in manifest but not inferred: {sorted(only_in_manifest)}"
)
for name, spec in inferred.items():
entry = manifest.get(name)
if entry is None:
continue
if not _match_manifest_type(spec.postgres_type, entry):
expected = entry.get("postgres_type") or entry.get("acceptable_types")
problems.append(
f"column {name!r}: inferred {spec.postgres_type!r}, "
f"manifest expected {expected!r}"
)
manifest_nullable = bool(entry.get("nullable", True))
if spec.nullable and not manifest_nullable:
problems.append(
f"column {name!r}: inferred nullable, manifest expects NOT NULL "
f"(loosening nullability is never allowed)"
)
return problems
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
description="Load a single SAS file (XPT or sas7bdat) into Postgres.",
)
p.add_argument("--config", required=True, type=Path, help="Path to YAML config")
p.add_argument(
"--validate",
action="store_true",
help=(
"Compare inferred schema against <filename-stem>.expected.json "
"next to the SAS file; exits nonzero on mismatch."
),
)
p.add_argument(
"--dry-run",
action="store_true",
help="Print inferred CREATE TABLE and stop; don't touch Postgres.",
)
2026-04-18 17:37:22 +00:00
p.add_argument(
"--dbcreds",
action="store_true",
help=(
"Prompt for database username and password instead of reading "
"PGUSER / PGPASSWORD from the environment or .env file."
),
)
2026-04-18 14:34:48 +00:00
return p
def _format_columns_summary(columns: Dict[str, ColumnSpec]) -> str:
lines = []
for spec in columns.values():
null = "" if spec.nullable else " NOT NULL"
lines.append(f" {spec.name}: {spec.postgres_type}{null}")
return "\n".join(lines)
def main(argv: Optional[List[str]] = None) -> int:
args = _build_argparser().parse_args(argv)
load_dotenv()
cfg = load_config(args.config)
if not cfg.filename.exists():
print(f"error: SAS file not found: {cfg.filename}", file=sys.stderr)
return 2
# Schema inference uses a bounded preview read so we never load a
# hundreds-of-millions-of-rows file into memory just to pick types.
# NB: ``meta.number_rows`` on a ``row_limit``-ed read reflects rows
# returned, not the file's total, so we don't trust it here.
preview_df, meta = read_sas_preview(cfg.filename)
preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude)
columns = infer_schema(preview_df, meta)
2026-04-18 14:34:48 +00:00
if args.validate:
manifest_path = cfg.filename.with_suffix("").with_suffix(".expected.json")
# The above strips .xpt then appends .expected.json, e.g.
# "sample_kitchensink.xpt" -> "sample_kitchensink.expected.json".
problems = validate_against_manifest(columns, manifest_path)
if problems:
print("validation failed:", file=sys.stderr)
for p in problems:
print(f" - {p}", file=sys.stderr)
return 1
print(f"validation OK ({len(columns)} columns match {manifest_path.name})")
if args.dry_run:
print(render_create_table(cfg.schemaname, cfg.tablename, columns))
return 0
# Release the preview frame before opening the stream - lets the GC reclaim
# it while we're holding a Postgres transaction open.
del preview_df
def _filtered_chunks():
seen = 0
for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename):
chunk_df = apply_column_filter(chunk_df, cfg.include, cfg.exclude)
seen += len(chunk_df)
print(f" streaming... {seen:,} rows", file=sys.stderr)
yield chunk_df
2026-04-18 17:37:22 +00:00
db_user = db_password = None
if args.dbcreds:
db_user = input("Database username: ")
db_password = getpass.getpass("Database password: ")
conn = connect(user=db_user, password=db_password)
2026-04-18 14:34:48 +00:00
conn.autocommit = False
try:
create_table(conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists)
inserted = copy_dataframes(
conn, cfg.schemaname, cfg.tablename, _filtered_chunks(), columns
)
2026-04-18 14:34:48 +00:00
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
print(
f"loaded {inserted} rows into {cfg.schemaname}.{cfg.tablename} "
f"({len(columns)} columns)"
)
print("final schema:")
print(_format_columns_summary(columns))
return 0
if __name__ == "__main__":
sys.exit(main())