"""Per-file SAS-to-Postgres loader. Library-style functions plus a thin CLI wrapper. Designed so an orchestrator can wrap the library for directory/batch mode; orchestration is out of scope here. Python 3.9 compatible (target is an air-gapped host that currently only has 3.9). ``from __future__ import annotations`` lets us use PEP 585 generics as annotations; runtime-resolved type uses (dataclass defaults, etc.) stick to ``typing``. ------------------------------------------------------------------------------- USAGE ------------------------------------------------------------------------------- Supported inputs: * ``.sas7bdat`` (read with ``encoding="latin-1"``) * ``.xpt`` / ``.xport`` (SAS transport files) 1. YAML config -------------- Every invocation is driven by a YAML file describing one SAS file to load:: filename: samples/sample_kitchensink.xpt # required; relative paths are # resolved against the config # file's directory when possible schemaname: public # required tablename: kitchensink # required # Optional. One of: fail | replace | append. Default: fail. # fail - error out if the target table already exists # replace - DROP and recreate the table from the inferred schema # append - keep the existing table; pre-flight a schema-compat check, # then COPY the new rows in if_exists: append # Optional, mutually exclusive. Restrict which columns are loaded. # include: # - ID # - INTCOL # exclude: # - ALLNULL 2. Database connection ---------------------- The loader uses standard libpq environment variables (read via ``os.environ``):: PGHOST, PGPORT, PGUSER, PGPASSWORD, PGDATABASE The CLI calls ``python-dotenv``'s ``load_dotenv()`` at startup, so a local ``.env`` file is picked up automatically. Library callers are responsible for populating the environment themselves (either call ``load_dotenv()`` or export the vars) before calling :func:`connect`. 3. Command-line interface ------------------------- :: python load_sas.py --config path/to/config.yaml [--validate] [--dry-run] Flags: --config PATH Required. Path to the YAML config above. --validate Compare the inferred schema against ``.expected.json`` sitting next to the SAS file. Exits nonzero on mismatch. Safe to combine with ``--dry-run``. --dry-run Print the inferred ``CREATE TABLE`` SQL and stop. The database is never touched (no connection is opened). Exit codes: 0 - success (load completed, or dry-run/validate passed) 1 - validation failure 2 - config references a SAS file that does not exist Other nonzero - uncaught exception (traceback printed); the transaction is rolled back before exit. Typical invocations:: # Preview the inferred schema without connecting to Postgres. python load_sas.py --config sample_config.yaml --dry-run # Check the inferred schema against an expected-types manifest. python load_sas.py --config sample_config.yaml --validate --dry-run # Actually load the data. python load_sas.py --config sample_config.yaml 4. Expected-types manifest (``--validate``) ------------------------------------------- ``--validate`` looks for a JSON file named ``.expected.json`` next to the SAS file, e.g. ``samples/sample_kitchensink.xpt`` pairs with ``samples/sample_kitchensink.expected.json``. Each top-level key is a column name; the value is an object with any of:: { "postgres_type": "BIGINT", # exact expected type, OR "acceptable_types": ["TEXT", # any-of list of acceptable types "VARCHAR"], "nullable": true, # default true; false = must be NOT NULL "note": "free-form comment" # ignored by the loader } Type comparison ignores length/precision modifiers and normalizes synonyms (e.g. ``INT`` == ``INTEGER`` == ``INT4``; ``VARCHAR(10)`` == ``VARCHAR``). Nullability tightening (inferred NULL, manifest NOT NULL) is a hard failure; loosening is not checked here because the append-mode check already covers it. 5. Library usage ---------------- The CLI is a thin wrapper around composable functions. The preferred pattern infers the schema from a bounded preview and then streams the rest of the file chunk-by-chunk into ``COPY`` - crucial for SAS files with hundreds of millions of rows:: from dotenv import load_dotenv from load_sas import ( load_config, read_sas_preview, iter_sas_chunks, apply_column_filter, infer_schema, validate_against_manifest, render_create_table, connect, create_table, copy_dataframes, ) load_dotenv() cfg = load_config("config.yaml") # Schema from a preview slice (bounded by TYPE_INFERENCE_SAMPLE_ROWS). preview_df, meta = read_sas_preview(cfg.filename) preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude) total_rows = getattr(meta, "number_rows", None) columns = infer_schema(preview_df, meta, total_rows=total_rows) # Optional: preview DDL / validate against a manifest. print(render_create_table(cfg.schemaname, cfg.tablename, columns)) problems = validate_against_manifest(columns, Path("expected.json")) assert not problems, problems conn = connect() conn.autocommit = False try: create_table(conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists) chunks = ( apply_column_filter(df, cfg.include, cfg.exclude) for df, _ in iter_sas_chunks(cfg.filename) ) rows = copy_dataframes(conn, cfg.schemaname, cfg.tablename, chunks, columns) conn.commit() finally: conn.close() For small files (or tests) the legacy one-shot API still works: :func:`read_sas` returns the whole frame and :func:`copy_dataframe` copies it in one round trip. All functions are side-effect free except :func:`connect`, :func:`create_table`, :func:`copy_dataframe`, and :func:`copy_dataframes`; schema inference (:func:`infer_schema`) accepts a ``coerce_chars`` kwarg to override the module-level ``COERCE_CHAR_COLUMNS`` without mutating global state. 6. Type inference summary ------------------------- Priority order used by :func:`infer_schema`: 1. SAS format string (via ``meta.original_variable_types``): ``DATETIME*`` -> ``TIMESTAMP``, ``TIME*`` -> ``TIME``, ``DATE*`` / ``YYMMDD*`` / ``MMDDYY*`` / ``DDMMYY*`` / ``JULIAN*`` -> ``DATE``. 2. All-null column -> ``TEXT`` (with a note). 3. pandas datetime dtype -> ``TIMESTAMP``. 4. Object columns containing only ``datetime.date`` / ``datetime.datetime`` -> ``DATE`` or ``TIMESTAMP``. 5. Object columns of strings: if ``COERCE_CHAR_COLUMNS`` is True and at least ``CHAR_INFERENCE_MIN_VALUES`` non-empty values parse cleanly, they are promoted to ``INTEGER`` / ``BIGINT`` / ``DOUBLE PRECISION`` / ``DATE`` / ``TIMESTAMP``; otherwise ``TEXT``. 6. Numeric columns of whole numbers -> ``INTEGER`` (or ``BIGINT`` if any value exceeds the int32 range ``NUMERIC_INT_RANGE``); otherwise ``DOUBLE PRECISION``. Type inference scans only the first ``TYPE_INFERENCE_SAMPLE_ROWS`` rows for performance on large files. The CLI enforces this at read time via :func:`read_sas_preview`, so the whole file is never materialized just to pick types. Sampled specs carry an ``inferred_from_sample`` marker and the usual tradeoffs: if the first N rows fit ``INTEGER`` but a later row exceeds int32, or a column had no nulls in the preview but does later in the file, ``COPY`` will fail mid-stream and the whole transaction rolls back. Set ``TYPE_INFERENCE_SAMPLE_ROWS = None`` to scan every row when exact typing matters more than speed. Streaming loads use :func:`iter_sas_chunks` + :func:`copy_dataframes`, which share one cursor and transaction so a failure mid-file rolls back the whole load. 7. Tunables ----------- Module-level knobs at the top of this file: * ``COERCE_CHAR_COLUMNS`` - promote stringly-typed numerics / dates (default True). * ``CHAR_INFERENCE_MIN_VALUES`` - minimum non-empty sample size before char-column coercion is attempted. * ``NUMERIC_INT_RANGE`` - INTEGER bounds; values outside become ``BIGINT``. * ``TYPE_INFERENCE_SAMPLE_ROWS`` - cap on rows read for type inference (``None`` = scan the whole column). * ``DEFAULT_CHUNK_ROWS`` - rows per streaming COPY chunk. """ from __future__ import annotations import argparse import datetime as dt import io import json import os import sys from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Tuple import pandas as pd import psycopg2 import psycopg2.extensions import pyreadstat import yaml from dotenv import load_dotenv # --------------------------------------------------------------------------- # Top-level tunables # --------------------------------------------------------------------------- COERCE_CHAR_COLUMNS = True """If True, try to promote object (string) columns to numeric/date/timestamp when every non-empty value parses cleanly.""" CHAR_INFERENCE_MIN_VALUES = 3 """Don't attempt character-column coercion with fewer than this many non-empty values; too small a sample is easy to mis-infer.""" NUMERIC_INT_RANGE = (-2_147_483_648, 2_147_483_647) """INTEGER bounds; anything outside becomes BIGINT.""" TYPE_INFERENCE_SAMPLE_ROWS: Optional[int] = 10_000 """Cap on rows inspected during per-column type inference. Also governs how many rows :func:`read_sas_preview` pulls from the file for dry-run / validate / schema-inference flows. Set to ``None`` to scan every row (and read the whole file into memory for the preview step - don't do this on multi-hundred-million row files).""" DEFAULT_CHUNK_ROWS = 100_000 """Rows per chunk when streaming a SAS file into ``COPY``. Larger values mean fewer COPY round-trips but more peak memory per chunk; smaller values are gentler on memory.""" VALID_IF_EXISTS = ("fail", "replace", "append") # --------------------------------------------------------------------------- # Dataclasses # --------------------------------------------------------------------------- @dataclass class LoaderConfig: filename: Path schemaname: str tablename: str if_exists: str = "fail" include: Optional[List[str]] = None exclude: Optional[List[str]] = None @dataclass class ColumnSpec: name: str postgres_type: str nullable: bool sas_format: Optional[str] = None source_dtype: Optional[str] = None notes: List[str] = field(default_factory=list) sampled: bool = False """True when the type was inferred from a bounded preview rather than the full file. A sampled spec carries the usual sampling risks: a later chunk could contain a value that exceeds the inferred integer range, doesn't parse as the inferred type, or is null in a column the preview showed as non-null - all of which surface as mid-``COPY`` failures.""" # --------------------------------------------------------------------------- # Custom exceptions # --------------------------------------------------------------------------- class TableExistsError(RuntimeError): """Raised when if_exists=fail and the target table already exists.""" class SchemaCompatibilityError(RuntimeError): """Raised when if_exists=append and the incoming schema is not compatible with the existing table.""" class ValidationError(RuntimeError): """Raised when --validate detects a mismatch against the manifest.""" # --------------------------------------------------------------------------- # Connection # --------------------------------------------------------------------------- def connect() -> psycopg2.extensions.connection: """Open a psycopg2 connection using standard libpq env vars. Assumes `.env` has already been loaded (the CLI does this before calling). Orchestrators that wrap this module should either call ``load_dotenv()`` themselves or ensure the env vars are set. """ conn = psycopg2.connect( host=os.environ.get("PGHOST"), port=os.environ.get("PGPORT"), user=os.environ.get("PGUSER"), password=os.environ.get("PGPASSWORD"), dbname=os.environ.get("PGDATABASE"), ) return conn # --------------------------------------------------------------------------- # Config loading # --------------------------------------------------------------------------- def load_config(path: Path) -> LoaderConfig: """Parse and validate the YAML config at ``path``.""" path = Path(path) with path.open("r", encoding="utf-8") as f: raw = yaml.safe_load(f) if not isinstance(raw, dict): raise ValueError(f"Config at {path} must be a YAML mapping at the top level.") missing = [k for k in ("filename", "schemaname", "tablename") if k not in raw] if missing: raise ValueError(f"Config {path} missing required keys: {', '.join(missing)}") filename = Path(raw["filename"]) if not filename.is_absolute(): filename = (path.parent / filename).resolve() if (path.parent / filename).exists() else Path(raw["filename"]) schemaname = str(raw["schemaname"]) tablename = str(raw["tablename"]) if_exists = str(raw.get("if_exists", "fail")).lower() if if_exists not in VALID_IF_EXISTS: raise ValueError( f"Config {path}: if_exists={if_exists!r} is not one of {VALID_IF_EXISTS}" ) include = raw.get("include") exclude = raw.get("exclude") if include is not None and exclude is not None: raise ValueError( f"Config {path}: 'include' and 'exclude' are mutually exclusive; set at most one." ) if include is not None and not isinstance(include, list): raise ValueError(f"Config {path}: 'include' must be a list of column names.") if exclude is not None and not isinstance(exclude, list): raise ValueError(f"Config {path}: 'exclude' must be a list of column names.") return LoaderConfig( filename=filename, schemaname=schemaname, tablename=tablename, if_exists=if_exists, include=[str(c) for c in include] if include is not None else None, exclude=[str(c) for c in exclude] if exclude is not None else None, ) # --------------------------------------------------------------------------- # Reader # --------------------------------------------------------------------------- def _sas_reader(path: Path) -> Tuple[Any, Dict[str, Any]]: """Return ``(pyreadstat_reader, extra_kwargs)`` for ``path``. Invariants (learned the hard way while building the sample generator): * ``.xpt`` / ``.xport`` - no encoding arg; pyreadstat is flaky about encoding on XPORT files it wrote itself. * ``.sas7bdat`` - explicit ``encoding="latin-1"`` per colleague guidance. """ suffix = Path(path).suffix.lower() if suffix in (".xpt", ".xport"): return pyreadstat.read_xport, {} if suffix == ".sas7bdat": return pyreadstat.read_sas7bdat, {"encoding": "latin-1"} raise ValueError(f"Unsupported SAS file extension: {suffix}") def read_sas(path: Path) -> Tuple[pd.DataFrame, Any]: """Read an entire SAS file into memory. Only safe for small files. Kept for backward compatibility and tests; the CLI now uses :func:`read_sas_preview` + :func:`iter_sas_chunks` so it never materializes the whole frame at once. """ reader, kwargs = _sas_reader(path) return reader(str(Path(path)), **kwargs) def read_sas_preview( path: Path, *, rows: Optional[int] = None, ) -> Tuple[pd.DataFrame, Any]: """Read the first ``rows`` records from ``path`` plus its metadata. Defaults to ``TYPE_INFERENCE_SAMPLE_ROWS`` when ``rows`` is not given. Passing ``rows=None`` with ``TYPE_INFERENCE_SAMPLE_ROWS=None`` reads the whole file (pyreadstat treats ``row_limit=0`` as unlimited). """ reader, kwargs = _sas_reader(path) effective = rows if rows is not None else TYPE_INFERENCE_SAMPLE_ROWS row_limit = int(effective) if effective else 0 return reader(str(Path(path)), row_limit=row_limit, **kwargs) def iter_sas_chunks( path: Path, *, chunksize: int = DEFAULT_CHUNK_ROWS, ): """Yield ``(df_chunk, meta)`` tuples for streaming loads. Thin wrapper over ``pyreadstat.read_file_in_chunks`` that picks the right underlying reader by extension and threads through our encoding defaults. """ reader, kwargs = _sas_reader(path) yield from pyreadstat.read_file_in_chunks( reader, str(Path(path)), chunksize=chunksize, **kwargs ) # --------------------------------------------------------------------------- # Column filtering # --------------------------------------------------------------------------- def apply_column_filter( df: pd.DataFrame, include: Optional[List[str]], exclude: Optional[List[str]], ) -> pd.DataFrame: """Restrict ``df`` to the requested columns. Names missing from the frame raise a clear error rather than silently dropping.""" if include is not None and exclude is not None: raise ValueError("include and exclude are mutually exclusive.") if include is not None: missing = [c for c in include if c not in df.columns] if missing: raise ValueError(f"include references unknown columns: {missing}") return df.loc[:, list(include)].copy() if exclude is not None: missing = [c for c in exclude if c not in df.columns] if missing: raise ValueError(f"exclude references unknown columns: {missing}") return df.drop(columns=list(exclude)).copy() return df.copy() # --------------------------------------------------------------------------- # Type inference # --------------------------------------------------------------------------- _DATE_FORMAT_PREFIXES = ("DATE", "YYMMDD", "MMDDYY", "DDMMYY", "JULIAN") def _format_driven_type(sas_format: Optional[str]) -> Optional[str]: """Return a Postgres type inferred from the SAS format string, or None if the format doesn't pin it down.""" if not sas_format: return None fmt = sas_format.upper().lstrip() # DATETIME must be checked before DATE since "DATETIME20." starts with "DATE". if fmt.startswith("DATETIME"): return "TIMESTAMP" if fmt.startswith("TIME"): return "TIME" for prefix in _DATE_FORMAT_PREFIXES: if fmt.startswith(prefix): return "DATE" return None def _all_null(series: pd.Series) -> bool: if pd.api.types.is_object_dtype(series): return bool(series.map(lambda v: v is None or (isinstance(v, str) and v == "") or (isinstance(v, float) and pd.isna(v))).all()) return bool(series.isna().all()) def _char_missing_mask(series: pd.Series) -> pd.Series: return series.map(lambda v: v is None or (isinstance(v, float) and pd.isna(v)) or (isinstance(v, str) and v == "")) def _is_nullable(series: pd.Series) -> bool: """True if the column has at least one missing value.""" if pd.api.types.is_object_dtype(series): return bool(_char_missing_mask(series).any()) return bool(series.isna().any()) def _numeric_int_target(series: pd.Series) -> Optional[str]: """Given a numeric (float64) series, if every non-null value is a whole number, return INTEGER or BIGINT depending on range; else None.""" nonnull = series.dropna() if nonnull.empty: return None # Whole-number test. Guard against inf. try: whole = ((nonnull % 1) == 0).all() except TypeError: return None if not whole: return None lo, hi = NUMERIC_INT_RANGE vmin = nonnull.min() vmax = nonnull.max() if lo <= vmin and vmax <= hi: return "INTEGER" return "BIGINT" def _object_is_dates(series: pd.Series) -> Tuple[bool, bool]: """Return (all-date-like, any-datetime). If every non-null value is a ``datetime.date`` / ``datetime.datetime`` / ``pd.Timestamp``, return True plus whether at least one carries a time component.""" nonnull = series.dropna() if nonnull.empty: return False, False any_datetime = False for v in nonnull: if isinstance(v, dt.datetime) or isinstance(v, pd.Timestamp): any_datetime = True continue if isinstance(v, dt.date): continue return False, False return True, any_datetime def _try_int_coerce(values: List[str]) -> Optional[str]: """If every value parses as an int, return INTEGER/BIGINT, else None.""" ints: List[int] = [] for v in values: s = v.strip() try: ints.append(int(s)) except ValueError: return None if not ints: return None lo, hi = NUMERIC_INT_RANGE if all(lo <= i <= hi for i in ints): return "INTEGER" return "BIGINT" def _try_float_coerce(values: List[str]) -> bool: for v in values: try: float(v) except ValueError: return False return True def _try_date_coerce(values: List[str]) -> bool: for v in values: try: dt.date.fromisoformat(v) except (ValueError, TypeError): return False return True def _try_datetime_coerce(values: List[str]) -> bool: for v in values: try: dt.datetime.fromisoformat(v) except (ValueError, TypeError): return False return True def _infer_char_type(series: pd.Series) -> str: """Object/string column inference. Returns a Postgres type string.""" mask = _char_missing_mask(series) nonempty = [str(v) for v in series[~mask].tolist()] if not COERCE_CHAR_COLUMNS or len(nonempty) < CHAR_INFERENCE_MIN_VALUES: return "TEXT" int_guess = _try_int_coerce(nonempty) if int_guess is not None: return int_guess if _try_float_coerce(nonempty): return "DOUBLE PRECISION" if _try_date_coerce(nonempty): return "DATE" if _try_datetime_coerce(nonempty): return "TIMESTAMP" return "TEXT" def infer_schema( df: pd.DataFrame, meta: Any, *, coerce_chars: bool = COERCE_CHAR_COLUMNS, total_rows: Optional[int] = None, ) -> Dict[str, ColumnSpec]: """Infer a Postgres column spec for each column in ``df``. ``meta`` is the pyreadstat metadata object; we read ``meta.original_variable_types`` (a dict keyed by column name) for format-driven date/time/timestamp inference. The ``coerce_chars`` kwarg lets callers override the module-level ``COERCE_CHAR_COLUMNS`` without mutating global state. Internally the char-inference helpers still read the constant - a full override would thread the flag through, but the one-knob story here is intentional. ``total_rows`` lets callers who already sampled the frame (e.g. via :func:`read_sas_preview`) report the real file size in the per-column "inferred from first N of M rows" note. Falls back to ``len(df)``. """ original_formats: Dict[str, str] = dict(getattr(meta, "original_variable_types", {}) or {}) # Row-walking type probes run on a bounded head slice; nullability and the # all-null check still see every row so NOT NULL declarations stay honest. df_rows = len(df) effective_total = total_rows if total_rows is not None else df_rows if TYPE_INFERENCE_SAMPLE_ROWS is not None and df_rows > TYPE_INFERENCE_SAMPLE_ROWS: sample_df = df.head(TYPE_INFERENCE_SAMPLE_ROWS) sample_size = TYPE_INFERENCE_SAMPLE_ROWS else: sample_df = df sample_size = df_rows sampled = sample_size < effective_total # Temporarily flip the module-level flag if the caller asked us to. global COERCE_CHAR_COLUMNS saved = COERCE_CHAR_COLUMNS COERCE_CHAR_COLUMNS = coerce_chars try: out: Dict[str, ColumnSpec] = {} for col in df.columns: series = df[col] sample_series = sample_df[col] sas_format = original_formats.get(col) notes: List[str] = [] pg_type = _format_driven_type(sas_format) if pg_type is None: if _all_null(series): pg_type = "TEXT" notes.append("all-null column; defaulting to TEXT") elif pd.api.types.is_datetime64_any_dtype(series): pg_type = "TIMESTAMP" elif pd.api.types.is_object_dtype(series): is_dates, any_dt = _object_is_dates(sample_series) if is_dates: pg_type = "TIMESTAMP" if any_dt else "DATE" else: pg_type = _infer_char_type(sample_series) elif pd.api.types.is_numeric_dtype(series): int_target = _numeric_int_target(sample_series) if int_target is not None: pg_type = int_target else: pg_type = "DOUBLE PRECISION" else: pg_type = "TEXT" notes.append(f"unhandled dtype {series.dtype}; defaulting to TEXT") if sampled: notes.append( f"type inferred from first {sample_size:,} of " f"{effective_total:,} rows" ) nullable = _is_nullable(series) out[col] = ColumnSpec( name=col, postgres_type=pg_type, nullable=nullable, sas_format=sas_format, source_dtype=str(series.dtype), notes=notes, sampled=sampled, ) return out finally: COERCE_CHAR_COLUMNS = saved # --------------------------------------------------------------------------- # Table management # --------------------------------------------------------------------------- def _quote_ident(ident: str) -> str: """Quote a Postgres identifier. psycopg2 doesn't expose this directly until 2.8+ with sql.Identifier; we do it by hand to stay driver-simple.""" return '"' + ident.replace('"', '""') + '"' def _qualified(schema: str, table: str) -> str: return f"{_quote_ident(schema)}.{_quote_ident(table)}" def _table_exists(conn, schema: str, table: str) -> bool: with conn.cursor() as cur: cur.execute( "SELECT 1 FROM information_schema.tables " "WHERE table_schema = %s AND table_name = %s", (schema, table), ) return cur.fetchone() is not None def render_create_table(schema: str, table: str, columns: Dict[str, ColumnSpec]) -> str: lines = [] for spec in columns.values(): null_clause = "" if spec.nullable else " NOT NULL" lines.append(f" {_quote_ident(spec.name)} {spec.postgres_type}{null_clause}") body = ",\n".join(lines) return f"CREATE TABLE {_qualified(schema, table)} (\n{body}\n);" def _create_table_sql(conn, schema: str, table: str, columns: Dict[str, ColumnSpec]) -> None: sql = render_create_table(schema, table, columns) with conn.cursor() as cur: cur.execute(sql) def _drop_table(conn, schema: str, table: str) -> None: with conn.cursor() as cur: cur.execute(f"DROP TABLE {_qualified(schema, table)}") # Normalization table: map both loader-emitted and Postgres-reported type # strings to a single canonical family name. Ignore length/precision # modifiers like VARCHAR(n) and NUMERIC(p,s). _TYPE_NORMALIZATION: Dict[str, str] = { "INTEGER": "integer", "INT": "integer", "INT4": "integer", "BIGINT": "bigint", "INT8": "bigint", "SMALLINT": "smallint", "INT2": "smallint", "DOUBLE PRECISION": "double precision", "FLOAT8": "double precision", "REAL": "real", "FLOAT4": "real", "NUMERIC": "numeric", "DECIMAL": "numeric", "TEXT": "text", "VARCHAR": "character varying", "CHARACTER VARYING": "character varying", "CHAR": "character", "CHARACTER": "character", "BPCHAR": "character", "BOOLEAN": "boolean", "BOOL": "boolean", "DATE": "date", "TIMESTAMP": "timestamp without time zone", "TIMESTAMP WITHOUT TIME ZONE": "timestamp without time zone", "TIMESTAMPTZ": "timestamp with time zone", "TIMESTAMP WITH TIME ZONE": "timestamp with time zone", "TIME": "time without time zone", "TIME WITHOUT TIME ZONE": "time without time zone", "TIMETZ": "time with time zone", "TIME WITH TIME ZONE": "time with time zone", } def _normalize_type(pg_type: str) -> str: """Strip length/precision modifiers and map to canonical family.""" stripped = pg_type.strip().upper() # Remove trailing (n) / (p,s) before the space-separated tail. # Examples: "VARCHAR(10)" -> "VARCHAR"; "TIMESTAMP(6) WITHOUT TIME ZONE" -> "TIMESTAMP WITHOUT TIME ZONE" import re stripped = re.sub(r"\(\s*\d+\s*(?:,\s*\d+\s*)?\)", "", stripped).strip() # Collapse doubled whitespace after paren removal. stripped = re.sub(r"\s+", " ", stripped) return _TYPE_NORMALIZATION.get(stripped, stripped.lower()) def _assert_schema_compatible( conn, schema: str, table: str, columns: Dict[str, ColumnSpec] ) -> None: """Pre-flight check for if_exists=append. See plan section on option B.""" with conn.cursor() as cur: cur.execute( "SELECT column_name, data_type, is_nullable " "FROM information_schema.columns " "WHERE table_schema = %s AND table_name = %s", (schema, table), ) existing = {row[0]: (row[1], row[2]) for row in cur.fetchall()} mismatches: List[str] = [] warnings: List[str] = [] for name, spec in columns.items(): if name not in existing: mismatches.append( f"column {name!r} not present in target {schema}.{table}" ) continue target_type, target_nullable = existing[name] inferred_norm = _normalize_type(spec.postgres_type) target_norm = _normalize_type(target_type) if inferred_norm != target_norm: mismatches.append( f"column {name!r}: inferred {spec.postgres_type} " f"(normalized {inferred_norm!r}) but target is {target_type} " f"(normalized {target_norm!r})" ) target_is_notnull = (target_nullable == "NO") if spec.nullable and target_is_notnull: warnings.append( f"column {name!r}: incoming allows NULLs but target is NOT NULL; " "COPY will fail if any NULLs appear" ) for w in warnings: print(f"[warn] {w}", file=sys.stderr) if mismatches: raise SchemaCompatibilityError( "append-mode schema compatibility check failed:\n - " + "\n - ".join(mismatches) ) def assert_schema_compatible( conn, schema_name: str, table_name: str, columns: Dict[str, ColumnSpec], ) -> None: """Public wrapper around :func:`_assert_schema_compatible`. Intended for orchestrators (e.g. the folder loader) that append multiple files into one table and need to re-run the same compatibility check that ``if_exists=append`` performs internally. Raises :class:`SchemaCompatibilityError` on mismatch. """ _assert_schema_compatible(conn, schema_name, table_name, columns) def create_table( conn, schema_name: str, table_name: str, columns: Dict[str, ColumnSpec], if_exists: str, ) -> None: """Create (or verify) the target table according to ``if_exists``.""" if if_exists not in VALID_IF_EXISTS: raise ValueError(f"if_exists must be one of {VALID_IF_EXISTS}, got {if_exists!r}") exists = _table_exists(conn, schema_name, table_name) if exists: if if_exists == "fail": raise TableExistsError( f"Table {schema_name}.{table_name} already exists and if_exists=fail" ) if if_exists == "replace": _drop_table(conn, schema_name, table_name) _create_table_sql(conn, schema_name, table_name, columns) return if if_exists == "append": _assert_schema_compatible(conn, schema_name, table_name, columns) return else: _create_table_sql(conn, schema_name, table_name, columns) # --------------------------------------------------------------------------- # COPY loading # --------------------------------------------------------------------------- def _seconds_to_time(v: Any) -> Optional[dt.time]: if v is None: return None if isinstance(v, float) and pd.isna(v): return None if isinstance(v, dt.time): return v if isinstance(v, (dt.datetime, pd.Timestamp)): return v.time() if not pd.isna(v) else None try: total = int(round(float(v))) except (TypeError, ValueError): return None h, rem = divmod(total, 3600) m, s = divmod(rem, 60) # Clamp; TIME8. is always within a day. h = max(0, min(h, 23)) return dt.time(h, m, s) def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.DataFrame: """Materialize a copy of ``df`` with each column in the right shape for ``to_csv`` so the CSV lands as valid input for the target Postgres type. """ out = pd.DataFrame(index=df.index) for name, spec in columns.items(): series = df[name] pg = spec.postgres_type.upper() if pg in ("INTEGER", "BIGINT", "SMALLINT"): if pd.api.types.is_object_dtype(series): series = pd.to_numeric( series.replace({"": None}), errors="coerce" ) out[name] = series.astype("Int64") elif pg in ("DOUBLE PRECISION", "REAL", "NUMERIC"): if pd.api.types.is_object_dtype(series): series = pd.to_numeric( series.replace({"": None}), errors="coerce" ) out[name] = series.astype("float64") elif pg == "DATE": if pd.api.types.is_datetime64_any_dtype(series): out[name] = series.dt.date elif pd.api.types.is_object_dtype(series): def _to_date(v: Any) -> Optional[dt.date]: if v is None or (isinstance(v, float) and pd.isna(v)): return None if isinstance(v, dt.datetime): return v.date() if isinstance(v, dt.date): return v if isinstance(v, str): if v == "": return None try: return dt.date.fromisoformat(v) except ValueError: return None return None out[name] = series.map(_to_date) else: out[name] = series elif pg in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE", "TIMESTAMP WITH TIME ZONE"): if pd.api.types.is_datetime64_any_dtype(series): out[name] = series elif pd.api.types.is_object_dtype(series): def _to_dt(v: Any) -> Optional[dt.datetime]: if v is None or (isinstance(v, float) and pd.isna(v)): return None if isinstance(v, dt.datetime): return v if isinstance(v, dt.date): return dt.datetime(v.year, v.month, v.day) if isinstance(v, pd.Timestamp): return v.to_pydatetime() if not pd.isna(v) else None if isinstance(v, str): if v == "": return None try: return dt.datetime.fromisoformat(v) except ValueError: return None return None out[name] = series.map(_to_dt) else: out[name] = series elif pg in ("TIME", "TIME WITHOUT TIME ZONE", "TIME WITH TIME ZONE"): out[name] = series.map(_seconds_to_time) elif pg in ("TEXT", "VARCHAR", "CHARACTER VARYING", "CHAR", "CHARACTER"): # Leave empty strings as "" so `NULL ''` in COPY turns them into NULL. def _to_str(v: Any) -> Any: if v is None: return "" if isinstance(v, float) and pd.isna(v): return "" return str(v) out[name] = series.map(_to_str) elif pg == "BOOLEAN": out[name] = series.astype("boolean") if series.dtype != object else series else: out[name] = series return out def copy_dataframes( conn, schema_name: str, table_name: str, dfs: Iterable[pd.DataFrame], columns: Dict[str, ColumnSpec], ) -> int: """Stream an iterable of DataFrames into one ``COPY`` session. All chunks share a cursor and transaction, so a failure mid-stream rolls back the whole load when the caller hasn't committed yet. Empty chunks are skipped. Returns the total rows inserted. """ col_list = ", ".join(_quote_ident(name) for name in columns.keys()) sql = ( f"COPY {_qualified(schema_name, table_name)} ({col_list}) " f"FROM STDIN WITH (FORMAT csv, NULL '')" ) total = 0 with conn.cursor() as cur: for df in dfs: if df.empty: continue prepared = _prepare_for_copy(df, columns) buf = io.StringIO() prepared.to_csv( buf, index=False, header=False, na_rep="", date_format="%Y-%m-%d %H:%M:%S", ) buf.seek(0) cur.copy_expert(sql, buf) total += len(prepared) return total def copy_dataframe( conn, schema_name: str, table_name: str, df: pd.DataFrame, columns: Dict[str, ColumnSpec], ) -> int: """Stream ``df`` into Postgres via ``COPY ... FROM STDIN``. Convenience wrapper around :func:`copy_dataframes` for single-frame callers. Returns the number of rows inserted. """ return copy_dataframes(conn, schema_name, table_name, [df], columns) # --------------------------------------------------------------------------- # Manifest validation # --------------------------------------------------------------------------- def _match_manifest_type(inferred: str, manifest_entry: Dict[str, Any]) -> bool: inferred_norm = _normalize_type(inferred) if "postgres_type" in manifest_entry: return inferred_norm == _normalize_type(manifest_entry["postgres_type"]) if "acceptable_types" in manifest_entry: return any( inferred_norm == _normalize_type(t) for t in manifest_entry["acceptable_types"] ) return False def validate_against_manifest( inferred: Dict[str, ColumnSpec], manifest_path: Path, ) -> List[str]: """Compare the inferred schema against the expected-types manifest. Returns a list of human-readable problem strings; empty list means OK. """ manifest_path = Path(manifest_path) if not manifest_path.exists(): return [f"manifest not found: {manifest_path}"] with manifest_path.open("r", encoding="utf-8") as f: manifest = json.load(f) problems: List[str] = [] only_in_inferred = set(inferred) - set(manifest) only_in_manifest = set(manifest) - set(inferred) if only_in_inferred: problems.append( f"columns in inferred but not manifest: {sorted(only_in_inferred)}" ) if only_in_manifest: problems.append( f"columns in manifest but not inferred: {sorted(only_in_manifest)}" ) for name, spec in inferred.items(): entry = manifest.get(name) if entry is None: continue if not _match_manifest_type(spec.postgres_type, entry): expected = entry.get("postgres_type") or entry.get("acceptable_types") problems.append( f"column {name!r}: inferred {spec.postgres_type!r}, " f"manifest expected {expected!r}" ) manifest_nullable = bool(entry.get("nullable", True)) if spec.nullable and not manifest_nullable: problems.append( f"column {name!r}: inferred nullable, manifest expects NOT NULL " f"(loosening nullability is never allowed)" ) return problems # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def _build_argparser() -> argparse.ArgumentParser: p = argparse.ArgumentParser( description="Load a single SAS file (XPT or sas7bdat) into Postgres.", ) p.add_argument("--config", required=True, type=Path, help="Path to YAML config") p.add_argument( "--validate", action="store_true", help=( "Compare inferred schema against .expected.json " "next to the SAS file; exits nonzero on mismatch." ), ) p.add_argument( "--dry-run", action="store_true", help="Print inferred CREATE TABLE and stop; don't touch Postgres.", ) return p def _format_columns_summary(columns: Dict[str, ColumnSpec]) -> str: lines = [] for spec in columns.values(): null = "" if spec.nullable else " NOT NULL" lines.append(f" {spec.name}: {spec.postgres_type}{null}") return "\n".join(lines) def main(argv: Optional[List[str]] = None) -> int: args = _build_argparser().parse_args(argv) load_dotenv() cfg = load_config(args.config) if not cfg.filename.exists(): print(f"error: SAS file not found: {cfg.filename}", file=sys.stderr) return 2 # Schema inference uses a bounded preview read so we never load a # hundreds-of-millions-of-rows file into memory just to pick types. # NB: ``meta.number_rows`` on a ``row_limit``-ed read reflects rows # returned, not the file's total, so we don't trust it here. preview_df, meta = read_sas_preview(cfg.filename) preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude) columns = infer_schema(preview_df, meta) if args.validate: manifest_path = cfg.filename.with_suffix("").with_suffix(".expected.json") # The above strips .xpt then appends .expected.json, e.g. # "sample_kitchensink.xpt" -> "sample_kitchensink.expected.json". problems = validate_against_manifest(columns, manifest_path) if problems: print("validation failed:", file=sys.stderr) for p in problems: print(f" - {p}", file=sys.stderr) return 1 print(f"validation OK ({len(columns)} columns match {manifest_path.name})") if args.dry_run: print(render_create_table(cfg.schemaname, cfg.tablename, columns)) return 0 # Release the preview frame before opening the stream - lets the GC reclaim # it while we're holding a Postgres transaction open. del preview_df def _filtered_chunks(): seen = 0 for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename): chunk_df = apply_column_filter(chunk_df, cfg.include, cfg.exclude) seen += len(chunk_df) print(f" streaming... {seen:,} rows", file=sys.stderr) yield chunk_df conn = connect() conn.autocommit = False try: create_table(conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists) inserted = copy_dataframes( conn, cfg.schemaname, cfg.tablename, _filtered_chunks(), columns ) conn.commit() except Exception: conn.rollback() raise finally: conn.close() print( f"loaded {inserted} rows into {cfg.schemaname}.{cfg.tablename} " f"({len(columns)} columns)" ) print("final schema:") print(_format_columns_summary(columns)) return 0 if __name__ == "__main__": sys.exit(main())