Added memory management optimizations in the _worker_load_append_file function to release unused memory from pyarrow's pool and trigger Python's garbage collection. Implemented explicit memory trimming using glibc's malloc_trim to ensure efficient memory usage during long-running processes. Updated the copy_dataframes function in load_sas.py to release pyarrow's memory pool between chunks, preventing high memory usage in long-lived workers. These changes aim to reduce memory footprint and improve overall performance during large dataset processing.
2217 lines
82 KiB
Python
2217 lines
82 KiB
Python
"""Per-file SAS-to-Postgres loader.
|
||
|
||
Library-style functions plus a thin CLI wrapper. Designed so an orchestrator
|
||
can wrap the library for directory/batch mode; orchestration is out of scope
|
||
here.
|
||
|
||
Python 3.14 compatible (target is an air-gapped host that currently only has
|
||
3.14). ``from __future__ import annotations`` lets us use PEP 585 generics
|
||
as annotations; runtime-resolved type uses (dataclass defaults, etc.) stick
|
||
to ``typing``.
|
||
|
||
-------------------------------------------------------------------------------
|
||
USAGE
|
||
-------------------------------------------------------------------------------
|
||
|
||
Supported inputs:
|
||
* ``.sas7bdat`` (read with ``encoding="latin-1"``)
|
||
* ``.xpt`` / ``.xport`` (SAS transport files)
|
||
|
||
1. YAML config
|
||
--------------
|
||
Every invocation is driven by a YAML file describing one SAS file to load::
|
||
|
||
filename: samples/sample_kitchensink.xpt # required; relative paths are
|
||
# resolved against the config
|
||
# file's directory when possible
|
||
schemaname: public # required
|
||
tablename: kitchensink # required
|
||
|
||
# Optional. One of: fail | replace | append. Default: fail.
|
||
# fail - error out if the target table already exists
|
||
# replace - DROP and recreate the table from the inferred schema
|
||
# append - keep the existing table; pre-flight a schema-compat check,
|
||
# then COPY the new rows in
|
||
if_exists: append
|
||
|
||
# Optional, mutually exclusive. Restrict which columns are loaded.
|
||
# include:
|
||
# - ID
|
||
# - INTCOL
|
||
# exclude:
|
||
# - ALLNULL
|
||
|
||
2. Database connection
|
||
----------------------
|
||
The loader uses standard libpq environment variables (read via ``os.environ``)::
|
||
|
||
PGHOST, PGPORT, PGUSER, PGPASSWORD, PGDATABASE
|
||
|
||
The CLI calls ``python-dotenv``'s ``load_dotenv()`` at startup, so a local
|
||
``.env`` file is picked up automatically. Library callers are responsible for
|
||
populating the environment themselves (either call ``load_dotenv()`` or export
|
||
the vars) before calling :func:`connect`.
|
||
|
||
3. Command-line interface
|
||
-------------------------
|
||
::
|
||
|
||
python load_sas.py --config path/to/config.yaml [--validate] [--dry-run]
|
||
[--dbcreds]
|
||
|
||
Flags:
|
||
--config PATH Required. Path to the YAML config above.
|
||
--validate Compare the inferred schema against
|
||
``<sas-file-stem>.expected.json`` sitting next to the SAS
|
||
file. Exits nonzero on mismatch. Safe to combine with
|
||
``--dry-run``.
|
||
--dry-run Print the inferred ``CREATE TABLE`` SQL and stop. The
|
||
database is never touched (no connection is opened).
|
||
--dbcreds Prompt interactively for the database username and
|
||
password instead of reading ``PGUSER`` / ``PGPASSWORD``
|
||
from the environment or ``.env`` file. The password
|
||
prompt does not echo. Has no effect with ``--dry-run``
|
||
(no connection is opened).
|
||
|
||
Exit codes:
|
||
0 - success (load completed, or dry-run/validate passed)
|
||
1 - validation failure
|
||
2 - config references a SAS file that does not exist
|
||
Other nonzero - uncaught exception (traceback printed); the transaction
|
||
is rolled back before exit.
|
||
|
||
Typical invocations::
|
||
|
||
# Preview the inferred schema without connecting to Postgres.
|
||
python load_sas.py --config sample_config.yaml --dry-run
|
||
|
||
# Check the inferred schema against an expected-types manifest.
|
||
python load_sas.py --config sample_config.yaml --validate --dry-run
|
||
|
||
# Actually load the data.
|
||
python load_sas.py --config sample_config.yaml
|
||
|
||
# Load the data, prompting for credentials instead of using .env.
|
||
python load_sas.py --config sample_config.yaml --dbcreds
|
||
|
||
4. Expected-types manifest (``--validate``)
|
||
-------------------------------------------
|
||
``--validate`` looks for a JSON file named ``<sas-stem>.expected.json`` next
|
||
to the SAS file, e.g. ``samples/sample_kitchensink.xpt`` pairs with
|
||
``samples/sample_kitchensink.expected.json``. Each top-level key is a column
|
||
name; the value is an object with any of::
|
||
|
||
{
|
||
"postgres_type": "BIGINT", # exact expected type, OR
|
||
"acceptable_types": ["TEXT", # any-of list of acceptable types
|
||
"VARCHAR"],
|
||
"nullable": true, # default true; false = must be NOT NULL
|
||
"note": "free-form comment" # ignored by the loader
|
||
}
|
||
|
||
Type comparison ignores length/precision modifiers and normalizes synonyms
|
||
(e.g. ``INT`` == ``INTEGER`` == ``INT4``; ``VARCHAR(10)`` == ``VARCHAR``).
|
||
Nullability tightening (inferred NULL, manifest NOT NULL) is a hard failure;
|
||
loosening is not checked here because the append-mode check already covers it.
|
||
|
||
5. Library usage
|
||
----------------
|
||
The CLI is a thin wrapper around composable functions. The preferred pattern
|
||
infers the schema from a bounded preview and then streams the rest of the
|
||
file chunk-by-chunk into ``COPY`` - crucial for SAS files with hundreds of
|
||
millions of rows::
|
||
|
||
from dotenv import load_dotenv
|
||
from load_sas import (
|
||
load_config, read_sas_preview, iter_sas_chunks, apply_column_filter,
|
||
infer_schema, validate_against_manifest, render_create_table,
|
||
connect, create_table, copy_dataframes,
|
||
)
|
||
|
||
load_dotenv()
|
||
cfg = load_config("config.yaml")
|
||
|
||
# Schema from a preview slice (bounded by TYPE_INFERENCE_SAMPLE_ROWS).
|
||
preview_df, meta = read_sas_preview(cfg.filename)
|
||
preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude)
|
||
total_rows = getattr(meta, "number_rows", None)
|
||
columns = infer_schema(preview_df, meta, total_rows=total_rows)
|
||
|
||
# Optional: preview DDL / validate against a manifest.
|
||
print(render_create_table(cfg.schemaname, cfg.tablename, columns))
|
||
problems = validate_against_manifest(columns, Path("expected.json"))
|
||
assert not problems, problems
|
||
|
||
conn = connect()
|
||
conn.autocommit = False
|
||
try:
|
||
create_table(conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists)
|
||
chunks = (
|
||
apply_column_filter(df, cfg.include, cfg.exclude)
|
||
for df, _ in iter_sas_chunks(cfg.filename)
|
||
)
|
||
rows = copy_dataframes(conn, cfg.schemaname, cfg.tablename, chunks, columns)
|
||
conn.commit()
|
||
finally:
|
||
conn.close()
|
||
|
||
For small files (or tests) the legacy one-shot API still works:
|
||
:func:`read_sas` returns the whole frame and :func:`copy_dataframe` copies it
|
||
in one round trip.
|
||
|
||
All functions are side-effect free except :func:`connect`, :func:`create_table`,
|
||
:func:`copy_dataframe`, and :func:`copy_dataframes`; schema inference
|
||
(:func:`infer_schema`) accepts a ``coerce_chars`` kwarg to override the
|
||
module-level ``COERCE_CHAR_COLUMNS`` without mutating global state.
|
||
|
||
6. Type inference summary
|
||
-------------------------
|
||
Priority order used by :func:`infer_schema`:
|
||
|
||
1. SAS format string (via ``meta.original_variable_types``):
|
||
``DATETIME*`` -> ``TIMESTAMP``, ``TIME*`` -> ``TIME``,
|
||
``DATE*`` / ``YYMMDD*`` / ``MMDDYY*`` / ``DDMMYY*`` / ``JULIAN*`` -> ``DATE``.
|
||
2. All-null column -> ``TEXT`` (with a note).
|
||
3. pandas datetime dtype -> ``TIMESTAMP``.
|
||
4. Object columns containing only ``datetime.date`` / ``datetime.datetime``
|
||
-> ``DATE`` or ``TIMESTAMP``.
|
||
5. Object columns of strings: if ``COERCE_CHAR_COLUMNS`` is True and at
|
||
least ``CHAR_INFERENCE_MIN_VALUES`` non-empty values parse cleanly, they
|
||
are promoted to ``INTEGER`` / ``BIGINT`` / ``DOUBLE PRECISION`` /
|
||
``DATE`` / ``TIMESTAMP``; otherwise ``TEXT``.
|
||
6. Numeric columns of whole numbers -> ``INTEGER`` (or ``BIGINT`` if any
|
||
value exceeds the int32 range ``NUMERIC_INT_RANGE``); otherwise
|
||
``DOUBLE PRECISION``.
|
||
|
||
Type inference scans the whole file by default (``TYPE_INFERENCE_SAMPLE_ROWS
|
||
= None``) so type + nullability are both computed against every row. The CLI
|
||
materializes the file once for schema inference, then re-streams it chunk by
|
||
chunk into ``COPY``; peak memory is roughly one full dataframe. Override
|
||
``TYPE_INFERENCE_SAMPLE_ROWS`` to an integer cap if you're on a host that
|
||
can't hold the file in memory - but know that sampled specs carry the usual
|
||
risks: a later row may exceed the inferred integer range, or a column that
|
||
had no nulls in the preview may carry nulls later in the file (which then
|
||
detonates ``COPY`` because the sampled spec stamped it ``NOT NULL``). Seen
|
||
in production on a 2.5M-row file with ~6k null MAFIDs past the 10k-row
|
||
preview - the entire load aborted mid-stream.
|
||
|
||
Streaming loads use :func:`iter_sas_chunks` + :func:`copy_dataframes`, which
|
||
commit each chunk as it is copied so an interrupted load retains the rows
|
||
that were already written.
|
||
|
||
7. Tunables
|
||
-----------
|
||
Module-level knobs at the top of this file:
|
||
|
||
* ``COERCE_CHAR_COLUMNS`` - promote stringly-typed numerics / dates
|
||
(default True).
|
||
* ``CHAR_INFERENCE_MIN_VALUES`` - minimum non-empty sample size before
|
||
char-column coercion is attempted.
|
||
* ``NUMERIC_INT_RANGE`` - INTEGER bounds; values outside become
|
||
``BIGINT``.
|
||
* ``TYPE_INFERENCE_SAMPLE_ROWS`` - cap on rows read for type inference
|
||
(``None`` = scan the whole column).
|
||
* ``DEFAULT_CHUNK_ROWS`` - rows per streaming COPY chunk.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import datetime as dt
|
||
import getpass
|
||
import hashlib
|
||
import io
|
||
import json
|
||
import logging
|
||
import math
|
||
import os
|
||
import re
|
||
import sys
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||
|
||
import pandas as pd
|
||
import psycopg2
|
||
import psycopg2.extensions
|
||
import pyarrow as pa
|
||
import pyarrow.csv as pa_csv
|
||
import pyreadstat
|
||
import yaml
|
||
from dotenv import load_dotenv
|
||
from tqdm import tqdm
|
||
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Top-level tunables
|
||
# ---------------------------------------------------------------------------
|
||
|
||
COERCE_CHAR_COLUMNS = True
|
||
"""If True, try to promote object (string) columns to numeric/date/timestamp
|
||
when every non-empty value parses cleanly."""
|
||
|
||
CHAR_INFERENCE_MIN_VALUES = 3
|
||
"""Don't attempt character-column coercion with fewer than this many non-empty
|
||
values; too small a sample is easy to mis-infer."""
|
||
|
||
NUMERIC_INT_RANGE = (-2_147_483_648, 2_147_483_647)
|
||
"""INTEGER bounds; anything outside becomes BIGINT."""
|
||
|
||
TYPE_INFERENCE_SAMPLE_ROWS: Optional[int] = None
|
||
"""Cap on rows inspected during per-column type inference. Also governs how
|
||
many rows :func:`read_sas_preview` pulls from the file for dry-run / validate /
|
||
schema-inference flows.
|
||
|
||
Default is ``None`` (scan every row, reading the whole file into memory for
|
||
the schema-inference step). That's the only honest setting for nullability:
|
||
any integer cap lets a column look ``NOT NULL`` across the first N rows
|
||
while the file actually holds rare nulls past the window, which then
|
||
detonates ``COPY`` mid-stream (seen in production on a 2.5M-row file where
|
||
~6k MAFIDs were null past the 10k-row preview). If you're loading a file
|
||
so large that a full read won't fit in memory, set this to an integer cap
|
||
and accept that sampled specs can't be trusted for ``NOT NULL``."""
|
||
|
||
DEFAULT_CHUNK_ROWS = 2_000_000
|
||
"""Rows per chunk when streaming a SAS file into ``COPY``. Larger values mean
|
||
fewer COPY round-trips and lower per-row overhead but more peak memory per
|
||
chunk; smaller values are gentler on memory.
|
||
|
||
The chunk size can be overridden at runtime via the
|
||
``GENERIC_LOADER_CHUNK_ROWS`` environment variable (read inside
|
||
:func:`iter_sas_chunks`), so ``.env``-driven overrides work without code
|
||
changes. Explicit ``chunksize=`` kwargs still win over both."""
|
||
|
||
|
||
VALID_IF_EXISTS = ("fail", "replace", "append")
|
||
|
||
_PG_IDENT_MAX_LEN = 63
|
||
"""PostgreSQL maximum identifier length in bytes (characters for ASCII)."""
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Dataclasses
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
@dataclass
|
||
class LoaderConfig:
|
||
filename: Path
|
||
schemaname: str
|
||
tablename: str
|
||
if_exists: str = "fail"
|
||
include: Optional[List[str]] = None
|
||
exclude: Optional[List[str]] = None
|
||
partition_by: List[str] = field(default_factory=list)
|
||
max_partitions: int = 10_000
|
||
indexes: List[str] = field(default_factory=list)
|
||
|
||
|
||
@dataclass
|
||
class ColumnSpec:
|
||
name: str
|
||
postgres_type: str
|
||
nullable: bool
|
||
sas_format: Optional[str] = None
|
||
source_dtype: Optional[str] = None
|
||
notes: List[str] = field(default_factory=list)
|
||
sampled: bool = False
|
||
"""True when the type was inferred from a bounded preview rather than the
|
||
full file. A sampled spec carries the usual sampling risks: a later chunk
|
||
could contain a value that exceeds the inferred integer range, doesn't
|
||
parse as the inferred type, or is null in a column the preview showed as
|
||
non-null - all of which surface as mid-``COPY`` failures."""
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Custom exceptions
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TableExistsError(RuntimeError):
|
||
"""Raised when if_exists=fail and the target table already exists."""
|
||
|
||
|
||
class SchemaCompatibilityError(RuntimeError):
|
||
"""Raised when if_exists=append and the incoming schema is not
|
||
compatible with the existing table."""
|
||
|
||
|
||
class ValidationError(RuntimeError):
|
||
"""Raised when --validate detects a mismatch against the manifest."""
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Connection
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def connect(
|
||
*,
|
||
user: Optional[str] = None,
|
||
password: Optional[str] = None,
|
||
) -> psycopg2.extensions.connection:
|
||
"""Open a psycopg2 connection using standard libpq env vars.
|
||
|
||
Assumes `.env` has already been loaded (the CLI does this before calling).
|
||
Orchestrators that wrap this module should either call ``load_dotenv()``
|
||
themselves or ensure the env vars are set.
|
||
|
||
``user`` and ``password`` override the corresponding env vars when supplied
|
||
(used by the ``--dbcreds`` CLI flag to accept interactive input).
|
||
"""
|
||
conn = psycopg2.connect(
|
||
host=os.environ.get("PGHOST"),
|
||
port=os.environ.get("PGPORT"),
|
||
user=user or os.environ.get("PGUSER"),
|
||
password=password or os.environ.get("PGPASSWORD"),
|
||
dbname=os.environ.get("PGDATABASE"),
|
||
)
|
||
return conn
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Config loading
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def load_config(path: Path) -> LoaderConfig:
|
||
"""Parse and validate the YAML config at ``path``."""
|
||
path = Path(path)
|
||
with path.open("r", encoding="utf-8") as f:
|
||
raw = yaml.safe_load(f)
|
||
|
||
if not isinstance(raw, dict):
|
||
raise ValueError(f"Config at {path} must be a YAML mapping at the top level.")
|
||
|
||
missing = [k for k in ("filename", "schemaname", "tablename") if k not in raw]
|
||
if missing:
|
||
raise ValueError(f"Config {path} missing required keys: {', '.join(missing)}")
|
||
|
||
filename = Path(raw["filename"])
|
||
if not filename.is_absolute():
|
||
filename = (path.parent / filename).resolve() if (path.parent / filename).exists() else Path(raw["filename"])
|
||
|
||
schemaname = str(raw["schemaname"])
|
||
tablename = str(raw["tablename"])
|
||
|
||
if_exists = str(raw.get("if_exists", "fail")).lower()
|
||
if if_exists not in VALID_IF_EXISTS:
|
||
raise ValueError(
|
||
f"Config {path}: if_exists={if_exists!r} is not one of {VALID_IF_EXISTS}"
|
||
)
|
||
|
||
include = raw.get("include")
|
||
exclude = raw.get("exclude")
|
||
if include is not None and exclude is not None:
|
||
raise ValueError(
|
||
f"Config {path}: 'include' and 'exclude' are mutually exclusive; set at most one."
|
||
)
|
||
if include is not None and not isinstance(include, list):
|
||
raise ValueError(f"Config {path}: 'include' must be a list of column names.")
|
||
if exclude is not None and not isinstance(exclude, list):
|
||
raise ValueError(f"Config {path}: 'exclude' must be a list of column names.")
|
||
|
||
# -- partition_by -------------------------------------------------------
|
||
raw_pb = raw.get("partition_by")
|
||
if raw_pb is None or (isinstance(raw_pb, list) and len(raw_pb) == 0):
|
||
partition_by: List[str] = []
|
||
elif isinstance(raw_pb, str):
|
||
if not raw_pb.strip():
|
||
raise ValueError(f"Config {path}: 'partition_by' string must be non-empty.")
|
||
partition_by = [raw_pb.strip()]
|
||
elif isinstance(raw_pb, list):
|
||
partition_by = []
|
||
for i, item in enumerate(raw_pb):
|
||
if not isinstance(item, str) or not item.strip():
|
||
raise ValueError(
|
||
f"Config {path}: 'partition_by[{i}]' must be a non-empty string."
|
||
)
|
||
partition_by.append(str(item).strip())
|
||
if len(partition_by) != len(set(partition_by)):
|
||
raise ValueError(
|
||
f"Config {path}: 'partition_by' contains duplicate column names."
|
||
)
|
||
else:
|
||
raise ValueError(
|
||
f"Config {path}: 'partition_by' must be a string or list of strings."
|
||
)
|
||
|
||
# Validate partition_by vs include/exclude
|
||
if partition_by:
|
||
inc_list = [str(c) for c in include] if include is not None else None
|
||
exc_list = [str(c) for c in exclude] if exclude is not None else None
|
||
if inc_list is not None:
|
||
missing_in_include = [c for c in partition_by if c not in inc_list]
|
||
if missing_in_include:
|
||
raise ValueError(
|
||
f"Config {path}: 'include' omits partition_by columns: "
|
||
f"{missing_in_include}"
|
||
)
|
||
if exc_list is not None:
|
||
excluded_parts = [c for c in partition_by if c in exc_list]
|
||
if excluded_parts:
|
||
raise ValueError(
|
||
f"Config {path}: 'exclude' removes partition_by columns: "
|
||
f"{excluded_parts}"
|
||
)
|
||
|
||
# -- max_partitions -----------------------------------------------------
|
||
raw_mp = raw.get("max_partitions", 10_000)
|
||
try:
|
||
max_partitions = int(raw_mp)
|
||
except (TypeError, ValueError):
|
||
raise ValueError(
|
||
f"Config {path}: 'max_partitions' must be a positive integer, "
|
||
f"got {raw_mp!r}"
|
||
)
|
||
if max_partitions <= 0:
|
||
raise ValueError(
|
||
f"Config {path}: 'max_partitions' must be a positive integer, "
|
||
f"got {max_partitions}"
|
||
)
|
||
|
||
# -- indexes ------------------------------------------------------------
|
||
raw_idx = raw.get("indexes")
|
||
if raw_idx is None or (isinstance(raw_idx, list) and len(raw_idx) == 0):
|
||
indexes: List[str] = []
|
||
elif isinstance(raw_idx, str):
|
||
if not raw_idx.strip():
|
||
raise ValueError(f"Config {path}: 'indexes' string must be non-empty.")
|
||
indexes = [raw_idx.strip()]
|
||
elif isinstance(raw_idx, list):
|
||
indexes = []
|
||
for i, item in enumerate(raw_idx):
|
||
if not isinstance(item, str) or not item.strip():
|
||
raise ValueError(
|
||
f"Config {path}: 'indexes[{i}]' must be a non-empty string."
|
||
)
|
||
indexes.append(str(item).strip())
|
||
if len(indexes) != len(set(indexes)):
|
||
raise ValueError(
|
||
f"Config {path}: 'indexes' contains duplicate column names."
|
||
)
|
||
else:
|
||
raise ValueError(
|
||
f"Config {path}: 'indexes' must be a string or list of strings."
|
||
)
|
||
|
||
# Validate indexes vs include/exclude
|
||
if indexes:
|
||
inc_list = [str(c) for c in include] if include is not None else None
|
||
exc_list = [str(c) for c in exclude] if exclude is not None else None
|
||
if exc_list is not None:
|
||
excluded_idx = [c for c in indexes if c in exc_list]
|
||
if excluded_idx:
|
||
raise ValueError(
|
||
f"Config {path}: 'exclude' removes index columns: "
|
||
f"{excluded_idx}"
|
||
)
|
||
if inc_list is not None:
|
||
missing_in_include = [c for c in indexes if c not in inc_list]
|
||
if missing_in_include:
|
||
raise ValueError(
|
||
f"Config {path}: 'include' omits index columns: "
|
||
f"{missing_in_include}"
|
||
)
|
||
|
||
return LoaderConfig(
|
||
filename=filename,
|
||
schemaname=schemaname,
|
||
tablename=tablename,
|
||
if_exists=if_exists,
|
||
include=[str(c) for c in include] if include is not None else None,
|
||
exclude=[str(c) for c in exclude] if exclude is not None else None,
|
||
partition_by=partition_by,
|
||
max_partitions=max_partitions,
|
||
indexes=indexes,
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Reader
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _sas_reader(path: Path) -> Tuple[Any, Dict[str, Any]]:
|
||
"""Return ``(pyreadstat_reader, extra_kwargs)`` for ``path``.
|
||
|
||
Invariants (learned the hard way while building the sample generator):
|
||
|
||
* ``.xpt`` / ``.xport`` - no encoding arg; pyreadstat is flaky about
|
||
encoding on XPORT files it wrote itself.
|
||
* ``.sas7bdat`` - explicit ``encoding="latin-1"`` per colleague guidance.
|
||
"""
|
||
suffix = Path(path).suffix.lower()
|
||
if suffix in (".xpt", ".xport"):
|
||
return pyreadstat.read_xport, {}
|
||
if suffix == ".sas7bdat":
|
||
return pyreadstat.read_sas7bdat, {}
|
||
raise ValueError(f"Unsupported SAS file extension: {suffix}")
|
||
|
||
|
||
def read_sas(path: Path) -> Tuple[pd.DataFrame, Any]:
|
||
"""Read an entire SAS file into memory. Only safe for small files.
|
||
|
||
Kept for backward compatibility and tests; the CLI now uses
|
||
:func:`read_sas_preview` + :func:`iter_sas_chunks` so it never materializes
|
||
the whole frame at once.
|
||
"""
|
||
reader, kwargs = _sas_reader(path)
|
||
return reader(str(Path(path)), **kwargs)
|
||
|
||
|
||
def read_sas_preview(
|
||
path: Path,
|
||
*,
|
||
rows: Optional[int] = None,
|
||
) -> Tuple[pd.DataFrame, Any]:
|
||
"""Read the first ``rows`` records from ``path`` plus its metadata.
|
||
|
||
Defaults to ``TYPE_INFERENCE_SAMPLE_ROWS`` when ``rows`` is not given.
|
||
Passing ``rows=None`` with ``TYPE_INFERENCE_SAMPLE_ROWS=None`` reads the
|
||
whole file (pyreadstat treats ``row_limit=0`` as unlimited).
|
||
"""
|
||
reader, kwargs = _sas_reader(path)
|
||
effective = rows if rows is not None else TYPE_INFERENCE_SAMPLE_ROWS
|
||
row_limit = int(effective) if effective else 0
|
||
return reader(str(Path(path)), row_limit=row_limit, **kwargs)
|
||
|
||
|
||
def read_sas_metadata(path: Path) -> Any:
|
||
"""Read only the metadata (no rows) from a SAS file.
|
||
|
||
Uses pyreadstat's ``metadataonly=True`` fast path: the reader decodes
|
||
the file header (column names, formats, total row count, etc.) and
|
||
returns without touching the data pages. Orders of magnitude faster
|
||
than :func:`read_sas_preview` when all you need is
|
||
``meta.number_rows`` - typically a few ms per sas7bdat file, which
|
||
makes it cheap to pre-scan a whole folder to populate a global
|
||
progress bar.
|
||
"""
|
||
reader, kwargs = _sas_reader(path)
|
||
_, meta = reader(str(Path(path)), metadataonly=True, **kwargs)
|
||
return meta
|
||
|
||
|
||
def iter_sas_chunks(
|
||
path: Path,
|
||
*,
|
||
chunksize: Optional[int] = None,
|
||
):
|
||
"""Yield ``(df_chunk, meta)`` tuples for streaming loads.
|
||
|
||
Thin wrapper over ``pyreadstat.read_file_in_chunks`` that picks the right
|
||
underlying reader by extension and threads through our encoding defaults.
|
||
|
||
When ``chunksize`` is ``None`` (the default), the effective value comes
|
||
from the ``GENERIC_LOADER_CHUNK_ROWS`` environment variable if set and
|
||
parseable, otherwise from :data:`DEFAULT_CHUNK_ROWS`. An explicit int
|
||
always wins.
|
||
"""
|
||
if chunksize is None:
|
||
raw = os.environ.get("GENERIC_LOADER_CHUNK_ROWS")
|
||
if raw is not None:
|
||
try:
|
||
chunksize = int(raw)
|
||
except ValueError:
|
||
chunksize = DEFAULT_CHUNK_ROWS
|
||
else:
|
||
chunksize = DEFAULT_CHUNK_ROWS
|
||
reader, kwargs = _sas_reader(path)
|
||
yield from pyreadstat.read_file_in_chunks(
|
||
reader, str(Path(path)), chunksize=chunksize, **kwargs
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Column filtering
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def apply_column_filter(
|
||
df: pd.DataFrame,
|
||
include: Optional[List[str]],
|
||
exclude: Optional[List[str]],
|
||
) -> pd.DataFrame:
|
||
"""Restrict ``df`` to the requested columns. Names missing from the frame
|
||
raise a clear error rather than silently dropping.
|
||
|
||
Returns the input frame (or a column-sliced view / drop result) without
|
||
an extra ``.copy()`` — downstream (:func:`_prepare_for_copy`) reads the
|
||
frame into a freshly built output and never mutates its input, so the
|
||
copies were pure overhead on every streamed chunk.
|
||
"""
|
||
if include is not None and exclude is not None:
|
||
raise ValueError("include and exclude are mutually exclusive.")
|
||
|
||
if include is not None:
|
||
missing = [c for c in include if c not in df.columns]
|
||
if missing:
|
||
raise ValueError(f"include references unknown columns: {missing}")
|
||
return df.loc[:, list(include)]
|
||
|
||
if exclude is not None:
|
||
missing = [c for c in exclude if c not in df.columns]
|
||
if missing:
|
||
raise ValueError(f"exclude references unknown columns: {missing}")
|
||
return df.drop(columns=list(exclude))
|
||
|
||
return df
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Type inference
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
_DATE_FORMAT_PREFIXES = ("DATE", "YYMMDD", "MMDDYY", "DDMMYY", "JULIAN")
|
||
|
||
|
||
def _format_driven_type(sas_format: Optional[str]) -> Optional[str]:
|
||
"""Return a Postgres type inferred from the SAS format string, or None
|
||
if the format doesn't pin it down."""
|
||
if not sas_format:
|
||
return None
|
||
fmt = sas_format.upper().lstrip()
|
||
# DATETIME must be checked before DATE since "DATETIME20." starts with "DATE".
|
||
if fmt.startswith("DATETIME"):
|
||
return "TIMESTAMP"
|
||
if fmt.startswith("TIME"):
|
||
return "TIME"
|
||
for prefix in _DATE_FORMAT_PREFIXES:
|
||
if fmt.startswith(prefix):
|
||
return "DATE"
|
||
return None
|
||
|
||
|
||
def _all_null(series: pd.Series) -> bool:
|
||
if pd.api.types.is_object_dtype(series):
|
||
return bool(series.map(lambda v: v is None or (isinstance(v, str) and v == "") or (isinstance(v, float) and pd.isna(v))).all())
|
||
return bool(series.isna().all())
|
||
|
||
|
||
def _char_missing_mask(series: pd.Series) -> pd.Series:
|
||
return series.map(lambda v: v is None or (isinstance(v, float) and pd.isna(v)) or (isinstance(v, str) and v == ""))
|
||
|
||
|
||
def _is_nullable(series: pd.Series) -> bool:
|
||
"""True if the column has at least one missing value."""
|
||
if pd.api.types.is_object_dtype(series):
|
||
return bool(_char_missing_mask(series).any())
|
||
return bool(series.isna().any())
|
||
|
||
|
||
def _numeric_int_target(series: pd.Series) -> Optional[str]:
|
||
"""Given a numeric (float64) series, if every non-null value is a whole
|
||
number, return INTEGER or BIGINT depending on range; else None."""
|
||
nonnull = series.dropna()
|
||
if nonnull.empty:
|
||
return None
|
||
# Whole-number test. Guard against inf.
|
||
try:
|
||
whole = ((nonnull % 1) == 0).all()
|
||
except TypeError:
|
||
return None
|
||
if not whole:
|
||
return None
|
||
lo, hi = NUMERIC_INT_RANGE
|
||
vmin = nonnull.min()
|
||
vmax = nonnull.max()
|
||
if lo <= vmin and vmax <= hi:
|
||
return "INTEGER"
|
||
return "BIGINT"
|
||
|
||
|
||
def _object_is_dates(series: pd.Series) -> Tuple[bool, bool]:
|
||
"""Return (all-date-like, any-datetime). If every non-null value is a
|
||
``datetime.date`` / ``datetime.datetime`` / ``pd.Timestamp``, return True
|
||
plus whether at least one carries a time component."""
|
||
nonnull = series.dropna()
|
||
if nonnull.empty:
|
||
return False, False
|
||
any_datetime = False
|
||
for v in nonnull:
|
||
if isinstance(v, dt.datetime) or isinstance(v, pd.Timestamp):
|
||
any_datetime = True
|
||
continue
|
||
if isinstance(v, dt.date):
|
||
continue
|
||
return False, False
|
||
return True, any_datetime
|
||
|
||
|
||
def _try_int_coerce(values: List[str]) -> Optional[str]:
|
||
"""If every value parses as an int, return INTEGER/BIGINT, else None."""
|
||
ints: List[int] = []
|
||
for v in values:
|
||
s = v.strip()
|
||
try:
|
||
ints.append(int(s))
|
||
except ValueError:
|
||
return None
|
||
if not ints:
|
||
return None
|
||
lo, hi = NUMERIC_INT_RANGE
|
||
if all(lo <= i <= hi for i in ints):
|
||
return "INTEGER"
|
||
return "BIGINT"
|
||
|
||
|
||
def _try_float_coerce(values: List[str]) -> bool:
|
||
for v in values:
|
||
try:
|
||
float(v)
|
||
except ValueError:
|
||
return False
|
||
return True
|
||
|
||
|
||
def _try_date_coerce(values: List[str]) -> bool:
|
||
for v in values:
|
||
try:
|
||
dt.date.fromisoformat(v)
|
||
except (ValueError, TypeError):
|
||
return False
|
||
return True
|
||
|
||
|
||
def _try_datetime_coerce(values: List[str]) -> bool:
|
||
for v in values:
|
||
try:
|
||
dt.datetime.fromisoformat(v)
|
||
except (ValueError, TypeError):
|
||
return False
|
||
return True
|
||
|
||
|
||
def _infer_char_type(series: pd.Series) -> str:
|
||
"""Object/string column inference. Returns a Postgres type string."""
|
||
mask = _char_missing_mask(series)
|
||
nonempty = [str(v) for v in series[~mask].tolist()]
|
||
if not COERCE_CHAR_COLUMNS or len(nonempty) < CHAR_INFERENCE_MIN_VALUES:
|
||
return "TEXT"
|
||
|
||
int_guess = _try_int_coerce(nonempty)
|
||
if int_guess is not None:
|
||
return int_guess
|
||
if _try_float_coerce(nonempty):
|
||
return "DOUBLE PRECISION"
|
||
if _try_date_coerce(nonempty):
|
||
return "DATE"
|
||
if _try_datetime_coerce(nonempty):
|
||
return "TIMESTAMP"
|
||
return "TEXT"
|
||
|
||
|
||
def infer_schema(
|
||
df: pd.DataFrame,
|
||
meta: Any,
|
||
*,
|
||
coerce_chars: bool = COERCE_CHAR_COLUMNS,
|
||
total_rows: Optional[int] = None,
|
||
) -> Dict[str, ColumnSpec]:
|
||
"""Infer a Postgres column spec for each column in ``df``.
|
||
|
||
``meta`` is the pyreadstat metadata object; we read
|
||
``meta.original_variable_types`` (a dict keyed by column name) for
|
||
format-driven date/time/timestamp inference.
|
||
|
||
The ``coerce_chars`` kwarg lets callers override the module-level
|
||
``COERCE_CHAR_COLUMNS`` without mutating global state. Internally the
|
||
char-inference helpers still read the constant - a full override would
|
||
thread the flag through, but the one-knob story here is intentional.
|
||
|
||
``total_rows`` lets callers who already sampled the frame (e.g. via
|
||
:func:`read_sas_preview`) report the real file size in the per-column
|
||
"inferred from first N of M rows" note. Falls back to ``len(df)``.
|
||
"""
|
||
original_formats: Dict[str, str] = dict(getattr(meta, "original_variable_types", {}) or {})
|
||
|
||
# When ``TYPE_INFERENCE_SAMPLE_ROWS`` is an integer cap, row-walking type
|
||
# probes run on the head slice for speed; nullability and the all-null
|
||
# check still walk every row of ``df``. That's only honest when the
|
||
# caller handed us the full file - with the default cap of ``None`` the
|
||
# CLI does exactly that. Callers who pass a partial preview and a tight
|
||
# integer cap accept that ``NOT NULL`` can be wrong for rare-null columns.
|
||
df_rows = len(df)
|
||
effective_total = total_rows if total_rows is not None else df_rows
|
||
if TYPE_INFERENCE_SAMPLE_ROWS is not None and df_rows > TYPE_INFERENCE_SAMPLE_ROWS:
|
||
sample_df = df.head(TYPE_INFERENCE_SAMPLE_ROWS)
|
||
sample_size = TYPE_INFERENCE_SAMPLE_ROWS
|
||
else:
|
||
sample_df = df
|
||
sample_size = df_rows
|
||
sampled = sample_size < effective_total
|
||
|
||
# Temporarily flip the module-level flag if the caller asked us to.
|
||
global COERCE_CHAR_COLUMNS
|
||
saved = COERCE_CHAR_COLUMNS
|
||
COERCE_CHAR_COLUMNS = coerce_chars
|
||
try:
|
||
out: Dict[str, ColumnSpec] = {}
|
||
for col in df.columns:
|
||
series = df[col]
|
||
sample_series = sample_df[col]
|
||
sas_format = original_formats.get(col)
|
||
notes: List[str] = []
|
||
|
||
pg_type = _format_driven_type(sas_format)
|
||
|
||
if pg_type is None:
|
||
if _all_null(series):
|
||
pg_type = "TEXT"
|
||
notes.append("all-null column; defaulting to TEXT")
|
||
elif pd.api.types.is_datetime64_any_dtype(series):
|
||
pg_type = "TIMESTAMP"
|
||
elif pd.api.types.is_object_dtype(series):
|
||
is_dates, any_dt = _object_is_dates(sample_series)
|
||
if is_dates:
|
||
pg_type = "TIMESTAMP" if any_dt else "DATE"
|
||
else:
|
||
pg_type = _infer_char_type(sample_series)
|
||
elif pd.api.types.is_numeric_dtype(series):
|
||
int_target = _numeric_int_target(sample_series)
|
||
if int_target is not None:
|
||
pg_type = int_target
|
||
else:
|
||
pg_type = "DOUBLE PRECISION"
|
||
else:
|
||
pg_type = "TEXT"
|
||
notes.append(f"unhandled dtype {series.dtype}; defaulting to TEXT")
|
||
|
||
if sampled:
|
||
notes.append(
|
||
f"type inferred from first {sample_size:,} of "
|
||
f"{effective_total:,} rows"
|
||
)
|
||
|
||
nullable = _is_nullable(series)
|
||
|
||
out[col] = ColumnSpec(
|
||
name=col,
|
||
postgres_type=pg_type,
|
||
nullable=nullable,
|
||
sas_format=sas_format,
|
||
source_dtype=str(series.dtype),
|
||
notes=notes,
|
||
sampled=sampled,
|
||
)
|
||
return out
|
||
finally:
|
||
COERCE_CHAR_COLUMNS = saved
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Table management
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _quote_ident(ident: str) -> str:
|
||
"""Quote a Postgres identifier. psycopg2 doesn't expose this directly
|
||
until 2.8+ with sql.Identifier; we do it by hand to stay driver-simple."""
|
||
return '"' + ident.replace('"', '""') + '"'
|
||
|
||
|
||
def _qualified(schema: str, table: str) -> str:
|
||
return f"{_quote_ident(schema)}.{_quote_ident(table)}"
|
||
|
||
|
||
def _table_exists(conn, schema: str, table: str) -> bool:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"SELECT 1 FROM information_schema.tables "
|
||
"WHERE table_schema = %s AND table_name = %s",
|
||
(schema, table),
|
||
)
|
||
return cur.fetchone() is not None
|
||
|
||
|
||
def render_create_table(
|
||
schema: str,
|
||
table: str,
|
||
columns: Dict[str, ColumnSpec],
|
||
*,
|
||
partition_by: Optional[List[str]] = None,
|
||
) -> str:
|
||
"""Render a ``CREATE TABLE`` statement.
|
||
|
||
When ``partition_by`` is provided and non-empty, appends a
|
||
``PARTITION BY LIST ("first_field")`` clause to the statement.
|
||
"""
|
||
lines = []
|
||
for spec in columns.values():
|
||
null_clause = "" if spec.nullable else " NOT NULL"
|
||
lines.append(f" {_quote_ident(spec.name)} {spec.postgres_type}{null_clause}")
|
||
body = ",\n".join(lines)
|
||
suffix = ""
|
||
if partition_by:
|
||
suffix = f"\nPARTITION BY LIST ({_quote_ident(partition_by[0])})"
|
||
return f"CREATE TABLE {_qualified(schema, table)} (\n{body}\n){suffix};"
|
||
|
||
|
||
def _create_table_sql(
|
||
conn,
|
||
schema: str,
|
||
table: str,
|
||
columns: Dict[str, ColumnSpec],
|
||
*,
|
||
partition_by: Optional[List[str]] = None,
|
||
) -> None:
|
||
"""Execute a ``CREATE TABLE`` statement, optionally with partitioning."""
|
||
sql = render_create_table(schema, table, columns, partition_by=partition_by)
|
||
with conn.cursor() as cur:
|
||
cur.execute(sql)
|
||
|
||
|
||
def _drop_table(conn, schema: str, table: str, *, cascade: bool = False) -> None:
|
||
"""Drop a table, optionally with CASCADE for partitioned tables."""
|
||
tail = " CASCADE" if cascade else ""
|
||
with conn.cursor() as cur:
|
||
cur.execute(f"DROP TABLE {_qualified(schema, table)}{tail}")
|
||
|
||
|
||
# Normalization table: map both loader-emitted and Postgres-reported type
|
||
# strings to a single canonical family name. Ignore length/precision
|
||
# modifiers like VARCHAR(n) and NUMERIC(p,s).
|
||
_TYPE_NORMALIZATION: Dict[str, str] = {
|
||
"INTEGER": "integer",
|
||
"INT": "integer",
|
||
"INT4": "integer",
|
||
"BIGINT": "bigint",
|
||
"INT8": "bigint",
|
||
"SMALLINT": "smallint",
|
||
"INT2": "smallint",
|
||
"DOUBLE PRECISION": "double precision",
|
||
"FLOAT8": "double precision",
|
||
"REAL": "real",
|
||
"FLOAT4": "real",
|
||
"NUMERIC": "numeric",
|
||
"DECIMAL": "numeric",
|
||
"TEXT": "text",
|
||
"VARCHAR": "character varying",
|
||
"CHARACTER VARYING": "character varying",
|
||
"CHAR": "character",
|
||
"CHARACTER": "character",
|
||
"BPCHAR": "character",
|
||
"BOOLEAN": "boolean",
|
||
"BOOL": "boolean",
|
||
"DATE": "date",
|
||
"TIMESTAMP": "timestamp without time zone",
|
||
"TIMESTAMP WITHOUT TIME ZONE": "timestamp without time zone",
|
||
"TIMESTAMPTZ": "timestamp with time zone",
|
||
"TIMESTAMP WITH TIME ZONE": "timestamp with time zone",
|
||
"TIME": "time without time zone",
|
||
"TIME WITHOUT TIME ZONE": "time without time zone",
|
||
"TIMETZ": "time with time zone",
|
||
"TIME WITH TIME ZONE": "time with time zone",
|
||
}
|
||
|
||
|
||
def _normalize_type(pg_type: str) -> str:
|
||
"""Strip length/precision modifiers and map to canonical family."""
|
||
stripped = pg_type.strip().upper()
|
||
# Remove trailing (n) / (p,s) before the space-separated tail.
|
||
# Examples: "VARCHAR(10)" -> "VARCHAR"; "TIMESTAMP(6) WITHOUT TIME ZONE" -> "TIMESTAMP WITHOUT TIME ZONE"
|
||
stripped = re.sub(r"\(\s*\d+\s*(?:,\s*\d+\s*)?\)", "", stripped).strip()
|
||
# Collapse doubled whitespace after paren removal.
|
||
stripped = re.sub(r"\s+", " ", stripped)
|
||
return _TYPE_NORMALIZATION.get(stripped, stripped.lower())
|
||
|
||
|
||
# Widening pairs: (inferred_from_source, existing_in_target). When the
|
||
# incoming spec is narrower than the target we accept it - the value is
|
||
# guaranteed to fit, and ``_prepare_for_copy`` already emits ``COPY``
|
||
# payloads that Postgres silently promotes to the wider column type. The
|
||
# INVERSE direction stays a hard failure: a BIGINT value does not fit in
|
||
# an INTEGER column, so we must not let a cluster whose first file had
|
||
# only small ints accept a later file with a value past int32. Comes up
|
||
# most often on cluster loads where file 1 pushed the target to BIGINT
|
||
# (a single value > 2_147_483_647) and file N happens to sit entirely
|
||
# within int32 range - strict equality would reject file N even though
|
||
# the copy is trivially safe.
|
||
_WIDENING_COMPATIBLE: set = {
|
||
("smallint", "integer"),
|
||
("smallint", "bigint"),
|
||
("integer", "bigint"),
|
||
("real", "double precision"),
|
||
# INTEGER / BIGINT into DOUBLE PRECISION is lossless for int32 and
|
||
# exact up to 2**53 for int64, which covers every value pandas could
|
||
# have carried through as Int64 without wrapping anyway.
|
||
("integer", "double precision"),
|
||
("bigint", "double precision"),
|
||
}
|
||
|
||
|
||
def _assert_schema_compatible(
|
||
conn, schema: str, table: str, columns: Dict[str, ColumnSpec]
|
||
) -> None:
|
||
"""Pre-flight check for if_exists=append. See plan section on option B."""
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"SELECT column_name, data_type, is_nullable "
|
||
"FROM information_schema.columns "
|
||
"WHERE table_schema = %s AND table_name = %s",
|
||
(schema, table),
|
||
)
|
||
existing = {row[0]: (row[1], row[2]) for row in cur.fetchall()}
|
||
|
||
mismatches: List[str] = []
|
||
warnings: List[str] = []
|
||
|
||
for name, spec in columns.items():
|
||
if name not in existing:
|
||
mismatches.append(
|
||
f"column {name!r} not present in target {schema}.{table}"
|
||
)
|
||
continue
|
||
target_type, target_nullable = existing[name]
|
||
inferred_norm = _normalize_type(spec.postgres_type)
|
||
target_norm = _normalize_type(target_type)
|
||
if inferred_norm != target_norm:
|
||
if (inferred_norm, target_norm) in _WIDENING_COMPATIBLE:
|
||
# Narrower inferred type fits inside the wider target.
|
||
# Accept silently-but-noisily so the operator knows the
|
||
# file came in with a smaller range than the cluster's
|
||
# target was sized for.
|
||
warnings.append(
|
||
f"column {name!r}: inferred {spec.postgres_type} "
|
||
f"(narrower than target {target_type}); accepting - "
|
||
f"values fit in the wider target type"
|
||
)
|
||
else:
|
||
mismatches.append(
|
||
f"column {name!r}: inferred {spec.postgres_type} "
|
||
f"(normalized {inferred_norm!r}) but target is {target_type} "
|
||
f"(normalized {target_norm!r})"
|
||
)
|
||
target_is_notnull = (target_nullable == "NO")
|
||
if spec.nullable and target_is_notnull:
|
||
warnings.append(
|
||
f"column {name!r}: incoming allows NULLs but target is NOT NULL; "
|
||
"COPY will fail if any NULLs appear"
|
||
)
|
||
|
||
for w in warnings:
|
||
print(f"[warn] {w}", file=sys.stderr)
|
||
|
||
if mismatches:
|
||
raise SchemaCompatibilityError(
|
||
"append-mode schema compatibility check failed:\n - "
|
||
+ "\n - ".join(mismatches)
|
||
)
|
||
|
||
|
||
def assert_schema_compatible(
|
||
conn,
|
||
schema_name: str,
|
||
table_name: str,
|
||
columns: Dict[str, ColumnSpec],
|
||
) -> None:
|
||
"""Public wrapper around :func:`_assert_schema_compatible`.
|
||
|
||
Intended for orchestrators (e.g. the folder loader) that append multiple
|
||
files into one table and need to re-run the same compatibility check
|
||
that ``if_exists=append`` performs internally. Raises
|
||
:class:`SchemaCompatibilityError` on mismatch.
|
||
"""
|
||
_assert_schema_compatible(conn, schema_name, table_name, columns)
|
||
|
||
|
||
def create_table(
|
||
conn,
|
||
schema_name: str,
|
||
table_name: str,
|
||
columns: Dict[str, ColumnSpec],
|
||
if_exists: str,
|
||
*,
|
||
partition_by: Optional[List[str]] = None,
|
||
partition_values: Optional[dict] = None,
|
||
max_partitions: int = 10_000,
|
||
) -> None:
|
||
"""Create (or verify) the target table according to ``if_exists``.
|
||
|
||
When ``partition_by`` is provided and non-empty, the parent table is
|
||
created with ``PARTITION BY LIST`` and all child partition DDL from
|
||
:func:`render_partition_ddl` is executed immediately after.
|
||
|
||
For ``replace`` mode the existing table is dropped with ``CASCADE`` so
|
||
all child partitions are removed automatically.
|
||
|
||
For ``append`` mode partition creation is skipped entirely — the
|
||
partitions are assumed to already exist from the original creation.
|
||
"""
|
||
if if_exists not in VALID_IF_EXISTS:
|
||
raise ValueError(f"if_exists must be one of {VALID_IF_EXISTS}, got {if_exists!r}")
|
||
|
||
is_partitioned = bool(partition_by)
|
||
|
||
exists = _table_exists(conn, schema_name, table_name)
|
||
if exists:
|
||
if if_exists == "fail":
|
||
raise TableExistsError(
|
||
f"Table {schema_name}.{table_name} already exists and if_exists=fail"
|
||
)
|
||
if if_exists == "replace":
|
||
_drop_table(conn, schema_name, table_name, cascade=is_partitioned)
|
||
_create_table_sql(
|
||
conn, schema_name, table_name, columns,
|
||
partition_by=partition_by,
|
||
)
|
||
if is_partitioned and partition_values is not None:
|
||
ddl_stmts = render_partition_ddl(
|
||
schema_name, table_name, partition_by, partition_values,
|
||
columns, max_partitions=max_partitions,
|
||
)
|
||
with conn.cursor() as cur:
|
||
for stmt in ddl_stmts:
|
||
cur.execute(stmt)
|
||
return
|
||
if if_exists == "append":
|
||
_assert_schema_compatible(conn, schema_name, table_name, columns)
|
||
return
|
||
else:
|
||
_create_table_sql(
|
||
conn, schema_name, table_name, columns,
|
||
partition_by=partition_by,
|
||
)
|
||
if is_partitioned and partition_values is not None:
|
||
ddl_stmts = render_partition_ddl(
|
||
schema_name, table_name, partition_by, partition_values,
|
||
columns, max_partitions=max_partitions,
|
||
)
|
||
with conn.cursor() as cur:
|
||
for stmt in ddl_stmts:
|
||
cur.execute(stmt)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Partition support
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _sanitize_partition_value(value: Any, parent_table: str = "") -> str:
|
||
"""Convert a partition value into a safe, deterministic table-name suffix.
|
||
|
||
Rules:
|
||
- Convert to string, lowercase
|
||
- Replace non-alphanumeric runs with ``_``
|
||
- Collapse consecutive underscores, strip leading/trailing ``_``
|
||
- None/NaN → ``null``; empty string → ``empty``
|
||
- Truncate to fit within PostgreSQL's 63-character identifier limit
|
||
accounting for ``parent_table`` + ``_`` separator
|
||
"""
|
||
if value is None or (isinstance(value, float) and (pd.isna(value) or math.isnan(value))):
|
||
token = "null"
|
||
elif isinstance(value, dt.date) or isinstance(value, dt.datetime):
|
||
token = value.isoformat()
|
||
elif isinstance(value, dt.time):
|
||
token = value.isoformat()
|
||
else:
|
||
token = str(value)
|
||
|
||
token = token.lower()
|
||
token = re.sub(r"[^a-z0-9]+", "_", token)
|
||
token = re.sub(r"_+", "_", token)
|
||
token = token.strip("_")
|
||
|
||
if not token:
|
||
if value is None or (isinstance(value, float) and pd.isna(value)):
|
||
token = "null"
|
||
elif isinstance(value, str) and value == "":
|
||
token = "empty"
|
||
else:
|
||
token = "value"
|
||
|
||
# Truncate to keep total table name within PG's 63-char limit.
|
||
if parent_table:
|
||
# Reserve room for parent + underscore separator.
|
||
max_token_len = _PG_IDENT_MAX_LEN - len(parent_table) - 1
|
||
if max_token_len < 1:
|
||
raise ValueError(
|
||
f"Parent table name {parent_table!r} is too long "
|
||
f"({len(parent_table)} chars) to create child partitions."
|
||
)
|
||
if len(token) > max_token_len:
|
||
token = token[:max_token_len].rstrip("_")
|
||
|
||
return token
|
||
|
||
|
||
def _render_partition_value_literal(value: Any, pg_type: str) -> str:
|
||
"""Render a Python value as a SQL literal for ``FOR VALUES IN (...)``.
|
||
|
||
- None/NaN → ``NULL``
|
||
- Strings → single-quoted with ``'`` escaped to ``''``
|
||
- Numbers → plain numeric literal
|
||
- Booleans → ``TRUE`` / ``FALSE``
|
||
- Dates → ``DATE 'YYYY-MM-DD'``
|
||
- Timestamps → ``TIMESTAMP 'YYYY-MM-DD HH:MM:SS'``
|
||
- Times → ``TIME 'HH:MM:SS'``
|
||
"""
|
||
if value is None or (isinstance(value, float) and pd.isna(value)):
|
||
return "NULL"
|
||
|
||
pg_upper = pg_type.upper()
|
||
|
||
if pg_upper in ("BOOLEAN", "BOOL"):
|
||
return "TRUE" if value else "FALSE"
|
||
|
||
if pg_upper in ("INTEGER", "BIGINT", "SMALLINT", "INT", "INT4", "INT8", "INT2"):
|
||
return str(int(value))
|
||
|
||
if pg_upper in ("DOUBLE PRECISION", "REAL", "NUMERIC", "DECIMAL",
|
||
"FLOAT4", "FLOAT8"):
|
||
return str(value)
|
||
|
||
if pg_upper == "DATE":
|
||
if isinstance(value, (dt.date, dt.datetime)):
|
||
return f"DATE '{value.isoformat()}'"
|
||
return f"DATE '{value}'"
|
||
|
||
if pg_upper in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE",
|
||
"TIMESTAMP WITH TIME ZONE", "TIMESTAMPTZ"):
|
||
if isinstance(value, (dt.datetime, pd.Timestamp)):
|
||
return f"TIMESTAMP '{value.isoformat()}'"
|
||
if isinstance(value, dt.date):
|
||
return f"TIMESTAMP '{dt.datetime(value.year, value.month, value.day).isoformat()}'"
|
||
return f"TIMESTAMP '{value}'"
|
||
|
||
if pg_upper in ("TIME", "TIME WITHOUT TIME ZONE",
|
||
"TIME WITH TIME ZONE", "TIMETZ"):
|
||
if isinstance(value, dt.time):
|
||
return f"TIME '{value.isoformat()}'"
|
||
return f"TIME '{value}'"
|
||
|
||
# Default: treat as text — single-quote with escaping.
|
||
escaped = str(value).replace("'", "''")
|
||
return f"'{escaped}'"
|
||
|
||
|
||
def _normalize_partition_value(value: Any, pg_type: str) -> Any:
|
||
"""Normalize a raw partition value to its Python-native form.
|
||
|
||
Applies the same semantic normalization that :func:`_prepare_for_copy`
|
||
uses, so partition discovery deduplicates on the routed value rather
|
||
than the raw source representation.
|
||
"""
|
||
# Handle pandas null types
|
||
if value is None:
|
||
return None
|
||
if isinstance(value, float) and (pd.isna(value) or math.isnan(value)):
|
||
return None
|
||
try:
|
||
if pd.isna(value):
|
||
return None
|
||
except (TypeError, ValueError):
|
||
pass
|
||
|
||
pg_upper = pg_type.upper()
|
||
|
||
if pg_upper in ("INTEGER", "BIGINT", "SMALLINT", "INT", "INT4", "INT8", "INT2"):
|
||
if isinstance(value, str):
|
||
value = value.strip()
|
||
if value == "":
|
||
return None
|
||
try:
|
||
return int(float(value))
|
||
except (TypeError, ValueError):
|
||
return None
|
||
|
||
if pg_upper in ("DOUBLE PRECISION", "REAL", "NUMERIC", "DECIMAL",
|
||
"FLOAT4", "FLOAT8"):
|
||
if isinstance(value, str):
|
||
value = value.strip()
|
||
if value == "":
|
||
return None
|
||
try:
|
||
result = float(value)
|
||
return None if math.isnan(result) else result
|
||
except (TypeError, ValueError):
|
||
return None
|
||
|
||
if pg_upper == "DATE":
|
||
if isinstance(value, dt.datetime):
|
||
return value.date()
|
||
if isinstance(value, dt.date):
|
||
return value
|
||
if isinstance(value, str):
|
||
if value.strip() == "":
|
||
return None
|
||
try:
|
||
return dt.date.fromisoformat(value.strip())
|
||
except (ValueError, TypeError):
|
||
return None
|
||
return None
|
||
|
||
if pg_upper in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE",
|
||
"TIMESTAMP WITH TIME ZONE", "TIMESTAMPTZ"):
|
||
if isinstance(value, dt.datetime):
|
||
return value
|
||
if isinstance(value, pd.Timestamp):
|
||
return value.to_pydatetime() if not pd.isna(value) else None
|
||
if isinstance(value, dt.date):
|
||
return dt.datetime(value.year, value.month, value.day)
|
||
if isinstance(value, str):
|
||
if value.strip() == "":
|
||
return None
|
||
try:
|
||
return dt.datetime.fromisoformat(value.strip())
|
||
except (ValueError, TypeError):
|
||
return None
|
||
return None
|
||
|
||
if pg_upper in ("TIME", "TIME WITHOUT TIME ZONE",
|
||
"TIME WITH TIME ZONE", "TIMETZ"):
|
||
return _seconds_to_time(value)
|
||
|
||
if pg_upper in ("BOOLEAN", "BOOL"):
|
||
if isinstance(value, bool):
|
||
return value
|
||
if isinstance(value, (int, float)):
|
||
return bool(value)
|
||
if isinstance(value, str):
|
||
return value.strip().lower() in ("true", "1", "t", "yes")
|
||
return None
|
||
|
||
# Text-like types: None, pandas nulls, and '' all become None
|
||
# because copy_dataframes() sends empty strings with NULL ''.
|
||
if pg_upper in ("TEXT", "VARCHAR", "CHARACTER VARYING", "CHAR", "CHARACTER", "BPCHAR"):
|
||
if isinstance(value, str):
|
||
if value == "":
|
||
return None
|
||
return value
|
||
return str(value)
|
||
|
||
# Fallback: return as-is converted to native Python type
|
||
if hasattr(value, "item"):
|
||
return value.item()
|
||
return value
|
||
|
||
|
||
def discover_partition_values(
|
||
df: pd.DataFrame,
|
||
partition_by: list[str],
|
||
columns: Optional[Dict[str, ColumnSpec]] = None,
|
||
) -> dict:
|
||
"""Build a nested structure of unique partition values from a DataFrame.
|
||
|
||
For ``partition_by = ['state', 'zip']`` returns::
|
||
|
||
{
|
||
'MO': {'63101': {}, '63102': {}},
|
||
'IL': {'62001': {}, '62002': {}}
|
||
}
|
||
|
||
When ``columns`` is provided, values are normalized through
|
||
:func:`_normalize_partition_value` to match the routed values Postgres
|
||
will see during ``COPY``.
|
||
|
||
None/NaN values are included as a distinct partition value (``None`` key).
|
||
Values are converted to Python native types (not numpy types).
|
||
"""
|
||
if not partition_by:
|
||
return {}
|
||
|
||
def _to_native(val: Any) -> Any:
|
||
"""Convert numpy scalars to Python native types."""
|
||
if val is None:
|
||
return None
|
||
if isinstance(val, float) and pd.isna(val):
|
||
return None
|
||
if hasattr(val, "item"):
|
||
return val.item()
|
||
return val
|
||
|
||
def _build_level(sub_df: pd.DataFrame, fields: list[str]) -> dict:
|
||
if not fields or sub_df.empty:
|
||
return {}
|
||
|
||
field = fields[0]
|
||
remaining = fields[1:]
|
||
result: dict = {}
|
||
|
||
# Get unique values, handling NaN
|
||
unique_vals = sub_df[field].unique()
|
||
|
||
for raw_val in unique_vals:
|
||
val = _to_native(raw_val)
|
||
|
||
# Normalize if column spec is available
|
||
if columns and field in columns:
|
||
val = _normalize_partition_value(val, columns[field].postgres_type)
|
||
|
||
if remaining:
|
||
# Filter rows matching this value
|
||
if val is None:
|
||
mask = sub_df[field].isna() | sub_df[field].map(
|
||
lambda v: v is None or (isinstance(v, float) and pd.isna(v))
|
||
or (isinstance(v, str) and v == ""
|
||
and columns and field in columns
|
||
and columns[field].postgres_type.upper() in (
|
||
"TEXT", "VARCHAR", "CHARACTER VARYING",
|
||
"CHAR", "CHARACTER", "BPCHAR"))
|
||
)
|
||
else:
|
||
mask = sub_df[field].map(lambda v, target=val: _matches(v, target, field))
|
||
child_df = sub_df[mask]
|
||
result[val] = _build_level(child_df, remaining)
|
||
else:
|
||
result[val] = {}
|
||
|
||
return result
|
||
|
||
def _matches(raw_val: Any, target: Any, field_name: str) -> bool:
|
||
"""Check if a raw value normalizes to the target."""
|
||
native = _to_native(raw_val)
|
||
if columns and field_name in columns:
|
||
native = _normalize_partition_value(native, columns[field_name].postgres_type)
|
||
if target is None:
|
||
return native is None
|
||
return native == target
|
||
|
||
return _build_level(df, list(partition_by))
|
||
|
||
|
||
def discover_partition_values_chunked(
|
||
chunk_iter: Iterable[pd.DataFrame],
|
||
partition_by: list[str],
|
||
columns: Optional[Dict[str, ColumnSpec]] = None,
|
||
) -> dict:
|
||
"""Discover partition values across an iterable of DataFrame chunks.
|
||
|
||
Scans the entire file chunk-by-chunk, collecting unique partition
|
||
column values and merging them into a single nested partition tree.
|
||
This avoids materializing the full file in memory.
|
||
"""
|
||
if not partition_by:
|
||
return {}
|
||
|
||
merged: dict = {}
|
||
|
||
for chunk_df in chunk_iter:
|
||
if chunk_df.empty:
|
||
continue
|
||
# Only keep partition columns to minimize memory
|
||
part_cols = [c for c in partition_by if c in chunk_df.columns]
|
||
if len(part_cols) != len(partition_by):
|
||
missing = [c for c in partition_by if c not in chunk_df.columns]
|
||
raise ValueError(
|
||
f"Partition columns not found in data: {missing}"
|
||
)
|
||
sub_df = chunk_df[part_cols]
|
||
chunk_tree = discover_partition_values(sub_df, partition_by, columns)
|
||
_merge_partition_trees(merged, chunk_tree)
|
||
|
||
return merged
|
||
|
||
|
||
def _merge_partition_trees(target: dict, source: dict) -> None:
|
||
"""Merge ``source`` partition tree into ``target`` in place.
|
||
|
||
Both trees are nested dicts where keys are partition values and values
|
||
are either empty dicts (leaf) or nested dicts (intermediate levels).
|
||
"""
|
||
for key, subtree in source.items():
|
||
if key not in target:
|
||
target[key] = subtree
|
||
else:
|
||
# Merge children recursively
|
||
if subtree and target[key]:
|
||
_merge_partition_trees(target[key], subtree)
|
||
elif subtree:
|
||
target[key] = subtree
|
||
|
||
|
||
def _count_partitions(tree: dict) -> int:
|
||
"""Count total partition tables in a nested partition tree."""
|
||
count = 0
|
||
for _key, children in tree.items():
|
||
count += 1
|
||
if children:
|
||
count += _count_partitions(children)
|
||
return count
|
||
|
||
|
||
def render_partition_ddl(
|
||
schema: str,
|
||
parent_table: str,
|
||
partition_by: list[str],
|
||
partition_values: dict,
|
||
column_specs: Dict[str, ColumnSpec],
|
||
*,
|
||
max_partitions: int = 10_000,
|
||
) -> list[str]:
|
||
"""Generate all child partition DDL statements for the partition tree.
|
||
|
||
Returns a list of SQL strings to execute in order (depth-first).
|
||
The parent ``CREATE TABLE`` is NOT included — it is rendered separately
|
||
by :func:`render_create_table`.
|
||
|
||
Logs a warning if the total partition count exceeds ``max_partitions``,
|
||
but continues.
|
||
"""
|
||
if not partition_by or not partition_values:
|
||
return []
|
||
|
||
total = _count_partitions(partition_values)
|
||
if total > max_partitions:
|
||
logger.warning(
|
||
"Partition count (%d) exceeds threshold (%d). "
|
||
"This may impact database performance.",
|
||
total, max_partitions,
|
||
)
|
||
print(
|
||
f"[warn] partition plan for {schema}.{parent_table} will create "
|
||
f"{total:,} partition tables, exceeding max_partitions={max_partitions:,}",
|
||
file=sys.stderr,
|
||
)
|
||
|
||
# Track used child names at each parent level to detect collisions
|
||
statements: list[str] = []
|
||
_render_partition_ddl_recursive(
|
||
schema, parent_table, partition_by, partition_values,
|
||
column_specs, 0, statements,
|
||
)
|
||
return statements
|
||
|
||
|
||
def _render_partition_ddl_recursive(
|
||
schema: str,
|
||
parent_table: str,
|
||
partition_by: list[str],
|
||
values: dict,
|
||
column_specs: Dict[str, ColumnSpec],
|
||
depth: int,
|
||
statements: list[str],
|
||
) -> None:
|
||
"""Recursively generate partition DDL statements (depth-first)."""
|
||
field_name = partition_by[depth]
|
||
next_field = partition_by[depth + 1] if depth + 1 < len(partition_by) else None
|
||
field_spec = column_specs.get(field_name)
|
||
pg_type = field_spec.postgres_type if field_spec else "TEXT"
|
||
|
||
# Track names used at this level under this parent to handle collisions
|
||
used_names: Dict[str, Any] = {}
|
||
|
||
# Sort values deterministically: None first, then by string representation
|
||
def _sort_key(val: Any) -> Tuple[int, str]:
|
||
if val is None:
|
||
return (0, "")
|
||
return (1, str(val))
|
||
|
||
sorted_values = sorted(values.keys(), key=_sort_key)
|
||
|
||
for val in sorted_values:
|
||
children = values[val]
|
||
token = _sanitize_partition_value(val, parent_table)
|
||
child_name = f"{parent_table}_{token}"
|
||
|
||
# Handle collisions
|
||
if child_name in used_names and used_names[child_name] is not val:
|
||
# Append a short hash of the value to disambiguate
|
||
val_hash = hashlib.sha256(repr(val).encode()).hexdigest()[:8]
|
||
# Re-truncate token to make room for _hash
|
||
max_token_len = _PG_IDENT_MAX_LEN - len(parent_table) - 1 - 9 # _hash8
|
||
if max_token_len < 1:
|
||
max_token_len = 1
|
||
truncated_token = token[:max_token_len].rstrip("_")
|
||
child_name = f"{parent_table}_{truncated_token}_{val_hash}"
|
||
|
||
# Final length check
|
||
if len(child_name) > _PG_IDENT_MAX_LEN:
|
||
child_name = child_name[:_PG_IDENT_MAX_LEN]
|
||
|
||
used_names[child_name] = val
|
||
|
||
literal = _render_partition_value_literal(val, pg_type)
|
||
|
||
if next_field is not None:
|
||
# Intermediate partition: itself partitioned by the next field
|
||
stmt = (
|
||
f"CREATE TABLE {_qualified(schema, child_name)} "
|
||
f"PARTITION OF {_qualified(schema, parent_table)} "
|
||
f"FOR VALUES IN ({literal}) "
|
||
f"PARTITION BY LIST ({_quote_ident(next_field)});"
|
||
)
|
||
statements.append(stmt)
|
||
# Recurse into children
|
||
if children:
|
||
_render_partition_ddl_recursive(
|
||
schema, child_name, partition_by, children,
|
||
column_specs, depth + 1, statements,
|
||
)
|
||
else:
|
||
# Leaf partition
|
||
stmt = (
|
||
f"CREATE TABLE {_qualified(schema, child_name)} "
|
||
f"PARTITION OF {_qualified(schema, parent_table)} "
|
||
f"FOR VALUES IN ({literal});"
|
||
)
|
||
statements.append(stmt)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Index support
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def render_create_indexes(
|
||
schema: str,
|
||
tablename: str,
|
||
indexes: List[str],
|
||
) -> List[str]:
|
||
"""Generate ``CREATE INDEX IF NOT EXISTS`` DDL for each column in *indexes*.
|
||
|
||
Each index is a simple B-tree index on a single column. The index name
|
||
follows the pattern ``ix_{tablename}_{column}`` (raw, unsanitized names
|
||
wrapped with :func:`_quote_ident`). The table reference is fully
|
||
qualified as ``schema.tablename``.
|
||
|
||
If the generated index name exceeds PostgreSQL's 63-character identifier
|
||
limit, it is truncated and a short hash suffix is appended to preserve
|
||
uniqueness (similar to partition name truncation).
|
||
|
||
Returns a list of SQL strings, one per index.
|
||
"""
|
||
stmts: List[str] = []
|
||
for col in indexes:
|
||
idx_name = f"ix_{tablename}_{col}"
|
||
if len(idx_name) > _PG_IDENT_MAX_LEN:
|
||
# Truncate and append an 8-char hash for uniqueness.
|
||
name_hash = hashlib.sha256(idx_name.encode()).hexdigest()[:8]
|
||
# 9 = 1 underscore + 8 hash chars
|
||
truncated = idx_name[: _PG_IDENT_MAX_LEN - 9].rstrip("_")
|
||
idx_name = f"{truncated}_{name_hash}"
|
||
stmt = (
|
||
f"CREATE INDEX IF NOT EXISTS {_quote_ident(idx_name)} "
|
||
f"ON {_qualified(schema, tablename)} ({_quote_ident(col)});"
|
||
)
|
||
stmts.append(stmt)
|
||
return stmts
|
||
|
||
|
||
def create_indexes(
|
||
conn,
|
||
schema: str,
|
||
tablename: str,
|
||
indexes: List[str],
|
||
) -> None:
|
||
"""Execute ``CREATE INDEX IF NOT EXISTS`` for each column in *indexes*.
|
||
|
||
Calls :func:`render_create_indexes` to obtain the DDL, executes each
|
||
statement, commits immediately after each successful creation, and logs
|
||
progress to stderr. If an individual index creation fails (e.g. a name
|
||
collision unrelated to ``IF NOT EXISTS``), the transaction is rolled back
|
||
(affecting only the failed statement) and the remaining indexes are still
|
||
attempted.
|
||
"""
|
||
stmts = render_create_indexes(schema, tablename, indexes)
|
||
with conn.cursor() as cur:
|
||
for stmt, col in zip(stmts, indexes):
|
||
try:
|
||
cur.execute(stmt)
|
||
conn.commit()
|
||
print(
|
||
f"[info] created index ix_{tablename}_{col} "
|
||
f"on {schema}.{tablename}({col})",
|
||
file=sys.stderr,
|
||
)
|
||
except Exception as exc:
|
||
conn.rollback()
|
||
print(
|
||
f"[warn] failed to create index ix_{tablename}_{col} "
|
||
f"on {schema}.{tablename}({col}): {exc}",
|
||
file=sys.stderr,
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# COPY loading
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _seconds_to_time(v: Any) -> Optional[dt.time]:
|
||
if v is None:
|
||
return None
|
||
if isinstance(v, float) and pd.isna(v):
|
||
return None
|
||
if isinstance(v, dt.time):
|
||
return v
|
||
if isinstance(v, (dt.datetime, pd.Timestamp)):
|
||
return v.time() if not pd.isna(v) else None
|
||
try:
|
||
total = int(round(float(v)))
|
||
except (TypeError, ValueError):
|
||
return None
|
||
h, rem = divmod(total, 3600)
|
||
m, s = divmod(rem, 60)
|
||
# Clamp; TIME8. is always within a day.
|
||
h = max(0, min(h, 23))
|
||
return dt.time(h, m, s)
|
||
|
||
|
||
def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.DataFrame:
|
||
"""Materialize a copy of ``df`` with each column in the right shape for
|
||
``to_csv`` so the CSV lands as valid input for the target Postgres type.
|
||
|
||
Per-column conversions are vectorized (``.astype`` / ``pd.to_datetime`` /
|
||
``.mask`` / ``.fillna``) instead of the element-wise ``.map(func)``
|
||
loops this function used to run. That was the single largest per-chunk
|
||
CPU cost on text-heavy loads - a 40-column × 100k-row chunk was issuing
|
||
~4M Python-level function calls just to cast strings. TIME columns are
|
||
still the ``.map`` path because SAS TIME8 is stored as seconds and the
|
||
clamp-to-24h logic doesn't fit cleanly in vector form; they're also
|
||
rare in practice.
|
||
"""
|
||
out = pd.DataFrame(index=df.index)
|
||
for name, spec in columns.items():
|
||
series = df[name]
|
||
pg = spec.postgres_type.upper()
|
||
|
||
if pg in ("INTEGER", "BIGINT", "SMALLINT"):
|
||
if pd.api.types.is_object_dtype(series):
|
||
series = pd.to_numeric(
|
||
series.replace({"": None}), errors="coerce"
|
||
)
|
||
out[name] = series.astype("Int64")
|
||
elif pg in ("DOUBLE PRECISION", "REAL", "NUMERIC"):
|
||
if pd.api.types.is_object_dtype(series):
|
||
series = pd.to_numeric(
|
||
series.replace({"": None}), errors="coerce"
|
||
)
|
||
out[name] = series.astype("float64")
|
||
elif pg == "DATE":
|
||
if pd.api.types.is_datetime64_any_dtype(series):
|
||
out[name] = series.dt.date
|
||
elif pd.api.types.is_object_dtype(series):
|
||
# Vectorized parse: empty strings / None / unparseable -> NaT,
|
||
# then .dt.date yields date objects or NaT. NaT serializes as
|
||
# an empty CSV field (matching ``NULL ''`` in COPY).
|
||
parsed = pd.to_datetime(
|
||
series.replace({"": None}), errors="coerce"
|
||
)
|
||
out[name] = parsed.dt.date
|
||
elif pd.api.types.is_numeric_dtype(series):
|
||
# pyreadstat couldn't decode the SAS format (some
|
||
# ``DATEw.``/``YYMMDDw.`` variants and all custom formats slip
|
||
# through) so the column came back as float64: days since
|
||
# 1960-01-01, the SAS epoch. Without this branch the raw
|
||
# number would hit COPY and Postgres rejects it with
|
||
# ``invalid input syntax for type date``.
|
||
parsed = pd.to_datetime(
|
||
series, unit="D", origin="1960-01-01", errors="coerce",
|
||
)
|
||
out[name] = parsed.dt.date
|
||
else:
|
||
out[name] = series
|
||
elif pg in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE", "TIMESTAMP WITH TIME ZONE"):
|
||
if pd.api.types.is_datetime64_any_dtype(series):
|
||
out[name] = series
|
||
elif pd.api.types.is_object_dtype(series):
|
||
out[name] = pd.to_datetime(
|
||
series.replace({"": None}), errors="coerce"
|
||
)
|
||
elif pd.api.types.is_numeric_dtype(series):
|
||
# Same story as the DATE branch above, but SAS datetimes are
|
||
# *seconds* since 1960-01-01 (fractional seconds for
|
||
# ``DATETIMEw.d``). Example caught in the wild:
|
||
# ``1915465463.615`` -> 2020-09-13 05:44:23.615.
|
||
out[name] = pd.to_datetime(
|
||
series, unit="s", origin="1960-01-01", errors="coerce",
|
||
)
|
||
else:
|
||
out[name] = series
|
||
elif pg in ("TIME", "TIME WITHOUT TIME ZONE", "TIME WITH TIME ZONE"):
|
||
out[name] = series.map(_seconds_to_time)
|
||
elif pg in ("TEXT", "VARCHAR", "CHARACTER VARYING", "CHAR", "CHARACTER"):
|
||
# Render every cell as a string and blank out nulls. ``NULL ''``
|
||
# in the COPY statement turns the blanks back into SQL NULL.
|
||
# astype(str) stringifies NaN/None to the literal "nan"/"None",
|
||
# so we mask those after the fact rather than branching per cell.
|
||
na_mask = series.isna()
|
||
out[name] = series.astype(str).mask(na_mask, "")
|
||
elif pg == "BOOLEAN":
|
||
out[name] = series.astype("boolean") if series.dtype != object else series
|
||
else:
|
||
out[name] = series
|
||
return out
|
||
|
||
|
||
def _serialize_chunk_csv(prepared: pd.DataFrame) -> io.BytesIO:
|
||
"""Serialize a prepared frame into a CSV buffer for ``COPY FROM STDIN``.
|
||
|
||
Uses ``pyarrow.csv.write_csv`` (typically 5-10× faster than pandas'
|
||
pure-Python ``to_csv`` on wide/text-heavy frames). Null cells serialize
|
||
as empty strings and date/timestamp values land in ISO 8601 form, both
|
||
of which Postgres accepts under ``FORMAT csv, NULL ''``.
|
||
"""
|
||
table = pa.Table.from_pandas(prepared, preserve_index=False)
|
||
buf = io.BytesIO()
|
||
pa_csv.write_csv(
|
||
table,
|
||
buf,
|
||
write_options=pa_csv.WriteOptions(include_header=False),
|
||
)
|
||
buf.seek(0)
|
||
return buf
|
||
|
||
|
||
def copy_dataframes(
|
||
conn,
|
||
schema_name: str,
|
||
table_name: str,
|
||
dfs: Iterable[pd.DataFrame],
|
||
columns: Dict[str, ColumnSpec],
|
||
) -> int:
|
||
"""Stream an iterable of DataFrames into Postgres, committing each chunk.
|
||
|
||
Each non-empty chunk is copied via ``COPY ... FROM STDIN`` and committed
|
||
before the next chunk is processed, so an interrupted or failed load
|
||
retains the rows from previously committed chunks. The first chunk's
|
||
commit also flushes any pending DDL (e.g. a preceding ``CREATE TABLE``).
|
||
Empty chunks are skipped. Returns the total rows inserted.
|
||
"""
|
||
col_list = ", ".join(_quote_ident(name) for name in columns.keys())
|
||
sql = (
|
||
f"COPY {_qualified(schema_name, table_name)} ({col_list}) "
|
||
f"FROM STDIN WITH (FORMAT csv, NULL '')"
|
||
)
|
||
|
||
total = 0
|
||
# Pull chunks one at a time so each ``df`` is unreferenced before the
|
||
# generator reads the next one. Without this the loop-variable binding
|
||
# of a ``for df in dfs:`` keeps the previous chunk alive during the
|
||
# next pyreadstat read, pushing peak memory to 5-6× chunk size per
|
||
# worker (old df + incoming df + prepared + pyarrow table + CSV buf).
|
||
# With explicit drops we cap peak at ~2× chunk size: ``df`` goes away
|
||
# once ``prepared`` exists, ``prepared`` once ``buf`` exists, ``buf``
|
||
# once COPY has consumed it. Matters most in parallel mode where
|
||
# 32 × per-worker peak can exhaust a 128 GB host.
|
||
dfs_iter = iter(dfs)
|
||
with conn.cursor() as cur:
|
||
while True:
|
||
try:
|
||
df = next(dfs_iter)
|
||
except StopIteration:
|
||
break
|
||
if df.empty:
|
||
del df
|
||
continue
|
||
prepared = _prepare_for_copy(df, columns)
|
||
del df
|
||
n = len(prepared)
|
||
buf = _serialize_chunk_csv(prepared)
|
||
del prepared
|
||
cur.copy_expert(sql, buf)
|
||
del buf
|
||
conn.commit()
|
||
total += n
|
||
# Hand pyarrow's pool memory back between chunks. Without this,
|
||
# arrow's internal buffer pool keeps the high-water bytes
|
||
# reserved across the worker's lifetime - inside long-running
|
||
# workers this presents as steadily climbing RSS even with the
|
||
# ``del``s above. Cheap (microseconds); call it every chunk.
|
||
try:
|
||
pa.default_memory_pool().release_unused()
|
||
except Exception:
|
||
pass
|
||
return total
|
||
|
||
|
||
def copy_dataframe(
|
||
conn,
|
||
schema_name: str,
|
||
table_name: str,
|
||
df: pd.DataFrame,
|
||
columns: Dict[str, ColumnSpec],
|
||
) -> int:
|
||
"""Stream ``df`` into Postgres via ``COPY ... FROM STDIN``.
|
||
|
||
Convenience wrapper around :func:`copy_dataframes` for single-frame
|
||
callers. Returns the number of rows inserted.
|
||
"""
|
||
return copy_dataframes(conn, schema_name, table_name, [df], columns)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Manifest validation
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _match_manifest_type(inferred: str, manifest_entry: Dict[str, Any]) -> bool:
|
||
inferred_norm = _normalize_type(inferred)
|
||
if "postgres_type" in manifest_entry:
|
||
return inferred_norm == _normalize_type(manifest_entry["postgres_type"])
|
||
if "acceptable_types" in manifest_entry:
|
||
return any(
|
||
inferred_norm == _normalize_type(t)
|
||
for t in manifest_entry["acceptable_types"]
|
||
)
|
||
return False
|
||
|
||
|
||
def validate_against_manifest(
|
||
inferred: Dict[str, ColumnSpec],
|
||
manifest_path: Path,
|
||
) -> List[str]:
|
||
"""Compare the inferred schema against the expected-types manifest.
|
||
|
||
Returns a list of human-readable problem strings; empty list means OK.
|
||
"""
|
||
manifest_path = Path(manifest_path)
|
||
if not manifest_path.exists():
|
||
return [f"manifest not found: {manifest_path}"]
|
||
|
||
with manifest_path.open("r", encoding="utf-8") as f:
|
||
manifest = json.load(f)
|
||
|
||
problems: List[str] = []
|
||
|
||
only_in_inferred = set(inferred) - set(manifest)
|
||
only_in_manifest = set(manifest) - set(inferred)
|
||
if only_in_inferred:
|
||
problems.append(
|
||
f"columns in inferred but not manifest: {sorted(only_in_inferred)}"
|
||
)
|
||
if only_in_manifest:
|
||
problems.append(
|
||
f"columns in manifest but not inferred: {sorted(only_in_manifest)}"
|
||
)
|
||
|
||
for name, spec in inferred.items():
|
||
entry = manifest.get(name)
|
||
if entry is None:
|
||
continue
|
||
if not _match_manifest_type(spec.postgres_type, entry):
|
||
expected = entry.get("postgres_type") or entry.get("acceptable_types")
|
||
problems.append(
|
||
f"column {name!r}: inferred {spec.postgres_type!r}, "
|
||
f"manifest expected {expected!r}"
|
||
)
|
||
manifest_nullable = bool(entry.get("nullable", True))
|
||
if spec.nullable and not manifest_nullable:
|
||
problems.append(
|
||
f"column {name!r}: inferred nullable, manifest expects NOT NULL "
|
||
f"(loosening nullability is never allowed)"
|
||
)
|
||
|
||
return problems
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# CLI
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _build_argparser() -> argparse.ArgumentParser:
|
||
p = argparse.ArgumentParser(
|
||
description="Load a single SAS file (XPT or sas7bdat) into Postgres.",
|
||
)
|
||
p.add_argument("--config", required=True, type=Path, help="Path to YAML config")
|
||
p.add_argument(
|
||
"--validate",
|
||
action="store_true",
|
||
help=(
|
||
"Compare inferred schema against <filename-stem>.expected.json "
|
||
"next to the SAS file; exits nonzero on mismatch."
|
||
),
|
||
)
|
||
p.add_argument(
|
||
"--dry-run",
|
||
action="store_true",
|
||
help="Print inferred CREATE TABLE and stop; don't touch Postgres.",
|
||
)
|
||
p.add_argument(
|
||
"--dbcreds",
|
||
action="store_true",
|
||
help=(
|
||
"Prompt for database username and password instead of reading "
|
||
"PGUSER / PGPASSWORD from the environment or .env file."
|
||
),
|
||
)
|
||
return p
|
||
|
||
|
||
def _format_columns_summary(columns: Dict[str, ColumnSpec]) -> str:
|
||
lines = []
|
||
for spec in columns.values():
|
||
null = "" if spec.nullable else " NOT NULL"
|
||
lines.append(f" {spec.name}: {spec.postgres_type}{null}")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def main(argv: Optional[List[str]] = None) -> int:
|
||
args = _build_argparser().parse_args(argv)
|
||
|
||
load_dotenv()
|
||
|
||
cfg = load_config(args.config)
|
||
|
||
if not cfg.filename.exists():
|
||
print(f"error: SAS file not found: {cfg.filename}", file=sys.stderr)
|
||
return 2
|
||
|
||
# Schema inference reads the whole file so type + nullability are
|
||
# computed against every row. That's what the target host has the
|
||
# resources for and is the only way to honestly emit ``NOT NULL`` -
|
||
# a bounded preview routinely missed the ~0.2% of rows with nulls on
|
||
# otherwise-dense keys (e.g. MAFID). If you're on a box that can't
|
||
# fit the file in memory, override ``TYPE_INFERENCE_SAMPLE_ROWS`` to
|
||
# an integer cap and know that sampled specs may stamp ``NOT NULL``
|
||
# on columns whose nulls live past the window.
|
||
preview_df, meta = read_sas_preview(cfg.filename)
|
||
preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude)
|
||
columns = infer_schema(preview_df, meta)
|
||
|
||
# Validate partition columns exist in the schema after filtering.
|
||
if cfg.partition_by:
|
||
missing_pcols = [c for c in cfg.partition_by if c not in columns]
|
||
if missing_pcols:
|
||
raise ValueError(
|
||
f"partition_by references columns not present in the "
|
||
f"(filtered) schema: {missing_pcols}"
|
||
)
|
||
|
||
# Validate index columns exist in the schema after filtering.
|
||
if cfg.indexes:
|
||
missing_icols = [c for c in cfg.indexes if c not in columns]
|
||
if missing_icols:
|
||
raise ValueError(
|
||
f"indexes references columns not present in the "
|
||
f"(filtered) schema: {missing_icols}"
|
||
)
|
||
|
||
if args.validate:
|
||
manifest_path = cfg.filename.with_suffix("").with_suffix(".expected.json")
|
||
# The above strips .xpt then appends .expected.json, e.g.
|
||
# "sample_kitchensink.xpt" -> "sample_kitchensink.expected.json".
|
||
problems = validate_against_manifest(columns, manifest_path)
|
||
if problems:
|
||
print("validation failed:", file=sys.stderr)
|
||
for p in problems:
|
||
print(f" - {p}", file=sys.stderr)
|
||
return 1
|
||
print(f"validation OK ({len(columns)} columns match {manifest_path.name})")
|
||
|
||
# -- Partition value discovery ------------------------------------------
|
||
# If partitioned, scan the ENTIRE file to discover all unique partition
|
||
# values. The preview is only the first N rows and may miss values.
|
||
# In append mode the partitions already exist, so skip the costly scan.
|
||
partition_values: Optional[dict] = None
|
||
if cfg.partition_by and cfg.if_exists != "append":
|
||
print(" discovering partition values (full file scan)...", file=sys.stderr)
|
||
|
||
def _discovery_chunks():
|
||
for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename):
|
||
yield apply_column_filter(chunk_df, cfg.include, cfg.exclude)
|
||
|
||
partition_values = discover_partition_values_chunked(
|
||
_discovery_chunks(), cfg.partition_by, columns,
|
||
)
|
||
total_parts = _count_partitions(partition_values)
|
||
print(
|
||
f" discovered {total_parts:,} partition tables "
|
||
f"across {len(cfg.partition_by)} level(s)",
|
||
file=sys.stderr,
|
||
)
|
||
elif cfg.partition_by and cfg.if_exists == "append":
|
||
print(
|
||
" [info] append mode: skipping partition discovery "
|
||
"(partitions assumed to exist)",
|
||
file=sys.stderr,
|
||
)
|
||
|
||
if args.dry_run:
|
||
# Print the parent CREATE TABLE (with PARTITION BY if applicable).
|
||
parent_ddl = render_create_table(
|
||
cfg.schemaname, cfg.tablename, columns,
|
||
partition_by=cfg.partition_by or None,
|
||
)
|
||
print(parent_ddl)
|
||
# Print child partition DDL if partitioned.
|
||
if cfg.partition_by and partition_values:
|
||
child_stmts = render_partition_ddl(
|
||
cfg.schemaname, cfg.tablename, cfg.partition_by,
|
||
partition_values, columns,
|
||
max_partitions=cfg.max_partitions,
|
||
)
|
||
for stmt in child_stmts:
|
||
print()
|
||
print(stmt)
|
||
# Print CREATE INDEX DDL if indexes are configured.
|
||
if cfg.indexes:
|
||
idx_stmts = render_create_indexes(
|
||
cfg.schemaname, cfg.tablename, cfg.indexes,
|
||
)
|
||
for stmt in idx_stmts:
|
||
print()
|
||
print(stmt)
|
||
return 0
|
||
|
||
# Release the preview frame before opening the stream - lets the GC reclaim
|
||
# it while we're holding a Postgres transaction open.
|
||
del preview_df
|
||
|
||
total_rows = getattr(meta, "number_rows", None)
|
||
|
||
def _filtered_chunks():
|
||
pbar = tqdm(
|
||
total=total_rows,
|
||
unit="row",
|
||
unit_scale=True,
|
||
desc=f" {cfg.filename.name}",
|
||
file=sys.stderr,
|
||
dynamic_ncols=True,
|
||
)
|
||
try:
|
||
for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename):
|
||
chunk_df = apply_column_filter(chunk_df, cfg.include, cfg.exclude)
|
||
pbar.update(len(chunk_df))
|
||
yield chunk_df
|
||
finally:
|
||
pbar.close()
|
||
|
||
db_user = db_password = None
|
||
if args.dbcreds:
|
||
db_user = input("Database username: ")
|
||
db_password = getpass.getpass("Database password: ")
|
||
|
||
conn = connect(user=db_user, password=db_password)
|
||
conn.autocommit = False
|
||
try:
|
||
create_table(
|
||
conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists,
|
||
partition_by=cfg.partition_by or None,
|
||
partition_values=partition_values,
|
||
max_partitions=cfg.max_partitions,
|
||
)
|
||
inserted = copy_dataframes(
|
||
conn, cfg.schemaname, cfg.tablename, _filtered_chunks(), columns
|
||
)
|
||
conn.commit()
|
||
if cfg.indexes:
|
||
create_indexes(conn, cfg.schemaname, cfg.tablename, cfg.indexes)
|
||
except Exception:
|
||
conn.rollback()
|
||
raise
|
||
finally:
|
||
conn.close()
|
||
|
||
print(
|
||
f"loaded {inserted} rows into {cfg.schemaname}.{cfg.tablename} "
|
||
f"({len(columns)} columns)"
|
||
)
|
||
if cfg.partition_by and partition_values:
|
||
total_parts = _count_partitions(partition_values)
|
||
print(f"partitioned by {cfg.partition_by} ({total_parts:,} partition tables)")
|
||
print("final schema:")
|
||
print(_format_columns_summary(columns))
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main())
|