Introduced a new mechanism to sample non-null values for determining the appropriate datetime parsing strategy, significantly reducing processing time for large datasets. This change replaces the previous full row-walk method with a more efficient sampling technique, enhancing performance while maintaining robust handling of various date formats. Updated comments for clarity on the new approach.
3197 lines
120 KiB
Python
3197 lines
120 KiB
Python
"""Per-file data-to-Postgres loader (SAS and delimited text).
|
||
|
||
Library-style functions plus a thin CLI wrapper. Designed so an orchestrator
|
||
can wrap the library for directory/batch mode; orchestration is out of scope
|
||
here.
|
||
|
||
Python 3.14 compatible (target is an air-gapped host that currently only has
|
||
3.14). ``from __future__ import annotations`` lets us use PEP 585 generics
|
||
as annotations; runtime-resolved type uses (dataclass defaults, etc.) stick
|
||
to ``typing``.
|
||
|
||
-------------------------------------------------------------------------------
|
||
USAGE
|
||
-------------------------------------------------------------------------------
|
||
|
||
Supported inputs:
|
||
* ``.sas7bdat`` (read with ``encoding="latin-1"``)
|
||
* ``.xpt`` / ``.xport`` (SAS transport files)
|
||
* ``.csv`` / ``.tsv`` / ``.txt`` (delimited text files with headers)
|
||
|
||
1. YAML config
|
||
--------------
|
||
Every invocation is driven by a YAML file describing one data file to load::
|
||
|
||
filename: samples/sample_kitchensink.xpt # required; relative paths are
|
||
# resolved against the config
|
||
# file's directory when possible
|
||
schemaname: public # required
|
||
tablename: kitchensink # required
|
||
|
||
# Optional. One of: fail | replace | append. Default: fail.
|
||
# fail - error out if the target table already exists
|
||
# replace - DROP and recreate the table from the inferred schema
|
||
# append - keep the existing table; pre-flight a schema-compat check,
|
||
# then COPY the new rows in
|
||
if_exists: append
|
||
|
||
# Optional, mutually exclusive. Restrict which columns are loaded.
|
||
# include:
|
||
# - ID
|
||
# - INTCOL
|
||
# exclude:
|
||
# - ALLNULL
|
||
|
||
2. Database connection
|
||
----------------------
|
||
The loader uses standard libpq environment variables (read via ``os.environ``)::
|
||
|
||
PGHOST, PGPORT, PGUSER, PGPASSWORD, PGDATABASE
|
||
|
||
The CLI calls ``python-dotenv``'s ``load_dotenv()`` at startup, so a local
|
||
``.env`` file is picked up automatically. Library callers are responsible for
|
||
populating the environment themselves (either call ``load_dotenv()`` or export
|
||
the vars) before calling :func:`connect`.
|
||
|
||
3. Command-line interface
|
||
-------------------------
|
||
::
|
||
|
||
python load_sas.py --config path/to/config.yaml [--validate] [--dry-run]
|
||
[--dbcreds]
|
||
|
||
Flags:
|
||
--config PATH Required. Path to the YAML config above.
|
||
--validate Compare the inferred schema against
|
||
``<sas-file-stem>.expected.json`` sitting next to the SAS
|
||
file. Exits nonzero on mismatch. Safe to combine with
|
||
``--dry-run``.
|
||
--dry-run Print the inferred ``CREATE TABLE`` SQL and stop. The
|
||
database is never touched (no connection is opened).
|
||
--dbcreds Prompt interactively for the database username and
|
||
password instead of reading ``PGUSER`` / ``PGPASSWORD``
|
||
from the environment or ``.env`` file. The password
|
||
prompt does not echo. Has no effect with ``--dry-run``
|
||
(no connection is opened).
|
||
|
||
Exit codes:
|
||
0 - success (load completed, or dry-run/validate passed)
|
||
1 - validation failure
|
||
2 - config references a SAS file that does not exist
|
||
Other nonzero - uncaught exception (traceback printed); the transaction
|
||
is rolled back before exit.
|
||
|
||
Typical invocations::
|
||
|
||
# Preview the inferred schema without connecting to Postgres.
|
||
python load_sas.py --config sample_config.yaml --dry-run
|
||
|
||
# Check the inferred schema against an expected-types manifest.
|
||
python load_sas.py --config sample_config.yaml --validate --dry-run
|
||
|
||
# Actually load the data.
|
||
python load_sas.py --config sample_config.yaml
|
||
|
||
# Load the data, prompting for credentials instead of using .env.
|
||
python load_sas.py --config sample_config.yaml --dbcreds
|
||
|
||
4. Expected-types manifest (``--validate``)
|
||
-------------------------------------------
|
||
``--validate`` looks for a JSON file named ``<sas-stem>.expected.json`` next
|
||
to the SAS file, e.g. ``samples/sample_kitchensink.xpt`` pairs with
|
||
``samples/sample_kitchensink.expected.json``. Each top-level key is a column
|
||
name; the value is an object with any of::
|
||
|
||
{
|
||
"postgres_type": "BIGINT", # exact expected type, OR
|
||
"acceptable_types": ["TEXT", # any-of list of acceptable types
|
||
"VARCHAR"],
|
||
"nullable": true, # default true; false = must be NOT NULL
|
||
"note": "free-form comment" # ignored by the loader
|
||
}
|
||
|
||
Type comparison ignores length/precision modifiers and normalizes synonyms
|
||
(e.g. ``INT`` == ``INTEGER`` == ``INT4``; ``VARCHAR(10)`` == ``VARCHAR``).
|
||
Nullability tightening (inferred NULL, manifest NOT NULL) is a hard failure;
|
||
loosening is not checked here because the append-mode check already covers it.
|
||
|
||
5. Library usage
|
||
----------------
|
||
The CLI is a thin wrapper around composable functions. The preferred pattern
|
||
infers the schema from a bounded preview and then streams the rest of the
|
||
file chunk-by-chunk into ``COPY`` - crucial for SAS files with hundreds of
|
||
millions of rows::
|
||
|
||
from dotenv import load_dotenv
|
||
from load_sas import (
|
||
load_config, read_sas_preview, iter_sas_chunks, apply_column_filter,
|
||
infer_schema, validate_against_manifest, render_create_table,
|
||
connect, create_table, copy_dataframes,
|
||
)
|
||
|
||
load_dotenv()
|
||
cfg = load_config("config.yaml")
|
||
|
||
# Schema from a preview slice (bounded by TYPE_INFERENCE_SAMPLE_ROWS).
|
||
preview_df, meta = read_sas_preview(cfg.filename)
|
||
preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude)
|
||
total_rows = getattr(meta, "number_rows", None)
|
||
columns = infer_schema(preview_df, meta, total_rows=total_rows)
|
||
|
||
# Optional: preview DDL / validate against a manifest.
|
||
print(render_create_table(cfg.schemaname, cfg.tablename, columns))
|
||
problems = validate_against_manifest(columns, Path("expected.json"))
|
||
assert not problems, problems
|
||
|
||
conn = connect()
|
||
conn.autocommit = False
|
||
try:
|
||
create_table(conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists)
|
||
chunks = (
|
||
apply_column_filter(df, cfg.include, cfg.exclude)
|
||
for df, _ in iter_sas_chunks(cfg.filename)
|
||
)
|
||
rows = copy_dataframes(conn, cfg.schemaname, cfg.tablename, chunks, columns)
|
||
conn.commit()
|
||
finally:
|
||
conn.close()
|
||
|
||
For small files (or tests) the legacy one-shot API still works:
|
||
:func:`read_sas` returns the whole frame and :func:`copy_dataframe` copies it
|
||
in one round trip.
|
||
|
||
All functions are side-effect free except :func:`connect`, :func:`create_table`,
|
||
:func:`copy_dataframe`, and :func:`copy_dataframes`; schema inference
|
||
(:func:`infer_schema`) accepts a ``coerce_chars`` kwarg to override the
|
||
module-level ``COERCE_CHAR_COLUMNS`` without mutating global state.
|
||
|
||
6. Type inference summary
|
||
-------------------------
|
||
Priority order used by :func:`infer_schema`:
|
||
|
||
1. SAS format string (via ``meta.original_variable_types``):
|
||
``DATETIME*`` -> ``TIMESTAMP``, ``TIME*`` -> ``TIME``,
|
||
``DATE*`` / ``YYMMDD*`` / ``MMDDYY*`` / ``DDMMYY*`` / ``JULIAN*`` -> ``DATE``.
|
||
2. All-null column -> ``TEXT`` (with a note).
|
||
3. pandas datetime dtype -> ``TIMESTAMP``.
|
||
4. Object columns containing only ``datetime.date`` / ``datetime.datetime``
|
||
-> ``DATE`` or ``TIMESTAMP``.
|
||
5. Object columns of strings: if ``COERCE_CHAR_COLUMNS`` is True and at
|
||
least ``CHAR_INFERENCE_MIN_VALUES`` non-empty values parse cleanly, they
|
||
are promoted to ``INTEGER`` / ``BIGINT`` / ``DOUBLE PRECISION`` /
|
||
``DATE`` / ``TIMESTAMP``; otherwise ``TEXT``.
|
||
6. Numeric columns of whole numbers -> ``INTEGER`` (or ``BIGINT`` if any
|
||
value exceeds the int32 range ``NUMERIC_INT_RANGE``); otherwise
|
||
``DOUBLE PRECISION``.
|
||
|
||
Type inference scans the whole file by default (``TYPE_INFERENCE_SAMPLE_ROWS
|
||
= None``) so type + nullability are both computed against every row. The CLI
|
||
materializes the file once for schema inference, then re-streams it chunk by
|
||
chunk into ``COPY``; peak memory is roughly one full dataframe. Override
|
||
``TYPE_INFERENCE_SAMPLE_ROWS`` to an integer cap if you're on a host that
|
||
can't hold the file in memory - but know that sampled specs carry the usual
|
||
risks: a later row may exceed the inferred integer range, or a column that
|
||
had no nulls in the preview may carry nulls later in the file (which then
|
||
detonates ``COPY`` because the sampled spec stamped it ``NOT NULL``). Seen
|
||
in production on a 2.5M-row file with ~6k null MAFIDs past the 10k-row
|
||
preview - the entire load aborted mid-stream.
|
||
|
||
Streaming loads use :func:`iter_sas_chunks` + :func:`copy_dataframes`, which
|
||
commit each chunk as it is copied so an interrupted load retains the rows
|
||
that were already written.
|
||
|
||
7. Tunables
|
||
-----------
|
||
Module-level knobs at the top of this file:
|
||
|
||
* ``COERCE_CHAR_COLUMNS`` - promote stringly-typed numerics / dates
|
||
(default True).
|
||
* ``CHAR_INFERENCE_MIN_VALUES`` - minimum non-empty sample size before
|
||
char-column coercion is attempted.
|
||
* ``NUMERIC_INT_RANGE`` - INTEGER bounds; values outside become
|
||
``BIGINT``.
|
||
* ``TYPE_INFERENCE_SAMPLE_ROWS`` - cap on rows read for type inference
|
||
(``None`` = scan the whole column).
|
||
* ``DEFAULT_CHUNK_ROWS`` - rows per streaming COPY chunk.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import datetime as dt
|
||
import getpass
|
||
import hashlib
|
||
import io
|
||
import json
|
||
import logging
|
||
import math
|
||
import os
|
||
import re
|
||
import sys
|
||
import warnings
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import psycopg2
|
||
import psycopg2.extensions
|
||
import pyarrow as pa
|
||
import pyarrow.csv as pa_csv
|
||
import pyreadstat
|
||
import yaml
|
||
from dotenv import load_dotenv
|
||
from pandas.errors import PerformanceWarning
|
||
from tqdm import tqdm
|
||
|
||
# ``_prepare_for_copy`` builds its output frame one column at a time with
|
||
# ``out[name] = ...``. On wide SAS files (~100+ columns) pandas prints a
|
||
# ``PerformanceWarning: DataFrame is highly fragmented`` once per chunk to
|
||
# nudge callers toward ``pd.concat(axis=1, ...)``. The fragmentation only
|
||
# matters for row-oriented ops or in-place ``.copy()``; we hand the frame
|
||
# straight to ``pyarrow.Table.from_pandas`` which reads columns
|
||
# independently, so the warning is pure noise for our pipeline. Filter it
|
||
# at import time - narrow category match so nothing else is suppressed.
|
||
warnings.filterwarnings("ignore", category=PerformanceWarning)
|
||
|
||
# Turn numpy's "raise on float overflow" (and friends) into silent inf/nan
|
||
# production, module-wide. Pandas ships with ``np.errstate(over="raise")``
|
||
# wrapped around several internal ops (most painfully, the multiply inside
|
||
# ``pd.to_datetime(unit="s")`` that converts SAS epoch -> nanoseconds).
|
||
# Our data routinely carries ``inf`` / huge sentinels, which trip that
|
||
# ``raise`` and blow up an entire worker before ``errors="coerce"`` gets
|
||
# a chance to turn them into NaT. Even with ``_safe_numeric_to_datetime``
|
||
# pre-masking the obvious cases, other code paths (pandas object-dtype
|
||
# datetime parsing, pyarrow type promotion, pyreadstat) can also trigger.
|
||
# Setting a process-wide ``seterr`` is a heavier hammer than an
|
||
# ``errstate`` block but survives library internals that don't explicitly
|
||
# rewrap it. Downside: a real overflow bug in new code would now silently
|
||
# produce inf/nan instead of raising - acceptable for a bulk loader where
|
||
# "don't crash on bad rows, null them and move on" is the whole point.
|
||
np.seterr(over="ignore", invalid="ignore", divide="ignore")
|
||
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Top-level tunables
|
||
# ---------------------------------------------------------------------------
|
||
|
||
COERCE_CHAR_COLUMNS = True
|
||
"""If True, try to promote object (string) columns to numeric/date/timestamp
|
||
when every non-empty value parses cleanly."""
|
||
|
||
CHAR_INFERENCE_MIN_VALUES = 3
|
||
"""Don't attempt character-column coercion with fewer than this many non-empty
|
||
values; too small a sample is easy to mis-infer."""
|
||
|
||
NUMERIC_INT_RANGE = (-2_147_483_648, 2_147_483_647)
|
||
"""INTEGER bounds; anything outside becomes BIGINT."""
|
||
|
||
TYPE_INFERENCE_SAMPLE_ROWS: Optional[int] = None
|
||
"""Cap on rows inspected during per-column type inference. Also governs how
|
||
many rows :func:`read_sas_preview` pulls from the file for dry-run / validate /
|
||
schema-inference flows.
|
||
|
||
Default is ``None`` (scan every row, reading the whole file into memory for
|
||
the schema-inference step). That's the only honest setting for nullability:
|
||
any integer cap lets a column look ``NOT NULL`` across the first N rows
|
||
while the file actually holds rare nulls past the window, which then
|
||
detonates ``COPY`` mid-stream (seen in production on a 2.5M-row file where
|
||
~6k MAFIDs were null past the 10k-row preview). If you're loading a file
|
||
so large that a full read won't fit in memory, set this to an integer cap
|
||
and accept that sampled specs can't be trusted for ``NOT NULL``."""
|
||
|
||
DEFAULT_CHUNK_ROWS = 2_000_000
|
||
"""Rows per chunk when streaming a SAS file into ``COPY``. Larger values mean
|
||
fewer COPY round-trips and lower per-row overhead but more peak memory per
|
||
chunk; smaller values are gentler on memory.
|
||
|
||
The chunk size can be overridden at runtime via the
|
||
``GENERIC_LOADER_CHUNK_ROWS`` environment variable (read inside
|
||
:func:`iter_sas_chunks`), so ``.env``-driven overrides work without code
|
||
changes. Explicit ``chunksize=`` kwargs still win over both."""
|
||
|
||
|
||
VALID_IF_EXISTS = ("fail", "replace", "append")
|
||
|
||
VALID_FILE_TYPES = ("sas", "text")
|
||
"""Supported ``file_type`` values in the YAML config."""
|
||
|
||
TEXT_EXTENSIONS = (".txt", ".csv", ".tsv")
|
||
"""File extensions recognised as delimited text files."""
|
||
|
||
_PG_IDENT_MAX_LEN = 63
|
||
"""PostgreSQL maximum identifier length in bytes (characters for ASCII)."""
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Dataclasses
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
@dataclass
|
||
class TextFileMetadata:
|
||
"""Minimal metadata object for text files, mimicking pyreadstat metadata.
|
||
|
||
Provides the same attribute surface that :func:`infer_schema` reads from
|
||
pyreadstat metadata objects: ``column_names``, ``column_labels``,
|
||
``original_variable_types``, and ``number_rows``.
|
||
"""
|
||
column_names: List[str]
|
||
column_labels: List[str]
|
||
original_variable_types: Dict[str, str]
|
||
number_rows: Optional[int] = None
|
||
|
||
|
||
@dataclass
|
||
class LoaderConfig:
|
||
filename: Path
|
||
schemaname: str
|
||
tablename: str
|
||
if_exists: str = "fail"
|
||
include: Optional[List[str]] = None
|
||
exclude: Optional[List[str]] = None
|
||
partition_by: List[str] = field(default_factory=list)
|
||
max_partitions: int = 10_000
|
||
indexes: List[str] = field(default_factory=list)
|
||
column_types: Dict[str, str] = field(default_factory=dict)
|
||
all_nullable: bool = False
|
||
file_type: str = "sas"
|
||
delimiter: str = ","
|
||
text_encoding: str = "utf-8"
|
||
quotechar: str = '"'
|
||
|
||
|
||
@dataclass
|
||
class ColumnSpec:
|
||
name: str
|
||
postgres_type: str
|
||
nullable: bool
|
||
sas_format: Optional[str] = None
|
||
source_dtype: Optional[str] = None
|
||
notes: List[str] = field(default_factory=list)
|
||
sampled: bool = False
|
||
"""True when the type was inferred from a bounded preview rather than the
|
||
full file. A sampled spec carries the usual sampling risks: a later chunk
|
||
could contain a value that exceeds the inferred integer range, doesn't
|
||
parse as the inferred type, or is null in a column the preview showed as
|
||
non-null - all of which surface as mid-``COPY`` failures."""
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Custom exceptions
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TableExistsError(RuntimeError):
|
||
"""Raised when if_exists=fail and the target table already exists."""
|
||
|
||
|
||
class SchemaCompatibilityError(RuntimeError):
|
||
"""Raised when if_exists=append and the incoming schema is not
|
||
compatible with the existing table."""
|
||
|
||
|
||
class ValidationError(RuntimeError):
|
||
"""Raised when --validate detects a mismatch against the manifest."""
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Connection
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def connect(
|
||
*,
|
||
user: Optional[str] = None,
|
||
password: Optional[str] = None,
|
||
) -> psycopg2.extensions.connection:
|
||
"""Open a psycopg2 connection using standard libpq env vars.
|
||
|
||
Assumes `.env` has already been loaded (the CLI does this before calling).
|
||
Orchestrators that wrap this module should either call ``load_dotenv()``
|
||
themselves or ensure the env vars are set.
|
||
|
||
``user`` and ``password`` override the corresponding env vars when supplied
|
||
(used by the ``--dbcreds`` CLI flag to accept interactive input).
|
||
"""
|
||
conn = psycopg2.connect(
|
||
host=os.environ.get("PGHOST"),
|
||
port=os.environ.get("PGPORT"),
|
||
user=user or os.environ.get("PGUSER"),
|
||
password=password or os.environ.get("PGPASSWORD"),
|
||
dbname=os.environ.get("PGDATABASE"),
|
||
)
|
||
return conn
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Config loading
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def load_config(path: Path) -> LoaderConfig:
|
||
"""Parse and validate the YAML config at ``path``."""
|
||
path = Path(path)
|
||
with path.open("r", encoding="utf-8") as f:
|
||
raw = yaml.safe_load(f)
|
||
|
||
if not isinstance(raw, dict):
|
||
raise ValueError(f"Config at {path} must be a YAML mapping at the top level.")
|
||
|
||
missing = [k for k in ("filename", "schemaname", "tablename") if k not in raw]
|
||
if missing:
|
||
raise ValueError(f"Config {path} missing required keys: {', '.join(missing)}")
|
||
|
||
filename = Path(raw["filename"])
|
||
if not filename.is_absolute():
|
||
filename = (path.parent / filename).resolve() if (path.parent / filename).exists() else Path(raw["filename"])
|
||
|
||
schemaname = str(raw["schemaname"])
|
||
tablename = str(raw["tablename"])
|
||
|
||
if_exists = str(raw.get("if_exists", "fail")).lower()
|
||
if if_exists not in VALID_IF_EXISTS:
|
||
raise ValueError(
|
||
f"Config {path}: if_exists={if_exists!r} is not one of {VALID_IF_EXISTS}"
|
||
)
|
||
|
||
include = raw.get("include")
|
||
exclude = raw.get("exclude")
|
||
if include is not None and exclude is not None:
|
||
raise ValueError(
|
||
f"Config {path}: 'include' and 'exclude' are mutually exclusive; set at most one."
|
||
)
|
||
if include is not None and not isinstance(include, list):
|
||
raise ValueError(f"Config {path}: 'include' must be a list of column names.")
|
||
if exclude is not None and not isinstance(exclude, list):
|
||
raise ValueError(f"Config {path}: 'exclude' must be a list of column names.")
|
||
|
||
# -- partition_by -------------------------------------------------------
|
||
raw_pb = raw.get("partition_by")
|
||
if raw_pb is None or (isinstance(raw_pb, list) and len(raw_pb) == 0):
|
||
partition_by: List[str] = []
|
||
elif isinstance(raw_pb, str):
|
||
if not raw_pb.strip():
|
||
raise ValueError(f"Config {path}: 'partition_by' string must be non-empty.")
|
||
partition_by = [raw_pb.strip()]
|
||
elif isinstance(raw_pb, list):
|
||
partition_by = []
|
||
for i, item in enumerate(raw_pb):
|
||
if not isinstance(item, str) or not item.strip():
|
||
raise ValueError(
|
||
f"Config {path}: 'partition_by[{i}]' must be a non-empty string."
|
||
)
|
||
partition_by.append(str(item).strip())
|
||
if len(partition_by) != len(set(partition_by)):
|
||
raise ValueError(
|
||
f"Config {path}: 'partition_by' contains duplicate column names."
|
||
)
|
||
else:
|
||
raise ValueError(
|
||
f"Config {path}: 'partition_by' must be a string or list of strings."
|
||
)
|
||
|
||
# Validate partition_by vs include/exclude
|
||
if partition_by:
|
||
inc_list = [str(c) for c in include] if include is not None else None
|
||
exc_list = [str(c) for c in exclude] if exclude is not None else None
|
||
if inc_list is not None:
|
||
missing_in_include = [c for c in partition_by if c not in inc_list]
|
||
if missing_in_include:
|
||
raise ValueError(
|
||
f"Config {path}: 'include' omits partition_by columns: "
|
||
f"{missing_in_include}"
|
||
)
|
||
if exc_list is not None:
|
||
excluded_parts = [c for c in partition_by if c in exc_list]
|
||
if excluded_parts:
|
||
raise ValueError(
|
||
f"Config {path}: 'exclude' removes partition_by columns: "
|
||
f"{excluded_parts}"
|
||
)
|
||
|
||
# -- max_partitions -----------------------------------------------------
|
||
raw_mp = raw.get("max_partitions", 10_000)
|
||
try:
|
||
max_partitions = int(raw_mp)
|
||
except (TypeError, ValueError):
|
||
raise ValueError(
|
||
f"Config {path}: 'max_partitions' must be a positive integer, "
|
||
f"got {raw_mp!r}"
|
||
)
|
||
if max_partitions <= 0:
|
||
raise ValueError(
|
||
f"Config {path}: 'max_partitions' must be a positive integer, "
|
||
f"got {max_partitions}"
|
||
)
|
||
|
||
# -- indexes ------------------------------------------------------------
|
||
raw_idx = raw.get("indexes")
|
||
if raw_idx is None or (isinstance(raw_idx, list) and len(raw_idx) == 0):
|
||
indexes: List[str] = []
|
||
elif isinstance(raw_idx, str):
|
||
if not raw_idx.strip():
|
||
raise ValueError(f"Config {path}: 'indexes' string must be non-empty.")
|
||
indexes = [raw_idx.strip()]
|
||
elif isinstance(raw_idx, list):
|
||
indexes = []
|
||
for i, item in enumerate(raw_idx):
|
||
if not isinstance(item, str) or not item.strip():
|
||
raise ValueError(
|
||
f"Config {path}: 'indexes[{i}]' must be a non-empty string."
|
||
)
|
||
indexes.append(str(item).strip())
|
||
if len(indexes) != len(set(indexes)):
|
||
raise ValueError(
|
||
f"Config {path}: 'indexes' contains duplicate column names."
|
||
)
|
||
else:
|
||
raise ValueError(
|
||
f"Config {path}: 'indexes' must be a string or list of strings."
|
||
)
|
||
|
||
# Validate indexes vs include/exclude
|
||
if indexes:
|
||
inc_list = [str(c) for c in include] if include is not None else None
|
||
exc_list = [str(c) for c in exclude] if exclude is not None else None
|
||
if exc_list is not None:
|
||
excluded_idx = [c for c in indexes if c in exc_list]
|
||
if excluded_idx:
|
||
raise ValueError(
|
||
f"Config {path}: 'exclude' removes index columns: "
|
||
f"{excluded_idx}"
|
||
)
|
||
if inc_list is not None:
|
||
missing_in_include = [c for c in indexes if c not in inc_list]
|
||
if missing_in_include:
|
||
raise ValueError(
|
||
f"Config {path}: 'include' omits index columns: "
|
||
f"{missing_in_include}"
|
||
)
|
||
|
||
# -- column_types -------------------------------------------------------
|
||
# Optional ``{column_name: pg_type}`` escape hatch that bypasses
|
||
# automatic type inference for specific columns. Useful when
|
||
# pyreadstat reports a column as NUM but the downstream consumer
|
||
# expects TEXT (e.g. phone-id columns), or when a column has drifted
|
||
# between CHAR and NUM across file versions and you want to pin
|
||
# TEXT up front. See also :func:`infer_schema`.
|
||
raw_ct = raw.get("column_types")
|
||
column_types: Dict[str, str] = {}
|
||
if raw_ct is not None:
|
||
if not isinstance(raw_ct, dict):
|
||
raise ValueError(
|
||
f"Config {path}: 'column_types' must be a mapping of "
|
||
f"{{column_name: postgres_type}}."
|
||
)
|
||
for k, v in raw_ct.items():
|
||
key = str(k).strip()
|
||
if not key:
|
||
raise ValueError(
|
||
f"Config {path}: 'column_types' contains an empty "
|
||
f"column name."
|
||
)
|
||
if not isinstance(v, str) or not v.strip():
|
||
raise ValueError(
|
||
f"Config {path}: 'column_types[{key}]' must be a "
|
||
f"non-empty Postgres type string (got {v!r})."
|
||
)
|
||
column_types[key] = v.strip()
|
||
|
||
# -- all_nullable -------------------------------------------------------
|
||
# When inference wrongly stamps a column NOT NULL (sampled rows happened
|
||
# to be dense; later rows carry nulls) downstream COPYs fail mid-stream.
|
||
# Set ``all_nullable: true`` in the YAML to stamp every column nullable
|
||
# up front. The CLI flag ``--all-nullable`` overrides this to ``true``
|
||
# if set.
|
||
raw_an = raw.get("all_nullable", False)
|
||
if not isinstance(raw_an, bool):
|
||
raise ValueError(
|
||
f"Config {path}: 'all_nullable' must be a boolean (got {raw_an!r})."
|
||
)
|
||
all_nullable = bool(raw_an)
|
||
|
||
# -- file_type ----------------------------------------------------------
|
||
file_type = str(raw.get("file_type", "sas")).lower()
|
||
if file_type not in VALID_FILE_TYPES:
|
||
raise ValueError(
|
||
f"Config {path}: file_type={file_type!r} is not one of "
|
||
f"{VALID_FILE_TYPES}"
|
||
)
|
||
|
||
# -- text-file-specific fields ------------------------------------------
|
||
# Only validated when file_type == "text"; harmless defaults otherwise.
|
||
raw_delim = raw.get("delimiter", ",")
|
||
if isinstance(raw_delim, str):
|
||
delim_lower = raw_delim.lower().strip()
|
||
if delim_lower in ("tab", "\\t"):
|
||
delimiter = "\t"
|
||
elif delim_lower in ("pipe", "|"):
|
||
delimiter = "|"
|
||
else:
|
||
delimiter = raw_delim
|
||
else:
|
||
delimiter = str(raw_delim)
|
||
|
||
text_encoding = str(raw.get("text_encoding", "utf-8"))
|
||
quotechar = str(raw.get("quotechar", '"'))
|
||
|
||
return LoaderConfig(
|
||
filename=filename,
|
||
schemaname=schemaname,
|
||
tablename=tablename,
|
||
if_exists=if_exists,
|
||
include=[str(c) for c in include] if include is not None else None,
|
||
exclude=[str(c) for c in exclude] if exclude is not None else None,
|
||
partition_by=partition_by,
|
||
max_partitions=max_partitions,
|
||
indexes=indexes,
|
||
column_types=column_types,
|
||
all_nullable=all_nullable,
|
||
file_type=file_type,
|
||
delimiter=delimiter,
|
||
text_encoding=text_encoding,
|
||
quotechar=quotechar,
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Reader
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _is_text_file(path: Path) -> bool:
|
||
"""Return True if ``path`` has a recognised delimited-text extension."""
|
||
return Path(path).suffix.lower() in TEXT_EXTENSIONS
|
||
|
||
|
||
def _sas_reader(path: Path) -> Tuple[Any, Dict[str, Any]]:
|
||
"""Return ``(pyreadstat_reader, extra_kwargs)`` for ``path``.
|
||
|
||
Invariants (learned the hard way while building the sample generator):
|
||
|
||
* ``.xpt`` / ``.xport`` - no encoding arg; pyreadstat is flaky about
|
||
encoding on XPORT files it wrote itself.
|
||
* ``.sas7bdat`` - explicit ``encoding="latin-1"`` per colleague guidance.
|
||
"""
|
||
suffix = Path(path).suffix.lower()
|
||
if suffix in (".xpt", ".xport"):
|
||
return pyreadstat.read_xport, {}
|
||
if suffix == ".sas7bdat":
|
||
return pyreadstat.read_sas7bdat, {}
|
||
raise ValueError(f"Unsupported SAS file extension: {suffix}")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Text file readers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _count_text_lines(path: Path, encoding: str = "utf-8") -> int:
|
||
"""Count data rows in a text file (excludes the header line).
|
||
|
||
Reads the file in binary chunks for speed; counts newlines and
|
||
subtracts one for the header.
|
||
"""
|
||
count = 0
|
||
with open(path, "rb") as fh:
|
||
for chunk in iter(lambda: fh.read(1 << 20), b""):
|
||
count += chunk.count(b"\n")
|
||
# If the file doesn't end with a newline the last line is still a row.
|
||
# But the first line is the header, so subtract 1.
|
||
# Edge case: empty file or header-only -> 0 rows.
|
||
return max(0, count - 1) if count > 0 else 0
|
||
|
||
|
||
def _build_text_metadata(
|
||
column_names: List[str],
|
||
number_rows: Optional[int] = None,
|
||
) -> TextFileMetadata:
|
||
"""Build a :class:`TextFileMetadata` from column names and an optional
|
||
row count."""
|
||
return TextFileMetadata(
|
||
column_names=list(column_names),
|
||
column_labels=list(column_names),
|
||
original_variable_types={},
|
||
number_rows=number_rows,
|
||
)
|
||
|
||
|
||
def read_text(
|
||
path: Path,
|
||
delimiter: str = ",",
|
||
encoding: str = "utf-8",
|
||
quotechar: str = '"',
|
||
) -> Tuple[pd.DataFrame, TextFileMetadata]:
|
||
"""Read an entire delimited text file into memory.
|
||
|
||
Returns ``(DataFrame, TextFileMetadata)`` — the metadata object carries
|
||
the same attributes that :func:`infer_schema` reads from pyreadstat
|
||
metadata.
|
||
"""
|
||
path = Path(path)
|
||
df = pd.read_csv(
|
||
path,
|
||
delimiter=delimiter,
|
||
encoding=encoding,
|
||
quotechar=quotechar,
|
||
dtype=str,
|
||
keep_default_na=True,
|
||
na_values=[""],
|
||
)
|
||
meta = _build_text_metadata(list(df.columns), number_rows=len(df))
|
||
return df, meta
|
||
|
||
|
||
def read_text_preview(
|
||
path: Path,
|
||
delimiter: str = ",",
|
||
encoding: str = "utf-8",
|
||
quotechar: str = '"',
|
||
rows: Optional[int] = None,
|
||
) -> Tuple[pd.DataFrame, TextFileMetadata]:
|
||
"""Read the first ``rows`` records from a delimited text file.
|
||
|
||
When ``rows`` is ``None`` or 0, reads the entire file (matching the
|
||
semantics of :func:`read_sas_preview`).
|
||
"""
|
||
path = Path(path)
|
||
nrows = int(rows) if rows else None
|
||
df = pd.read_csv(
|
||
path,
|
||
delimiter=delimiter,
|
||
encoding=encoding,
|
||
quotechar=quotechar,
|
||
nrows=nrows,
|
||
dtype=str,
|
||
keep_default_na=True,
|
||
na_values=[""],
|
||
)
|
||
# For total row count, do a fast line count when we only read a preview.
|
||
if nrows is not None and nrows > 0:
|
||
total = _count_text_lines(path, encoding)
|
||
else:
|
||
total = len(df)
|
||
meta = _build_text_metadata(list(df.columns), number_rows=total)
|
||
return df, meta
|
||
|
||
|
||
def read_text_metadata(
|
||
path: Path,
|
||
delimiter: str = ",",
|
||
encoding: str = "utf-8",
|
||
quotechar: str = '"',
|
||
) -> TextFileMetadata:
|
||
"""Read only the header and line count from a delimited text file.
|
||
|
||
Fast path: reads the first line for column names and counts newlines
|
||
for the row total without materializing a DataFrame.
|
||
"""
|
||
path = Path(path)
|
||
# Read just the header row.
|
||
df_header = pd.read_csv(
|
||
path,
|
||
delimiter=delimiter,
|
||
encoding=encoding,
|
||
quotechar=quotechar,
|
||
nrows=0,
|
||
)
|
||
column_names = list(df_header.columns)
|
||
total = _count_text_lines(path, encoding)
|
||
return _build_text_metadata(column_names, number_rows=total)
|
||
|
||
|
||
def iter_text_chunks(
|
||
path: Path,
|
||
delimiter: str = ",",
|
||
encoding: str = "utf-8",
|
||
quotechar: str = '"',
|
||
chunksize: Optional[int] = None,
|
||
usecols: Optional[List[str]] = None,
|
||
):
|
||
"""Yield ``(df_chunk, meta)`` tuples for streaming text file loads.
|
||
|
||
Uses ``pandas.read_csv()`` with ``chunksize`` for memory-efficient
|
||
iteration. The metadata object is rebuilt for each chunk with the
|
||
chunk's column names and ``number_rows`` set to the total file rows
|
||
(computed once up front).
|
||
|
||
When ``usecols`` is provided, only those columns are parsed - useful
|
||
for cheap partition-value discovery scans where the rest of the row
|
||
would be wasted I/O.
|
||
"""
|
||
path = Path(path)
|
||
if chunksize is None:
|
||
raw_env = os.environ.get("GENERIC_LOADER_CHUNK_ROWS")
|
||
if raw_env is not None:
|
||
try:
|
||
chunksize = int(raw_env)
|
||
except ValueError:
|
||
chunksize = DEFAULT_CHUNK_ROWS
|
||
else:
|
||
chunksize = DEFAULT_CHUNK_ROWS
|
||
|
||
total = _count_text_lines(path, encoding)
|
||
|
||
read_csv_kwargs: Dict[str, Any] = dict(
|
||
delimiter=delimiter,
|
||
encoding=encoding,
|
||
quotechar=quotechar,
|
||
chunksize=chunksize,
|
||
dtype=str,
|
||
keep_default_na=True,
|
||
na_values=[""],
|
||
)
|
||
if usecols is not None:
|
||
read_csv_kwargs["usecols"] = list(usecols)
|
||
|
||
reader = pd.read_csv(path, **read_csv_kwargs)
|
||
for chunk_df in reader:
|
||
meta = _build_text_metadata(list(chunk_df.columns), number_rows=total)
|
||
yield chunk_df, meta
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Unified reader dispatch
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def read_sas(
|
||
path: Path,
|
||
*,
|
||
delimiter: str = ",",
|
||
text_encoding: str = "utf-8",
|
||
quotechar: str = '"',
|
||
) -> Tuple[pd.DataFrame, Any]:
|
||
"""Read an entire SAS or delimited text file into memory.
|
||
|
||
For SAS files (``.sas7bdat``, ``.xpt``, ``.xport``), delegates to
|
||
pyreadstat. For text files (``.txt``, ``.csv``, ``.tsv``), delegates
|
||
to :func:`read_text`. The text-specific parameters are ignored for SAS
|
||
files.
|
||
|
||
Kept for backward compatibility and tests; the CLI now uses
|
||
:func:`read_sas_preview` + :func:`iter_sas_chunks` so it never materializes
|
||
the whole frame at once.
|
||
"""
|
||
if _is_text_file(path):
|
||
return read_text(path, delimiter=delimiter, encoding=text_encoding, quotechar=quotechar)
|
||
reader, kwargs = _sas_reader(path)
|
||
return reader(str(Path(path)), **kwargs)
|
||
|
||
|
||
def read_sas_preview(
|
||
path: Path,
|
||
*,
|
||
rows: Optional[int] = None,
|
||
delimiter: str = ",",
|
||
text_encoding: str = "utf-8",
|
||
quotechar: str = '"',
|
||
) -> Tuple[pd.DataFrame, Any]:
|
||
"""Read the first ``rows`` records from ``path`` plus its metadata.
|
||
|
||
Defaults to ``TYPE_INFERENCE_SAMPLE_ROWS`` when ``rows`` is not given.
|
||
Passing ``rows=None`` with ``TYPE_INFERENCE_SAMPLE_ROWS=None`` reads the
|
||
whole file (pyreadstat treats ``row_limit=0`` as unlimited).
|
||
|
||
For text files, delegates to :func:`read_text_preview`.
|
||
"""
|
||
effective = rows if rows is not None else TYPE_INFERENCE_SAMPLE_ROWS
|
||
if _is_text_file(path):
|
||
return read_text_preview(
|
||
path,
|
||
delimiter=delimiter,
|
||
encoding=text_encoding,
|
||
quotechar=quotechar,
|
||
rows=effective,
|
||
)
|
||
reader, kwargs = _sas_reader(path)
|
||
row_limit = int(effective) if effective else 0
|
||
return reader(str(Path(path)), row_limit=row_limit, **kwargs)
|
||
|
||
|
||
def read_sas_metadata(
|
||
path: Path,
|
||
*,
|
||
delimiter: str = ",",
|
||
text_encoding: str = "utf-8",
|
||
quotechar: str = '"',
|
||
) -> Any:
|
||
"""Read only the metadata (no rows) from a SAS or text file.
|
||
|
||
Uses pyreadstat's ``metadataonly=True`` fast path for SAS files: the
|
||
reader decodes the file header (column names, formats, total row count,
|
||
etc.) and returns without touching the data pages. Orders of magnitude
|
||
faster than :func:`read_sas_preview` when all you need is
|
||
``meta.number_rows`` - typically a few ms per sas7bdat file, which
|
||
makes it cheap to pre-scan a whole folder to populate a global
|
||
progress bar.
|
||
|
||
For text files, delegates to :func:`read_text_metadata`.
|
||
"""
|
||
if _is_text_file(path):
|
||
return read_text_metadata(
|
||
path, delimiter=delimiter, encoding=text_encoding, quotechar=quotechar,
|
||
)
|
||
reader, kwargs = _sas_reader(path)
|
||
_, meta = reader(str(Path(path)), metadataonly=True, **kwargs)
|
||
return meta
|
||
|
||
|
||
def iter_sas_chunks(
|
||
path: Path,
|
||
*,
|
||
chunksize: Optional[int] = None,
|
||
delimiter: str = ",",
|
||
text_encoding: str = "utf-8",
|
||
quotechar: str = '"',
|
||
usecols: Optional[List[str]] = None,
|
||
):
|
||
"""Yield ``(df_chunk, meta)`` tuples for streaming loads.
|
||
|
||
Thin wrapper over ``pyreadstat.read_file_in_chunks`` that picks the right
|
||
underlying reader by extension and threads through our encoding defaults.
|
||
|
||
When ``chunksize`` is ``None`` (the default), the effective value comes
|
||
from the ``GENERIC_LOADER_CHUNK_ROWS`` environment variable if set and
|
||
parseable, otherwise from :data:`DEFAULT_CHUNK_ROWS`. An explicit int
|
||
always wins.
|
||
|
||
When ``usecols`` is provided, pyreadstat only decodes the listed
|
||
columns. For wide sas7bdat files this is dramatically cheaper than
|
||
a full read - the C decoder skips unwanted columns instead of
|
||
materializing them. Used by partition-value discovery to avoid
|
||
re-reading every byte of every file just to extract a couple of
|
||
partition keys.
|
||
|
||
For text files, delegates to :func:`iter_text_chunks`.
|
||
"""
|
||
if _is_text_file(path):
|
||
yield from iter_text_chunks(
|
||
path,
|
||
delimiter=delimiter,
|
||
encoding=text_encoding,
|
||
quotechar=quotechar,
|
||
chunksize=chunksize,
|
||
usecols=usecols,
|
||
)
|
||
return
|
||
if chunksize is None:
|
||
raw = os.environ.get("GENERIC_LOADER_CHUNK_ROWS")
|
||
if raw is not None:
|
||
try:
|
||
chunksize = int(raw)
|
||
except ValueError:
|
||
chunksize = DEFAULT_CHUNK_ROWS
|
||
else:
|
||
chunksize = DEFAULT_CHUNK_ROWS
|
||
reader, kwargs = _sas_reader(path)
|
||
if usecols is not None:
|
||
kwargs = dict(kwargs)
|
||
kwargs["usecols"] = list(usecols)
|
||
yield from pyreadstat.read_file_in_chunks(
|
||
reader, str(Path(path)), chunksize=chunksize, **kwargs
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Column filtering
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def apply_column_filter(
|
||
df: pd.DataFrame,
|
||
include: Optional[List[str]],
|
||
exclude: Optional[List[str]],
|
||
) -> pd.DataFrame:
|
||
"""Restrict ``df`` to the requested columns. Names missing from the frame
|
||
raise a clear error rather than silently dropping.
|
||
|
||
Returns the input frame (or a column-sliced view / drop result) without
|
||
an extra ``.copy()`` — downstream (:func:`_prepare_for_copy`) reads the
|
||
frame into a freshly built output and never mutates its input, so the
|
||
copies were pure overhead on every streamed chunk.
|
||
"""
|
||
if include is not None and exclude is not None:
|
||
raise ValueError("include and exclude are mutually exclusive.")
|
||
|
||
if include is not None:
|
||
missing = [c for c in include if c not in df.columns]
|
||
if missing:
|
||
raise ValueError(f"include references unknown columns: {missing}")
|
||
return df.loc[:, list(include)]
|
||
|
||
if exclude is not None:
|
||
missing = [c for c in exclude if c not in df.columns]
|
||
if missing:
|
||
raise ValueError(f"exclude references unknown columns: {missing}")
|
||
return df.drop(columns=list(exclude))
|
||
|
||
return df
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Type inference
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
_DATE_FORMAT_PREFIXES = ("DATE", "YYMMDD", "MMDDYY", "DDMMYY", "JULIAN")
|
||
|
||
|
||
def _format_driven_type(sas_format: Optional[str]) -> Optional[str]:
|
||
"""Return a Postgres type inferred from the SAS format string, or None
|
||
if the format doesn't pin it down."""
|
||
if not sas_format:
|
||
return None
|
||
fmt = sas_format.upper().lstrip()
|
||
# DATETIME must be checked before DATE since "DATETIME20." starts with "DATE".
|
||
if fmt.startswith("DATETIME"):
|
||
return "TIMESTAMP"
|
||
if fmt.startswith("TIME"):
|
||
return "TIME"
|
||
for prefix in _DATE_FORMAT_PREFIXES:
|
||
if fmt.startswith(prefix):
|
||
return "DATE"
|
||
return None
|
||
|
||
|
||
_DECIMAL_FORMAT_RE = re.compile(r"\.(\d+)")
|
||
|
||
|
||
def _format_hints_decimal(sas_format: Optional[str]) -> bool:
|
||
"""True if a numeric SAS format string explicitly carries decimal places.
|
||
|
||
SAS numeric formats are ``NAMEw.d``; ``d > 0`` means the variable was
|
||
intended to render with ``d`` decimal digits (COMMA10.2, F8.3, ...).
|
||
A bare width like ``BEST12.`` or ``F8.`` has no digits after the dot
|
||
and is treated as integer-presenting. Used by
|
||
:func:`union_column_types` to pick BIGINT vs DOUBLE PRECISION when a
|
||
column is numeric in every file of a cluster.
|
||
"""
|
||
if not sas_format:
|
||
return False
|
||
m = _DECIMAL_FORMAT_RE.search(sas_format)
|
||
if not m:
|
||
return False
|
||
try:
|
||
return int(m.group(1)) > 0
|
||
except ValueError:
|
||
return False
|
||
|
||
|
||
def extract_union_metadata(
|
||
meta: Any,
|
||
) -> Dict[str, Tuple[str, Optional[str]]]:
|
||
"""Pull the (readstat_type, sas_format) pair for every column in ``meta``.
|
||
|
||
Returns a plain dict that's safe to pass between processes and to
|
||
:func:`union_column_types`. ``readstat_type`` is the simplified type
|
||
reported by pyreadstat: ``"string"`` for SAS CHAR, ``"double"`` for
|
||
SAS NUM. ``sas_format`` comes from ``meta.original_variable_types``
|
||
and drives date/datetime detection during union.
|
||
"""
|
||
var_types = dict(getattr(meta, "variable_types", None) or {})
|
||
formats = dict(getattr(meta, "original_variable_types", None) or {})
|
||
names = list(
|
||
getattr(meta, "column_names", None)
|
||
or list(var_types.keys())
|
||
or list(formats.keys())
|
||
)
|
||
out: Dict[str, Tuple[str, Optional[str]]] = {}
|
||
for col in names:
|
||
rtype = str(var_types.get(col, "")) if var_types else ""
|
||
fmt = formats.get(col)
|
||
out[col] = (rtype, fmt if fmt else None)
|
||
return out
|
||
|
||
|
||
def union_column_types(
|
||
per_file_metas: Iterable[Dict[str, Tuple[str, Optional[str]]]],
|
||
) -> Dict[str, str]:
|
||
"""Derive one Postgres type per column that's safe across every file.
|
||
|
||
``per_file_metas`` is an iterable (one entry per file in a cluster) of
|
||
``{column_name: (readstat_type, sas_format)}`` dicts as produced by
|
||
:func:`extract_union_metadata`.
|
||
|
||
Rules, evaluated per column:
|
||
|
||
* **CHAR/NUM drift wins TEXT.** If any file stores the column as CHAR
|
||
(``readstat_type != "double"``) the union is ``TEXT``. This covers
|
||
the phone-id case where some years stored ``RESP_PH_PREFIX_ID`` as
|
||
CHAR and others as NUM.
|
||
* **All NUM, format hints DATETIME → TIMESTAMP.** Any file whose
|
||
format resolves to ``TIMESTAMP`` (via :func:`_format_driven_type`)
|
||
pins the column to ``TIMESTAMP`` even if other files left the
|
||
format blank.
|
||
* **All NUM, format hints DATE → DATE.** Same idea for date-only
|
||
formats.
|
||
* **All NUM, any decimal hint → DOUBLE PRECISION.** A ``w.d`` format
|
||
with ``d > 0`` in any file implies fractional values somewhere.
|
||
* **All NUM, no useful hint → DOUBLE PRECISION.** SAS numeric
|
||
formats are *display* formats, not storage constraints - a
|
||
``BEST12.`` / ``F8.`` / blank-format column can still hold floats,
|
||
and pyreadstat hands back plain ``float64`` regardless. Defaulting
|
||
to ``DOUBLE PRECISION`` here costs the same 8 bytes as ``BIGINT``
|
||
but can't fail on real data. For columns that truly are
|
||
integer-only and you want ``BIGINT`` semantics in queries, pin
|
||
them via a ``column_types`` override.
|
||
|
||
Columns missing from a given file are simply skipped for that file;
|
||
the union is computed over whichever files *did* supply the column.
|
||
Columns that never appear anywhere are omitted from the result.
|
||
"""
|
||
per_col: Dict[str, List[Tuple[str, Optional[str]]]] = {}
|
||
for meta in per_file_metas:
|
||
for col, pair in meta.items():
|
||
per_col.setdefault(col, []).append(pair)
|
||
|
||
result: Dict[str, str] = {}
|
||
for col, entries in per_col.items():
|
||
any_char = any(
|
||
rtype and rtype.lower() != "double" for rtype, _ in entries
|
||
)
|
||
if any_char:
|
||
result[col] = "TEXT"
|
||
continue
|
||
formats = [fmt for _, fmt in entries if fmt]
|
||
driven = [_format_driven_type(f) for f in formats]
|
||
if "TIMESTAMP" in driven:
|
||
result[col] = "TIMESTAMP"
|
||
elif "DATE" in driven:
|
||
result[col] = "DATE"
|
||
else:
|
||
# Safe default: DOUBLE PRECISION. The BIGINT default we tried
|
||
# first failed the moment a file contained a fractional
|
||
# value in a column whose format didn't carry a decimal
|
||
# hint (very common: SAS ``BEST12.`` / ``F8.`` are display
|
||
# formats, not storage constraints, so the underlying
|
||
# 8-byte float can hold any value). Same storage cost as
|
||
# BIGINT, handles both integer- and float-valued data, and
|
||
# keeps loads from failing mid-cluster. Use a
|
||
# ``column_types`` override to pin specific columns to
|
||
# ``BIGINT`` when you want integer semantics in queries.
|
||
result[col] = "DOUBLE PRECISION"
|
||
return result
|
||
|
||
|
||
def _all_null(series: pd.Series) -> bool:
|
||
if pd.api.types.is_object_dtype(series):
|
||
return bool(series.map(lambda v: v is None or (isinstance(v, str) and v == "") or (isinstance(v, float) and pd.isna(v))).all())
|
||
return bool(series.isna().all())
|
||
|
||
|
||
def _char_missing_mask(series: pd.Series) -> pd.Series:
|
||
return series.map(lambda v: v is None or (isinstance(v, float) and pd.isna(v)) or (isinstance(v, str) and v == ""))
|
||
|
||
|
||
def _is_nullable(series: pd.Series) -> bool:
|
||
"""True if the column has at least one missing value."""
|
||
if pd.api.types.is_object_dtype(series):
|
||
return bool(_char_missing_mask(series).any())
|
||
return bool(series.isna().any())
|
||
|
||
|
||
def _numeric_int_target(series: pd.Series) -> Optional[str]:
|
||
"""Given a numeric (float64) series, if every non-null value is a whole
|
||
number, return INTEGER or BIGINT depending on range; else None."""
|
||
nonnull = series.dropna()
|
||
if nonnull.empty:
|
||
return None
|
||
# Whole-number test. Guard against inf.
|
||
try:
|
||
whole = ((nonnull % 1) == 0).all()
|
||
except TypeError:
|
||
return None
|
||
if not whole:
|
||
return None
|
||
lo, hi = NUMERIC_INT_RANGE
|
||
vmin = nonnull.min()
|
||
vmax = nonnull.max()
|
||
if lo <= vmin and vmax <= hi:
|
||
return "INTEGER"
|
||
return "BIGINT"
|
||
|
||
|
||
def _object_is_dates(series: pd.Series) -> Tuple[bool, bool]:
|
||
"""Return (all-date-like, any-datetime). If every non-null value is a
|
||
``datetime.date`` / ``datetime.datetime`` / ``pd.Timestamp``, return True
|
||
plus whether at least one carries a time component."""
|
||
nonnull = series.dropna()
|
||
if nonnull.empty:
|
||
return False, False
|
||
any_datetime = False
|
||
for v in nonnull:
|
||
if isinstance(v, dt.datetime) or isinstance(v, pd.Timestamp):
|
||
any_datetime = True
|
||
continue
|
||
if isinstance(v, dt.date):
|
||
continue
|
||
return False, False
|
||
return True, any_datetime
|
||
|
||
|
||
def _try_int_coerce(values: List[str]) -> Optional[str]:
|
||
"""If every value parses as an int, return INTEGER/BIGINT, else None."""
|
||
ints: List[int] = []
|
||
for v in values:
|
||
s = v.strip()
|
||
try:
|
||
ints.append(int(s))
|
||
except ValueError:
|
||
return None
|
||
if not ints:
|
||
return None
|
||
lo, hi = NUMERIC_INT_RANGE
|
||
if all(lo <= i <= hi for i in ints):
|
||
return "INTEGER"
|
||
return "BIGINT"
|
||
|
||
|
||
def _try_float_coerce(values: List[str]) -> bool:
|
||
for v in values:
|
||
try:
|
||
float(v)
|
||
except ValueError:
|
||
return False
|
||
return True
|
||
|
||
|
||
# Locale-independent month lookup so ``DD-MON-YY`` / ``DDMONYYYY`` style
|
||
# strings (Oracle's default ``DD-MON-YY`` export, SAS ``DATE7.`` /
|
||
# ``DATE9.`` rendered to text, spreadsheets spitting out ``23-Mar-2020``)
|
||
# parse correctly regardless of the host's ``LC_TIME``. ``strptime("%b")``
|
||
# is locale-dependent and silently fails on non-English systems; this
|
||
# dict sidesteps that entirely.
|
||
_MONTH_LOOKUP: Dict[str, int] = {
|
||
"JAN": 1, "FEB": 2, "MAR": 3, "APR": 4, "MAY": 5, "JUN": 6,
|
||
"JUL": 7, "AUG": 8, "SEP": 9, "SEPT": 9, "OCT": 10, "NOV": 11, "DEC": 12,
|
||
"JANUARY": 1, "FEBRUARY": 2, "MARCH": 3, "APRIL": 4, "JUNE": 6,
|
||
"JULY": 7, "AUGUST": 8, "SEPTEMBER": 9, "OCTOBER": 10,
|
||
"NOVEMBER": 11, "DECEMBER": 12,
|
||
}
|
||
|
||
# ``DD[sep]MON[sep]YY`` with an optional ``HH:MM[:SS[.ffff]] [AM|PM]``
|
||
# suffix. ``sep`` can be ``-``, ``/``, space, or empty so the same
|
||
# regex covers ``23-MAR-20``, ``23-MAR-2020``, ``23MAR2020`` (SAS
|
||
# ``DATE9.``), ``23 Mar 2020`` (Excel), and ``23-MAR-20 14:30:00``
|
||
# (Oracle ``TO_CHAR`` default with timestamp). Time portion is lenient
|
||
# on separator (``:`` or ``.``) since Oracle's default timestamp
|
||
# rendering uses dots (``02.30.45.123456``) while most others use
|
||
# colons.
|
||
_DDMONYY_RE = re.compile(
|
||
r"""
|
||
^\s*
|
||
(?P<day>\d{1,2})
|
||
[-/\s]?
|
||
(?P<month>[A-Za-z]{3,9})
|
||
[-/\s]?
|
||
(?P<year>\d{2}|\d{4})
|
||
(?:
|
||
[\sT:]+
|
||
(?P<hour>\d{1,2}) [:.] (?P<minute>\d{2})
|
||
(?:
|
||
[:.] (?P<second>\d{2})
|
||
(?: \. (?P<micro>\d+) )?
|
||
)?
|
||
\s*
|
||
(?P<ampm>[AaPp][Mm])?
|
||
)?
|
||
\s*$
|
||
""",
|
||
re.VERBOSE,
|
||
)
|
||
|
||
# Strptime fallbacks for all-numeric shapes the regex above can't
|
||
# disambiguate. Order matters: unambiguous 4-digit-year layouts first,
|
||
# then US-style ``mm/dd`` before EU-style ``dd/mm`` (the former is
|
||
# dominant in the kinds of exports this loader sees). Columns whose
|
||
# true format is ``DD/MM/YY`` should pin the Postgres type via
|
||
# ``column_types: {col: TEXT}`` and parse themselves downstream.
|
||
_EXTRA_DATE_FORMATS: Tuple[str, ...] = (
|
||
"%Y/%m/%d",
|
||
"%Y%m%d",
|
||
"%m/%d/%Y",
|
||
"%m/%d/%y",
|
||
"%m-%d-%Y",
|
||
"%m-%d-%y",
|
||
"%d/%m/%Y",
|
||
"%d/%m/%y",
|
||
"%d-%m-%Y",
|
||
"%d-%m-%y",
|
||
)
|
||
|
||
_EXTRA_DATETIME_FORMATS: Tuple[str, ...] = (
|
||
"%Y-%m-%d %H:%M:%S",
|
||
"%Y-%m-%d %H:%M:%S.%f",
|
||
"%Y-%m-%dT%H:%M:%S",
|
||
"%Y-%m-%dT%H:%M:%S.%f",
|
||
"%m/%d/%Y %H:%M:%S",
|
||
"%m/%d/%Y %H:%M",
|
||
"%m/%d/%y %H:%M:%S",
|
||
"%m/%d/%y %H:%M",
|
||
"%d/%m/%Y %H:%M:%S",
|
||
"%d/%m/%y %H:%M:%S",
|
||
"%Y/%m/%d %H:%M:%S",
|
||
)
|
||
|
||
|
||
def _parse_flexible_date(value: Any) -> Optional[dt.date]:
|
||
"""Parse ``value`` to ``datetime.date`` using ISO first, then the
|
||
``DD-MON-YY`` family, then the numeric fallbacks in
|
||
:data:`_EXTRA_DATE_FORMATS`. Returns ``None`` if nothing matches.
|
||
|
||
Non-string / empty / non-finite inputs return ``None`` rather than
|
||
raising so callers can use this as a drop-in replacement for the old
|
||
``dt.date.fromisoformat`` + ``try``/``except`` pattern.
|
||
"""
|
||
if value is None:
|
||
return None
|
||
if not isinstance(value, str):
|
||
return None
|
||
s = value.strip()
|
||
if not s:
|
||
return None
|
||
try:
|
||
return dt.date.fromisoformat(s)
|
||
except (ValueError, TypeError):
|
||
pass
|
||
m = _DDMONYY_RE.match(s)
|
||
# Reject inputs that carry a time component so ``_try_date_coerce``
|
||
# doesn't silently swallow ``TIMESTAMP`` columns (``23-MAR-20 14:30:00``)
|
||
# and misclassify them as ``DATE``.
|
||
if m and m.group("hour") is None:
|
||
month = _MONTH_LOOKUP.get(m.group("month").upper())
|
||
if month is not None:
|
||
try:
|
||
day = int(m.group("day"))
|
||
year = int(m.group("year"))
|
||
if len(m.group("year")) == 2:
|
||
# Pivot year = 69 matches SAS / Oracle / Excel
|
||
# conventions: ``00..68`` -> 2000s, ``69..99`` -> 1900s.
|
||
year = 2000 + year if year < 69 else 1900 + year
|
||
return dt.date(year, month, day)
|
||
except ValueError:
|
||
return None
|
||
for fmt in _EXTRA_DATE_FORMATS:
|
||
try:
|
||
return dt.datetime.strptime(s, fmt).date()
|
||
except ValueError:
|
||
continue
|
||
return None
|
||
|
||
|
||
def _parse_flexible_datetime(value: Any) -> Optional[dt.datetime]:
|
||
"""Parse ``value`` to ``datetime.datetime``. Same format coverage as
|
||
:func:`_parse_flexible_date` plus explicit datetime shapes; a
|
||
date-only input is promoted to midnight so callers can treat a
|
||
column that mixes ``23-MAR-20`` and ``23-MAR-20 14:30:00`` as
|
||
``TIMESTAMP`` end-to-end.
|
||
"""
|
||
if value is None:
|
||
return None
|
||
if not isinstance(value, str):
|
||
return None
|
||
s = value.strip()
|
||
if not s:
|
||
return None
|
||
try:
|
||
return dt.datetime.fromisoformat(s)
|
||
except (ValueError, TypeError):
|
||
pass
|
||
m = _DDMONYY_RE.match(s)
|
||
if m:
|
||
month = _MONTH_LOOKUP.get(m.group("month").upper())
|
||
if month is not None:
|
||
try:
|
||
day = int(m.group("day"))
|
||
year = int(m.group("year"))
|
||
if len(m.group("year")) == 2:
|
||
year = 2000 + year if year < 69 else 1900 + year
|
||
hour = int(m.group("hour")) if m.group("hour") else 0
|
||
minute = int(m.group("minute")) if m.group("minute") else 0
|
||
second = int(m.group("second")) if m.group("second") else 0
|
||
micro = 0
|
||
if m.group("micro"):
|
||
# ``%f`` expects 1-6 digits; pad / truncate to match.
|
||
micro_s = m.group("micro")[:6].ljust(6, "0")
|
||
micro = int(micro_s)
|
||
ampm = m.group("ampm")
|
||
if ampm:
|
||
ap = ampm.upper()
|
||
if ap == "PM" and hour < 12:
|
||
hour += 12
|
||
elif ap == "AM" and hour == 12:
|
||
hour = 0
|
||
return dt.datetime(year, month, day, hour, minute, second, micro)
|
||
except ValueError:
|
||
return None
|
||
for fmt in _EXTRA_DATETIME_FORMATS:
|
||
try:
|
||
return dt.datetime.strptime(s, fmt)
|
||
except ValueError:
|
||
continue
|
||
# Final fallback: accept a date-only string and promote to midnight.
|
||
d = _parse_flexible_date(s)
|
||
if d is not None:
|
||
return dt.datetime(d.year, d.month, d.day)
|
||
return None
|
||
|
||
|
||
def _try_date_coerce(values: List[str]) -> bool:
|
||
for v in values:
|
||
if _parse_flexible_date(v) is None:
|
||
return False
|
||
return True
|
||
|
||
|
||
def _try_datetime_coerce(values: List[str]) -> bool:
|
||
for v in values:
|
||
if _parse_flexible_datetime(v) is None:
|
||
return False
|
||
return True
|
||
|
||
|
||
def _infer_char_type(series: pd.Series) -> str:
|
||
"""Object/string column inference. Returns a Postgres type string."""
|
||
mask = _char_missing_mask(series)
|
||
nonempty = [str(v) for v in series[~mask].tolist()]
|
||
if not COERCE_CHAR_COLUMNS or len(nonempty) < CHAR_INFERENCE_MIN_VALUES:
|
||
return "TEXT"
|
||
|
||
int_guess = _try_int_coerce(nonempty)
|
||
if int_guess is not None:
|
||
return int_guess
|
||
if _try_float_coerce(nonempty):
|
||
return "DOUBLE PRECISION"
|
||
if _try_date_coerce(nonempty):
|
||
return "DATE"
|
||
if _try_datetime_coerce(nonempty):
|
||
return "TIMESTAMP"
|
||
return "TEXT"
|
||
|
||
|
||
def infer_schema(
|
||
df: pd.DataFrame,
|
||
meta: Any,
|
||
*,
|
||
coerce_chars: bool = COERCE_CHAR_COLUMNS,
|
||
total_rows: Optional[int] = None,
|
||
column_types: Optional[Dict[str, str]] = None,
|
||
force_nullable: bool = False,
|
||
) -> Dict[str, ColumnSpec]:
|
||
"""Infer a Postgres column spec for each column in ``df``.
|
||
|
||
``meta`` is the pyreadstat metadata object; we read
|
||
``meta.original_variable_types`` (a dict keyed by column name) for
|
||
format-driven date/time/timestamp inference.
|
||
|
||
The ``coerce_chars`` kwarg lets callers override the module-level
|
||
``COERCE_CHAR_COLUMNS`` without mutating global state. Internally the
|
||
char-inference helpers still read the constant - a full override would
|
||
thread the flag through, but the one-knob story here is intentional.
|
||
|
||
``total_rows`` lets callers who already sampled the frame (e.g. via
|
||
:func:`read_sas_preview`) report the real file size in the per-column
|
||
"inferred from first N of M rows" note. Falls back to ``len(df)``.
|
||
|
||
``column_types`` is an optional map ``{column_name: pg_type_str}``
|
||
whose entries bypass inference entirely - the caller has already
|
||
decided the type (e.g. via :func:`union_column_types` across a
|
||
cluster, or a YAML ``column_types`` override). Nullability is still
|
||
computed from the data. Columns in ``column_types`` that don't exist
|
||
in ``df`` are ignored so a shared override dict can apply to clusters
|
||
with different column sets.
|
||
|
||
``force_nullable=True`` stamps every column nullable regardless of
|
||
what the data sample shows. Escape hatch for when inference marks a
|
||
column ``NOT NULL`` because the sampled rows happened to be dense but
|
||
downstream files carry nulls in that column - common with cluster
|
||
loads where one file's preview can't speak for the rest. Cheaper than
|
||
trying to sharpen the sampler: widen the column and move on.
|
||
"""
|
||
original_formats: Dict[str, str] = dict(getattr(meta, "original_variable_types", {}) or {})
|
||
|
||
# When ``TYPE_INFERENCE_SAMPLE_ROWS`` is an integer cap, row-walking type
|
||
# probes run on the head slice for speed; nullability and the all-null
|
||
# check still walk every row of ``df``. That's only honest when the
|
||
# caller handed us the full file - with the default cap of ``None`` the
|
||
# CLI does exactly that. Callers who pass a partial preview and a tight
|
||
# integer cap accept that ``NOT NULL`` can be wrong for rare-null columns.
|
||
df_rows = len(df)
|
||
effective_total = total_rows if total_rows is not None else df_rows
|
||
if TYPE_INFERENCE_SAMPLE_ROWS is not None and df_rows > TYPE_INFERENCE_SAMPLE_ROWS:
|
||
sample_df = df.head(TYPE_INFERENCE_SAMPLE_ROWS)
|
||
sample_size = TYPE_INFERENCE_SAMPLE_ROWS
|
||
else:
|
||
sample_df = df
|
||
sample_size = df_rows
|
||
sampled = sample_size < effective_total
|
||
|
||
overrides: Dict[str, str] = dict(column_types or {})
|
||
|
||
# Temporarily flip the module-level flag if the caller asked us to.
|
||
global COERCE_CHAR_COLUMNS
|
||
saved = COERCE_CHAR_COLUMNS
|
||
COERCE_CHAR_COLUMNS = coerce_chars
|
||
try:
|
||
out: Dict[str, ColumnSpec] = {}
|
||
for col in df.columns:
|
||
series = df[col]
|
||
sample_series = sample_df[col]
|
||
sas_format = original_formats.get(col)
|
||
notes: List[str] = []
|
||
|
||
if col in overrides:
|
||
pg_type = overrides[col]
|
||
notes.append(
|
||
f"type forced to {pg_type} via column_types override"
|
||
)
|
||
if force_nullable:
|
||
nullable = True
|
||
notes.append("nullable forced via --all-nullable")
|
||
else:
|
||
nullable = _is_nullable(series)
|
||
out[col] = ColumnSpec(
|
||
name=col,
|
||
postgres_type=pg_type,
|
||
nullable=nullable,
|
||
sas_format=sas_format,
|
||
source_dtype=str(series.dtype),
|
||
notes=notes,
|
||
sampled=sampled,
|
||
)
|
||
continue
|
||
|
||
pg_type = _format_driven_type(sas_format)
|
||
|
||
if pg_type is None:
|
||
if _all_null(series):
|
||
pg_type = "TEXT"
|
||
notes.append("all-null column; defaulting to TEXT")
|
||
elif pd.api.types.is_datetime64_any_dtype(series):
|
||
pg_type = "TIMESTAMP"
|
||
elif pd.api.types.is_object_dtype(series):
|
||
is_dates, any_dt = _object_is_dates(sample_series)
|
||
if is_dates:
|
||
pg_type = "TIMESTAMP" if any_dt else "DATE"
|
||
else:
|
||
pg_type = _infer_char_type(sample_series)
|
||
elif pd.api.types.is_numeric_dtype(series):
|
||
int_target = _numeric_int_target(sample_series)
|
||
if int_target is not None:
|
||
pg_type = int_target
|
||
else:
|
||
pg_type = "DOUBLE PRECISION"
|
||
else:
|
||
pg_type = "TEXT"
|
||
notes.append(f"unhandled dtype {series.dtype}; defaulting to TEXT")
|
||
|
||
if sampled:
|
||
notes.append(
|
||
f"type inferred from first {sample_size:,} of "
|
||
f"{effective_total:,} rows"
|
||
)
|
||
|
||
if force_nullable:
|
||
nullable = True
|
||
notes.append("nullable forced via --all-nullable")
|
||
else:
|
||
nullable = _is_nullable(series)
|
||
|
||
out[col] = ColumnSpec(
|
||
name=col,
|
||
postgres_type=pg_type,
|
||
nullable=nullable,
|
||
sas_format=sas_format,
|
||
source_dtype=str(series.dtype),
|
||
notes=notes,
|
||
sampled=sampled,
|
||
)
|
||
return out
|
||
finally:
|
||
COERCE_CHAR_COLUMNS = saved
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Table management
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _quote_ident(ident: str) -> str:
|
||
"""Quote a Postgres identifier. psycopg2 doesn't expose this directly
|
||
until 2.8+ with sql.Identifier; we do it by hand to stay driver-simple."""
|
||
return '"' + ident.replace('"', '""') + '"'
|
||
|
||
|
||
def _qualified(schema: str, table: str) -> str:
|
||
return f"{_quote_ident(schema)}.{_quote_ident(table)}"
|
||
|
||
|
||
def _table_exists(conn, schema: str, table: str) -> bool:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"SELECT 1 FROM information_schema.tables "
|
||
"WHERE table_schema = %s AND table_name = %s",
|
||
(schema, table),
|
||
)
|
||
return cur.fetchone() is not None
|
||
|
||
|
||
def render_create_table(
|
||
schema: str,
|
||
table: str,
|
||
columns: Dict[str, ColumnSpec],
|
||
*,
|
||
partition_by: Optional[List[str]] = None,
|
||
) -> str:
|
||
"""Render a ``CREATE TABLE`` statement.
|
||
|
||
When ``partition_by`` is provided and non-empty, appends a
|
||
``PARTITION BY LIST ("first_field")`` clause to the statement.
|
||
"""
|
||
lines = []
|
||
for spec in columns.values():
|
||
null_clause = "" if spec.nullable else " NOT NULL"
|
||
lines.append(f" {_quote_ident(spec.name)} {spec.postgres_type}{null_clause}")
|
||
body = ",\n".join(lines)
|
||
suffix = ""
|
||
if partition_by:
|
||
suffix = f"\nPARTITION BY LIST ({_quote_ident(partition_by[0])})"
|
||
return f"CREATE TABLE {_qualified(schema, table)} (\n{body}\n){suffix};"
|
||
|
||
|
||
def _create_table_sql(
|
||
conn,
|
||
schema: str,
|
||
table: str,
|
||
columns: Dict[str, ColumnSpec],
|
||
*,
|
||
partition_by: Optional[List[str]] = None,
|
||
) -> None:
|
||
"""Execute a ``CREATE TABLE`` statement, optionally with partitioning."""
|
||
sql = render_create_table(schema, table, columns, partition_by=partition_by)
|
||
with conn.cursor() as cur:
|
||
cur.execute(sql)
|
||
|
||
|
||
def _drop_table(conn, schema: str, table: str, *, cascade: bool = False) -> None:
|
||
"""Drop a table, optionally with CASCADE for partitioned tables."""
|
||
tail = " CASCADE" if cascade else ""
|
||
with conn.cursor() as cur:
|
||
cur.execute(f"DROP TABLE {_qualified(schema, table)}{tail}")
|
||
|
||
|
||
# Normalization table: map both loader-emitted and Postgres-reported type
|
||
# strings to a single canonical family name. Ignore length/precision
|
||
# modifiers like VARCHAR(n) and NUMERIC(p,s).
|
||
_TYPE_NORMALIZATION: Dict[str, str] = {
|
||
"INTEGER": "integer",
|
||
"INT": "integer",
|
||
"INT4": "integer",
|
||
"BIGINT": "bigint",
|
||
"INT8": "bigint",
|
||
"SMALLINT": "smallint",
|
||
"INT2": "smallint",
|
||
"DOUBLE PRECISION": "double precision",
|
||
"FLOAT8": "double precision",
|
||
"REAL": "real",
|
||
"FLOAT4": "real",
|
||
"NUMERIC": "numeric",
|
||
"DECIMAL": "numeric",
|
||
"TEXT": "text",
|
||
"VARCHAR": "character varying",
|
||
"CHARACTER VARYING": "character varying",
|
||
"CHAR": "character",
|
||
"CHARACTER": "character",
|
||
"BPCHAR": "character",
|
||
"BOOLEAN": "boolean",
|
||
"BOOL": "boolean",
|
||
"DATE": "date",
|
||
"TIMESTAMP": "timestamp without time zone",
|
||
"TIMESTAMP WITHOUT TIME ZONE": "timestamp without time zone",
|
||
"TIMESTAMPTZ": "timestamp with time zone",
|
||
"TIMESTAMP WITH TIME ZONE": "timestamp with time zone",
|
||
"TIME": "time without time zone",
|
||
"TIME WITHOUT TIME ZONE": "time without time zone",
|
||
"TIMETZ": "time with time zone",
|
||
"TIME WITH TIME ZONE": "time with time zone",
|
||
}
|
||
|
||
|
||
def _normalize_type(pg_type: str) -> str:
|
||
"""Strip length/precision modifiers and map to canonical family."""
|
||
stripped = pg_type.strip().upper()
|
||
# Remove trailing (n) / (p,s) before the space-separated tail.
|
||
# Examples: "VARCHAR(10)" -> "VARCHAR"; "TIMESTAMP(6) WITHOUT TIME ZONE" -> "TIMESTAMP WITHOUT TIME ZONE"
|
||
stripped = re.sub(r"\(\s*\d+\s*(?:,\s*\d+\s*)?\)", "", stripped).strip()
|
||
# Collapse doubled whitespace after paren removal.
|
||
stripped = re.sub(r"\s+", " ", stripped)
|
||
return _TYPE_NORMALIZATION.get(stripped, stripped.lower())
|
||
|
||
|
||
# Widening pairs: (inferred_from_source, existing_in_target). When the
|
||
# incoming spec is narrower than the target we accept it - the value is
|
||
# guaranteed to fit, and ``_prepare_for_copy`` already emits ``COPY``
|
||
# payloads that Postgres silently promotes to the wider column type. The
|
||
# INVERSE direction stays a hard failure: a BIGINT value does not fit in
|
||
# an INTEGER column, so we must not let a cluster whose first file had
|
||
# only small ints accept a later file with a value past int32. Comes up
|
||
# most often on cluster loads where file 1 pushed the target to BIGINT
|
||
# (a single value > 2_147_483_647) and file N happens to sit entirely
|
||
# within int32 range - strict equality would reject file N even though
|
||
# the copy is trivially safe.
|
||
_WIDENING_COMPATIBLE: set = {
|
||
("smallint", "integer"),
|
||
("smallint", "bigint"),
|
||
("integer", "bigint"),
|
||
("real", "double precision"),
|
||
# INTEGER / BIGINT into DOUBLE PRECISION is lossless for int32 and
|
||
# exact up to 2**53 for int64, which covers every value pandas could
|
||
# have carried through as Int64 without wrapping anyway.
|
||
("integer", "double precision"),
|
||
("bigint", "double precision"),
|
||
}
|
||
|
||
|
||
def _assert_schema_compatible(
|
||
conn, schema: str, table: str, columns: Dict[str, ColumnSpec]
|
||
) -> None:
|
||
"""Pre-flight check for if_exists=append. See plan section on option B."""
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"SELECT column_name, data_type, is_nullable "
|
||
"FROM information_schema.columns "
|
||
"WHERE table_schema = %s AND table_name = %s",
|
||
(schema, table),
|
||
)
|
||
existing = {row[0]: (row[1], row[2]) for row in cur.fetchall()}
|
||
|
||
mismatches: List[str] = []
|
||
warnings: List[str] = []
|
||
|
||
for name, spec in columns.items():
|
||
if name not in existing:
|
||
mismatches.append(
|
||
f"column {name!r} not present in target {schema}.{table}"
|
||
)
|
||
continue
|
||
target_type, target_nullable = existing[name]
|
||
inferred_norm = _normalize_type(spec.postgres_type)
|
||
target_norm = _normalize_type(target_type)
|
||
if inferred_norm != target_norm:
|
||
if (inferred_norm, target_norm) in _WIDENING_COMPATIBLE:
|
||
# Narrower inferred type fits inside the wider target.
|
||
# Accept silently-but-noisily so the operator knows the
|
||
# file came in with a smaller range than the cluster's
|
||
# target was sized for.
|
||
warnings.append(
|
||
f"column {name!r}: inferred {spec.postgres_type} "
|
||
f"(narrower than target {target_type}); accepting - "
|
||
f"values fit in the wider target type"
|
||
)
|
||
else:
|
||
mismatches.append(
|
||
f"column {name!r}: inferred {spec.postgres_type} "
|
||
f"(normalized {inferred_norm!r}) but target is {target_type} "
|
||
f"(normalized {target_norm!r})"
|
||
)
|
||
target_is_notnull = (target_nullable == "NO")
|
||
if spec.nullable and target_is_notnull:
|
||
warnings.append(
|
||
f"column {name!r}: incoming allows NULLs but target is NOT NULL; "
|
||
"COPY will fail if any NULLs appear"
|
||
)
|
||
|
||
for w in warnings:
|
||
print(f"[warn] {w}", file=sys.stderr)
|
||
|
||
if mismatches:
|
||
raise SchemaCompatibilityError(
|
||
"append-mode schema compatibility check failed:\n - "
|
||
+ "\n - ".join(mismatches)
|
||
)
|
||
|
||
|
||
def assert_schema_compatible(
|
||
conn,
|
||
schema_name: str,
|
||
table_name: str,
|
||
columns: Dict[str, ColumnSpec],
|
||
) -> None:
|
||
"""Public wrapper around :func:`_assert_schema_compatible`.
|
||
|
||
Intended for orchestrators (e.g. the folder loader) that append multiple
|
||
files into one table and need to re-run the same compatibility check
|
||
that ``if_exists=append`` performs internally. Raises
|
||
:class:`SchemaCompatibilityError` on mismatch.
|
||
"""
|
||
_assert_schema_compatible(conn, schema_name, table_name, columns)
|
||
|
||
|
||
def create_table(
|
||
conn,
|
||
schema_name: str,
|
||
table_name: str,
|
||
columns: Dict[str, ColumnSpec],
|
||
if_exists: str,
|
||
*,
|
||
partition_by: Optional[List[str]] = None,
|
||
partition_values: Optional[dict] = None,
|
||
max_partitions: int = 10_000,
|
||
) -> None:
|
||
"""Create (or verify) the target table according to ``if_exists``.
|
||
|
||
When ``partition_by`` is provided and non-empty, the parent table is
|
||
created with ``PARTITION BY LIST`` and all child partition DDL from
|
||
:func:`render_partition_ddl` is executed immediately after.
|
||
|
||
For ``replace`` mode the existing table is dropped with ``CASCADE`` so
|
||
all child partitions are removed automatically.
|
||
|
||
For ``append`` mode partition creation is skipped entirely — the
|
||
partitions are assumed to already exist from the original creation.
|
||
"""
|
||
if if_exists not in VALID_IF_EXISTS:
|
||
raise ValueError(f"if_exists must be one of {VALID_IF_EXISTS}, got {if_exists!r}")
|
||
|
||
is_partitioned = bool(partition_by)
|
||
|
||
exists = _table_exists(conn, schema_name, table_name)
|
||
if exists:
|
||
if if_exists == "fail":
|
||
raise TableExistsError(
|
||
f"Table {schema_name}.{table_name} already exists and if_exists=fail"
|
||
)
|
||
if if_exists == "replace":
|
||
_drop_table(conn, schema_name, table_name, cascade=is_partitioned)
|
||
_create_table_sql(
|
||
conn, schema_name, table_name, columns,
|
||
partition_by=partition_by,
|
||
)
|
||
if is_partitioned and partition_values is not None:
|
||
ddl_stmts = render_partition_ddl(
|
||
schema_name, table_name, partition_by, partition_values,
|
||
columns, max_partitions=max_partitions,
|
||
)
|
||
with conn.cursor() as cur:
|
||
for stmt in ddl_stmts:
|
||
cur.execute(stmt)
|
||
return
|
||
if if_exists == "append":
|
||
_assert_schema_compatible(conn, schema_name, table_name, columns)
|
||
return
|
||
else:
|
||
_create_table_sql(
|
||
conn, schema_name, table_name, columns,
|
||
partition_by=partition_by,
|
||
)
|
||
if is_partitioned and partition_values is not None:
|
||
ddl_stmts = render_partition_ddl(
|
||
schema_name, table_name, partition_by, partition_values,
|
||
columns, max_partitions=max_partitions,
|
||
)
|
||
with conn.cursor() as cur:
|
||
for stmt in ddl_stmts:
|
||
cur.execute(stmt)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Partition support
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _sanitize_partition_value(value: Any, parent_table: str = "") -> str:
|
||
"""Convert a partition value into a safe, deterministic table-name suffix.
|
||
|
||
Rules:
|
||
- Convert to string, lowercase
|
||
- Replace non-alphanumeric runs with ``_``
|
||
- Collapse consecutive underscores, strip leading/trailing ``_``
|
||
- None/NaN → ``null``; empty string → ``empty``
|
||
- Truncate to fit within PostgreSQL's 63-character identifier limit
|
||
accounting for ``parent_table`` + ``_`` separator
|
||
"""
|
||
if value is None or (isinstance(value, float) and (pd.isna(value) or math.isnan(value))):
|
||
token = "null"
|
||
elif isinstance(value, dt.date) or isinstance(value, dt.datetime):
|
||
token = value.isoformat()
|
||
elif isinstance(value, dt.time):
|
||
token = value.isoformat()
|
||
else:
|
||
token = str(value)
|
||
|
||
token = token.lower()
|
||
token = re.sub(r"[^a-z0-9]+", "_", token)
|
||
token = re.sub(r"_+", "_", token)
|
||
token = token.strip("_")
|
||
|
||
if not token:
|
||
if value is None or (isinstance(value, float) and pd.isna(value)):
|
||
token = "null"
|
||
elif isinstance(value, str) and value == "":
|
||
token = "empty"
|
||
else:
|
||
token = "value"
|
||
|
||
# Truncate to keep total table name within PG's 63-char limit.
|
||
if parent_table:
|
||
# Reserve room for parent + underscore separator.
|
||
max_token_len = _PG_IDENT_MAX_LEN - len(parent_table) - 1
|
||
if max_token_len < 1:
|
||
raise ValueError(
|
||
f"Parent table name {parent_table!r} is too long "
|
||
f"({len(parent_table)} chars) to create child partitions."
|
||
)
|
||
if len(token) > max_token_len:
|
||
token = token[:max_token_len].rstrip("_")
|
||
|
||
return token
|
||
|
||
|
||
def _render_partition_value_literal(value: Any, pg_type: str) -> str:
|
||
"""Render a Python value as a SQL literal for ``FOR VALUES IN (...)``.
|
||
|
||
- None/NaN → ``NULL``
|
||
- Strings → single-quoted with ``'`` escaped to ``''``
|
||
- Numbers → plain numeric literal
|
||
- Booleans → ``TRUE`` / ``FALSE``
|
||
- Dates → ``DATE 'YYYY-MM-DD'``
|
||
- Timestamps → ``TIMESTAMP 'YYYY-MM-DD HH:MM:SS'``
|
||
- Times → ``TIME 'HH:MM:SS'``
|
||
"""
|
||
if value is None or (isinstance(value, float) and pd.isna(value)):
|
||
return "NULL"
|
||
|
||
pg_upper = pg_type.upper()
|
||
|
||
if pg_upper in ("BOOLEAN", "BOOL"):
|
||
return "TRUE" if value else "FALSE"
|
||
|
||
if pg_upper in ("INTEGER", "BIGINT", "SMALLINT", "INT", "INT4", "INT8", "INT2"):
|
||
return str(int(value))
|
||
|
||
if pg_upper in ("DOUBLE PRECISION", "REAL", "NUMERIC", "DECIMAL",
|
||
"FLOAT4", "FLOAT8"):
|
||
return str(value)
|
||
|
||
if pg_upper == "DATE":
|
||
if isinstance(value, (dt.date, dt.datetime)):
|
||
return f"DATE '{value.isoformat()}'"
|
||
return f"DATE '{value}'"
|
||
|
||
if pg_upper in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE",
|
||
"TIMESTAMP WITH TIME ZONE", "TIMESTAMPTZ"):
|
||
if isinstance(value, (dt.datetime, pd.Timestamp)):
|
||
return f"TIMESTAMP '{value.isoformat()}'"
|
||
if isinstance(value, dt.date):
|
||
return f"TIMESTAMP '{dt.datetime(value.year, value.month, value.day).isoformat()}'"
|
||
return f"TIMESTAMP '{value}'"
|
||
|
||
if pg_upper in ("TIME", "TIME WITHOUT TIME ZONE",
|
||
"TIME WITH TIME ZONE", "TIMETZ"):
|
||
if isinstance(value, dt.time):
|
||
return f"TIME '{value.isoformat()}'"
|
||
return f"TIME '{value}'"
|
||
|
||
# Default: treat as text — single-quote with escaping.
|
||
escaped = str(value).replace("'", "''")
|
||
return f"'{escaped}'"
|
||
|
||
|
||
def _normalize_partition_value(value: Any, pg_type: str) -> Any:
|
||
"""Normalize a raw partition value to its Python-native form.
|
||
|
||
Applies the same semantic normalization that :func:`_prepare_for_copy`
|
||
uses, so partition discovery deduplicates on the routed value rather
|
||
than the raw source representation.
|
||
"""
|
||
# Handle pandas null types
|
||
if value is None:
|
||
return None
|
||
if isinstance(value, float) and (pd.isna(value) or math.isnan(value)):
|
||
return None
|
||
try:
|
||
if pd.isna(value):
|
||
return None
|
||
except (TypeError, ValueError):
|
||
pass
|
||
|
||
pg_upper = pg_type.upper()
|
||
|
||
if pg_upper in ("INTEGER", "BIGINT", "SMALLINT", "INT", "INT4", "INT8", "INT2"):
|
||
if isinstance(value, str):
|
||
value = value.strip()
|
||
if value == "":
|
||
return None
|
||
try:
|
||
return int(float(value))
|
||
except (TypeError, ValueError):
|
||
return None
|
||
|
||
if pg_upper in ("DOUBLE PRECISION", "REAL", "NUMERIC", "DECIMAL",
|
||
"FLOAT4", "FLOAT8"):
|
||
if isinstance(value, str):
|
||
value = value.strip()
|
||
if value == "":
|
||
return None
|
||
try:
|
||
result = float(value)
|
||
return None if math.isnan(result) else result
|
||
except (TypeError, ValueError):
|
||
return None
|
||
|
||
if pg_upper == "DATE":
|
||
if isinstance(value, dt.datetime):
|
||
return value.date()
|
||
if isinstance(value, dt.date):
|
||
return value
|
||
if isinstance(value, str):
|
||
if value.strip() == "":
|
||
return None
|
||
return _parse_flexible_date(value)
|
||
return None
|
||
|
||
if pg_upper in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE",
|
||
"TIMESTAMP WITH TIME ZONE", "TIMESTAMPTZ"):
|
||
if isinstance(value, dt.datetime):
|
||
return value
|
||
if isinstance(value, pd.Timestamp):
|
||
return value.to_pydatetime() if not pd.isna(value) else None
|
||
if isinstance(value, dt.date):
|
||
return dt.datetime(value.year, value.month, value.day)
|
||
if isinstance(value, str):
|
||
if value.strip() == "":
|
||
return None
|
||
return _parse_flexible_datetime(value)
|
||
return None
|
||
|
||
if pg_upper in ("TIME", "TIME WITHOUT TIME ZONE",
|
||
"TIME WITH TIME ZONE", "TIMETZ"):
|
||
return _seconds_to_time(value)
|
||
|
||
if pg_upper in ("BOOLEAN", "BOOL"):
|
||
if isinstance(value, bool):
|
||
return value
|
||
if isinstance(value, (int, float)):
|
||
return bool(value)
|
||
if isinstance(value, str):
|
||
return value.strip().lower() in ("true", "1", "t", "yes")
|
||
return None
|
||
|
||
# Text-like types: None, pandas nulls, and '' all become None
|
||
# because copy_dataframes() sends empty strings with NULL ''.
|
||
if pg_upper in ("TEXT", "VARCHAR", "CHARACTER VARYING", "CHAR", "CHARACTER", "BPCHAR"):
|
||
if isinstance(value, str):
|
||
if value == "":
|
||
return None
|
||
return value
|
||
return str(value)
|
||
|
||
# Fallback: return as-is converted to native Python type
|
||
if hasattr(value, "item"):
|
||
return value.item()
|
||
return value
|
||
|
||
|
||
def discover_partition_values(
|
||
df: pd.DataFrame,
|
||
partition_by: list[str],
|
||
columns: Optional[Dict[str, ColumnSpec]] = None,
|
||
) -> dict:
|
||
"""Build a nested structure of unique partition values from a DataFrame.
|
||
|
||
For ``partition_by = ['state', 'zip']`` returns::
|
||
|
||
{
|
||
'MO': {'63101': {}, '63102': {}},
|
||
'IL': {'62001': {}, '62002': {}}
|
||
}
|
||
|
||
When ``columns`` is provided, values are normalized through
|
||
:func:`_normalize_partition_value` to match the routed values Postgres
|
||
will see during ``COPY``.
|
||
|
||
None/NaN values are included as a distinct partition value (``None`` key).
|
||
Values are converted to Python native types (not numpy types).
|
||
"""
|
||
if not partition_by:
|
||
return {}
|
||
|
||
def _to_native(val: Any) -> Any:
|
||
"""Convert numpy scalars to Python native types."""
|
||
if val is None:
|
||
return None
|
||
if isinstance(val, float) and pd.isna(val):
|
||
return None
|
||
if hasattr(val, "item"):
|
||
return val.item()
|
||
return val
|
||
|
||
def _build_level(sub_df: pd.DataFrame, fields: list[str]) -> dict:
|
||
if not fields or sub_df.empty:
|
||
return {}
|
||
|
||
field = fields[0]
|
||
remaining = fields[1:]
|
||
result: dict = {}
|
||
|
||
# Get unique values, handling NaN
|
||
unique_vals = sub_df[field].unique()
|
||
|
||
for raw_val in unique_vals:
|
||
val = _to_native(raw_val)
|
||
|
||
# Normalize if column spec is available
|
||
if columns and field in columns:
|
||
val = _normalize_partition_value(val, columns[field].postgres_type)
|
||
|
||
if remaining:
|
||
# Filter rows matching this value
|
||
if val is None:
|
||
mask = sub_df[field].isna() | sub_df[field].map(
|
||
lambda v: v is None or (isinstance(v, float) and pd.isna(v))
|
||
or (isinstance(v, str) and v == ""
|
||
and columns and field in columns
|
||
and columns[field].postgres_type.upper() in (
|
||
"TEXT", "VARCHAR", "CHARACTER VARYING",
|
||
"CHAR", "CHARACTER", "BPCHAR"))
|
||
)
|
||
else:
|
||
mask = sub_df[field].map(lambda v, target=val: _matches(v, target, field))
|
||
child_df = sub_df[mask]
|
||
result[val] = _build_level(child_df, remaining)
|
||
else:
|
||
result[val] = {}
|
||
|
||
return result
|
||
|
||
def _matches(raw_val: Any, target: Any, field_name: str) -> bool:
|
||
"""Check if a raw value normalizes to the target."""
|
||
native = _to_native(raw_val)
|
||
if columns and field_name in columns:
|
||
native = _normalize_partition_value(native, columns[field_name].postgres_type)
|
||
if target is None:
|
||
return native is None
|
||
return native == target
|
||
|
||
return _build_level(df, list(partition_by))
|
||
|
||
|
||
def discover_partition_values_chunked(
|
||
chunk_iter: Iterable[pd.DataFrame],
|
||
partition_by: list[str],
|
||
columns: Optional[Dict[str, ColumnSpec]] = None,
|
||
) -> dict:
|
||
"""Discover partition values across an iterable of DataFrame chunks.
|
||
|
||
Scans the entire file chunk-by-chunk, collecting unique partition
|
||
column values and merging them into a single nested partition tree.
|
||
This avoids materializing the full file in memory.
|
||
"""
|
||
if not partition_by:
|
||
return {}
|
||
|
||
merged: dict = {}
|
||
|
||
for chunk_df in chunk_iter:
|
||
if chunk_df.empty:
|
||
continue
|
||
# Only keep partition columns to minimize memory
|
||
part_cols = [c for c in partition_by if c in chunk_df.columns]
|
||
if len(part_cols) != len(partition_by):
|
||
missing = [c for c in partition_by if c not in chunk_df.columns]
|
||
raise ValueError(
|
||
f"Partition columns not found in data: {missing}"
|
||
)
|
||
sub_df = chunk_df[part_cols]
|
||
chunk_tree = discover_partition_values(sub_df, partition_by, columns)
|
||
_merge_partition_trees(merged, chunk_tree)
|
||
|
||
return merged
|
||
|
||
|
||
def _merge_partition_trees(target: dict, source: dict) -> None:
|
||
"""Merge ``source`` partition tree into ``target`` in place.
|
||
|
||
Both trees are nested dicts where keys are partition values and values
|
||
are either empty dicts (leaf) or nested dicts (intermediate levels).
|
||
"""
|
||
for key, subtree in source.items():
|
||
if key not in target:
|
||
target[key] = subtree
|
||
else:
|
||
# Merge children recursively
|
||
if subtree and target[key]:
|
||
_merge_partition_trees(target[key], subtree)
|
||
elif subtree:
|
||
target[key] = subtree
|
||
|
||
|
||
def _count_partitions(tree: dict) -> int:
|
||
"""Count total partition tables in a nested partition tree."""
|
||
count = 0
|
||
for _key, children in tree.items():
|
||
count += 1
|
||
if children:
|
||
count += _count_partitions(children)
|
||
return count
|
||
|
||
|
||
def render_partition_ddl(
|
||
schema: str,
|
||
parent_table: str,
|
||
partition_by: list[str],
|
||
partition_values: dict,
|
||
column_specs: Dict[str, ColumnSpec],
|
||
*,
|
||
max_partitions: int = 10_000,
|
||
) -> list[str]:
|
||
"""Generate all child partition DDL statements for the partition tree.
|
||
|
||
Returns a list of SQL strings to execute in order (depth-first).
|
||
The parent ``CREATE TABLE`` is NOT included — it is rendered separately
|
||
by :func:`render_create_table`.
|
||
|
||
Logs a warning if the total partition count exceeds ``max_partitions``,
|
||
but continues.
|
||
"""
|
||
if not partition_by or not partition_values:
|
||
return []
|
||
|
||
total = _count_partitions(partition_values)
|
||
if total > max_partitions:
|
||
logger.warning(
|
||
"Partition count (%d) exceeds threshold (%d). "
|
||
"This may impact database performance.",
|
||
total, max_partitions,
|
||
)
|
||
print(
|
||
f"[warn] partition plan for {schema}.{parent_table} will create "
|
||
f"{total:,} partition tables, exceeding max_partitions={max_partitions:,}",
|
||
file=sys.stderr,
|
||
)
|
||
|
||
# Track used child names at each parent level to detect collisions
|
||
statements: list[str] = []
|
||
_render_partition_ddl_recursive(
|
||
schema, parent_table, partition_by, partition_values,
|
||
column_specs, 0, statements,
|
||
)
|
||
return statements
|
||
|
||
|
||
def _render_partition_ddl_recursive(
|
||
schema: str,
|
||
parent_table: str,
|
||
partition_by: list[str],
|
||
values: dict,
|
||
column_specs: Dict[str, ColumnSpec],
|
||
depth: int,
|
||
statements: list[str],
|
||
) -> None:
|
||
"""Recursively generate partition DDL statements (depth-first)."""
|
||
field_name = partition_by[depth]
|
||
next_field = partition_by[depth + 1] if depth + 1 < len(partition_by) else None
|
||
field_spec = column_specs.get(field_name)
|
||
pg_type = field_spec.postgres_type if field_spec else "TEXT"
|
||
|
||
# Track names used at this level under this parent to handle collisions
|
||
used_names: Dict[str, Any] = {}
|
||
|
||
# Sort values deterministically: None first, then by string representation
|
||
def _sort_key(val: Any) -> Tuple[int, str]:
|
||
if val is None:
|
||
return (0, "")
|
||
return (1, str(val))
|
||
|
||
sorted_values = sorted(values.keys(), key=_sort_key)
|
||
|
||
for val in sorted_values:
|
||
children = values[val]
|
||
token = _sanitize_partition_value(val, parent_table)
|
||
child_name = f"{parent_table}_{token}"
|
||
|
||
# Handle collisions
|
||
if child_name in used_names and used_names[child_name] is not val:
|
||
# Append a short hash of the value to disambiguate
|
||
val_hash = hashlib.sha256(repr(val).encode()).hexdigest()[:8]
|
||
# Re-truncate token to make room for _hash
|
||
max_token_len = _PG_IDENT_MAX_LEN - len(parent_table) - 1 - 9 # _hash8
|
||
if max_token_len < 1:
|
||
max_token_len = 1
|
||
truncated_token = token[:max_token_len].rstrip("_")
|
||
child_name = f"{parent_table}_{truncated_token}_{val_hash}"
|
||
|
||
# Final length check
|
||
if len(child_name) > _PG_IDENT_MAX_LEN:
|
||
child_name = child_name[:_PG_IDENT_MAX_LEN]
|
||
|
||
used_names[child_name] = val
|
||
|
||
literal = _render_partition_value_literal(val, pg_type)
|
||
|
||
if next_field is not None:
|
||
# Intermediate partition: itself partitioned by the next field
|
||
stmt = (
|
||
f"CREATE TABLE {_qualified(schema, child_name)} "
|
||
f"PARTITION OF {_qualified(schema, parent_table)} "
|
||
f"FOR VALUES IN ({literal}) "
|
||
f"PARTITION BY LIST ({_quote_ident(next_field)});"
|
||
)
|
||
statements.append(stmt)
|
||
# Recurse into children
|
||
if children:
|
||
_render_partition_ddl_recursive(
|
||
schema, child_name, partition_by, children,
|
||
column_specs, depth + 1, statements,
|
||
)
|
||
else:
|
||
# Leaf partition
|
||
stmt = (
|
||
f"CREATE TABLE {_qualified(schema, child_name)} "
|
||
f"PARTITION OF {_qualified(schema, parent_table)} "
|
||
f"FOR VALUES IN ({literal});"
|
||
)
|
||
statements.append(stmt)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Index support
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def render_create_indexes(
|
||
schema: str,
|
||
tablename: str,
|
||
indexes: List[str],
|
||
) -> List[str]:
|
||
"""Generate ``CREATE INDEX IF NOT EXISTS`` DDL for each column in *indexes*.
|
||
|
||
Each index is a simple B-tree index on a single column. The index name
|
||
follows the pattern ``ix_{tablename}_{column}`` (raw, unsanitized names
|
||
wrapped with :func:`_quote_ident`). The table reference is fully
|
||
qualified as ``schema.tablename``.
|
||
|
||
If the generated index name exceeds PostgreSQL's 63-character identifier
|
||
limit, it is truncated and a short hash suffix is appended to preserve
|
||
uniqueness (similar to partition name truncation).
|
||
|
||
Returns a list of SQL strings, one per index.
|
||
"""
|
||
stmts: List[str] = []
|
||
for col in indexes:
|
||
idx_name = f"ix_{tablename}_{col}"
|
||
if len(idx_name) > _PG_IDENT_MAX_LEN:
|
||
# Truncate and append an 8-char hash for uniqueness.
|
||
name_hash = hashlib.sha256(idx_name.encode()).hexdigest()[:8]
|
||
# 9 = 1 underscore + 8 hash chars
|
||
truncated = idx_name[: _PG_IDENT_MAX_LEN - 9].rstrip("_")
|
||
idx_name = f"{truncated}_{name_hash}"
|
||
stmt = (
|
||
f"CREATE INDEX IF NOT EXISTS {_quote_ident(idx_name)} "
|
||
f"ON {_qualified(schema, tablename)} ({_quote_ident(col)});"
|
||
)
|
||
stmts.append(stmt)
|
||
return stmts
|
||
|
||
|
||
def create_indexes(
|
||
conn,
|
||
schema: str,
|
||
tablename: str,
|
||
indexes: List[str],
|
||
) -> None:
|
||
"""Execute ``CREATE INDEX IF NOT EXISTS`` for each column in *indexes*.
|
||
|
||
Calls :func:`render_create_indexes` to obtain the DDL, executes each
|
||
statement, commits immediately after each successful creation, and logs
|
||
progress to stderr. If an individual index creation fails (e.g. a name
|
||
collision unrelated to ``IF NOT EXISTS``), the transaction is rolled back
|
||
(affecting only the failed statement) and the remaining indexes are still
|
||
attempted.
|
||
"""
|
||
stmts = render_create_indexes(schema, tablename, indexes)
|
||
with conn.cursor() as cur:
|
||
for stmt, col in zip(stmts, indexes):
|
||
try:
|
||
cur.execute(stmt)
|
||
conn.commit()
|
||
print(
|
||
f"[info] created index ix_{tablename}_{col} "
|
||
f"on {schema}.{tablename}({col})",
|
||
file=sys.stderr,
|
||
)
|
||
except Exception as exc:
|
||
conn.rollback()
|
||
print(
|
||
f"[warn] failed to create index ix_{tablename}_{col} "
|
||
f"on {schema}.{tablename}({col}): {exc}",
|
||
file=sys.stderr,
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# COPY loading
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _seconds_to_time(v: Any) -> Optional[dt.time]:
|
||
if v is None:
|
||
return None
|
||
if isinstance(v, float) and pd.isna(v):
|
||
return None
|
||
if isinstance(v, dt.time):
|
||
return v
|
||
if isinstance(v, (dt.datetime, pd.Timestamp)):
|
||
return v.time() if not pd.isna(v) else None
|
||
try:
|
||
total = int(round(float(v)))
|
||
except (TypeError, ValueError):
|
||
return None
|
||
h, rem = divmod(total, 3600)
|
||
m, s = divmod(rem, 60)
|
||
# Clamp; TIME8. is always within a day.
|
||
h = max(0, min(h, 23))
|
||
return dt.time(h, m, s)
|
||
|
||
|
||
# Safe outer bound for the numeric->datetime conversion below. The true
|
||
# ceiling is ``pd.Timestamp.max`` (2262-04-11), which in seconds since 1960
|
||
# is ~9.52e9. We pick a much tighter bound - year ~2200, ~7.6e9 seconds,
|
||
# ~87600 days - because (a) any real SAS data past ~2100 is garbage anyway,
|
||
# and (b) staying well inside the float64 + datetime64[ns] windows gives
|
||
# pandas' internals zero room to trip the ``over="raise"`` they wrap
|
||
# around the ns-multiply. ``7.5e9 * 1e9 = 7.5e18``, comfortably under both
|
||
# ``int64.max`` (~9.22e18) and float64 overflow (~1.8e308).
|
||
_SAS_DATETIME_SAFE_S = 7_500_000_000
|
||
_SAS_DATETIME_SAFE_D = 87_000
|
||
|
||
# Number of non-null values :func:`_safe_object_to_datetime` peeks at to
|
||
# decide which parse path to use for the whole chunk. Keeps format
|
||
# detection to a bounded cost so a 1M-row chunk doesn't pay for a
|
||
# full row-walk just to figure out what shape its dates are in.
|
||
_DATETIME_FORMAT_SAMPLE = 16
|
||
|
||
|
||
def _safe_numeric_to_datetime(
|
||
series: pd.Series,
|
||
*,
|
||
unit: str,
|
||
column_name: str,
|
||
target_type: str,
|
||
) -> pd.Series:
|
||
"""Convert a numeric SAS-epoch series to ``datetime64[ns]`` without letting
|
||
one stray cell take down the worker.
|
||
|
||
Failure modes seen in production:
|
||
|
||
* ``np.inf`` / ``-np.inf`` slipping through pyreadstat (SAS missing-value
|
||
sentinels, divide-by-zero in the source, uninitialized cells).
|
||
* Absurdly large finite floats (e.g. ``1.7e308``) where ``value * 1e9``
|
||
overflows float64.
|
||
* Values between ``pd.Timestamp.max`` and float64 safety (~9.5e9 to 1e308
|
||
seconds) where the nanosecond multiply silently produces garbage or
|
||
overflows int64.
|
||
|
||
All of these trigger ``FloatingPointError: overflow encountered in multiply``
|
||
inside ``pd.to_datetime`` because pandas wraps the multiply in
|
||
``np.errstate(over="raise")`` -- our outer ``errors="coerce"`` never
|
||
gets a chance to turn the bad value into ``NaT``.
|
||
|
||
Strategy, belt + suspenders + airbag:
|
||
|
||
1. Coerce to float64 up front. Object-dtype branches hand us mixed
|
||
int/float/str; ``pd.to_numeric(errors="coerce")`` parses what it can
|
||
and NaNs the rest, so we hit the rest of this function with a
|
||
pristine float series.
|
||
2. Mask non-finite values and anything outside the safe epoch window to
|
||
NaN *before* ``pd.to_datetime`` sees them.
|
||
3. Run the conversion under a permissive ``errstate``.
|
||
4. If that still raises (some pandas version internally re-enables
|
||
``over="raise"`` in a way ``errstate`` can't override), catch it
|
||
and return all-NaT for the column with a loud warning. Better a
|
||
NULL column in one chunk than a dead worker + no diagnostics.
|
||
|
||
Emits one stderr line per chunk per affected column so silent data
|
||
loss doesn't sneak by.
|
||
"""
|
||
if not pd.api.types.is_float_dtype(series):
|
||
series = pd.to_numeric(series, errors="coerce").astype("float64")
|
||
|
||
arr = series.to_numpy(dtype="float64", copy=False, na_value=np.nan)
|
||
if unit == "s":
|
||
bound = _SAS_DATETIME_SAFE_S
|
||
elif unit == "D":
|
||
bound = _SAS_DATETIME_SAFE_D
|
||
else:
|
||
bound = _SAS_DATETIME_SAFE_S
|
||
with np.errstate(over="ignore", invalid="ignore", divide="ignore"):
|
||
finite_mask = np.isfinite(arr)
|
||
# ``np.abs(inf) -> inf``, ``np.abs(nan) -> nan``; both compare False
|
||
# to ``bound``, so ``in_range_mask`` already excludes non-finite
|
||
# values. The explicit ``finite_mask &`` below is belt-and-suspenders
|
||
# in case a future numpy changes that semantic.
|
||
in_range_mask = np.abs(arr) < bound
|
||
keep_mask = finite_mask & in_range_mask
|
||
was_present = ~np.isnan(arr)
|
||
coerced = int(((~keep_mask) & was_present).sum())
|
||
if coerced:
|
||
tqdm.write(
|
||
f"[warn] {target_type} column {column_name!r}: {coerced:,} "
|
||
f"row(s) had non-representable values (Inf/NaN/out-of-range), "
|
||
f"coerced to NULL",
|
||
file=sys.stderr,
|
||
)
|
||
cleaned_arr = np.where(keep_mask, arr, np.nan)
|
||
cleaned = pd.Series(cleaned_arr, index=series.index)
|
||
try:
|
||
with np.errstate(over="ignore", invalid="ignore", divide="ignore"):
|
||
return pd.to_datetime(
|
||
cleaned, unit=unit, origin="1960-01-01", errors="coerce",
|
||
)
|
||
except (FloatingPointError, OverflowError, ValueError) as exc:
|
||
tqdm.write(
|
||
f"[error] {target_type} column {column_name!r}: "
|
||
f"pd.to_datetime raised {type(exc).__name__}: {exc}; "
|
||
f"returning NaT for the entire chunk. This usually means one "
|
||
f"or more values slipped past the pre-mask (bound={bound}). "
|
||
f"Consider setting the column to TEXT via column_types if this "
|
||
f"recurs.",
|
||
file=sys.stderr,
|
||
)
|
||
return pd.Series(pd.NaT, index=series.index, dtype="datetime64[ns]")
|
||
|
||
|
||
def _safe_object_to_datetime(
|
||
series: pd.Series,
|
||
*,
|
||
column_name: str,
|
||
target_type: str,
|
||
) -> pd.Series:
|
||
"""Object-dtype to datetime. Shares the safety net (errstate +
|
||
try/except) with :func:`_safe_numeric_to_datetime`. If the column is
|
||
actually numeric-flavored (e.g. SAS wrote numbers into an object
|
||
column), route to the numeric path; otherwise try our explicit
|
||
``DD-MON-YY`` / strptime format set before falling back to the
|
||
generic ``pd.to_datetime`` dateutil parser.
|
||
|
||
The explicit-format pre-pass exists because:
|
||
* ``pd.to_datetime`` on unformatted object columns emits a
|
||
``UserWarning`` per chunk and parses row-by-row via ``dateutil``
|
||
-- 10-100× slower than a single vectorized strptime.
|
||
* ``dateutil`` *will* parse ``23-MAR-20`` but its 2-digit-year pivot
|
||
differs from SAS/Oracle convention in corner cases; applying our
|
||
own parser keeps behavior predictable.
|
||
"""
|
||
coerced = series.replace({"": None})
|
||
numeric = pd.to_numeric(coerced, errors="coerce")
|
||
all_numeric = numeric.notna().sum() == coerced.notna().sum()
|
||
if all_numeric and coerced.notna().any():
|
||
return _safe_numeric_to_datetime(
|
||
numeric, unit="s", column_name=column_name, target_type=target_type,
|
||
)
|
||
|
||
# Format sniff: peek at up to ``_DATETIME_FORMAT_SAMPLE`` non-null
|
||
# string values to decide the parse path for the WHOLE chunk. The
|
||
# previous version ran a full row-walk + up-to-20 vectorized
|
||
# ``pd.to_datetime(format=fmt)`` attempts per column per chunk; a
|
||
# single fat chunk (millions of rows × a few date columns) could
|
||
# pin a CPU for minutes. Sniffing first keeps the hot path to one
|
||
# O(n) pass.
|
||
non_null = coerced.dropna()
|
||
if not non_null.empty:
|
||
samples: List[str] = []
|
||
for v in non_null.head(_DATETIME_FORMAT_SAMPLE):
|
||
if isinstance(v, str):
|
||
samples.append(v.strip())
|
||
else:
|
||
# Mixed object column (e.g. already-parsed Timestamps +
|
||
# strings). Skip sniffing; let dateutil handle it.
|
||
samples = []
|
||
break
|
||
|
||
if samples:
|
||
# DD-MON-YY family: one pandas ``Series.map`` with our
|
||
# regex parser, then a single ``pd.to_datetime`` to land
|
||
# on ``datetime64[ns]``. ``pd.to_datetime(format=...)``
|
||
# has no ``%b``-with-locale-free semantics, so this is
|
||
# the vectorized win available for this format family.
|
||
if all(_DDMONYY_RE.match(s) for s in samples):
|
||
with np.errstate(over="ignore", invalid="ignore", divide="ignore"):
|
||
return pd.to_datetime(
|
||
coerced.map(
|
||
lambda v: _parse_flexible_datetime(v)
|
||
if isinstance(v, str) else None
|
||
),
|
||
errors="coerce",
|
||
)
|
||
|
||
# Numeric strptime shapes: pick the first format that
|
||
# parses every sample, then run ONE vectorized
|
||
# ``pd.to_datetime(format=fmt)`` over the full column.
|
||
# Bounded to the sample pass -- no 20×O(n) blow-up.
|
||
for fmt in _EXTRA_DATETIME_FORMATS + _EXTRA_DATE_FORMATS:
|
||
ok = True
|
||
for s in samples:
|
||
try:
|
||
dt.datetime.strptime(s, fmt)
|
||
except ValueError:
|
||
ok = False
|
||
break
|
||
if ok:
|
||
try:
|
||
with np.errstate(over="ignore", invalid="ignore", divide="ignore"):
|
||
return pd.to_datetime(
|
||
coerced, format=fmt, errors="coerce",
|
||
)
|
||
except (ValueError, TypeError):
|
||
continue
|
||
|
||
# Fallback: ``pd.to_datetime`` / dateutil. Handles shapes the
|
||
# sniffer missed (mixed formats within one column,
|
||
# already-parsed Timestamp/date objects sharing space with
|
||
# strings, ISO 8601 with offsets, etc.). Wrap in a warning
|
||
# filter because the unformatted path emits ``UserWarning:
|
||
# Could not infer format...`` once per chunk and we don't want
|
||
# the progress bar drowned.
|
||
try:
|
||
with warnings.catch_warnings():
|
||
warnings.simplefilter("ignore", category=UserWarning)
|
||
with np.errstate(over="ignore", invalid="ignore", divide="ignore"):
|
||
return pd.to_datetime(coerced, errors="coerce")
|
||
except (FloatingPointError, OverflowError, ValueError) as exc:
|
||
tqdm.write(
|
||
f"[error] {target_type} column {column_name!r}: "
|
||
f"pd.to_datetime raised {type(exc).__name__}: {exc}; "
|
||
f"returning NaT for the entire chunk.",
|
||
file=sys.stderr,
|
||
)
|
||
return pd.Series(pd.NaT, index=series.index, dtype="datetime64[ns]")
|
||
|
||
|
||
def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.DataFrame:
|
||
"""Materialize a copy of ``df`` with each column in the right shape for
|
||
``to_csv`` so the CSV lands as valid input for the target Postgres type.
|
||
|
||
Per-column conversions are vectorized (``.astype`` / ``pd.to_datetime`` /
|
||
``.mask`` / ``.fillna``) instead of the element-wise ``.map(func)``
|
||
loops this function used to run. That was the single largest per-chunk
|
||
CPU cost on text-heavy loads - a 40-column × 100k-row chunk was issuing
|
||
~4M Python-level function calls just to cast strings. TIME columns are
|
||
still the ``.map`` path because SAS TIME8 is stored as seconds and the
|
||
clamp-to-24h logic doesn't fit cleanly in vector form; they're also
|
||
rare in practice.
|
||
"""
|
||
out = pd.DataFrame(index=df.index)
|
||
for name, spec in columns.items():
|
||
series = df[name]
|
||
pg = spec.postgres_type.upper()
|
||
|
||
if pg in ("INTEGER", "BIGINT", "SMALLINT"):
|
||
if pd.api.types.is_object_dtype(series):
|
||
series = pd.to_numeric(
|
||
series.replace({"": None}), errors="coerce"
|
||
)
|
||
out[name] = series.astype("Int64")
|
||
elif pg in ("DOUBLE PRECISION", "REAL", "NUMERIC"):
|
||
if pd.api.types.is_object_dtype(series):
|
||
series = pd.to_numeric(
|
||
series.replace({"": None}), errors="coerce"
|
||
)
|
||
out[name] = series.astype("float64")
|
||
elif pg == "DATE":
|
||
if pd.api.types.is_datetime64_any_dtype(series):
|
||
out[name] = series.dt.date
|
||
elif pd.api.types.is_object_dtype(series):
|
||
# Vectorized parse: empty strings / None / unparseable -> NaT,
|
||
# then .dt.date yields date objects or NaT. NaT serializes as
|
||
# an empty CSV field (matching ``NULL ''`` in COPY). Routed
|
||
# through ``_safe_object_to_datetime`` so an object column
|
||
# that actually contains SAS-epoch numerics (seen when one
|
||
# file of a cluster stores the column as NUM and another as
|
||
# CHAR + the union flipped it to TEXT-then-DATE) can't trip
|
||
# the overflow-in-multiply bug.
|
||
parsed = _safe_object_to_datetime(
|
||
series, column_name=name, target_type="DATE",
|
||
)
|
||
out[name] = parsed.dt.date
|
||
elif pd.api.types.is_numeric_dtype(series):
|
||
# pyreadstat couldn't decode the SAS format (some
|
||
# ``DATEw.``/``YYMMDDw.`` variants and all custom formats slip
|
||
# through) so the column came back as float64: days since
|
||
# 1960-01-01, the SAS epoch. Without this branch the raw
|
||
# number would hit COPY and Postgres rejects it with
|
||
# ``invalid input syntax for type date``.
|
||
parsed = _safe_numeric_to_datetime(
|
||
series, unit="D", column_name=name, target_type="DATE",
|
||
)
|
||
out[name] = parsed.dt.date
|
||
else:
|
||
out[name] = series
|
||
elif pg in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE", "TIMESTAMP WITH TIME ZONE"):
|
||
if pd.api.types.is_datetime64_any_dtype(series):
|
||
out[name] = series
|
||
elif pd.api.types.is_object_dtype(series):
|
||
# Same rationale as the DATE object branch above: route
|
||
# through the safety net so numeric-flavored object columns
|
||
# can't blow us up during the ns multiply.
|
||
out[name] = _safe_object_to_datetime(
|
||
series, column_name=name, target_type="TIMESTAMP",
|
||
)
|
||
elif pd.api.types.is_numeric_dtype(series):
|
||
# Same story as the DATE branch above, but SAS datetimes are
|
||
# *seconds* since 1960-01-01 (fractional seconds for
|
||
# ``DATETIMEw.d``). Example caught in the wild:
|
||
# ``1915465463.615`` -> 2020-09-13 05:44:23.615.
|
||
out[name] = _safe_numeric_to_datetime(
|
||
series, unit="s", column_name=name, target_type="TIMESTAMP",
|
||
)
|
||
else:
|
||
out[name] = series
|
||
elif pg in ("TIME", "TIME WITHOUT TIME ZONE", "TIME WITH TIME ZONE"):
|
||
out[name] = series.map(_seconds_to_time)
|
||
elif pg in ("TEXT", "VARCHAR", "CHARACTER VARYING", "CHAR", "CHARACTER"):
|
||
# Render every cell as a string and blank out nulls. ``NULL ''``
|
||
# in the COPY statement turns the blanks back into SQL NULL.
|
||
# astype(str) stringifies NaN/None to the literal "nan"/"None",
|
||
# so we mask those after the fact rather than branching per cell.
|
||
na_mask = series.isna()
|
||
if pd.api.types.is_numeric_dtype(series):
|
||
# Hit when a column was auto-unioned to TEXT because at
|
||
# least one file of the cluster stored it as CHAR but this
|
||
# particular file stored it as NUM (typical of SAS phone-id
|
||
# columns). Default float formatting would emit "123.0" -
|
||
# which doesn't match the plain "123" coming from the CHAR
|
||
# files. When the whole chunk is integer-valued, round to
|
||
# int before stringifying; when any fractional value is
|
||
# present we leave float formatting alone so we don't
|
||
# silently drop precision.
|
||
nonnull = series.dropna()
|
||
int_like = False
|
||
if not nonnull.empty:
|
||
try:
|
||
int_like = bool(((nonnull % 1) == 0).all())
|
||
except TypeError:
|
||
int_like = False
|
||
if int_like:
|
||
# ``Int64`` preserves NA; ``.astype(str)`` renders NA
|
||
# as '<NA>', which we then mask out alongside original
|
||
# NaNs.
|
||
as_str = series.astype("Int64").astype(str)
|
||
out[name] = as_str.mask(na_mask, "")
|
||
else:
|
||
out[name] = series.astype(str).mask(na_mask, "")
|
||
else:
|
||
out[name] = series.astype(str).mask(na_mask, "")
|
||
elif pg == "BOOLEAN":
|
||
out[name] = series.astype("boolean") if series.dtype != object else series
|
||
else:
|
||
out[name] = series
|
||
return out
|
||
|
||
|
||
def _serialize_chunk_csv(prepared: pd.DataFrame) -> io.BytesIO:
|
||
"""Serialize a prepared frame into a CSV buffer for ``COPY FROM STDIN``.
|
||
|
||
Uses ``pyarrow.csv.write_csv`` (typically 5-10× faster than pandas'
|
||
pure-Python ``to_csv`` on wide/text-heavy frames). Null cells serialize
|
||
as empty strings and date/timestamp values land in ISO 8601 form, both
|
||
of which Postgres accepts under ``FORMAT csv, NULL ''``.
|
||
"""
|
||
table = pa.Table.from_pandas(prepared, preserve_index=False)
|
||
buf = io.BytesIO()
|
||
pa_csv.write_csv(
|
||
table,
|
||
buf,
|
||
write_options=pa_csv.WriteOptions(include_header=False),
|
||
)
|
||
buf.seek(0)
|
||
return buf
|
||
|
||
|
||
def copy_dataframes(
|
||
conn,
|
||
schema_name: str,
|
||
table_name: str,
|
||
dfs: Iterable[pd.DataFrame],
|
||
columns: Dict[str, ColumnSpec],
|
||
) -> int:
|
||
"""Stream an iterable of DataFrames into Postgres, committing each chunk.
|
||
|
||
Each non-empty chunk is copied via ``COPY ... FROM STDIN`` and committed
|
||
before the next chunk is processed, so an interrupted or failed load
|
||
retains the rows from previously committed chunks. The first chunk's
|
||
commit also flushes any pending DDL (e.g. a preceding ``CREATE TABLE``).
|
||
Empty chunks are skipped. Returns the total rows inserted.
|
||
"""
|
||
col_list = ", ".join(_quote_ident(name) for name in columns.keys())
|
||
sql = (
|
||
f"COPY {_qualified(schema_name, table_name)} ({col_list}) "
|
||
f"FROM STDIN WITH (FORMAT csv, NULL '')"
|
||
)
|
||
|
||
total = 0
|
||
# Pull chunks one at a time so each ``df`` is unreferenced before the
|
||
# generator reads the next one. Without this the loop-variable binding
|
||
# of a ``for df in dfs:`` keeps the previous chunk alive during the
|
||
# next pyreadstat read, pushing peak memory to 5-6× chunk size per
|
||
# worker (old df + incoming df + prepared + pyarrow table + CSV buf).
|
||
# With explicit drops we cap peak at ~2× chunk size: ``df`` goes away
|
||
# once ``prepared`` exists, ``prepared`` once ``buf`` exists, ``buf``
|
||
# once COPY has consumed it. Matters most in parallel mode where
|
||
# 32 × per-worker peak can exhaust a 128 GB host.
|
||
dfs_iter = iter(dfs)
|
||
with conn.cursor() as cur:
|
||
while True:
|
||
try:
|
||
df = next(dfs_iter)
|
||
except StopIteration:
|
||
break
|
||
if df.empty:
|
||
del df
|
||
continue
|
||
prepared = _prepare_for_copy(df, columns)
|
||
del df
|
||
n = len(prepared)
|
||
buf = _serialize_chunk_csv(prepared)
|
||
del prepared
|
||
cur.copy_expert(sql, buf)
|
||
del buf
|
||
conn.commit()
|
||
total += n
|
||
# Hand pyarrow's pool memory back between chunks. Without this,
|
||
# arrow's internal buffer pool keeps the high-water bytes
|
||
# reserved across the worker's lifetime - inside long-running
|
||
# workers this presents as steadily climbing RSS even with the
|
||
# ``del``s above. Cheap (microseconds); call it every chunk.
|
||
try:
|
||
pa.default_memory_pool().release_unused()
|
||
except Exception:
|
||
pass
|
||
return total
|
||
|
||
|
||
def copy_dataframe(
|
||
conn,
|
||
schema_name: str,
|
||
table_name: str,
|
||
df: pd.DataFrame,
|
||
columns: Dict[str, ColumnSpec],
|
||
) -> int:
|
||
"""Stream ``df`` into Postgres via ``COPY ... FROM STDIN``.
|
||
|
||
Convenience wrapper around :func:`copy_dataframes` for single-frame
|
||
callers. Returns the number of rows inserted.
|
||
"""
|
||
return copy_dataframes(conn, schema_name, table_name, [df], columns)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Manifest validation
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _match_manifest_type(inferred: str, manifest_entry: Dict[str, Any]) -> bool:
|
||
inferred_norm = _normalize_type(inferred)
|
||
if "postgres_type" in manifest_entry:
|
||
return inferred_norm == _normalize_type(manifest_entry["postgres_type"])
|
||
if "acceptable_types" in manifest_entry:
|
||
return any(
|
||
inferred_norm == _normalize_type(t)
|
||
for t in manifest_entry["acceptable_types"]
|
||
)
|
||
return False
|
||
|
||
|
||
def validate_against_manifest(
|
||
inferred: Dict[str, ColumnSpec],
|
||
manifest_path: Path,
|
||
) -> List[str]:
|
||
"""Compare the inferred schema against the expected-types manifest.
|
||
|
||
Returns a list of human-readable problem strings; empty list means OK.
|
||
"""
|
||
manifest_path = Path(manifest_path)
|
||
if not manifest_path.exists():
|
||
return [f"manifest not found: {manifest_path}"]
|
||
|
||
with manifest_path.open("r", encoding="utf-8") as f:
|
||
manifest = json.load(f)
|
||
|
||
problems: List[str] = []
|
||
|
||
only_in_inferred = set(inferred) - set(manifest)
|
||
only_in_manifest = set(manifest) - set(inferred)
|
||
if only_in_inferred:
|
||
problems.append(
|
||
f"columns in inferred but not manifest: {sorted(only_in_inferred)}"
|
||
)
|
||
if only_in_manifest:
|
||
problems.append(
|
||
f"columns in manifest but not inferred: {sorted(only_in_manifest)}"
|
||
)
|
||
|
||
for name, spec in inferred.items():
|
||
entry = manifest.get(name)
|
||
if entry is None:
|
||
continue
|
||
if not _match_manifest_type(spec.postgres_type, entry):
|
||
expected = entry.get("postgres_type") or entry.get("acceptable_types")
|
||
problems.append(
|
||
f"column {name!r}: inferred {spec.postgres_type!r}, "
|
||
f"manifest expected {expected!r}"
|
||
)
|
||
manifest_nullable = bool(entry.get("nullable", True))
|
||
if spec.nullable and not manifest_nullable:
|
||
problems.append(
|
||
f"column {name!r}: inferred nullable, manifest expects NOT NULL "
|
||
f"(loosening nullability is never allowed)"
|
||
)
|
||
|
||
return problems
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# CLI
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _build_argparser() -> argparse.ArgumentParser:
|
||
p = argparse.ArgumentParser(
|
||
description="Load a single data file (SAS or delimited text) into Postgres.",
|
||
)
|
||
p.add_argument("--config", required=True, type=Path, help="Path to YAML config")
|
||
p.add_argument(
|
||
"--validate",
|
||
action="store_true",
|
||
help=(
|
||
"Compare inferred schema against <filename-stem>.expected.json "
|
||
"next to the SAS file; exits nonzero on mismatch."
|
||
),
|
||
)
|
||
p.add_argument(
|
||
"--dry-run",
|
||
action="store_true",
|
||
help="Print inferred CREATE TABLE and stop; don't touch Postgres.",
|
||
)
|
||
p.add_argument(
|
||
"--dbcreds",
|
||
action="store_true",
|
||
help=(
|
||
"Prompt for database username and password instead of reading "
|
||
"PGUSER / PGPASSWORD from the environment or .env file."
|
||
),
|
||
)
|
||
p.add_argument(
|
||
"--all-nullable",
|
||
action="store_true",
|
||
help=(
|
||
"Stamp every column nullable in the generated schema, bypassing "
|
||
"NOT NULL inference. Use when sampled rows wrongly suggest a "
|
||
"column has no nulls. Overrides ``all_nullable`` in the YAML "
|
||
"config when set."
|
||
),
|
||
)
|
||
return p
|
||
|
||
|
||
def _format_columns_summary(columns: Dict[str, ColumnSpec]) -> str:
|
||
lines = []
|
||
for spec in columns.values():
|
||
null = "" if spec.nullable else " NOT NULL"
|
||
lines.append(f" {spec.name}: {spec.postgres_type}{null}")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def main(argv: Optional[List[str]] = None) -> int:
|
||
args = _build_argparser().parse_args(argv)
|
||
|
||
load_dotenv()
|
||
|
||
cfg = load_config(args.config)
|
||
|
||
if not cfg.filename.exists():
|
||
file_label = "text file" if cfg.file_type == "text" else "SAS file"
|
||
print(f"error: {file_label} not found: {cfg.filename}", file=sys.stderr)
|
||
return 2
|
||
|
||
# Build kwargs dict for text-file parameters. These are passed through
|
||
# to the unified reader functions and silently ignored for SAS files.
|
||
_text_kw: Dict[str, Any] = dict(
|
||
delimiter=cfg.delimiter,
|
||
text_encoding=cfg.text_encoding,
|
||
quotechar=cfg.quotechar,
|
||
)
|
||
|
||
# Schema inference reads the whole file so type + nullability are
|
||
# computed against every row. That's what the target host has the
|
||
# resources for and is the only way to honestly emit ``NOT NULL`` -
|
||
# a bounded preview routinely missed the ~0.2% of rows with nulls on
|
||
# otherwise-dense keys (e.g. MAFID). If you're on a box that can't
|
||
# fit the file in memory, override ``TYPE_INFERENCE_SAMPLE_ROWS`` to
|
||
# an integer cap and know that sampled specs may stamp ``NOT NULL``
|
||
# on columns whose nulls live past the window.
|
||
preview_df, meta = read_sas_preview(cfg.filename, **_text_kw)
|
||
preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude)
|
||
force_nullable = args.all_nullable or cfg.all_nullable
|
||
columns = infer_schema(
|
||
preview_df,
|
||
meta,
|
||
column_types=cfg.column_types,
|
||
force_nullable=force_nullable,
|
||
)
|
||
|
||
# Validate partition columns exist in the schema after filtering.
|
||
if cfg.partition_by:
|
||
missing_pcols = [c for c in cfg.partition_by if c not in columns]
|
||
if missing_pcols:
|
||
raise ValueError(
|
||
f"partition_by references columns not present in the "
|
||
f"(filtered) schema: {missing_pcols}"
|
||
)
|
||
|
||
# Validate index columns exist in the schema after filtering.
|
||
if cfg.indexes:
|
||
missing_icols = [c for c in cfg.indexes if c not in columns]
|
||
if missing_icols:
|
||
raise ValueError(
|
||
f"indexes references columns not present in the "
|
||
f"(filtered) schema: {missing_icols}"
|
||
)
|
||
|
||
if args.validate:
|
||
manifest_path = cfg.filename.with_suffix("").with_suffix(".expected.json")
|
||
# The above strips .xpt then appends .expected.json, e.g.
|
||
# "sample_kitchensink.xpt" -> "sample_kitchensink.expected.json".
|
||
problems = validate_against_manifest(columns, manifest_path)
|
||
if problems:
|
||
print("validation failed:", file=sys.stderr)
|
||
for p in problems:
|
||
print(f" - {p}", file=sys.stderr)
|
||
return 1
|
||
print(f"validation OK ({len(columns)} columns match {manifest_path.name})")
|
||
|
||
# -- Partition value discovery ------------------------------------------
|
||
# If partitioned, scan the ENTIRE file to discover all unique partition
|
||
# values. The preview is only the first N rows and may miss values.
|
||
# In append mode the partitions already exist, so skip the costly scan.
|
||
partition_values: Optional[dict] = None
|
||
if cfg.partition_by and cfg.if_exists != "append":
|
||
print(" discovering partition values (full file scan)...", file=sys.stderr)
|
||
|
||
def _discovery_chunks():
|
||
for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename, **_text_kw):
|
||
yield apply_column_filter(chunk_df, cfg.include, cfg.exclude)
|
||
|
||
partition_values = discover_partition_values_chunked(
|
||
_discovery_chunks(), cfg.partition_by, columns,
|
||
)
|
||
total_parts = _count_partitions(partition_values)
|
||
print(
|
||
f" discovered {total_parts:,} partition tables "
|
||
f"across {len(cfg.partition_by)} level(s)",
|
||
file=sys.stderr,
|
||
)
|
||
elif cfg.partition_by and cfg.if_exists == "append":
|
||
print(
|
||
" [info] append mode: skipping partition discovery "
|
||
"(partitions assumed to exist)",
|
||
file=sys.stderr,
|
||
)
|
||
|
||
if args.dry_run:
|
||
# Print the parent CREATE TABLE (with PARTITION BY if applicable).
|
||
parent_ddl = render_create_table(
|
||
cfg.schemaname, cfg.tablename, columns,
|
||
partition_by=cfg.partition_by or None,
|
||
)
|
||
print(parent_ddl)
|
||
# Print child partition DDL if partitioned.
|
||
if cfg.partition_by and partition_values:
|
||
child_stmts = render_partition_ddl(
|
||
cfg.schemaname, cfg.tablename, cfg.partition_by,
|
||
partition_values, columns,
|
||
max_partitions=cfg.max_partitions,
|
||
)
|
||
for stmt in child_stmts:
|
||
print()
|
||
print(stmt)
|
||
# Print CREATE INDEX DDL if indexes are configured.
|
||
if cfg.indexes:
|
||
idx_stmts = render_create_indexes(
|
||
cfg.schemaname, cfg.tablename, cfg.indexes,
|
||
)
|
||
for stmt in idx_stmts:
|
||
print()
|
||
print(stmt)
|
||
return 0
|
||
|
||
# Release the preview frame before opening the stream - lets the GC reclaim
|
||
# it while we're holding a Postgres transaction open.
|
||
del preview_df
|
||
|
||
total_rows = getattr(meta, "number_rows", None)
|
||
|
||
def _filtered_chunks():
|
||
pbar = tqdm(
|
||
total=total_rows,
|
||
unit="row",
|
||
unit_scale=True,
|
||
desc=f" {cfg.filename.name}",
|
||
file=sys.stderr,
|
||
dynamic_ncols=True,
|
||
)
|
||
try:
|
||
for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename, **_text_kw):
|
||
chunk_df = apply_column_filter(chunk_df, cfg.include, cfg.exclude)
|
||
pbar.update(len(chunk_df))
|
||
yield chunk_df
|
||
finally:
|
||
pbar.close()
|
||
|
||
db_user = db_password = None
|
||
if args.dbcreds:
|
||
db_user = input("Database username: ")
|
||
db_password = getpass.getpass("Database password: ")
|
||
|
||
conn = connect(user=db_user, password=db_password)
|
||
conn.autocommit = False
|
||
try:
|
||
create_table(
|
||
conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists,
|
||
partition_by=cfg.partition_by or None,
|
||
partition_values=partition_values,
|
||
max_partitions=cfg.max_partitions,
|
||
)
|
||
inserted = copy_dataframes(
|
||
conn, cfg.schemaname, cfg.tablename, _filtered_chunks(), columns
|
||
)
|
||
conn.commit()
|
||
if cfg.indexes:
|
||
create_indexes(conn, cfg.schemaname, cfg.tablename, cfg.indexes)
|
||
except Exception:
|
||
conn.rollback()
|
||
raise
|
||
finally:
|
||
conn.close()
|
||
|
||
print(
|
||
f"loaded {inserted} rows into {cfg.schemaname}.{cfg.tablename} "
|
||
f"({len(columns)} columns)"
|
||
)
|
||
if cfg.partition_by and partition_values:
|
||
total_parts = _count_partitions(partition_values)
|
||
print(f"partitioned by {cfg.partition_by} ({total_parts:,} partition tables)")
|
||
print("final schema:")
|
||
print(_format_columns_summary(columns))
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main())
|