Adding generic loader

This commit is contained in:
michael-corey 2026-04-18 09:34:48 -05:00
commit f681f1012a
6 changed files with 1315 additions and 0 deletions

View File

@ -0,0 +1,5 @@
PGHOST=localhost
PGPORT=5432
PGUSER=
PGPASSWORD=
PGDATABASE=

3
generic_loader/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
/.venv
/samples
/.env

View File

@ -0,0 +1,380 @@
"""Generate a kitchen-sink SAS XPORT file plus an expected-types manifest.
Running this script produces two files under samples/:
- sample_kitchensink.xpt the SAS XPORT test fixture
- sample_kitchensink.expected.json ground-truth Postgres types for the loader
Tune behavior via the top-level constants below.
"""
from __future__ import annotations
import datetime as dt
import json
import string
from pathlib import Path
import numpy as np
import pandas as pd
import pyreadstat
SEED = 42
N_ROWS = 1000
NULL_FRACTION = 0.20
OUT_DIR = Path("samples")
OUT_PATH = OUT_DIR / "sample_kitchensink.xpt"
MANIFEST_PATH = OUT_DIR / "sample_kitchensink.expected.json"
POSITIVE_CONTROLS = {"ID", "INTCOL", "STRCOL", "DATECOL", "CONST"}
ALL_NULL_COLS = {"ALLNULL", "ALLNULLC"}
def _missing_mask(rng: np.random.Generator, n: int, frac: float) -> np.ndarray:
"""Return a boolean array of length n with exactly round(frac * n) True positions.
Using an exact count (rather than per-row Bernoulli draws) keeps the observed
missing fraction tight so the round-trip assertion can use a small tolerance.
"""
mask = np.zeros(n, dtype=bool)
k = int(round(frac * n))
if k > 0:
idx = rng.choice(n, size=k, replace=False)
mask[idx] = True
return mask
def _random_word(rng: np.random.Generator, min_len: int = 3, max_len: int = 10) -> str:
length = int(rng.integers(min_len, max_len + 1))
letters = np.array(list(string.ascii_lowercase))
return "".join(rng.choice(letters, size=length))
def _random_sentence(rng: np.random.Generator, min_words: int = 8, max_words: int = 20) -> str:
n_words = int(rng.integers(min_words, max_words + 1))
return " ".join(_random_word(rng) for _ in range(n_words))
def build_dataframe(rng: np.random.Generator) -> pd.DataFrame:
n = N_ROWS
ids = np.arange(1, n + 1, dtype=np.int64)
int_vals = rng.integers(0, 1000, size=n).astype(np.float64)
bigint_vals = rng.integers(10_000_000_000, 20_000_000_000, size=n).astype(np.float64)
bigint_vals[_missing_mask(rng, n, NULL_FRACTION)] = np.nan
float_vals = rng.normal(loc=100.0, scale=15.0, size=n)
float_vals[_missing_mask(rng, n, NULL_FRACTION)] = np.nan
bool_vals = rng.integers(0, 2, size=n).astype(np.float64)
bool_vals[_missing_mask(rng, n, NULL_FRACTION)] = np.nan
str_vals = [_random_word(rng, 3, 8) for _ in range(n)]
long_str_vals: list[str] = []
long_mask = _missing_mask(rng, n, NULL_FRACTION)
for i in range(n):
long_str_vals.append("" if long_mask[i] else _random_sentence(rng))
base_date = dt.date(2020, 1, 1)
date_vals = [base_date + dt.timedelta(days=int(rng.integers(0, 2000))) for _ in range(n)]
dt_vals_mask = _missing_mask(rng, n, NULL_FRACTION)
dt_vals: list = []
base_dt = dt.datetime(2020, 1, 1, 0, 0, 0)
for i in range(n):
if dt_vals_mask[i]:
dt_vals.append(pd.NaT)
else:
offset_seconds = int(rng.integers(0, 2000 * 24 * 3600))
dt_vals.append(base_dt + dt.timedelta(seconds=offset_seconds))
dt_series = pd.to_datetime(dt_vals)
time_mask = _missing_mask(rng, n, NULL_FRACTION)
time_vals: list = []
for i in range(n):
if time_mask[i]:
time_vals.append(None)
else:
seconds_into_day = int(rng.integers(0, 24 * 3600))
h, rem = divmod(seconds_into_day, 3600)
m, s = divmod(rem, 60)
time_vals.append(dt.time(h, m, s))
numasstr_mask = _missing_mask(rng, n, NULL_FRACTION)
numasstr_vals: list[str] = []
for i in range(n):
if numasstr_mask[i]:
numasstr_vals.append("")
elif rng.random() < 0.5:
numasstr_vals.append(str(int(rng.integers(-500, 500))))
else:
numasstr_vals.append(f"{rng.normal(0, 50):.2f}")
dateasstr_mask = _missing_mask(rng, n, NULL_FRACTION)
dateasstr_vals: list[str] = []
for i in range(n):
if dateasstr_mask[i]:
dateasstr_vals.append("")
else:
d = base_date + dt.timedelta(days=int(rng.integers(0, 2000)))
dateasstr_vals.append(d.isoformat())
mixed_mask = _missing_mask(rng, n, NULL_FRACTION)
mixed_vals: list[str] = []
choices = ["number", "date", "text", "text"]
for i in range(n):
if mixed_mask[i]:
mixed_vals.append("")
continue
kind = choices[int(rng.integers(0, len(choices)))]
if kind == "number":
mixed_vals.append(str(int(rng.integers(0, 1000))))
elif kind == "date":
d = base_date + dt.timedelta(days=int(rng.integers(0, 2000)))
mixed_vals.append(d.isoformat())
else:
mixed_vals.append(_random_word(rng, 4, 12))
const_vals = ["CONSTANT"] * n
allnull_vals = np.full(n, np.nan, dtype=np.float64)
allnullc_vals = [""] * n
df = pd.DataFrame(
{
"ID": ids,
"INTCOL": int_vals,
"BIGINT": bigint_vals,
"FLOATCOL": float_vals,
"BOOLCOL": bool_vals,
"STRCOL": str_vals,
"LONGSTR": long_str_vals,
"DATECOL": date_vals,
"DTCOL": dt_series,
"TIMECOL": time_vals,
"NUMASSTR": numasstr_vals,
"DATEASTR": dateasstr_vals,
"MIXED": mixed_vals,
"CONST": const_vals,
"ALLNULL": allnull_vals,
"ALLNULLC": allnullc_vals,
}
)
return df
COLUMN_LABELS: dict[str, str] = {
"ID": "Row identifier",
"INTCOL": "Integer positive control",
"BIGINT": "Big integer beyond int32 range",
"FLOATCOL": "Floating point with decimals",
"BOOLCOL": "Nullable boolean 0/1/NaN",
"STRCOL": "Short string positive control",
"LONGSTR": "Longer free-text string",
"DATECOL": "Date positive control",
"DTCOL": "Datetime with missing values",
"TIMECOL": "Time of day with missing values",
"NUMASSTR": "Numeric-looking strings in a char column",
"DATEASTR": "Date-looking strings in a char column",
"MIXED": "Heterogeneous strings: fallback to text",
"CONST": "Constant repeated value",
"ALLNULL": "Entirely missing numeric column",
"ALLNULLC": "Entirely missing character column",
}
VARIABLE_FORMATS: dict[str, str] = {
"DATECOL": "DATE9.",
"DTCOL": "DATETIME20.",
"TIMECOL": "TIME8.",
}
EXPECTED_MANIFEST: dict[str, dict] = {
"ID": {"postgres_type": "INTEGER", "nullable": False},
"INTCOL": {"postgres_type": "INTEGER", "nullable": False, "note": "positive control"},
"BIGINT": {"postgres_type": "BIGINT", "nullable": True, "note": "values beyond int32 range"},
"FLOATCOL": {"acceptable_types": ["DOUBLE PRECISION", "NUMERIC"], "nullable": True},
"BOOLCOL": {
"acceptable_types": ["BOOLEAN", "SMALLINT", "INTEGER"],
"nullable": True,
"note": "{0,1,NaN} is genuinely ambiguous; loader's choice is a design decision",
},
"STRCOL": {"acceptable_types": ["TEXT", "VARCHAR"], "nullable": False, "note": "positive control"},
"LONGSTR": {"acceptable_types": ["TEXT", "VARCHAR"], "nullable": True},
"DATECOL": {"postgres_type": "DATE", "nullable": False, "note": "positive control"},
"DTCOL": {"acceptable_types": ["TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE"], "nullable": True},
"TIMECOL": {"postgres_type": "TIME", "nullable": True},
"NUMASSTR": {
"acceptable_types": ["NUMERIC", "DOUBLE PRECISION"],
"nullable": True,
"note": "stored as char in SAS; loader should coerce numeric-looking strings",
},
"DATEASTR": {
"postgres_type": "DATE",
"nullable": True,
"note": "stored as char in SAS; loader should coerce ISO-date strings",
},
"MIXED": {
"acceptable_types": ["TEXT", "VARCHAR"],
"nullable": True,
"note": "heterogeneous content; loader should fall back to text",
},
"CONST": {"acceptable_types": ["TEXT", "VARCHAR"], "nullable": False},
"ALLNULL": {
"acceptable_types": ["TEXT", "VARCHAR"],
"nullable": True,
"note": "entirely null numeric; loader must pick a default type, typically TEXT",
},
"ALLNULLC": {
"acceptable_types": ["TEXT", "VARCHAR"],
"nullable": True,
"note": "entirely null character",
},
}
def write_manifest(df: pd.DataFrame) -> None:
manifest_cols = set(EXPECTED_MANIFEST.keys())
df_cols = set(df.columns)
missing = df_cols - manifest_cols
extra = manifest_cols - df_cols
if missing or extra:
raise AssertionError(
f"Manifest/DataFrame column mismatch. Missing from manifest: {missing}. "
f"Extra in manifest: {extra}."
)
with MANIFEST_PATH.open("w", encoding="utf-8") as f:
json.dump(EXPECTED_MANIFEST, f, indent=2, sort_keys=True)
f.write("\n")
def _char_missing_fraction(series: pd.Series) -> float:
return float((series.fillna("").astype(str) == "").mean())
def _numeric_missing_fraction(series: pd.Series) -> float:
return float(series.isna().mean())
def verify_roundtrip(source_df: pd.DataFrame) -> pd.DataFrame:
# Use pyreadstat (the writer) to verify the writer's output. pyreadstat preserves
# SAS format metadata on readback, so we can confirm the date/datetime/time
# variable_format mappings actually took effect.
readback, _meta = pyreadstat.read_xport(str(OUT_PATH))
assert len(readback.columns) == len(source_df.columns), (
f"Column count mismatch: wrote {len(source_df.columns)}, read back {len(readback.columns)}"
)
assert set(readback.columns) == set(source_df.columns), (
f"Column name mismatch. Only in source: {set(source_df.columns) - set(readback.columns)}. "
f"Only in readback: {set(readback.columns) - set(source_df.columns)}."
)
assert len(readback) == len(source_df), (
f"Row count mismatch: wrote {len(source_df)}, read back {len(readback)}"
)
for col in ("DATECOL", "DTCOL"):
dtype = readback[col].dtype
is_datetime = pd.api.types.is_datetime64_any_dtype(dtype)
is_object_of_dates = pd.api.types.is_object_dtype(dtype) and readback[col].dropna().map(
lambda v: isinstance(v, (dt.date, dt.datetime, pd.Timestamp))
).all()
assert is_datetime or is_object_of_dates, (
f"{col} came back as {dtype}; expected datetime-like. "
f"variable_format mapping may not have taken effect."
)
time_dtype = readback["TIMECOL"].dtype
time_ok = (
pd.api.types.is_datetime64_any_dtype(time_dtype)
or pd.api.types.is_numeric_dtype(time_dtype)
or (
pd.api.types.is_object_dtype(time_dtype)
and readback["TIMECOL"].dropna().map(
lambda v: isinstance(v, (dt.time, dt.datetime, pd.Timestamp, int, float))
).all()
)
)
assert time_ok, f"TIMECOL came back as {time_dtype}; expected datetime/numeric/time-object"
tol = 0.10
for col in source_df.columns:
if col in POSITIVE_CONTROLS:
series = readback[col]
if pd.api.types.is_numeric_dtype(series) or pd.api.types.is_datetime64_any_dtype(series):
observed = _numeric_missing_fraction(series)
else:
observed = _char_missing_fraction(series)
assert observed == 0.0, (
f"Positive control {col!r} has {observed:.2%} missing; expected 0%."
)
continue
if col in ALL_NULL_COLS:
series = readback[col]
if pd.api.types.is_numeric_dtype(series):
observed = _numeric_missing_fraction(series)
else:
observed = _char_missing_fraction(series)
assert observed == 1.0, (
f"All-null column {col!r} has {observed:.2%} missing; expected 100%."
)
continue
series = readback[col]
if pd.api.types.is_numeric_dtype(series) or pd.api.types.is_datetime64_any_dtype(series):
observed = _numeric_missing_fraction(series)
else:
observed = _char_missing_fraction(series)
assert abs(observed - NULL_FRACTION) <= tol, (
f"Column {col!r}: observed missing fraction {observed:.2%} not within "
f"±{tol:.0%} of NULL_FRACTION={NULL_FRACTION:.2%}."
)
assert MANIFEST_PATH.exists(), f"Manifest file {MANIFEST_PATH} missing."
with MANIFEST_PATH.open("r", encoding="utf-8") as f:
manifest = json.load(f)
assert set(manifest.keys()) == set(readback.columns), (
f"Manifest/readback column set mismatch. "
f"Only in manifest: {set(manifest.keys()) - set(readback.columns)}. "
f"Only in readback: {set(readback.columns) - set(manifest.keys())}."
)
return readback
def main() -> None:
OUT_DIR.mkdir(parents=True, exist_ok=True)
rng = np.random.default_rng(SEED)
df = build_dataframe(rng)
pyreadstat.write_xport(
df,
str(OUT_PATH),
file_format_version=5,
table_name="SAMPLE",
file_label="Kitchen sink sample for loader testing",
column_labels=COLUMN_LABELS,
variable_format=VARIABLE_FORMATS,
)
write_manifest(df)
readback = verify_roundtrip(df)
print(f"Wrote {OUT_PATH} ({N_ROWS} rows x {len(df.columns)} cols)")
print(f"Wrote {MANIFEST_PATH}")
print()
print("Readback via pyreadstat.read_xport (same reader the loader will use):")
print(readback.dtypes.to_string())
print()
print("Readback head:")
print(readback.head().to_string())
if __name__ == "__main__":
main()

904
generic_loader/load_sas.py Normal file
View File

@ -0,0 +1,904 @@
"""Per-file SAS-to-Postgres loader.
Library-style functions plus a thin CLI wrapper. Designed so an orchestrator
can wrap the library for directory/batch mode; orchestration is out of scope
here.
Python 3.9 compatible (target is an air-gapped host that currently only has
3.9). ``from __future__ import annotations`` lets us use PEP 585 generics
as annotations; runtime-resolved type uses (dataclass defaults, etc.) stick
to ``typing``.
"""
from __future__ import annotations
import argparse
import datetime as dt
import io
import json
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import pandas as pd
import psycopg2
import psycopg2.extensions
import pyreadstat
import yaml
from dotenv import load_dotenv
# ---------------------------------------------------------------------------
# Top-level tunables
# ---------------------------------------------------------------------------
COERCE_CHAR_COLUMNS = True
"""If True, try to promote object (string) columns to numeric/date/timestamp
when every non-empty value parses cleanly."""
CHAR_INFERENCE_MIN_VALUES = 3
"""Don't attempt character-column coercion with fewer than this many non-empty
values; too small a sample is easy to mis-infer."""
NUMERIC_INT_RANGE = (-2_147_483_648, 2_147_483_647)
"""INTEGER bounds; anything outside becomes BIGINT."""
VALID_IF_EXISTS = ("fail", "replace", "append")
# ---------------------------------------------------------------------------
# Dataclasses
# ---------------------------------------------------------------------------
@dataclass
class LoaderConfig:
filename: Path
schemaname: str
tablename: str
if_exists: str = "fail"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
@dataclass
class ColumnSpec:
name: str
postgres_type: str
nullable: bool
sas_format: Optional[str] = None
source_dtype: Optional[str] = None
notes: List[str] = field(default_factory=list)
# ---------------------------------------------------------------------------
# Custom exceptions
# ---------------------------------------------------------------------------
class TableExistsError(RuntimeError):
"""Raised when if_exists=fail and the target table already exists."""
class SchemaCompatibilityError(RuntimeError):
"""Raised when if_exists=append and the incoming schema is not
compatible with the existing table."""
class ValidationError(RuntimeError):
"""Raised when --validate detects a mismatch against the manifest."""
# ---------------------------------------------------------------------------
# Connection
# ---------------------------------------------------------------------------
def connect() -> psycopg2.extensions.connection:
"""Open a psycopg2 connection using standard libpq env vars.
Assumes `.env` has already been loaded (the CLI does this before calling).
Orchestrators that wrap this module should either call ``load_dotenv()``
themselves or ensure the env vars are set.
"""
conn = psycopg2.connect(
host=os.environ.get("PGHOST"),
port=os.environ.get("PGPORT"),
user=os.environ.get("PGUSER"),
password=os.environ.get("PGPASSWORD"),
dbname=os.environ.get("PGDATABASE"),
)
return conn
# ---------------------------------------------------------------------------
# Config loading
# ---------------------------------------------------------------------------
def load_config(path: Path) -> LoaderConfig:
"""Parse and validate the YAML config at ``path``."""
path = Path(path)
with path.open("r", encoding="utf-8") as f:
raw = yaml.safe_load(f)
if not isinstance(raw, dict):
raise ValueError(f"Config at {path} must be a YAML mapping at the top level.")
missing = [k for k in ("filename", "schemaname", "tablename") if k not in raw]
if missing:
raise ValueError(f"Config {path} missing required keys: {', '.join(missing)}")
filename = Path(raw["filename"])
if not filename.is_absolute():
filename = (path.parent / filename).resolve() if (path.parent / filename).exists() else Path(raw["filename"])
schemaname = str(raw["schemaname"])
tablename = str(raw["tablename"])
if_exists = str(raw.get("if_exists", "fail")).lower()
if if_exists not in VALID_IF_EXISTS:
raise ValueError(
f"Config {path}: if_exists={if_exists!r} is not one of {VALID_IF_EXISTS}"
)
include = raw.get("include")
exclude = raw.get("exclude")
if include is not None and exclude is not None:
raise ValueError(
f"Config {path}: 'include' and 'exclude' are mutually exclusive; set at most one."
)
if include is not None and not isinstance(include, list):
raise ValueError(f"Config {path}: 'include' must be a list of column names.")
if exclude is not None and not isinstance(exclude, list):
raise ValueError(f"Config {path}: 'exclude' must be a list of column names.")
return LoaderConfig(
filename=filename,
schemaname=schemaname,
tablename=tablename,
if_exists=if_exists,
include=[str(c) for c in include] if include is not None else None,
exclude=[str(c) for c in exclude] if exclude is not None else None,
)
# ---------------------------------------------------------------------------
# Reader
# ---------------------------------------------------------------------------
def read_sas(path: Path) -> Tuple[pd.DataFrame, Any]:
"""Dispatch to the right pyreadstat reader by extension.
Invariants (learned the hard way while building the sample generator):
* ``.xpt`` / ``.xport`` - no encoding arg; pyreadstat is flaky about
encoding on XPORT files it wrote itself.
* ``.sas7bdat`` - explicit ``encoding="latin-1"`` per colleague guidance.
"""
path = Path(path)
suffix = path.suffix.lower()
if suffix in (".xpt", ".xport"):
return pyreadstat.read_xport(str(path))
if suffix == ".sas7bdat":
return pyreadstat.read_sas7bdat(str(path), encoding="latin-1")
raise ValueError(f"Unsupported SAS file extension: {suffix}")
# ---------------------------------------------------------------------------
# Column filtering
# ---------------------------------------------------------------------------
def apply_column_filter(
df: pd.DataFrame,
include: Optional[List[str]],
exclude: Optional[List[str]],
) -> pd.DataFrame:
"""Restrict ``df`` to the requested columns. Names missing from the frame
raise a clear error rather than silently dropping."""
if include is not None and exclude is not None:
raise ValueError("include and exclude are mutually exclusive.")
if include is not None:
missing = [c for c in include if c not in df.columns]
if missing:
raise ValueError(f"include references unknown columns: {missing}")
return df.loc[:, list(include)].copy()
if exclude is not None:
missing = [c for c in exclude if c not in df.columns]
if missing:
raise ValueError(f"exclude references unknown columns: {missing}")
return df.drop(columns=list(exclude)).copy()
return df.copy()
# ---------------------------------------------------------------------------
# Type inference
# ---------------------------------------------------------------------------
_DATE_FORMAT_PREFIXES = ("DATE", "YYMMDD", "MMDDYY", "DDMMYY", "JULIAN")
def _format_driven_type(sas_format: Optional[str]) -> Optional[str]:
"""Return a Postgres type inferred from the SAS format string, or None
if the format doesn't pin it down."""
if not sas_format:
return None
fmt = sas_format.upper().lstrip()
# DATETIME must be checked before DATE since "DATETIME20." starts with "DATE".
if fmt.startswith("DATETIME"):
return "TIMESTAMP"
if fmt.startswith("TIME"):
return "TIME"
for prefix in _DATE_FORMAT_PREFIXES:
if fmt.startswith(prefix):
return "DATE"
return None
def _all_null(series: pd.Series) -> bool:
if pd.api.types.is_object_dtype(series):
return bool(series.map(lambda v: v is None or (isinstance(v, str) and v == "") or (isinstance(v, float) and pd.isna(v))).all())
return bool(series.isna().all())
def _char_missing_mask(series: pd.Series) -> pd.Series:
return series.map(lambda v: v is None or (isinstance(v, float) and pd.isna(v)) or (isinstance(v, str) and v == ""))
def _is_nullable(series: pd.Series) -> bool:
"""True if the column has at least one missing value."""
if pd.api.types.is_object_dtype(series):
return bool(_char_missing_mask(series).any())
return bool(series.isna().any())
def _numeric_int_target(series: pd.Series) -> Optional[str]:
"""Given a numeric (float64) series, if every non-null value is a whole
number, return INTEGER or BIGINT depending on range; else None."""
nonnull = series.dropna()
if nonnull.empty:
return None
# Whole-number test. Guard against inf.
try:
whole = ((nonnull % 1) == 0).all()
except TypeError:
return None
if not whole:
return None
lo, hi = NUMERIC_INT_RANGE
vmin = nonnull.min()
vmax = nonnull.max()
if lo <= vmin and vmax <= hi:
return "INTEGER"
return "BIGINT"
def _object_is_dates(series: pd.Series) -> Tuple[bool, bool]:
"""Return (all-date-like, any-datetime). If every non-null value is a
``datetime.date`` / ``datetime.datetime`` / ``pd.Timestamp``, return True
plus whether at least one carries a time component."""
nonnull = series.dropna()
if nonnull.empty:
return False, False
any_datetime = False
for v in nonnull:
if isinstance(v, dt.datetime) or isinstance(v, pd.Timestamp):
any_datetime = True
continue
if isinstance(v, dt.date):
continue
return False, False
return True, any_datetime
def _try_int_coerce(values: List[str]) -> Optional[str]:
"""If every value parses as an int, return INTEGER/BIGINT, else None."""
ints: List[int] = []
for v in values:
s = v.strip()
try:
ints.append(int(s))
except ValueError:
return None
if not ints:
return None
lo, hi = NUMERIC_INT_RANGE
if all(lo <= i <= hi for i in ints):
return "INTEGER"
return "BIGINT"
def _try_float_coerce(values: List[str]) -> bool:
for v in values:
try:
float(v)
except ValueError:
return False
return True
def _try_date_coerce(values: List[str]) -> bool:
for v in values:
try:
dt.date.fromisoformat(v)
except (ValueError, TypeError):
return False
return True
def _try_datetime_coerce(values: List[str]) -> bool:
for v in values:
try:
dt.datetime.fromisoformat(v)
except (ValueError, TypeError):
return False
return True
def _infer_char_type(series: pd.Series) -> str:
"""Object/string column inference. Returns a Postgres type string."""
mask = _char_missing_mask(series)
nonempty = [str(v) for v in series[~mask].tolist()]
if not COERCE_CHAR_COLUMNS or len(nonempty) < CHAR_INFERENCE_MIN_VALUES:
return "TEXT"
int_guess = _try_int_coerce(nonempty)
if int_guess is not None:
return int_guess
if _try_float_coerce(nonempty):
return "DOUBLE PRECISION"
if _try_date_coerce(nonempty):
return "DATE"
if _try_datetime_coerce(nonempty):
return "TIMESTAMP"
return "TEXT"
def infer_schema(
df: pd.DataFrame,
meta: Any,
*,
coerce_chars: bool = COERCE_CHAR_COLUMNS,
) -> Dict[str, ColumnSpec]:
"""Infer a Postgres column spec for each column in ``df``.
``meta`` is the pyreadstat metadata object; we read
``meta.original_variable_types`` (a dict keyed by column name) for
format-driven date/time/timestamp inference.
The ``coerce_chars`` kwarg lets callers override the module-level
``COERCE_CHAR_COLUMNS`` without mutating global state. Internally the
char-inference helpers still read the constant - a full override would
thread the flag through, but the one-knob story here is intentional.
"""
original_formats: Dict[str, str] = dict(getattr(meta, "original_variable_types", {}) or {})
# Temporarily flip the module-level flag if the caller asked us to.
global COERCE_CHAR_COLUMNS
saved = COERCE_CHAR_COLUMNS
COERCE_CHAR_COLUMNS = coerce_chars
try:
out: Dict[str, ColumnSpec] = {}
for col in df.columns:
series = df[col]
sas_format = original_formats.get(col)
notes: List[str] = []
pg_type = _format_driven_type(sas_format)
if pg_type is None:
if _all_null(series):
pg_type = "TEXT"
notes.append("all-null column; defaulting to TEXT")
elif pd.api.types.is_datetime64_any_dtype(series):
pg_type = "TIMESTAMP"
elif pd.api.types.is_object_dtype(series):
is_dates, any_dt = _object_is_dates(series)
if is_dates:
pg_type = "TIMESTAMP" if any_dt else "DATE"
else:
pg_type = _infer_char_type(series)
elif pd.api.types.is_numeric_dtype(series):
int_target = _numeric_int_target(series)
if int_target is not None:
pg_type = int_target
else:
pg_type = "DOUBLE PRECISION"
else:
pg_type = "TEXT"
notes.append(f"unhandled dtype {series.dtype}; defaulting to TEXT")
nullable = _is_nullable(series)
out[col] = ColumnSpec(
name=col,
postgres_type=pg_type,
nullable=nullable,
sas_format=sas_format,
source_dtype=str(series.dtype),
notes=notes,
)
return out
finally:
COERCE_CHAR_COLUMNS = saved
# ---------------------------------------------------------------------------
# Table management
# ---------------------------------------------------------------------------
def _quote_ident(ident: str) -> str:
"""Quote a Postgres identifier. psycopg2 doesn't expose this directly
until 2.8+ with sql.Identifier; we do it by hand to stay driver-simple."""
return '"' + ident.replace('"', '""') + '"'
def _qualified(schema: str, table: str) -> str:
return f"{_quote_ident(schema)}.{_quote_ident(table)}"
def _table_exists(conn, schema: str, table: str) -> bool:
with conn.cursor() as cur:
cur.execute(
"SELECT 1 FROM information_schema.tables "
"WHERE table_schema = %s AND table_name = %s",
(schema, table),
)
return cur.fetchone() is not None
def render_create_table(schema: str, table: str, columns: Dict[str, ColumnSpec]) -> str:
lines = []
for spec in columns.values():
null_clause = "" if spec.nullable else " NOT NULL"
lines.append(f" {_quote_ident(spec.name)} {spec.postgres_type}{null_clause}")
body = ",\n".join(lines)
return f"CREATE TABLE {_qualified(schema, table)} (\n{body}\n);"
def _create_table_sql(conn, schema: str, table: str, columns: Dict[str, ColumnSpec]) -> None:
sql = render_create_table(schema, table, columns)
with conn.cursor() as cur:
cur.execute(sql)
def _drop_table(conn, schema: str, table: str) -> None:
with conn.cursor() as cur:
cur.execute(f"DROP TABLE {_qualified(schema, table)}")
# Normalization table: map both loader-emitted and Postgres-reported type
# strings to a single canonical family name. Ignore length/precision
# modifiers like VARCHAR(n) and NUMERIC(p,s).
_TYPE_NORMALIZATION: Dict[str, str] = {
"INTEGER": "integer",
"INT": "integer",
"INT4": "integer",
"BIGINT": "bigint",
"INT8": "bigint",
"SMALLINT": "smallint",
"INT2": "smallint",
"DOUBLE PRECISION": "double precision",
"FLOAT8": "double precision",
"REAL": "real",
"FLOAT4": "real",
"NUMERIC": "numeric",
"DECIMAL": "numeric",
"TEXT": "text",
"VARCHAR": "character varying",
"CHARACTER VARYING": "character varying",
"CHAR": "character",
"CHARACTER": "character",
"BPCHAR": "character",
"BOOLEAN": "boolean",
"BOOL": "boolean",
"DATE": "date",
"TIMESTAMP": "timestamp without time zone",
"TIMESTAMP WITHOUT TIME ZONE": "timestamp without time zone",
"TIMESTAMPTZ": "timestamp with time zone",
"TIMESTAMP WITH TIME ZONE": "timestamp with time zone",
"TIME": "time without time zone",
"TIME WITHOUT TIME ZONE": "time without time zone",
"TIMETZ": "time with time zone",
"TIME WITH TIME ZONE": "time with time zone",
}
def _normalize_type(pg_type: str) -> str:
"""Strip length/precision modifiers and map to canonical family."""
stripped = pg_type.strip().upper()
# Remove trailing (n) / (p,s) before the space-separated tail.
# Examples: "VARCHAR(10)" -> "VARCHAR"; "TIMESTAMP(6) WITHOUT TIME ZONE" -> "TIMESTAMP WITHOUT TIME ZONE"
import re
stripped = re.sub(r"\(\s*\d+\s*(?:,\s*\d+\s*)?\)", "", stripped).strip()
# Collapse doubled whitespace after paren removal.
stripped = re.sub(r"\s+", " ", stripped)
return _TYPE_NORMALIZATION.get(stripped, stripped.lower())
def _assert_schema_compatible(
conn, schema: str, table: str, columns: Dict[str, ColumnSpec]
) -> None:
"""Pre-flight check for if_exists=append. See plan section on option B."""
with conn.cursor() as cur:
cur.execute(
"SELECT column_name, data_type, is_nullable "
"FROM information_schema.columns "
"WHERE table_schema = %s AND table_name = %s",
(schema, table),
)
existing = {row[0]: (row[1], row[2]) for row in cur.fetchall()}
mismatches: List[str] = []
warnings: List[str] = []
for name, spec in columns.items():
if name not in existing:
mismatches.append(
f"column {name!r} not present in target {schema}.{table}"
)
continue
target_type, target_nullable = existing[name]
inferred_norm = _normalize_type(spec.postgres_type)
target_norm = _normalize_type(target_type)
if inferred_norm != target_norm:
mismatches.append(
f"column {name!r}: inferred {spec.postgres_type} "
f"(normalized {inferred_norm!r}) but target is {target_type} "
f"(normalized {target_norm!r})"
)
target_is_notnull = (target_nullable == "NO")
if spec.nullable and target_is_notnull:
warnings.append(
f"column {name!r}: incoming allows NULLs but target is NOT NULL; "
"COPY will fail if any NULLs appear"
)
for w in warnings:
print(f"[warn] {w}", file=sys.stderr)
if mismatches:
raise SchemaCompatibilityError(
"append-mode schema compatibility check failed:\n - "
+ "\n - ".join(mismatches)
)
def create_table(
conn,
schema_name: str,
table_name: str,
columns: Dict[str, ColumnSpec],
if_exists: str,
) -> None:
"""Create (or verify) the target table according to ``if_exists``."""
if if_exists not in VALID_IF_EXISTS:
raise ValueError(f"if_exists must be one of {VALID_IF_EXISTS}, got {if_exists!r}")
exists = _table_exists(conn, schema_name, table_name)
if exists:
if if_exists == "fail":
raise TableExistsError(
f"Table {schema_name}.{table_name} already exists and if_exists=fail"
)
if if_exists == "replace":
_drop_table(conn, schema_name, table_name)
_create_table_sql(conn, schema_name, table_name, columns)
return
if if_exists == "append":
_assert_schema_compatible(conn, schema_name, table_name, columns)
return
else:
_create_table_sql(conn, schema_name, table_name, columns)
# ---------------------------------------------------------------------------
# COPY loading
# ---------------------------------------------------------------------------
def _seconds_to_time(v: Any) -> Optional[dt.time]:
if v is None:
return None
if isinstance(v, float) and pd.isna(v):
return None
if isinstance(v, dt.time):
return v
if isinstance(v, (dt.datetime, pd.Timestamp)):
return v.time() if not pd.isna(v) else None
try:
total = int(round(float(v)))
except (TypeError, ValueError):
return None
h, rem = divmod(total, 3600)
m, s = divmod(rem, 60)
# Clamp; TIME8. is always within a day.
h = max(0, min(h, 23))
return dt.time(h, m, s)
def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.DataFrame:
"""Materialize a copy of ``df`` with each column in the right shape for
``to_csv`` so the CSV lands as valid input for the target Postgres type.
"""
out = pd.DataFrame(index=df.index)
for name, spec in columns.items():
series = df[name]
pg = spec.postgres_type.upper()
if pg in ("INTEGER", "BIGINT", "SMALLINT"):
if pd.api.types.is_object_dtype(series):
series = pd.to_numeric(
series.replace({"": None}), errors="coerce"
)
out[name] = series.astype("Int64")
elif pg in ("DOUBLE PRECISION", "REAL", "NUMERIC"):
if pd.api.types.is_object_dtype(series):
series = pd.to_numeric(
series.replace({"": None}), errors="coerce"
)
out[name] = series.astype("float64")
elif pg == "DATE":
if pd.api.types.is_datetime64_any_dtype(series):
out[name] = series.dt.date
elif pd.api.types.is_object_dtype(series):
def _to_date(v: Any) -> Optional[dt.date]:
if v is None or (isinstance(v, float) and pd.isna(v)):
return None
if isinstance(v, dt.datetime):
return v.date()
if isinstance(v, dt.date):
return v
if isinstance(v, str):
if v == "":
return None
try:
return dt.date.fromisoformat(v)
except ValueError:
return None
return None
out[name] = series.map(_to_date)
else:
out[name] = series
elif pg in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE", "TIMESTAMP WITH TIME ZONE"):
if pd.api.types.is_datetime64_any_dtype(series):
out[name] = series
elif pd.api.types.is_object_dtype(series):
def _to_dt(v: Any) -> Optional[dt.datetime]:
if v is None or (isinstance(v, float) and pd.isna(v)):
return None
if isinstance(v, dt.datetime):
return v
if isinstance(v, dt.date):
return dt.datetime(v.year, v.month, v.day)
if isinstance(v, pd.Timestamp):
return v.to_pydatetime() if not pd.isna(v) else None
if isinstance(v, str):
if v == "":
return None
try:
return dt.datetime.fromisoformat(v)
except ValueError:
return None
return None
out[name] = series.map(_to_dt)
else:
out[name] = series
elif pg in ("TIME", "TIME WITHOUT TIME ZONE", "TIME WITH TIME ZONE"):
out[name] = series.map(_seconds_to_time)
elif pg in ("TEXT", "VARCHAR", "CHARACTER VARYING", "CHAR", "CHARACTER"):
# Leave empty strings as "" so `NULL ''` in COPY turns them into NULL.
def _to_str(v: Any) -> Any:
if v is None:
return ""
if isinstance(v, float) and pd.isna(v):
return ""
return str(v)
out[name] = series.map(_to_str)
elif pg == "BOOLEAN":
out[name] = series.astype("boolean") if series.dtype != object else series
else:
out[name] = series
return out
def copy_dataframe(
conn,
schema_name: str,
table_name: str,
df: pd.DataFrame,
columns: Dict[str, ColumnSpec],
) -> int:
"""Stream ``df`` into Postgres via ``COPY ... FROM STDIN``.
Returns the number of rows inserted.
"""
prepared = _prepare_for_copy(df, columns)
buf = io.StringIO()
prepared.to_csv(
buf,
index=False,
header=False,
na_rep="",
date_format="%Y-%m-%d %H:%M:%S",
)
buf.seek(0)
col_list = ", ".join(_quote_ident(name) for name in columns.keys())
sql = (
f"COPY {_qualified(schema_name, table_name)} ({col_list}) "
f"FROM STDIN WITH (FORMAT csv, NULL '')"
)
with conn.cursor() as cur:
cur.copy_expert(sql, buf)
rowcount = cur.rowcount
return int(rowcount) if rowcount is not None else len(prepared)
# ---------------------------------------------------------------------------
# Manifest validation
# ---------------------------------------------------------------------------
def _match_manifest_type(inferred: str, manifest_entry: Dict[str, Any]) -> bool:
inferred_norm = _normalize_type(inferred)
if "postgres_type" in manifest_entry:
return inferred_norm == _normalize_type(manifest_entry["postgres_type"])
if "acceptable_types" in manifest_entry:
return any(
inferred_norm == _normalize_type(t)
for t in manifest_entry["acceptable_types"]
)
return False
def validate_against_manifest(
inferred: Dict[str, ColumnSpec],
manifest_path: Path,
) -> List[str]:
"""Compare the inferred schema against the expected-types manifest.
Returns a list of human-readable problem strings; empty list means OK.
"""
manifest_path = Path(manifest_path)
if not manifest_path.exists():
return [f"manifest not found: {manifest_path}"]
with manifest_path.open("r", encoding="utf-8") as f:
manifest = json.load(f)
problems: List[str] = []
only_in_inferred = set(inferred) - set(manifest)
only_in_manifest = set(manifest) - set(inferred)
if only_in_inferred:
problems.append(
f"columns in inferred but not manifest: {sorted(only_in_inferred)}"
)
if only_in_manifest:
problems.append(
f"columns in manifest but not inferred: {sorted(only_in_manifest)}"
)
for name, spec in inferred.items():
entry = manifest.get(name)
if entry is None:
continue
if not _match_manifest_type(spec.postgres_type, entry):
expected = entry.get("postgres_type") or entry.get("acceptable_types")
problems.append(
f"column {name!r}: inferred {spec.postgres_type!r}, "
f"manifest expected {expected!r}"
)
manifest_nullable = bool(entry.get("nullable", True))
if spec.nullable and not manifest_nullable:
problems.append(
f"column {name!r}: inferred nullable, manifest expects NOT NULL "
f"(loosening nullability is never allowed)"
)
return problems
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
description="Load a single SAS file (XPT or sas7bdat) into Postgres.",
)
p.add_argument("--config", required=True, type=Path, help="Path to YAML config")
p.add_argument(
"--validate",
action="store_true",
help=(
"Compare inferred schema against <filename-stem>.expected.json "
"next to the SAS file; exits nonzero on mismatch."
),
)
p.add_argument(
"--dry-run",
action="store_true",
help="Print inferred CREATE TABLE and stop; don't touch Postgres.",
)
return p
def _format_columns_summary(columns: Dict[str, ColumnSpec]) -> str:
lines = []
for spec in columns.values():
null = "" if spec.nullable else " NOT NULL"
lines.append(f" {spec.name}: {spec.postgres_type}{null}")
return "\n".join(lines)
def main(argv: Optional[List[str]] = None) -> int:
args = _build_argparser().parse_args(argv)
load_dotenv()
cfg = load_config(args.config)
if not cfg.filename.exists():
print(f"error: SAS file not found: {cfg.filename}", file=sys.stderr)
return 2
df, meta = read_sas(cfg.filename)
df = apply_column_filter(df, cfg.include, cfg.exclude)
columns = infer_schema(df, meta)
if args.validate:
manifest_path = cfg.filename.with_suffix("").with_suffix(".expected.json")
# The above strips .xpt then appends .expected.json, e.g.
# "sample_kitchensink.xpt" -> "sample_kitchensink.expected.json".
problems = validate_against_manifest(columns, manifest_path)
if problems:
print("validation failed:", file=sys.stderr)
for p in problems:
print(f" - {p}", file=sys.stderr)
return 1
print(f"validation OK ({len(columns)} columns match {manifest_path.name})")
if args.dry_run:
print(render_create_table(cfg.schemaname, cfg.tablename, columns))
return 0
conn = connect()
conn.autocommit = False
try:
create_table(conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists)
inserted = copy_dataframe(conn, cfg.schemaname, cfg.tablename, df, columns)
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
print(
f"loaded {inserted} rows into {cfg.schemaname}.{cfg.tablename} "
f"({len(columns)} columns)"
)
print("final schema:")
print(_format_columns_summary(columns))
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@ -0,0 +1,6 @@
pandas>=2.0,<2.3
pyreadstat>=1.2,<1.3
numpy>=1.24,<2.1
pyyaml>=6.0,<7.0
psycopg2-binary>=2.9,<3.0
python-dotenv>=1.0,<2.0

View File

@ -0,0 +1,17 @@
filename: samples/sample_kitchensink.xpt
schemaname: public
tablename: kitchensink
# Optional. If set, only these columns are loaded. Mutually exclusive with exclude.
# include:
# - ID
# - INTCOL
# - DATECOL
# Optional. Columns to drop.
# exclude:
# - ALLNULL
# What to do if the target table already exists: fail | replace | append
# Defaults to fail.
if_exists: append