2069 lines
74 KiB
Python
2069 lines
74 KiB
Python
"""Per-file SAS-to-Postgres loader.
|
|
|
|
Library-style functions plus a thin CLI wrapper. Designed so an orchestrator
|
|
can wrap the library for directory/batch mode; orchestration is out of scope
|
|
here.
|
|
|
|
Python 3.14 compatible (target is an air-gapped host that currently only has
|
|
3.14). ``from __future__ import annotations`` lets us use PEP 585 generics
|
|
as annotations; runtime-resolved type uses (dataclass defaults, etc.) stick
|
|
to ``typing``.
|
|
|
|
-------------------------------------------------------------------------------
|
|
USAGE
|
|
-------------------------------------------------------------------------------
|
|
|
|
Supported inputs:
|
|
* ``.sas7bdat`` (read with ``encoding="latin-1"``)
|
|
* ``.xpt`` / ``.xport`` (SAS transport files)
|
|
|
|
1. YAML config
|
|
--------------
|
|
Every invocation is driven by a YAML file describing one SAS file to load::
|
|
|
|
filename: samples/sample_kitchensink.xpt # required; relative paths are
|
|
# resolved against the config
|
|
# file's directory when possible
|
|
schemaname: public # required
|
|
tablename: kitchensink # required
|
|
|
|
# Optional. One of: fail | replace | append. Default: fail.
|
|
# fail - error out if the target table already exists
|
|
# replace - DROP and recreate the table from the inferred schema
|
|
# append - keep the existing table; pre-flight a schema-compat check,
|
|
# then COPY the new rows in
|
|
if_exists: append
|
|
|
|
# Optional, mutually exclusive. Restrict which columns are loaded.
|
|
# include:
|
|
# - ID
|
|
# - INTCOL
|
|
# exclude:
|
|
# - ALLNULL
|
|
|
|
2. Database connection
|
|
----------------------
|
|
The loader uses standard libpq environment variables (read via ``os.environ``)::
|
|
|
|
PGHOST, PGPORT, PGUSER, PGPASSWORD, PGDATABASE
|
|
|
|
The CLI calls ``python-dotenv``'s ``load_dotenv()`` at startup, so a local
|
|
``.env`` file is picked up automatically. Library callers are responsible for
|
|
populating the environment themselves (either call ``load_dotenv()`` or export
|
|
the vars) before calling :func:`connect`.
|
|
|
|
3. Command-line interface
|
|
-------------------------
|
|
::
|
|
|
|
python load_sas.py --config path/to/config.yaml [--validate] [--dry-run]
|
|
[--dbcreds]
|
|
|
|
Flags:
|
|
--config PATH Required. Path to the YAML config above.
|
|
--validate Compare the inferred schema against
|
|
``<sas-file-stem>.expected.json`` sitting next to the SAS
|
|
file. Exits nonzero on mismatch. Safe to combine with
|
|
``--dry-run``.
|
|
--dry-run Print the inferred ``CREATE TABLE`` SQL and stop. The
|
|
database is never touched (no connection is opened).
|
|
--dbcreds Prompt interactively for the database username and
|
|
password instead of reading ``PGUSER`` / ``PGPASSWORD``
|
|
from the environment or ``.env`` file. The password
|
|
prompt does not echo. Has no effect with ``--dry-run``
|
|
(no connection is opened).
|
|
|
|
Exit codes:
|
|
0 - success (load completed, or dry-run/validate passed)
|
|
1 - validation failure
|
|
2 - config references a SAS file that does not exist
|
|
Other nonzero - uncaught exception (traceback printed); the transaction
|
|
is rolled back before exit.
|
|
|
|
Typical invocations::
|
|
|
|
# Preview the inferred schema without connecting to Postgres.
|
|
python load_sas.py --config sample_config.yaml --dry-run
|
|
|
|
# Check the inferred schema against an expected-types manifest.
|
|
python load_sas.py --config sample_config.yaml --validate --dry-run
|
|
|
|
# Actually load the data.
|
|
python load_sas.py --config sample_config.yaml
|
|
|
|
# Load the data, prompting for credentials instead of using .env.
|
|
python load_sas.py --config sample_config.yaml --dbcreds
|
|
|
|
4. Expected-types manifest (``--validate``)
|
|
-------------------------------------------
|
|
``--validate`` looks for a JSON file named ``<sas-stem>.expected.json`` next
|
|
to the SAS file, e.g. ``samples/sample_kitchensink.xpt`` pairs with
|
|
``samples/sample_kitchensink.expected.json``. Each top-level key is a column
|
|
name; the value is an object with any of::
|
|
|
|
{
|
|
"postgres_type": "BIGINT", # exact expected type, OR
|
|
"acceptable_types": ["TEXT", # any-of list of acceptable types
|
|
"VARCHAR"],
|
|
"nullable": true, # default true; false = must be NOT NULL
|
|
"note": "free-form comment" # ignored by the loader
|
|
}
|
|
|
|
Type comparison ignores length/precision modifiers and normalizes synonyms
|
|
(e.g. ``INT`` == ``INTEGER`` == ``INT4``; ``VARCHAR(10)`` == ``VARCHAR``).
|
|
Nullability tightening (inferred NULL, manifest NOT NULL) is a hard failure;
|
|
loosening is not checked here because the append-mode check already covers it.
|
|
|
|
5. Library usage
|
|
----------------
|
|
The CLI is a thin wrapper around composable functions. The preferred pattern
|
|
infers the schema from a bounded preview and then streams the rest of the
|
|
file chunk-by-chunk into ``COPY`` - crucial for SAS files with hundreds of
|
|
millions of rows::
|
|
|
|
from dotenv import load_dotenv
|
|
from load_sas import (
|
|
load_config, read_sas_preview, iter_sas_chunks, apply_column_filter,
|
|
infer_schema, validate_against_manifest, render_create_table,
|
|
connect, create_table, copy_dataframes,
|
|
)
|
|
|
|
load_dotenv()
|
|
cfg = load_config("config.yaml")
|
|
|
|
# Schema from a preview slice (bounded by TYPE_INFERENCE_SAMPLE_ROWS).
|
|
preview_df, meta = read_sas_preview(cfg.filename)
|
|
preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude)
|
|
total_rows = getattr(meta, "number_rows", None)
|
|
columns = infer_schema(preview_df, meta, total_rows=total_rows)
|
|
|
|
# Optional: preview DDL / validate against a manifest.
|
|
print(render_create_table(cfg.schemaname, cfg.tablename, columns))
|
|
problems = validate_against_manifest(columns, Path("expected.json"))
|
|
assert not problems, problems
|
|
|
|
conn = connect()
|
|
conn.autocommit = False
|
|
try:
|
|
create_table(conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists)
|
|
chunks = (
|
|
apply_column_filter(df, cfg.include, cfg.exclude)
|
|
for df, _ in iter_sas_chunks(cfg.filename)
|
|
)
|
|
rows = copy_dataframes(conn, cfg.schemaname, cfg.tablename, chunks, columns)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
For small files (or tests) the legacy one-shot API still works:
|
|
:func:`read_sas` returns the whole frame and :func:`copy_dataframe` copies it
|
|
in one round trip.
|
|
|
|
All functions are side-effect free except :func:`connect`, :func:`create_table`,
|
|
:func:`copy_dataframe`, and :func:`copy_dataframes`; schema inference
|
|
(:func:`infer_schema`) accepts a ``coerce_chars`` kwarg to override the
|
|
module-level ``COERCE_CHAR_COLUMNS`` without mutating global state.
|
|
|
|
6. Type inference summary
|
|
-------------------------
|
|
Priority order used by :func:`infer_schema`:
|
|
|
|
1. SAS format string (via ``meta.original_variable_types``):
|
|
``DATETIME*`` -> ``TIMESTAMP``, ``TIME*`` -> ``TIME``,
|
|
``DATE*`` / ``YYMMDD*`` / ``MMDDYY*`` / ``DDMMYY*`` / ``JULIAN*`` -> ``DATE``.
|
|
2. All-null column -> ``TEXT`` (with a note).
|
|
3. pandas datetime dtype -> ``TIMESTAMP``.
|
|
4. Object columns containing only ``datetime.date`` / ``datetime.datetime``
|
|
-> ``DATE`` or ``TIMESTAMP``.
|
|
5. Object columns of strings: if ``COERCE_CHAR_COLUMNS`` is True and at
|
|
least ``CHAR_INFERENCE_MIN_VALUES`` non-empty values parse cleanly, they
|
|
are promoted to ``INTEGER`` / ``BIGINT`` / ``DOUBLE PRECISION`` /
|
|
``DATE`` / ``TIMESTAMP``; otherwise ``TEXT``.
|
|
6. Numeric columns of whole numbers -> ``INTEGER`` (or ``BIGINT`` if any
|
|
value exceeds the int32 range ``NUMERIC_INT_RANGE``); otherwise
|
|
``DOUBLE PRECISION``.
|
|
|
|
Type inference scans only the first ``TYPE_INFERENCE_SAMPLE_ROWS`` rows for
|
|
performance on large files. The CLI enforces this at read time via
|
|
:func:`read_sas_preview`, so the whole file is never materialized just to pick
|
|
types. Sampled specs carry an ``inferred_from_sample`` marker and the usual
|
|
tradeoffs: if the first N rows fit ``INTEGER`` but a later row exceeds int32,
|
|
or a column had no nulls in the preview but does later in the file, ``COPY``
|
|
will fail mid-stream and the whole transaction rolls back. Set
|
|
``TYPE_INFERENCE_SAMPLE_ROWS = None`` to scan every row when exact typing
|
|
matters more than speed.
|
|
|
|
Streaming loads use :func:`iter_sas_chunks` + :func:`copy_dataframes`, which
|
|
commit each chunk as it is copied so an interrupted load retains the rows
|
|
that were already written.
|
|
|
|
7. Tunables
|
|
-----------
|
|
Module-level knobs at the top of this file:
|
|
|
|
* ``COERCE_CHAR_COLUMNS`` - promote stringly-typed numerics / dates
|
|
(default True).
|
|
* ``CHAR_INFERENCE_MIN_VALUES`` - minimum non-empty sample size before
|
|
char-column coercion is attempted.
|
|
* ``NUMERIC_INT_RANGE`` - INTEGER bounds; values outside become
|
|
``BIGINT``.
|
|
* ``TYPE_INFERENCE_SAMPLE_ROWS`` - cap on rows read for type inference
|
|
(``None`` = scan the whole column).
|
|
* ``DEFAULT_CHUNK_ROWS`` - rows per streaming COPY chunk.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import datetime as dt
|
|
import getpass
|
|
import hashlib
|
|
import io
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import re
|
|
import sys
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
|
|
|
import pandas as pd
|
|
import psycopg2
|
|
import psycopg2.extensions
|
|
import pyreadstat
|
|
import yaml
|
|
from dotenv import load_dotenv
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Top-level tunables
|
|
# ---------------------------------------------------------------------------
|
|
|
|
COERCE_CHAR_COLUMNS = True
|
|
"""If True, try to promote object (string) columns to numeric/date/timestamp
|
|
when every non-empty value parses cleanly."""
|
|
|
|
CHAR_INFERENCE_MIN_VALUES = 3
|
|
"""Don't attempt character-column coercion with fewer than this many non-empty
|
|
values; too small a sample is easy to mis-infer."""
|
|
|
|
NUMERIC_INT_RANGE = (-2_147_483_648, 2_147_483_647)
|
|
"""INTEGER bounds; anything outside becomes BIGINT."""
|
|
|
|
TYPE_INFERENCE_SAMPLE_ROWS: Optional[int] = 10_000
|
|
"""Cap on rows inspected during per-column type inference. Also governs how
|
|
many rows :func:`read_sas_preview` pulls from the file for dry-run / validate /
|
|
schema-inference flows. Set to ``None`` to scan every row (and read the whole
|
|
file into memory for the preview step - don't do this on multi-hundred-million
|
|
row files)."""
|
|
|
|
DEFAULT_CHUNK_ROWS = 100_000
|
|
"""Rows per chunk when streaming a SAS file into ``COPY``. Larger values mean
|
|
fewer COPY round-trips but more peak memory per chunk; smaller values are
|
|
gentler on memory."""
|
|
|
|
|
|
VALID_IF_EXISTS = ("fail", "replace", "append")
|
|
|
|
_PG_IDENT_MAX_LEN = 63
|
|
"""PostgreSQL maximum identifier length in bytes (characters for ASCII)."""
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Dataclasses
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class LoaderConfig:
|
|
filename: Path
|
|
schemaname: str
|
|
tablename: str
|
|
if_exists: str = "fail"
|
|
include: Optional[List[str]] = None
|
|
exclude: Optional[List[str]] = None
|
|
partition_by: List[str] = field(default_factory=list)
|
|
max_partitions: int = 10_000
|
|
indexes: List[str] = field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class ColumnSpec:
|
|
name: str
|
|
postgres_type: str
|
|
nullable: bool
|
|
sas_format: Optional[str] = None
|
|
source_dtype: Optional[str] = None
|
|
notes: List[str] = field(default_factory=list)
|
|
sampled: bool = False
|
|
"""True when the type was inferred from a bounded preview rather than the
|
|
full file. A sampled spec carries the usual sampling risks: a later chunk
|
|
could contain a value that exceeds the inferred integer range, doesn't
|
|
parse as the inferred type, or is null in a column the preview showed as
|
|
non-null - all of which surface as mid-``COPY`` failures."""
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Custom exceptions
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TableExistsError(RuntimeError):
|
|
"""Raised when if_exists=fail and the target table already exists."""
|
|
|
|
|
|
class SchemaCompatibilityError(RuntimeError):
|
|
"""Raised when if_exists=append and the incoming schema is not
|
|
compatible with the existing table."""
|
|
|
|
|
|
class ValidationError(RuntimeError):
|
|
"""Raised when --validate detects a mismatch against the manifest."""
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Connection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def connect(
|
|
*,
|
|
user: Optional[str] = None,
|
|
password: Optional[str] = None,
|
|
) -> psycopg2.extensions.connection:
|
|
"""Open a psycopg2 connection using standard libpq env vars.
|
|
|
|
Assumes `.env` has already been loaded (the CLI does this before calling).
|
|
Orchestrators that wrap this module should either call ``load_dotenv()``
|
|
themselves or ensure the env vars are set.
|
|
|
|
``user`` and ``password`` override the corresponding env vars when supplied
|
|
(used by the ``--dbcreds`` CLI flag to accept interactive input).
|
|
"""
|
|
conn = psycopg2.connect(
|
|
host=os.environ.get("PGHOST"),
|
|
port=os.environ.get("PGPORT"),
|
|
user=user or os.environ.get("PGUSER"),
|
|
password=password or os.environ.get("PGPASSWORD"),
|
|
dbname=os.environ.get("PGDATABASE"),
|
|
)
|
|
return conn
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config loading
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def load_config(path: Path) -> LoaderConfig:
|
|
"""Parse and validate the YAML config at ``path``."""
|
|
path = Path(path)
|
|
with path.open("r", encoding="utf-8") as f:
|
|
raw = yaml.safe_load(f)
|
|
|
|
if not isinstance(raw, dict):
|
|
raise ValueError(f"Config at {path} must be a YAML mapping at the top level.")
|
|
|
|
missing = [k for k in ("filename", "schemaname", "tablename") if k not in raw]
|
|
if missing:
|
|
raise ValueError(f"Config {path} missing required keys: {', '.join(missing)}")
|
|
|
|
filename = Path(raw["filename"])
|
|
if not filename.is_absolute():
|
|
filename = (path.parent / filename).resolve() if (path.parent / filename).exists() else Path(raw["filename"])
|
|
|
|
schemaname = str(raw["schemaname"])
|
|
tablename = str(raw["tablename"])
|
|
|
|
if_exists = str(raw.get("if_exists", "fail")).lower()
|
|
if if_exists not in VALID_IF_EXISTS:
|
|
raise ValueError(
|
|
f"Config {path}: if_exists={if_exists!r} is not one of {VALID_IF_EXISTS}"
|
|
)
|
|
|
|
include = raw.get("include")
|
|
exclude = raw.get("exclude")
|
|
if include is not None and exclude is not None:
|
|
raise ValueError(
|
|
f"Config {path}: 'include' and 'exclude' are mutually exclusive; set at most one."
|
|
)
|
|
if include is not None and not isinstance(include, list):
|
|
raise ValueError(f"Config {path}: 'include' must be a list of column names.")
|
|
if exclude is not None and not isinstance(exclude, list):
|
|
raise ValueError(f"Config {path}: 'exclude' must be a list of column names.")
|
|
|
|
# -- partition_by -------------------------------------------------------
|
|
raw_pb = raw.get("partition_by")
|
|
if raw_pb is None or (isinstance(raw_pb, list) and len(raw_pb) == 0):
|
|
partition_by: List[str] = []
|
|
elif isinstance(raw_pb, str):
|
|
if not raw_pb.strip():
|
|
raise ValueError(f"Config {path}: 'partition_by' string must be non-empty.")
|
|
partition_by = [raw_pb.strip()]
|
|
elif isinstance(raw_pb, list):
|
|
partition_by = []
|
|
for i, item in enumerate(raw_pb):
|
|
if not isinstance(item, str) or not item.strip():
|
|
raise ValueError(
|
|
f"Config {path}: 'partition_by[{i}]' must be a non-empty string."
|
|
)
|
|
partition_by.append(str(item).strip())
|
|
if len(partition_by) != len(set(partition_by)):
|
|
raise ValueError(
|
|
f"Config {path}: 'partition_by' contains duplicate column names."
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Config {path}: 'partition_by' must be a string or list of strings."
|
|
)
|
|
|
|
# Validate partition_by vs include/exclude
|
|
if partition_by:
|
|
inc_list = [str(c) for c in include] if include is not None else None
|
|
exc_list = [str(c) for c in exclude] if exclude is not None else None
|
|
if inc_list is not None:
|
|
missing_in_include = [c for c in partition_by if c not in inc_list]
|
|
if missing_in_include:
|
|
raise ValueError(
|
|
f"Config {path}: 'include' omits partition_by columns: "
|
|
f"{missing_in_include}"
|
|
)
|
|
if exc_list is not None:
|
|
excluded_parts = [c for c in partition_by if c in exc_list]
|
|
if excluded_parts:
|
|
raise ValueError(
|
|
f"Config {path}: 'exclude' removes partition_by columns: "
|
|
f"{excluded_parts}"
|
|
)
|
|
|
|
# -- max_partitions -----------------------------------------------------
|
|
raw_mp = raw.get("max_partitions", 10_000)
|
|
try:
|
|
max_partitions = int(raw_mp)
|
|
except (TypeError, ValueError):
|
|
raise ValueError(
|
|
f"Config {path}: 'max_partitions' must be a positive integer, "
|
|
f"got {raw_mp!r}"
|
|
)
|
|
if max_partitions <= 0:
|
|
raise ValueError(
|
|
f"Config {path}: 'max_partitions' must be a positive integer, "
|
|
f"got {max_partitions}"
|
|
)
|
|
|
|
# -- indexes ------------------------------------------------------------
|
|
raw_idx = raw.get("indexes")
|
|
if raw_idx is None or (isinstance(raw_idx, list) and len(raw_idx) == 0):
|
|
indexes: List[str] = []
|
|
elif isinstance(raw_idx, str):
|
|
if not raw_idx.strip():
|
|
raise ValueError(f"Config {path}: 'indexes' string must be non-empty.")
|
|
indexes = [raw_idx.strip()]
|
|
elif isinstance(raw_idx, list):
|
|
indexes = []
|
|
for i, item in enumerate(raw_idx):
|
|
if not isinstance(item, str) or not item.strip():
|
|
raise ValueError(
|
|
f"Config {path}: 'indexes[{i}]' must be a non-empty string."
|
|
)
|
|
indexes.append(str(item).strip())
|
|
if len(indexes) != len(set(indexes)):
|
|
raise ValueError(
|
|
f"Config {path}: 'indexes' contains duplicate column names."
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Config {path}: 'indexes' must be a string or list of strings."
|
|
)
|
|
|
|
# Validate indexes vs include/exclude
|
|
if indexes:
|
|
inc_list = [str(c) for c in include] if include is not None else None
|
|
exc_list = [str(c) for c in exclude] if exclude is not None else None
|
|
if exc_list is not None:
|
|
excluded_idx = [c for c in indexes if c in exc_list]
|
|
if excluded_idx:
|
|
raise ValueError(
|
|
f"Config {path}: 'exclude' removes index columns: "
|
|
f"{excluded_idx}"
|
|
)
|
|
if inc_list is not None:
|
|
missing_in_include = [c for c in indexes if c not in inc_list]
|
|
if missing_in_include:
|
|
raise ValueError(
|
|
f"Config {path}: 'include' omits index columns: "
|
|
f"{missing_in_include}"
|
|
)
|
|
|
|
return LoaderConfig(
|
|
filename=filename,
|
|
schemaname=schemaname,
|
|
tablename=tablename,
|
|
if_exists=if_exists,
|
|
include=[str(c) for c in include] if include is not None else None,
|
|
exclude=[str(c) for c in exclude] if exclude is not None else None,
|
|
partition_by=partition_by,
|
|
max_partitions=max_partitions,
|
|
indexes=indexes,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Reader
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _sas_reader(path: Path) -> Tuple[Any, Dict[str, Any]]:
|
|
"""Return ``(pyreadstat_reader, extra_kwargs)`` for ``path``.
|
|
|
|
Invariants (learned the hard way while building the sample generator):
|
|
|
|
* ``.xpt`` / ``.xport`` - no encoding arg; pyreadstat is flaky about
|
|
encoding on XPORT files it wrote itself.
|
|
* ``.sas7bdat`` - explicit ``encoding="latin-1"`` per colleague guidance.
|
|
"""
|
|
suffix = Path(path).suffix.lower()
|
|
if suffix in (".xpt", ".xport"):
|
|
return pyreadstat.read_xport, {}
|
|
if suffix == ".sas7bdat":
|
|
return pyreadstat.read_sas7bdat, {}
|
|
raise ValueError(f"Unsupported SAS file extension: {suffix}")
|
|
|
|
|
|
def read_sas(path: Path) -> Tuple[pd.DataFrame, Any]:
|
|
"""Read an entire SAS file into memory. Only safe for small files.
|
|
|
|
Kept for backward compatibility and tests; the CLI now uses
|
|
:func:`read_sas_preview` + :func:`iter_sas_chunks` so it never materializes
|
|
the whole frame at once.
|
|
"""
|
|
reader, kwargs = _sas_reader(path)
|
|
return reader(str(Path(path)), **kwargs)
|
|
|
|
|
|
def read_sas_preview(
|
|
path: Path,
|
|
*,
|
|
rows: Optional[int] = None,
|
|
) -> Tuple[pd.DataFrame, Any]:
|
|
"""Read the first ``rows`` records from ``path`` plus its metadata.
|
|
|
|
Defaults to ``TYPE_INFERENCE_SAMPLE_ROWS`` when ``rows`` is not given.
|
|
Passing ``rows=None`` with ``TYPE_INFERENCE_SAMPLE_ROWS=None`` reads the
|
|
whole file (pyreadstat treats ``row_limit=0`` as unlimited).
|
|
"""
|
|
reader, kwargs = _sas_reader(path)
|
|
effective = rows if rows is not None else TYPE_INFERENCE_SAMPLE_ROWS
|
|
row_limit = int(effective) if effective else 0
|
|
return reader(str(Path(path)), row_limit=row_limit, **kwargs)
|
|
|
|
|
|
def iter_sas_chunks(
|
|
path: Path,
|
|
*,
|
|
chunksize: int = DEFAULT_CHUNK_ROWS,
|
|
):
|
|
"""Yield ``(df_chunk, meta)`` tuples for streaming loads.
|
|
|
|
Thin wrapper over ``pyreadstat.read_file_in_chunks`` that picks the right
|
|
underlying reader by extension and threads through our encoding defaults.
|
|
"""
|
|
reader, kwargs = _sas_reader(path)
|
|
yield from pyreadstat.read_file_in_chunks(
|
|
reader, str(Path(path)), chunksize=chunksize, **kwargs
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Column filtering
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def apply_column_filter(
|
|
df: pd.DataFrame,
|
|
include: Optional[List[str]],
|
|
exclude: Optional[List[str]],
|
|
) -> pd.DataFrame:
|
|
"""Restrict ``df`` to the requested columns. Names missing from the frame
|
|
raise a clear error rather than silently dropping."""
|
|
if include is not None and exclude is not None:
|
|
raise ValueError("include and exclude are mutually exclusive.")
|
|
|
|
if include is not None:
|
|
missing = [c for c in include if c not in df.columns]
|
|
if missing:
|
|
raise ValueError(f"include references unknown columns: {missing}")
|
|
return df.loc[:, list(include)].copy()
|
|
|
|
if exclude is not None:
|
|
missing = [c for c in exclude if c not in df.columns]
|
|
if missing:
|
|
raise ValueError(f"exclude references unknown columns: {missing}")
|
|
return df.drop(columns=list(exclude)).copy()
|
|
|
|
return df.copy()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Type inference
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
_DATE_FORMAT_PREFIXES = ("DATE", "YYMMDD", "MMDDYY", "DDMMYY", "JULIAN")
|
|
|
|
|
|
def _format_driven_type(sas_format: Optional[str]) -> Optional[str]:
|
|
"""Return a Postgres type inferred from the SAS format string, or None
|
|
if the format doesn't pin it down."""
|
|
if not sas_format:
|
|
return None
|
|
fmt = sas_format.upper().lstrip()
|
|
# DATETIME must be checked before DATE since "DATETIME20." starts with "DATE".
|
|
if fmt.startswith("DATETIME"):
|
|
return "TIMESTAMP"
|
|
if fmt.startswith("TIME"):
|
|
return "TIME"
|
|
for prefix in _DATE_FORMAT_PREFIXES:
|
|
if fmt.startswith(prefix):
|
|
return "DATE"
|
|
return None
|
|
|
|
|
|
def _all_null(series: pd.Series) -> bool:
|
|
if pd.api.types.is_object_dtype(series):
|
|
return bool(series.map(lambda v: v is None or (isinstance(v, str) and v == "") or (isinstance(v, float) and pd.isna(v))).all())
|
|
return bool(series.isna().all())
|
|
|
|
|
|
def _char_missing_mask(series: pd.Series) -> pd.Series:
|
|
return series.map(lambda v: v is None or (isinstance(v, float) and pd.isna(v)) or (isinstance(v, str) and v == ""))
|
|
|
|
|
|
def _is_nullable(series: pd.Series) -> bool:
|
|
"""True if the column has at least one missing value."""
|
|
if pd.api.types.is_object_dtype(series):
|
|
return bool(_char_missing_mask(series).any())
|
|
return bool(series.isna().any())
|
|
|
|
|
|
def _numeric_int_target(series: pd.Series) -> Optional[str]:
|
|
"""Given a numeric (float64) series, if every non-null value is a whole
|
|
number, return INTEGER or BIGINT depending on range; else None."""
|
|
nonnull = series.dropna()
|
|
if nonnull.empty:
|
|
return None
|
|
# Whole-number test. Guard against inf.
|
|
try:
|
|
whole = ((nonnull % 1) == 0).all()
|
|
except TypeError:
|
|
return None
|
|
if not whole:
|
|
return None
|
|
lo, hi = NUMERIC_INT_RANGE
|
|
vmin = nonnull.min()
|
|
vmax = nonnull.max()
|
|
if lo <= vmin and vmax <= hi:
|
|
return "INTEGER"
|
|
return "BIGINT"
|
|
|
|
|
|
def _object_is_dates(series: pd.Series) -> Tuple[bool, bool]:
|
|
"""Return (all-date-like, any-datetime). If every non-null value is a
|
|
``datetime.date`` / ``datetime.datetime`` / ``pd.Timestamp``, return True
|
|
plus whether at least one carries a time component."""
|
|
nonnull = series.dropna()
|
|
if nonnull.empty:
|
|
return False, False
|
|
any_datetime = False
|
|
for v in nonnull:
|
|
if isinstance(v, dt.datetime) or isinstance(v, pd.Timestamp):
|
|
any_datetime = True
|
|
continue
|
|
if isinstance(v, dt.date):
|
|
continue
|
|
return False, False
|
|
return True, any_datetime
|
|
|
|
|
|
def _try_int_coerce(values: List[str]) -> Optional[str]:
|
|
"""If every value parses as an int, return INTEGER/BIGINT, else None."""
|
|
ints: List[int] = []
|
|
for v in values:
|
|
s = v.strip()
|
|
try:
|
|
ints.append(int(s))
|
|
except ValueError:
|
|
return None
|
|
if not ints:
|
|
return None
|
|
lo, hi = NUMERIC_INT_RANGE
|
|
if all(lo <= i <= hi for i in ints):
|
|
return "INTEGER"
|
|
return "BIGINT"
|
|
|
|
|
|
def _try_float_coerce(values: List[str]) -> bool:
|
|
for v in values:
|
|
try:
|
|
float(v)
|
|
except ValueError:
|
|
return False
|
|
return True
|
|
|
|
|
|
def _try_date_coerce(values: List[str]) -> bool:
|
|
for v in values:
|
|
try:
|
|
dt.date.fromisoformat(v)
|
|
except (ValueError, TypeError):
|
|
return False
|
|
return True
|
|
|
|
|
|
def _try_datetime_coerce(values: List[str]) -> bool:
|
|
for v in values:
|
|
try:
|
|
dt.datetime.fromisoformat(v)
|
|
except (ValueError, TypeError):
|
|
return False
|
|
return True
|
|
|
|
|
|
def _infer_char_type(series: pd.Series) -> str:
|
|
"""Object/string column inference. Returns a Postgres type string."""
|
|
mask = _char_missing_mask(series)
|
|
nonempty = [str(v) for v in series[~mask].tolist()]
|
|
if not COERCE_CHAR_COLUMNS or len(nonempty) < CHAR_INFERENCE_MIN_VALUES:
|
|
return "TEXT"
|
|
|
|
int_guess = _try_int_coerce(nonempty)
|
|
if int_guess is not None:
|
|
return int_guess
|
|
if _try_float_coerce(nonempty):
|
|
return "DOUBLE PRECISION"
|
|
if _try_date_coerce(nonempty):
|
|
return "DATE"
|
|
if _try_datetime_coerce(nonempty):
|
|
return "TIMESTAMP"
|
|
return "TEXT"
|
|
|
|
|
|
def infer_schema(
|
|
df: pd.DataFrame,
|
|
meta: Any,
|
|
*,
|
|
coerce_chars: bool = COERCE_CHAR_COLUMNS,
|
|
total_rows: Optional[int] = None,
|
|
) -> Dict[str, ColumnSpec]:
|
|
"""Infer a Postgres column spec for each column in ``df``.
|
|
|
|
``meta`` is the pyreadstat metadata object; we read
|
|
``meta.original_variable_types`` (a dict keyed by column name) for
|
|
format-driven date/time/timestamp inference.
|
|
|
|
The ``coerce_chars`` kwarg lets callers override the module-level
|
|
``COERCE_CHAR_COLUMNS`` without mutating global state. Internally the
|
|
char-inference helpers still read the constant - a full override would
|
|
thread the flag through, but the one-knob story here is intentional.
|
|
|
|
``total_rows`` lets callers who already sampled the frame (e.g. via
|
|
:func:`read_sas_preview`) report the real file size in the per-column
|
|
"inferred from first N of M rows" note. Falls back to ``len(df)``.
|
|
"""
|
|
original_formats: Dict[str, str] = dict(getattr(meta, "original_variable_types", {}) or {})
|
|
|
|
# Row-walking type probes run on a bounded head slice; nullability and the
|
|
# all-null check still see every row so NOT NULL declarations stay honest.
|
|
df_rows = len(df)
|
|
effective_total = total_rows if total_rows is not None else df_rows
|
|
if TYPE_INFERENCE_SAMPLE_ROWS is not None and df_rows > TYPE_INFERENCE_SAMPLE_ROWS:
|
|
sample_df = df.head(TYPE_INFERENCE_SAMPLE_ROWS)
|
|
sample_size = TYPE_INFERENCE_SAMPLE_ROWS
|
|
else:
|
|
sample_df = df
|
|
sample_size = df_rows
|
|
sampled = sample_size < effective_total
|
|
|
|
# Temporarily flip the module-level flag if the caller asked us to.
|
|
global COERCE_CHAR_COLUMNS
|
|
saved = COERCE_CHAR_COLUMNS
|
|
COERCE_CHAR_COLUMNS = coerce_chars
|
|
try:
|
|
out: Dict[str, ColumnSpec] = {}
|
|
for col in df.columns:
|
|
series = df[col]
|
|
sample_series = sample_df[col]
|
|
sas_format = original_formats.get(col)
|
|
notes: List[str] = []
|
|
|
|
pg_type = _format_driven_type(sas_format)
|
|
|
|
if pg_type is None:
|
|
if _all_null(series):
|
|
pg_type = "TEXT"
|
|
notes.append("all-null column; defaulting to TEXT")
|
|
elif pd.api.types.is_datetime64_any_dtype(series):
|
|
pg_type = "TIMESTAMP"
|
|
elif pd.api.types.is_object_dtype(series):
|
|
is_dates, any_dt = _object_is_dates(sample_series)
|
|
if is_dates:
|
|
pg_type = "TIMESTAMP" if any_dt else "DATE"
|
|
else:
|
|
pg_type = _infer_char_type(sample_series)
|
|
elif pd.api.types.is_numeric_dtype(series):
|
|
int_target = _numeric_int_target(sample_series)
|
|
if int_target is not None:
|
|
pg_type = int_target
|
|
else:
|
|
pg_type = "DOUBLE PRECISION"
|
|
else:
|
|
pg_type = "TEXT"
|
|
notes.append(f"unhandled dtype {series.dtype}; defaulting to TEXT")
|
|
|
|
if sampled:
|
|
notes.append(
|
|
f"type inferred from first {sample_size:,} of "
|
|
f"{effective_total:,} rows"
|
|
)
|
|
|
|
nullable = _is_nullable(series)
|
|
|
|
out[col] = ColumnSpec(
|
|
name=col,
|
|
postgres_type=pg_type,
|
|
nullable=nullable,
|
|
sas_format=sas_format,
|
|
source_dtype=str(series.dtype),
|
|
notes=notes,
|
|
sampled=sampled,
|
|
)
|
|
return out
|
|
finally:
|
|
COERCE_CHAR_COLUMNS = saved
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Table management
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _quote_ident(ident: str) -> str:
|
|
"""Quote a Postgres identifier. psycopg2 doesn't expose this directly
|
|
until 2.8+ with sql.Identifier; we do it by hand to stay driver-simple."""
|
|
return '"' + ident.replace('"', '""') + '"'
|
|
|
|
|
|
def _qualified(schema: str, table: str) -> str:
|
|
return f"{_quote_ident(schema)}.{_quote_ident(table)}"
|
|
|
|
|
|
def _table_exists(conn, schema: str, table: str) -> bool:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"SELECT 1 FROM information_schema.tables "
|
|
"WHERE table_schema = %s AND table_name = %s",
|
|
(schema, table),
|
|
)
|
|
return cur.fetchone() is not None
|
|
|
|
|
|
def render_create_table(
|
|
schema: str,
|
|
table: str,
|
|
columns: Dict[str, ColumnSpec],
|
|
*,
|
|
partition_by: Optional[List[str]] = None,
|
|
) -> str:
|
|
"""Render a ``CREATE TABLE`` statement.
|
|
|
|
When ``partition_by`` is provided and non-empty, appends a
|
|
``PARTITION BY LIST ("first_field")`` clause to the statement.
|
|
"""
|
|
lines = []
|
|
for spec in columns.values():
|
|
null_clause = "" if spec.nullable else " NOT NULL"
|
|
lines.append(f" {_quote_ident(spec.name)} {spec.postgres_type}{null_clause}")
|
|
body = ",\n".join(lines)
|
|
suffix = ""
|
|
if partition_by:
|
|
suffix = f"\nPARTITION BY LIST ({_quote_ident(partition_by[0])})"
|
|
return f"CREATE TABLE {_qualified(schema, table)} (\n{body}\n){suffix};"
|
|
|
|
|
|
def _create_table_sql(
|
|
conn,
|
|
schema: str,
|
|
table: str,
|
|
columns: Dict[str, ColumnSpec],
|
|
*,
|
|
partition_by: Optional[List[str]] = None,
|
|
) -> None:
|
|
"""Execute a ``CREATE TABLE`` statement, optionally with partitioning."""
|
|
sql = render_create_table(schema, table, columns, partition_by=partition_by)
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql)
|
|
|
|
|
|
def _drop_table(conn, schema: str, table: str, *, cascade: bool = False) -> None:
|
|
"""Drop a table, optionally with CASCADE for partitioned tables."""
|
|
tail = " CASCADE" if cascade else ""
|
|
with conn.cursor() as cur:
|
|
cur.execute(f"DROP TABLE {_qualified(schema, table)}{tail}")
|
|
|
|
|
|
# Normalization table: map both loader-emitted and Postgres-reported type
|
|
# strings to a single canonical family name. Ignore length/precision
|
|
# modifiers like VARCHAR(n) and NUMERIC(p,s).
|
|
_TYPE_NORMALIZATION: Dict[str, str] = {
|
|
"INTEGER": "integer",
|
|
"INT": "integer",
|
|
"INT4": "integer",
|
|
"BIGINT": "bigint",
|
|
"INT8": "bigint",
|
|
"SMALLINT": "smallint",
|
|
"INT2": "smallint",
|
|
"DOUBLE PRECISION": "double precision",
|
|
"FLOAT8": "double precision",
|
|
"REAL": "real",
|
|
"FLOAT4": "real",
|
|
"NUMERIC": "numeric",
|
|
"DECIMAL": "numeric",
|
|
"TEXT": "text",
|
|
"VARCHAR": "character varying",
|
|
"CHARACTER VARYING": "character varying",
|
|
"CHAR": "character",
|
|
"CHARACTER": "character",
|
|
"BPCHAR": "character",
|
|
"BOOLEAN": "boolean",
|
|
"BOOL": "boolean",
|
|
"DATE": "date",
|
|
"TIMESTAMP": "timestamp without time zone",
|
|
"TIMESTAMP WITHOUT TIME ZONE": "timestamp without time zone",
|
|
"TIMESTAMPTZ": "timestamp with time zone",
|
|
"TIMESTAMP WITH TIME ZONE": "timestamp with time zone",
|
|
"TIME": "time without time zone",
|
|
"TIME WITHOUT TIME ZONE": "time without time zone",
|
|
"TIMETZ": "time with time zone",
|
|
"TIME WITH TIME ZONE": "time with time zone",
|
|
}
|
|
|
|
|
|
def _normalize_type(pg_type: str) -> str:
|
|
"""Strip length/precision modifiers and map to canonical family."""
|
|
stripped = pg_type.strip().upper()
|
|
# Remove trailing (n) / (p,s) before the space-separated tail.
|
|
# Examples: "VARCHAR(10)" -> "VARCHAR"; "TIMESTAMP(6) WITHOUT TIME ZONE" -> "TIMESTAMP WITHOUT TIME ZONE"
|
|
stripped = re.sub(r"\(\s*\d+\s*(?:,\s*\d+\s*)?\)", "", stripped).strip()
|
|
# Collapse doubled whitespace after paren removal.
|
|
stripped = re.sub(r"\s+", " ", stripped)
|
|
return _TYPE_NORMALIZATION.get(stripped, stripped.lower())
|
|
|
|
|
|
def _assert_schema_compatible(
|
|
conn, schema: str, table: str, columns: Dict[str, ColumnSpec]
|
|
) -> None:
|
|
"""Pre-flight check for if_exists=append. See plan section on option B."""
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"SELECT column_name, data_type, is_nullable "
|
|
"FROM information_schema.columns "
|
|
"WHERE table_schema = %s AND table_name = %s",
|
|
(schema, table),
|
|
)
|
|
existing = {row[0]: (row[1], row[2]) for row in cur.fetchall()}
|
|
|
|
mismatches: List[str] = []
|
|
warnings: List[str] = []
|
|
|
|
for name, spec in columns.items():
|
|
if name not in existing:
|
|
mismatches.append(
|
|
f"column {name!r} not present in target {schema}.{table}"
|
|
)
|
|
continue
|
|
target_type, target_nullable = existing[name]
|
|
inferred_norm = _normalize_type(spec.postgres_type)
|
|
target_norm = _normalize_type(target_type)
|
|
if inferred_norm != target_norm:
|
|
mismatches.append(
|
|
f"column {name!r}: inferred {spec.postgres_type} "
|
|
f"(normalized {inferred_norm!r}) but target is {target_type} "
|
|
f"(normalized {target_norm!r})"
|
|
)
|
|
target_is_notnull = (target_nullable == "NO")
|
|
if spec.nullable and target_is_notnull:
|
|
warnings.append(
|
|
f"column {name!r}: incoming allows NULLs but target is NOT NULL; "
|
|
"COPY will fail if any NULLs appear"
|
|
)
|
|
|
|
for w in warnings:
|
|
print(f"[warn] {w}", file=sys.stderr)
|
|
|
|
if mismatches:
|
|
raise SchemaCompatibilityError(
|
|
"append-mode schema compatibility check failed:\n - "
|
|
+ "\n - ".join(mismatches)
|
|
)
|
|
|
|
|
|
def assert_schema_compatible(
|
|
conn,
|
|
schema_name: str,
|
|
table_name: str,
|
|
columns: Dict[str, ColumnSpec],
|
|
) -> None:
|
|
"""Public wrapper around :func:`_assert_schema_compatible`.
|
|
|
|
Intended for orchestrators (e.g. the folder loader) that append multiple
|
|
files into one table and need to re-run the same compatibility check
|
|
that ``if_exists=append`` performs internally. Raises
|
|
:class:`SchemaCompatibilityError` on mismatch.
|
|
"""
|
|
_assert_schema_compatible(conn, schema_name, table_name, columns)
|
|
|
|
|
|
def create_table(
|
|
conn,
|
|
schema_name: str,
|
|
table_name: str,
|
|
columns: Dict[str, ColumnSpec],
|
|
if_exists: str,
|
|
*,
|
|
partition_by: Optional[List[str]] = None,
|
|
partition_values: Optional[dict] = None,
|
|
max_partitions: int = 10_000,
|
|
) -> None:
|
|
"""Create (or verify) the target table according to ``if_exists``.
|
|
|
|
When ``partition_by`` is provided and non-empty, the parent table is
|
|
created with ``PARTITION BY LIST`` and all child partition DDL from
|
|
:func:`render_partition_ddl` is executed immediately after.
|
|
|
|
For ``replace`` mode the existing table is dropped with ``CASCADE`` so
|
|
all child partitions are removed automatically.
|
|
|
|
For ``append`` mode partition creation is skipped entirely — the
|
|
partitions are assumed to already exist from the original creation.
|
|
"""
|
|
if if_exists not in VALID_IF_EXISTS:
|
|
raise ValueError(f"if_exists must be one of {VALID_IF_EXISTS}, got {if_exists!r}")
|
|
|
|
is_partitioned = bool(partition_by)
|
|
|
|
exists = _table_exists(conn, schema_name, table_name)
|
|
if exists:
|
|
if if_exists == "fail":
|
|
raise TableExistsError(
|
|
f"Table {schema_name}.{table_name} already exists and if_exists=fail"
|
|
)
|
|
if if_exists == "replace":
|
|
_drop_table(conn, schema_name, table_name, cascade=is_partitioned)
|
|
_create_table_sql(
|
|
conn, schema_name, table_name, columns,
|
|
partition_by=partition_by,
|
|
)
|
|
if is_partitioned and partition_values is not None:
|
|
ddl_stmts = render_partition_ddl(
|
|
schema_name, table_name, partition_by, partition_values,
|
|
columns, max_partitions=max_partitions,
|
|
)
|
|
with conn.cursor() as cur:
|
|
for stmt in ddl_stmts:
|
|
cur.execute(stmt)
|
|
return
|
|
if if_exists == "append":
|
|
_assert_schema_compatible(conn, schema_name, table_name, columns)
|
|
return
|
|
else:
|
|
_create_table_sql(
|
|
conn, schema_name, table_name, columns,
|
|
partition_by=partition_by,
|
|
)
|
|
if is_partitioned and partition_values is not None:
|
|
ddl_stmts = render_partition_ddl(
|
|
schema_name, table_name, partition_by, partition_values,
|
|
columns, max_partitions=max_partitions,
|
|
)
|
|
with conn.cursor() as cur:
|
|
for stmt in ddl_stmts:
|
|
cur.execute(stmt)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Partition support
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _sanitize_partition_value(value: Any, parent_table: str = "") -> str:
|
|
"""Convert a partition value into a safe, deterministic table-name suffix.
|
|
|
|
Rules:
|
|
- Convert to string, lowercase
|
|
- Replace non-alphanumeric runs with ``_``
|
|
- Collapse consecutive underscores, strip leading/trailing ``_``
|
|
- None/NaN → ``null``; empty string → ``empty``
|
|
- Truncate to fit within PostgreSQL's 63-character identifier limit
|
|
accounting for ``parent_table`` + ``_`` separator
|
|
"""
|
|
if value is None or (isinstance(value, float) and (pd.isna(value) or math.isnan(value))):
|
|
token = "null"
|
|
elif isinstance(value, dt.date) or isinstance(value, dt.datetime):
|
|
token = value.isoformat()
|
|
elif isinstance(value, dt.time):
|
|
token = value.isoformat()
|
|
else:
|
|
token = str(value)
|
|
|
|
token = token.lower()
|
|
token = re.sub(r"[^a-z0-9]+", "_", token)
|
|
token = re.sub(r"_+", "_", token)
|
|
token = token.strip("_")
|
|
|
|
if not token:
|
|
if value is None or (isinstance(value, float) and pd.isna(value)):
|
|
token = "null"
|
|
elif isinstance(value, str) and value == "":
|
|
token = "empty"
|
|
else:
|
|
token = "value"
|
|
|
|
# Truncate to keep total table name within PG's 63-char limit.
|
|
if parent_table:
|
|
# Reserve room for parent + underscore separator.
|
|
max_token_len = _PG_IDENT_MAX_LEN - len(parent_table) - 1
|
|
if max_token_len < 1:
|
|
raise ValueError(
|
|
f"Parent table name {parent_table!r} is too long "
|
|
f"({len(parent_table)} chars) to create child partitions."
|
|
)
|
|
if len(token) > max_token_len:
|
|
token = token[:max_token_len].rstrip("_")
|
|
|
|
return token
|
|
|
|
|
|
def _render_partition_value_literal(value: Any, pg_type: str) -> str:
|
|
"""Render a Python value as a SQL literal for ``FOR VALUES IN (...)``.
|
|
|
|
- None/NaN → ``NULL``
|
|
- Strings → single-quoted with ``'`` escaped to ``''``
|
|
- Numbers → plain numeric literal
|
|
- Booleans → ``TRUE`` / ``FALSE``
|
|
- Dates → ``DATE 'YYYY-MM-DD'``
|
|
- Timestamps → ``TIMESTAMP 'YYYY-MM-DD HH:MM:SS'``
|
|
- Times → ``TIME 'HH:MM:SS'``
|
|
"""
|
|
if value is None or (isinstance(value, float) and pd.isna(value)):
|
|
return "NULL"
|
|
|
|
pg_upper = pg_type.upper()
|
|
|
|
if pg_upper in ("BOOLEAN", "BOOL"):
|
|
return "TRUE" if value else "FALSE"
|
|
|
|
if pg_upper in ("INTEGER", "BIGINT", "SMALLINT", "INT", "INT4", "INT8", "INT2"):
|
|
return str(int(value))
|
|
|
|
if pg_upper in ("DOUBLE PRECISION", "REAL", "NUMERIC", "DECIMAL",
|
|
"FLOAT4", "FLOAT8"):
|
|
return str(value)
|
|
|
|
if pg_upper == "DATE":
|
|
if isinstance(value, (dt.date, dt.datetime)):
|
|
return f"DATE '{value.isoformat()}'"
|
|
return f"DATE '{value}'"
|
|
|
|
if pg_upper in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE",
|
|
"TIMESTAMP WITH TIME ZONE", "TIMESTAMPTZ"):
|
|
if isinstance(value, (dt.datetime, pd.Timestamp)):
|
|
return f"TIMESTAMP '{value.isoformat()}'"
|
|
if isinstance(value, dt.date):
|
|
return f"TIMESTAMP '{dt.datetime(value.year, value.month, value.day).isoformat()}'"
|
|
return f"TIMESTAMP '{value}'"
|
|
|
|
if pg_upper in ("TIME", "TIME WITHOUT TIME ZONE",
|
|
"TIME WITH TIME ZONE", "TIMETZ"):
|
|
if isinstance(value, dt.time):
|
|
return f"TIME '{value.isoformat()}'"
|
|
return f"TIME '{value}'"
|
|
|
|
# Default: treat as text — single-quote with escaping.
|
|
escaped = str(value).replace("'", "''")
|
|
return f"'{escaped}'"
|
|
|
|
|
|
def _normalize_partition_value(value: Any, pg_type: str) -> Any:
|
|
"""Normalize a raw partition value to its Python-native form.
|
|
|
|
Applies the same semantic normalization that :func:`_prepare_for_copy`
|
|
uses, so partition discovery deduplicates on the routed value rather
|
|
than the raw source representation.
|
|
"""
|
|
# Handle pandas null types
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, float) and (pd.isna(value) or math.isnan(value)):
|
|
return None
|
|
try:
|
|
if pd.isna(value):
|
|
return None
|
|
except (TypeError, ValueError):
|
|
pass
|
|
|
|
pg_upper = pg_type.upper()
|
|
|
|
if pg_upper in ("INTEGER", "BIGINT", "SMALLINT", "INT", "INT4", "INT8", "INT2"):
|
|
if isinstance(value, str):
|
|
value = value.strip()
|
|
if value == "":
|
|
return None
|
|
try:
|
|
return int(float(value))
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
if pg_upper in ("DOUBLE PRECISION", "REAL", "NUMERIC", "DECIMAL",
|
|
"FLOAT4", "FLOAT8"):
|
|
if isinstance(value, str):
|
|
value = value.strip()
|
|
if value == "":
|
|
return None
|
|
try:
|
|
result = float(value)
|
|
return None if math.isnan(result) else result
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
if pg_upper == "DATE":
|
|
if isinstance(value, dt.datetime):
|
|
return value.date()
|
|
if isinstance(value, dt.date):
|
|
return value
|
|
if isinstance(value, str):
|
|
if value.strip() == "":
|
|
return None
|
|
try:
|
|
return dt.date.fromisoformat(value.strip())
|
|
except (ValueError, TypeError):
|
|
return None
|
|
return None
|
|
|
|
if pg_upper in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE",
|
|
"TIMESTAMP WITH TIME ZONE", "TIMESTAMPTZ"):
|
|
if isinstance(value, dt.datetime):
|
|
return value
|
|
if isinstance(value, pd.Timestamp):
|
|
return value.to_pydatetime() if not pd.isna(value) else None
|
|
if isinstance(value, dt.date):
|
|
return dt.datetime(value.year, value.month, value.day)
|
|
if isinstance(value, str):
|
|
if value.strip() == "":
|
|
return None
|
|
try:
|
|
return dt.datetime.fromisoformat(value.strip())
|
|
except (ValueError, TypeError):
|
|
return None
|
|
return None
|
|
|
|
if pg_upper in ("TIME", "TIME WITHOUT TIME ZONE",
|
|
"TIME WITH TIME ZONE", "TIMETZ"):
|
|
return _seconds_to_time(value)
|
|
|
|
if pg_upper in ("BOOLEAN", "BOOL"):
|
|
if isinstance(value, bool):
|
|
return value
|
|
if isinstance(value, (int, float)):
|
|
return bool(value)
|
|
if isinstance(value, str):
|
|
return value.strip().lower() in ("true", "1", "t", "yes")
|
|
return None
|
|
|
|
# Text-like types: None, pandas nulls, and '' all become None
|
|
# because copy_dataframes() sends empty strings with NULL ''.
|
|
if pg_upper in ("TEXT", "VARCHAR", "CHARACTER VARYING", "CHAR", "CHARACTER", "BPCHAR"):
|
|
if isinstance(value, str):
|
|
if value == "":
|
|
return None
|
|
return value
|
|
return str(value)
|
|
|
|
# Fallback: return as-is converted to native Python type
|
|
if hasattr(value, "item"):
|
|
return value.item()
|
|
return value
|
|
|
|
|
|
def discover_partition_values(
|
|
df: pd.DataFrame,
|
|
partition_by: list[str],
|
|
columns: Optional[Dict[str, ColumnSpec]] = None,
|
|
) -> dict:
|
|
"""Build a nested structure of unique partition values from a DataFrame.
|
|
|
|
For ``partition_by = ['state', 'zip']`` returns::
|
|
|
|
{
|
|
'MO': {'63101': {}, '63102': {}},
|
|
'IL': {'62001': {}, '62002': {}}
|
|
}
|
|
|
|
When ``columns`` is provided, values are normalized through
|
|
:func:`_normalize_partition_value` to match the routed values Postgres
|
|
will see during ``COPY``.
|
|
|
|
None/NaN values are included as a distinct partition value (``None`` key).
|
|
Values are converted to Python native types (not numpy types).
|
|
"""
|
|
if not partition_by:
|
|
return {}
|
|
|
|
def _to_native(val: Any) -> Any:
|
|
"""Convert numpy scalars to Python native types."""
|
|
if val is None:
|
|
return None
|
|
if isinstance(val, float) and pd.isna(val):
|
|
return None
|
|
if hasattr(val, "item"):
|
|
return val.item()
|
|
return val
|
|
|
|
def _build_level(sub_df: pd.DataFrame, fields: list[str]) -> dict:
|
|
if not fields or sub_df.empty:
|
|
return {}
|
|
|
|
field = fields[0]
|
|
remaining = fields[1:]
|
|
result: dict = {}
|
|
|
|
# Get unique values, handling NaN
|
|
unique_vals = sub_df[field].unique()
|
|
|
|
for raw_val in unique_vals:
|
|
val = _to_native(raw_val)
|
|
|
|
# Normalize if column spec is available
|
|
if columns and field in columns:
|
|
val = _normalize_partition_value(val, columns[field].postgres_type)
|
|
|
|
if remaining:
|
|
# Filter rows matching this value
|
|
if val is None:
|
|
mask = sub_df[field].isna() | sub_df[field].map(
|
|
lambda v: v is None or (isinstance(v, float) and pd.isna(v))
|
|
or (isinstance(v, str) and v == ""
|
|
and columns and field in columns
|
|
and columns[field].postgres_type.upper() in (
|
|
"TEXT", "VARCHAR", "CHARACTER VARYING",
|
|
"CHAR", "CHARACTER", "BPCHAR"))
|
|
)
|
|
else:
|
|
mask = sub_df[field].map(lambda v, target=val: _matches(v, target, field))
|
|
child_df = sub_df[mask]
|
|
result[val] = _build_level(child_df, remaining)
|
|
else:
|
|
result[val] = {}
|
|
|
|
return result
|
|
|
|
def _matches(raw_val: Any, target: Any, field_name: str) -> bool:
|
|
"""Check if a raw value normalizes to the target."""
|
|
native = _to_native(raw_val)
|
|
if columns and field_name in columns:
|
|
native = _normalize_partition_value(native, columns[field_name].postgres_type)
|
|
if target is None:
|
|
return native is None
|
|
return native == target
|
|
|
|
return _build_level(df, list(partition_by))
|
|
|
|
|
|
def discover_partition_values_chunked(
|
|
chunk_iter: Iterable[pd.DataFrame],
|
|
partition_by: list[str],
|
|
columns: Optional[Dict[str, ColumnSpec]] = None,
|
|
) -> dict:
|
|
"""Discover partition values across an iterable of DataFrame chunks.
|
|
|
|
Scans the entire file chunk-by-chunk, collecting unique partition
|
|
column values and merging them into a single nested partition tree.
|
|
This avoids materializing the full file in memory.
|
|
"""
|
|
if not partition_by:
|
|
return {}
|
|
|
|
merged: dict = {}
|
|
|
|
for chunk_df in chunk_iter:
|
|
if chunk_df.empty:
|
|
continue
|
|
# Only keep partition columns to minimize memory
|
|
part_cols = [c for c in partition_by if c in chunk_df.columns]
|
|
if len(part_cols) != len(partition_by):
|
|
missing = [c for c in partition_by if c not in chunk_df.columns]
|
|
raise ValueError(
|
|
f"Partition columns not found in data: {missing}"
|
|
)
|
|
sub_df = chunk_df[part_cols]
|
|
chunk_tree = discover_partition_values(sub_df, partition_by, columns)
|
|
_merge_partition_trees(merged, chunk_tree)
|
|
|
|
return merged
|
|
|
|
|
|
def _merge_partition_trees(target: dict, source: dict) -> None:
|
|
"""Merge ``source`` partition tree into ``target`` in place.
|
|
|
|
Both trees are nested dicts where keys are partition values and values
|
|
are either empty dicts (leaf) or nested dicts (intermediate levels).
|
|
"""
|
|
for key, subtree in source.items():
|
|
if key not in target:
|
|
target[key] = subtree
|
|
else:
|
|
# Merge children recursively
|
|
if subtree and target[key]:
|
|
_merge_partition_trees(target[key], subtree)
|
|
elif subtree:
|
|
target[key] = subtree
|
|
|
|
|
|
def _count_partitions(tree: dict) -> int:
|
|
"""Count total partition tables in a nested partition tree."""
|
|
count = 0
|
|
for _key, children in tree.items():
|
|
count += 1
|
|
if children:
|
|
count += _count_partitions(children)
|
|
return count
|
|
|
|
|
|
def render_partition_ddl(
|
|
schema: str,
|
|
parent_table: str,
|
|
partition_by: list[str],
|
|
partition_values: dict,
|
|
column_specs: Dict[str, ColumnSpec],
|
|
*,
|
|
max_partitions: int = 10_000,
|
|
) -> list[str]:
|
|
"""Generate all child partition DDL statements for the partition tree.
|
|
|
|
Returns a list of SQL strings to execute in order (depth-first).
|
|
The parent ``CREATE TABLE`` is NOT included — it is rendered separately
|
|
by :func:`render_create_table`.
|
|
|
|
Logs a warning if the total partition count exceeds ``max_partitions``,
|
|
but continues.
|
|
"""
|
|
if not partition_by or not partition_values:
|
|
return []
|
|
|
|
total = _count_partitions(partition_values)
|
|
if total > max_partitions:
|
|
logger.warning(
|
|
"Partition count (%d) exceeds threshold (%d). "
|
|
"This may impact database performance.",
|
|
total, max_partitions,
|
|
)
|
|
print(
|
|
f"[warn] partition plan for {schema}.{parent_table} will create "
|
|
f"{total:,} partition tables, exceeding max_partitions={max_partitions:,}",
|
|
file=sys.stderr,
|
|
)
|
|
|
|
# Track used child names at each parent level to detect collisions
|
|
statements: list[str] = []
|
|
_render_partition_ddl_recursive(
|
|
schema, parent_table, partition_by, partition_values,
|
|
column_specs, 0, statements,
|
|
)
|
|
return statements
|
|
|
|
|
|
def _render_partition_ddl_recursive(
|
|
schema: str,
|
|
parent_table: str,
|
|
partition_by: list[str],
|
|
values: dict,
|
|
column_specs: Dict[str, ColumnSpec],
|
|
depth: int,
|
|
statements: list[str],
|
|
) -> None:
|
|
"""Recursively generate partition DDL statements (depth-first)."""
|
|
field_name = partition_by[depth]
|
|
next_field = partition_by[depth + 1] if depth + 1 < len(partition_by) else None
|
|
field_spec = column_specs.get(field_name)
|
|
pg_type = field_spec.postgres_type if field_spec else "TEXT"
|
|
|
|
# Track names used at this level under this parent to handle collisions
|
|
used_names: Dict[str, Any] = {}
|
|
|
|
# Sort values deterministically: None first, then by string representation
|
|
def _sort_key(val: Any) -> Tuple[int, str]:
|
|
if val is None:
|
|
return (0, "")
|
|
return (1, str(val))
|
|
|
|
sorted_values = sorted(values.keys(), key=_sort_key)
|
|
|
|
for val in sorted_values:
|
|
children = values[val]
|
|
token = _sanitize_partition_value(val, parent_table)
|
|
child_name = f"{parent_table}_{token}"
|
|
|
|
# Handle collisions
|
|
if child_name in used_names and used_names[child_name] is not val:
|
|
# Append a short hash of the value to disambiguate
|
|
val_hash = hashlib.sha256(repr(val).encode()).hexdigest()[:8]
|
|
# Re-truncate token to make room for _hash
|
|
max_token_len = _PG_IDENT_MAX_LEN - len(parent_table) - 1 - 9 # _hash8
|
|
if max_token_len < 1:
|
|
max_token_len = 1
|
|
truncated_token = token[:max_token_len].rstrip("_")
|
|
child_name = f"{parent_table}_{truncated_token}_{val_hash}"
|
|
|
|
# Final length check
|
|
if len(child_name) > _PG_IDENT_MAX_LEN:
|
|
child_name = child_name[:_PG_IDENT_MAX_LEN]
|
|
|
|
used_names[child_name] = val
|
|
|
|
literal = _render_partition_value_literal(val, pg_type)
|
|
|
|
if next_field is not None:
|
|
# Intermediate partition: itself partitioned by the next field
|
|
stmt = (
|
|
f"CREATE TABLE {_qualified(schema, child_name)} "
|
|
f"PARTITION OF {_qualified(schema, parent_table)} "
|
|
f"FOR VALUES IN ({literal}) "
|
|
f"PARTITION BY LIST ({_quote_ident(next_field)});"
|
|
)
|
|
statements.append(stmt)
|
|
# Recurse into children
|
|
if children:
|
|
_render_partition_ddl_recursive(
|
|
schema, child_name, partition_by, children,
|
|
column_specs, depth + 1, statements,
|
|
)
|
|
else:
|
|
# Leaf partition
|
|
stmt = (
|
|
f"CREATE TABLE {_qualified(schema, child_name)} "
|
|
f"PARTITION OF {_qualified(schema, parent_table)} "
|
|
f"FOR VALUES IN ({literal});"
|
|
)
|
|
statements.append(stmt)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Index support
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def render_create_indexes(
|
|
schema: str,
|
|
tablename: str,
|
|
indexes: List[str],
|
|
) -> List[str]:
|
|
"""Generate ``CREATE INDEX IF NOT EXISTS`` DDL for each column in *indexes*.
|
|
|
|
Each index is a simple B-tree index on a single column. The index name
|
|
follows the pattern ``ix_{tablename}_{column}`` (raw, unsanitized names
|
|
wrapped with :func:`_quote_ident`). The table reference is fully
|
|
qualified as ``schema.tablename``.
|
|
|
|
If the generated index name exceeds PostgreSQL's 63-character identifier
|
|
limit, it is truncated and a short hash suffix is appended to preserve
|
|
uniqueness (similar to partition name truncation).
|
|
|
|
Returns a list of SQL strings, one per index.
|
|
"""
|
|
stmts: List[str] = []
|
|
for col in indexes:
|
|
idx_name = f"ix_{tablename}_{col}"
|
|
if len(idx_name) > _PG_IDENT_MAX_LEN:
|
|
# Truncate and append an 8-char hash for uniqueness.
|
|
name_hash = hashlib.sha256(idx_name.encode()).hexdigest()[:8]
|
|
# 9 = 1 underscore + 8 hash chars
|
|
truncated = idx_name[: _PG_IDENT_MAX_LEN - 9].rstrip("_")
|
|
idx_name = f"{truncated}_{name_hash}"
|
|
stmt = (
|
|
f"CREATE INDEX IF NOT EXISTS {_quote_ident(idx_name)} "
|
|
f"ON {_qualified(schema, tablename)} ({_quote_ident(col)});"
|
|
)
|
|
stmts.append(stmt)
|
|
return stmts
|
|
|
|
|
|
def create_indexes(
|
|
conn,
|
|
schema: str,
|
|
tablename: str,
|
|
indexes: List[str],
|
|
) -> None:
|
|
"""Execute ``CREATE INDEX IF NOT EXISTS`` for each column in *indexes*.
|
|
|
|
Calls :func:`render_create_indexes` to obtain the DDL, executes each
|
|
statement, commits immediately after each successful creation, and logs
|
|
progress to stderr. If an individual index creation fails (e.g. a name
|
|
collision unrelated to ``IF NOT EXISTS``), the transaction is rolled back
|
|
(affecting only the failed statement) and the remaining indexes are still
|
|
attempted.
|
|
"""
|
|
stmts = render_create_indexes(schema, tablename, indexes)
|
|
with conn.cursor() as cur:
|
|
for stmt, col in zip(stmts, indexes):
|
|
try:
|
|
cur.execute(stmt)
|
|
conn.commit()
|
|
print(
|
|
f"[info] created index ix_{tablename}_{col} "
|
|
f"on {schema}.{tablename}({col})",
|
|
file=sys.stderr,
|
|
)
|
|
except Exception as exc:
|
|
conn.rollback()
|
|
print(
|
|
f"[warn] failed to create index ix_{tablename}_{col} "
|
|
f"on {schema}.{tablename}({col}): {exc}",
|
|
file=sys.stderr,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# COPY loading
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _seconds_to_time(v: Any) -> Optional[dt.time]:
|
|
if v is None:
|
|
return None
|
|
if isinstance(v, float) and pd.isna(v):
|
|
return None
|
|
if isinstance(v, dt.time):
|
|
return v
|
|
if isinstance(v, (dt.datetime, pd.Timestamp)):
|
|
return v.time() if not pd.isna(v) else None
|
|
try:
|
|
total = int(round(float(v)))
|
|
except (TypeError, ValueError):
|
|
return None
|
|
h, rem = divmod(total, 3600)
|
|
m, s = divmod(rem, 60)
|
|
# Clamp; TIME8. is always within a day.
|
|
h = max(0, min(h, 23))
|
|
return dt.time(h, m, s)
|
|
|
|
|
|
def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.DataFrame:
|
|
"""Materialize a copy of ``df`` with each column in the right shape for
|
|
``to_csv`` so the CSV lands as valid input for the target Postgres type.
|
|
"""
|
|
out = pd.DataFrame(index=df.index)
|
|
for name, spec in columns.items():
|
|
series = df[name]
|
|
pg = spec.postgres_type.upper()
|
|
|
|
if pg in ("INTEGER", "BIGINT", "SMALLINT"):
|
|
if pd.api.types.is_object_dtype(series):
|
|
series = pd.to_numeric(
|
|
series.replace({"": None}), errors="coerce"
|
|
)
|
|
out[name] = series.astype("Int64")
|
|
elif pg in ("DOUBLE PRECISION", "REAL", "NUMERIC"):
|
|
if pd.api.types.is_object_dtype(series):
|
|
series = pd.to_numeric(
|
|
series.replace({"": None}), errors="coerce"
|
|
)
|
|
out[name] = series.astype("float64")
|
|
elif pg == "DATE":
|
|
if pd.api.types.is_datetime64_any_dtype(series):
|
|
out[name] = series.dt.date
|
|
elif pd.api.types.is_object_dtype(series):
|
|
def _to_date(v: Any) -> Optional[dt.date]:
|
|
if v is None or (isinstance(v, float) and pd.isna(v)):
|
|
return None
|
|
if isinstance(v, dt.datetime):
|
|
return v.date()
|
|
if isinstance(v, dt.date):
|
|
return v
|
|
if isinstance(v, str):
|
|
if v == "":
|
|
return None
|
|
try:
|
|
return dt.date.fromisoformat(v)
|
|
except ValueError:
|
|
return None
|
|
return None
|
|
out[name] = series.map(_to_date)
|
|
else:
|
|
out[name] = series
|
|
elif pg in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE", "TIMESTAMP WITH TIME ZONE"):
|
|
if pd.api.types.is_datetime64_any_dtype(series):
|
|
out[name] = series
|
|
elif pd.api.types.is_object_dtype(series):
|
|
def _to_dt(v: Any) -> Optional[dt.datetime]:
|
|
if v is None or (isinstance(v, float) and pd.isna(v)):
|
|
return None
|
|
if isinstance(v, dt.datetime):
|
|
return v
|
|
if isinstance(v, dt.date):
|
|
return dt.datetime(v.year, v.month, v.day)
|
|
if isinstance(v, pd.Timestamp):
|
|
return v.to_pydatetime() if not pd.isna(v) else None
|
|
if isinstance(v, str):
|
|
if v == "":
|
|
return None
|
|
try:
|
|
return dt.datetime.fromisoformat(v)
|
|
except ValueError:
|
|
return None
|
|
return None
|
|
out[name] = series.map(_to_dt)
|
|
else:
|
|
out[name] = series
|
|
elif pg in ("TIME", "TIME WITHOUT TIME ZONE", "TIME WITH TIME ZONE"):
|
|
out[name] = series.map(_seconds_to_time)
|
|
elif pg in ("TEXT", "VARCHAR", "CHARACTER VARYING", "CHAR", "CHARACTER"):
|
|
# Leave empty strings as "" so `NULL ''` in COPY turns them into NULL.
|
|
def _to_str(v: Any) -> Any:
|
|
if v is None:
|
|
return ""
|
|
if isinstance(v, float) and pd.isna(v):
|
|
return ""
|
|
return str(v)
|
|
out[name] = series.map(_to_str)
|
|
elif pg == "BOOLEAN":
|
|
out[name] = series.astype("boolean") if series.dtype != object else series
|
|
else:
|
|
out[name] = series
|
|
return out
|
|
|
|
|
|
def copy_dataframes(
|
|
conn,
|
|
schema_name: str,
|
|
table_name: str,
|
|
dfs: Iterable[pd.DataFrame],
|
|
columns: Dict[str, ColumnSpec],
|
|
) -> int:
|
|
"""Stream an iterable of DataFrames into Postgres, committing each chunk.
|
|
|
|
Each non-empty chunk is copied via ``COPY ... FROM STDIN`` and committed
|
|
before the next chunk is processed, so an interrupted or failed load
|
|
retains the rows from previously committed chunks. The first chunk's
|
|
commit also flushes any pending DDL (e.g. a preceding ``CREATE TABLE``).
|
|
Empty chunks are skipped. Returns the total rows inserted.
|
|
"""
|
|
col_list = ", ".join(_quote_ident(name) for name in columns.keys())
|
|
sql = (
|
|
f"COPY {_qualified(schema_name, table_name)} ({col_list}) "
|
|
f"FROM STDIN WITH (FORMAT csv, NULL '')"
|
|
)
|
|
|
|
total = 0
|
|
with conn.cursor() as cur:
|
|
for df in dfs:
|
|
if df.empty:
|
|
continue
|
|
prepared = _prepare_for_copy(df, columns)
|
|
buf = io.StringIO()
|
|
prepared.to_csv(
|
|
buf,
|
|
index=False,
|
|
header=False,
|
|
na_rep="",
|
|
date_format="%Y-%m-%d %H:%M:%S",
|
|
)
|
|
buf.seek(0)
|
|
cur.copy_expert(sql, buf)
|
|
conn.commit()
|
|
total += len(prepared)
|
|
return total
|
|
|
|
|
|
def copy_dataframe(
|
|
conn,
|
|
schema_name: str,
|
|
table_name: str,
|
|
df: pd.DataFrame,
|
|
columns: Dict[str, ColumnSpec],
|
|
) -> int:
|
|
"""Stream ``df`` into Postgres via ``COPY ... FROM STDIN``.
|
|
|
|
Convenience wrapper around :func:`copy_dataframes` for single-frame
|
|
callers. Returns the number of rows inserted.
|
|
"""
|
|
return copy_dataframes(conn, schema_name, table_name, [df], columns)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Manifest validation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _match_manifest_type(inferred: str, manifest_entry: Dict[str, Any]) -> bool:
|
|
inferred_norm = _normalize_type(inferred)
|
|
if "postgres_type" in manifest_entry:
|
|
return inferred_norm == _normalize_type(manifest_entry["postgres_type"])
|
|
if "acceptable_types" in manifest_entry:
|
|
return any(
|
|
inferred_norm == _normalize_type(t)
|
|
for t in manifest_entry["acceptable_types"]
|
|
)
|
|
return False
|
|
|
|
|
|
def validate_against_manifest(
|
|
inferred: Dict[str, ColumnSpec],
|
|
manifest_path: Path,
|
|
) -> List[str]:
|
|
"""Compare the inferred schema against the expected-types manifest.
|
|
|
|
Returns a list of human-readable problem strings; empty list means OK.
|
|
"""
|
|
manifest_path = Path(manifest_path)
|
|
if not manifest_path.exists():
|
|
return [f"manifest not found: {manifest_path}"]
|
|
|
|
with manifest_path.open("r", encoding="utf-8") as f:
|
|
manifest = json.load(f)
|
|
|
|
problems: List[str] = []
|
|
|
|
only_in_inferred = set(inferred) - set(manifest)
|
|
only_in_manifest = set(manifest) - set(inferred)
|
|
if only_in_inferred:
|
|
problems.append(
|
|
f"columns in inferred but not manifest: {sorted(only_in_inferred)}"
|
|
)
|
|
if only_in_manifest:
|
|
problems.append(
|
|
f"columns in manifest but not inferred: {sorted(only_in_manifest)}"
|
|
)
|
|
|
|
for name, spec in inferred.items():
|
|
entry = manifest.get(name)
|
|
if entry is None:
|
|
continue
|
|
if not _match_manifest_type(spec.postgres_type, entry):
|
|
expected = entry.get("postgres_type") or entry.get("acceptable_types")
|
|
problems.append(
|
|
f"column {name!r}: inferred {spec.postgres_type!r}, "
|
|
f"manifest expected {expected!r}"
|
|
)
|
|
manifest_nullable = bool(entry.get("nullable", True))
|
|
if spec.nullable and not manifest_nullable:
|
|
problems.append(
|
|
f"column {name!r}: inferred nullable, manifest expects NOT NULL "
|
|
f"(loosening nullability is never allowed)"
|
|
)
|
|
|
|
return problems
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CLI
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _build_argparser() -> argparse.ArgumentParser:
|
|
p = argparse.ArgumentParser(
|
|
description="Load a single SAS file (XPT or sas7bdat) into Postgres.",
|
|
)
|
|
p.add_argument("--config", required=True, type=Path, help="Path to YAML config")
|
|
p.add_argument(
|
|
"--validate",
|
|
action="store_true",
|
|
help=(
|
|
"Compare inferred schema against <filename-stem>.expected.json "
|
|
"next to the SAS file; exits nonzero on mismatch."
|
|
),
|
|
)
|
|
p.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
help="Print inferred CREATE TABLE and stop; don't touch Postgres.",
|
|
)
|
|
p.add_argument(
|
|
"--dbcreds",
|
|
action="store_true",
|
|
help=(
|
|
"Prompt for database username and password instead of reading "
|
|
"PGUSER / PGPASSWORD from the environment or .env file."
|
|
),
|
|
)
|
|
return p
|
|
|
|
|
|
def _format_columns_summary(columns: Dict[str, ColumnSpec]) -> str:
|
|
lines = []
|
|
for spec in columns.values():
|
|
null = "" if spec.nullable else " NOT NULL"
|
|
lines.append(f" {spec.name}: {spec.postgres_type}{null}")
|
|
return "\n".join(lines)
|
|
|
|
|
|
def main(argv: Optional[List[str]] = None) -> int:
|
|
args = _build_argparser().parse_args(argv)
|
|
|
|
load_dotenv()
|
|
|
|
cfg = load_config(args.config)
|
|
|
|
if not cfg.filename.exists():
|
|
print(f"error: SAS file not found: {cfg.filename}", file=sys.stderr)
|
|
return 2
|
|
|
|
# Schema inference uses a bounded preview read so we never load a
|
|
# hundreds-of-millions-of-rows file into memory just to pick types.
|
|
# NB: ``meta.number_rows`` on a ``row_limit``-ed read reflects rows
|
|
# returned, not the file's total, so we don't trust it here.
|
|
preview_df, meta = read_sas_preview(cfg.filename)
|
|
preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude)
|
|
columns = infer_schema(preview_df, meta)
|
|
|
|
# Validate partition columns exist in the schema after filtering.
|
|
if cfg.partition_by:
|
|
missing_pcols = [c for c in cfg.partition_by if c not in columns]
|
|
if missing_pcols:
|
|
raise ValueError(
|
|
f"partition_by references columns not present in the "
|
|
f"(filtered) schema: {missing_pcols}"
|
|
)
|
|
|
|
# Validate index columns exist in the schema after filtering.
|
|
if cfg.indexes:
|
|
missing_icols = [c for c in cfg.indexes if c not in columns]
|
|
if missing_icols:
|
|
raise ValueError(
|
|
f"indexes references columns not present in the "
|
|
f"(filtered) schema: {missing_icols}"
|
|
)
|
|
|
|
if args.validate:
|
|
manifest_path = cfg.filename.with_suffix("").with_suffix(".expected.json")
|
|
# The above strips .xpt then appends .expected.json, e.g.
|
|
# "sample_kitchensink.xpt" -> "sample_kitchensink.expected.json".
|
|
problems = validate_against_manifest(columns, manifest_path)
|
|
if problems:
|
|
print("validation failed:", file=sys.stderr)
|
|
for p in problems:
|
|
print(f" - {p}", file=sys.stderr)
|
|
return 1
|
|
print(f"validation OK ({len(columns)} columns match {manifest_path.name})")
|
|
|
|
# -- Partition value discovery ------------------------------------------
|
|
# If partitioned, scan the ENTIRE file to discover all unique partition
|
|
# values. The preview is only the first N rows and may miss values.
|
|
# In append mode the partitions already exist, so skip the costly scan.
|
|
partition_values: Optional[dict] = None
|
|
if cfg.partition_by and cfg.if_exists != "append":
|
|
print(" discovering partition values (full file scan)...", file=sys.stderr)
|
|
|
|
def _discovery_chunks():
|
|
for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename):
|
|
yield apply_column_filter(chunk_df, cfg.include, cfg.exclude)
|
|
|
|
partition_values = discover_partition_values_chunked(
|
|
_discovery_chunks(), cfg.partition_by, columns,
|
|
)
|
|
total_parts = _count_partitions(partition_values)
|
|
print(
|
|
f" discovered {total_parts:,} partition tables "
|
|
f"across {len(cfg.partition_by)} level(s)",
|
|
file=sys.stderr,
|
|
)
|
|
elif cfg.partition_by and cfg.if_exists == "append":
|
|
print(
|
|
" [info] append mode: skipping partition discovery "
|
|
"(partitions assumed to exist)",
|
|
file=sys.stderr,
|
|
)
|
|
|
|
if args.dry_run:
|
|
# Print the parent CREATE TABLE (with PARTITION BY if applicable).
|
|
parent_ddl = render_create_table(
|
|
cfg.schemaname, cfg.tablename, columns,
|
|
partition_by=cfg.partition_by or None,
|
|
)
|
|
print(parent_ddl)
|
|
# Print child partition DDL if partitioned.
|
|
if cfg.partition_by and partition_values:
|
|
child_stmts = render_partition_ddl(
|
|
cfg.schemaname, cfg.tablename, cfg.partition_by,
|
|
partition_values, columns,
|
|
max_partitions=cfg.max_partitions,
|
|
)
|
|
for stmt in child_stmts:
|
|
print()
|
|
print(stmt)
|
|
# Print CREATE INDEX DDL if indexes are configured.
|
|
if cfg.indexes:
|
|
idx_stmts = render_create_indexes(
|
|
cfg.schemaname, cfg.tablename, cfg.indexes,
|
|
)
|
|
for stmt in idx_stmts:
|
|
print()
|
|
print(stmt)
|
|
return 0
|
|
|
|
# Release the preview frame before opening the stream - lets the GC reclaim
|
|
# it while we're holding a Postgres transaction open.
|
|
del preview_df
|
|
|
|
def _filtered_chunks():
|
|
seen = 0
|
|
for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename):
|
|
chunk_df = apply_column_filter(chunk_df, cfg.include, cfg.exclude)
|
|
seen += len(chunk_df)
|
|
print(f" streaming... {seen:,} rows", file=sys.stderr)
|
|
yield chunk_df
|
|
|
|
db_user = db_password = None
|
|
if args.dbcreds:
|
|
db_user = input("Database username: ")
|
|
db_password = getpass.getpass("Database password: ")
|
|
|
|
conn = connect(user=db_user, password=db_password)
|
|
conn.autocommit = False
|
|
try:
|
|
create_table(
|
|
conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists,
|
|
partition_by=cfg.partition_by or None,
|
|
partition_values=partition_values,
|
|
max_partitions=cfg.max_partitions,
|
|
)
|
|
inserted = copy_dataframes(
|
|
conn, cfg.schemaname, cfg.tablename, _filtered_chunks(), columns
|
|
)
|
|
conn.commit()
|
|
if cfg.indexes:
|
|
create_indexes(conn, cfg.schemaname, cfg.tablename, cfg.indexes)
|
|
except Exception:
|
|
conn.rollback()
|
|
raise
|
|
finally:
|
|
conn.close()
|
|
|
|
print(
|
|
f"loaded {inserted} rows into {cfg.schemaname}.{cfg.tablename} "
|
|
f"({len(columns)} columns)"
|
|
)
|
|
if cfg.partition_by and partition_values:
|
|
total_parts = _count_partitions(partition_values)
|
|
print(f"partitioned by {cfg.partition_by} ({total_parts:,} partition tables)")
|
|
print("final schema:")
|
|
print(_format_columns_summary(columns))
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|