Adding generic loader
This commit is contained in:
commit
f681f1012a
5
generic_loader/.env.example
Normal file
5
generic_loader/.env.example
Normal file
@ -0,0 +1,5 @@
|
||||
PGHOST=localhost
|
||||
PGPORT=5432
|
||||
PGUSER=
|
||||
PGPASSWORD=
|
||||
PGDATABASE=
|
||||
3
generic_loader/.gitignore
vendored
Normal file
3
generic_loader/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
/.venv
|
||||
/samples
|
||||
/.env
|
||||
380
generic_loader/generate_sample_sas.py
Normal file
380
generic_loader/generate_sample_sas.py
Normal file
@ -0,0 +1,380 @@
|
||||
"""Generate a kitchen-sink SAS XPORT file plus an expected-types manifest.
|
||||
|
||||
Running this script produces two files under samples/:
|
||||
- sample_kitchensink.xpt the SAS XPORT test fixture
|
||||
- sample_kitchensink.expected.json ground-truth Postgres types for the loader
|
||||
|
||||
Tune behavior via the top-level constants below.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as dt
|
||||
import json
|
||||
import string
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyreadstat
|
||||
|
||||
SEED = 42
|
||||
N_ROWS = 1000
|
||||
NULL_FRACTION = 0.20
|
||||
OUT_DIR = Path("samples")
|
||||
OUT_PATH = OUT_DIR / "sample_kitchensink.xpt"
|
||||
MANIFEST_PATH = OUT_DIR / "sample_kitchensink.expected.json"
|
||||
|
||||
POSITIVE_CONTROLS = {"ID", "INTCOL", "STRCOL", "DATECOL", "CONST"}
|
||||
ALL_NULL_COLS = {"ALLNULL", "ALLNULLC"}
|
||||
|
||||
|
||||
def _missing_mask(rng: np.random.Generator, n: int, frac: float) -> np.ndarray:
|
||||
"""Return a boolean array of length n with exactly round(frac * n) True positions.
|
||||
|
||||
Using an exact count (rather than per-row Bernoulli draws) keeps the observed
|
||||
missing fraction tight so the round-trip assertion can use a small tolerance.
|
||||
"""
|
||||
mask = np.zeros(n, dtype=bool)
|
||||
k = int(round(frac * n))
|
||||
if k > 0:
|
||||
idx = rng.choice(n, size=k, replace=False)
|
||||
mask[idx] = True
|
||||
return mask
|
||||
|
||||
|
||||
def _random_word(rng: np.random.Generator, min_len: int = 3, max_len: int = 10) -> str:
|
||||
length = int(rng.integers(min_len, max_len + 1))
|
||||
letters = np.array(list(string.ascii_lowercase))
|
||||
return "".join(rng.choice(letters, size=length))
|
||||
|
||||
|
||||
def _random_sentence(rng: np.random.Generator, min_words: int = 8, max_words: int = 20) -> str:
|
||||
n_words = int(rng.integers(min_words, max_words + 1))
|
||||
return " ".join(_random_word(rng) for _ in range(n_words))
|
||||
|
||||
|
||||
def build_dataframe(rng: np.random.Generator) -> pd.DataFrame:
|
||||
n = N_ROWS
|
||||
|
||||
ids = np.arange(1, n + 1, dtype=np.int64)
|
||||
|
||||
int_vals = rng.integers(0, 1000, size=n).astype(np.float64)
|
||||
|
||||
bigint_vals = rng.integers(10_000_000_000, 20_000_000_000, size=n).astype(np.float64)
|
||||
bigint_vals[_missing_mask(rng, n, NULL_FRACTION)] = np.nan
|
||||
|
||||
float_vals = rng.normal(loc=100.0, scale=15.0, size=n)
|
||||
float_vals[_missing_mask(rng, n, NULL_FRACTION)] = np.nan
|
||||
|
||||
bool_vals = rng.integers(0, 2, size=n).astype(np.float64)
|
||||
bool_vals[_missing_mask(rng, n, NULL_FRACTION)] = np.nan
|
||||
|
||||
str_vals = [_random_word(rng, 3, 8) for _ in range(n)]
|
||||
|
||||
long_str_vals: list[str] = []
|
||||
long_mask = _missing_mask(rng, n, NULL_FRACTION)
|
||||
for i in range(n):
|
||||
long_str_vals.append("" if long_mask[i] else _random_sentence(rng))
|
||||
|
||||
base_date = dt.date(2020, 1, 1)
|
||||
date_vals = [base_date + dt.timedelta(days=int(rng.integers(0, 2000))) for _ in range(n)]
|
||||
|
||||
dt_vals_mask = _missing_mask(rng, n, NULL_FRACTION)
|
||||
dt_vals: list = []
|
||||
base_dt = dt.datetime(2020, 1, 1, 0, 0, 0)
|
||||
for i in range(n):
|
||||
if dt_vals_mask[i]:
|
||||
dt_vals.append(pd.NaT)
|
||||
else:
|
||||
offset_seconds = int(rng.integers(0, 2000 * 24 * 3600))
|
||||
dt_vals.append(base_dt + dt.timedelta(seconds=offset_seconds))
|
||||
dt_series = pd.to_datetime(dt_vals)
|
||||
|
||||
time_mask = _missing_mask(rng, n, NULL_FRACTION)
|
||||
time_vals: list = []
|
||||
for i in range(n):
|
||||
if time_mask[i]:
|
||||
time_vals.append(None)
|
||||
else:
|
||||
seconds_into_day = int(rng.integers(0, 24 * 3600))
|
||||
h, rem = divmod(seconds_into_day, 3600)
|
||||
m, s = divmod(rem, 60)
|
||||
time_vals.append(dt.time(h, m, s))
|
||||
|
||||
numasstr_mask = _missing_mask(rng, n, NULL_FRACTION)
|
||||
numasstr_vals: list[str] = []
|
||||
for i in range(n):
|
||||
if numasstr_mask[i]:
|
||||
numasstr_vals.append("")
|
||||
elif rng.random() < 0.5:
|
||||
numasstr_vals.append(str(int(rng.integers(-500, 500))))
|
||||
else:
|
||||
numasstr_vals.append(f"{rng.normal(0, 50):.2f}")
|
||||
|
||||
dateasstr_mask = _missing_mask(rng, n, NULL_FRACTION)
|
||||
dateasstr_vals: list[str] = []
|
||||
for i in range(n):
|
||||
if dateasstr_mask[i]:
|
||||
dateasstr_vals.append("")
|
||||
else:
|
||||
d = base_date + dt.timedelta(days=int(rng.integers(0, 2000)))
|
||||
dateasstr_vals.append(d.isoformat())
|
||||
|
||||
mixed_mask = _missing_mask(rng, n, NULL_FRACTION)
|
||||
mixed_vals: list[str] = []
|
||||
choices = ["number", "date", "text", "text"]
|
||||
for i in range(n):
|
||||
if mixed_mask[i]:
|
||||
mixed_vals.append("")
|
||||
continue
|
||||
kind = choices[int(rng.integers(0, len(choices)))]
|
||||
if kind == "number":
|
||||
mixed_vals.append(str(int(rng.integers(0, 1000))))
|
||||
elif kind == "date":
|
||||
d = base_date + dt.timedelta(days=int(rng.integers(0, 2000)))
|
||||
mixed_vals.append(d.isoformat())
|
||||
else:
|
||||
mixed_vals.append(_random_word(rng, 4, 12))
|
||||
|
||||
const_vals = ["CONSTANT"] * n
|
||||
|
||||
allnull_vals = np.full(n, np.nan, dtype=np.float64)
|
||||
allnullc_vals = [""] * n
|
||||
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"ID": ids,
|
||||
"INTCOL": int_vals,
|
||||
"BIGINT": bigint_vals,
|
||||
"FLOATCOL": float_vals,
|
||||
"BOOLCOL": bool_vals,
|
||||
"STRCOL": str_vals,
|
||||
"LONGSTR": long_str_vals,
|
||||
"DATECOL": date_vals,
|
||||
"DTCOL": dt_series,
|
||||
"TIMECOL": time_vals,
|
||||
"NUMASSTR": numasstr_vals,
|
||||
"DATEASTR": dateasstr_vals,
|
||||
"MIXED": mixed_vals,
|
||||
"CONST": const_vals,
|
||||
"ALLNULL": allnull_vals,
|
||||
"ALLNULLC": allnullc_vals,
|
||||
}
|
||||
)
|
||||
return df
|
||||
|
||||
|
||||
COLUMN_LABELS: dict[str, str] = {
|
||||
"ID": "Row identifier",
|
||||
"INTCOL": "Integer positive control",
|
||||
"BIGINT": "Big integer beyond int32 range",
|
||||
"FLOATCOL": "Floating point with decimals",
|
||||
"BOOLCOL": "Nullable boolean 0/1/NaN",
|
||||
"STRCOL": "Short string positive control",
|
||||
"LONGSTR": "Longer free-text string",
|
||||
"DATECOL": "Date positive control",
|
||||
"DTCOL": "Datetime with missing values",
|
||||
"TIMECOL": "Time of day with missing values",
|
||||
"NUMASSTR": "Numeric-looking strings in a char column",
|
||||
"DATEASTR": "Date-looking strings in a char column",
|
||||
"MIXED": "Heterogeneous strings: fallback to text",
|
||||
"CONST": "Constant repeated value",
|
||||
"ALLNULL": "Entirely missing numeric column",
|
||||
"ALLNULLC": "Entirely missing character column",
|
||||
}
|
||||
|
||||
|
||||
VARIABLE_FORMATS: dict[str, str] = {
|
||||
"DATECOL": "DATE9.",
|
||||
"DTCOL": "DATETIME20.",
|
||||
"TIMECOL": "TIME8.",
|
||||
}
|
||||
|
||||
|
||||
EXPECTED_MANIFEST: dict[str, dict] = {
|
||||
"ID": {"postgres_type": "INTEGER", "nullable": False},
|
||||
"INTCOL": {"postgres_type": "INTEGER", "nullable": False, "note": "positive control"},
|
||||
"BIGINT": {"postgres_type": "BIGINT", "nullable": True, "note": "values beyond int32 range"},
|
||||
"FLOATCOL": {"acceptable_types": ["DOUBLE PRECISION", "NUMERIC"], "nullable": True},
|
||||
"BOOLCOL": {
|
||||
"acceptable_types": ["BOOLEAN", "SMALLINT", "INTEGER"],
|
||||
"nullable": True,
|
||||
"note": "{0,1,NaN} is genuinely ambiguous; loader's choice is a design decision",
|
||||
},
|
||||
"STRCOL": {"acceptable_types": ["TEXT", "VARCHAR"], "nullable": False, "note": "positive control"},
|
||||
"LONGSTR": {"acceptable_types": ["TEXT", "VARCHAR"], "nullable": True},
|
||||
"DATECOL": {"postgres_type": "DATE", "nullable": False, "note": "positive control"},
|
||||
"DTCOL": {"acceptable_types": ["TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE"], "nullable": True},
|
||||
"TIMECOL": {"postgres_type": "TIME", "nullable": True},
|
||||
"NUMASSTR": {
|
||||
"acceptable_types": ["NUMERIC", "DOUBLE PRECISION"],
|
||||
"nullable": True,
|
||||
"note": "stored as char in SAS; loader should coerce numeric-looking strings",
|
||||
},
|
||||
"DATEASTR": {
|
||||
"postgres_type": "DATE",
|
||||
"nullable": True,
|
||||
"note": "stored as char in SAS; loader should coerce ISO-date strings",
|
||||
},
|
||||
"MIXED": {
|
||||
"acceptable_types": ["TEXT", "VARCHAR"],
|
||||
"nullable": True,
|
||||
"note": "heterogeneous content; loader should fall back to text",
|
||||
},
|
||||
"CONST": {"acceptable_types": ["TEXT", "VARCHAR"], "nullable": False},
|
||||
"ALLNULL": {
|
||||
"acceptable_types": ["TEXT", "VARCHAR"],
|
||||
"nullable": True,
|
||||
"note": "entirely null numeric; loader must pick a default type, typically TEXT",
|
||||
},
|
||||
"ALLNULLC": {
|
||||
"acceptable_types": ["TEXT", "VARCHAR"],
|
||||
"nullable": True,
|
||||
"note": "entirely null character",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def write_manifest(df: pd.DataFrame) -> None:
|
||||
manifest_cols = set(EXPECTED_MANIFEST.keys())
|
||||
df_cols = set(df.columns)
|
||||
missing = df_cols - manifest_cols
|
||||
extra = manifest_cols - df_cols
|
||||
if missing or extra:
|
||||
raise AssertionError(
|
||||
f"Manifest/DataFrame column mismatch. Missing from manifest: {missing}. "
|
||||
f"Extra in manifest: {extra}."
|
||||
)
|
||||
with MANIFEST_PATH.open("w", encoding="utf-8") as f:
|
||||
json.dump(EXPECTED_MANIFEST, f, indent=2, sort_keys=True)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def _char_missing_fraction(series: pd.Series) -> float:
|
||||
return float((series.fillna("").astype(str) == "").mean())
|
||||
|
||||
|
||||
def _numeric_missing_fraction(series: pd.Series) -> float:
|
||||
return float(series.isna().mean())
|
||||
|
||||
|
||||
def verify_roundtrip(source_df: pd.DataFrame) -> pd.DataFrame:
|
||||
# Use pyreadstat (the writer) to verify the writer's output. pyreadstat preserves
|
||||
# SAS format metadata on readback, so we can confirm the date/datetime/time
|
||||
# variable_format mappings actually took effect.
|
||||
readback, _meta = pyreadstat.read_xport(str(OUT_PATH))
|
||||
|
||||
assert len(readback.columns) == len(source_df.columns), (
|
||||
f"Column count mismatch: wrote {len(source_df.columns)}, read back {len(readback.columns)}"
|
||||
)
|
||||
assert set(readback.columns) == set(source_df.columns), (
|
||||
f"Column name mismatch. Only in source: {set(source_df.columns) - set(readback.columns)}. "
|
||||
f"Only in readback: {set(readback.columns) - set(source_df.columns)}."
|
||||
)
|
||||
assert len(readback) == len(source_df), (
|
||||
f"Row count mismatch: wrote {len(source_df)}, read back {len(readback)}"
|
||||
)
|
||||
|
||||
for col in ("DATECOL", "DTCOL"):
|
||||
dtype = readback[col].dtype
|
||||
is_datetime = pd.api.types.is_datetime64_any_dtype(dtype)
|
||||
is_object_of_dates = pd.api.types.is_object_dtype(dtype) and readback[col].dropna().map(
|
||||
lambda v: isinstance(v, (dt.date, dt.datetime, pd.Timestamp))
|
||||
).all()
|
||||
assert is_datetime or is_object_of_dates, (
|
||||
f"{col} came back as {dtype}; expected datetime-like. "
|
||||
f"variable_format mapping may not have taken effect."
|
||||
)
|
||||
|
||||
time_dtype = readback["TIMECOL"].dtype
|
||||
time_ok = (
|
||||
pd.api.types.is_datetime64_any_dtype(time_dtype)
|
||||
or pd.api.types.is_numeric_dtype(time_dtype)
|
||||
or (
|
||||
pd.api.types.is_object_dtype(time_dtype)
|
||||
and readback["TIMECOL"].dropna().map(
|
||||
lambda v: isinstance(v, (dt.time, dt.datetime, pd.Timestamp, int, float))
|
||||
).all()
|
||||
)
|
||||
)
|
||||
assert time_ok, f"TIMECOL came back as {time_dtype}; expected datetime/numeric/time-object"
|
||||
|
||||
tol = 0.10
|
||||
for col in source_df.columns:
|
||||
if col in POSITIVE_CONTROLS:
|
||||
series = readback[col]
|
||||
if pd.api.types.is_numeric_dtype(series) or pd.api.types.is_datetime64_any_dtype(series):
|
||||
observed = _numeric_missing_fraction(series)
|
||||
else:
|
||||
observed = _char_missing_fraction(series)
|
||||
assert observed == 0.0, (
|
||||
f"Positive control {col!r} has {observed:.2%} missing; expected 0%."
|
||||
)
|
||||
continue
|
||||
|
||||
if col in ALL_NULL_COLS:
|
||||
series = readback[col]
|
||||
if pd.api.types.is_numeric_dtype(series):
|
||||
observed = _numeric_missing_fraction(series)
|
||||
else:
|
||||
observed = _char_missing_fraction(series)
|
||||
assert observed == 1.0, (
|
||||
f"All-null column {col!r} has {observed:.2%} missing; expected 100%."
|
||||
)
|
||||
continue
|
||||
|
||||
series = readback[col]
|
||||
if pd.api.types.is_numeric_dtype(series) or pd.api.types.is_datetime64_any_dtype(series):
|
||||
observed = _numeric_missing_fraction(series)
|
||||
else:
|
||||
observed = _char_missing_fraction(series)
|
||||
assert abs(observed - NULL_FRACTION) <= tol, (
|
||||
f"Column {col!r}: observed missing fraction {observed:.2%} not within "
|
||||
f"±{tol:.0%} of NULL_FRACTION={NULL_FRACTION:.2%}."
|
||||
)
|
||||
|
||||
assert MANIFEST_PATH.exists(), f"Manifest file {MANIFEST_PATH} missing."
|
||||
with MANIFEST_PATH.open("r", encoding="utf-8") as f:
|
||||
manifest = json.load(f)
|
||||
assert set(manifest.keys()) == set(readback.columns), (
|
||||
f"Manifest/readback column set mismatch. "
|
||||
f"Only in manifest: {set(manifest.keys()) - set(readback.columns)}. "
|
||||
f"Only in readback: {set(readback.columns) - set(manifest.keys())}."
|
||||
)
|
||||
|
||||
return readback
|
||||
|
||||
|
||||
def main() -> None:
|
||||
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
rng = np.random.default_rng(SEED)
|
||||
df = build_dataframe(rng)
|
||||
|
||||
pyreadstat.write_xport(
|
||||
df,
|
||||
str(OUT_PATH),
|
||||
file_format_version=5,
|
||||
table_name="SAMPLE",
|
||||
file_label="Kitchen sink sample for loader testing",
|
||||
column_labels=COLUMN_LABELS,
|
||||
variable_format=VARIABLE_FORMATS,
|
||||
)
|
||||
|
||||
write_manifest(df)
|
||||
|
||||
readback = verify_roundtrip(df)
|
||||
|
||||
print(f"Wrote {OUT_PATH} ({N_ROWS} rows x {len(df.columns)} cols)")
|
||||
print(f"Wrote {MANIFEST_PATH}")
|
||||
print()
|
||||
print("Readback via pyreadstat.read_xport (same reader the loader will use):")
|
||||
print(readback.dtypes.to_string())
|
||||
print()
|
||||
print("Readback head:")
|
||||
print(readback.head().to_string())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
904
generic_loader/load_sas.py
Normal file
904
generic_loader/load_sas.py
Normal file
@ -0,0 +1,904 @@
|
||||
"""Per-file SAS-to-Postgres loader.
|
||||
|
||||
Library-style functions plus a thin CLI wrapper. Designed so an orchestrator
|
||||
can wrap the library for directory/batch mode; orchestration is out of scope
|
||||
here.
|
||||
|
||||
Python 3.9 compatible (target is an air-gapped host that currently only has
|
||||
3.9). ``from __future__ import annotations`` lets us use PEP 585 generics
|
||||
as annotations; runtime-resolved type uses (dataclass defaults, etc.) stick
|
||||
to ``typing``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import pandas as pd
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
import pyreadstat
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level tunables
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
COERCE_CHAR_COLUMNS = True
|
||||
"""If True, try to promote object (string) columns to numeric/date/timestamp
|
||||
when every non-empty value parses cleanly."""
|
||||
|
||||
CHAR_INFERENCE_MIN_VALUES = 3
|
||||
"""Don't attempt character-column coercion with fewer than this many non-empty
|
||||
values; too small a sample is easy to mis-infer."""
|
||||
|
||||
NUMERIC_INT_RANGE = (-2_147_483_648, 2_147_483_647)
|
||||
"""INTEGER bounds; anything outside becomes BIGINT."""
|
||||
|
||||
|
||||
VALID_IF_EXISTS = ("fail", "replace", "append")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoaderConfig:
|
||||
filename: Path
|
||||
schemaname: str
|
||||
tablename: str
|
||||
if_exists: str = "fail"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ColumnSpec:
|
||||
name: str
|
||||
postgres_type: str
|
||||
nullable: bool
|
||||
sas_format: Optional[str] = None
|
||||
source_dtype: Optional[str] = None
|
||||
notes: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TableExistsError(RuntimeError):
|
||||
"""Raised when if_exists=fail and the target table already exists."""
|
||||
|
||||
|
||||
class SchemaCompatibilityError(RuntimeError):
|
||||
"""Raised when if_exists=append and the incoming schema is not
|
||||
compatible with the existing table."""
|
||||
|
||||
|
||||
class ValidationError(RuntimeError):
|
||||
"""Raised when --validate detects a mismatch against the manifest."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Connection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def connect() -> psycopg2.extensions.connection:
|
||||
"""Open a psycopg2 connection using standard libpq env vars.
|
||||
|
||||
Assumes `.env` has already been loaded (the CLI does this before calling).
|
||||
Orchestrators that wrap this module should either call ``load_dotenv()``
|
||||
themselves or ensure the env vars are set.
|
||||
"""
|
||||
conn = psycopg2.connect(
|
||||
host=os.environ.get("PGHOST"),
|
||||
port=os.environ.get("PGPORT"),
|
||||
user=os.environ.get("PGUSER"),
|
||||
password=os.environ.get("PGPASSWORD"),
|
||||
dbname=os.environ.get("PGDATABASE"),
|
||||
)
|
||||
return conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_config(path: Path) -> LoaderConfig:
|
||||
"""Parse and validate the YAML config at ``path``."""
|
||||
path = Path(path)
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
raw = yaml.safe_load(f)
|
||||
|
||||
if not isinstance(raw, dict):
|
||||
raise ValueError(f"Config at {path} must be a YAML mapping at the top level.")
|
||||
|
||||
missing = [k for k in ("filename", "schemaname", "tablename") if k not in raw]
|
||||
if missing:
|
||||
raise ValueError(f"Config {path} missing required keys: {', '.join(missing)}")
|
||||
|
||||
filename = Path(raw["filename"])
|
||||
if not filename.is_absolute():
|
||||
filename = (path.parent / filename).resolve() if (path.parent / filename).exists() else Path(raw["filename"])
|
||||
|
||||
schemaname = str(raw["schemaname"])
|
||||
tablename = str(raw["tablename"])
|
||||
|
||||
if_exists = str(raw.get("if_exists", "fail")).lower()
|
||||
if if_exists not in VALID_IF_EXISTS:
|
||||
raise ValueError(
|
||||
f"Config {path}: if_exists={if_exists!r} is not one of {VALID_IF_EXISTS}"
|
||||
)
|
||||
|
||||
include = raw.get("include")
|
||||
exclude = raw.get("exclude")
|
||||
if include is not None and exclude is not None:
|
||||
raise ValueError(
|
||||
f"Config {path}: 'include' and 'exclude' are mutually exclusive; set at most one."
|
||||
)
|
||||
if include is not None and not isinstance(include, list):
|
||||
raise ValueError(f"Config {path}: 'include' must be a list of column names.")
|
||||
if exclude is not None and not isinstance(exclude, list):
|
||||
raise ValueError(f"Config {path}: 'exclude' must be a list of column names.")
|
||||
|
||||
return LoaderConfig(
|
||||
filename=filename,
|
||||
schemaname=schemaname,
|
||||
tablename=tablename,
|
||||
if_exists=if_exists,
|
||||
include=[str(c) for c in include] if include is not None else None,
|
||||
exclude=[str(c) for c in exclude] if exclude is not None else None,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def read_sas(path: Path) -> Tuple[pd.DataFrame, Any]:
|
||||
"""Dispatch to the right pyreadstat reader by extension.
|
||||
|
||||
Invariants (learned the hard way while building the sample generator):
|
||||
|
||||
* ``.xpt`` / ``.xport`` - no encoding arg; pyreadstat is flaky about
|
||||
encoding on XPORT files it wrote itself.
|
||||
* ``.sas7bdat`` - explicit ``encoding="latin-1"`` per colleague guidance.
|
||||
"""
|
||||
path = Path(path)
|
||||
suffix = path.suffix.lower()
|
||||
if suffix in (".xpt", ".xport"):
|
||||
return pyreadstat.read_xport(str(path))
|
||||
if suffix == ".sas7bdat":
|
||||
return pyreadstat.read_sas7bdat(str(path), encoding="latin-1")
|
||||
raise ValueError(f"Unsupported SAS file extension: {suffix}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Column filtering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_column_filter(
|
||||
df: pd.DataFrame,
|
||||
include: Optional[List[str]],
|
||||
exclude: Optional[List[str]],
|
||||
) -> pd.DataFrame:
|
||||
"""Restrict ``df`` to the requested columns. Names missing from the frame
|
||||
raise a clear error rather than silently dropping."""
|
||||
if include is not None and exclude is not None:
|
||||
raise ValueError("include and exclude are mutually exclusive.")
|
||||
|
||||
if include is not None:
|
||||
missing = [c for c in include if c not in df.columns]
|
||||
if missing:
|
||||
raise ValueError(f"include references unknown columns: {missing}")
|
||||
return df.loc[:, list(include)].copy()
|
||||
|
||||
if exclude is not None:
|
||||
missing = [c for c in exclude if c not in df.columns]
|
||||
if missing:
|
||||
raise ValueError(f"exclude references unknown columns: {missing}")
|
||||
return df.drop(columns=list(exclude)).copy()
|
||||
|
||||
return df.copy()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type inference
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_DATE_FORMAT_PREFIXES = ("DATE", "YYMMDD", "MMDDYY", "DDMMYY", "JULIAN")
|
||||
|
||||
|
||||
def _format_driven_type(sas_format: Optional[str]) -> Optional[str]:
|
||||
"""Return a Postgres type inferred from the SAS format string, or None
|
||||
if the format doesn't pin it down."""
|
||||
if not sas_format:
|
||||
return None
|
||||
fmt = sas_format.upper().lstrip()
|
||||
# DATETIME must be checked before DATE since "DATETIME20." starts with "DATE".
|
||||
if fmt.startswith("DATETIME"):
|
||||
return "TIMESTAMP"
|
||||
if fmt.startswith("TIME"):
|
||||
return "TIME"
|
||||
for prefix in _DATE_FORMAT_PREFIXES:
|
||||
if fmt.startswith(prefix):
|
||||
return "DATE"
|
||||
return None
|
||||
|
||||
|
||||
def _all_null(series: pd.Series) -> bool:
|
||||
if pd.api.types.is_object_dtype(series):
|
||||
return bool(series.map(lambda v: v is None or (isinstance(v, str) and v == "") or (isinstance(v, float) and pd.isna(v))).all())
|
||||
return bool(series.isna().all())
|
||||
|
||||
|
||||
def _char_missing_mask(series: pd.Series) -> pd.Series:
|
||||
return series.map(lambda v: v is None or (isinstance(v, float) and pd.isna(v)) or (isinstance(v, str) and v == ""))
|
||||
|
||||
|
||||
def _is_nullable(series: pd.Series) -> bool:
|
||||
"""True if the column has at least one missing value."""
|
||||
if pd.api.types.is_object_dtype(series):
|
||||
return bool(_char_missing_mask(series).any())
|
||||
return bool(series.isna().any())
|
||||
|
||||
|
||||
def _numeric_int_target(series: pd.Series) -> Optional[str]:
|
||||
"""Given a numeric (float64) series, if every non-null value is a whole
|
||||
number, return INTEGER or BIGINT depending on range; else None."""
|
||||
nonnull = series.dropna()
|
||||
if nonnull.empty:
|
||||
return None
|
||||
# Whole-number test. Guard against inf.
|
||||
try:
|
||||
whole = ((nonnull % 1) == 0).all()
|
||||
except TypeError:
|
||||
return None
|
||||
if not whole:
|
||||
return None
|
||||
lo, hi = NUMERIC_INT_RANGE
|
||||
vmin = nonnull.min()
|
||||
vmax = nonnull.max()
|
||||
if lo <= vmin and vmax <= hi:
|
||||
return "INTEGER"
|
||||
return "BIGINT"
|
||||
|
||||
|
||||
def _object_is_dates(series: pd.Series) -> Tuple[bool, bool]:
|
||||
"""Return (all-date-like, any-datetime). If every non-null value is a
|
||||
``datetime.date`` / ``datetime.datetime`` / ``pd.Timestamp``, return True
|
||||
plus whether at least one carries a time component."""
|
||||
nonnull = series.dropna()
|
||||
if nonnull.empty:
|
||||
return False, False
|
||||
any_datetime = False
|
||||
for v in nonnull:
|
||||
if isinstance(v, dt.datetime) or isinstance(v, pd.Timestamp):
|
||||
any_datetime = True
|
||||
continue
|
||||
if isinstance(v, dt.date):
|
||||
continue
|
||||
return False, False
|
||||
return True, any_datetime
|
||||
|
||||
|
||||
def _try_int_coerce(values: List[str]) -> Optional[str]:
|
||||
"""If every value parses as an int, return INTEGER/BIGINT, else None."""
|
||||
ints: List[int] = []
|
||||
for v in values:
|
||||
s = v.strip()
|
||||
try:
|
||||
ints.append(int(s))
|
||||
except ValueError:
|
||||
return None
|
||||
if not ints:
|
||||
return None
|
||||
lo, hi = NUMERIC_INT_RANGE
|
||||
if all(lo <= i <= hi for i in ints):
|
||||
return "INTEGER"
|
||||
return "BIGINT"
|
||||
|
||||
|
||||
def _try_float_coerce(values: List[str]) -> bool:
|
||||
for v in values:
|
||||
try:
|
||||
float(v)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _try_date_coerce(values: List[str]) -> bool:
|
||||
for v in values:
|
||||
try:
|
||||
dt.date.fromisoformat(v)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _try_datetime_coerce(values: List[str]) -> bool:
|
||||
for v in values:
|
||||
try:
|
||||
dt.datetime.fromisoformat(v)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _infer_char_type(series: pd.Series) -> str:
|
||||
"""Object/string column inference. Returns a Postgres type string."""
|
||||
mask = _char_missing_mask(series)
|
||||
nonempty = [str(v) for v in series[~mask].tolist()]
|
||||
if not COERCE_CHAR_COLUMNS or len(nonempty) < CHAR_INFERENCE_MIN_VALUES:
|
||||
return "TEXT"
|
||||
|
||||
int_guess = _try_int_coerce(nonempty)
|
||||
if int_guess is not None:
|
||||
return int_guess
|
||||
if _try_float_coerce(nonempty):
|
||||
return "DOUBLE PRECISION"
|
||||
if _try_date_coerce(nonempty):
|
||||
return "DATE"
|
||||
if _try_datetime_coerce(nonempty):
|
||||
return "TIMESTAMP"
|
||||
return "TEXT"
|
||||
|
||||
|
||||
def infer_schema(
|
||||
df: pd.DataFrame,
|
||||
meta: Any,
|
||||
*,
|
||||
coerce_chars: bool = COERCE_CHAR_COLUMNS,
|
||||
) -> Dict[str, ColumnSpec]:
|
||||
"""Infer a Postgres column spec for each column in ``df``.
|
||||
|
||||
``meta`` is the pyreadstat metadata object; we read
|
||||
``meta.original_variable_types`` (a dict keyed by column name) for
|
||||
format-driven date/time/timestamp inference.
|
||||
|
||||
The ``coerce_chars`` kwarg lets callers override the module-level
|
||||
``COERCE_CHAR_COLUMNS`` without mutating global state. Internally the
|
||||
char-inference helpers still read the constant - a full override would
|
||||
thread the flag through, but the one-knob story here is intentional.
|
||||
"""
|
||||
original_formats: Dict[str, str] = dict(getattr(meta, "original_variable_types", {}) or {})
|
||||
|
||||
# Temporarily flip the module-level flag if the caller asked us to.
|
||||
global COERCE_CHAR_COLUMNS
|
||||
saved = COERCE_CHAR_COLUMNS
|
||||
COERCE_CHAR_COLUMNS = coerce_chars
|
||||
try:
|
||||
out: Dict[str, ColumnSpec] = {}
|
||||
for col in df.columns:
|
||||
series = df[col]
|
||||
sas_format = original_formats.get(col)
|
||||
notes: List[str] = []
|
||||
|
||||
pg_type = _format_driven_type(sas_format)
|
||||
|
||||
if pg_type is None:
|
||||
if _all_null(series):
|
||||
pg_type = "TEXT"
|
||||
notes.append("all-null column; defaulting to TEXT")
|
||||
elif pd.api.types.is_datetime64_any_dtype(series):
|
||||
pg_type = "TIMESTAMP"
|
||||
elif pd.api.types.is_object_dtype(series):
|
||||
is_dates, any_dt = _object_is_dates(series)
|
||||
if is_dates:
|
||||
pg_type = "TIMESTAMP" if any_dt else "DATE"
|
||||
else:
|
||||
pg_type = _infer_char_type(series)
|
||||
elif pd.api.types.is_numeric_dtype(series):
|
||||
int_target = _numeric_int_target(series)
|
||||
if int_target is not None:
|
||||
pg_type = int_target
|
||||
else:
|
||||
pg_type = "DOUBLE PRECISION"
|
||||
else:
|
||||
pg_type = "TEXT"
|
||||
notes.append(f"unhandled dtype {series.dtype}; defaulting to TEXT")
|
||||
|
||||
nullable = _is_nullable(series)
|
||||
|
||||
out[col] = ColumnSpec(
|
||||
name=col,
|
||||
postgres_type=pg_type,
|
||||
nullable=nullable,
|
||||
sas_format=sas_format,
|
||||
source_dtype=str(series.dtype),
|
||||
notes=notes,
|
||||
)
|
||||
return out
|
||||
finally:
|
||||
COERCE_CHAR_COLUMNS = saved
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Table management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _quote_ident(ident: str) -> str:
|
||||
"""Quote a Postgres identifier. psycopg2 doesn't expose this directly
|
||||
until 2.8+ with sql.Identifier; we do it by hand to stay driver-simple."""
|
||||
return '"' + ident.replace('"', '""') + '"'
|
||||
|
||||
|
||||
def _qualified(schema: str, table: str) -> str:
|
||||
return f"{_quote_ident(schema)}.{_quote_ident(table)}"
|
||||
|
||||
|
||||
def _table_exists(conn, schema: str, table: str) -> bool:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT 1 FROM information_schema.tables "
|
||||
"WHERE table_schema = %s AND table_name = %s",
|
||||
(schema, table),
|
||||
)
|
||||
return cur.fetchone() is not None
|
||||
|
||||
|
||||
def render_create_table(schema: str, table: str, columns: Dict[str, ColumnSpec]) -> str:
|
||||
lines = []
|
||||
for spec in columns.values():
|
||||
null_clause = "" if spec.nullable else " NOT NULL"
|
||||
lines.append(f" {_quote_ident(spec.name)} {spec.postgres_type}{null_clause}")
|
||||
body = ",\n".join(lines)
|
||||
return f"CREATE TABLE {_qualified(schema, table)} (\n{body}\n);"
|
||||
|
||||
|
||||
def _create_table_sql(conn, schema: str, table: str, columns: Dict[str, ColumnSpec]) -> None:
|
||||
sql = render_create_table(schema, table, columns)
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql)
|
||||
|
||||
|
||||
def _drop_table(conn, schema: str, table: str) -> None:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"DROP TABLE {_qualified(schema, table)}")
|
||||
|
||||
|
||||
# Normalization table: map both loader-emitted and Postgres-reported type
|
||||
# strings to a single canonical family name. Ignore length/precision
|
||||
# modifiers like VARCHAR(n) and NUMERIC(p,s).
|
||||
_TYPE_NORMALIZATION: Dict[str, str] = {
|
||||
"INTEGER": "integer",
|
||||
"INT": "integer",
|
||||
"INT4": "integer",
|
||||
"BIGINT": "bigint",
|
||||
"INT8": "bigint",
|
||||
"SMALLINT": "smallint",
|
||||
"INT2": "smallint",
|
||||
"DOUBLE PRECISION": "double precision",
|
||||
"FLOAT8": "double precision",
|
||||
"REAL": "real",
|
||||
"FLOAT4": "real",
|
||||
"NUMERIC": "numeric",
|
||||
"DECIMAL": "numeric",
|
||||
"TEXT": "text",
|
||||
"VARCHAR": "character varying",
|
||||
"CHARACTER VARYING": "character varying",
|
||||
"CHAR": "character",
|
||||
"CHARACTER": "character",
|
||||
"BPCHAR": "character",
|
||||
"BOOLEAN": "boolean",
|
||||
"BOOL": "boolean",
|
||||
"DATE": "date",
|
||||
"TIMESTAMP": "timestamp without time zone",
|
||||
"TIMESTAMP WITHOUT TIME ZONE": "timestamp without time zone",
|
||||
"TIMESTAMPTZ": "timestamp with time zone",
|
||||
"TIMESTAMP WITH TIME ZONE": "timestamp with time zone",
|
||||
"TIME": "time without time zone",
|
||||
"TIME WITHOUT TIME ZONE": "time without time zone",
|
||||
"TIMETZ": "time with time zone",
|
||||
"TIME WITH TIME ZONE": "time with time zone",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_type(pg_type: str) -> str:
|
||||
"""Strip length/precision modifiers and map to canonical family."""
|
||||
stripped = pg_type.strip().upper()
|
||||
# Remove trailing (n) / (p,s) before the space-separated tail.
|
||||
# Examples: "VARCHAR(10)" -> "VARCHAR"; "TIMESTAMP(6) WITHOUT TIME ZONE" -> "TIMESTAMP WITHOUT TIME ZONE"
|
||||
import re
|
||||
|
||||
stripped = re.sub(r"\(\s*\d+\s*(?:,\s*\d+\s*)?\)", "", stripped).strip()
|
||||
# Collapse doubled whitespace after paren removal.
|
||||
stripped = re.sub(r"\s+", " ", stripped)
|
||||
return _TYPE_NORMALIZATION.get(stripped, stripped.lower())
|
||||
|
||||
|
||||
def _assert_schema_compatible(
|
||||
conn, schema: str, table: str, columns: Dict[str, ColumnSpec]
|
||||
) -> None:
|
||||
"""Pre-flight check for if_exists=append. See plan section on option B."""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT column_name, data_type, is_nullable "
|
||||
"FROM information_schema.columns "
|
||||
"WHERE table_schema = %s AND table_name = %s",
|
||||
(schema, table),
|
||||
)
|
||||
existing = {row[0]: (row[1], row[2]) for row in cur.fetchall()}
|
||||
|
||||
mismatches: List[str] = []
|
||||
warnings: List[str] = []
|
||||
|
||||
for name, spec in columns.items():
|
||||
if name not in existing:
|
||||
mismatches.append(
|
||||
f"column {name!r} not present in target {schema}.{table}"
|
||||
)
|
||||
continue
|
||||
target_type, target_nullable = existing[name]
|
||||
inferred_norm = _normalize_type(spec.postgres_type)
|
||||
target_norm = _normalize_type(target_type)
|
||||
if inferred_norm != target_norm:
|
||||
mismatches.append(
|
||||
f"column {name!r}: inferred {spec.postgres_type} "
|
||||
f"(normalized {inferred_norm!r}) but target is {target_type} "
|
||||
f"(normalized {target_norm!r})"
|
||||
)
|
||||
target_is_notnull = (target_nullable == "NO")
|
||||
if spec.nullable and target_is_notnull:
|
||||
warnings.append(
|
||||
f"column {name!r}: incoming allows NULLs but target is NOT NULL; "
|
||||
"COPY will fail if any NULLs appear"
|
||||
)
|
||||
|
||||
for w in warnings:
|
||||
print(f"[warn] {w}", file=sys.stderr)
|
||||
|
||||
if mismatches:
|
||||
raise SchemaCompatibilityError(
|
||||
"append-mode schema compatibility check failed:\n - "
|
||||
+ "\n - ".join(mismatches)
|
||||
)
|
||||
|
||||
|
||||
def create_table(
|
||||
conn,
|
||||
schema_name: str,
|
||||
table_name: str,
|
||||
columns: Dict[str, ColumnSpec],
|
||||
if_exists: str,
|
||||
) -> None:
|
||||
"""Create (or verify) the target table according to ``if_exists``."""
|
||||
if if_exists not in VALID_IF_EXISTS:
|
||||
raise ValueError(f"if_exists must be one of {VALID_IF_EXISTS}, got {if_exists!r}")
|
||||
|
||||
exists = _table_exists(conn, schema_name, table_name)
|
||||
if exists:
|
||||
if if_exists == "fail":
|
||||
raise TableExistsError(
|
||||
f"Table {schema_name}.{table_name} already exists and if_exists=fail"
|
||||
)
|
||||
if if_exists == "replace":
|
||||
_drop_table(conn, schema_name, table_name)
|
||||
_create_table_sql(conn, schema_name, table_name, columns)
|
||||
return
|
||||
if if_exists == "append":
|
||||
_assert_schema_compatible(conn, schema_name, table_name, columns)
|
||||
return
|
||||
else:
|
||||
_create_table_sql(conn, schema_name, table_name, columns)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# COPY loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _seconds_to_time(v: Any) -> Optional[dt.time]:
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, float) and pd.isna(v):
|
||||
return None
|
||||
if isinstance(v, dt.time):
|
||||
return v
|
||||
if isinstance(v, (dt.datetime, pd.Timestamp)):
|
||||
return v.time() if not pd.isna(v) else None
|
||||
try:
|
||||
total = int(round(float(v)))
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
h, rem = divmod(total, 3600)
|
||||
m, s = divmod(rem, 60)
|
||||
# Clamp; TIME8. is always within a day.
|
||||
h = max(0, min(h, 23))
|
||||
return dt.time(h, m, s)
|
||||
|
||||
|
||||
def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.DataFrame:
|
||||
"""Materialize a copy of ``df`` with each column in the right shape for
|
||||
``to_csv`` so the CSV lands as valid input for the target Postgres type.
|
||||
"""
|
||||
out = pd.DataFrame(index=df.index)
|
||||
for name, spec in columns.items():
|
||||
series = df[name]
|
||||
pg = spec.postgres_type.upper()
|
||||
|
||||
if pg in ("INTEGER", "BIGINT", "SMALLINT"):
|
||||
if pd.api.types.is_object_dtype(series):
|
||||
series = pd.to_numeric(
|
||||
series.replace({"": None}), errors="coerce"
|
||||
)
|
||||
out[name] = series.astype("Int64")
|
||||
elif pg in ("DOUBLE PRECISION", "REAL", "NUMERIC"):
|
||||
if pd.api.types.is_object_dtype(series):
|
||||
series = pd.to_numeric(
|
||||
series.replace({"": None}), errors="coerce"
|
||||
)
|
||||
out[name] = series.astype("float64")
|
||||
elif pg == "DATE":
|
||||
if pd.api.types.is_datetime64_any_dtype(series):
|
||||
out[name] = series.dt.date
|
||||
elif pd.api.types.is_object_dtype(series):
|
||||
def _to_date(v: Any) -> Optional[dt.date]:
|
||||
if v is None or (isinstance(v, float) and pd.isna(v)):
|
||||
return None
|
||||
if isinstance(v, dt.datetime):
|
||||
return v.date()
|
||||
if isinstance(v, dt.date):
|
||||
return v
|
||||
if isinstance(v, str):
|
||||
if v == "":
|
||||
return None
|
||||
try:
|
||||
return dt.date.fromisoformat(v)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
out[name] = series.map(_to_date)
|
||||
else:
|
||||
out[name] = series
|
||||
elif pg in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE", "TIMESTAMP WITH TIME ZONE"):
|
||||
if pd.api.types.is_datetime64_any_dtype(series):
|
||||
out[name] = series
|
||||
elif pd.api.types.is_object_dtype(series):
|
||||
def _to_dt(v: Any) -> Optional[dt.datetime]:
|
||||
if v is None or (isinstance(v, float) and pd.isna(v)):
|
||||
return None
|
||||
if isinstance(v, dt.datetime):
|
||||
return v
|
||||
if isinstance(v, dt.date):
|
||||
return dt.datetime(v.year, v.month, v.day)
|
||||
if isinstance(v, pd.Timestamp):
|
||||
return v.to_pydatetime() if not pd.isna(v) else None
|
||||
if isinstance(v, str):
|
||||
if v == "":
|
||||
return None
|
||||
try:
|
||||
return dt.datetime.fromisoformat(v)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
out[name] = series.map(_to_dt)
|
||||
else:
|
||||
out[name] = series
|
||||
elif pg in ("TIME", "TIME WITHOUT TIME ZONE", "TIME WITH TIME ZONE"):
|
||||
out[name] = series.map(_seconds_to_time)
|
||||
elif pg in ("TEXT", "VARCHAR", "CHARACTER VARYING", "CHAR", "CHARACTER"):
|
||||
# Leave empty strings as "" so `NULL ''` in COPY turns them into NULL.
|
||||
def _to_str(v: Any) -> Any:
|
||||
if v is None:
|
||||
return ""
|
||||
if isinstance(v, float) and pd.isna(v):
|
||||
return ""
|
||||
return str(v)
|
||||
out[name] = series.map(_to_str)
|
||||
elif pg == "BOOLEAN":
|
||||
out[name] = series.astype("boolean") if series.dtype != object else series
|
||||
else:
|
||||
out[name] = series
|
||||
return out
|
||||
|
||||
|
||||
def copy_dataframe(
|
||||
conn,
|
||||
schema_name: str,
|
||||
table_name: str,
|
||||
df: pd.DataFrame,
|
||||
columns: Dict[str, ColumnSpec],
|
||||
) -> int:
|
||||
"""Stream ``df`` into Postgres via ``COPY ... FROM STDIN``.
|
||||
|
||||
Returns the number of rows inserted.
|
||||
"""
|
||||
prepared = _prepare_for_copy(df, columns)
|
||||
|
||||
buf = io.StringIO()
|
||||
prepared.to_csv(
|
||||
buf,
|
||||
index=False,
|
||||
header=False,
|
||||
na_rep="",
|
||||
date_format="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
buf.seek(0)
|
||||
|
||||
col_list = ", ".join(_quote_ident(name) for name in columns.keys())
|
||||
sql = (
|
||||
f"COPY {_qualified(schema_name, table_name)} ({col_list}) "
|
||||
f"FROM STDIN WITH (FORMAT csv, NULL '')"
|
||||
)
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.copy_expert(sql, buf)
|
||||
rowcount = cur.rowcount
|
||||
|
||||
return int(rowcount) if rowcount is not None else len(prepared)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Manifest validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _match_manifest_type(inferred: str, manifest_entry: Dict[str, Any]) -> bool:
|
||||
inferred_norm = _normalize_type(inferred)
|
||||
if "postgres_type" in manifest_entry:
|
||||
return inferred_norm == _normalize_type(manifest_entry["postgres_type"])
|
||||
if "acceptable_types" in manifest_entry:
|
||||
return any(
|
||||
inferred_norm == _normalize_type(t)
|
||||
for t in manifest_entry["acceptable_types"]
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def validate_against_manifest(
|
||||
inferred: Dict[str, ColumnSpec],
|
||||
manifest_path: Path,
|
||||
) -> List[str]:
|
||||
"""Compare the inferred schema against the expected-types manifest.
|
||||
|
||||
Returns a list of human-readable problem strings; empty list means OK.
|
||||
"""
|
||||
manifest_path = Path(manifest_path)
|
||||
if not manifest_path.exists():
|
||||
return [f"manifest not found: {manifest_path}"]
|
||||
|
||||
with manifest_path.open("r", encoding="utf-8") as f:
|
||||
manifest = json.load(f)
|
||||
|
||||
problems: List[str] = []
|
||||
|
||||
only_in_inferred = set(inferred) - set(manifest)
|
||||
only_in_manifest = set(manifest) - set(inferred)
|
||||
if only_in_inferred:
|
||||
problems.append(
|
||||
f"columns in inferred but not manifest: {sorted(only_in_inferred)}"
|
||||
)
|
||||
if only_in_manifest:
|
||||
problems.append(
|
||||
f"columns in manifest but not inferred: {sorted(only_in_manifest)}"
|
||||
)
|
||||
|
||||
for name, spec in inferred.items():
|
||||
entry = manifest.get(name)
|
||||
if entry is None:
|
||||
continue
|
||||
if not _match_manifest_type(spec.postgres_type, entry):
|
||||
expected = entry.get("postgres_type") or entry.get("acceptable_types")
|
||||
problems.append(
|
||||
f"column {name!r}: inferred {spec.postgres_type!r}, "
|
||||
f"manifest expected {expected!r}"
|
||||
)
|
||||
manifest_nullable = bool(entry.get("nullable", True))
|
||||
if spec.nullable and not manifest_nullable:
|
||||
problems.append(
|
||||
f"column {name!r}: inferred nullable, manifest expects NOT NULL "
|
||||
f"(loosening nullability is never allowed)"
|
||||
)
|
||||
|
||||
return problems
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_argparser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(
|
||||
description="Load a single SAS file (XPT or sas7bdat) into Postgres.",
|
||||
)
|
||||
p.add_argument("--config", required=True, type=Path, help="Path to YAML config")
|
||||
p.add_argument(
|
||||
"--validate",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Compare inferred schema against <filename-stem>.expected.json "
|
||||
"next to the SAS file; exits nonzero on mismatch."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Print inferred CREATE TABLE and stop; don't touch Postgres.",
|
||||
)
|
||||
return p
|
||||
|
||||
|
||||
def _format_columns_summary(columns: Dict[str, ColumnSpec]) -> str:
|
||||
lines = []
|
||||
for spec in columns.values():
|
||||
null = "" if spec.nullable else " NOT NULL"
|
||||
lines.append(f" {spec.name}: {spec.postgres_type}{null}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main(argv: Optional[List[str]] = None) -> int:
|
||||
args = _build_argparser().parse_args(argv)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
cfg = load_config(args.config)
|
||||
|
||||
if not cfg.filename.exists():
|
||||
print(f"error: SAS file not found: {cfg.filename}", file=sys.stderr)
|
||||
return 2
|
||||
|
||||
df, meta = read_sas(cfg.filename)
|
||||
df = apply_column_filter(df, cfg.include, cfg.exclude)
|
||||
columns = infer_schema(df, meta)
|
||||
|
||||
if args.validate:
|
||||
manifest_path = cfg.filename.with_suffix("").with_suffix(".expected.json")
|
||||
# The above strips .xpt then appends .expected.json, e.g.
|
||||
# "sample_kitchensink.xpt" -> "sample_kitchensink.expected.json".
|
||||
problems = validate_against_manifest(columns, manifest_path)
|
||||
if problems:
|
||||
print("validation failed:", file=sys.stderr)
|
||||
for p in problems:
|
||||
print(f" - {p}", file=sys.stderr)
|
||||
return 1
|
||||
print(f"validation OK ({len(columns)} columns match {manifest_path.name})")
|
||||
|
||||
if args.dry_run:
|
||||
print(render_create_table(cfg.schemaname, cfg.tablename, columns))
|
||||
return 0
|
||||
|
||||
conn = connect()
|
||||
conn.autocommit = False
|
||||
try:
|
||||
create_table(conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists)
|
||||
inserted = copy_dataframe(conn, cfg.schemaname, cfg.tablename, df, columns)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
print(
|
||||
f"loaded {inserted} rows into {cfg.schemaname}.{cfg.tablename} "
|
||||
f"({len(columns)} columns)"
|
||||
)
|
||||
print("final schema:")
|
||||
print(_format_columns_summary(columns))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
6
generic_loader/requirements.txt
Normal file
6
generic_loader/requirements.txt
Normal file
@ -0,0 +1,6 @@
|
||||
pandas>=2.0,<2.3
|
||||
pyreadstat>=1.2,<1.3
|
||||
numpy>=1.24,<2.1
|
||||
pyyaml>=6.0,<7.0
|
||||
psycopg2-binary>=2.9,<3.0
|
||||
python-dotenv>=1.0,<2.0
|
||||
17
generic_loader/sample_config.yaml
Normal file
17
generic_loader/sample_config.yaml
Normal file
@ -0,0 +1,17 @@
|
||||
filename: samples/sample_kitchensink.xpt
|
||||
schemaname: public
|
||||
tablename: kitchensink
|
||||
|
||||
# Optional. If set, only these columns are loaded. Mutually exclusive with exclude.
|
||||
# include:
|
||||
# - ID
|
||||
# - INTCOL
|
||||
# - DATECOL
|
||||
|
||||
# Optional. Columns to drop.
|
||||
# exclude:
|
||||
# - ALLNULL
|
||||
|
||||
# What to do if the target table already exists: fail | replace | append
|
||||
# Defaults to fail.
|
||||
if_exists: append
|
||||
Loading…
Reference in New Issue
Block a user