diff --git a/generic_loader/load_folder.py b/generic_loader/load_folder.py index 9301786..5136fc1 100644 --- a/generic_loader/load_folder.py +++ b/generic_loader/load_folder.py @@ -32,9 +32,19 @@ USAGE # include: [ID, INTCOL] # exclude: [ALLNULL] + # Optional folder default for LIST partitioning. Omit or set [] for no + # partitioning. Accepts a single string or a list of column names. + # partition_by: + # - state + # - zip + + # Optional folder default threshold. Default: 10000. + # max_partitions: 10000 + # Optional explicit cluster patterns. Each pattern is matched against the # file *basename*. Matched files are pulled out of the auto-detect pool. - # Per-cluster if_exists/include/exclude override the folder-level defaults. + # Per-cluster if_exists/include/exclude/partition_by/max_partitions + # override the folder-level defaults. clusters: - pattern: '^group_a\\d+\\.sas7bdat$' tablename: group_a @@ -51,9 +61,10 @@ USAGE Flags: --config PATH Required. Path to the YAML config above. - --dry-run Print the discovered clusters and the inferred CREATE - TABLE for each (schema from the first file of the - cluster). The database is never touched. + --dry-run Print the discovered clusters and the inferred DDL for + each (CREATE TABLE plus partition DDL when applicable). + For partitioned clusters all files are scanned to + discover partition values. The database is never touched. --fail-fast Abort the whole run on the first cluster failure. Default is to log the failure, roll that cluster back, and keep going. @@ -113,15 +124,21 @@ from dotenv import load_dotenv from load_sas import ( VALID_IF_EXISTS, + _count_partitions, + _merge_partition_trees, apply_column_filter, assert_schema_compatible, connect, copy_dataframes, + create_indexes, create_table, + discover_partition_values_chunked, infer_schema, iter_sas_chunks, read_sas_preview, + render_create_indexes, render_create_table, + render_partition_ddl, ) @@ -135,6 +152,13 @@ SAS_EXTENSIONS = (".sas7bdat", ".xpt", ".xport") @dataclass class ClusterSpec: + """Resolved per-cluster load settings. + + ``partition_by``, ``max_partitions``, and ``indexes`` are resolved from + the folder defaults and any per-cluster overrides during + :func:`discover_clusters`. + """ + tablename: str files: List[Path] if_exists: str @@ -142,11 +166,20 @@ class ClusterSpec: exclude: Optional[List[str]] source: str # "explicit" or "auto" pattern: Optional[str] = None + partition_by: List[str] = field(default_factory=list) + max_partitions: int = 10_000 + indexes: List[str] = field(default_factory=list) @dataclass class _ExplicitPattern: - """Parsed form of a single ``clusters[*]`` YAML entry.""" + """Parsed form of a single ``clusters[*]`` YAML entry. + + ``partition_by`` defaults to ``None`` meaning "inherit from folder level". + An explicit empty list ``[]`` means "disable partitioning for this cluster". + ``max_partitions`` defaults to ``None`` meaning "inherit from folder level". + ``indexes`` defaults to ``None`` meaning "inherit from folder level". + """ pattern: re.Pattern raw_pattern: str @@ -154,10 +187,19 @@ class _ExplicitPattern: if_exists: Optional[str] = None include: Optional[List[str]] = None exclude: Optional[List[str]] = None + partition_by: Optional[List[str]] = None + max_partitions: Optional[int] = None + indexes: Optional[List[str]] = None @dataclass class FolderConfig: + """Folder-level configuration parsed from YAML. + + ``partition_by``, ``max_partitions``, and ``indexes`` serve as defaults + for every cluster unless overridden at the cluster level. + """ + folder: Path schemaname: str if_exists: str = "fail" @@ -165,6 +207,9 @@ class FolderConfig: include: Optional[List[str]] = None exclude: Optional[List[str]] = None explicit: List[_ExplicitPattern] = field(default_factory=list) + partition_by: List[str] = field(default_factory=list) + max_partitions: int = 10_000 + indexes: List[str] = field(default_factory=list) # --------------------------------------------------------------------------- @@ -197,8 +242,141 @@ def _parse_columns_filter( return include_out, exclude_out +def _parse_partition_by( + raw_value: Any, where: str, *, allow_none: bool = False +) -> Optional[List[str]]: + """Parse a ``partition_by`` value from YAML. + + Returns a list of non-empty, unique column name strings. When + ``allow_none`` is True (used for per-cluster entries), an omitted key + returns ``None`` to signal "inherit from folder level". An explicit + empty list ``[]`` always returns ``[]``. + """ + if raw_value is None: + return None if allow_none else [] + if isinstance(raw_value, str): + if not raw_value.strip(): + raise ValueError(f"{where}: 'partition_by' string must be non-empty.") + return [raw_value.strip()] + if isinstance(raw_value, list): + if len(raw_value) == 0: + return [] + result: List[str] = [] + for i, item in enumerate(raw_value): + if not isinstance(item, str) or not item.strip(): + raise ValueError( + f"{where}: 'partition_by[{i}]' must be a non-empty string." + ) + result.append(str(item).strip()) + if len(result) != len(set(result)): + raise ValueError( + f"{where}: 'partition_by' contains duplicate column names." + ) + return result + raise ValueError( + f"{where}: 'partition_by' must be a string or list of strings." + ) + + +def _parse_max_partitions( + raw_value: Any, where: str, *, allow_none: bool = False +) -> Optional[int]: + """Parse a ``max_partitions`` value from YAML. + + Returns a positive integer. When ``allow_none`` is True (used for + per-cluster entries), an omitted key returns ``None`` to signal + "inherit from folder level". + """ + if raw_value is None: + return None if allow_none else 10_000 + try: + value = int(raw_value) + except (TypeError, ValueError): + raise ValueError( + f"{where}: 'max_partitions' must be a positive integer, " + f"got {raw_value!r}" + ) + if value <= 0: + raise ValueError( + f"{where}: 'max_partitions' must be a positive integer, " + f"got {value}" + ) + return value + + +def _validate_partition_vs_columns( + partition_by: List[str], + exclude: Optional[List[str]], + where: str, +) -> None: + """Raise if any ``partition_by`` column is in the ``exclude`` list.""" + if not partition_by or exclude is None: + return + excluded_parts = [c for c in partition_by if c in exclude] + if excluded_parts: + raise ValueError( + f"{where}: 'exclude' removes partition_by columns: {excluded_parts}" + ) + + +def _parse_indexes( + raw_value: Any, where: str, *, allow_none: bool = False +) -> Optional[List[str]]: + """Parse an ``indexes`` value from YAML. + + Returns a list of non-empty, unique column name strings. When + ``allow_none`` is True (used for per-cluster entries), an omitted key + returns ``None`` to signal "inherit from folder level". An explicit + empty list ``[]`` always returns ``[]``. + """ + if raw_value is None: + return None if allow_none else [] + if isinstance(raw_value, str): + if not raw_value.strip(): + raise ValueError(f"{where}: 'indexes' string must be non-empty.") + return [raw_value.strip()] + if isinstance(raw_value, list): + if len(raw_value) == 0: + return [] + result: List[str] = [] + for i, item in enumerate(raw_value): + if not isinstance(item, str) or not item.strip(): + raise ValueError( + f"{where}: 'indexes[{i}]' must be a non-empty string." + ) + result.append(str(item).strip()) + if len(result) != len(set(result)): + raise ValueError( + f"{where}: 'indexes' contains duplicate column names." + ) + return result + raise ValueError( + f"{where}: 'indexes' must be a string or list of strings." + ) + + +def _validate_indexes_vs_columns( + indexes: List[str], + exclude: Optional[List[str]], + where: str, +) -> None: + """Raise if any ``indexes`` column is in the ``exclude`` list.""" + if not indexes or exclude is None: + return + excluded_idx = [c for c in indexes if c in exclude] + if excluded_idx: + raise ValueError( + f"{where}: 'exclude' removes index columns: {excluded_idx}" + ) + + def load_folder_config(path: Path) -> FolderConfig: - """Parse and validate the folder-level YAML config at ``path``.""" + """Parse and validate the folder-level YAML config at ``path``. + + Supports optional ``partition_by`` and ``max_partitions`` at both the + folder level (defaults for all clusters) and per explicit cluster entry + (overrides the folder default). + """ path = Path(path) with path.open("r", encoding="utf-8") as f: raw = yaml.safe_load(f) @@ -221,6 +399,19 @@ def load_folder_config(path: Path) -> FolderConfig: include, exclude = _parse_columns_filter(raw, f"Config {path}") + # -- folder-level partition settings ------------------------------------ + partition_by = _parse_partition_by( + raw.get("partition_by"), f"Config {path}" + ) + max_partitions = _parse_max_partitions( + raw.get("max_partitions"), f"Config {path}" + ) + _validate_partition_vs_columns(partition_by, exclude, f"Config {path}") + + # -- folder-level index settings ---------------------------------------- + indexes = _parse_indexes(raw.get("indexes"), f"Config {path}") + _validate_indexes_vs_columns(indexes, exclude, f"Config {path}") + explicit: List[_ExplicitPattern] = [] clusters_raw = raw.get("clusters") or [] if not isinstance(clusters_raw, list): @@ -242,6 +433,26 @@ def load_folder_config(path: Path) -> FolderConfig: else None ) c_include, c_exclude = _parse_columns_filter(entry, where) + + # -- per-cluster partition settings --------------------------------- + c_partition_by = _parse_partition_by( + entry.get("partition_by"), where, allow_none=True + ) + c_max_partitions = _parse_max_partitions( + entry.get("max_partitions"), where, allow_none=True + ) + # Validate partition_by vs the effective exclude for this cluster. + effective_exclude = c_exclude if c_exclude is not None else exclude + effective_pb = c_partition_by if c_partition_by is not None else partition_by + _validate_partition_vs_columns(effective_pb, effective_exclude, where) + + # -- per-cluster index settings ------------------------------------- + c_indexes = _parse_indexes( + entry.get("indexes"), where, allow_none=True + ) + effective_idx = c_indexes if c_indexes is not None else indexes + _validate_indexes_vs_columns(effective_idx, effective_exclude, where) + explicit.append( _ExplicitPattern( pattern=compiled, @@ -250,6 +461,9 @@ def load_folder_config(path: Path) -> FolderConfig: if_exists=c_if_exists, include=c_include, exclude=c_exclude, + partition_by=c_partition_by, + max_partitions=c_max_partitions, + indexes=c_indexes, ) ) @@ -261,6 +475,9 @@ def load_folder_config(path: Path) -> FolderConfig: include=include, exclude=exclude, explicit=explicit, + partition_by=partition_by, + max_partitions=max_partitions, + indexes=indexes, ) @@ -300,6 +517,13 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: order; files matched by an earlier pattern are removed from the pool before the next pattern runs. A file matching two patterns triggers a hard error (that's almost always a config bug). + + Partition settings are resolved per cluster: + + * For explicit clusters, ``partition_by`` / ``max_partitions`` from the + cluster entry override the folder defaults when present. ``None`` + means "inherit"; an explicit ``[]`` disables partitioning. + * For auto-detected clusters, folder defaults are inherited directly. """ if not cfg.folder.exists() or not cfg.folder.is_dir(): raise FileNotFoundError(f"Folder not found or not a directory: {cfg.folder}") @@ -320,6 +544,21 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: remaining = list(pool) for patt in cfg.explicit: + # Resolve partition_by: None = inherit folder, [] = disable, list = override + resolved_pb = ( + patt.partition_by if patt.partition_by is not None + else cfg.partition_by + ) + resolved_mp = ( + patt.max_partitions if patt.max_partitions is not None + else cfg.max_partitions + ) + # Resolve indexes: None = inherit folder, [] = disable, list = override + resolved_idx = ( + patt.indexes if patt.indexes is not None + else cfg.indexes + ) + matched = [f for f in remaining if patt.pattern.search(f.name)] if not matched: # Not an error - the folder might legitimately not contain files @@ -333,6 +572,9 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: exclude=patt.exclude if patt.exclude is not None else cfg.exclude, source="explicit", pattern=patt.raw_pattern, + partition_by=resolved_pb, + max_partitions=resolved_mp, + indexes=resolved_idx, ) ) continue @@ -346,6 +588,9 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: exclude=patt.exclude if patt.exclude is not None else cfg.exclude, source="explicit", pattern=patt.raw_pattern, + partition_by=resolved_pb, + max_partitions=resolved_mp, + indexes=resolved_idx, ) ) @@ -363,6 +608,9 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: include=cfg.include, exclude=cfg.exclude, source="auto", + partition_by=cfg.partition_by, + max_partitions=cfg.max_partitions, + indexes=cfg.indexes, ) ) @@ -375,6 +623,7 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: def _infer_cluster_schema(path: Path, include, exclude): + """Infer the Postgres column schema from a SAS file preview.""" preview_df, meta = read_sas_preview(path) preview_df = apply_column_filter(preview_df, include, exclude) total_rows = getattr(meta, "number_rows", None) @@ -382,20 +631,103 @@ def _infer_cluster_schema(path: Path, include, exclude): return columns +def _discover_cluster_partitions( + cluster: ClusterSpec, + columns: Dict, +) -> dict: + """Scan ALL files in ``cluster`` to discover partition values. + + Returns a nested partition-value tree suitable for passing to + :func:`load_sas.render_partition_ddl` and :func:`load_sas.create_table`. + Each file is scanned chunk-by-chunk so the full dataset is never + materialized in memory. + """ + merged: dict = {} + for path in cluster.files: + def _filtered_chunks(p=path): + for chunk_df, _chunk_meta in iter_sas_chunks(p): + yield apply_column_filter( + chunk_df, cluster.include, cluster.exclude + ) + + file_tree = discover_partition_values_chunked( + _filtered_chunks(), cluster.partition_by, columns, + ) + _merge_partition_trees(merged, file_tree) + return merged + + def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int: """Load every file in ``cluster`` into one table. Returns total rows loaded. - The caller owns transaction boundaries. This function does NOT commit or - roll back - :func:`main` does that per cluster so one bad cluster - doesn't poison the rest of the run. + When ``cluster.partition_by`` is non-empty, partition values are + discovered across ALL files before table creation so the full partition + tree exists before any data is copied. + + Commits happen per chunk inside :func:`load_sas.copy_dataframes`. If a + file mid-cluster fails, earlier chunks - including chunks from earlier + files in the cluster - stay committed; only the in-flight chunk is + rolled back by :func:`main`. """ if not cluster.files: return 0 first, *rest = cluster.files first_columns = _infer_cluster_schema(first, cluster.include, cluster.exclude) + + # -- Validate index columns early --------------------------------------- + if cluster.indexes: + missing_icols = [ + c for c in cluster.indexes if c not in first_columns + ] + if missing_icols: + raise ValueError( + f"cluster {cluster.tablename!r}: indexes references " + f"columns not present in the inferred schema: {missing_icols}" + ) + + # -- Partition support -------------------------------------------------- + partition_values: Optional[dict] = None + if cluster.partition_by: + # Validate that all partition_by columns exist in the inferred schema. + missing_pcols = [ + c for c in cluster.partition_by if c not in first_columns + ] + if missing_pcols: + raise ValueError( + f"cluster {cluster.tablename!r}: partition_by references " + f"columns not present in the inferred schema: {missing_pcols}" + ) + + # Discover partition values across ALL files in the cluster. + # In append mode the partitions already exist, so skip the scan. + if cluster.if_exists == "append": + print( + " [info] append mode: skipping partition discovery " + "(partitions assumed to exist)", + file=sys.stderr, + ) + else: + print( + f" discovering partition values across " + f"{len(cluster.files)} file(s)...", + file=sys.stderr, + ) + partition_values = _discover_cluster_partitions( + cluster, first_columns, + ) + total_parts = _count_partitions(partition_values) + print( + f" discovered {total_parts:,} partition table(s) " + f"across {len(cluster.partition_by)} level(s)", + file=sys.stderr, + ) + create_table( - conn, schemaname, cluster.tablename, first_columns, cluster.if_exists + conn, schemaname, cluster.tablename, first_columns, cluster.if_exists, + partition_by=cluster.partition_by or None, + partition_values=partition_values, + max_partitions=cluster.max_partitions, ) total = 0 @@ -407,14 +739,18 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int: for path in rest: columns = _infer_cluster_schema(path, cluster.include, cluster.exclude) # Uses the same check that if_exists=append runs. A type mismatch or - # missing column aborts the cluster; the transaction rollback in - # main() keeps the table from ending up half-loaded. + # missing column aborts the cluster; because chunks commit as they + # load, earlier chunks in the cluster remain in the table. assert_schema_compatible(conn, schemaname, cluster.tablename, columns) total += _stream_file( conn, schemaname, cluster.tablename, path, columns, cluster.include, cluster.exclude, ) + # -- Index support ------------------------------------------------------ + if cluster.indexes: + create_indexes(conn, schemaname, cluster.tablename, cluster.indexes) + return total @@ -458,8 +794,10 @@ def _build_argparser() -> argparse.ArgumentParser: "--dry-run", action="store_true", help=( - "Print discovered clusters and the inferred CREATE TABLE for " - "each; don't touch Postgres." + "Print discovered clusters and the inferred DDL for each " + "(CREATE TABLE plus partition DDL when applicable). For " + "partitioned clusters all files are scanned to discover " + "partition values. The database is never touched." ), ) p.add_argument( @@ -486,9 +824,15 @@ def _describe_cluster(cluster: ClusterSpec) -> str: if cluster.pattern: src += f" pattern={cluster.pattern!r}" files = ", ".join(f.name for f in cluster.files) or "(no matching files)" + parts = "" + if cluster.partition_by: + parts = f"\n partition_by: {cluster.partition_by}" + idx = "" + if cluster.indexes: + idx = f"\n indexes: {cluster.indexes}" return ( f"cluster {cluster.tablename!r} [{src}] if_exists={cluster.if_exists}\n" - f" files: {files}" + f" files: {files}{parts}{idx}" ) @@ -521,9 +865,68 @@ def main(argv: Optional[List[str]] = None) -> int: if args.dry_run: print() for c in loadable: - print(f"--- CREATE TABLE for cluster {c.tablename!r} ---") + print(f"--- DDL for cluster {c.tablename!r} ---") columns = _infer_cluster_schema(c.files[0], c.include, c.exclude) - print(render_create_table(cfg.schemaname, c.tablename, columns)) + # Print parent CREATE TABLE (with PARTITION BY if applicable). + print( + render_create_table( + cfg.schemaname, c.tablename, columns, + partition_by=c.partition_by or None, + ) + ) + # Print child partition DDL when the cluster is partitioned. + if c.partition_by: + # Validate partition columns exist in the schema. + missing_pcols = [ + col for col in c.partition_by if col not in columns + ] + if missing_pcols: + print( + f" [error] partition_by references columns not in " + f"schema: {missing_pcols}", + file=sys.stderr, + ) + else: + print( + f" discovering partition values across " + f"{len(c.files)} file(s)...", + file=sys.stderr, + ) + partition_values = _discover_cluster_partitions( + c, columns, + ) + total_parts = _count_partitions(partition_values) + print( + f" discovered {total_parts:,} partition table(s) " + f"across {len(c.partition_by)} level(s)", + file=sys.stderr, + ) + child_stmts = render_partition_ddl( + cfg.schemaname, c.tablename, c.partition_by, + partition_values, columns, + max_partitions=c.max_partitions, + ) + for stmt in child_stmts: + print() + print(stmt) + # Print CREATE INDEX DDL when the cluster has indexes. + if c.indexes: + missing_icols = [ + col for col in c.indexes if col not in columns + ] + if missing_icols: + print( + f" [error] indexes references columns not in " + f"schema: {missing_icols}", + file=sys.stderr, + ) + else: + idx_stmts = render_create_indexes( + cfg.schemaname, c.tablename, c.indexes, + ) + for stmt in idx_stmts: + print() + print(stmt) print() return 0 diff --git a/generic_loader/load_sas.py b/generic_loader/load_sas.py index 83b4f28..74f5ff8 100644 --- a/generic_loader/load_sas.py +++ b/generic_loader/load_sas.py @@ -194,8 +194,8 @@ will fail mid-stream and the whole transaction rolls back. Set matters more than speed. Streaming loads use :func:`iter_sas_chunks` + :func:`copy_dataframes`, which -share one cursor and transaction so a failure mid-file rolls back the whole -load. +commit each chunk as it is copied so an interrupted load retains the rows +that were already written. 7. Tunables ----------- @@ -217,9 +217,13 @@ from __future__ import annotations import argparse import datetime as dt import getpass +import hashlib import io import json +import logging +import math import os +import re import sys from dataclasses import dataclass, field from pathlib import Path @@ -233,6 +237,9 @@ import yaml from dotenv import load_dotenv +logger = logging.getLogger(__name__) + + # --------------------------------------------------------------------------- # Top-level tunables # --------------------------------------------------------------------------- @@ -263,6 +270,9 @@ gentler on memory.""" VALID_IF_EXISTS = ("fail", "replace", "append") +_PG_IDENT_MAX_LEN = 63 +"""PostgreSQL maximum identifier length in bytes (characters for ASCII).""" + # --------------------------------------------------------------------------- # Dataclasses @@ -277,6 +287,9 @@ class LoaderConfig: if_exists: str = "fail" include: Optional[List[str]] = None exclude: Optional[List[str]] = None + partition_by: List[str] = field(default_factory=list) + max_partitions: int = 10_000 + indexes: List[str] = field(default_factory=list) @dataclass @@ -384,6 +397,109 @@ def load_config(path: Path) -> LoaderConfig: if exclude is not None and not isinstance(exclude, list): raise ValueError(f"Config {path}: 'exclude' must be a list of column names.") + # -- partition_by ------------------------------------------------------- + raw_pb = raw.get("partition_by") + if raw_pb is None or (isinstance(raw_pb, list) and len(raw_pb) == 0): + partition_by: List[str] = [] + elif isinstance(raw_pb, str): + if not raw_pb.strip(): + raise ValueError(f"Config {path}: 'partition_by' string must be non-empty.") + partition_by = [raw_pb.strip()] + elif isinstance(raw_pb, list): + partition_by = [] + for i, item in enumerate(raw_pb): + if not isinstance(item, str) or not item.strip(): + raise ValueError( + f"Config {path}: 'partition_by[{i}]' must be a non-empty string." + ) + partition_by.append(str(item).strip()) + if len(partition_by) != len(set(partition_by)): + raise ValueError( + f"Config {path}: 'partition_by' contains duplicate column names." + ) + else: + raise ValueError( + f"Config {path}: 'partition_by' must be a string or list of strings." + ) + + # Validate partition_by vs include/exclude + if partition_by: + inc_list = [str(c) for c in include] if include is not None else None + exc_list = [str(c) for c in exclude] if exclude is not None else None + if inc_list is not None: + missing_in_include = [c for c in partition_by if c not in inc_list] + if missing_in_include: + raise ValueError( + f"Config {path}: 'include' omits partition_by columns: " + f"{missing_in_include}" + ) + if exc_list is not None: + excluded_parts = [c for c in partition_by if c in exc_list] + if excluded_parts: + raise ValueError( + f"Config {path}: 'exclude' removes partition_by columns: " + f"{excluded_parts}" + ) + + # -- max_partitions ----------------------------------------------------- + raw_mp = raw.get("max_partitions", 10_000) + try: + max_partitions = int(raw_mp) + except (TypeError, ValueError): + raise ValueError( + f"Config {path}: 'max_partitions' must be a positive integer, " + f"got {raw_mp!r}" + ) + if max_partitions <= 0: + raise ValueError( + f"Config {path}: 'max_partitions' must be a positive integer, " + f"got {max_partitions}" + ) + + # -- indexes ------------------------------------------------------------ + raw_idx = raw.get("indexes") + if raw_idx is None or (isinstance(raw_idx, list) and len(raw_idx) == 0): + indexes: List[str] = [] + elif isinstance(raw_idx, str): + if not raw_idx.strip(): + raise ValueError(f"Config {path}: 'indexes' string must be non-empty.") + indexes = [raw_idx.strip()] + elif isinstance(raw_idx, list): + indexes = [] + for i, item in enumerate(raw_idx): + if not isinstance(item, str) or not item.strip(): + raise ValueError( + f"Config {path}: 'indexes[{i}]' must be a non-empty string." + ) + indexes.append(str(item).strip()) + if len(indexes) != len(set(indexes)): + raise ValueError( + f"Config {path}: 'indexes' contains duplicate column names." + ) + else: + raise ValueError( + f"Config {path}: 'indexes' must be a string or list of strings." + ) + + # Validate indexes vs include/exclude + if indexes: + inc_list = [str(c) for c in include] if include is not None else None + exc_list = [str(c) for c in exclude] if exclude is not None else None + if exc_list is not None: + excluded_idx = [c for c in indexes if c in exc_list] + if excluded_idx: + raise ValueError( + f"Config {path}: 'exclude' removes index columns: " + f"{excluded_idx}" + ) + if inc_list is not None: + missing_in_include = [c for c in indexes if c not in inc_list] + if missing_in_include: + raise ValueError( + f"Config {path}: 'include' omits index columns: " + f"{missing_in_include}" + ) + return LoaderConfig( filename=filename, schemaname=schemaname, @@ -391,6 +507,9 @@ def load_config(path: Path) -> LoaderConfig: if_exists=if_exists, include=[str(c) for c in include] if include is not None else None, exclude=[str(c) for c in exclude] if exclude is not None else None, + partition_by=partition_by, + max_partitions=max_partitions, + indexes=indexes, ) @@ -753,24 +872,48 @@ def _table_exists(conn, schema: str, table: str) -> bool: return cur.fetchone() is not None -def render_create_table(schema: str, table: str, columns: Dict[str, ColumnSpec]) -> str: +def render_create_table( + schema: str, + table: str, + columns: Dict[str, ColumnSpec], + *, + partition_by: Optional[List[str]] = None, +) -> str: + """Render a ``CREATE TABLE`` statement. + + When ``partition_by`` is provided and non-empty, appends a + ``PARTITION BY LIST ("first_field")`` clause to the statement. + """ lines = [] for spec in columns.values(): null_clause = "" if spec.nullable else " NOT NULL" lines.append(f" {_quote_ident(spec.name)} {spec.postgres_type}{null_clause}") body = ",\n".join(lines) - return f"CREATE TABLE {_qualified(schema, table)} (\n{body}\n);" + suffix = "" + if partition_by: + suffix = f"\nPARTITION BY LIST ({_quote_ident(partition_by[0])})" + return f"CREATE TABLE {_qualified(schema, table)} (\n{body}\n){suffix};" -def _create_table_sql(conn, schema: str, table: str, columns: Dict[str, ColumnSpec]) -> None: - sql = render_create_table(schema, table, columns) +def _create_table_sql( + conn, + schema: str, + table: str, + columns: Dict[str, ColumnSpec], + *, + partition_by: Optional[List[str]] = None, +) -> None: + """Execute a ``CREATE TABLE`` statement, optionally with partitioning.""" + sql = render_create_table(schema, table, columns, partition_by=partition_by) with conn.cursor() as cur: cur.execute(sql) -def _drop_table(conn, schema: str, table: str) -> None: +def _drop_table(conn, schema: str, table: str, *, cascade: bool = False) -> None: + """Drop a table, optionally with CASCADE for partitioned tables.""" + tail = " CASCADE" if cascade else "" with conn.cursor() as cur: - cur.execute(f"DROP TABLE {_qualified(schema, table)}") + cur.execute(f"DROP TABLE {_qualified(schema, table)}{tail}") # Normalization table: map both loader-emitted and Postgres-reported type @@ -815,8 +958,6 @@ def _normalize_type(pg_type: str) -> str: stripped = pg_type.strip().upper() # Remove trailing (n) / (p,s) before the space-separated tail. # Examples: "VARCHAR(10)" -> "VARCHAR"; "TIMESTAMP(6) WITHOUT TIME ZONE" -> "TIMESTAMP WITHOUT TIME ZONE" - import re - stripped = re.sub(r"\(\s*\d+\s*(?:,\s*\d+\s*)?\)", "", stripped).strip() # Collapse doubled whitespace after paren removal. stripped = re.sub(r"\s+", " ", stripped) @@ -893,11 +1034,28 @@ def create_table( table_name: str, columns: Dict[str, ColumnSpec], if_exists: str, + *, + partition_by: Optional[List[str]] = None, + partition_values: Optional[dict] = None, + max_partitions: int = 10_000, ) -> None: - """Create (or verify) the target table according to ``if_exists``.""" + """Create (or verify) the target table according to ``if_exists``. + + When ``partition_by`` is provided and non-empty, the parent table is + created with ``PARTITION BY LIST`` and all child partition DDL from + :func:`render_partition_ddl` is executed immediately after. + + For ``replace`` mode the existing table is dropped with ``CASCADE`` so + all child partitions are removed automatically. + + For ``append`` mode partition creation is skipped entirely — the + partitions are assumed to already exist from the original creation. + """ if if_exists not in VALID_IF_EXISTS: raise ValueError(f"if_exists must be one of {VALID_IF_EXISTS}, got {if_exists!r}") + is_partitioned = bool(partition_by) + exists = _table_exists(conn, schema_name, table_name) if exists: if if_exists == "fail": @@ -905,14 +1063,577 @@ def create_table( f"Table {schema_name}.{table_name} already exists and if_exists=fail" ) if if_exists == "replace": - _drop_table(conn, schema_name, table_name) - _create_table_sql(conn, schema_name, table_name, columns) + _drop_table(conn, schema_name, table_name, cascade=is_partitioned) + _create_table_sql( + conn, schema_name, table_name, columns, + partition_by=partition_by, + ) + if is_partitioned and partition_values is not None: + ddl_stmts = render_partition_ddl( + schema_name, table_name, partition_by, partition_values, + columns, max_partitions=max_partitions, + ) + with conn.cursor() as cur: + for stmt in ddl_stmts: + cur.execute(stmt) return if if_exists == "append": _assert_schema_compatible(conn, schema_name, table_name, columns) return else: - _create_table_sql(conn, schema_name, table_name, columns) + _create_table_sql( + conn, schema_name, table_name, columns, + partition_by=partition_by, + ) + if is_partitioned and partition_values is not None: + ddl_stmts = render_partition_ddl( + schema_name, table_name, partition_by, partition_values, + columns, max_partitions=max_partitions, + ) + with conn.cursor() as cur: + for stmt in ddl_stmts: + cur.execute(stmt) + + +# --------------------------------------------------------------------------- +# Partition support +# --------------------------------------------------------------------------- + + +def _sanitize_partition_value(value: Any, parent_table: str = "") -> str: + """Convert a partition value into a safe, deterministic table-name suffix. + + Rules: + - Convert to string, lowercase + - Replace non-alphanumeric runs with ``_`` + - Collapse consecutive underscores, strip leading/trailing ``_`` + - None/NaN → ``null``; empty string → ``empty`` + - Truncate to fit within PostgreSQL's 63-character identifier limit + accounting for ``parent_table`` + ``_`` separator + """ + if value is None or (isinstance(value, float) and (pd.isna(value) or math.isnan(value))): + token = "null" + elif isinstance(value, dt.date) or isinstance(value, dt.datetime): + token = value.isoformat() + elif isinstance(value, dt.time): + token = value.isoformat() + else: + token = str(value) + + token = token.lower() + token = re.sub(r"[^a-z0-9]+", "_", token) + token = re.sub(r"_+", "_", token) + token = token.strip("_") + + if not token: + if value is None or (isinstance(value, float) and pd.isna(value)): + token = "null" + elif isinstance(value, str) and value == "": + token = "empty" + else: + token = "value" + + # Truncate to keep total table name within PG's 63-char limit. + if parent_table: + # Reserve room for parent + underscore separator. + max_token_len = _PG_IDENT_MAX_LEN - len(parent_table) - 1 + if max_token_len < 1: + raise ValueError( + f"Parent table name {parent_table!r} is too long " + f"({len(parent_table)} chars) to create child partitions." + ) + if len(token) > max_token_len: + token = token[:max_token_len].rstrip("_") + + return token + + +def _render_partition_value_literal(value: Any, pg_type: str) -> str: + """Render a Python value as a SQL literal for ``FOR VALUES IN (...)``. + + - None/NaN → ``NULL`` + - Strings → single-quoted with ``'`` escaped to ``''`` + - Numbers → plain numeric literal + - Booleans → ``TRUE`` / ``FALSE`` + - Dates → ``DATE 'YYYY-MM-DD'`` + - Timestamps → ``TIMESTAMP 'YYYY-MM-DD HH:MM:SS'`` + - Times → ``TIME 'HH:MM:SS'`` + """ + if value is None or (isinstance(value, float) and pd.isna(value)): + return "NULL" + + pg_upper = pg_type.upper() + + if pg_upper in ("BOOLEAN", "BOOL"): + return "TRUE" if value else "FALSE" + + if pg_upper in ("INTEGER", "BIGINT", "SMALLINT", "INT", "INT4", "INT8", "INT2"): + return str(int(value)) + + if pg_upper in ("DOUBLE PRECISION", "REAL", "NUMERIC", "DECIMAL", + "FLOAT4", "FLOAT8"): + return str(value) + + if pg_upper == "DATE": + if isinstance(value, (dt.date, dt.datetime)): + return f"DATE '{value.isoformat()}'" + return f"DATE '{value}'" + + if pg_upper in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE", + "TIMESTAMP WITH TIME ZONE", "TIMESTAMPTZ"): + if isinstance(value, (dt.datetime, pd.Timestamp)): + return f"TIMESTAMP '{value.isoformat()}'" + if isinstance(value, dt.date): + return f"TIMESTAMP '{dt.datetime(value.year, value.month, value.day).isoformat()}'" + return f"TIMESTAMP '{value}'" + + if pg_upper in ("TIME", "TIME WITHOUT TIME ZONE", + "TIME WITH TIME ZONE", "TIMETZ"): + if isinstance(value, dt.time): + return f"TIME '{value.isoformat()}'" + return f"TIME '{value}'" + + # Default: treat as text — single-quote with escaping. + escaped = str(value).replace("'", "''") + return f"'{escaped}'" + + +def _normalize_partition_value(value: Any, pg_type: str) -> Any: + """Normalize a raw partition value to its Python-native form. + + Applies the same semantic normalization that :func:`_prepare_for_copy` + uses, so partition discovery deduplicates on the routed value rather + than the raw source representation. + """ + # Handle pandas null types + if value is None: + return None + if isinstance(value, float) and (pd.isna(value) or math.isnan(value)): + return None + try: + if pd.isna(value): + return None + except (TypeError, ValueError): + pass + + pg_upper = pg_type.upper() + + if pg_upper in ("INTEGER", "BIGINT", "SMALLINT", "INT", "INT4", "INT8", "INT2"): + if isinstance(value, str): + value = value.strip() + if value == "": + return None + try: + return int(float(value)) + except (TypeError, ValueError): + return None + + if pg_upper in ("DOUBLE PRECISION", "REAL", "NUMERIC", "DECIMAL", + "FLOAT4", "FLOAT8"): + if isinstance(value, str): + value = value.strip() + if value == "": + return None + try: + result = float(value) + return None if math.isnan(result) else result + except (TypeError, ValueError): + return None + + if pg_upper == "DATE": + if isinstance(value, dt.datetime): + return value.date() + if isinstance(value, dt.date): + return value + if isinstance(value, str): + if value.strip() == "": + return None + try: + return dt.date.fromisoformat(value.strip()) + except (ValueError, TypeError): + return None + return None + + if pg_upper in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE", + "TIMESTAMP WITH TIME ZONE", "TIMESTAMPTZ"): + if isinstance(value, dt.datetime): + return value + if isinstance(value, pd.Timestamp): + return value.to_pydatetime() if not pd.isna(value) else None + if isinstance(value, dt.date): + return dt.datetime(value.year, value.month, value.day) + if isinstance(value, str): + if value.strip() == "": + return None + try: + return dt.datetime.fromisoformat(value.strip()) + except (ValueError, TypeError): + return None + return None + + if pg_upper in ("TIME", "TIME WITHOUT TIME ZONE", + "TIME WITH TIME ZONE", "TIMETZ"): + return _seconds_to_time(value) + + if pg_upper in ("BOOLEAN", "BOOL"): + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + return value.strip().lower() in ("true", "1", "t", "yes") + return None + + # Text-like types: None, pandas nulls, and '' all become None + # because copy_dataframes() sends empty strings with NULL ''. + if pg_upper in ("TEXT", "VARCHAR", "CHARACTER VARYING", "CHAR", "CHARACTER", "BPCHAR"): + if isinstance(value, str): + if value == "": + return None + return value + return str(value) + + # Fallback: return as-is converted to native Python type + if hasattr(value, "item"): + return value.item() + return value + + +def discover_partition_values( + df: pd.DataFrame, + partition_by: list[str], + columns: Optional[Dict[str, ColumnSpec]] = None, +) -> dict: + """Build a nested structure of unique partition values from a DataFrame. + + For ``partition_by = ['state', 'zip']`` returns:: + + { + 'MO': {'63101': {}, '63102': {}}, + 'IL': {'62001': {}, '62002': {}} + } + + When ``columns`` is provided, values are normalized through + :func:`_normalize_partition_value` to match the routed values Postgres + will see during ``COPY``. + + None/NaN values are included as a distinct partition value (``None`` key). + Values are converted to Python native types (not numpy types). + """ + if not partition_by: + return {} + + def _to_native(val: Any) -> Any: + """Convert numpy scalars to Python native types.""" + if val is None: + return None + if isinstance(val, float) and pd.isna(val): + return None + if hasattr(val, "item"): + return val.item() + return val + + def _build_level(sub_df: pd.DataFrame, fields: list[str]) -> dict: + if not fields or sub_df.empty: + return {} + + field = fields[0] + remaining = fields[1:] + result: dict = {} + + # Get unique values, handling NaN + unique_vals = sub_df[field].unique() + + for raw_val in unique_vals: + val = _to_native(raw_val) + + # Normalize if column spec is available + if columns and field in columns: + val = _normalize_partition_value(val, columns[field].postgres_type) + + if remaining: + # Filter rows matching this value + if val is None: + mask = sub_df[field].isna() | sub_df[field].map( + lambda v: v is None or (isinstance(v, float) and pd.isna(v)) + or (isinstance(v, str) and v == "" + and columns and field in columns + and columns[field].postgres_type.upper() in ( + "TEXT", "VARCHAR", "CHARACTER VARYING", + "CHAR", "CHARACTER", "BPCHAR")) + ) + else: + mask = sub_df[field].map(lambda v, target=val: _matches(v, target, field)) + child_df = sub_df[mask] + result[val] = _build_level(child_df, remaining) + else: + result[val] = {} + + return result + + def _matches(raw_val: Any, target: Any, field_name: str) -> bool: + """Check if a raw value normalizes to the target.""" + native = _to_native(raw_val) + if columns and field_name in columns: + native = _normalize_partition_value(native, columns[field_name].postgres_type) + if target is None: + return native is None + return native == target + + return _build_level(df, list(partition_by)) + + +def discover_partition_values_chunked( + chunk_iter: Iterable[pd.DataFrame], + partition_by: list[str], + columns: Optional[Dict[str, ColumnSpec]] = None, +) -> dict: + """Discover partition values across an iterable of DataFrame chunks. + + Scans the entire file chunk-by-chunk, collecting unique partition + column values and merging them into a single nested partition tree. + This avoids materializing the full file in memory. + """ + if not partition_by: + return {} + + merged: dict = {} + + for chunk_df in chunk_iter: + if chunk_df.empty: + continue + # Only keep partition columns to minimize memory + part_cols = [c for c in partition_by if c in chunk_df.columns] + if len(part_cols) != len(partition_by): + missing = [c for c in partition_by if c not in chunk_df.columns] + raise ValueError( + f"Partition columns not found in data: {missing}" + ) + sub_df = chunk_df[part_cols] + chunk_tree = discover_partition_values(sub_df, partition_by, columns) + _merge_partition_trees(merged, chunk_tree) + + return merged + + +def _merge_partition_trees(target: dict, source: dict) -> None: + """Merge ``source`` partition tree into ``target`` in place. + + Both trees are nested dicts where keys are partition values and values + are either empty dicts (leaf) or nested dicts (intermediate levels). + """ + for key, subtree in source.items(): + if key not in target: + target[key] = subtree + else: + # Merge children recursively + if subtree and target[key]: + _merge_partition_trees(target[key], subtree) + elif subtree: + target[key] = subtree + + +def _count_partitions(tree: dict) -> int: + """Count total partition tables in a nested partition tree.""" + count = 0 + for _key, children in tree.items(): + count += 1 + if children: + count += _count_partitions(children) + return count + + +def render_partition_ddl( + schema: str, + parent_table: str, + partition_by: list[str], + partition_values: dict, + column_specs: Dict[str, ColumnSpec], + *, + max_partitions: int = 10_000, +) -> list[str]: + """Generate all child partition DDL statements for the partition tree. + + Returns a list of SQL strings to execute in order (depth-first). + The parent ``CREATE TABLE`` is NOT included — it is rendered separately + by :func:`render_create_table`. + + Logs a warning if the total partition count exceeds ``max_partitions``, + but continues. + """ + if not partition_by or not partition_values: + return [] + + total = _count_partitions(partition_values) + if total > max_partitions: + logger.warning( + "Partition count (%d) exceeds threshold (%d). " + "This may impact database performance.", + total, max_partitions, + ) + print( + f"[warn] partition plan for {schema}.{parent_table} will create " + f"{total:,} partition tables, exceeding max_partitions={max_partitions:,}", + file=sys.stderr, + ) + + # Track used child names at each parent level to detect collisions + statements: list[str] = [] + _render_partition_ddl_recursive( + schema, parent_table, partition_by, partition_values, + column_specs, 0, statements, + ) + return statements + + +def _render_partition_ddl_recursive( + schema: str, + parent_table: str, + partition_by: list[str], + values: dict, + column_specs: Dict[str, ColumnSpec], + depth: int, + statements: list[str], +) -> None: + """Recursively generate partition DDL statements (depth-first).""" + field_name = partition_by[depth] + next_field = partition_by[depth + 1] if depth + 1 < len(partition_by) else None + field_spec = column_specs.get(field_name) + pg_type = field_spec.postgres_type if field_spec else "TEXT" + + # Track names used at this level under this parent to handle collisions + used_names: Dict[str, Any] = {} + + # Sort values deterministically: None first, then by string representation + def _sort_key(val: Any) -> Tuple[int, str]: + if val is None: + return (0, "") + return (1, str(val)) + + sorted_values = sorted(values.keys(), key=_sort_key) + + for val in sorted_values: + children = values[val] + token = _sanitize_partition_value(val, parent_table) + child_name = f"{parent_table}_{token}" + + # Handle collisions + if child_name in used_names and used_names[child_name] is not val: + # Append a short hash of the value to disambiguate + val_hash = hashlib.sha256(repr(val).encode()).hexdigest()[:8] + # Re-truncate token to make room for _hash + max_token_len = _PG_IDENT_MAX_LEN - len(parent_table) - 1 - 9 # _hash8 + if max_token_len < 1: + max_token_len = 1 + truncated_token = token[:max_token_len].rstrip("_") + child_name = f"{parent_table}_{truncated_token}_{val_hash}" + + # Final length check + if len(child_name) > _PG_IDENT_MAX_LEN: + child_name = child_name[:_PG_IDENT_MAX_LEN] + + used_names[child_name] = val + + literal = _render_partition_value_literal(val, pg_type) + + if next_field is not None: + # Intermediate partition: itself partitioned by the next field + stmt = ( + f"CREATE TABLE {_qualified(schema, child_name)} " + f"PARTITION OF {_qualified(schema, parent_table)} " + f"FOR VALUES IN ({literal}) " + f"PARTITION BY LIST ({_quote_ident(next_field)});" + ) + statements.append(stmt) + # Recurse into children + if children: + _render_partition_ddl_recursive( + schema, child_name, partition_by, children, + column_specs, depth + 1, statements, + ) + else: + # Leaf partition + stmt = ( + f"CREATE TABLE {_qualified(schema, child_name)} " + f"PARTITION OF {_qualified(schema, parent_table)} " + f"FOR VALUES IN ({literal});" + ) + statements.append(stmt) + + +# --------------------------------------------------------------------------- +# Index support +# --------------------------------------------------------------------------- + + +def render_create_indexes( + schema: str, + tablename: str, + indexes: List[str], +) -> List[str]: + """Generate ``CREATE INDEX IF NOT EXISTS`` DDL for each column in *indexes*. + + Each index is a simple B-tree index on a single column. The index name + follows the pattern ``ix_{tablename}_{column}`` (raw, unsanitized names + wrapped with :func:`_quote_ident`). The table reference is fully + qualified as ``schema.tablename``. + + If the generated index name exceeds PostgreSQL's 63-character identifier + limit, it is truncated and a short hash suffix is appended to preserve + uniqueness (similar to partition name truncation). + + Returns a list of SQL strings, one per index. + """ + stmts: List[str] = [] + for col in indexes: + idx_name = f"ix_{tablename}_{col}" + if len(idx_name) > _PG_IDENT_MAX_LEN: + # Truncate and append an 8-char hash for uniqueness. + name_hash = hashlib.sha256(idx_name.encode()).hexdigest()[:8] + # 9 = 1 underscore + 8 hash chars + truncated = idx_name[: _PG_IDENT_MAX_LEN - 9].rstrip("_") + idx_name = f"{truncated}_{name_hash}" + stmt = ( + f"CREATE INDEX IF NOT EXISTS {_quote_ident(idx_name)} " + f"ON {_qualified(schema, tablename)} ({_quote_ident(col)});" + ) + stmts.append(stmt) + return stmts + + +def create_indexes( + conn, + schema: str, + tablename: str, + indexes: List[str], +) -> None: + """Execute ``CREATE INDEX IF NOT EXISTS`` for each column in *indexes*. + + Calls :func:`render_create_indexes` to obtain the DDL, executes each + statement, commits immediately after each successful creation, and logs + progress to stderr. If an individual index creation fails (e.g. a name + collision unrelated to ``IF NOT EXISTS``), the transaction is rolled back + (affecting only the failed statement) and the remaining indexes are still + attempted. + """ + stmts = render_create_indexes(schema, tablename, indexes) + with conn.cursor() as cur: + for stmt, col in zip(stmts, indexes): + try: + cur.execute(stmt) + conn.commit() + print( + f"[info] created index ix_{tablename}_{col} " + f"on {schema}.{tablename}({col})", + file=sys.stderr, + ) + except Exception as exc: + conn.rollback() + print( + f"[warn] failed to create index ix_{tablename}_{col} " + f"on {schema}.{tablename}({col}): {exc}", + file=sys.stderr, + ) # --------------------------------------------------------------------------- @@ -1032,10 +1753,12 @@ def copy_dataframes( dfs: Iterable[pd.DataFrame], columns: Dict[str, ColumnSpec], ) -> int: - """Stream an iterable of DataFrames into one ``COPY`` session. + """Stream an iterable of DataFrames into Postgres, committing each chunk. - All chunks share a cursor and transaction, so a failure mid-stream - rolls back the whole load when the caller hasn't committed yet. + Each non-empty chunk is copied via ``COPY ... FROM STDIN`` and committed + before the next chunk is processed, so an interrupted or failed load + retains the rows from previously committed chunks. The first chunk's + commit also flushes any pending DDL (e.g. a preceding ``CREATE TABLE``). Empty chunks are skipped. Returns the total rows inserted. """ col_list = ", ".join(_quote_ident(name) for name in columns.keys()) @@ -1060,6 +1783,7 @@ def copy_dataframes( ) buf.seek(0) cur.copy_expert(sql, buf) + conn.commit() total += len(prepared) return total @@ -1205,6 +1929,24 @@ def main(argv: Optional[List[str]] = None) -> int: preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude) columns = infer_schema(preview_df, meta) + # Validate partition columns exist in the schema after filtering. + if cfg.partition_by: + missing_pcols = [c for c in cfg.partition_by if c not in columns] + if missing_pcols: + raise ValueError( + f"partition_by references columns not present in the " + f"(filtered) schema: {missing_pcols}" + ) + + # Validate index columns exist in the schema after filtering. + if cfg.indexes: + missing_icols = [c for c in cfg.indexes if c not in columns] + if missing_icols: + raise ValueError( + f"indexes references columns not present in the " + f"(filtered) schema: {missing_icols}" + ) + if args.validate: manifest_path = cfg.filename.with_suffix("").with_suffix(".expected.json") # The above strips .xpt then appends .expected.json, e.g. @@ -1217,8 +1959,59 @@ def main(argv: Optional[List[str]] = None) -> int: return 1 print(f"validation OK ({len(columns)} columns match {manifest_path.name})") + # -- Partition value discovery ------------------------------------------ + # If partitioned, scan the ENTIRE file to discover all unique partition + # values. The preview is only the first N rows and may miss values. + # In append mode the partitions already exist, so skip the costly scan. + partition_values: Optional[dict] = None + if cfg.partition_by and cfg.if_exists != "append": + print(" discovering partition values (full file scan)...", file=sys.stderr) + + def _discovery_chunks(): + for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename): + yield apply_column_filter(chunk_df, cfg.include, cfg.exclude) + + partition_values = discover_partition_values_chunked( + _discovery_chunks(), cfg.partition_by, columns, + ) + total_parts = _count_partitions(partition_values) + print( + f" discovered {total_parts:,} partition tables " + f"across {len(cfg.partition_by)} level(s)", + file=sys.stderr, + ) + elif cfg.partition_by and cfg.if_exists == "append": + print( + " [info] append mode: skipping partition discovery " + "(partitions assumed to exist)", + file=sys.stderr, + ) + if args.dry_run: - print(render_create_table(cfg.schemaname, cfg.tablename, columns)) + # Print the parent CREATE TABLE (with PARTITION BY if applicable). + parent_ddl = render_create_table( + cfg.schemaname, cfg.tablename, columns, + partition_by=cfg.partition_by or None, + ) + print(parent_ddl) + # Print child partition DDL if partitioned. + if cfg.partition_by and partition_values: + child_stmts = render_partition_ddl( + cfg.schemaname, cfg.tablename, cfg.partition_by, + partition_values, columns, + max_partitions=cfg.max_partitions, + ) + for stmt in child_stmts: + print() + print(stmt) + # Print CREATE INDEX DDL if indexes are configured. + if cfg.indexes: + idx_stmts = render_create_indexes( + cfg.schemaname, cfg.tablename, cfg.indexes, + ) + for stmt in idx_stmts: + print() + print(stmt) return 0 # Release the preview frame before opening the stream - lets the GC reclaim @@ -1241,11 +2034,18 @@ def main(argv: Optional[List[str]] = None) -> int: conn = connect(user=db_user, password=db_password) conn.autocommit = False try: - create_table(conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists) + create_table( + conn, cfg.schemaname, cfg.tablename, columns, cfg.if_exists, + partition_by=cfg.partition_by or None, + partition_values=partition_values, + max_partitions=cfg.max_partitions, + ) inserted = copy_dataframes( conn, cfg.schemaname, cfg.tablename, _filtered_chunks(), columns ) conn.commit() + if cfg.indexes: + create_indexes(conn, cfg.schemaname, cfg.tablename, cfg.indexes) except Exception: conn.rollback() raise @@ -1256,6 +2056,9 @@ def main(argv: Optional[List[str]] = None) -> int: f"loaded {inserted} rows into {cfg.schemaname}.{cfg.tablename} " f"({len(columns)} columns)" ) + if cfg.partition_by and partition_values: + total_parts = _count_partitions(partition_values) + print(f"partitioned by {cfg.partition_by} ({total_parts:,} partition tables)") print("final schema:") print(_format_columns_summary(columns)) return 0 diff --git a/generic_loader/sample_config.yaml b/generic_loader/sample_config.yaml index 791205a..c487769 100644 --- a/generic_loader/sample_config.yaml +++ b/generic_loader/sample_config.yaml @@ -15,3 +15,26 @@ tablename: kitchensink # What to do if the target table already exists: fail | replace | append # Defaults to fail. if_exists: append + +# partition_by: Partition the table by unique values of these columns. +# Columns are applied in cascading order (first column = top-level partition). +# Requires if_exists: replace or fail (not append for initial creation). +# Single field: +# partition_by: state +# Multiple fields (cascading): +# partition_by: +# - state +# - zip +# +# max_partitions: Warning threshold for total partition count (default: 10000). +# If the number of partitions exceeds this, a warning is logged but loading continues. +# max_partitions: 10000 + +# indexes: Create B-tree indexes on these columns after data loading. +# Indexes are created with IF NOT EXISTS for safe use with append mode. +# Single column: +# indexes: state +# Multiple columns (one index per column): +# indexes: +# - state +# - zip diff --git a/generic_loader/sample_folder_config.yaml b/generic_loader/sample_folder_config.yaml index e2ddfda..066d840 100644 --- a/generic_loader/sample_folder_config.yaml +++ b/generic_loader/sample_folder_config.yaml @@ -31,6 +31,30 @@ auto_detect: true # exclude: # - ALLNULL +# Folder-level partition_by: Partition every cluster's table by unique values +# of these columns. Inherited by all clusters unless overridden per-cluster. +# Requires if_exists: replace or fail (not append for initial creation). +# Single field: +# partition_by: state +# Multiple fields (cascading): +# partition_by: +# - state +# - zip +# +# Folder-level max_partitions: Warning threshold for total partition count +# (default: 10000). Inherited by all clusters unless overridden per-cluster. +# max_partitions: 10000 + +# Folder-level indexes: Create B-tree indexes on these columns after data +# loading. Inherited by all clusters unless overridden per-cluster. +# Indexes are created with IF NOT EXISTS for safe use with append mode. +# Single column: +# indexes: state +# Multiple columns (one index per column): +# indexes: +# - state +# - zip + # Explicit cluster patterns. Each pattern is matched against the file # *basename*. Files matched by a pattern are pulled out of the auto-detect # pool, so explicit and auto clusters compose cleanly. @@ -48,6 +72,26 @@ clusters: # tablename: group_b # if_exists: append + # Per-cluster partition_by / max_partitions override. These take precedence + # over the folder-level defaults above. + # + # - pattern: '^group_c\d+\.xpt$' + # tablename: group_c + # partition_by: + # - region + # - year + # max_partitions: 500 + + # Per-cluster indexes override. Takes precedence over the folder-level + # indexes default above. An explicit empty list disables indexing for + # this cluster even when the folder default has indexes. + # + # - pattern: '^group_d\d+\.xpt$' + # tablename: group_d + # indexes: + # - region + # - year + # With only the gq pattern explicit, auto_detect: true will still bucket # group_b1.xpt + group_b2.xpt into a "group_b" cluster and the lone # standalone.xpt into a "standalone" cluster. See generate_sample_folder.py