"""Folder-level SAS-to-Postgres loader. Wraps :mod:`load_sas` so an entire directory of SAS files can be ingested in one invocation. A directory often contains several *clusters* of files that share a schema (e.g. ``group_a1.sas7bdat``, ``group_a2.sas7bdat``, ...). Each cluster becomes one Postgres table; files inside a cluster are appended to it. ------------------------------------------------------------------------------- USAGE ------------------------------------------------------------------------------- 1. YAML config -------------- :: folder: samples/folder_test # required; relative paths resolve against # the config file's directory schemaname: public # required # Optional. One of: fail | replace | append. Default: fail. # Applied to the first file of each cluster (subsequent files in the # cluster always run through the append-mode compatibility check). if_exists: fail # Optional. Default: true. When true, files that don't match any explicit # pattern below are grouped by their common prefix (trailing digits, and # optional trailing separators, are stripped from each file stem). auto_detect: true # Optional. Columns to force-include or force-exclude across every file. # include and exclude are mutually exclusive. # include: [ID, INTCOL] # exclude: [ALLNULL] # Optional folder default for LIST partitioning. Omit or set [] for no # partitioning. Accepts a single string or a list of column names. # partition_by: # - state # - zip # Optional folder default threshold. Default: 10000. # max_partitions: 10000 # Optional explicit cluster patterns. Each pattern is matched against the # file *basename*. Matched files are pulled out of the auto-detect pool. # Per-cluster if_exists/include/exclude/partition_by/max_partitions # override the folder-level defaults. clusters: - pattern: '^group_a\\d+\\.sas7bdat$' tablename: group_a - pattern: '^group_b\\d+\\.sas7bdat$' tablename: group_b if_exists: replace 2. Command-line interface ------------------------- :: python load_folder.py --config folder_config.yaml [--dry-run] [--fail-fast] [--dbcreds] Flags: --config PATH Required. Path to the YAML config above. --dry-run Print the discovered clusters and the inferred DDL for each (CREATE TABLE plus partition DDL when applicable). For partitioned clusters all files are scanned to discover partition values. The database is never touched. --fail-fast Abort the whole run on the first cluster failure. Default is to log the failure, roll that cluster back, and keep going. --dbcreds Prompt interactively for the database username and password instead of reading ``PGUSER`` / ``PGPASSWORD`` from the environment or ``.env`` file. The password prompt does not echo. Has no effect with ``--dry-run`` (no connection is opened). Exit codes: 0 - every cluster loaded successfully (or dry-run completed) 1 - at least one cluster failed (details on stderr) 2 - folder does not exist / contains no SAS files 3. Discovery rules ------------------ * Supported extensions: ``.sas7bdat``, ``.xpt``, ``.xport`` (matches :mod:`load_sas`). The folder is not scanned recursively. * Explicit patterns are tried in order. A file matched by one pattern is removed from the pool before the next pattern runs, so earlier patterns win in case of overlap. Overlap between patterns is flagged as an error at config-parse time (a file matching two patterns is almost always a bug). * Auto-detect groups remaining files by ``re.sub(r'\\d+$', '', stem)`` with any trailing ``_`` / ``-`` stripped afterward. Stems without trailing digits become singleton clusters named after the stem. * Within a cluster, files are sorted **numerically** by the last digit group in the stem, so ``..._9_...`` comes before ``..._10_...`` / ``..._40_...`` regardless of zero-padding. The first file in that order drives schema inference; the rest are checked against that schema via :func:`load_sas.assert_schema_compatible`. Gaps in the numeric sequence (missing ``3``, ``7``, ``14``, ...) are irrelevant - whatever files are present get loaded in numeric order. * Auto-detect only recognizes *trailing* digit runs. File names where the varying number sits in the middle of the stem (surrounded by other name components) are not grouped by auto-detect - each becomes its own singleton cluster. Use an explicit pattern to bucket them:: clusters: - pattern: '^year2020_regionA_\\d+_detail\\.sas7bdat$' tablename: year2020_regionA_detail The regex still matches any digit width, so numbers like ``9`` and ``40`` both land in the same cluster and the numeric sort above puts ``9`` before ``40``. 4. Library usage ---------------- :: from load_folder import load_folder_config, discover_clusters, load_cluster from load_sas import connect cfg = load_folder_config("folder_config.yaml") clusters = discover_clusters(cfg) conn = connect() try: for cluster in clusters: load_cluster(conn, cluster, cfg.schemaname) finally: conn.close() """ from __future__ import annotations import argparse import getpass import multiprocessing as mp import os import queue as _queue_mod import re import sys import threading from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import yaml from dotenv import load_dotenv from tqdm import tqdm from load_sas import ( VALID_IF_EXISTS, _count_partitions, _merge_partition_trees, apply_column_filter, assert_schema_compatible, connect, copy_dataframes, create_indexes, create_table, discover_partition_values_chunked, infer_schema, iter_sas_chunks, read_sas_metadata, read_sas_preview, render_create_indexes, render_create_table, render_partition_ddl, ) SAS_EXTENSIONS = (".sas7bdat", ".xpt", ".xport") # --------------------------------------------------------------------------- # Dataclasses # --------------------------------------------------------------------------- @dataclass class ClusterSpec: """Resolved per-cluster load settings. ``partition_by``, ``max_partitions``, and ``indexes`` are resolved from the folder defaults and any per-cluster overrides during :func:`discover_clusters`. """ tablename: str files: List[Path] if_exists: str include: Optional[List[str]] exclude: Optional[List[str]] source: str # "explicit" or "auto" pattern: Optional[str] = None partition_by: List[str] = field(default_factory=list) max_partitions: int = 10_000 indexes: List[str] = field(default_factory=list) @dataclass class _ExplicitPattern: """Parsed form of a single ``clusters[*]`` YAML entry. ``partition_by`` defaults to ``None`` meaning "inherit from folder level". An explicit empty list ``[]`` means "disable partitioning for this cluster". ``max_partitions`` defaults to ``None`` meaning "inherit from folder level". ``indexes`` defaults to ``None`` meaning "inherit from folder level". """ pattern: re.Pattern raw_pattern: str tablename: str if_exists: Optional[str] = None include: Optional[List[str]] = None exclude: Optional[List[str]] = None partition_by: Optional[List[str]] = None max_partitions: Optional[int] = None indexes: Optional[List[str]] = None @dataclass class FolderConfig: """Folder-level configuration parsed from YAML. ``partition_by``, ``max_partitions``, and ``indexes`` serve as defaults for every cluster unless overridden at the cluster level. """ folder: Path schemaname: str if_exists: str = "fail" auto_detect: bool = True include: Optional[List[str]] = None exclude: Optional[List[str]] = None explicit: List[_ExplicitPattern] = field(default_factory=list) partition_by: List[str] = field(default_factory=list) max_partitions: int = 10_000 indexes: List[str] = field(default_factory=list) # --------------------------------------------------------------------------- # Config loading # --------------------------------------------------------------------------- def _validate_if_exists(value: Any, where: str) -> str: s = str(value).lower() if s not in VALID_IF_EXISTS: raise ValueError( f"{where}: if_exists={value!r} is not one of {VALID_IF_EXISTS}" ) return s def _parse_columns_filter( raw: Dict[str, Any], where: str ) -> Tuple[Optional[List[str]], Optional[List[str]]]: include = raw.get("include") exclude = raw.get("exclude") if include is not None and exclude is not None: raise ValueError(f"{where}: 'include' and 'exclude' are mutually exclusive.") if include is not None and not isinstance(include, list): raise ValueError(f"{where}: 'include' must be a list of column names.") if exclude is not None and not isinstance(exclude, list): raise ValueError(f"{where}: 'exclude' must be a list of column names.") include_out = [str(c) for c in include] if include is not None else None exclude_out = [str(c) for c in exclude] if exclude is not None else None return include_out, exclude_out def _parse_partition_by( raw_value: Any, where: str, *, allow_none: bool = False ) -> Optional[List[str]]: """Parse a ``partition_by`` value from YAML. Returns a list of non-empty, unique column name strings. When ``allow_none`` is True (used for per-cluster entries), an omitted key returns ``None`` to signal "inherit from folder level". An explicit empty list ``[]`` always returns ``[]``. """ if raw_value is None: return None if allow_none else [] if isinstance(raw_value, str): if not raw_value.strip(): raise ValueError(f"{where}: 'partition_by' string must be non-empty.") return [raw_value.strip()] if isinstance(raw_value, list): if len(raw_value) == 0: return [] result: List[str] = [] for i, item in enumerate(raw_value): if not isinstance(item, str) or not item.strip(): raise ValueError( f"{where}: 'partition_by[{i}]' must be a non-empty string." ) result.append(str(item).strip()) if len(result) != len(set(result)): raise ValueError( f"{where}: 'partition_by' contains duplicate column names." ) return result raise ValueError( f"{where}: 'partition_by' must be a string or list of strings." ) def _parse_max_partitions( raw_value: Any, where: str, *, allow_none: bool = False ) -> Optional[int]: """Parse a ``max_partitions`` value from YAML. Returns a positive integer. When ``allow_none`` is True (used for per-cluster entries), an omitted key returns ``None`` to signal "inherit from folder level". """ if raw_value is None: return None if allow_none else 10_000 try: value = int(raw_value) except (TypeError, ValueError): raise ValueError( f"{where}: 'max_partitions' must be a positive integer, " f"got {raw_value!r}" ) if value <= 0: raise ValueError( f"{where}: 'max_partitions' must be a positive integer, " f"got {value}" ) return value def _validate_partition_vs_columns( partition_by: List[str], exclude: Optional[List[str]], where: str, ) -> None: """Raise if any ``partition_by`` column is in the ``exclude`` list.""" if not partition_by or exclude is None: return excluded_parts = [c for c in partition_by if c in exclude] if excluded_parts: raise ValueError( f"{where}: 'exclude' removes partition_by columns: {excluded_parts}" ) def _parse_indexes( raw_value: Any, where: str, *, allow_none: bool = False ) -> Optional[List[str]]: """Parse an ``indexes`` value from YAML. Returns a list of non-empty, unique column name strings. When ``allow_none`` is True (used for per-cluster entries), an omitted key returns ``None`` to signal "inherit from folder level". An explicit empty list ``[]`` always returns ``[]``. """ if raw_value is None: return None if allow_none else [] if isinstance(raw_value, str): if not raw_value.strip(): raise ValueError(f"{where}: 'indexes' string must be non-empty.") return [raw_value.strip()] if isinstance(raw_value, list): if len(raw_value) == 0: return [] result: List[str] = [] for i, item in enumerate(raw_value): if not isinstance(item, str) or not item.strip(): raise ValueError( f"{where}: 'indexes[{i}]' must be a non-empty string." ) result.append(str(item).strip()) if len(result) != len(set(result)): raise ValueError( f"{where}: 'indexes' contains duplicate column names." ) return result raise ValueError( f"{where}: 'indexes' must be a string or list of strings." ) def _validate_indexes_vs_columns( indexes: List[str], exclude: Optional[List[str]], where: str, ) -> None: """Raise if any ``indexes`` column is in the ``exclude`` list.""" if not indexes or exclude is None: return excluded_idx = [c for c in indexes if c in exclude] if excluded_idx: raise ValueError( f"{where}: 'exclude' removes index columns: {excluded_idx}" ) def load_folder_config(path: Path) -> FolderConfig: """Parse and validate the folder-level YAML config at ``path``. Supports optional ``partition_by`` and ``max_partitions`` at both the folder level (defaults for all clusters) and per explicit cluster entry (overrides the folder default). """ path = Path(path) with path.open("r", encoding="utf-8") as f: raw = yaml.safe_load(f) if not isinstance(raw, dict): raise ValueError(f"Config at {path} must be a YAML mapping at the top level.") missing = [k for k in ("folder", "schemaname") if k not in raw] if missing: raise ValueError(f"Config {path} missing required keys: {', '.join(missing)}") folder = Path(raw["folder"]) if not folder.is_absolute(): candidate = (path.parent / folder).resolve() folder = candidate if candidate.exists() else folder schemaname = str(raw["schemaname"]) if_exists = _validate_if_exists(raw.get("if_exists", "fail"), f"Config {path}") auto_detect = bool(raw.get("auto_detect", True)) include, exclude = _parse_columns_filter(raw, f"Config {path}") # -- folder-level partition settings ------------------------------------ partition_by = _parse_partition_by( raw.get("partition_by"), f"Config {path}" ) max_partitions = _parse_max_partitions( raw.get("max_partitions"), f"Config {path}" ) _validate_partition_vs_columns(partition_by, exclude, f"Config {path}") # -- folder-level index settings ---------------------------------------- indexes = _parse_indexes(raw.get("indexes"), f"Config {path}") _validate_indexes_vs_columns(indexes, exclude, f"Config {path}") explicit: List[_ExplicitPattern] = [] clusters_raw = raw.get("clusters") or [] if not isinstance(clusters_raw, list): raise ValueError(f"Config {path}: 'clusters' must be a list if present.") for i, entry in enumerate(clusters_raw): where = f"Config {path} clusters[{i}]" if not isinstance(entry, dict): raise ValueError(f"{where} must be a mapping.") if "pattern" not in entry or "tablename" not in entry: raise ValueError(f"{where} must include 'pattern' and 'tablename'.") raw_pat = str(entry["pattern"]) try: compiled = re.compile(raw_pat) except re.error as e: raise ValueError(f"{where}: invalid regex {raw_pat!r}: {e}") from e c_if_exists = ( _validate_if_exists(entry["if_exists"], where) if "if_exists" in entry else None ) c_include, c_exclude = _parse_columns_filter(entry, where) # -- per-cluster partition settings --------------------------------- c_partition_by = _parse_partition_by( entry.get("partition_by"), where, allow_none=True ) c_max_partitions = _parse_max_partitions( entry.get("max_partitions"), where, allow_none=True ) # Validate partition_by vs the effective exclude for this cluster. effective_exclude = c_exclude if c_exclude is not None else exclude effective_pb = c_partition_by if c_partition_by is not None else partition_by _validate_partition_vs_columns(effective_pb, effective_exclude, where) # -- per-cluster index settings ------------------------------------- c_indexes = _parse_indexes( entry.get("indexes"), where, allow_none=True ) effective_idx = c_indexes if c_indexes is not None else indexes _validate_indexes_vs_columns(effective_idx, effective_exclude, where) explicit.append( _ExplicitPattern( pattern=compiled, raw_pattern=raw_pat, tablename=str(entry["tablename"]), if_exists=c_if_exists, include=c_include, exclude=c_exclude, partition_by=c_partition_by, max_partitions=c_max_partitions, indexes=c_indexes, ) ) return FolderConfig( folder=folder, schemaname=schemaname, if_exists=if_exists, auto_detect=auto_detect, include=include, exclude=exclude, explicit=explicit, partition_by=partition_by, max_partitions=max_partitions, indexes=indexes, ) # --------------------------------------------------------------------------- # Cluster discovery # --------------------------------------------------------------------------- _TRAILING_DIGIT_RE = re.compile(r"\d+$") _DIGIT_GROUP_RE = re.compile(r"\d+") def _auto_prefix(stem: str) -> str: """Derive the cluster key for a file stem. Strip trailing digits and any trailing separators so ``group_a1`` / ``group_a_2`` / ``group_a-3`` all land in the same ``group_a`` bucket. If nothing is stripped, the stem is its own key. """ stripped = _TRAILING_DIGIT_RE.sub("", stem) stripped = stripped.rstrip("_-") return stripped or stem def _cluster_sort_key(path: Path) -> Tuple[int, str]: """Sort key for ordering files within a cluster. Sorts numerically by the LAST digit group in the stem so ``_9`` comes before ``_10`` / ``_40`` regardless of width, and so a file named ``foo_9_detail`` lands before ``foo_40_detail``. The first file under this order is the one whose schema is inferred and used to create the target table; sorting numerically keeps that choice stable as the file set grows. Files with no digits fall to ``-1`` so they sort before numbered files; the stem is a tiebreaker for reproducibility. """ digits = _DIGIT_GROUP_RE.findall(path.stem) n = int(digits[-1]) if digits else -1 return (n, path.stem) def _list_sas_files(folder: Path) -> List[Path]: files: List[Path] = [] for p in sorted(folder.iterdir()): if p.is_file() and p.suffix.lower() in SAS_EXTENSIONS: files.append(p) return files def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: """Enumerate ``cfg.folder`` and bucket files into ``ClusterSpec`` objects. Pure/IO-bounded: the only filesystem access is listing ``cfg.folder``. No SAS file is opened here. Explicit patterns are applied first, in config order; files matched by an earlier pattern are removed from the pool before the next pattern runs. A file matching two patterns triggers a hard error (that's almost always a config bug). Partition settings are resolved per cluster: * For explicit clusters, ``partition_by`` / ``max_partitions`` from the cluster entry override the folder defaults when present. ``None`` means "inherit"; an explicit ``[]`` disables partitioning. * For auto-detected clusters, folder defaults are inherited directly. """ if not cfg.folder.exists() or not cfg.folder.is_dir(): raise FileNotFoundError(f"Folder not found or not a directory: {cfg.folder}") pool = _list_sas_files(cfg.folder) clusters: List[ClusterSpec] = [] # Detect cross-pattern overlap up front for a clearer error message. for i, p_i in enumerate(cfg.explicit): for j in range(i + 1, len(cfg.explicit)): p_j = cfg.explicit[j] for f in pool: if p_i.pattern.search(f.name) and p_j.pattern.search(f.name): raise ValueError( f"File {f.name!r} matches multiple explicit patterns: " f"{p_i.raw_pattern!r} and {p_j.raw_pattern!r}" ) remaining = list(pool) for patt in cfg.explicit: # Resolve partition_by: None = inherit folder, [] = disable, list = override resolved_pb = ( patt.partition_by if patt.partition_by is not None else cfg.partition_by ) resolved_mp = ( patt.max_partitions if patt.max_partitions is not None else cfg.max_partitions ) # Resolve indexes: None = inherit folder, [] = disable, list = override resolved_idx = ( patt.indexes if patt.indexes is not None else cfg.indexes ) matched = [f for f in remaining if patt.pattern.search(f.name)] if not matched: # Not an error - the folder might legitimately not contain files # for this pattern on a given run. Emit a note for the CLI. clusters.append( ClusterSpec( tablename=patt.tablename, files=[], if_exists=patt.if_exists or cfg.if_exists, include=patt.include if patt.include is not None else cfg.include, exclude=patt.exclude if patt.exclude is not None else cfg.exclude, source="explicit", pattern=patt.raw_pattern, partition_by=resolved_pb, max_partitions=resolved_mp, indexes=resolved_idx, ) ) continue remaining = [f for f in remaining if f not in matched] clusters.append( ClusterSpec( tablename=patt.tablename, files=sorted(matched, key=_cluster_sort_key), if_exists=patt.if_exists or cfg.if_exists, include=patt.include if patt.include is not None else cfg.include, exclude=patt.exclude if patt.exclude is not None else cfg.exclude, source="explicit", pattern=patt.raw_pattern, partition_by=resolved_pb, max_partitions=resolved_mp, indexes=resolved_idx, ) ) if cfg.auto_detect and remaining: buckets: Dict[str, List[Path]] = {} for f in remaining: key = _auto_prefix(f.stem) buckets.setdefault(key, []).append(f) for key in sorted(buckets): clusters.append( ClusterSpec( tablename=key, files=sorted(buckets[key], key=_cluster_sort_key), if_exists=cfg.if_exists, include=cfg.include, exclude=cfg.exclude, source="auto", partition_by=cfg.partition_by, max_partitions=cfg.max_partitions, indexes=cfg.indexes, ) ) return clusters # --------------------------------------------------------------------------- # Per-cluster load # --------------------------------------------------------------------------- def _infer_cluster_schema( path: Path, include, exclude ) -> Tuple[Dict, Optional[int]]: """Infer the Postgres column schema from a SAS file preview. Returns ``(columns, total_rows)``. ``total_rows`` comes from the pyreadstat metadata (the file's declared row count) and is threaded through to :func:`_stream_file` so the tqdm progress bar has a real denominator instead of an indeterminate spinner. """ preview_df, meta = read_sas_preview(path) preview_df = apply_column_filter(preview_df, include, exclude) total_rows = getattr(meta, "number_rows", None) columns = infer_schema(preview_df, meta, total_rows=total_rows) return columns, total_rows def _discover_cluster_partitions( cluster: ClusterSpec, columns: Dict, ) -> dict: """Scan ALL files in ``cluster`` to discover partition values. Returns a nested partition-value tree suitable for passing to :func:`load_sas.render_partition_ddl` and :func:`load_sas.create_table`. Each file is scanned chunk-by-chunk so the full dataset is never materialized in memory. """ merged: dict = {} for path in cluster.files: def _filtered_chunks(p=path): for chunk_df, _chunk_meta in iter_sas_chunks(p): yield apply_column_filter( chunk_df, cluster.include, cluster.exclude ) file_tree = discover_partition_values_chunked( _filtered_chunks(), cluster.partition_by, columns, ) _merge_partition_trees(merged, file_tree) return merged def load_cluster( conn, cluster: ClusterSpec, schemaname: str, *, workers: int = 1, progress_queue: Any = None, db_overrides: Optional[Dict[str, Optional[str]]] = None, ) -> int: """Load every file in ``cluster`` into one table. Returns total rows loaded. When ``cluster.partition_by`` is non-empty, partition values are discovered across ALL files before table creation so the full partition tree exists before any data is copied. Commits happen per chunk inside :func:`load_sas.copy_dataframes`. If a file mid-cluster fails, earlier chunks - including chunks from earlier files in the cluster - stay committed; only the in-flight chunk is rolled back by :func:`main`. ``workers`` controls parallelism for the *append* phase. The first file always runs serially on ``conn`` (to create the table and, when partitioned, pre-create partitions). When ``workers > 1`` the remaining files dispatch to a ``ProcessPoolExecutor``; each worker opens its own psycopg2 connection, re-infers the per-file schema, runs the same :func:`load_sas.assert_schema_compatible` check the serial path uses, and streams chunks via COPY. Workers report per-chunk row counts to ``progress_queue`` so the caller can drive a single aggregated tqdm bar regardless of how many workers are in flight. ``db_overrides`` carries ``{"user", "password"}`` into workers when the caller prompted for credentials interactively; leave ``None`` to let workers read the standard libpq environment variables on their own. """ if not cluster.files: return 0 first, *rest = cluster.files first_columns, first_total_rows = _infer_cluster_schema( first, cluster.include, cluster.exclude ) # -- Validate index columns early --------------------------------------- if cluster.indexes: missing_icols = [ c for c in cluster.indexes if c not in first_columns ] if missing_icols: raise ValueError( f"cluster {cluster.tablename!r}: indexes references " f"columns not present in the inferred schema: {missing_icols}" ) # -- Partition support -------------------------------------------------- partition_values: Optional[dict] = None if cluster.partition_by: # Validate that all partition_by columns exist in the inferred schema. missing_pcols = [ c for c in cluster.partition_by if c not in first_columns ] if missing_pcols: raise ValueError( f"cluster {cluster.tablename!r}: partition_by references " f"columns not present in the inferred schema: {missing_pcols}" ) # Discover partition values across ALL files in the cluster. # In append mode the partitions already exist, so skip the scan. if cluster.if_exists == "append": print( " [info] append mode: skipping partition discovery " "(partitions assumed to exist)", file=sys.stderr, ) else: print( f" discovering partition values across " f"{len(cluster.files)} file(s)...", file=sys.stderr, ) partition_values = _discover_cluster_partitions( cluster, first_columns, ) total_parts = _count_partitions(partition_values) print( f" discovered {total_parts:,} partition table(s) " f"across {len(cluster.partition_by)} level(s)", file=sys.stderr, ) create_table( conn, schemaname, cluster.tablename, first_columns, cluster.if_exists, partition_by=cluster.partition_by or None, partition_values=partition_values, max_partitions=cluster.max_partitions, ) total = 0 total += _stream_file( conn, schemaname, cluster.tablename, first, first_columns, cluster.include, cluster.exclude, total_rows=first_total_rows, progress_queue=progress_queue, ) # Commit the first file (and the CREATE TABLE) before spawning workers # so their ``assert_schema_compatible`` probes actually see the new # table. Without this, worker connections started mid-transaction on # the main connection would see nothing in information_schema. conn.commit() if rest: if workers > 1: total += _load_remaining_files_parallel( rest, schemaname, cluster.tablename, cluster.include, cluster.exclude, workers=workers, progress_queue=progress_queue, db_overrides=db_overrides, ) else: for path in rest: columns, path_total_rows = _infer_cluster_schema( path, cluster.include, cluster.exclude ) # Uses the same check that if_exists=append runs. A type # mismatch or missing column aborts the cluster; because # chunks commit as they load, earlier chunks in the # cluster remain in the table. assert_schema_compatible( conn, schemaname, cluster.tablename, columns ) total += _stream_file( conn, schemaname, cluster.tablename, path, columns, cluster.include, cluster.exclude, total_rows=path_total_rows, progress_queue=progress_queue, ) # -- Index support ------------------------------------------------------ if cluster.indexes: create_indexes(conn, schemaname, cluster.tablename, cluster.indexes) return total def _stream_file( conn, schemaname: str, tablename: str, path: Path, columns, include, exclude, *, total_rows: Optional[int] = None, progress_queue: Any = None, ) -> int: """Stream ``path`` into an existing table chunk by chunk. When ``progress_queue`` is provided, each chunk's row count is published to the queue as ``("rows", n)`` tuples instead of being rendered to a per-file tqdm bar. That lets :func:`main` drive a single folder-wide progress bar from a background drainer thread, which is the only way to keep a coherent progress view when the folder loader is running files in parallel workers. """ def _chunks(): if progress_queue is not None: for chunk_df, _chunk_meta in iter_sas_chunks(path): chunk_df = apply_column_filter(chunk_df, include, exclude) progress_queue.put(("rows", len(chunk_df))) yield chunk_df return pbar = tqdm( total=total_rows, unit="row", unit_scale=True, desc=f" {path.name}", file=sys.stderr, dynamic_ncols=True, ) try: for chunk_df, _chunk_meta in iter_sas_chunks(path): chunk_df = apply_column_filter(chunk_df, include, exclude) pbar.update(len(chunk_df)) yield chunk_df finally: pbar.close() return copy_dataframes(conn, schemaname, tablename, _chunks(), columns) # --------------------------------------------------------------------------- # Parallel append workers # --------------------------------------------------------------------------- def _worker_load_append_file( path_str: str, schemaname: str, tablename: str, include: Optional[List[str]], exclude: Optional[List[str]], progress_queue: Any, db_overrides: Optional[Dict[str, Optional[str]]], ) -> Tuple[str, int, Optional[str]]: """Worker process: load one SAS file in append mode. Runs in a subprocess spawned by :func:`_load_remaining_files_parallel`. Opens its own psycopg2 connection, re-infers the per-file schema (so per-file ``INTEGER`` vs ``BIGINT`` drift is caught by the existing schema-compat check just like in the serial path), and streams chunks via ``COPY``. Row counts are published to the shared queue for the main process's global tqdm bar. Returns ``(path_str, rows_loaded, error_or_None)`` - failures are returned rather than raised so the parent can aggregate results across workers without losing partial progress. """ from pathlib import Path as _Path from dotenv import load_dotenv as _load_dotenv from load_sas import ( apply_column_filter as _apply_column_filter, assert_schema_compatible as _assert_schema_compatible, connect as _connect, copy_dataframes as _copy_dataframes, infer_schema as _infer_schema, iter_sas_chunks as _iter_sas_chunks, read_sas_preview as _read_sas_preview, ) _load_dotenv() path = _Path(path_str) try: preview_df, meta = _read_sas_preview(path) preview_df = _apply_column_filter(preview_df, include, exclude) total_rows = getattr(meta, "number_rows", None) columns = _infer_schema(preview_df, meta, total_rows=total_rows) user = db_overrides.get("user") if db_overrides else None password = db_overrides.get("password") if db_overrides else None conn = _connect(user=user, password=password) conn.autocommit = False try: _assert_schema_compatible(conn, schemaname, tablename, columns) def _chunks(): for chunk_df, _chunk_meta in _iter_sas_chunks(path): chunk_df = _apply_column_filter(chunk_df, include, exclude) if progress_queue is not None: progress_queue.put(("rows", len(chunk_df))) yield chunk_df rows = _copy_dataframes( conn, schemaname, tablename, _chunks(), columns ) conn.commit() return (path_str, rows, None) finally: conn.close() except Exception as e: return (path_str, 0, f"{type(e).__name__}: {e}") def _load_remaining_files_parallel( files: List[Path], schemaname: str, tablename: str, include: Optional[List[str]], exclude: Optional[List[str]], *, workers: int, progress_queue: Any, db_overrides: Optional[Dict[str, Optional[str]]], ) -> int: """Run append-mode loads for ``files`` across a process pool. Each file is an independent unit of work submitted to ``ProcessPoolExecutor``. Workers infer schema, validate compatibility, and stream via COPY just like the serial path. Failures are collected and re-raised as a single ``RuntimeError`` at the end so that all other workers' rows still count toward the committed total. """ total = 0 errors: List[Tuple[str, str]] = [] with ProcessPoolExecutor(max_workers=workers) as pool: futures = [ pool.submit( _worker_load_append_file, str(p), schemaname, tablename, include, exclude, progress_queue, db_overrides, ) for p in files ] for fut in as_completed(futures): path_str, rows, err = fut.result() if err is not None: errors.append((path_str, err)) else: total += rows if errors: joined = "\n".join(f" {p}: {e}" for p, e in errors) raise RuntimeError( f"{len(errors)} worker(s) failed while appending to " f"{schemaname}.{tablename}:\n{joined}" ) return total # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def _build_argparser() -> argparse.ArgumentParser: p = argparse.ArgumentParser( description=( "Load every SAS file in a folder into Postgres, grouping files " "into clusters that each become one table." ), ) p.add_argument("--config", required=True, type=Path, help="Path to YAML config") p.add_argument( "--dry-run", action="store_true", help=( "Print discovered clusters and the inferred DDL for each " "(CREATE TABLE plus partition DDL when applicable). For " "partitioned clusters all files are scanned to discover " "partition values. The database is never touched." ), ) p.add_argument( "--fail-fast", action="store_true", help=( "Abort on the first cluster failure. Default is to roll that " "cluster back and continue with the next one." ), ) p.add_argument( "--dbcreds", action="store_true", help=( "Prompt for database username and password instead of reading " "PGUSER / PGPASSWORD from the environment or .env file." ), ) p.add_argument( "--no-prescan", action="store_true", help=( "Skip the per-file metadata scan that populates the folder-wide " "tqdm ETA. Useful when the folder is large (half-hour+ pre-scan) " "or when you're iterating quickly on a failure. Without the " "pre-scan the progress bar still shows rows loaded, rate, and " "elapsed time - it just can't estimate remaining time." ), ) p.add_argument( "--workers", type=int, default=1, metavar="N", help=( "Number of worker processes for the append phase. With N=1 " "(default) files load serially on the main connection. With " "N>1 the first file of each cluster still runs serially (to " "create the table), then the remaining files load in parallel " "across N processes, each with its own psycopg2 connection. " "On a big box try N close to your core count. When N>1 the " "per-chunk row target drops to 500,000 unless you've pinned " "GENERIC_LOADER_CHUNK_ROWS, so peak memory stays bounded." ), ) return p def _describe_cluster(cluster: ClusterSpec) -> str: src = f"{cluster.source}" if cluster.pattern: src += f" pattern={cluster.pattern!r}" files = ", ".join(f.name for f in cluster.files) or "(no matching files)" parts = "" if cluster.partition_by: parts = f"\n partition_by: {cluster.partition_by}" idx = "" if cluster.indexes: idx = f"\n indexes: {cluster.indexes}" return ( f"cluster {cluster.tablename!r} [{src}] if_exists={cluster.if_exists}\n" f" files: {files}{parts}{idx}" ) def main(argv: Optional[List[str]] = None) -> int: args = _build_argparser().parse_args(argv) load_dotenv() cfg = load_folder_config(args.config) if not cfg.folder.exists() or not cfg.folder.is_dir(): print(f"error: folder not found: {cfg.folder}", file=sys.stderr) return 2 clusters = discover_clusters(cfg) loadable = [c for c in clusters if c.files] if not loadable: print( f"error: no SAS files found in {cfg.folder} " f"(looked for {', '.join(SAS_EXTENSIONS)})", file=sys.stderr, ) return 2 print(f"discovered {len(loadable)} cluster(s) in {cfg.folder}:") for c in clusters: print(_describe_cluster(c)) if args.dry_run: print() for c in loadable: print(f"--- DDL for cluster {c.tablename!r} ---") columns, _ = _infer_cluster_schema(c.files[0], c.include, c.exclude) # Print parent CREATE TABLE (with PARTITION BY if applicable). print( render_create_table( cfg.schemaname, c.tablename, columns, partition_by=c.partition_by or None, ) ) # Print child partition DDL when the cluster is partitioned. if c.partition_by: # Validate partition columns exist in the schema. missing_pcols = [ col for col in c.partition_by if col not in columns ] if missing_pcols: print( f" [error] partition_by references columns not in " f"schema: {missing_pcols}", file=sys.stderr, ) else: print( f" discovering partition values across " f"{len(c.files)} file(s)...", file=sys.stderr, ) partition_values = _discover_cluster_partitions( c, columns, ) total_parts = _count_partitions(partition_values) print( f" discovered {total_parts:,} partition table(s) " f"across {len(c.partition_by)} level(s)", file=sys.stderr, ) child_stmts = render_partition_ddl( cfg.schemaname, c.tablename, c.partition_by, partition_values, columns, max_partitions=c.max_partitions, ) for stmt in child_stmts: print() print(stmt) # Print CREATE INDEX DDL when the cluster has indexes. if c.indexes: missing_icols = [ col for col in c.indexes if col not in columns ] if missing_icols: print( f" [error] indexes references columns not in " f"schema: {missing_icols}", file=sys.stderr, ) else: idx_stmts = render_create_indexes( cfg.schemaname, c.tablename, c.indexes, ) for stmt in idx_stmts: print() print(stmt) print() return 0 db_user = db_password = None if args.dbcreds: db_user = input("Database username: ") db_password = getpass.getpass("Database password: ") db_overrides: Optional[Dict[str, Optional[str]]] = ( {"user": db_user, "password": db_password} if args.dbcreds else None ) workers = max(1, int(args.workers)) # When running parallel workers, bound peak memory: each worker buffers a # chunk (read + prepared + serialized) so total memory scales with # workers × chunk_rows × avg_row_bytes. Drop the default chunk target to # 500k unless the operator has explicitly pinned it. Setting the env var # before workers spawn means they inherit it through forkserver / spawn. if ( workers > 1 and "GENERIC_LOADER_CHUNK_ROWS" not in os.environ ): os.environ["GENERIC_LOADER_CHUNK_ROWS"] = "500000" print( "[info] parallel mode: bounding per-chunk rows to 500,000. " "Pin GENERIC_LOADER_CHUNK_ROWS to override.", file=sys.stderr, ) # -- Metadata pre-scan ----------------------------------------------------- # Sum ``number_rows`` across every file so the tqdm bar has a real # denominator. ``read_sas_metadata`` uses pyreadstat's ``metadataonly=True`` # fast path, but on multi-GB sas7bdat files that still reads tens of MB # of scattered subheader pages per file - sequentially that's minutes for # a 52-file folder. pyreadstat releases the GIL during I/O and C decoding, # so a ThreadPool gives near-linear scaling until the disk saturates. # ``--no-prescan`` bypasses the scan entirely; the progress bar then runs # without an ETA - useful when pre-scan itself is expensive (half hour+ # on very large files) or when debugging iteratively. all_files: List[Path] = [p for c in loadable for p in c.files] grand_total: Optional[int] = 0 if args.no_prescan: grand_total = None print( f"[info] --no-prescan set: skipping row-count pre-scan for " f"{len(all_files)} file(s); progress bar will show rate + " f"elapsed but no ETA.", file=sys.stderr, ) else: prescan_workers = min(16, max(1, len(all_files))) print( f"pre-scanning row counts for {len(all_files)} file(s) " f"across {prescan_workers} thread(s)...", file=sys.stderr, ) def _scan_one(p: Path) -> Tuple[Path, Optional[int], Optional[str]]: try: meta = read_sas_metadata(p) n = getattr(meta, "number_rows", None) return (p, int(n) if n is not None else None, None) except Exception as e: return (p, None, str(e)) unknown_total_files: List[str] = [] running_total = 0 with ThreadPoolExecutor(max_workers=prescan_workers) as tpool: prescan_bar = tqdm( total=len(all_files), unit="file", desc=" prescanning", file=sys.stderr, dynamic_ncols=True, ) try: for p, n, err in tpool.map(_scan_one, all_files): prescan_bar.update(1) if err is not None: unknown_total_files.append(f"{p.name} ({err})") elif n is None: unknown_total_files.append(p.name) else: running_total += n finally: prescan_bar.close() if unknown_total_files: print( f"[warn] could not read row count from " f"{len(unknown_total_files)} file(s); progress bar ETA will " f"be approximate.", file=sys.stderr, ) print( f" total rows across folder: {running_total:,}", file=sys.stderr, ) grand_total = running_total # -- Shared progress plumbing --------------------------------------------- # The queue crosses process boundaries when workers > 1 (managed proxy) # and is a plain in-process queue otherwise; the put/get contract is # identical either way. A daemon thread drains it and advances the one # tqdm bar that spans the whole folder load. manager: Optional[Any] = None progress_queue: Any if workers > 1: manager = mp.Manager() progress_queue = manager.Queue() else: progress_queue = _queue_mod.Queue() pbar = tqdm( total=grand_total or None, unit="row", unit_scale=True, desc=f"{cfg.folder.name}", file=sys.stderr, dynamic_ncols=True, ) stop_drainer = threading.Event() def _drainer() -> None: while not stop_drainer.is_set(): try: event = progress_queue.get(timeout=0.1) except _queue_mod.Empty: continue except (EOFError, OSError): return if not event: continue kind = event[0] if kind == "rows": pbar.update(event[1]) drainer_thread = threading.Thread(target=_drainer, daemon=True) drainer_thread.start() conn = connect(user=db_user, password=db_password) conn.autocommit = False failures: List[Tuple[str, Exception]] = [] totals: List[Tuple[str, int, int]] = [] # (tablename, files, rows) try: for cluster in loadable: print( f"\n>>> loading cluster {cluster.tablename!r} " f"({len(cluster.files)} file(s)) " f"[workers={workers}]" ) try: rows = load_cluster( conn, cluster, cfg.schemaname, workers=workers, progress_queue=progress_queue, db_overrides=db_overrides, ) conn.commit() totals.append((cluster.tablename, len(cluster.files), rows)) print( f" -> loaded {rows:,} row(s) into " f"{cfg.schemaname}.{cluster.tablename}" ) except Exception as e: conn.rollback() failures.append((cluster.tablename, e)) print( f" !! cluster {cluster.tablename!r} failed: {e}", file=sys.stderr, ) if args.fail_fast: break finally: # Drain any pending progress events before shutting the bar down so # the final rendered total matches what actually landed. stop_drainer.set() drainer_thread.join(timeout=2.0) pbar.close() conn.close() if manager is not None: manager.shutdown() print("\n=== summary ===") for name, fcount, rows in totals: print(f" ok {name}: {fcount} file(s), {rows:,} row(s)") for name, err in failures: print(f" FAIL {name}: {err}", file=sys.stderr) return 1 if failures else 0 if __name__ == "__main__": sys.exit(main())