commit f681f1012ab14790a42c3ecb153082b245eb415f Author: michael-corey Date: Sat Apr 18 09:34:48 2026 -0500 Adding generic loader diff --git a/generic_loader/.env.example b/generic_loader/.env.example new file mode 100644 index 0000000..5be8065 --- /dev/null +++ b/generic_loader/.env.example @@ -0,0 +1,5 @@ +PGHOST=localhost +PGPORT=5432 +PGUSER= +PGPASSWORD= +PGDATABASE= diff --git a/generic_loader/.gitignore b/generic_loader/.gitignore new file mode 100644 index 0000000..c93b13d --- /dev/null +++ b/generic_loader/.gitignore @@ -0,0 +1,3 @@ +/.venv +/samples +/.env diff --git a/generic_loader/generate_sample_sas.py b/generic_loader/generate_sample_sas.py new file mode 100644 index 0000000..0a6a977 --- /dev/null +++ b/generic_loader/generate_sample_sas.py @@ -0,0 +1,380 @@ +"""Generate a kitchen-sink SAS XPORT file plus an expected-types manifest. + +Running this script produces two files under samples/: + - sample_kitchensink.xpt the SAS XPORT test fixture + - sample_kitchensink.expected.json ground-truth Postgres types for the loader + +Tune behavior via the top-level constants below. +""" + +from __future__ import annotations + +import datetime as dt +import json +import string +from pathlib import Path + +import numpy as np +import pandas as pd +import pyreadstat + +SEED = 42 +N_ROWS = 1000 +NULL_FRACTION = 0.20 +OUT_DIR = Path("samples") +OUT_PATH = OUT_DIR / "sample_kitchensink.xpt" +MANIFEST_PATH = OUT_DIR / "sample_kitchensink.expected.json" + +POSITIVE_CONTROLS = {"ID", "INTCOL", "STRCOL", "DATECOL", "CONST"} +ALL_NULL_COLS = {"ALLNULL", "ALLNULLC"} + + +def _missing_mask(rng: np.random.Generator, n: int, frac: float) -> np.ndarray: + """Return a boolean array of length n with exactly round(frac * n) True positions. + + Using an exact count (rather than per-row Bernoulli draws) keeps the observed + missing fraction tight so the round-trip assertion can use a small tolerance. + """ + mask = np.zeros(n, dtype=bool) + k = int(round(frac * n)) + if k > 0: + idx = rng.choice(n, size=k, replace=False) + mask[idx] = True + return mask + + +def _random_word(rng: np.random.Generator, min_len: int = 3, max_len: int = 10) -> str: + length = int(rng.integers(min_len, max_len + 1)) + letters = np.array(list(string.ascii_lowercase)) + return "".join(rng.choice(letters, size=length)) + + +def _random_sentence(rng: np.random.Generator, min_words: int = 8, max_words: int = 20) -> str: + n_words = int(rng.integers(min_words, max_words + 1)) + return " ".join(_random_word(rng) for _ in range(n_words)) + + +def build_dataframe(rng: np.random.Generator) -> pd.DataFrame: + n = N_ROWS + + ids = np.arange(1, n + 1, dtype=np.int64) + + int_vals = rng.integers(0, 1000, size=n).astype(np.float64) + + bigint_vals = rng.integers(10_000_000_000, 20_000_000_000, size=n).astype(np.float64) + bigint_vals[_missing_mask(rng, n, NULL_FRACTION)] = np.nan + + float_vals = rng.normal(loc=100.0, scale=15.0, size=n) + float_vals[_missing_mask(rng, n, NULL_FRACTION)] = np.nan + + bool_vals = rng.integers(0, 2, size=n).astype(np.float64) + bool_vals[_missing_mask(rng, n, NULL_FRACTION)] = np.nan + + str_vals = [_random_word(rng, 3, 8) for _ in range(n)] + + long_str_vals: list[str] = [] + long_mask = _missing_mask(rng, n, NULL_FRACTION) + for i in range(n): + long_str_vals.append("" if long_mask[i] else _random_sentence(rng)) + + base_date = dt.date(2020, 1, 1) + date_vals = [base_date + dt.timedelta(days=int(rng.integers(0, 2000))) for _ in range(n)] + + dt_vals_mask = _missing_mask(rng, n, NULL_FRACTION) + dt_vals: list = [] + base_dt = dt.datetime(2020, 1, 1, 0, 0, 0) + for i in range(n): + if dt_vals_mask[i]: + dt_vals.append(pd.NaT) + else: + offset_seconds = int(rng.integers(0, 2000 * 24 * 3600)) + dt_vals.append(base_dt + dt.timedelta(seconds=offset_seconds)) + dt_series = pd.to_datetime(dt_vals) + + time_mask = _missing_mask(rng, n, NULL_FRACTION) + time_vals: list = [] + for i in range(n): + if time_mask[i]: + time_vals.append(None) + else: + seconds_into_day = int(rng.integers(0, 24 * 3600)) + h, rem = divmod(seconds_into_day, 3600) + m, s = divmod(rem, 60) + time_vals.append(dt.time(h, m, s)) + + numasstr_mask = _missing_mask(rng, n, NULL_FRACTION) + numasstr_vals: list[str] = [] + for i in range(n): + if numasstr_mask[i]: + numasstr_vals.append("") + elif rng.random() < 0.5: + numasstr_vals.append(str(int(rng.integers(-500, 500)))) + else: + numasstr_vals.append(f"{rng.normal(0, 50):.2f}") + + dateasstr_mask = _missing_mask(rng, n, NULL_FRACTION) + dateasstr_vals: list[str] = [] + for i in range(n): + if dateasstr_mask[i]: + dateasstr_vals.append("") + else: + d = base_date + dt.timedelta(days=int(rng.integers(0, 2000))) + dateasstr_vals.append(d.isoformat()) + + mixed_mask = _missing_mask(rng, n, NULL_FRACTION) + mixed_vals: list[str] = [] + choices = ["number", "date", "text", "text"] + for i in range(n): + if mixed_mask[i]: + mixed_vals.append("") + continue + kind = choices[int(rng.integers(0, len(choices)))] + if kind == "number": + mixed_vals.append(str(int(rng.integers(0, 1000)))) + elif kind == "date": + d = base_date + dt.timedelta(days=int(rng.integers(0, 2000))) + mixed_vals.append(d.isoformat()) + else: + mixed_vals.append(_random_word(rng, 4, 12)) + + const_vals = ["CONSTANT"] * n + + allnull_vals = np.full(n, np.nan, dtype=np.float64) + allnullc_vals = [""] * n + + df = pd.DataFrame( + { + "ID": ids, + "INTCOL": int_vals, + "BIGINT": bigint_vals, + "FLOATCOL": float_vals, + "BOOLCOL": bool_vals, + "STRCOL": str_vals, + "LONGSTR": long_str_vals, + "DATECOL": date_vals, + "DTCOL": dt_series, + "TIMECOL": time_vals, + "NUMASSTR": numasstr_vals, + "DATEASTR": dateasstr_vals, + "MIXED": mixed_vals, + "CONST": const_vals, + "ALLNULL": allnull_vals, + "ALLNULLC": allnullc_vals, + } + ) + return df + + +COLUMN_LABELS: dict[str, str] = { + "ID": "Row identifier", + "INTCOL": "Integer positive control", + "BIGINT": "Big integer beyond int32 range", + "FLOATCOL": "Floating point with decimals", + "BOOLCOL": "Nullable boolean 0/1/NaN", + "STRCOL": "Short string positive control", + "LONGSTR": "Longer free-text string", + "DATECOL": "Date positive control", + "DTCOL": "Datetime with missing values", + "TIMECOL": "Time of day with missing values", + "NUMASSTR": "Numeric-looking strings in a char column", + "DATEASTR": "Date-looking strings in a char column", + "MIXED": "Heterogeneous strings: fallback to text", + "CONST": "Constant repeated value", + "ALLNULL": "Entirely missing numeric column", + "ALLNULLC": "Entirely missing character column", +} + + +VARIABLE_FORMATS: dict[str, str] = { + "DATECOL": "DATE9.", + "DTCOL": "DATETIME20.", + "TIMECOL": "TIME8.", +} + + +EXPECTED_MANIFEST: dict[str, dict] = { + "ID": {"postgres_type": "INTEGER", "nullable": False}, + "INTCOL": {"postgres_type": "INTEGER", "nullable": False, "note": "positive control"}, + "BIGINT": {"postgres_type": "BIGINT", "nullable": True, "note": "values beyond int32 range"}, + "FLOATCOL": {"acceptable_types": ["DOUBLE PRECISION", "NUMERIC"], "nullable": True}, + "BOOLCOL": { + "acceptable_types": ["BOOLEAN", "SMALLINT", "INTEGER"], + "nullable": True, + "note": "{0,1,NaN} is genuinely ambiguous; loader's choice is a design decision", + }, + "STRCOL": {"acceptable_types": ["TEXT", "VARCHAR"], "nullable": False, "note": "positive control"}, + "LONGSTR": {"acceptable_types": ["TEXT", "VARCHAR"], "nullable": True}, + "DATECOL": {"postgres_type": "DATE", "nullable": False, "note": "positive control"}, + "DTCOL": {"acceptable_types": ["TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE"], "nullable": True}, + "TIMECOL": {"postgres_type": "TIME", "nullable": True}, + "NUMASSTR": { + "acceptable_types": ["NUMERIC", "DOUBLE PRECISION"], + "nullable": True, + "note": "stored as char in SAS; loader should coerce numeric-looking strings", + }, + "DATEASTR": { + "postgres_type": "DATE", + "nullable": True, + "note": "stored as char in SAS; loader should coerce ISO-date strings", + }, + "MIXED": { + "acceptable_types": ["TEXT", "VARCHAR"], + "nullable": True, + "note": "heterogeneous content; loader should fall back to text", + }, + "CONST": {"acceptable_types": ["TEXT", "VARCHAR"], "nullable": False}, + "ALLNULL": { + "acceptable_types": ["TEXT", "VARCHAR"], + "nullable": True, + "note": "entirely null numeric; loader must pick a default type, typically TEXT", + }, + "ALLNULLC": { + "acceptable_types": ["TEXT", "VARCHAR"], + "nullable": True, + "note": "entirely null character", + }, +} + + +def write_manifest(df: pd.DataFrame) -> None: + manifest_cols = set(EXPECTED_MANIFEST.keys()) + df_cols = set(df.columns) + missing = df_cols - manifest_cols + extra = manifest_cols - df_cols + if missing or extra: + raise AssertionError( + f"Manifest/DataFrame column mismatch. Missing from manifest: {missing}. " + f"Extra in manifest: {extra}." + ) + with MANIFEST_PATH.open("w", encoding="utf-8") as f: + json.dump(EXPECTED_MANIFEST, f, indent=2, sort_keys=True) + f.write("\n") + + +def _char_missing_fraction(series: pd.Series) -> float: + return float((series.fillna("").astype(str) == "").mean()) + + +def _numeric_missing_fraction(series: pd.Series) -> float: + return float(series.isna().mean()) + + +def verify_roundtrip(source_df: pd.DataFrame) -> pd.DataFrame: + # Use pyreadstat (the writer) to verify the writer's output. pyreadstat preserves + # SAS format metadata on readback, so we can confirm the date/datetime/time + # variable_format mappings actually took effect. + readback, _meta = pyreadstat.read_xport(str(OUT_PATH)) + + assert len(readback.columns) == len(source_df.columns), ( + f"Column count mismatch: wrote {len(source_df.columns)}, read back {len(readback.columns)}" + ) + assert set(readback.columns) == set(source_df.columns), ( + f"Column name mismatch. Only in source: {set(source_df.columns) - set(readback.columns)}. " + f"Only in readback: {set(readback.columns) - set(source_df.columns)}." + ) + assert len(readback) == len(source_df), ( + f"Row count mismatch: wrote {len(source_df)}, read back {len(readback)}" + ) + + for col in ("DATECOL", "DTCOL"): + dtype = readback[col].dtype + is_datetime = pd.api.types.is_datetime64_any_dtype(dtype) + is_object_of_dates = pd.api.types.is_object_dtype(dtype) and readback[col].dropna().map( + lambda v: isinstance(v, (dt.date, dt.datetime, pd.Timestamp)) + ).all() + assert is_datetime or is_object_of_dates, ( + f"{col} came back as {dtype}; expected datetime-like. " + f"variable_format mapping may not have taken effect." + ) + + time_dtype = readback["TIMECOL"].dtype + time_ok = ( + pd.api.types.is_datetime64_any_dtype(time_dtype) + or pd.api.types.is_numeric_dtype(time_dtype) + or ( + pd.api.types.is_object_dtype(time_dtype) + and readback["TIMECOL"].dropna().map( + lambda v: isinstance(v, (dt.time, dt.datetime, pd.Timestamp, int, float)) + ).all() + ) + ) + assert time_ok, f"TIMECOL came back as {time_dtype}; expected datetime/numeric/time-object" + + tol = 0.10 + for col in source_df.columns: + if col in POSITIVE_CONTROLS: + series = readback[col] + if pd.api.types.is_numeric_dtype(series) or pd.api.types.is_datetime64_any_dtype(series): + observed = _numeric_missing_fraction(series) + else: + observed = _char_missing_fraction(series) + assert observed == 0.0, ( + f"Positive control {col!r} has {observed:.2%} missing; expected 0%." + ) + continue + + if col in ALL_NULL_COLS: + series = readback[col] + if pd.api.types.is_numeric_dtype(series): + observed = _numeric_missing_fraction(series) + else: + observed = _char_missing_fraction(series) + assert observed == 1.0, ( + f"All-null column {col!r} has {observed:.2%} missing; expected 100%." + ) + continue + + series = readback[col] + if pd.api.types.is_numeric_dtype(series) or pd.api.types.is_datetime64_any_dtype(series): + observed = _numeric_missing_fraction(series) + else: + observed = _char_missing_fraction(series) + assert abs(observed - NULL_FRACTION) <= tol, ( + f"Column {col!r}: observed missing fraction {observed:.2%} not within " + f"±{tol:.0%} of NULL_FRACTION={NULL_FRACTION:.2%}." + ) + + assert MANIFEST_PATH.exists(), f"Manifest file {MANIFEST_PATH} missing." + with MANIFEST_PATH.open("r", encoding="utf-8") as f: + manifest = json.load(f) + assert set(manifest.keys()) == set(readback.columns), ( + f"Manifest/readback column set mismatch. " + f"Only in manifest: {set(manifest.keys()) - set(readback.columns)}. " + f"Only in readback: {set(readback.columns) - set(manifest.keys())}." + ) + + return readback + + +def main() -> None: + OUT_DIR.mkdir(parents=True, exist_ok=True) + + rng = np.random.default_rng(SEED) + df = build_dataframe(rng) + + pyreadstat.write_xport( + df, + str(OUT_PATH), + file_format_version=5, + table_name="SAMPLE", + file_label="Kitchen sink sample for loader testing", + column_labels=COLUMN_LABELS, + variable_format=VARIABLE_FORMATS, + ) + + write_manifest(df) + + readback = verify_roundtrip(df) + + print(f"Wrote {OUT_PATH} ({N_ROWS} rows x {len(df.columns)} cols)") + print(f"Wrote {MANIFEST_PATH}") + print() + print("Readback via pyreadstat.read_xport (same reader the loader will use):") + print(readback.dtypes.to_string()) + print() + print("Readback head:") + print(readback.head().to_string()) + + +if __name__ == "__main__": + main() diff --git a/generic_loader/load_sas.py b/generic_loader/load_sas.py new file mode 100644 index 0000000..4f9a4ca --- /dev/null +++ b/generic_loader/load_sas.py @@ -0,0 +1,904 @@ +"""Per-file SAS-to-Postgres loader. + +Library-style functions plus a thin CLI wrapper. Designed so an orchestrator +can wrap the library for directory/batch mode; orchestration is out of scope +here. + +Python 3.9 compatible (target is an air-gapped host that currently only has +3.9). ``from __future__ import annotations`` lets us use PEP 585 generics +as annotations; runtime-resolved type uses (dataclass defaults, etc.) stick +to ``typing``. +""" + +from __future__ import annotations + +import argparse +import datetime as dt +import io +import json +import os +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import pandas as pd +import psycopg2 +import psycopg2.extensions +import pyreadstat +import yaml +from dotenv import load_dotenv + + +# --------------------------------------------------------------------------- +# Top-level tunables +# --------------------------------------------------------------------------- + +COERCE_CHAR_COLUMNS = True +"""If True, try to promote object (string) columns to numeric/date/timestamp +when every non-empty value parses cleanly.""" + +CHAR_INFERENCE_MIN_VALUES = 3 +"""Don't attempt character-column coercion with fewer than this many non-empty +values; too small a sample is easy to mis-infer.""" + +NUMERIC_INT_RANGE = (-2_147_483_648, 2_147_483_647) +"""INTEGER bounds; anything outside becomes BIGINT.""" + + +VALID_IF_EXISTS = ("fail", "replace", "append") + + +# --------------------------------------------------------------------------- +# Dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class LoaderConfig: + filename: Path + schemaname: str + tablename: str + if_exists: str = "fail" + include: Optional[List[str]] = None + exclude: Optional[List[str]] = None + + +@dataclass +class ColumnSpec: + name: str + postgres_type: str + nullable: bool + sas_format: Optional[str] = None + source_dtype: Optional[str] = None + notes: List[str] = field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Custom exceptions +# --------------------------------------------------------------------------- + + +class TableExistsError(RuntimeError): + """Raised when if_exists=fail and the target table already exists.""" + + +class SchemaCompatibilityError(RuntimeError): + """Raised when if_exists=append and the incoming schema is not + compatible with the existing table.""" + + +class ValidationError(RuntimeError): + """Raised when --validate detects a mismatch against the manifest.""" + + +# --------------------------------------------------------------------------- +# Connection +# --------------------------------------------------------------------------- + + +def connect() -> psycopg2.extensions.connection: + """Open a psycopg2 connection using standard libpq env vars. + + Assumes `.env` has already been loaded (the CLI does this before calling). + Orchestrators that wrap this module should either call ``load_dotenv()`` + themselves or ensure the env vars are set. + """ + conn = psycopg2.connect( + host=os.environ.get("PGHOST"), + port=os.environ.get("PGPORT"), + user=os.environ.get("PGUSER"), + password=os.environ.get("PGPASSWORD"), + dbname=os.environ.get("PGDATABASE"), + ) + return conn + + +# --------------------------------------------------------------------------- +# Config loading +# --------------------------------------------------------------------------- + + +def load_config(path: Path) -> LoaderConfig: + """Parse and validate the YAML config at ``path``.""" + path = Path(path) + with path.open("r", encoding="utf-8") as f: + raw = yaml.safe_load(f) + + if not isinstance(raw, dict): + raise ValueError(f"Config at {path} must be a YAML mapping at the top level.") + + missing = [k for k in ("filename", "schemaname", "tablename") if k not in raw] + if missing: + raise ValueError(f"Config {path} missing required keys: {', '.join(missing)}") + + filename = Path(raw["filename"]) + if not filename.is_absolute(): + filename = (path.parent / filename).resolve() if (path.parent / filename).exists() else Path(raw["filename"]) + + schemaname = str(raw["schemaname"]) + tablename = str(raw["tablename"]) + + if_exists = str(raw.get("if_exists", "fail")).lower() + if if_exists not in VALID_IF_EXISTS: + raise ValueError( + f"Config {path}: if_exists={if_exists!r} is not one of {VALID_IF_EXISTS}" + ) + + include = raw.get("include") + exclude = raw.get("exclude") + if include is not None and exclude is not None: + raise ValueError( + f"Config {path}: 'include' and 'exclude' are mutually exclusive; set at most one." + ) + if include is not None and not isinstance(include, list): + raise ValueError(f"Config {path}: 'include' must be a list of column names.") + if exclude is not None and not isinstance(exclude, list): + raise ValueError(f"Config {path}: 'exclude' must be a list of column names.") + + return LoaderConfig( + filename=filename, + schemaname=schemaname, + tablename=tablename, + if_exists=if_exists, + include=[str(c) for c in include] if include is not None else None, + exclude=[str(c) for c in exclude] if exclude is not None else None, + ) + + +# --------------------------------------------------------------------------- +# Reader +# --------------------------------------------------------------------------- + + +def read_sas(path: Path) -> Tuple[pd.DataFrame, Any]: + """Dispatch to the right pyreadstat reader by extension. + + Invariants (learned the hard way while building the sample generator): + + * ``.xpt`` / ``.xport`` - no encoding arg; pyreadstat is flaky about + encoding on XPORT files it wrote itself. + * ``.sas7bdat`` - explicit ``encoding="latin-1"`` per colleague guidance. + """ + path = Path(path) + suffix = path.suffix.lower() + if suffix in (".xpt", ".xport"): + return pyreadstat.read_xport(str(path)) + if suffix == ".sas7bdat": + return pyreadstat.read_sas7bdat(str(path), encoding="latin-1") + raise ValueError(f"Unsupported SAS file extension: {suffix}") + + +# --------------------------------------------------------------------------- +# Column filtering +# --------------------------------------------------------------------------- + + +def apply_column_filter( + df: pd.DataFrame, + include: Optional[List[str]], + exclude: Optional[List[str]], +) -> pd.DataFrame: + """Restrict ``df`` to the requested columns. Names missing from the frame + raise a clear error rather than silently dropping.""" + if include is not None and exclude is not None: + raise ValueError("include and exclude are mutually exclusive.") + + if include is not None: + missing = [c for c in include if c not in df.columns] + if missing: + raise ValueError(f"include references unknown columns: {missing}") + return df.loc[:, list(include)].copy() + + if exclude is not None: + missing = [c for c in exclude if c not in df.columns] + if missing: + raise ValueError(f"exclude references unknown columns: {missing}") + return df.drop(columns=list(exclude)).copy() + + return df.copy() + + +# --------------------------------------------------------------------------- +# Type inference +# --------------------------------------------------------------------------- + + +_DATE_FORMAT_PREFIXES = ("DATE", "YYMMDD", "MMDDYY", "DDMMYY", "JULIAN") + + +def _format_driven_type(sas_format: Optional[str]) -> Optional[str]: + """Return a Postgres type inferred from the SAS format string, or None + if the format doesn't pin it down.""" + if not sas_format: + return None + fmt = sas_format.upper().lstrip() + # DATETIME must be checked before DATE since "DATETIME20." starts with "DATE". + if fmt.startswith("DATETIME"): + return "TIMESTAMP" + if fmt.startswith("TIME"): + return "TIME" + for prefix in _DATE_FORMAT_PREFIXES: + if fmt.startswith(prefix): + return "DATE" + return None + + +def _all_null(series: pd.Series) -> bool: + if pd.api.types.is_object_dtype(series): + return bool(series.map(lambda v: v is None or (isinstance(v, str) and v == "") or (isinstance(v, float) and pd.isna(v))).all()) + return bool(series.isna().all()) + + +def _char_missing_mask(series: pd.Series) -> pd.Series: + return series.map(lambda v: v is None or (isinstance(v, float) and pd.isna(v)) or (isinstance(v, str) and v == "")) + + +def _is_nullable(series: pd.Series) -> bool: + """True if the column has at least one missing value.""" + if pd.api.types.is_object_dtype(series): + return bool(_char_missing_mask(series).any()) + return bool(series.isna().any()) + + +def _numeric_int_target(series: pd.Series) -> Optional[str]: + """Given a numeric (float64) series, if every non-null value is a whole + number, return INTEGER or BIGINT depending on range; else None.""" + nonnull = series.dropna() + if nonnull.empty: + return None + # Whole-number test. Guard against inf. + try: + whole = ((nonnull % 1) == 0).all() + except TypeError: + return None + if not whole: + return None + lo, hi = NUMERIC_INT_RANGE + vmin = nonnull.min() + vmax = nonnull.max() + if lo <= vmin and vmax <= hi: + return "INTEGER" + return "BIGINT" + + +def _object_is_dates(series: pd.Series) -> Tuple[bool, bool]: + """Return (all-date-like, any-datetime). If every non-null value is a + ``datetime.date`` / ``datetime.datetime`` / ``pd.Timestamp``, return True + plus whether at least one carries a time component.""" + nonnull = series.dropna() + if nonnull.empty: + return False, False + any_datetime = False + for v in nonnull: + if isinstance(v, dt.datetime) or isinstance(v, pd.Timestamp): + any_datetime = True + continue + if isinstance(v, dt.date): + continue + return False, False + return True, any_datetime + + +def _try_int_coerce(values: List[str]) -> Optional[str]: + """If every value parses as an int, return INTEGER/BIGINT, else None.""" + ints: List[int] = [] + for v in values: + s = v.strip() + try: + ints.append(int(s)) + except ValueError: + return None + if not ints: + return None + lo, hi = NUMERIC_INT_RANGE + if all(lo <= i <= hi for i in ints): + return "INTEGER" + return "BIGINT" + + +def _try_float_coerce(values: List[str]) -> bool: + for v in values: + try: + float(v) + except ValueError: + return False + return True + + +def _try_date_coerce(values: List[str]) -> bool: + for v in values: + try: + dt.date.fromisoformat(v) + except (ValueError, TypeError): + return False + return True + + +def _try_datetime_coerce(values: List[str]) -> bool: + for v in values: + try: + dt.datetime.fromisoformat(v) + except (ValueError, TypeError): + return False + return True + + +def _infer_char_type(series: pd.Series) -> str: + """Object/string column inference. Returns a Postgres type string.""" + mask = _char_missing_mask(series) + nonempty = [str(v) for v in series[~mask].tolist()] + if not COERCE_CHAR_COLUMNS or len(nonempty) < CHAR_INFERENCE_MIN_VALUES: + return "TEXT" + + int_guess = _try_int_coerce(nonempty) + if int_guess is not None: + return int_guess + if _try_float_coerce(nonempty): + return "DOUBLE PRECISION" + if _try_date_coerce(nonempty): + return "DATE" + if _try_datetime_coerce(nonempty): + return "TIMESTAMP" + return "TEXT" + + +def infer_schema( + df: pd.DataFrame, + meta: Any, + *, + coerce_chars: bool = COERCE_CHAR_COLUMNS, +) -> Dict[str, ColumnSpec]: + """Infer a Postgres column spec for each column in ``df``. + + ``meta`` is the pyreadstat metadata object; we read + ``meta.original_variable_types`` (a dict keyed by column name) for + format-driven date/time/timestamp inference. + + The ``coerce_chars`` kwarg lets callers override the module-level + ``COERCE_CHAR_COLUMNS`` without mutating global state. Internally the + char-inference helpers still read the constant - a full override would + thread the flag through, but the one-knob story here is intentional. + """ + original_formats: Dict[str, str] = dict(getattr(meta, "original_variable_types", {}) or {}) + + # Temporarily flip the module-level flag if the caller asked us to. + global COERCE_CHAR_COLUMNS + saved = COERCE_CHAR_COLUMNS + COERCE_CHAR_COLUMNS = coerce_chars + try: + out: Dict[str, ColumnSpec] = {} + for col in df.columns: + series = df[col] + sas_format = original_formats.get(col) + notes: List[str] = [] + + pg_type = _format_driven_type(sas_format) + + if pg_type is None: + if _all_null(series): + pg_type = "TEXT" + notes.append("all-null column; defaulting to TEXT") + elif pd.api.types.is_datetime64_any_dtype(series): + pg_type = "TIMESTAMP" + elif pd.api.types.is_object_dtype(series): + is_dates, any_dt = _object_is_dates(series) + if is_dates: + pg_type = "TIMESTAMP" if any_dt else "DATE" + else: + pg_type = _infer_char_type(series) + elif pd.api.types.is_numeric_dtype(series): + int_target = _numeric_int_target(series) + if int_target is not None: + pg_type = int_target + else: + pg_type = "DOUBLE PRECISION" + else: + pg_type = "TEXT" + notes.append(f"unhandled dtype {series.dtype}; defaulting to TEXT") + + nullable = _is_nullable(series) + + out[col] = ColumnSpec( + name=col, + postgres_type=pg_type, + nullable=nullable, + sas_format=sas_format, + source_dtype=str(series.dtype), + notes=notes, + ) + return out + finally: + COERCE_CHAR_COLUMNS = saved + + +# --------------------------------------------------------------------------- +# Table management +# --------------------------------------------------------------------------- + + +def _quote_ident(ident: str) -> str: + """Quote a Postgres identifier. psycopg2 doesn't expose this directly + until 2.8+ with sql.Identifier; we do it by hand to stay driver-simple.""" + return '"' + ident.replace('"', '""') + '"' + + +def _qualified(schema: str, table: str) -> str: + return f"{_quote_ident(schema)}.{_quote_ident(table)}" + + +def _table_exists(conn, schema: str, table: str) -> bool: + with conn.cursor() as cur: + cur.execute( + "SELECT 1 FROM information_schema.tables " + "WHERE table_schema = %s AND table_name = %s", + (schema, table), + ) + return cur.fetchone() is not None + + +def render_create_table(schema: str, table: str, columns: Dict[str, ColumnSpec]) -> str: + lines = [] + for spec in columns.values(): + null_clause = "" if spec.nullable else " NOT NULL" + lines.append(f" {_quote_ident(spec.name)} {spec.postgres_type}{null_clause}") + body = ",\n".join(lines) + return f"CREATE TABLE {_qualified(schema, table)} (\n{body}\n);" + + +def _create_table_sql(conn, schema: str, table: str, columns: Dict[str, ColumnSpec]) -> None: + sql = render_create_table(schema, table, columns) + with conn.cursor() as cur: + cur.execute(sql) + + +def _drop_table(conn, schema: str, table: str) -> None: + with conn.cursor() as cur: + cur.execute(f"DROP TABLE {_qualified(schema, table)}") + + +# Normalization table: map both loader-emitted and Postgres-reported type +# strings to a single canonical family name. Ignore length/precision +# modifiers like VARCHAR(n) and NUMERIC(p,s). +_TYPE_NORMALIZATION: Dict[str, str] = { + "INTEGER": "integer", + "INT": "integer", + "INT4": "integer", + "BIGINT": "bigint", + "INT8": "bigint", + "SMALLINT": "smallint", + "INT2": "smallint", + "DOUBLE PRECISION": "double precision", + "FLOAT8": "double precision", + "REAL": "real", + "FLOAT4": "real", + "NUMERIC": "numeric", + "DECIMAL": "numeric", + "TEXT": "text", + "VARCHAR": "character varying", + "CHARACTER VARYING": "character varying", + "CHAR": "character", + "CHARACTER": "character", + "BPCHAR": "character", + "BOOLEAN": "boolean", + "BOOL": "boolean", + "DATE": "date", + "TIMESTAMP": "timestamp without time zone", + "TIMESTAMP WITHOUT TIME ZONE": "timestamp without time zone", + "TIMESTAMPTZ": "timestamp with time zone", + "TIMESTAMP WITH TIME ZONE": "timestamp with time zone", + "TIME": "time without time zone", + "TIME WITHOUT TIME ZONE": "time without time zone", + "TIMETZ": "time with time zone", + "TIME WITH TIME ZONE": "time with time zone", +} + + +def _normalize_type(pg_type: str) -> str: + """Strip length/precision modifiers and map to canonical family.""" + stripped = pg_type.strip().upper() + # Remove trailing (n) / (p,s) before the space-separated tail. + # Examples: "VARCHAR(10)" -> "VARCHAR"; "TIMESTAMP(6) WITHOUT TIME ZONE" -> "TIMESTAMP WITHOUT TIME ZONE" + import re + + stripped = re.sub(r"\(\s*\d+\s*(?:,\s*\d+\s*)?\)", "", stripped).strip() + # Collapse doubled whitespace after paren removal. + stripped = re.sub(r"\s+", " ", stripped) + return _TYPE_NORMALIZATION.get(stripped, stripped.lower()) + + +def _assert_schema_compatible( + conn, schema: str, table: str, columns: Dict[str, ColumnSpec] +) -> None: + """Pre-flight check for if_exists=append. See plan section on option B.""" + with conn.cursor() as cur: + cur.execute( + "SELECT column_name, data_type, is_nullable " + "FROM information_schema.columns " + "WHERE table_schema = %s AND table_name = %s", + (schema, table), + ) + existing = {row[0]: (row[1], row[2]) for row in cur.fetchall()} + + mismatches: List[str] = [] + warnings: List[str] = [] + + for name, spec in columns.items(): + if name not in existing: + mismatches.append( + f"column {name!r} not present in target {schema}.{table}" + ) + continue + target_type, target_nullable = existing[name] + inferred_norm = _normalize_type(spec.postgres_type) + target_norm = _normalize_type(target_type) + if inferred_norm != target_norm: + mismatches.append( + f"column {name!r}: inferred {spec.postgres_type} " + f"(normalized {inferred_norm!r}) but target is {target_type} " + f"(normalized {target_norm!r})" + ) + target_is_notnull = (target_nullable == "NO") + if spec.nullable and target_is_notnull: + warnings.append( + f"column {name!r}: incoming allows NULLs but target is NOT NULL; " + "COPY will fail if any NULLs appear" + ) + + for w in warnings: + print(f"[warn] {w}", file=sys.stderr) + + if mismatches: + raise SchemaCompatibilityError( + "append-mode schema compatibility check failed:\n - " + + "\n - ".join(mismatches) + ) + + +def create_table( + conn, + schema_name: str, + table_name: str, + columns: Dict[str, ColumnSpec], + if_exists: str, +) -> None: + """Create (or verify) the target table according to ``if_exists``.""" + if if_exists not in VALID_IF_EXISTS: + raise ValueError(f"if_exists must be one of {VALID_IF_EXISTS}, got {if_exists!r}") + + exists = _table_exists(conn, schema_name, table_name) + if exists: + if if_exists == "fail": + raise TableExistsError( + f"Table {schema_name}.{table_name} already exists and if_exists=fail" + ) + if if_exists == "replace": + _drop_table(conn, schema_name, table_name) + _create_table_sql(conn, schema_name, table_name, columns) + return + if if_exists == "append": + _assert_schema_compatible(conn, schema_name, table_name, columns) + return + else: + _create_table_sql(conn, schema_name, table_name, columns) + + +# --------------------------------------------------------------------------- +# COPY loading +# --------------------------------------------------------------------------- + + +def _seconds_to_time(v: Any) -> Optional[dt.time]: + if v is None: + return None + if isinstance(v, float) and pd.isna(v): + return None + if isinstance(v, dt.time): + return v + if isinstance(v, (dt.datetime, pd.Timestamp)): + return v.time() if not pd.isna(v) else None + try: + total = int(round(float(v))) + except (TypeError, ValueError): + return None + h, rem = divmod(total, 3600) + m, s = divmod(rem, 60) + # Clamp; TIME8. is always within a day. + h = max(0, min(h, 23)) + return dt.time(h, m, s) + + +def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.DataFrame: + """Materialize a copy of ``df`` with each column in the right shape for + ``to_csv`` so the CSV lands as valid input for the target Postgres type. + """ + out = pd.DataFrame(index=df.index) + for name, spec in columns.items(): + series = df[name] + pg = spec.postgres_type.upper() + + if pg in ("INTEGER", "BIGINT", "SMALLINT"): + if pd.api.types.is_object_dtype(series): + series = pd.to_numeric( + series.replace({"": None}), errors="coerce" + ) + out[name] = series.astype("Int64") + elif pg in ("DOUBLE PRECISION", "REAL", "NUMERIC"): + if pd.api.types.is_object_dtype(series): + series = pd.to_numeric( + series.replace({"": None}), errors="coerce" + ) + out[name] = series.astype("float64") + elif pg == "DATE": + if pd.api.types.is_datetime64_any_dtype(series): + out[name] = series.dt.date + elif pd.api.types.is_object_dtype(series): + def _to_date(v: Any) -> Optional[dt.date]: + if v is None or (isinstance(v, float) and pd.isna(v)): + return None + if isinstance(v, dt.datetime): + return v.date() + if isinstance(v, dt.date): + return v + if isinstance(v, str): + if v == "": + return None + try: + return dt.date.fromisoformat(v) + except ValueError: + return None + return None + out[name] = series.map(_to_date) + else: + out[name] = series + elif pg in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE", "TIMESTAMP WITH TIME ZONE"): + if pd.api.types.is_datetime64_any_dtype(series): + out[name] = series + elif pd.api.types.is_object_dtype(series): + def _to_dt(v: Any) -> Optional[dt.datetime]: + if v is None or (isinstance(v, float) and pd.isna(v)): + return None + if isinstance(v, dt.datetime): + return v + if isinstance(v, dt.date): + return dt.datetime(v.year, v.month, v.day) + if isinstance(v, pd.Timestamp): + return v.to_pydatetime() if not pd.isna(v) else None + if isinstance(v, str): + if v == "": + return None + try: + return dt.datetime.fromisoformat(v) + except ValueError: + return None + return None + out[name] = series.map(_to_dt) + else: + out[name] = series + elif pg in ("TIME", "TIME WITHOUT TIME ZONE", "TIME WITH TIME ZONE"): + out[name] = series.map(_seconds_to_time) + elif pg in ("TEXT", "VARCHAR", "CHARACTER VARYING", "CHAR", "CHARACTER"): + # Leave empty strings as "" so `NULL ''` in COPY turns them into NULL. + def _to_str(v: Any) -> Any: + if v is None: + return "" + if isinstance(v, float) and pd.isna(v): + return "" + return str(v) + out[name] = series.map(_to_str) + elif pg == "BOOLEAN": + out[name] = series.astype("boolean") if series.dtype != object else series + else: + out[name] = series + return out + + +def copy_dataframe( + conn, + schema_name: str, + table_name: str, + df: pd.DataFrame, + columns: Dict[str, ColumnSpec], +) -> int: + """Stream ``df`` into Postgres via ``COPY ... FROM STDIN``. + + Returns the number of rows inserted. + """ + prepared = _prepare_for_copy(df, columns) + + buf = io.StringIO() + prepared.to_csv( + buf, + index=False, + header=False, + na_rep="", + date_format="%Y-%m-%d %H:%M:%S", + ) + buf.seek(0) + + col_list = ", ".join(_quote_ident(name) for name in columns.keys()) + sql = ( + f"COPY {_qualified(schema_name, table_name)} ({col_list}) " + f"FROM STDIN WITH (FORMAT csv, NULL '')" + ) + + with conn.cursor() as cur: + cur.copy_expert(sql, buf) + rowcount = cur.rowcount + + return int(rowcount) if rowcount is not None else len(prepared) + + +# --------------------------------------------------------------------------- +# Manifest validation +# --------------------------------------------------------------------------- + + +def _match_manifest_type(inferred: str, manifest_entry: Dict[str, Any]) -> bool: + inferred_norm = _normalize_type(inferred) + if "postgres_type" in manifest_entry: + return inferred_norm == _normalize_type(manifest_entry["postgres_type"]) + if "acceptable_types" in manifest_entry: + return any( + inferred_norm == _normalize_type(t) + for t in manifest_entry["acceptable_types"] + ) + return False + + +def validate_against_manifest( + inferred: Dict[str, ColumnSpec], + manifest_path: Path, +) -> List[str]: + """Compare the inferred schema against the expected-types manifest. + + Returns a list of human-readable problem strings; empty list means OK. + """ + manifest_path = Path(manifest_path) + if not manifest_path.exists(): + return [f"manifest not found: {manifest_path}"] + + with manifest_path.open("r", encoding="utf-8") as f: + manifest = json.load(f) + + problems: List[str] = [] + + only_in_inferred = set(inferred) - set(manifest) + only_in_manifest = set(manifest) - set(inferred) + if only_in_inferred: + problems.append( + f"columns in inferred but not manifest: {sorted(only_in_inferred)}" + ) + if only_in_manifest: + problems.append( + f"columns in manifest but not inferred: {sorted(only_in_manifest)}" + ) + + for name, spec in inferred.items(): + entry = manifest.get(name) + if entry is None: + continue + if not _match_manifest_type(spec.postgres_type, entry): + expected = entry.get("postgres_type") or entry.get("acceptable_types") + problems.append( + f"column {name!r}: inferred {spec.postgres_type!r}, " + f"manifest expected {expected!r}" + ) + manifest_nullable = bool(entry.get("nullable", True)) + if spec.nullable and not manifest_nullable: + problems.append( + f"column {name!r}: inferred nullable, manifest expects NOT NULL " + f"(loosening nullability is never allowed)" + ) + + return problems + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def _build_argparser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( + description="Load a single SAS file (XPT or sas7bdat) into Postgres.", + ) + p.add_argument("--config", required=True, type=Path, help="Path to YAML config") + p.add_argument( + "--validate", + action="store_true", + help=( + "Compare inferred schema against .expected.json " + "next to the SAS file; exits nonzero on mismatch." + ), + ) + p.add_argument( + "--dry-run", + action="store_true", + help="Print inferred CREATE TABLE and stop; don't touch Postgres.", + ) + return p + + +def _format_columns_summary(columns: Dict[str, ColumnSpec]) -> str: + lines = [] + for spec in columns.values(): + null = "" if spec.nullable else " NOT NULL" + lines.append(f" {spec.name}: {spec.postgres_type}{null}") + return "\n".join(lines) + + +def main(argv: Optional[List[str]] = None) -> int: + args = _build_argparser().parse_args(argv) + + load_dotenv() + + cfg = load_config(args.config) + + if not cfg.filename.exists(): + print(f"error: SAS file not found: {cfg.filename}", file=sys.stderr) + return 2 + + df, meta = read_sas(cfg.filename) + df = apply_column_filter(df, cfg.include, cfg.exclude) + columns = infer_schema(df, meta) + + if args.validate: + manifest_path = cfg.filename.with_suffix("").with_suffix(".expected.json") + # The above strips .xpt then appends .expected.json, e.g. + # "sample_kitchensink.xpt" -> "sample_kitchensink.expected.json". + problems = validate_against_manifest(columns, manifest_path) + if problems: + print("validation failed:", file=sys.stderr) + for p in problems: + print(f" - {p}", file=sys.stderr) + return 1 + print(f"validation OK ({len(columns)} columns match {manifest_path.name})") + + if args.dry_run: + print(render_create_table(cfg.schemaname, cfg.tablename, columns)) + return 0 + + conn = connect() + conn.autocommit = False + try: + create_table(conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists) + inserted = copy_dataframe(conn, cfg.schemaname, cfg.tablename, df, columns) + conn.commit() + except Exception: + conn.rollback() + raise + finally: + conn.close() + + print( + f"loaded {inserted} rows into {cfg.schemaname}.{cfg.tablename} " + f"({len(columns)} columns)" + ) + print("final schema:") + print(_format_columns_summary(columns)) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/generic_loader/requirements.txt b/generic_loader/requirements.txt new file mode 100644 index 0000000..c481d42 --- /dev/null +++ b/generic_loader/requirements.txt @@ -0,0 +1,6 @@ +pandas>=2.0,<2.3 +pyreadstat>=1.2,<1.3 +numpy>=1.24,<2.1 +pyyaml>=6.0,<7.0 +psycopg2-binary>=2.9,<3.0 +python-dotenv>=1.0,<2.0 diff --git a/generic_loader/sample_config.yaml b/generic_loader/sample_config.yaml new file mode 100644 index 0000000..791205a --- /dev/null +++ b/generic_loader/sample_config.yaml @@ -0,0 +1,17 @@ +filename: samples/sample_kitchensink.xpt +schemaname: public +tablename: kitchensink + +# Optional. If set, only these columns are loaded. Mutually exclusive with exclude. +# include: +# - ID +# - INTCOL +# - DATECOL + +# Optional. Columns to drop. +# exclude: +# - ALLNULL + +# What to do if the target table already exists: fail | replace | append +# Defaults to fail. +if_exists: append