Compare commits

...

3 Commits

Author SHA1 Message Date
michael-corey
b3d7a9d440 adding index field 2026-04-20 10:18:09 -05:00
michael-corey
0d955eeab1 adding partition flag 2026-04-20 09:56:00 -05:00
michael-corey
e39eb47a90 altering such that commit is by batch 2026-04-20 08:38:38 -05:00
4 changed files with 1309 additions and 36 deletions

View File

@ -32,9 +32,19 @@ USAGE
# include: [ID, INTCOL] # include: [ID, INTCOL]
# exclude: [ALLNULL] # exclude: [ALLNULL]
# Optional folder default for LIST partitioning. Omit or set [] for no
# partitioning. Accepts a single string or a list of column names.
# partition_by:
# - state
# - zip
# Optional folder default threshold. Default: 10000.
# max_partitions: 10000
# Optional explicit cluster patterns. Each pattern is matched against the # Optional explicit cluster patterns. Each pattern is matched against the
# file *basename*. Matched files are pulled out of the auto-detect pool. # file *basename*. Matched files are pulled out of the auto-detect pool.
# Per-cluster if_exists/include/exclude override the folder-level defaults. # Per-cluster if_exists/include/exclude/partition_by/max_partitions
# override the folder-level defaults.
clusters: clusters:
- pattern: '^group_a\\d+\\.sas7bdat$' - pattern: '^group_a\\d+\\.sas7bdat$'
tablename: group_a tablename: group_a
@ -51,9 +61,10 @@ USAGE
Flags: Flags:
--config PATH Required. Path to the YAML config above. --config PATH Required. Path to the YAML config above.
--dry-run Print the discovered clusters and the inferred CREATE --dry-run Print the discovered clusters and the inferred DDL for
TABLE for each (schema from the first file of the each (CREATE TABLE plus partition DDL when applicable).
cluster). The database is never touched. For partitioned clusters all files are scanned to
discover partition values. The database is never touched.
--fail-fast Abort the whole run on the first cluster failure. --fail-fast Abort the whole run on the first cluster failure.
Default is to log the failure, roll that cluster back, Default is to log the failure, roll that cluster back,
and keep going. and keep going.
@ -113,15 +124,21 @@ from dotenv import load_dotenv
from load_sas import ( from load_sas import (
VALID_IF_EXISTS, VALID_IF_EXISTS,
_count_partitions,
_merge_partition_trees,
apply_column_filter, apply_column_filter,
assert_schema_compatible, assert_schema_compatible,
connect, connect,
copy_dataframes, copy_dataframes,
create_indexes,
create_table, create_table,
discover_partition_values_chunked,
infer_schema, infer_schema,
iter_sas_chunks, iter_sas_chunks,
read_sas_preview, read_sas_preview,
render_create_indexes,
render_create_table, render_create_table,
render_partition_ddl,
) )
@ -135,6 +152,13 @@ SAS_EXTENSIONS = (".sas7bdat", ".xpt", ".xport")
@dataclass @dataclass
class ClusterSpec: class ClusterSpec:
"""Resolved per-cluster load settings.
``partition_by``, ``max_partitions``, and ``indexes`` are resolved from
the folder defaults and any per-cluster overrides during
:func:`discover_clusters`.
"""
tablename: str tablename: str
files: List[Path] files: List[Path]
if_exists: str if_exists: str
@ -142,11 +166,20 @@ class ClusterSpec:
exclude: Optional[List[str]] exclude: Optional[List[str]]
source: str # "explicit" or "auto" source: str # "explicit" or "auto"
pattern: Optional[str] = None pattern: Optional[str] = None
partition_by: List[str] = field(default_factory=list)
max_partitions: int = 10_000
indexes: List[str] = field(default_factory=list)
@dataclass @dataclass
class _ExplicitPattern: class _ExplicitPattern:
"""Parsed form of a single ``clusters[*]`` YAML entry.""" """Parsed form of a single ``clusters[*]`` YAML entry.
``partition_by`` defaults to ``None`` meaning "inherit from folder level".
An explicit empty list ``[]`` means "disable partitioning for this cluster".
``max_partitions`` defaults to ``None`` meaning "inherit from folder level".
``indexes`` defaults to ``None`` meaning "inherit from folder level".
"""
pattern: re.Pattern pattern: re.Pattern
raw_pattern: str raw_pattern: str
@ -154,10 +187,19 @@ class _ExplicitPattern:
if_exists: Optional[str] = None if_exists: Optional[str] = None
include: Optional[List[str]] = None include: Optional[List[str]] = None
exclude: Optional[List[str]] = None exclude: Optional[List[str]] = None
partition_by: Optional[List[str]] = None
max_partitions: Optional[int] = None
indexes: Optional[List[str]] = None
@dataclass @dataclass
class FolderConfig: class FolderConfig:
"""Folder-level configuration parsed from YAML.
``partition_by``, ``max_partitions``, and ``indexes`` serve as defaults
for every cluster unless overridden at the cluster level.
"""
folder: Path folder: Path
schemaname: str schemaname: str
if_exists: str = "fail" if_exists: str = "fail"
@ -165,6 +207,9 @@ class FolderConfig:
include: Optional[List[str]] = None include: Optional[List[str]] = None
exclude: Optional[List[str]] = None exclude: Optional[List[str]] = None
explicit: List[_ExplicitPattern] = field(default_factory=list) explicit: List[_ExplicitPattern] = field(default_factory=list)
partition_by: List[str] = field(default_factory=list)
max_partitions: int = 10_000
indexes: List[str] = field(default_factory=list)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -197,8 +242,141 @@ def _parse_columns_filter(
return include_out, exclude_out return include_out, exclude_out
def _parse_partition_by(
raw_value: Any, where: str, *, allow_none: bool = False
) -> Optional[List[str]]:
"""Parse a ``partition_by`` value from YAML.
Returns a list of non-empty, unique column name strings. When
``allow_none`` is True (used for per-cluster entries), an omitted key
returns ``None`` to signal "inherit from folder level". An explicit
empty list ``[]`` always returns ``[]``.
"""
if raw_value is None:
return None if allow_none else []
if isinstance(raw_value, str):
if not raw_value.strip():
raise ValueError(f"{where}: 'partition_by' string must be non-empty.")
return [raw_value.strip()]
if isinstance(raw_value, list):
if len(raw_value) == 0:
return []
result: List[str] = []
for i, item in enumerate(raw_value):
if not isinstance(item, str) or not item.strip():
raise ValueError(
f"{where}: 'partition_by[{i}]' must be a non-empty string."
)
result.append(str(item).strip())
if len(result) != len(set(result)):
raise ValueError(
f"{where}: 'partition_by' contains duplicate column names."
)
return result
raise ValueError(
f"{where}: 'partition_by' must be a string or list of strings."
)
def _parse_max_partitions(
raw_value: Any, where: str, *, allow_none: bool = False
) -> Optional[int]:
"""Parse a ``max_partitions`` value from YAML.
Returns a positive integer. When ``allow_none`` is True (used for
per-cluster entries), an omitted key returns ``None`` to signal
"inherit from folder level".
"""
if raw_value is None:
return None if allow_none else 10_000
try:
value = int(raw_value)
except (TypeError, ValueError):
raise ValueError(
f"{where}: 'max_partitions' must be a positive integer, "
f"got {raw_value!r}"
)
if value <= 0:
raise ValueError(
f"{where}: 'max_partitions' must be a positive integer, "
f"got {value}"
)
return value
def _validate_partition_vs_columns(
partition_by: List[str],
exclude: Optional[List[str]],
where: str,
) -> None:
"""Raise if any ``partition_by`` column is in the ``exclude`` list."""
if not partition_by or exclude is None:
return
excluded_parts = [c for c in partition_by if c in exclude]
if excluded_parts:
raise ValueError(
f"{where}: 'exclude' removes partition_by columns: {excluded_parts}"
)
def _parse_indexes(
raw_value: Any, where: str, *, allow_none: bool = False
) -> Optional[List[str]]:
"""Parse an ``indexes`` value from YAML.
Returns a list of non-empty, unique column name strings. When
``allow_none`` is True (used for per-cluster entries), an omitted key
returns ``None`` to signal "inherit from folder level". An explicit
empty list ``[]`` always returns ``[]``.
"""
if raw_value is None:
return None if allow_none else []
if isinstance(raw_value, str):
if not raw_value.strip():
raise ValueError(f"{where}: 'indexes' string must be non-empty.")
return [raw_value.strip()]
if isinstance(raw_value, list):
if len(raw_value) == 0:
return []
result: List[str] = []
for i, item in enumerate(raw_value):
if not isinstance(item, str) or not item.strip():
raise ValueError(
f"{where}: 'indexes[{i}]' must be a non-empty string."
)
result.append(str(item).strip())
if len(result) != len(set(result)):
raise ValueError(
f"{where}: 'indexes' contains duplicate column names."
)
return result
raise ValueError(
f"{where}: 'indexes' must be a string or list of strings."
)
def _validate_indexes_vs_columns(
indexes: List[str],
exclude: Optional[List[str]],
where: str,
) -> None:
"""Raise if any ``indexes`` column is in the ``exclude`` list."""
if not indexes or exclude is None:
return
excluded_idx = [c for c in indexes if c in exclude]
if excluded_idx:
raise ValueError(
f"{where}: 'exclude' removes index columns: {excluded_idx}"
)
def load_folder_config(path: Path) -> FolderConfig: def load_folder_config(path: Path) -> FolderConfig:
"""Parse and validate the folder-level YAML config at ``path``.""" """Parse and validate the folder-level YAML config at ``path``.
Supports optional ``partition_by`` and ``max_partitions`` at both the
folder level (defaults for all clusters) and per explicit cluster entry
(overrides the folder default).
"""
path = Path(path) path = Path(path)
with path.open("r", encoding="utf-8") as f: with path.open("r", encoding="utf-8") as f:
raw = yaml.safe_load(f) raw = yaml.safe_load(f)
@ -221,6 +399,19 @@ def load_folder_config(path: Path) -> FolderConfig:
include, exclude = _parse_columns_filter(raw, f"Config {path}") include, exclude = _parse_columns_filter(raw, f"Config {path}")
# -- folder-level partition settings ------------------------------------
partition_by = _parse_partition_by(
raw.get("partition_by"), f"Config {path}"
)
max_partitions = _parse_max_partitions(
raw.get("max_partitions"), f"Config {path}"
)
_validate_partition_vs_columns(partition_by, exclude, f"Config {path}")
# -- folder-level index settings ----------------------------------------
indexes = _parse_indexes(raw.get("indexes"), f"Config {path}")
_validate_indexes_vs_columns(indexes, exclude, f"Config {path}")
explicit: List[_ExplicitPattern] = [] explicit: List[_ExplicitPattern] = []
clusters_raw = raw.get("clusters") or [] clusters_raw = raw.get("clusters") or []
if not isinstance(clusters_raw, list): if not isinstance(clusters_raw, list):
@ -242,6 +433,26 @@ def load_folder_config(path: Path) -> FolderConfig:
else None else None
) )
c_include, c_exclude = _parse_columns_filter(entry, where) c_include, c_exclude = _parse_columns_filter(entry, where)
# -- per-cluster partition settings ---------------------------------
c_partition_by = _parse_partition_by(
entry.get("partition_by"), where, allow_none=True
)
c_max_partitions = _parse_max_partitions(
entry.get("max_partitions"), where, allow_none=True
)
# Validate partition_by vs the effective exclude for this cluster.
effective_exclude = c_exclude if c_exclude is not None else exclude
effective_pb = c_partition_by if c_partition_by is not None else partition_by
_validate_partition_vs_columns(effective_pb, effective_exclude, where)
# -- per-cluster index settings -------------------------------------
c_indexes = _parse_indexes(
entry.get("indexes"), where, allow_none=True
)
effective_idx = c_indexes if c_indexes is not None else indexes
_validate_indexes_vs_columns(effective_idx, effective_exclude, where)
explicit.append( explicit.append(
_ExplicitPattern( _ExplicitPattern(
pattern=compiled, pattern=compiled,
@ -250,6 +461,9 @@ def load_folder_config(path: Path) -> FolderConfig:
if_exists=c_if_exists, if_exists=c_if_exists,
include=c_include, include=c_include,
exclude=c_exclude, exclude=c_exclude,
partition_by=c_partition_by,
max_partitions=c_max_partitions,
indexes=c_indexes,
) )
) )
@ -261,6 +475,9 @@ def load_folder_config(path: Path) -> FolderConfig:
include=include, include=include,
exclude=exclude, exclude=exclude,
explicit=explicit, explicit=explicit,
partition_by=partition_by,
max_partitions=max_partitions,
indexes=indexes,
) )
@ -300,6 +517,13 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]:
order; files matched by an earlier pattern are removed from the pool order; files matched by an earlier pattern are removed from the pool
before the next pattern runs. A file matching two patterns triggers a before the next pattern runs. A file matching two patterns triggers a
hard error (that's almost always a config bug). hard error (that's almost always a config bug).
Partition settings are resolved per cluster:
* For explicit clusters, ``partition_by`` / ``max_partitions`` from the
cluster entry override the folder defaults when present. ``None``
means "inherit"; an explicit ``[]`` disables partitioning.
* For auto-detected clusters, folder defaults are inherited directly.
""" """
if not cfg.folder.exists() or not cfg.folder.is_dir(): if not cfg.folder.exists() or not cfg.folder.is_dir():
raise FileNotFoundError(f"Folder not found or not a directory: {cfg.folder}") raise FileNotFoundError(f"Folder not found or not a directory: {cfg.folder}")
@ -320,6 +544,21 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]:
remaining = list(pool) remaining = list(pool)
for patt in cfg.explicit: for patt in cfg.explicit:
# Resolve partition_by: None = inherit folder, [] = disable, list = override
resolved_pb = (
patt.partition_by if patt.partition_by is not None
else cfg.partition_by
)
resolved_mp = (
patt.max_partitions if patt.max_partitions is not None
else cfg.max_partitions
)
# Resolve indexes: None = inherit folder, [] = disable, list = override
resolved_idx = (
patt.indexes if patt.indexes is not None
else cfg.indexes
)
matched = [f for f in remaining if patt.pattern.search(f.name)] matched = [f for f in remaining if patt.pattern.search(f.name)]
if not matched: if not matched:
# Not an error - the folder might legitimately not contain files # Not an error - the folder might legitimately not contain files
@ -333,6 +572,9 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]:
exclude=patt.exclude if patt.exclude is not None else cfg.exclude, exclude=patt.exclude if patt.exclude is not None else cfg.exclude,
source="explicit", source="explicit",
pattern=patt.raw_pattern, pattern=patt.raw_pattern,
partition_by=resolved_pb,
max_partitions=resolved_mp,
indexes=resolved_idx,
) )
) )
continue continue
@ -346,6 +588,9 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]:
exclude=patt.exclude if patt.exclude is not None else cfg.exclude, exclude=patt.exclude if patt.exclude is not None else cfg.exclude,
source="explicit", source="explicit",
pattern=patt.raw_pattern, pattern=patt.raw_pattern,
partition_by=resolved_pb,
max_partitions=resolved_mp,
indexes=resolved_idx,
) )
) )
@ -363,6 +608,9 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]:
include=cfg.include, include=cfg.include,
exclude=cfg.exclude, exclude=cfg.exclude,
source="auto", source="auto",
partition_by=cfg.partition_by,
max_partitions=cfg.max_partitions,
indexes=cfg.indexes,
) )
) )
@ -375,6 +623,7 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]:
def _infer_cluster_schema(path: Path, include, exclude): def _infer_cluster_schema(path: Path, include, exclude):
"""Infer the Postgres column schema from a SAS file preview."""
preview_df, meta = read_sas_preview(path) preview_df, meta = read_sas_preview(path)
preview_df = apply_column_filter(preview_df, include, exclude) preview_df = apply_column_filter(preview_df, include, exclude)
total_rows = getattr(meta, "number_rows", None) total_rows = getattr(meta, "number_rows", None)
@ -382,20 +631,103 @@ def _infer_cluster_schema(path: Path, include, exclude):
return columns return columns
def _discover_cluster_partitions(
cluster: ClusterSpec,
columns: Dict,
) -> dict:
"""Scan ALL files in ``cluster`` to discover partition values.
Returns a nested partition-value tree suitable for passing to
:func:`load_sas.render_partition_ddl` and :func:`load_sas.create_table`.
Each file is scanned chunk-by-chunk so the full dataset is never
materialized in memory.
"""
merged: dict = {}
for path in cluster.files:
def _filtered_chunks(p=path):
for chunk_df, _chunk_meta in iter_sas_chunks(p):
yield apply_column_filter(
chunk_df, cluster.include, cluster.exclude
)
file_tree = discover_partition_values_chunked(
_filtered_chunks(), cluster.partition_by, columns,
)
_merge_partition_trees(merged, file_tree)
return merged
def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int: def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int:
"""Load every file in ``cluster`` into one table. Returns total rows loaded. """Load every file in ``cluster`` into one table. Returns total rows loaded.
The caller owns transaction boundaries. This function does NOT commit or When ``cluster.partition_by`` is non-empty, partition values are
roll back - :func:`main` does that per cluster so one bad cluster discovered across ALL files before table creation so the full partition
doesn't poison the rest of the run. tree exists before any data is copied.
Commits happen per chunk inside :func:`load_sas.copy_dataframes`. If a
file mid-cluster fails, earlier chunks - including chunks from earlier
files in the cluster - stay committed; only the in-flight chunk is
rolled back by :func:`main`.
""" """
if not cluster.files: if not cluster.files:
return 0 return 0
first, *rest = cluster.files first, *rest = cluster.files
first_columns = _infer_cluster_schema(first, cluster.include, cluster.exclude) first_columns = _infer_cluster_schema(first, cluster.include, cluster.exclude)
# -- Validate index columns early ---------------------------------------
if cluster.indexes:
missing_icols = [
c for c in cluster.indexes if c not in first_columns
]
if missing_icols:
raise ValueError(
f"cluster {cluster.tablename!r}: indexes references "
f"columns not present in the inferred schema: {missing_icols}"
)
# -- Partition support --------------------------------------------------
partition_values: Optional[dict] = None
if cluster.partition_by:
# Validate that all partition_by columns exist in the inferred schema.
missing_pcols = [
c for c in cluster.partition_by if c not in first_columns
]
if missing_pcols:
raise ValueError(
f"cluster {cluster.tablename!r}: partition_by references "
f"columns not present in the inferred schema: {missing_pcols}"
)
# Discover partition values across ALL files in the cluster.
# In append mode the partitions already exist, so skip the scan.
if cluster.if_exists == "append":
print(
" [info] append mode: skipping partition discovery "
"(partitions assumed to exist)",
file=sys.stderr,
)
else:
print(
f" discovering partition values across "
f"{len(cluster.files)} file(s)...",
file=sys.stderr,
)
partition_values = _discover_cluster_partitions(
cluster, first_columns,
)
total_parts = _count_partitions(partition_values)
print(
f" discovered {total_parts:,} partition table(s) "
f"across {len(cluster.partition_by)} level(s)",
file=sys.stderr,
)
create_table( create_table(
conn, schemaname, cluster.tablename, first_columns, cluster.if_exists conn, schemaname, cluster.tablename, first_columns, cluster.if_exists,
partition_by=cluster.partition_by or None,
partition_values=partition_values,
max_partitions=cluster.max_partitions,
) )
total = 0 total = 0
@ -407,14 +739,18 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int:
for path in rest: for path in rest:
columns = _infer_cluster_schema(path, cluster.include, cluster.exclude) columns = _infer_cluster_schema(path, cluster.include, cluster.exclude)
# Uses the same check that if_exists=append runs. A type mismatch or # Uses the same check that if_exists=append runs. A type mismatch or
# missing column aborts the cluster; the transaction rollback in # missing column aborts the cluster; because chunks commit as they
# main() keeps the table from ending up half-loaded. # load, earlier chunks in the cluster remain in the table.
assert_schema_compatible(conn, schemaname, cluster.tablename, columns) assert_schema_compatible(conn, schemaname, cluster.tablename, columns)
total += _stream_file( total += _stream_file(
conn, schemaname, cluster.tablename, path, columns, conn, schemaname, cluster.tablename, path, columns,
cluster.include, cluster.exclude, cluster.include, cluster.exclude,
) )
# -- Index support ------------------------------------------------------
if cluster.indexes:
create_indexes(conn, schemaname, cluster.tablename, cluster.indexes)
return total return total
@ -458,8 +794,10 @@ def _build_argparser() -> argparse.ArgumentParser:
"--dry-run", "--dry-run",
action="store_true", action="store_true",
help=( help=(
"Print discovered clusters and the inferred CREATE TABLE for " "Print discovered clusters and the inferred DDL for each "
"each; don't touch Postgres." "(CREATE TABLE plus partition DDL when applicable). For "
"partitioned clusters all files are scanned to discover "
"partition values. The database is never touched."
), ),
) )
p.add_argument( p.add_argument(
@ -486,9 +824,15 @@ def _describe_cluster(cluster: ClusterSpec) -> str:
if cluster.pattern: if cluster.pattern:
src += f" pattern={cluster.pattern!r}" src += f" pattern={cluster.pattern!r}"
files = ", ".join(f.name for f in cluster.files) or "(no matching files)" files = ", ".join(f.name for f in cluster.files) or "(no matching files)"
parts = ""
if cluster.partition_by:
parts = f"\n partition_by: {cluster.partition_by}"
idx = ""
if cluster.indexes:
idx = f"\n indexes: {cluster.indexes}"
return ( return (
f"cluster {cluster.tablename!r} [{src}] if_exists={cluster.if_exists}\n" f"cluster {cluster.tablename!r} [{src}] if_exists={cluster.if_exists}\n"
f" files: {files}" f" files: {files}{parts}{idx}"
) )
@ -521,9 +865,68 @@ def main(argv: Optional[List[str]] = None) -> int:
if args.dry_run: if args.dry_run:
print() print()
for c in loadable: for c in loadable:
print(f"--- CREATE TABLE for cluster {c.tablename!r} ---") print(f"--- DDL for cluster {c.tablename!r} ---")
columns = _infer_cluster_schema(c.files[0], c.include, c.exclude) columns = _infer_cluster_schema(c.files[0], c.include, c.exclude)
print(render_create_table(cfg.schemaname, c.tablename, columns)) # Print parent CREATE TABLE (with PARTITION BY if applicable).
print(
render_create_table(
cfg.schemaname, c.tablename, columns,
partition_by=c.partition_by or None,
)
)
# Print child partition DDL when the cluster is partitioned.
if c.partition_by:
# Validate partition columns exist in the schema.
missing_pcols = [
col for col in c.partition_by if col not in columns
]
if missing_pcols:
print(
f" [error] partition_by references columns not in "
f"schema: {missing_pcols}",
file=sys.stderr,
)
else:
print(
f" discovering partition values across "
f"{len(c.files)} file(s)...",
file=sys.stderr,
)
partition_values = _discover_cluster_partitions(
c, columns,
)
total_parts = _count_partitions(partition_values)
print(
f" discovered {total_parts:,} partition table(s) "
f"across {len(c.partition_by)} level(s)",
file=sys.stderr,
)
child_stmts = render_partition_ddl(
cfg.schemaname, c.tablename, c.partition_by,
partition_values, columns,
max_partitions=c.max_partitions,
)
for stmt in child_stmts:
print()
print(stmt)
# Print CREATE INDEX DDL when the cluster has indexes.
if c.indexes:
missing_icols = [
col for col in c.indexes if col not in columns
]
if missing_icols:
print(
f" [error] indexes references columns not in "
f"schema: {missing_icols}",
file=sys.stderr,
)
else:
idx_stmts = render_create_indexes(
cfg.schemaname, c.tablename, c.indexes,
)
for stmt in idx_stmts:
print()
print(stmt)
print() print()
return 0 return 0

View File

@ -194,8 +194,8 @@ will fail mid-stream and the whole transaction rolls back. Set
matters more than speed. matters more than speed.
Streaming loads use :func:`iter_sas_chunks` + :func:`copy_dataframes`, which Streaming loads use :func:`iter_sas_chunks` + :func:`copy_dataframes`, which
share one cursor and transaction so a failure mid-file rolls back the whole commit each chunk as it is copied so an interrupted load retains the rows
load. that were already written.
7. Tunables 7. Tunables
----------- -----------
@ -217,9 +217,13 @@ from __future__ import annotations
import argparse import argparse
import datetime as dt import datetime as dt
import getpass import getpass
import hashlib
import io import io
import json import json
import logging
import math
import os import os
import re
import sys import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
@ -233,6 +237,9 @@ import yaml
from dotenv import load_dotenv from dotenv import load_dotenv
logger = logging.getLogger(__name__)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Top-level tunables # Top-level tunables
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -263,6 +270,9 @@ gentler on memory."""
VALID_IF_EXISTS = ("fail", "replace", "append") VALID_IF_EXISTS = ("fail", "replace", "append")
_PG_IDENT_MAX_LEN = 63
"""PostgreSQL maximum identifier length in bytes (characters for ASCII)."""
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Dataclasses # Dataclasses
@ -277,6 +287,9 @@ class LoaderConfig:
if_exists: str = "fail" if_exists: str = "fail"
include: Optional[List[str]] = None include: Optional[List[str]] = None
exclude: Optional[List[str]] = None exclude: Optional[List[str]] = None
partition_by: List[str] = field(default_factory=list)
max_partitions: int = 10_000
indexes: List[str] = field(default_factory=list)
@dataclass @dataclass
@ -384,6 +397,109 @@ def load_config(path: Path) -> LoaderConfig:
if exclude is not None and not isinstance(exclude, list): if exclude is not None and not isinstance(exclude, list):
raise ValueError(f"Config {path}: 'exclude' must be a list of column names.") raise ValueError(f"Config {path}: 'exclude' must be a list of column names.")
# -- partition_by -------------------------------------------------------
raw_pb = raw.get("partition_by")
if raw_pb is None or (isinstance(raw_pb, list) and len(raw_pb) == 0):
partition_by: List[str] = []
elif isinstance(raw_pb, str):
if not raw_pb.strip():
raise ValueError(f"Config {path}: 'partition_by' string must be non-empty.")
partition_by = [raw_pb.strip()]
elif isinstance(raw_pb, list):
partition_by = []
for i, item in enumerate(raw_pb):
if not isinstance(item, str) or not item.strip():
raise ValueError(
f"Config {path}: 'partition_by[{i}]' must be a non-empty string."
)
partition_by.append(str(item).strip())
if len(partition_by) != len(set(partition_by)):
raise ValueError(
f"Config {path}: 'partition_by' contains duplicate column names."
)
else:
raise ValueError(
f"Config {path}: 'partition_by' must be a string or list of strings."
)
# Validate partition_by vs include/exclude
if partition_by:
inc_list = [str(c) for c in include] if include is not None else None
exc_list = [str(c) for c in exclude] if exclude is not None else None
if inc_list is not None:
missing_in_include = [c for c in partition_by if c not in inc_list]
if missing_in_include:
raise ValueError(
f"Config {path}: 'include' omits partition_by columns: "
f"{missing_in_include}"
)
if exc_list is not None:
excluded_parts = [c for c in partition_by if c in exc_list]
if excluded_parts:
raise ValueError(
f"Config {path}: 'exclude' removes partition_by columns: "
f"{excluded_parts}"
)
# -- max_partitions -----------------------------------------------------
raw_mp = raw.get("max_partitions", 10_000)
try:
max_partitions = int(raw_mp)
except (TypeError, ValueError):
raise ValueError(
f"Config {path}: 'max_partitions' must be a positive integer, "
f"got {raw_mp!r}"
)
if max_partitions <= 0:
raise ValueError(
f"Config {path}: 'max_partitions' must be a positive integer, "
f"got {max_partitions}"
)
# -- indexes ------------------------------------------------------------
raw_idx = raw.get("indexes")
if raw_idx is None or (isinstance(raw_idx, list) and len(raw_idx) == 0):
indexes: List[str] = []
elif isinstance(raw_idx, str):
if not raw_idx.strip():
raise ValueError(f"Config {path}: 'indexes' string must be non-empty.")
indexes = [raw_idx.strip()]
elif isinstance(raw_idx, list):
indexes = []
for i, item in enumerate(raw_idx):
if not isinstance(item, str) or not item.strip():
raise ValueError(
f"Config {path}: 'indexes[{i}]' must be a non-empty string."
)
indexes.append(str(item).strip())
if len(indexes) != len(set(indexes)):
raise ValueError(
f"Config {path}: 'indexes' contains duplicate column names."
)
else:
raise ValueError(
f"Config {path}: 'indexes' must be a string or list of strings."
)
# Validate indexes vs include/exclude
if indexes:
inc_list = [str(c) for c in include] if include is not None else None
exc_list = [str(c) for c in exclude] if exclude is not None else None
if exc_list is not None:
excluded_idx = [c for c in indexes if c in exc_list]
if excluded_idx:
raise ValueError(
f"Config {path}: 'exclude' removes index columns: "
f"{excluded_idx}"
)
if inc_list is not None:
missing_in_include = [c for c in indexes if c not in inc_list]
if missing_in_include:
raise ValueError(
f"Config {path}: 'include' omits index columns: "
f"{missing_in_include}"
)
return LoaderConfig( return LoaderConfig(
filename=filename, filename=filename,
schemaname=schemaname, schemaname=schemaname,
@ -391,6 +507,9 @@ def load_config(path: Path) -> LoaderConfig:
if_exists=if_exists, if_exists=if_exists,
include=[str(c) for c in include] if include is not None else None, include=[str(c) for c in include] if include is not None else None,
exclude=[str(c) for c in exclude] if exclude is not None else None, exclude=[str(c) for c in exclude] if exclude is not None else None,
partition_by=partition_by,
max_partitions=max_partitions,
indexes=indexes,
) )
@ -753,24 +872,48 @@ def _table_exists(conn, schema: str, table: str) -> bool:
return cur.fetchone() is not None return cur.fetchone() is not None
def render_create_table(schema: str, table: str, columns: Dict[str, ColumnSpec]) -> str: def render_create_table(
schema: str,
table: str,
columns: Dict[str, ColumnSpec],
*,
partition_by: Optional[List[str]] = None,
) -> str:
"""Render a ``CREATE TABLE`` statement.
When ``partition_by`` is provided and non-empty, appends a
``PARTITION BY LIST ("first_field")`` clause to the statement.
"""
lines = [] lines = []
for spec in columns.values(): for spec in columns.values():
null_clause = "" if spec.nullable else " NOT NULL" null_clause = "" if spec.nullable else " NOT NULL"
lines.append(f" {_quote_ident(spec.name)} {spec.postgres_type}{null_clause}") lines.append(f" {_quote_ident(spec.name)} {spec.postgres_type}{null_clause}")
body = ",\n".join(lines) body = ",\n".join(lines)
return f"CREATE TABLE {_qualified(schema, table)} (\n{body}\n);" suffix = ""
if partition_by:
suffix = f"\nPARTITION BY LIST ({_quote_ident(partition_by[0])})"
return f"CREATE TABLE {_qualified(schema, table)} (\n{body}\n){suffix};"
def _create_table_sql(conn, schema: str, table: str, columns: Dict[str, ColumnSpec]) -> None: def _create_table_sql(
sql = render_create_table(schema, table, columns) conn,
schema: str,
table: str,
columns: Dict[str, ColumnSpec],
*,
partition_by: Optional[List[str]] = None,
) -> None:
"""Execute a ``CREATE TABLE`` statement, optionally with partitioning."""
sql = render_create_table(schema, table, columns, partition_by=partition_by)
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute(sql) cur.execute(sql)
def _drop_table(conn, schema: str, table: str) -> None: def _drop_table(conn, schema: str, table: str, *, cascade: bool = False) -> None:
"""Drop a table, optionally with CASCADE for partitioned tables."""
tail = " CASCADE" if cascade else ""
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute(f"DROP TABLE {_qualified(schema, table)}") cur.execute(f"DROP TABLE {_qualified(schema, table)}{tail}")
# Normalization table: map both loader-emitted and Postgres-reported type # Normalization table: map both loader-emitted and Postgres-reported type
@ -815,8 +958,6 @@ def _normalize_type(pg_type: str) -> str:
stripped = pg_type.strip().upper() stripped = pg_type.strip().upper()
# Remove trailing (n) / (p,s) before the space-separated tail. # Remove trailing (n) / (p,s) before the space-separated tail.
# Examples: "VARCHAR(10)" -> "VARCHAR"; "TIMESTAMP(6) WITHOUT TIME ZONE" -> "TIMESTAMP WITHOUT TIME ZONE" # Examples: "VARCHAR(10)" -> "VARCHAR"; "TIMESTAMP(6) WITHOUT TIME ZONE" -> "TIMESTAMP WITHOUT TIME ZONE"
import re
stripped = re.sub(r"\(\s*\d+\s*(?:,\s*\d+\s*)?\)", "", stripped).strip() stripped = re.sub(r"\(\s*\d+\s*(?:,\s*\d+\s*)?\)", "", stripped).strip()
# Collapse doubled whitespace after paren removal. # Collapse doubled whitespace after paren removal.
stripped = re.sub(r"\s+", " ", stripped) stripped = re.sub(r"\s+", " ", stripped)
@ -893,11 +1034,28 @@ def create_table(
table_name: str, table_name: str,
columns: Dict[str, ColumnSpec], columns: Dict[str, ColumnSpec],
if_exists: str, if_exists: str,
*,
partition_by: Optional[List[str]] = None,
partition_values: Optional[dict] = None,
max_partitions: int = 10_000,
) -> None: ) -> None:
"""Create (or verify) the target table according to ``if_exists``.""" """Create (or verify) the target table according to ``if_exists``.
When ``partition_by`` is provided and non-empty, the parent table is
created with ``PARTITION BY LIST`` and all child partition DDL from
:func:`render_partition_ddl` is executed immediately after.
For ``replace`` mode the existing table is dropped with ``CASCADE`` so
all child partitions are removed automatically.
For ``append`` mode partition creation is skipped entirely the
partitions are assumed to already exist from the original creation.
"""
if if_exists not in VALID_IF_EXISTS: if if_exists not in VALID_IF_EXISTS:
raise ValueError(f"if_exists must be one of {VALID_IF_EXISTS}, got {if_exists!r}") raise ValueError(f"if_exists must be one of {VALID_IF_EXISTS}, got {if_exists!r}")
is_partitioned = bool(partition_by)
exists = _table_exists(conn, schema_name, table_name) exists = _table_exists(conn, schema_name, table_name)
if exists: if exists:
if if_exists == "fail": if if_exists == "fail":
@ -905,14 +1063,577 @@ def create_table(
f"Table {schema_name}.{table_name} already exists and if_exists=fail" f"Table {schema_name}.{table_name} already exists and if_exists=fail"
) )
if if_exists == "replace": if if_exists == "replace":
_drop_table(conn, schema_name, table_name) _drop_table(conn, schema_name, table_name, cascade=is_partitioned)
_create_table_sql(conn, schema_name, table_name, columns) _create_table_sql(
conn, schema_name, table_name, columns,
partition_by=partition_by,
)
if is_partitioned and partition_values is not None:
ddl_stmts = render_partition_ddl(
schema_name, table_name, partition_by, partition_values,
columns, max_partitions=max_partitions,
)
with conn.cursor() as cur:
for stmt in ddl_stmts:
cur.execute(stmt)
return return
if if_exists == "append": if if_exists == "append":
_assert_schema_compatible(conn, schema_name, table_name, columns) _assert_schema_compatible(conn, schema_name, table_name, columns)
return return
else: else:
_create_table_sql(conn, schema_name, table_name, columns) _create_table_sql(
conn, schema_name, table_name, columns,
partition_by=partition_by,
)
if is_partitioned and partition_values is not None:
ddl_stmts = render_partition_ddl(
schema_name, table_name, partition_by, partition_values,
columns, max_partitions=max_partitions,
)
with conn.cursor() as cur:
for stmt in ddl_stmts:
cur.execute(stmt)
# ---------------------------------------------------------------------------
# Partition support
# ---------------------------------------------------------------------------
def _sanitize_partition_value(value: Any, parent_table: str = "") -> str:
"""Convert a partition value into a safe, deterministic table-name suffix.
Rules:
- Convert to string, lowercase
- Replace non-alphanumeric runs with ``_``
- Collapse consecutive underscores, strip leading/trailing ``_``
- None/NaN ``null``; empty string ``empty``
- Truncate to fit within PostgreSQL's 63-character identifier limit
accounting for ``parent_table`` + ``_`` separator
"""
if value is None or (isinstance(value, float) and (pd.isna(value) or math.isnan(value))):
token = "null"
elif isinstance(value, dt.date) or isinstance(value, dt.datetime):
token = value.isoformat()
elif isinstance(value, dt.time):
token = value.isoformat()
else:
token = str(value)
token = token.lower()
token = re.sub(r"[^a-z0-9]+", "_", token)
token = re.sub(r"_+", "_", token)
token = token.strip("_")
if not token:
if value is None or (isinstance(value, float) and pd.isna(value)):
token = "null"
elif isinstance(value, str) and value == "":
token = "empty"
else:
token = "value"
# Truncate to keep total table name within PG's 63-char limit.
if parent_table:
# Reserve room for parent + underscore separator.
max_token_len = _PG_IDENT_MAX_LEN - len(parent_table) - 1
if max_token_len < 1:
raise ValueError(
f"Parent table name {parent_table!r} is too long "
f"({len(parent_table)} chars) to create child partitions."
)
if len(token) > max_token_len:
token = token[:max_token_len].rstrip("_")
return token
def _render_partition_value_literal(value: Any, pg_type: str) -> str:
"""Render a Python value as a SQL literal for ``FOR VALUES IN (...)``.
- None/NaN ``NULL``
- Strings single-quoted with ``'`` escaped to ``''``
- Numbers plain numeric literal
- Booleans ``TRUE`` / ``FALSE``
- Dates ``DATE 'YYYY-MM-DD'``
- Timestamps ``TIMESTAMP 'YYYY-MM-DD HH:MM:SS'``
- Times ``TIME 'HH:MM:SS'``
"""
if value is None or (isinstance(value, float) and pd.isna(value)):
return "NULL"
pg_upper = pg_type.upper()
if pg_upper in ("BOOLEAN", "BOOL"):
return "TRUE" if value else "FALSE"
if pg_upper in ("INTEGER", "BIGINT", "SMALLINT", "INT", "INT4", "INT8", "INT2"):
return str(int(value))
if pg_upper in ("DOUBLE PRECISION", "REAL", "NUMERIC", "DECIMAL",
"FLOAT4", "FLOAT8"):
return str(value)
if pg_upper == "DATE":
if isinstance(value, (dt.date, dt.datetime)):
return f"DATE '{value.isoformat()}'"
return f"DATE '{value}'"
if pg_upper in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE",
"TIMESTAMP WITH TIME ZONE", "TIMESTAMPTZ"):
if isinstance(value, (dt.datetime, pd.Timestamp)):
return f"TIMESTAMP '{value.isoformat()}'"
if isinstance(value, dt.date):
return f"TIMESTAMP '{dt.datetime(value.year, value.month, value.day).isoformat()}'"
return f"TIMESTAMP '{value}'"
if pg_upper in ("TIME", "TIME WITHOUT TIME ZONE",
"TIME WITH TIME ZONE", "TIMETZ"):
if isinstance(value, dt.time):
return f"TIME '{value.isoformat()}'"
return f"TIME '{value}'"
# Default: treat as text — single-quote with escaping.
escaped = str(value).replace("'", "''")
return f"'{escaped}'"
def _normalize_partition_value(value: Any, pg_type: str) -> Any:
"""Normalize a raw partition value to its Python-native form.
Applies the same semantic normalization that :func:`_prepare_for_copy`
uses, so partition discovery deduplicates on the routed value rather
than the raw source representation.
"""
# Handle pandas null types
if value is None:
return None
if isinstance(value, float) and (pd.isna(value) or math.isnan(value)):
return None
try:
if pd.isna(value):
return None
except (TypeError, ValueError):
pass
pg_upper = pg_type.upper()
if pg_upper in ("INTEGER", "BIGINT", "SMALLINT", "INT", "INT4", "INT8", "INT2"):
if isinstance(value, str):
value = value.strip()
if value == "":
return None
try:
return int(float(value))
except (TypeError, ValueError):
return None
if pg_upper in ("DOUBLE PRECISION", "REAL", "NUMERIC", "DECIMAL",
"FLOAT4", "FLOAT8"):
if isinstance(value, str):
value = value.strip()
if value == "":
return None
try:
result = float(value)
return None if math.isnan(result) else result
except (TypeError, ValueError):
return None
if pg_upper == "DATE":
if isinstance(value, dt.datetime):
return value.date()
if isinstance(value, dt.date):
return value
if isinstance(value, str):
if value.strip() == "":
return None
try:
return dt.date.fromisoformat(value.strip())
except (ValueError, TypeError):
return None
return None
if pg_upper in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE",
"TIMESTAMP WITH TIME ZONE", "TIMESTAMPTZ"):
if isinstance(value, dt.datetime):
return value
if isinstance(value, pd.Timestamp):
return value.to_pydatetime() if not pd.isna(value) else None
if isinstance(value, dt.date):
return dt.datetime(value.year, value.month, value.day)
if isinstance(value, str):
if value.strip() == "":
return None
try:
return dt.datetime.fromisoformat(value.strip())
except (ValueError, TypeError):
return None
return None
if pg_upper in ("TIME", "TIME WITHOUT TIME ZONE",
"TIME WITH TIME ZONE", "TIMETZ"):
return _seconds_to_time(value)
if pg_upper in ("BOOLEAN", "BOOL"):
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return bool(value)
if isinstance(value, str):
return value.strip().lower() in ("true", "1", "t", "yes")
return None
# Text-like types: None, pandas nulls, and '' all become None
# because copy_dataframes() sends empty strings with NULL ''.
if pg_upper in ("TEXT", "VARCHAR", "CHARACTER VARYING", "CHAR", "CHARACTER", "BPCHAR"):
if isinstance(value, str):
if value == "":
return None
return value
return str(value)
# Fallback: return as-is converted to native Python type
if hasattr(value, "item"):
return value.item()
return value
def discover_partition_values(
df: pd.DataFrame,
partition_by: list[str],
columns: Optional[Dict[str, ColumnSpec]] = None,
) -> dict:
"""Build a nested structure of unique partition values from a DataFrame.
For ``partition_by = ['state', 'zip']`` returns::
{
'MO': {'63101': {}, '63102': {}},
'IL': {'62001': {}, '62002': {}}
}
When ``columns`` is provided, values are normalized through
:func:`_normalize_partition_value` to match the routed values Postgres
will see during ``COPY``.
None/NaN values are included as a distinct partition value (``None`` key).
Values are converted to Python native types (not numpy types).
"""
if not partition_by:
return {}
def _to_native(val: Any) -> Any:
"""Convert numpy scalars to Python native types."""
if val is None:
return None
if isinstance(val, float) and pd.isna(val):
return None
if hasattr(val, "item"):
return val.item()
return val
def _build_level(sub_df: pd.DataFrame, fields: list[str]) -> dict:
if not fields or sub_df.empty:
return {}
field = fields[0]
remaining = fields[1:]
result: dict = {}
# Get unique values, handling NaN
unique_vals = sub_df[field].unique()
for raw_val in unique_vals:
val = _to_native(raw_val)
# Normalize if column spec is available
if columns and field in columns:
val = _normalize_partition_value(val, columns[field].postgres_type)
if remaining:
# Filter rows matching this value
if val is None:
mask = sub_df[field].isna() | sub_df[field].map(
lambda v: v is None or (isinstance(v, float) and pd.isna(v))
or (isinstance(v, str) and v == ""
and columns and field in columns
and columns[field].postgres_type.upper() in (
"TEXT", "VARCHAR", "CHARACTER VARYING",
"CHAR", "CHARACTER", "BPCHAR"))
)
else:
mask = sub_df[field].map(lambda v, target=val: _matches(v, target, field))
child_df = sub_df[mask]
result[val] = _build_level(child_df, remaining)
else:
result[val] = {}
return result
def _matches(raw_val: Any, target: Any, field_name: str) -> bool:
"""Check if a raw value normalizes to the target."""
native = _to_native(raw_val)
if columns and field_name in columns:
native = _normalize_partition_value(native, columns[field_name].postgres_type)
if target is None:
return native is None
return native == target
return _build_level(df, list(partition_by))
def discover_partition_values_chunked(
chunk_iter: Iterable[pd.DataFrame],
partition_by: list[str],
columns: Optional[Dict[str, ColumnSpec]] = None,
) -> dict:
"""Discover partition values across an iterable of DataFrame chunks.
Scans the entire file chunk-by-chunk, collecting unique partition
column values and merging them into a single nested partition tree.
This avoids materializing the full file in memory.
"""
if not partition_by:
return {}
merged: dict = {}
for chunk_df in chunk_iter:
if chunk_df.empty:
continue
# Only keep partition columns to minimize memory
part_cols = [c for c in partition_by if c in chunk_df.columns]
if len(part_cols) != len(partition_by):
missing = [c for c in partition_by if c not in chunk_df.columns]
raise ValueError(
f"Partition columns not found in data: {missing}"
)
sub_df = chunk_df[part_cols]
chunk_tree = discover_partition_values(sub_df, partition_by, columns)
_merge_partition_trees(merged, chunk_tree)
return merged
def _merge_partition_trees(target: dict, source: dict) -> None:
"""Merge ``source`` partition tree into ``target`` in place.
Both trees are nested dicts where keys are partition values and values
are either empty dicts (leaf) or nested dicts (intermediate levels).
"""
for key, subtree in source.items():
if key not in target:
target[key] = subtree
else:
# Merge children recursively
if subtree and target[key]:
_merge_partition_trees(target[key], subtree)
elif subtree:
target[key] = subtree
def _count_partitions(tree: dict) -> int:
"""Count total partition tables in a nested partition tree."""
count = 0
for _key, children in tree.items():
count += 1
if children:
count += _count_partitions(children)
return count
def render_partition_ddl(
schema: str,
parent_table: str,
partition_by: list[str],
partition_values: dict,
column_specs: Dict[str, ColumnSpec],
*,
max_partitions: int = 10_000,
) -> list[str]:
"""Generate all child partition DDL statements for the partition tree.
Returns a list of SQL strings to execute in order (depth-first).
The parent ``CREATE TABLE`` is NOT included it is rendered separately
by :func:`render_create_table`.
Logs a warning if the total partition count exceeds ``max_partitions``,
but continues.
"""
if not partition_by or not partition_values:
return []
total = _count_partitions(partition_values)
if total > max_partitions:
logger.warning(
"Partition count (%d) exceeds threshold (%d). "
"This may impact database performance.",
total, max_partitions,
)
print(
f"[warn] partition plan for {schema}.{parent_table} will create "
f"{total:,} partition tables, exceeding max_partitions={max_partitions:,}",
file=sys.stderr,
)
# Track used child names at each parent level to detect collisions
statements: list[str] = []
_render_partition_ddl_recursive(
schema, parent_table, partition_by, partition_values,
column_specs, 0, statements,
)
return statements
def _render_partition_ddl_recursive(
schema: str,
parent_table: str,
partition_by: list[str],
values: dict,
column_specs: Dict[str, ColumnSpec],
depth: int,
statements: list[str],
) -> None:
"""Recursively generate partition DDL statements (depth-first)."""
field_name = partition_by[depth]
next_field = partition_by[depth + 1] if depth + 1 < len(partition_by) else None
field_spec = column_specs.get(field_name)
pg_type = field_spec.postgres_type if field_spec else "TEXT"
# Track names used at this level under this parent to handle collisions
used_names: Dict[str, Any] = {}
# Sort values deterministically: None first, then by string representation
def _sort_key(val: Any) -> Tuple[int, str]:
if val is None:
return (0, "")
return (1, str(val))
sorted_values = sorted(values.keys(), key=_sort_key)
for val in sorted_values:
children = values[val]
token = _sanitize_partition_value(val, parent_table)
child_name = f"{parent_table}_{token}"
# Handle collisions
if child_name in used_names and used_names[child_name] is not val:
# Append a short hash of the value to disambiguate
val_hash = hashlib.sha256(repr(val).encode()).hexdigest()[:8]
# Re-truncate token to make room for _hash
max_token_len = _PG_IDENT_MAX_LEN - len(parent_table) - 1 - 9 # _hash8
if max_token_len < 1:
max_token_len = 1
truncated_token = token[:max_token_len].rstrip("_")
child_name = f"{parent_table}_{truncated_token}_{val_hash}"
# Final length check
if len(child_name) > _PG_IDENT_MAX_LEN:
child_name = child_name[:_PG_IDENT_MAX_LEN]
used_names[child_name] = val
literal = _render_partition_value_literal(val, pg_type)
if next_field is not None:
# Intermediate partition: itself partitioned by the next field
stmt = (
f"CREATE TABLE {_qualified(schema, child_name)} "
f"PARTITION OF {_qualified(schema, parent_table)} "
f"FOR VALUES IN ({literal}) "
f"PARTITION BY LIST ({_quote_ident(next_field)});"
)
statements.append(stmt)
# Recurse into children
if children:
_render_partition_ddl_recursive(
schema, child_name, partition_by, children,
column_specs, depth + 1, statements,
)
else:
# Leaf partition
stmt = (
f"CREATE TABLE {_qualified(schema, child_name)} "
f"PARTITION OF {_qualified(schema, parent_table)} "
f"FOR VALUES IN ({literal});"
)
statements.append(stmt)
# ---------------------------------------------------------------------------
# Index support
# ---------------------------------------------------------------------------
def render_create_indexes(
schema: str,
tablename: str,
indexes: List[str],
) -> List[str]:
"""Generate ``CREATE INDEX IF NOT EXISTS`` DDL for each column in *indexes*.
Each index is a simple B-tree index on a single column. The index name
follows the pattern ``ix_{tablename}_{column}`` (raw, unsanitized names
wrapped with :func:`_quote_ident`). The table reference is fully
qualified as ``schema.tablename``.
If the generated index name exceeds PostgreSQL's 63-character identifier
limit, it is truncated and a short hash suffix is appended to preserve
uniqueness (similar to partition name truncation).
Returns a list of SQL strings, one per index.
"""
stmts: List[str] = []
for col in indexes:
idx_name = f"ix_{tablename}_{col}"
if len(idx_name) > _PG_IDENT_MAX_LEN:
# Truncate and append an 8-char hash for uniqueness.
name_hash = hashlib.sha256(idx_name.encode()).hexdigest()[:8]
# 9 = 1 underscore + 8 hash chars
truncated = idx_name[: _PG_IDENT_MAX_LEN - 9].rstrip("_")
idx_name = f"{truncated}_{name_hash}"
stmt = (
f"CREATE INDEX IF NOT EXISTS {_quote_ident(idx_name)} "
f"ON {_qualified(schema, tablename)} ({_quote_ident(col)});"
)
stmts.append(stmt)
return stmts
def create_indexes(
conn,
schema: str,
tablename: str,
indexes: List[str],
) -> None:
"""Execute ``CREATE INDEX IF NOT EXISTS`` for each column in *indexes*.
Calls :func:`render_create_indexes` to obtain the DDL, executes each
statement, commits immediately after each successful creation, and logs
progress to stderr. If an individual index creation fails (e.g. a name
collision unrelated to ``IF NOT EXISTS``), the transaction is rolled back
(affecting only the failed statement) and the remaining indexes are still
attempted.
"""
stmts = render_create_indexes(schema, tablename, indexes)
with conn.cursor() as cur:
for stmt, col in zip(stmts, indexes):
try:
cur.execute(stmt)
conn.commit()
print(
f"[info] created index ix_{tablename}_{col} "
f"on {schema}.{tablename}({col})",
file=sys.stderr,
)
except Exception as exc:
conn.rollback()
print(
f"[warn] failed to create index ix_{tablename}_{col} "
f"on {schema}.{tablename}({col}): {exc}",
file=sys.stderr,
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -1032,10 +1753,12 @@ def copy_dataframes(
dfs: Iterable[pd.DataFrame], dfs: Iterable[pd.DataFrame],
columns: Dict[str, ColumnSpec], columns: Dict[str, ColumnSpec],
) -> int: ) -> int:
"""Stream an iterable of DataFrames into one ``COPY`` session. """Stream an iterable of DataFrames into Postgres, committing each chunk.
All chunks share a cursor and transaction, so a failure mid-stream Each non-empty chunk is copied via ``COPY ... FROM STDIN`` and committed
rolls back the whole load when the caller hasn't committed yet. before the next chunk is processed, so an interrupted or failed load
retains the rows from previously committed chunks. The first chunk's
commit also flushes any pending DDL (e.g. a preceding ``CREATE TABLE``).
Empty chunks are skipped. Returns the total rows inserted. Empty chunks are skipped. Returns the total rows inserted.
""" """
col_list = ", ".join(_quote_ident(name) for name in columns.keys()) col_list = ", ".join(_quote_ident(name) for name in columns.keys())
@ -1060,6 +1783,7 @@ def copy_dataframes(
) )
buf.seek(0) buf.seek(0)
cur.copy_expert(sql, buf) cur.copy_expert(sql, buf)
conn.commit()
total += len(prepared) total += len(prepared)
return total return total
@ -1205,6 +1929,24 @@ def main(argv: Optional[List[str]] = None) -> int:
preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude) preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude)
columns = infer_schema(preview_df, meta) columns = infer_schema(preview_df, meta)
# Validate partition columns exist in the schema after filtering.
if cfg.partition_by:
missing_pcols = [c for c in cfg.partition_by if c not in columns]
if missing_pcols:
raise ValueError(
f"partition_by references columns not present in the "
f"(filtered) schema: {missing_pcols}"
)
# Validate index columns exist in the schema after filtering.
if cfg.indexes:
missing_icols = [c for c in cfg.indexes if c not in columns]
if missing_icols:
raise ValueError(
f"indexes references columns not present in the "
f"(filtered) schema: {missing_icols}"
)
if args.validate: if args.validate:
manifest_path = cfg.filename.with_suffix("").with_suffix(".expected.json") manifest_path = cfg.filename.with_suffix("").with_suffix(".expected.json")
# The above strips .xpt then appends .expected.json, e.g. # The above strips .xpt then appends .expected.json, e.g.
@ -1217,8 +1959,59 @@ def main(argv: Optional[List[str]] = None) -> int:
return 1 return 1
print(f"validation OK ({len(columns)} columns match {manifest_path.name})") print(f"validation OK ({len(columns)} columns match {manifest_path.name})")
# -- Partition value discovery ------------------------------------------
# If partitioned, scan the ENTIRE file to discover all unique partition
# values. The preview is only the first N rows and may miss values.
# In append mode the partitions already exist, so skip the costly scan.
partition_values: Optional[dict] = None
if cfg.partition_by and cfg.if_exists != "append":
print(" discovering partition values (full file scan)...", file=sys.stderr)
def _discovery_chunks():
for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename):
yield apply_column_filter(chunk_df, cfg.include, cfg.exclude)
partition_values = discover_partition_values_chunked(
_discovery_chunks(), cfg.partition_by, columns,
)
total_parts = _count_partitions(partition_values)
print(
f" discovered {total_parts:,} partition tables "
f"across {len(cfg.partition_by)} level(s)",
file=sys.stderr,
)
elif cfg.partition_by and cfg.if_exists == "append":
print(
" [info] append mode: skipping partition discovery "
"(partitions assumed to exist)",
file=sys.stderr,
)
if args.dry_run: if args.dry_run:
print(render_create_table(cfg.schemaname, cfg.tablename, columns)) # Print the parent CREATE TABLE (with PARTITION BY if applicable).
parent_ddl = render_create_table(
cfg.schemaname, cfg.tablename, columns,
partition_by=cfg.partition_by or None,
)
print(parent_ddl)
# Print child partition DDL if partitioned.
if cfg.partition_by and partition_values:
child_stmts = render_partition_ddl(
cfg.schemaname, cfg.tablename, cfg.partition_by,
partition_values, columns,
max_partitions=cfg.max_partitions,
)
for stmt in child_stmts:
print()
print(stmt)
# Print CREATE INDEX DDL if indexes are configured.
if cfg.indexes:
idx_stmts = render_create_indexes(
cfg.schemaname, cfg.tablename, cfg.indexes,
)
for stmt in idx_stmts:
print()
print(stmt)
return 0 return 0
# Release the preview frame before opening the stream - lets the GC reclaim # Release the preview frame before opening the stream - lets the GC reclaim
@ -1241,11 +2034,18 @@ def main(argv: Optional[List[str]] = None) -> int:
conn = connect(user=db_user, password=db_password) conn = connect(user=db_user, password=db_password)
conn.autocommit = False conn.autocommit = False
try: try:
create_table(conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists) create_table(
conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists,
partition_by=cfg.partition_by or None,
partition_values=partition_values,
max_partitions=cfg.max_partitions,
)
inserted = copy_dataframes( inserted = copy_dataframes(
conn, cfg.schemaname, cfg.tablename, _filtered_chunks(), columns conn, cfg.schemaname, cfg.tablename, _filtered_chunks(), columns
) )
conn.commit() conn.commit()
if cfg.indexes:
create_indexes(conn, cfg.schemaname, cfg.tablename, cfg.indexes)
except Exception: except Exception:
conn.rollback() conn.rollback()
raise raise
@ -1256,6 +2056,9 @@ def main(argv: Optional[List[str]] = None) -> int:
f"loaded {inserted} rows into {cfg.schemaname}.{cfg.tablename} " f"loaded {inserted} rows into {cfg.schemaname}.{cfg.tablename} "
f"({len(columns)} columns)" f"({len(columns)} columns)"
) )
if cfg.partition_by and partition_values:
total_parts = _count_partitions(partition_values)
print(f"partitioned by {cfg.partition_by} ({total_parts:,} partition tables)")
print("final schema:") print("final schema:")
print(_format_columns_summary(columns)) print(_format_columns_summary(columns))
return 0 return 0

View File

@ -15,3 +15,26 @@ tablename: kitchensink
# What to do if the target table already exists: fail | replace | append # What to do if the target table already exists: fail | replace | append
# Defaults to fail. # Defaults to fail.
if_exists: append if_exists: append
# partition_by: Partition the table by unique values of these columns.
# Columns are applied in cascading order (first column = top-level partition).
# Requires if_exists: replace or fail (not append for initial creation).
# Single field:
# partition_by: state
# Multiple fields (cascading):
# partition_by:
# - state
# - zip
#
# max_partitions: Warning threshold for total partition count (default: 10000).
# If the number of partitions exceeds this, a warning is logged but loading continues.
# max_partitions: 10000
# indexes: Create B-tree indexes on these columns after data loading.
# Indexes are created with IF NOT EXISTS for safe use with append mode.
# Single column:
# indexes: state
# Multiple columns (one index per column):
# indexes:
# - state
# - zip

View File

@ -31,6 +31,30 @@ auto_detect: true
# exclude: # exclude:
# - ALLNULL # - ALLNULL
# Folder-level partition_by: Partition every cluster's table by unique values
# of these columns. Inherited by all clusters unless overridden per-cluster.
# Requires if_exists: replace or fail (not append for initial creation).
# Single field:
# partition_by: state
# Multiple fields (cascading):
# partition_by:
# - state
# - zip
#
# Folder-level max_partitions: Warning threshold for total partition count
# (default: 10000). Inherited by all clusters unless overridden per-cluster.
# max_partitions: 10000
# Folder-level indexes: Create B-tree indexes on these columns after data
# loading. Inherited by all clusters unless overridden per-cluster.
# Indexes are created with IF NOT EXISTS for safe use with append mode.
# Single column:
# indexes: state
# Multiple columns (one index per column):
# indexes:
# - state
# - zip
# Explicit cluster patterns. Each pattern is matched against the file # Explicit cluster patterns. Each pattern is matched against the file
# *basename*. Files matched by a pattern are pulled out of the auto-detect # *basename*. Files matched by a pattern are pulled out of the auto-detect
# pool, so explicit and auto clusters compose cleanly. # pool, so explicit and auto clusters compose cleanly.
@ -48,6 +72,26 @@ clusters:
# tablename: group_b # tablename: group_b
# if_exists: append # if_exists: append
# Per-cluster partition_by / max_partitions override. These take precedence
# over the folder-level defaults above.
#
# - pattern: '^group_c\d+\.xpt$'
# tablename: group_c
# partition_by:
# - region
# - year
# max_partitions: 500
# Per-cluster indexes override. Takes precedence over the folder-level
# indexes default above. An explicit empty list disables indexing for
# this cluster even when the folder default has indexes.
#
# - pattern: '^group_d\d+\.xpt$'
# tablename: group_d
# indexes:
# - region
# - year
# With only the gq pattern explicit, auto_detect: true will still bucket # With only the gq pattern explicit, auto_detect: true will still bucket
# group_b1.xpt + group_b2.xpt into a "group_b" cluster and the lone # group_b1.xpt + group_b2.xpt into a "group_b" cluster and the lone
# standalone.xpt into a "standalone" cluster. See generate_sample_folder.py # standalone.xpt into a "standalone" cluster. See generate_sample_folder.py