diff --git a/generic_loader/load_folder.py b/generic_loader/load_folder.py index 82fe459..d503276 100644 --- a/generic_loader/load_folder.py +++ b/generic_loader/load_folder.py @@ -32,6 +32,17 @@ USAGE # include: [ID, INTCOL] # exclude: [ALLNULL] + # Optional folder default for explicit column type overrides. These + # win over the cluster-wide auto-union computed during pre-scan; set + # them when a column's SAS-level type varies across files (e.g. phone + # IDs stored as CHAR in some years and NUM in others) and you want to + # pin the Postgres type yourself rather than accept the auto-derived + # one. Per-cluster column_types inside each clusters[*] entry are + # merged on top of this map. + # column_types: + # RESP_PH_PREFIX_ID: TEXT + # SOME_BIGINT_COL: BIGINT + # Optional folder default for LIST partitioning. Omit or set [] for no # partitioning. Accepts a single string or a list of column names. # partition_by: @@ -43,14 +54,16 @@ USAGE # Optional explicit cluster patterns. Each pattern is matched against the # file *basename*. Matched files are pulled out of the auto-detect pool. - # Per-cluster if_exists/include/exclude/partition_by/max_partitions - # override the folder-level defaults. + # Per-cluster if_exists/include/exclude/partition_by/max_partitions/ + # column_types override the folder-level defaults. clusters: - pattern: '^group_a\\d+\\.sas7bdat$' tablename: group_a - pattern: '^group_b\\d+\\.sas7bdat$' tablename: group_b if_exists: replace + column_types: + PHONE_PREFIX: TEXT 2. Command-line interface ------------------------- @@ -132,14 +145,25 @@ from __future__ import annotations import argparse import getpass +import multiprocessing as mp +import os +import queue as _queue_mod import re import sys +import threading +from concurrent.futures import ( + CancelledError, + ProcessPoolExecutor, + ThreadPoolExecutor, + as_completed, +) from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import yaml from dotenv import load_dotenv +from tqdm import tqdm from load_sas import ( VALID_IF_EXISTS, @@ -152,12 +176,15 @@ from load_sas import ( create_indexes, create_table, discover_partition_values_chunked, + extract_union_metadata, infer_schema, iter_sas_chunks, + read_sas_metadata, read_sas_preview, render_create_indexes, render_create_table, render_partition_ddl, + union_column_types, ) @@ -175,7 +202,11 @@ class ClusterSpec: ``partition_by``, ``max_partitions``, and ``indexes`` are resolved from the folder defaults and any per-cluster overrides during - :func:`discover_clusters`. + :func:`discover_clusters`. ``column_types`` holds the effective type + overrides for this cluster: user-supplied YAML entries merged on top + of the auto-union result computed during pre-scan (see :func:`main`). + The same dict is threaded through to workers so every file in the + cluster infers the same schema. """ tablename: str @@ -188,6 +219,8 @@ class ClusterSpec: partition_by: List[str] = field(default_factory=list) max_partitions: int = 10_000 indexes: List[str] = field(default_factory=list) + column_types: Dict[str, str] = field(default_factory=dict) + all_nullable: bool = False @dataclass @@ -198,6 +231,8 @@ class _ExplicitPattern: An explicit empty list ``[]`` means "disable partitioning for this cluster". ``max_partitions`` defaults to ``None`` meaning "inherit from folder level". ``indexes`` defaults to ``None`` meaning "inherit from folder level". + ``column_types`` defaults to ``None`` meaning "inherit from folder level"; + an explicit ``{}`` means "no user overrides for this cluster". """ pattern: re.Pattern @@ -209,6 +244,8 @@ class _ExplicitPattern: partition_by: Optional[List[str]] = None max_partitions: Optional[int] = None indexes: Optional[List[str]] = None + column_types: Optional[Dict[str, str]] = None + all_nullable: Optional[bool] = None @dataclass @@ -217,6 +254,9 @@ class FolderConfig: ``partition_by``, ``max_partitions``, and ``indexes`` serve as defaults for every cluster unless overridden at the cluster level. + ``column_types`` is a ``{column_name: postgres_type_str}`` map of + user-supplied type overrides that win over the auto-union computed + during pre-scan. """ folder: Path @@ -229,6 +269,8 @@ class FolderConfig: partition_by: List[str] = field(default_factory=list) max_partitions: int = 10_000 indexes: List[str] = field(default_factory=list) + column_types: Dict[str, str] = field(default_factory=dict) + all_nullable: bool = False # --------------------------------------------------------------------------- @@ -389,6 +431,40 @@ def _validate_indexes_vs_columns( ) +def _parse_column_types( + raw_value: Any, where: str, *, allow_none: bool = False +) -> Optional[Dict[str, str]]: + """Parse a ``column_types`` mapping from YAML. + + The value must be a mapping ``{column_name: pg_type_str}``. Keys and + values are whitespace-stripped strings; empty strings raise. When + ``allow_none`` is True (used for per-cluster entries), an omitted key + returns ``None`` to mean "inherit from folder level"; an explicit + empty mapping returns ``{}`` (no overrides for this cluster). + """ + if raw_value is None: + return None if allow_none else {} + if not isinstance(raw_value, dict): + raise ValueError( + f"{where}: 'column_types' must be a mapping of " + f"{{column_name: postgres_type}}." + ) + out: Dict[str, str] = {} + for k, v in raw_value.items(): + key = str(k).strip() + if not key: + raise ValueError( + f"{where}: 'column_types' contains an empty column name." + ) + if not isinstance(v, str) or not v.strip(): + raise ValueError( + f"{where}: 'column_types[{key}]' must be a non-empty " + f"Postgres type string (got {v!r})." + ) + out[key] = v.strip() + return out + + def load_folder_config(path: Path) -> FolderConfig: """Parse and validate the folder-level YAML config at ``path``. @@ -431,6 +507,22 @@ def load_folder_config(path: Path) -> FolderConfig: indexes = _parse_indexes(raw.get("indexes"), f"Config {path}") _validate_indexes_vs_columns(indexes, exclude, f"Config {path}") + # -- folder-level column_types overrides -------------------------------- + column_types = _parse_column_types( + raw.get("column_types"), f"Config {path}" + ) + + # -- folder-level all_nullable ----------------------------------------- + # Sets the default for every cluster. Per-cluster ``all_nullable`` wins + # when present; the CLI ``--all-nullable`` flag trumps both. + raw_an_folder = raw.get("all_nullable", False) + if not isinstance(raw_an_folder, bool): + raise ValueError( + f"Config {path}: 'all_nullable' must be a boolean " + f"(got {raw_an_folder!r})." + ) + all_nullable_default = bool(raw_an_folder) + explicit: List[_ExplicitPattern] = [] clusters_raw = raw.get("clusters") or [] if not isinstance(clusters_raw, list): @@ -472,6 +564,24 @@ def load_folder_config(path: Path) -> FolderConfig: effective_idx = c_indexes if c_indexes is not None else indexes _validate_indexes_vs_columns(effective_idx, effective_exclude, where) + # -- per-cluster column_types overrides ----------------------------- + c_column_types = _parse_column_types( + entry.get("column_types"), where, allow_none=True + ) + + # -- per-cluster all_nullable --------------------------------------- + c_all_nullable: Optional[bool] + if "all_nullable" in entry: + raw_c_an = entry["all_nullable"] + if not isinstance(raw_c_an, bool): + raise ValueError( + f"{where}: 'all_nullable' must be a boolean " + f"(got {raw_c_an!r})." + ) + c_all_nullable = bool(raw_c_an) + else: + c_all_nullable = None + explicit.append( _ExplicitPattern( pattern=compiled, @@ -483,6 +593,8 @@ def load_folder_config(path: Path) -> FolderConfig: partition_by=c_partition_by, max_partitions=c_max_partitions, indexes=c_indexes, + column_types=c_column_types, + all_nullable=c_all_nullable, ) ) @@ -497,6 +609,8 @@ def load_folder_config(path: Path) -> FolderConfig: partition_by=partition_by, max_partitions=max_partitions, indexes=indexes, + column_types=column_types or {}, + all_nullable=all_nullable_default, ) @@ -594,6 +708,20 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: patt.indexes if patt.indexes is not None else cfg.indexes ) + # Resolve column_types: user overrides only. The auto-union adds + # more entries later (in :func:`main`) after the metadata pre-scan. + # None = inherit folder, {} = no cluster-level overrides, dict = + # cluster-level overrides that win over folder-level entries. + if patt.column_types is None: + resolved_ct: Dict[str, str] = dict(cfg.column_types) + else: + resolved_ct = {**cfg.column_types, **patt.column_types} + + # Resolve all_nullable: None = inherit folder, bool = override. + resolved_an = ( + patt.all_nullable if patt.all_nullable is not None + else cfg.all_nullable + ) matched = [f for f in remaining if patt.pattern.search(f.name)] if not matched: @@ -611,6 +739,8 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: partition_by=resolved_pb, max_partitions=resolved_mp, indexes=resolved_idx, + column_types=dict(resolved_ct), + all_nullable=resolved_an, ) ) continue @@ -627,6 +757,8 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: partition_by=resolved_pb, max_partitions=resolved_mp, indexes=resolved_idx, + column_types=dict(resolved_ct), + all_nullable=resolved_an, ) ) @@ -647,6 +779,8 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: partition_by=cfg.partition_by, max_partitions=cfg.max_partitions, indexes=cfg.indexes, + column_types=dict(cfg.column_types), + all_nullable=cfg.all_nullable, ) ) @@ -658,13 +792,35 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: # --------------------------------------------------------------------------- -def _infer_cluster_schema(path: Path, include, exclude): - """Infer the Postgres column schema from a SAS file preview.""" +def _infer_cluster_schema( + path: Path, + include, + exclude, + *, + column_types: Optional[Dict[str, str]] = None, + force_nullable: bool = False, +) -> Tuple[Dict, Optional[int]]: + """Infer the Postgres column schema from a SAS file preview. + + Returns ``(columns, total_rows)``. ``total_rows`` comes from the + pyreadstat metadata (the file's declared row count) and is threaded + through to :func:`_stream_file` so the tqdm progress bar has a real + denominator instead of an indeterminate spinner. ``column_types`` + lets the caller pin specific columns to a chosen Postgres type + (typically the merged auto-union + YAML overrides for the cluster). + ``force_nullable`` stamps every column nullable regardless of what + the preview shows - see :func:`load_sas.infer_schema`. + """ preview_df, meta = read_sas_preview(path) preview_df = apply_column_filter(preview_df, include, exclude) total_rows = getattr(meta, "number_rows", None) - columns = infer_schema(preview_df, meta, total_rows=total_rows) - return columns + columns = infer_schema( + preview_df, meta, + total_rows=total_rows, + column_types=column_types, + force_nullable=force_nullable, + ) + return columns, total_rows def _discover_cluster_partitions( @@ -693,7 +849,16 @@ def _discover_cluster_partitions( return merged -def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int: +def load_cluster( + conn, + cluster: ClusterSpec, + schemaname: str, + *, + workers: int = 1, + progress_queue: Any = None, + db_overrides: Optional[Dict[str, Optional[str]]] = None, + abort_on_first_failure: bool = False, +) -> int: """Load every file in ``cluster`` into one table. Returns total rows loaded. When ``cluster.partition_by`` is non-empty, partition values are @@ -704,12 +869,31 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int: file mid-cluster fails, earlier chunks - including chunks from earlier files in the cluster - stay committed; only the in-flight chunk is rolled back by :func:`main`. + + ``workers`` controls parallelism for streaming. With ``workers == 1`` + every file streams on ``conn`` in sequence. With ``workers > 1`` the + main connection only does ``CREATE TABLE`` (and, for partitioned + clusters, partition discovery + pre-creation), commits, then dispatches + *every* file - including the first - to a ``ProcessPoolExecutor``. Each + worker opens its own psycopg2 connection, re-infers the per-file schema, + runs the same :func:`load_sas.assert_schema_compatible` check the serial + path uses, and streams chunks via COPY. Workers report per-chunk row + counts to ``progress_queue`` so the caller can drive a single aggregated + tqdm bar regardless of how many workers are in flight. + + ``db_overrides`` carries ``{"user", "password"}`` into workers when the + caller prompted for credentials interactively; leave ``None`` to let + workers read the standard libpq environment variables on their own. """ if not cluster.files: return 0 first, *rest = cluster.files - first_columns = _infer_cluster_schema(first, cluster.include, cluster.exclude) + first_columns, first_total_rows = _infer_cluster_schema( + first, cluster.include, cluster.exclude, + column_types=cluster.column_types, + force_nullable=cluster.all_nullable, + ) # -- Validate index columns early --------------------------------------- if cluster.indexes: @@ -767,21 +951,61 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int: ) total = 0 - total += _stream_file( - conn, schemaname, cluster.tablename, first, first_columns, - cluster.include, cluster.exclude, - ) - for path in rest: - columns = _infer_cluster_schema(path, cluster.include, cluster.exclude) - # Uses the same check that if_exists=append runs. A type mismatch or - # missing column aborts the cluster; because chunks commit as they - # load, earlier chunks in the cluster remain in the table. - assert_schema_compatible(conn, schemaname, cluster.tablename, columns) - total += _stream_file( - conn, schemaname, cluster.tablename, path, columns, - cluster.include, cluster.exclude, + if workers > 1: + # Parallel path: commit the (empty) table now so worker subprocesses' + # ``assert_schema_compatible`` probes can actually see it via + # ``information_schema``, then dispatch *every* file (first + + # rest) to the pool. The previous design streamed the first file + # on the main connection before spawning workers, which made the + # serial first-file phase the long pole on big-file clusters + # (e.g. 52 × 5-50 GB). Now ``CREATE TABLE`` is the only serial + # work and it takes milliseconds. + conn.commit() + total += _load_remaining_files_parallel( + cluster.files, + schemaname, + cluster.tablename, + cluster.include, + cluster.exclude, + workers=workers, + progress_queue=progress_queue, + db_overrides=db_overrides, + column_types=cluster.column_types, + force_nullable=cluster.all_nullable, + abort_on_first_failure=abort_on_first_failure, ) + else: + # Serial path: stream the first file on the main connection, then + # iterate the rest. Worth keeping separate from the parallel path + # because spawning a single-worker pool just to load files in + # series would be pure overhead. + total += _stream_file( + conn, schemaname, cluster.tablename, first, first_columns, + cluster.include, cluster.exclude, + total_rows=first_total_rows, + progress_queue=progress_queue, + ) + conn.commit() + for path in rest: + columns, path_total_rows = _infer_cluster_schema( + path, cluster.include, cluster.exclude, + column_types=cluster.column_types, + force_nullable=cluster.all_nullable, + ) + # Uses the same check that if_exists=append runs. A type + # mismatch or missing column aborts the cluster; because + # chunks commit as they load, earlier chunks in the + # cluster remain in the table. + assert_schema_compatible( + conn, schemaname, cluster.tablename, columns + ) + total += _stream_file( + conn, schemaname, cluster.tablename, path, columns, + cluster.include, cluster.exclude, + total_rows=path_total_rows, + progress_queue=progress_queue, + ) # -- Index support ------------------------------------------------------ if cluster.indexes: @@ -798,21 +1022,324 @@ def _stream_file( columns, include, exclude, + *, + total_rows: Optional[int] = None, + progress_queue: Any = None, ) -> int: + """Stream ``path`` into an existing table chunk by chunk. + + When ``progress_queue`` is provided, each chunk's row count is published + to the queue as ``("rows", n)`` tuples instead of being rendered to a + per-file tqdm bar. That lets :func:`main` drive a single folder-wide + progress bar from a background drainer thread, which is the only way + to keep a coherent progress view when the folder loader is running + files in parallel workers. + """ def _chunks(): - seen = 0 - for chunk_df, _chunk_meta in iter_sas_chunks(path): - chunk_df = apply_column_filter(chunk_df, include, exclude) - seen += len(chunk_df) - print( - f" {path.name}: streaming... {seen:,} rows", - file=sys.stderr, - ) - yield chunk_df + if progress_queue is not None: + for chunk_df, _chunk_meta in iter_sas_chunks(path): + chunk_df = apply_column_filter(chunk_df, include, exclude) + progress_queue.put(("rows", len(chunk_df))) + yield chunk_df + return + + pbar = tqdm( + total=total_rows, + unit="row", + unit_scale=True, + desc=f" {path.name}", + file=sys.stderr, + dynamic_ncols=True, + ) + try: + for chunk_df, _chunk_meta in iter_sas_chunks(path): + chunk_df = apply_column_filter(chunk_df, include, exclude) + pbar.update(len(chunk_df)) + yield chunk_df + finally: + pbar.close() return copy_dataframes(conn, schemaname, tablename, _chunks(), columns) +# --------------------------------------------------------------------------- +# Parallel append workers +# --------------------------------------------------------------------------- + + +def _worker_load_append_file( + path_str: str, + schemaname: str, + tablename: str, + include: Optional[List[str]], + exclude: Optional[List[str]], + progress_queue: Any, + db_overrides: Optional[Dict[str, Optional[str]]], + column_types: Optional[Dict[str, str]] = None, + force_nullable: bool = False, +) -> Tuple[str, int, Optional[str]]: + """Worker process: load one SAS file in append mode. + + Runs in a subprocess spawned by :func:`_load_remaining_files_parallel`. + Opens its own psycopg2 connection, re-infers the per-file schema (so + per-file ``INTEGER`` vs ``BIGINT`` drift is caught by the existing + schema-compat check just like in the serial path), and streams chunks + via ``COPY``. Row counts are published to the shared queue for the + main process's global tqdm bar. + + Returns ``(path_str, rows_loaded, error_or_None)`` - failures are + returned rather than raised so the parent can aggregate results + across workers without losing partial progress. + """ + from pathlib import Path as _Path + + from dotenv import load_dotenv as _load_dotenv + + import ctypes + import ctypes.util + import gc + + from load_sas import ( + apply_column_filter as _apply_column_filter, + assert_schema_compatible as _assert_schema_compatible, + connect as _connect, + copy_dataframes as _copy_dataframes, + infer_schema as _infer_schema, + iter_sas_chunks as _iter_sas_chunks, + read_sas_preview as _read_sas_preview, + ) + + _load_dotenv() + + path = _Path(path_str) + try: + preview_df, meta = _read_sas_preview(path) + preview_df = _apply_column_filter(preview_df, include, exclude) + total_rows = getattr(meta, "number_rows", None) + columns = _infer_schema( + preview_df, meta, + total_rows=total_rows, + column_types=column_types, + force_nullable=force_nullable, + ) + # Drop the preview ASAP - on a 2M-row wide file it's hundreds of MB + # and we never need it again after schema inference. + del preview_df, meta + + user = db_overrides.get("user") if db_overrides else None + password = db_overrides.get("password") if db_overrides else None + conn = _connect(user=user, password=password) + conn.autocommit = False + try: + _assert_schema_compatible(conn, schemaname, tablename, columns) + + def _chunks(): + for chunk_df, _chunk_meta in _iter_sas_chunks(path): + chunk_df = _apply_column_filter(chunk_df, include, exclude) + if progress_queue is not None: + progress_queue.put(("rows", len(chunk_df))) + yield chunk_df + + rows = _copy_dataframes( + conn, schemaname, tablename, _chunks(), columns + ) + conn.commit() + return (path_str, rows, None) + finally: + conn.close() + except Exception as e: + import traceback as _traceback + tb = _traceback.format_exc() + # Keep the one-line summary (what the tqdm [FAIL] print uses) but + # tack on the full traceback so the final cluster-failure block + # shows the file/line that crashed. Without this, ``ProcessPool`` + # workers lose every frame of context - you get "FloatingPointError: + # overflow encountered in multiply" with no hint of where inside + # the pandas/numpy/pyarrow stack it happened. + return (path_str, 0, f"{type(e).__name__}: {e}\n{tb}") + finally: + # Hand memory back to the OS before the worker is recycled (or before + # ``max_tasks_per_child`` rotates this process). Three layers, each + # of which independently retains memory across calls: + # + # 1. pyarrow's memory pool aggressively reuses buffers - explicitly + # release_unused() returns them to the allocator. + # 2. Python's GC: cyclic refs from pandas/pyarrow chains aren't + # collected until a generation tick; force one now. + # 3. glibc's ptmalloc keeps freed heap in per-thread arenas instead + # of munmap'ing it back. ``malloc_trim(0)`` is the explicit ask. + # No-op (silently) on platforms without the symbol (macOS, etc). + try: + import pyarrow as _pa + _pa.default_memory_pool().release_unused() + except Exception: + pass + gc.collect() + try: + _libc_name = ctypes.util.find_library("c") + if _libc_name: + _libc = ctypes.CDLL(_libc_name) + if hasattr(_libc, "malloc_trim"): + _libc.malloc_trim(0) + except Exception: + pass + + +def _load_remaining_files_parallel( + files: List[Path], + schemaname: str, + tablename: str, + include: Optional[List[str]], + exclude: Optional[List[str]], + *, + workers: int, + progress_queue: Any, + db_overrides: Optional[Dict[str, Optional[str]]], + column_types: Optional[Dict[str, str]] = None, + force_nullable: bool = False, + abort_on_first_failure: bool = False, +) -> int: + """Run append-mode loads for ``files`` across a process pool. + + Each file is an independent unit of work submitted to + ``ProcessPoolExecutor``. Workers infer schema, validate compatibility, + and stream via COPY just like the serial path. The table itself must + already exist (and be committed) before this is called - the worker + schema-compat probes read ``information_schema``, which won't see an + uncommitted ``CREATE TABLE``. Failures are surfaced two ways: + + 1. **Live stderr feed.** Every worker's outcome - success or failure + - is written to stderr the moment ``as_completed`` hands it back, + via ``tqdm.write`` so the active progress bar stays intact. This + turns the previous "wait 30 minutes for the last worker to drain + before you find out the first 31 failed at minute 2" footgun into + an immediate notification, which matters most when one bad file + (e.g. schema drift, OOM, bad data) torpedoes most of the pool. + 2. **Final aggregate raise.** Errors are still collected and raised as + one ``RuntimeError`` after the pool drains so the per-cluster + success/fail summary in :func:`main` stays accurate. On + ``KeyboardInterrupt`` we re-raise after wrapping the partial + error list into the same ``RuntimeError`` shape, so Ctrl-C still + prints what failed up to that point instead of silently + discarding it. + """ + total = 0 + errors: List[Tuple[str, str]] = [] + completed = 0 + n_files = len(files) + + # ``max_tasks_per_child=1`` recycles each worker process after every + # file. Without this, glibc/pyarrow/pyreadstat all retain peak-water + # memory inside long-lived workers; over a multi-hour run the sum + # across workers monotonically grows even though individual chunks + # have been freed at the Python level. Recycling per file gives the + # OS the memory back unconditionally - the only cost is one fork + + # python interpreter startup per file (~1-2 s), which is noise next + # to multi-GB sas7bdat reads. + pool_kwargs: Dict[str, Any] = {"max_workers": workers} + if sys.version_info >= (3, 11): + pool_kwargs["max_tasks_per_child"] = 1 + + with ProcessPoolExecutor(**pool_kwargs) as pool: + futures = [ + pool.submit( + _worker_load_append_file, + str(p), + schemaname, + tablename, + include, + exclude, + progress_queue, + db_overrides, + column_types, + force_nullable, + ) + for p in files + ] + aborted = False + try: + for fut in as_completed(futures): + # ``--abort-on-first-failure`` cancels still-pending + # futures; ``as_completed`` yields them anyway and + # ``.result()`` raises ``CancelledError``. Skip those + # quietly - the abort log already accounted for the + # cancellation count. + try: + path_str, rows, err = fut.result() + except CancelledError: + continue + completed += 1 + name = Path(path_str).name + if err is not None: + errors.append((path_str, err)) + # ``tqdm.write`` writes through the bar without + # garbling it; bare ``print`` would interleave with + # the rendered progress line. + tqdm.write( + f"[FAIL {completed}/{n_files}] {name}: {err}", + file=sys.stderr, + ) + if abort_on_first_failure and not aborted: + # Cancel anything still queued. Already-running + # workers can't be interrupted portably, so they + # keep going - we just stop dispatching new files + # and stop counting their results toward total. + # Set the flag once so we don't spam a cancel + # storm if multiple workers fail at the same time. + aborted = True + cancelled = 0 + for f in futures: + if not f.done() and f.cancel(): + cancelled += 1 + tqdm.write( + f"[abort] --abort-on-first-failure: cancelled " + f"{cancelled} pending file(s); waiting for " + f"{n_files - completed - cancelled} in-flight " + f"worker(s) to finish.", + file=sys.stderr, + ) + else: + tqdm.write( + f"[done {completed}/{n_files}] {name}: " + f"{rows:,} row(s)", + file=sys.stderr, + ) + total += rows + except KeyboardInterrupt: + # Cancel anything still queued so we exit promptly. Already- + # running workers will run to completion (Python can't kill + # a child mid-syscall portably) but at least pending tasks + # don't keep firing. We re-raise as the same RuntimeError + # shape used below so the caller's per-cluster summary path + # still sees the partial failure list instead of an opaque + # KeyboardInterrupt with no context about which workers + # had already failed. + for f in futures: + if not f.done(): + f.cancel() + tqdm.write( + f"[interrupt] Ctrl-C after {completed}/{n_files} file(s); " + f"{len(errors)} failure(s) collected so far.", + file=sys.stderr, + ) + if errors: + joined = "\n".join(f" {p}: {e}" for p, e in errors) + raise RuntimeError( + f"interrupted after {len(errors)} worker(s) failed " + f"while appending to {schemaname}.{tablename}:\n{joined}" + ) + raise + + if errors: + joined = "\n".join(f" {p}: {e}" for p, e in errors) + raise RuntimeError( + f"{len(errors)} worker(s) failed while appending to " + f"{schemaname}.{tablename}:\n{joined}" + ) + + return total + + # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- @@ -844,6 +1371,20 @@ def _build_argparser() -> argparse.ArgumentParser: "cluster back and continue with the next one." ), ) + p.add_argument( + "--abort-on-first-failure", + action="store_true", + help=( + "Within a single cluster's parallel append, cancel every " + "still-pending worker the instant one worker fails. Use when " + "you know the failure is systemic (schema drift, bad creds, " + "OOM) and don't want to wait for the slow files to drain " + "before getting your prompt back. Already-running workers " + "still finish their current file - Python can't kill children " + "mid-syscall - but new files won't be dispatched. Orthogonal " + "to --fail-fast, which controls what happens between clusters." + ), + ) p.add_argument( "--dbcreds", action="store_true", @@ -852,6 +1393,58 @@ def _build_argparser() -> argparse.ArgumentParser: "PGUSER / PGPASSWORD from the environment or .env file." ), ) + p.add_argument( + "--chunk-rows", + type=int, + default=None, + metavar="N", + help=( + "Per-chunk row target for pyreadstat streaming and COPY. " + "Overrides both the GENERIC_LOADER_CHUNK_ROWS env var and the " + "auto-scaling applied when --workers > 1. Peak memory per " + "worker is roughly 4 × N × avg_row_bytes; with wide sas7bdat " + "files (~4 KB/row) and 32 workers, N=100000 is a safe starting " + "point on a 128 GB box." + ), + ) + p.add_argument( + "--no-prescan", + action="store_true", + help=( + "Skip the per-file metadata scan that populates the folder-wide " + "tqdm ETA. Useful when the folder is large (half-hour+ pre-scan) " + "or when you're iterating quickly on a failure. Without the " + "pre-scan the progress bar still shows rows loaded, rate, and " + "elapsed time - it just can't estimate remaining time." + ), + ) + p.add_argument( + "--all-nullable", + action="store_true", + help=( + "Stamp every column nullable in the generated schema, bypassing " + "NOT NULL inference for every cluster. Use when sampled rows " + "wrongly suggest a column has no nulls and COPY fails mid-load " + "on the first null it hits. Overrides the per-cluster and " + "folder-level ``all_nullable`` YAML settings when set." + ), + ) + p.add_argument( + "--workers", + type=int, + default=1, + metavar="N", + help=( + "Number of worker processes for the append phase. With N=1 " + "(default) files load serially on the main connection. With " + "N>1 the first file of each cluster still runs serially (to " + "create the table), then the remaining files load in parallel " + "across N processes, each with its own psycopg2 connection. " + "On a big box try N close to your core count. When N>1 the " + "per-chunk row target drops to 500,000 unless you've pinned " + "GENERIC_LOADER_CHUNK_ROWS, so peak memory stays bounded." + ), + ) return p @@ -884,6 +1477,20 @@ def main(argv: Optional[List[str]] = None) -> int: return 2 clusters = discover_clusters(cfg) + + # CLI override: --all-nullable trumps both folder-level and per-cluster + # YAML ``all_nullable`` settings. Applied here (before any schema work) + # so every downstream path - dry-run, pre-scan, worker dispatch - sees + # the same flag on the ClusterSpec. + if args.all_nullable: + for c in clusters: + c.all_nullable = True + print( + "[info] --all-nullable set: stamping every column nullable " + "across all clusters (NOT NULL inference disabled).", + file=sys.stderr, + ) + loadable = [c for c in clusters if c.files] if not loadable: @@ -902,7 +1509,15 @@ def main(argv: Optional[List[str]] = None) -> int: print() for c in loadable: print(f"--- DDL for cluster {c.tablename!r} ---") - columns = _infer_cluster_schema(c.files[0], c.include, c.exclude) + # Dry-run skips the pre-scan (so no auto-union) but user-supplied + # ``column_types`` from YAML are already baked into ``c.column_types`` + # by ``discover_clusters`` - honor them here so the previewed DDL + # matches what a real load would produce on a single-file cluster. + columns, _ = _infer_cluster_schema( + c.files[0], c.include, c.exclude, + column_types=c.column_types, + force_nullable=c.all_nullable, + ) # Print parent CREATE TABLE (with PARTITION BY if applicable). print( render_create_table( @@ -970,6 +1585,233 @@ def main(argv: Optional[List[str]] = None) -> int: if args.dbcreds: db_user = input("Database username: ") db_password = getpass.getpass("Database password: ") + db_overrides: Optional[Dict[str, Optional[str]]] = ( + {"user": db_user, "password": db_password} if args.dbcreds else None + ) + + workers = max(1, int(args.workers)) + + # Per-worker peak memory ~= chunk_rows × avg_row_bytes × ~4 (the original + # pyreadstat DataFrame, the type-coerced ``prepared`` copy, the pyarrow + # table, and the serialized CSV buffer can all be alive simultaneously). + # With 32 workers and 500k rows × wide sas7bdat that's easily >128 GB - + # the default the loader shipped with OOM'd on a c6i.32xlarge box. Scale + # the auto target inversely with worker count so total memory stays + # roughly flat regardless of how many workers you pick. Floor of 50k + # keeps per-chunk overhead amortized; ceiling of 500k is where pyarrow + # / pyreadstat buffer spikes start to dominate. + # + # Order of precedence (most wins): + # 1. ``--chunk-rows N`` CLI flag (if provided) + # 2. ``GENERIC_LOADER_CHUNK_ROWS`` env var (if already set) + # 3. Auto-pick based on ``workers`` + if args.chunk_rows is not None: + os.environ["GENERIC_LOADER_CHUNK_ROWS"] = str(int(args.chunk_rows)) + print( + f"[info] --chunk-rows {args.chunk_rows:,}: pinning per-chunk " + f"row target (overrides auto-scaling).", + file=sys.stderr, + ) + elif "GENERIC_LOADER_CHUNK_ROWS" in os.environ: + print( + f"[info] honoring GENERIC_LOADER_CHUNK_ROWS=" + f"{os.environ['GENERIC_LOADER_CHUNK_ROWS']} from environment.", + file=sys.stderr, + ) + elif workers > 1: + auto_rows = max(50_000, min(500_000, 3_200_000 // workers)) + os.environ["GENERIC_LOADER_CHUNK_ROWS"] = str(auto_rows) + print( + f"[info] parallel mode (workers={workers}): auto-scaled " + f"per-chunk rows to {auto_rows:,}. " + f"Use --chunk-rows N to override if you have RAM headroom.", + file=sys.stderr, + ) + + # -- Metadata pre-scan ----------------------------------------------------- + # Sum ``number_rows`` across every file so the tqdm bar has a real + # denominator, AND collect the per-column (readstat_type, sas_format) + # tuples so we can union schemas across files in a cluster before any + # CREATE TABLE runs. ``read_sas_metadata`` uses pyreadstat's + # ``metadataonly=True`` fast path, but on multi-GB sas7bdat files + # that still reads tens of MB of scattered subheader pages per file - + # sequentially that's minutes for a 52-file folder. pyreadstat + # releases the GIL during I/O and C decoding, so a ThreadPool gives + # near-linear scaling until the disk saturates. ``--no-prescan`` + # bypasses the scan entirely; the progress bar then runs without an + # ETA *and* the auto-union is skipped (user overrides from YAML + # still apply). + all_files: List[Path] = [p for c in loadable for p in c.files] + grand_total: Optional[int] = 0 + file_meta_by_path: Dict[str, Dict[str, Tuple[str, Optional[str]]]] = {} + + if args.no_prescan: + grand_total = None + print( + f"[info] --no-prescan set: skipping row-count pre-scan for " + f"{len(all_files)} file(s); progress bar will show rate + " + f"elapsed but no ETA. Cluster-wide schema auto-union is also " + f"disabled; only user-specified column_types overrides apply.", + file=sys.stderr, + ) + else: + prescan_workers = min(16, max(1, len(all_files))) + print( + f"pre-scanning row counts + per-column metadata for " + f"{len(all_files)} file(s) across {prescan_workers} thread(s)...", + file=sys.stderr, + ) + + def _scan_one( + p: Path, + ) -> Tuple[ + Path, + Optional[int], + Optional[Dict[str, Tuple[str, Optional[str]]]], + Optional[str], + ]: + try: + meta = read_sas_metadata(p) + n = getattr(meta, "number_rows", None) + col_meta = extract_union_metadata(meta) + return ( + p, + int(n) if n is not None else None, + col_meta, + None, + ) + except Exception as e: + return (p, None, None, str(e)) + + unknown_total_files: List[str] = [] + running_total = 0 + with ThreadPoolExecutor(max_workers=prescan_workers) as tpool: + prescan_bar = tqdm( + total=len(all_files), + unit="file", + desc=" prescanning", + file=sys.stderr, + dynamic_ncols=True, + ) + try: + for p, n, col_meta, err in tpool.map(_scan_one, all_files): + prescan_bar.update(1) + if err is not None: + unknown_total_files.append(f"{p.name} ({err})") + elif n is None: + unknown_total_files.append(p.name) + else: + running_total += n + if col_meta is not None: + file_meta_by_path[str(p)] = col_meta + finally: + prescan_bar.close() + + if unknown_total_files: + print( + f"[warn] could not read row count from " + f"{len(unknown_total_files)} file(s); progress bar ETA will " + f"be approximate.", + file=sys.stderr, + ) + print( + f" total rows across folder: {running_total:,}", + file=sys.stderr, + ) + grand_total = running_total + + # -- Cluster-wide schema auto-union --------------------------------------- + # For each cluster, compute ``auto_types`` from the union of every + # file's metadata (see :func:`load_sas.union_column_types`). Merge with + # any user-supplied YAML overrides (user wins) and attach the result + # back onto the cluster so every later read - first-file inference, + # worker inference, schema-compat check - sees the same frozen schema. + # With ``--no-prescan`` the file_meta_by_path dict is empty and + # ``auto_types`` resolves to {}, so only the YAML overrides survive. + for c in loadable: + per_file = [ + file_meta_by_path[str(p)] + for p in c.files + if str(p) in file_meta_by_path + ] + auto_types = union_column_types(per_file) if per_file else {} + user_overrides = dict(c.column_types) # already merged folder+cluster + # User-supplied overrides win over the auto-union. + merged = {**auto_types, **user_overrides} + c.column_types = merged + + if auto_types: + # Only call out columns where auto-union *changed* something + # relative to the default "first file wins" inference. We + # don't have the default inference in hand at this point, so + # log the full resolved map at a debug-friendly level - it's + # bounded by column count and the user asked for visibility + # into what got overridden. + shown = auto_types + if user_overrides: + # Distinguish the user-forced entries in the log so it's + # obvious which types came from YAML. + shown = { + col: ( + f"{user_overrides[col]} (user override)" + if col in user_overrides + else pg + ) + for col, pg in merged.items() + } + print( + f"[info] cluster {c.tablename!r}: auto-union derived " + f"{len(auto_types)} column type(s) across " + f"{len(per_file)} file(s): {shown}", + file=sys.stderr, + ) + elif user_overrides and args.no_prescan: + print( + f"[info] cluster {c.tablename!r}: using {len(user_overrides)} " + f"user-supplied column_types override(s); auto-union " + f"disabled by --no-prescan.", + file=sys.stderr, + ) + + # -- Shared progress plumbing --------------------------------------------- + # The queue crosses process boundaries when workers > 1 (managed proxy) + # and is a plain in-process queue otherwise; the put/get contract is + # identical either way. A daemon thread drains it and advances the one + # tqdm bar that spans the whole folder load. + manager: Optional[Any] = None + progress_queue: Any + if workers > 1: + manager = mp.Manager() + progress_queue = manager.Queue() + else: + progress_queue = _queue_mod.Queue() + + pbar = tqdm( + total=grand_total or None, + unit="row", + unit_scale=True, + desc=f"{cfg.folder.name}", + file=sys.stderr, + dynamic_ncols=True, + ) + stop_drainer = threading.Event() + + def _drainer() -> None: + while not stop_drainer.is_set(): + try: + event = progress_queue.get(timeout=0.1) + except _queue_mod.Empty: + continue + except (EOFError, OSError): + return + if not event: + continue + kind = event[0] + if kind == "rows": + pbar.update(event[1]) + + drainer_thread = threading.Thread(target=_drainer, daemon=True) + drainer_thread.start() conn = connect(user=db_user, password=db_password) conn.autocommit = False @@ -979,10 +1821,19 @@ def main(argv: Optional[List[str]] = None) -> int: for cluster in loadable: print( f"\n>>> loading cluster {cluster.tablename!r} " - f"({len(cluster.files)} file(s))" + f"({len(cluster.files)} file(s)) " + f"[workers={workers}]" ) try: - rows = load_cluster(conn, cluster, cfg.schemaname) + rows = load_cluster( + conn, + cluster, + cfg.schemaname, + workers=workers, + progress_queue=progress_queue, + db_overrides=db_overrides, + abort_on_first_failure=args.abort_on_first_failure, + ) conn.commit() totals.append((cluster.tablename, len(cluster.files), rows)) print( @@ -999,7 +1850,14 @@ def main(argv: Optional[List[str]] = None) -> int: if args.fail_fast: break finally: + # Drain any pending progress events before shutting the bar down so + # the final rendered total matches what actually landed. + stop_drainer.set() + drainer_thread.join(timeout=2.0) + pbar.close() conn.close() + if manager is not None: + manager.shutdown() print("\n=== summary ===") for name, fcount, rows in totals: diff --git a/generic_loader/load_sas.py b/generic_loader/load_sas.py index 74f5ff8..8db1d5c 100644 --- a/generic_loader/load_sas.py +++ b/generic_loader/load_sas.py @@ -183,15 +183,17 @@ Priority order used by :func:`infer_schema`: value exceeds the int32 range ``NUMERIC_INT_RANGE``); otherwise ``DOUBLE PRECISION``. -Type inference scans only the first ``TYPE_INFERENCE_SAMPLE_ROWS`` rows for -performance on large files. The CLI enforces this at read time via -:func:`read_sas_preview`, so the whole file is never materialized just to pick -types. Sampled specs carry an ``inferred_from_sample`` marker and the usual -tradeoffs: if the first N rows fit ``INTEGER`` but a later row exceeds int32, -or a column had no nulls in the preview but does later in the file, ``COPY`` -will fail mid-stream and the whole transaction rolls back. Set -``TYPE_INFERENCE_SAMPLE_ROWS = None`` to scan every row when exact typing -matters more than speed. +Type inference scans the whole file by default (``TYPE_INFERENCE_SAMPLE_ROWS += None``) so type + nullability are both computed against every row. The CLI +materializes the file once for schema inference, then re-streams it chunk by +chunk into ``COPY``; peak memory is roughly one full dataframe. Override +``TYPE_INFERENCE_SAMPLE_ROWS`` to an integer cap if you're on a host that +can't hold the file in memory - but know that sampled specs carry the usual +risks: a later row may exceed the inferred integer range, or a column that +had no nulls in the preview may carry nulls later in the file (which then +detonates ``COPY`` because the sampled spec stamped it ``NOT NULL``). Seen +in production on a 2.5M-row file with ~6k null MAFIDs past the 10k-row +preview - the entire load aborted mid-stream. Streaming loads use :func:`iter_sas_chunks` + :func:`copy_dataframes`, which commit each chunk as it is copied so an interrupted load retains the rows @@ -225,16 +227,48 @@ import math import os import re import sys +import warnings from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Tuple +import numpy as np import pandas as pd import psycopg2 import psycopg2.extensions +import pyarrow as pa +import pyarrow.csv as pa_csv import pyreadstat import yaml from dotenv import load_dotenv +from pandas.errors import PerformanceWarning +from tqdm import tqdm + +# ``_prepare_for_copy`` builds its output frame one column at a time with +# ``out[name] = ...``. On wide SAS files (~100+ columns) pandas prints a +# ``PerformanceWarning: DataFrame is highly fragmented`` once per chunk to +# nudge callers toward ``pd.concat(axis=1, ...)``. The fragmentation only +# matters for row-oriented ops or in-place ``.copy()``; we hand the frame +# straight to ``pyarrow.Table.from_pandas`` which reads columns +# independently, so the warning is pure noise for our pipeline. Filter it +# at import time - narrow category match so nothing else is suppressed. +warnings.filterwarnings("ignore", category=PerformanceWarning) + +# Turn numpy's "raise on float overflow" (and friends) into silent inf/nan +# production, module-wide. Pandas ships with ``np.errstate(over="raise")`` +# wrapped around several internal ops (most painfully, the multiply inside +# ``pd.to_datetime(unit="s")`` that converts SAS epoch -> nanoseconds). +# Our data routinely carries ``inf`` / huge sentinels, which trip that +# ``raise`` and blow up an entire worker before ``errors="coerce"`` gets +# a chance to turn them into NaT. Even with ``_safe_numeric_to_datetime`` +# pre-masking the obvious cases, other code paths (pandas object-dtype +# datetime parsing, pyarrow type promotion, pyreadstat) can also trigger. +# Setting a process-wide ``seterr`` is a heavier hammer than an +# ``errstate`` block but survives library internals that don't explicitly +# rewrap it. Downside: a real overflow bug in new code would now silently +# produce inf/nan instead of raising - acceptable for a bulk loader where +# "don't crash on bad rows, null them and move on" is the whole point. +np.seterr(over="ignore", invalid="ignore", divide="ignore") logger = logging.getLogger(__name__) @@ -255,17 +289,29 @@ values; too small a sample is easy to mis-infer.""" NUMERIC_INT_RANGE = (-2_147_483_648, 2_147_483_647) """INTEGER bounds; anything outside becomes BIGINT.""" -TYPE_INFERENCE_SAMPLE_ROWS: Optional[int] = 10_000 +TYPE_INFERENCE_SAMPLE_ROWS: Optional[int] = None """Cap on rows inspected during per-column type inference. Also governs how many rows :func:`read_sas_preview` pulls from the file for dry-run / validate / -schema-inference flows. Set to ``None`` to scan every row (and read the whole -file into memory for the preview step - don't do this on multi-hundred-million -row files).""" +schema-inference flows. -DEFAULT_CHUNK_ROWS = 100_000 +Default is ``None`` (scan every row, reading the whole file into memory for +the schema-inference step). That's the only honest setting for nullability: +any integer cap lets a column look ``NOT NULL`` across the first N rows +while the file actually holds rare nulls past the window, which then +detonates ``COPY`` mid-stream (seen in production on a 2.5M-row file where +~6k MAFIDs were null past the 10k-row preview). If you're loading a file +so large that a full read won't fit in memory, set this to an integer cap +and accept that sampled specs can't be trusted for ``NOT NULL``.""" + +DEFAULT_CHUNK_ROWS = 2_000_000 """Rows per chunk when streaming a SAS file into ``COPY``. Larger values mean -fewer COPY round-trips but more peak memory per chunk; smaller values are -gentler on memory.""" +fewer COPY round-trips and lower per-row overhead but more peak memory per +chunk; smaller values are gentler on memory. + +The chunk size can be overridden at runtime via the +``GENERIC_LOADER_CHUNK_ROWS`` environment variable (read inside +:func:`iter_sas_chunks`), so ``.env``-driven overrides work without code +changes. Explicit ``chunksize=`` kwargs still win over both.""" VALID_IF_EXISTS = ("fail", "replace", "append") @@ -290,6 +336,8 @@ class LoaderConfig: partition_by: List[str] = field(default_factory=list) max_partitions: int = 10_000 indexes: List[str] = field(default_factory=list) + column_types: Dict[str, str] = field(default_factory=dict) + all_nullable: bool = False @dataclass @@ -500,6 +548,48 @@ def load_config(path: Path) -> LoaderConfig: f"{missing_in_include}" ) + # -- column_types ------------------------------------------------------- + # Optional ``{column_name: pg_type}`` escape hatch that bypasses + # automatic type inference for specific columns. Useful when + # pyreadstat reports a column as NUM but the downstream consumer + # expects TEXT (e.g. phone-id columns), or when a column has drifted + # between CHAR and NUM across file versions and you want to pin + # TEXT up front. See also :func:`infer_schema`. + raw_ct = raw.get("column_types") + column_types: Dict[str, str] = {} + if raw_ct is not None: + if not isinstance(raw_ct, dict): + raise ValueError( + f"Config {path}: 'column_types' must be a mapping of " + f"{{column_name: postgres_type}}." + ) + for k, v in raw_ct.items(): + key = str(k).strip() + if not key: + raise ValueError( + f"Config {path}: 'column_types' contains an empty " + f"column name." + ) + if not isinstance(v, str) or not v.strip(): + raise ValueError( + f"Config {path}: 'column_types[{key}]' must be a " + f"non-empty Postgres type string (got {v!r})." + ) + column_types[key] = v.strip() + + # -- all_nullable ------------------------------------------------------- + # When inference wrongly stamps a column NOT NULL (sampled rows happened + # to be dense; later rows carry nulls) downstream COPYs fail mid-stream. + # Set ``all_nullable: true`` in the YAML to stamp every column nullable + # up front. The CLI flag ``--all-nullable`` overrides this to ``true`` + # if set. + raw_an = raw.get("all_nullable", False) + if not isinstance(raw_an, bool): + raise ValueError( + f"Config {path}: 'all_nullable' must be a boolean (got {raw_an!r})." + ) + all_nullable = bool(raw_an) + return LoaderConfig( filename=filename, schemaname=schemaname, @@ -510,6 +600,8 @@ def load_config(path: Path) -> LoaderConfig: partition_by=partition_by, max_partitions=max_partitions, indexes=indexes, + column_types=column_types, + all_nullable=all_nullable, ) @@ -563,16 +655,46 @@ def read_sas_preview( return reader(str(Path(path)), row_limit=row_limit, **kwargs) +def read_sas_metadata(path: Path) -> Any: + """Read only the metadata (no rows) from a SAS file. + + Uses pyreadstat's ``metadataonly=True`` fast path: the reader decodes + the file header (column names, formats, total row count, etc.) and + returns without touching the data pages. Orders of magnitude faster + than :func:`read_sas_preview` when all you need is + ``meta.number_rows`` - typically a few ms per sas7bdat file, which + makes it cheap to pre-scan a whole folder to populate a global + progress bar. + """ + reader, kwargs = _sas_reader(path) + _, meta = reader(str(Path(path)), metadataonly=True, **kwargs) + return meta + + def iter_sas_chunks( path: Path, *, - chunksize: int = DEFAULT_CHUNK_ROWS, + chunksize: Optional[int] = None, ): """Yield ``(df_chunk, meta)`` tuples for streaming loads. Thin wrapper over ``pyreadstat.read_file_in_chunks`` that picks the right underlying reader by extension and threads through our encoding defaults. + + When ``chunksize`` is ``None`` (the default), the effective value comes + from the ``GENERIC_LOADER_CHUNK_ROWS`` environment variable if set and + parseable, otherwise from :data:`DEFAULT_CHUNK_ROWS`. An explicit int + always wins. """ + if chunksize is None: + raw = os.environ.get("GENERIC_LOADER_CHUNK_ROWS") + if raw is not None: + try: + chunksize = int(raw) + except ValueError: + chunksize = DEFAULT_CHUNK_ROWS + else: + chunksize = DEFAULT_CHUNK_ROWS reader, kwargs = _sas_reader(path) yield from pyreadstat.read_file_in_chunks( reader, str(Path(path)), chunksize=chunksize, **kwargs @@ -590,7 +712,13 @@ def apply_column_filter( exclude: Optional[List[str]], ) -> pd.DataFrame: """Restrict ``df`` to the requested columns. Names missing from the frame - raise a clear error rather than silently dropping.""" + raise a clear error rather than silently dropping. + + Returns the input frame (or a column-sliced view / drop result) without + an extra ``.copy()`` — downstream (:func:`_prepare_for_copy`) reads the + frame into a freshly built output and never mutates its input, so the + copies were pure overhead on every streamed chunk. + """ if include is not None and exclude is not None: raise ValueError("include and exclude are mutually exclusive.") @@ -598,15 +726,15 @@ def apply_column_filter( missing = [c for c in include if c not in df.columns] if missing: raise ValueError(f"include references unknown columns: {missing}") - return df.loc[:, list(include)].copy() + return df.loc[:, list(include)] if exclude is not None: missing = [c for c in exclude if c not in df.columns] if missing: raise ValueError(f"exclude references unknown columns: {missing}") - return df.drop(columns=list(exclude)).copy() + return df.drop(columns=list(exclude)) - return df.copy() + return df # --------------------------------------------------------------------------- @@ -634,6 +762,126 @@ def _format_driven_type(sas_format: Optional[str]) -> Optional[str]: return None +_DECIMAL_FORMAT_RE = re.compile(r"\.(\d+)") + + +def _format_hints_decimal(sas_format: Optional[str]) -> bool: + """True if a numeric SAS format string explicitly carries decimal places. + + SAS numeric formats are ``NAMEw.d``; ``d > 0`` means the variable was + intended to render with ``d`` decimal digits (COMMA10.2, F8.3, ...). + A bare width like ``BEST12.`` or ``F8.`` has no digits after the dot + and is treated as integer-presenting. Used by + :func:`union_column_types` to pick BIGINT vs DOUBLE PRECISION when a + column is numeric in every file of a cluster. + """ + if not sas_format: + return False + m = _DECIMAL_FORMAT_RE.search(sas_format) + if not m: + return False + try: + return int(m.group(1)) > 0 + except ValueError: + return False + + +def extract_union_metadata( + meta: Any, +) -> Dict[str, Tuple[str, Optional[str]]]: + """Pull the (readstat_type, sas_format) pair for every column in ``meta``. + + Returns a plain dict that's safe to pass between processes and to + :func:`union_column_types`. ``readstat_type`` is the simplified type + reported by pyreadstat: ``"string"`` for SAS CHAR, ``"double"`` for + SAS NUM. ``sas_format`` comes from ``meta.original_variable_types`` + and drives date/datetime detection during union. + """ + var_types = dict(getattr(meta, "variable_types", None) or {}) + formats = dict(getattr(meta, "original_variable_types", None) or {}) + names = list( + getattr(meta, "column_names", None) + or list(var_types.keys()) + or list(formats.keys()) + ) + out: Dict[str, Tuple[str, Optional[str]]] = {} + for col in names: + rtype = str(var_types.get(col, "")) if var_types else "" + fmt = formats.get(col) + out[col] = (rtype, fmt if fmt else None) + return out + + +def union_column_types( + per_file_metas: Iterable[Dict[str, Tuple[str, Optional[str]]]], +) -> Dict[str, str]: + """Derive one Postgres type per column that's safe across every file. + + ``per_file_metas`` is an iterable (one entry per file in a cluster) of + ``{column_name: (readstat_type, sas_format)}`` dicts as produced by + :func:`extract_union_metadata`. + + Rules, evaluated per column: + + * **CHAR/NUM drift wins TEXT.** If any file stores the column as CHAR + (``readstat_type != "double"``) the union is ``TEXT``. This covers + the phone-id case where some years stored ``RESP_PH_PREFIX_ID`` as + CHAR and others as NUM. + * **All NUM, format hints DATETIME → TIMESTAMP.** Any file whose + format resolves to ``TIMESTAMP`` (via :func:`_format_driven_type`) + pins the column to ``TIMESTAMP`` even if other files left the + format blank. + * **All NUM, format hints DATE → DATE.** Same idea for date-only + formats. + * **All NUM, any decimal hint → DOUBLE PRECISION.** A ``w.d`` format + with ``d > 0`` in any file implies fractional values somewhere. + * **All NUM, no useful hint → DOUBLE PRECISION.** SAS numeric + formats are *display* formats, not storage constraints - a + ``BEST12.`` / ``F8.`` / blank-format column can still hold floats, + and pyreadstat hands back plain ``float64`` regardless. Defaulting + to ``DOUBLE PRECISION`` here costs the same 8 bytes as ``BIGINT`` + but can't fail on real data. For columns that truly are + integer-only and you want ``BIGINT`` semantics in queries, pin + them via a ``column_types`` override. + + Columns missing from a given file are simply skipped for that file; + the union is computed over whichever files *did* supply the column. + Columns that never appear anywhere are omitted from the result. + """ + per_col: Dict[str, List[Tuple[str, Optional[str]]]] = {} + for meta in per_file_metas: + for col, pair in meta.items(): + per_col.setdefault(col, []).append(pair) + + result: Dict[str, str] = {} + for col, entries in per_col.items(): + any_char = any( + rtype and rtype.lower() != "double" for rtype, _ in entries + ) + if any_char: + result[col] = "TEXT" + continue + formats = [fmt for _, fmt in entries if fmt] + driven = [_format_driven_type(f) for f in formats] + if "TIMESTAMP" in driven: + result[col] = "TIMESTAMP" + elif "DATE" in driven: + result[col] = "DATE" + else: + # Safe default: DOUBLE PRECISION. The BIGINT default we tried + # first failed the moment a file contained a fractional + # value in a column whose format didn't carry a decimal + # hint (very common: SAS ``BEST12.`` / ``F8.`` are display + # formats, not storage constraints, so the underlying + # 8-byte float can hold any value). Same storage cost as + # BIGINT, handles both integer- and float-valued data, and + # keeps loads from failing mid-cluster. Use a + # ``column_types`` override to pin specific columns to + # ``BIGINT`` when you want integer semantics in queries. + result[col] = "DOUBLE PRECISION" + return result + + def _all_null(series: pd.Series) -> bool: if pd.api.types.is_object_dtype(series): return bool(series.map(lambda v: v is None or (isinstance(v, str) and v == "") or (isinstance(v, float) and pd.isna(v))).all()) @@ -759,6 +1007,8 @@ def infer_schema( *, coerce_chars: bool = COERCE_CHAR_COLUMNS, total_rows: Optional[int] = None, + column_types: Optional[Dict[str, str]] = None, + force_nullable: bool = False, ) -> Dict[str, ColumnSpec]: """Infer a Postgres column spec for each column in ``df``. @@ -774,11 +1024,30 @@ def infer_schema( ``total_rows`` lets callers who already sampled the frame (e.g. via :func:`read_sas_preview`) report the real file size in the per-column "inferred from first N of M rows" note. Falls back to ``len(df)``. + + ``column_types`` is an optional map ``{column_name: pg_type_str}`` + whose entries bypass inference entirely - the caller has already + decided the type (e.g. via :func:`union_column_types` across a + cluster, or a YAML ``column_types`` override). Nullability is still + computed from the data. Columns in ``column_types`` that don't exist + in ``df`` are ignored so a shared override dict can apply to clusters + with different column sets. + + ``force_nullable=True`` stamps every column nullable regardless of + what the data sample shows. Escape hatch for when inference marks a + column ``NOT NULL`` because the sampled rows happened to be dense but + downstream files carry nulls in that column - common with cluster + loads where one file's preview can't speak for the rest. Cheaper than + trying to sharpen the sampler: widen the column and move on. """ original_formats: Dict[str, str] = dict(getattr(meta, "original_variable_types", {}) or {}) - # Row-walking type probes run on a bounded head slice; nullability and the - # all-null check still see every row so NOT NULL declarations stay honest. + # When ``TYPE_INFERENCE_SAMPLE_ROWS`` is an integer cap, row-walking type + # probes run on the head slice for speed; nullability and the all-null + # check still walk every row of ``df``. That's only honest when the + # caller handed us the full file - with the default cap of ``None`` the + # CLI does exactly that. Callers who pass a partial preview and a tight + # integer cap accept that ``NOT NULL`` can be wrong for rare-null columns. df_rows = len(df) effective_total = total_rows if total_rows is not None else df_rows if TYPE_INFERENCE_SAMPLE_ROWS is not None and df_rows > TYPE_INFERENCE_SAMPLE_ROWS: @@ -789,6 +1058,8 @@ def infer_schema( sample_size = df_rows sampled = sample_size < effective_total + overrides: Dict[str, str] = dict(column_types or {}) + # Temporarily flip the module-level flag if the caller asked us to. global COERCE_CHAR_COLUMNS saved = COERCE_CHAR_COLUMNS @@ -801,6 +1072,27 @@ def infer_schema( sas_format = original_formats.get(col) notes: List[str] = [] + if col in overrides: + pg_type = overrides[col] + notes.append( + f"type forced to {pg_type} via column_types override" + ) + if force_nullable: + nullable = True + notes.append("nullable forced via --all-nullable") + else: + nullable = _is_nullable(series) + out[col] = ColumnSpec( + name=col, + postgres_type=pg_type, + nullable=nullable, + sas_format=sas_format, + source_dtype=str(series.dtype), + notes=notes, + sampled=sampled, + ) + continue + pg_type = _format_driven_type(sas_format) if pg_type is None: @@ -831,7 +1123,11 @@ def infer_schema( f"{effective_total:,} rows" ) - nullable = _is_nullable(series) + if force_nullable: + nullable = True + notes.append("nullable forced via --all-nullable") + else: + nullable = _is_nullable(series) out[col] = ColumnSpec( name=col, @@ -964,6 +1260,30 @@ def _normalize_type(pg_type: str) -> str: return _TYPE_NORMALIZATION.get(stripped, stripped.lower()) +# Widening pairs: (inferred_from_source, existing_in_target). When the +# incoming spec is narrower than the target we accept it - the value is +# guaranteed to fit, and ``_prepare_for_copy`` already emits ``COPY`` +# payloads that Postgres silently promotes to the wider column type. The +# INVERSE direction stays a hard failure: a BIGINT value does not fit in +# an INTEGER column, so we must not let a cluster whose first file had +# only small ints accept a later file with a value past int32. Comes up +# most often on cluster loads where file 1 pushed the target to BIGINT +# (a single value > 2_147_483_647) and file N happens to sit entirely +# within int32 range - strict equality would reject file N even though +# the copy is trivially safe. +_WIDENING_COMPATIBLE: set = { + ("smallint", "integer"), + ("smallint", "bigint"), + ("integer", "bigint"), + ("real", "double precision"), + # INTEGER / BIGINT into DOUBLE PRECISION is lossless for int32 and + # exact up to 2**53 for int64, which covers every value pandas could + # have carried through as Int64 without wrapping anyway. + ("integer", "double precision"), + ("bigint", "double precision"), +} + + def _assert_schema_compatible( conn, schema: str, table: str, columns: Dict[str, ColumnSpec] ) -> None: @@ -990,11 +1310,22 @@ def _assert_schema_compatible( inferred_norm = _normalize_type(spec.postgres_type) target_norm = _normalize_type(target_type) if inferred_norm != target_norm: - mismatches.append( - f"column {name!r}: inferred {spec.postgres_type} " - f"(normalized {inferred_norm!r}) but target is {target_type} " - f"(normalized {target_norm!r})" - ) + if (inferred_norm, target_norm) in _WIDENING_COMPATIBLE: + # Narrower inferred type fits inside the wider target. + # Accept silently-but-noisily so the operator knows the + # file came in with a smaller range than the cluster's + # target was sized for. + warnings.append( + f"column {name!r}: inferred {spec.postgres_type} " + f"(narrower than target {target_type}); accepting - " + f"values fit in the wider target type" + ) + else: + mismatches.append( + f"column {name!r}: inferred {spec.postgres_type} " + f"(normalized {inferred_norm!r}) but target is {target_type} " + f"(normalized {target_norm!r})" + ) target_is_notnull = (target_nullable == "NO") if spec.nullable and target_is_notnull: warnings.append( @@ -1661,9 +1992,151 @@ def _seconds_to_time(v: Any) -> Optional[dt.time]: return dt.time(h, m, s) +# Safe outer bound for the numeric->datetime conversion below. The true +# ceiling is ``pd.Timestamp.max`` (2262-04-11), which in seconds since 1960 +# is ~9.52e9. We pick a much tighter bound - year ~2200, ~7.6e9 seconds, +# ~87600 days - because (a) any real SAS data past ~2100 is garbage anyway, +# and (b) staying well inside the float64 + datetime64[ns] windows gives +# pandas' internals zero room to trip the ``over="raise"`` they wrap +# around the ns-multiply. ``7.5e9 * 1e9 = 7.5e18``, comfortably under both +# ``int64.max`` (~9.22e18) and float64 overflow (~1.8e308). +_SAS_DATETIME_SAFE_S = 7_500_000_000 +_SAS_DATETIME_SAFE_D = 87_000 + + +def _safe_numeric_to_datetime( + series: pd.Series, + *, + unit: str, + column_name: str, + target_type: str, +) -> pd.Series: + """Convert a numeric SAS-epoch series to ``datetime64[ns]`` without letting + one stray cell take down the worker. + + Failure modes seen in production: + + * ``np.inf`` / ``-np.inf`` slipping through pyreadstat (SAS missing-value + sentinels, divide-by-zero in the source, uninitialized cells). + * Absurdly large finite floats (e.g. ``1.7e308``) where ``value * 1e9`` + overflows float64. + * Values between ``pd.Timestamp.max`` and float64 safety (~9.5e9 to 1e308 + seconds) where the nanosecond multiply silently produces garbage or + overflows int64. + + All of these trigger ``FloatingPointError: overflow encountered in multiply`` + inside ``pd.to_datetime`` because pandas wraps the multiply in + ``np.errstate(over="raise")`` -- our outer ``errors="coerce"`` never + gets a chance to turn the bad value into ``NaT``. + + Strategy, belt + suspenders + airbag: + + 1. Coerce to float64 up front. Object-dtype branches hand us mixed + int/float/str; ``pd.to_numeric(errors="coerce")`` parses what it can + and NaNs the rest, so we hit the rest of this function with a + pristine float series. + 2. Mask non-finite values and anything outside the safe epoch window to + NaN *before* ``pd.to_datetime`` sees them. + 3. Run the conversion under a permissive ``errstate``. + 4. If that still raises (some pandas version internally re-enables + ``over="raise"`` in a way ``errstate`` can't override), catch it + and return all-NaT for the column with a loud warning. Better a + NULL column in one chunk than a dead worker + no diagnostics. + + Emits one stderr line per chunk per affected column so silent data + loss doesn't sneak by. + """ + if not pd.api.types.is_float_dtype(series): + series = pd.to_numeric(series, errors="coerce").astype("float64") + + arr = series.to_numpy(dtype="float64", copy=False, na_value=np.nan) + if unit == "s": + bound = _SAS_DATETIME_SAFE_S + elif unit == "D": + bound = _SAS_DATETIME_SAFE_D + else: + bound = _SAS_DATETIME_SAFE_S + with np.errstate(over="ignore", invalid="ignore", divide="ignore"): + finite_mask = np.isfinite(arr) + # ``np.abs(inf) -> inf``, ``np.abs(nan) -> nan``; both compare False + # to ``bound``, so ``in_range_mask`` already excludes non-finite + # values. The explicit ``finite_mask &`` below is belt-and-suspenders + # in case a future numpy changes that semantic. + in_range_mask = np.abs(arr) < bound + keep_mask = finite_mask & in_range_mask + was_present = ~np.isnan(arr) + coerced = int(((~keep_mask) & was_present).sum()) + if coerced: + tqdm.write( + f"[warn] {target_type} column {column_name!r}: {coerced:,} " + f"row(s) had non-representable values (Inf/NaN/out-of-range), " + f"coerced to NULL", + file=sys.stderr, + ) + cleaned_arr = np.where(keep_mask, arr, np.nan) + cleaned = pd.Series(cleaned_arr, index=series.index) + try: + with np.errstate(over="ignore", invalid="ignore", divide="ignore"): + return pd.to_datetime( + cleaned, unit=unit, origin="1960-01-01", errors="coerce", + ) + except (FloatingPointError, OverflowError, ValueError) as exc: + tqdm.write( + f"[error] {target_type} column {column_name!r}: " + f"pd.to_datetime raised {type(exc).__name__}: {exc}; " + f"returning NaT for the entire chunk. This usually means one " + f"or more values slipped past the pre-mask (bound={bound}). " + f"Consider setting the column to TEXT via column_types if this " + f"recurs.", + file=sys.stderr, + ) + return pd.Series(pd.NaT, index=series.index, dtype="datetime64[ns]") + + +def _safe_object_to_datetime( + series: pd.Series, + *, + column_name: str, + target_type: str, +) -> pd.Series: + """Object-dtype to datetime. Shares the safety net (errstate + + try/except) with :func:`_safe_numeric_to_datetime`. If the column is + actually numeric-flavored (e.g. SAS wrote numbers into an object + column), route to the numeric path; otherwise parse with ``to_datetime`` + on the object itself. + """ + coerced = series.replace({"": None}) + numeric = pd.to_numeric(coerced, errors="coerce") + all_numeric = numeric.notna().sum() == coerced.notna().sum() + if all_numeric and coerced.notna().any(): + return _safe_numeric_to_datetime( + numeric, unit="s", column_name=column_name, target_type=target_type, + ) + try: + with np.errstate(over="ignore", invalid="ignore", divide="ignore"): + return pd.to_datetime(coerced, errors="coerce") + except (FloatingPointError, OverflowError, ValueError) as exc: + tqdm.write( + f"[error] {target_type} column {column_name!r}: " + f"pd.to_datetime raised {type(exc).__name__}: {exc}; " + f"returning NaT for the entire chunk.", + file=sys.stderr, + ) + return pd.Series(pd.NaT, index=series.index, dtype="datetime64[ns]") + + def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.DataFrame: """Materialize a copy of ``df`` with each column in the right shape for ``to_csv`` so the CSV lands as valid input for the target Postgres type. + + Per-column conversions are vectorized (``.astype`` / ``pd.to_datetime`` / + ``.mask`` / ``.fillna``) instead of the element-wise ``.map(func)`` + loops this function used to run. That was the single largest per-chunk + CPU cost on text-heavy loads - a 40-column × 100k-row chunk was issuing + ~4M Python-level function calls just to cast strings. TIME columns are + still the ``.map`` path because SAS TIME8 is stored as seconds and the + clamp-to-24h logic doesn't fit cleanly in vector form; they're also + rare in practice. """ out = pd.DataFrame(index=df.index) for name, spec in columns.items(): @@ -1686,59 +2159,86 @@ def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.Da if pd.api.types.is_datetime64_any_dtype(series): out[name] = series.dt.date elif pd.api.types.is_object_dtype(series): - def _to_date(v: Any) -> Optional[dt.date]: - if v is None or (isinstance(v, float) and pd.isna(v)): - return None - if isinstance(v, dt.datetime): - return v.date() - if isinstance(v, dt.date): - return v - if isinstance(v, str): - if v == "": - return None - try: - return dt.date.fromisoformat(v) - except ValueError: - return None - return None - out[name] = series.map(_to_date) + # Vectorized parse: empty strings / None / unparseable -> NaT, + # then .dt.date yields date objects or NaT. NaT serializes as + # an empty CSV field (matching ``NULL ''`` in COPY). Routed + # through ``_safe_object_to_datetime`` so an object column + # that actually contains SAS-epoch numerics (seen when one + # file of a cluster stores the column as NUM and another as + # CHAR + the union flipped it to TEXT-then-DATE) can't trip + # the overflow-in-multiply bug. + parsed = _safe_object_to_datetime( + series, column_name=name, target_type="DATE", + ) + out[name] = parsed.dt.date + elif pd.api.types.is_numeric_dtype(series): + # pyreadstat couldn't decode the SAS format (some + # ``DATEw.``/``YYMMDDw.`` variants and all custom formats slip + # through) so the column came back as float64: days since + # 1960-01-01, the SAS epoch. Without this branch the raw + # number would hit COPY and Postgres rejects it with + # ``invalid input syntax for type date``. + parsed = _safe_numeric_to_datetime( + series, unit="D", column_name=name, target_type="DATE", + ) + out[name] = parsed.dt.date else: out[name] = series elif pg in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE", "TIMESTAMP WITH TIME ZONE"): if pd.api.types.is_datetime64_any_dtype(series): out[name] = series elif pd.api.types.is_object_dtype(series): - def _to_dt(v: Any) -> Optional[dt.datetime]: - if v is None or (isinstance(v, float) and pd.isna(v)): - return None - if isinstance(v, dt.datetime): - return v - if isinstance(v, dt.date): - return dt.datetime(v.year, v.month, v.day) - if isinstance(v, pd.Timestamp): - return v.to_pydatetime() if not pd.isna(v) else None - if isinstance(v, str): - if v == "": - return None - try: - return dt.datetime.fromisoformat(v) - except ValueError: - return None - return None - out[name] = series.map(_to_dt) + # Same rationale as the DATE object branch above: route + # through the safety net so numeric-flavored object columns + # can't blow us up during the ns multiply. + out[name] = _safe_object_to_datetime( + series, column_name=name, target_type="TIMESTAMP", + ) + elif pd.api.types.is_numeric_dtype(series): + # Same story as the DATE branch above, but SAS datetimes are + # *seconds* since 1960-01-01 (fractional seconds for + # ``DATETIMEw.d``). Example caught in the wild: + # ``1915465463.615`` -> 2020-09-13 05:44:23.615. + out[name] = _safe_numeric_to_datetime( + series, unit="s", column_name=name, target_type="TIMESTAMP", + ) else: out[name] = series elif pg in ("TIME", "TIME WITHOUT TIME ZONE", "TIME WITH TIME ZONE"): out[name] = series.map(_seconds_to_time) elif pg in ("TEXT", "VARCHAR", "CHARACTER VARYING", "CHAR", "CHARACTER"): - # Leave empty strings as "" so `NULL ''` in COPY turns them into NULL. - def _to_str(v: Any) -> Any: - if v is None: - return "" - if isinstance(v, float) and pd.isna(v): - return "" - return str(v) - out[name] = series.map(_to_str) + # Render every cell as a string and blank out nulls. ``NULL ''`` + # in the COPY statement turns the blanks back into SQL NULL. + # astype(str) stringifies NaN/None to the literal "nan"/"None", + # so we mask those after the fact rather than branching per cell. + na_mask = series.isna() + if pd.api.types.is_numeric_dtype(series): + # Hit when a column was auto-unioned to TEXT because at + # least one file of the cluster stored it as CHAR but this + # particular file stored it as NUM (typical of SAS phone-id + # columns). Default float formatting would emit "123.0" - + # which doesn't match the plain "123" coming from the CHAR + # files. When the whole chunk is integer-valued, round to + # int before stringifying; when any fractional value is + # present we leave float formatting alone so we don't + # silently drop precision. + nonnull = series.dropna() + int_like = False + if not nonnull.empty: + try: + int_like = bool(((nonnull % 1) == 0).all()) + except TypeError: + int_like = False + if int_like: + # ``Int64`` preserves NA; ``.astype(str)`` renders NA + # as '', which we then mask out alongside original + # NaNs. + as_str = series.astype("Int64").astype(str) + out[name] = as_str.mask(na_mask, "") + else: + out[name] = series.astype(str).mask(na_mask, "") + else: + out[name] = series.astype(str).mask(na_mask, "") elif pg == "BOOLEAN": out[name] = series.astype("boolean") if series.dtype != object else series else: @@ -1746,6 +2246,25 @@ def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.Da return out +def _serialize_chunk_csv(prepared: pd.DataFrame) -> io.BytesIO: + """Serialize a prepared frame into a CSV buffer for ``COPY FROM STDIN``. + + Uses ``pyarrow.csv.write_csv`` (typically 5-10× faster than pandas' + pure-Python ``to_csv`` on wide/text-heavy frames). Null cells serialize + as empty strings and date/timestamp values land in ISO 8601 form, both + of which Postgres accepts under ``FORMAT csv, NULL ''``. + """ + table = pa.Table.from_pandas(prepared, preserve_index=False) + buf = io.BytesIO() + pa_csv.write_csv( + table, + buf, + write_options=pa_csv.WriteOptions(include_header=False), + ) + buf.seek(0) + return buf + + def copy_dataframes( conn, schema_name: str, @@ -1768,23 +2287,43 @@ def copy_dataframes( ) total = 0 + # Pull chunks one at a time so each ``df`` is unreferenced before the + # generator reads the next one. Without this the loop-variable binding + # of a ``for df in dfs:`` keeps the previous chunk alive during the + # next pyreadstat read, pushing peak memory to 5-6× chunk size per + # worker (old df + incoming df + prepared + pyarrow table + CSV buf). + # With explicit drops we cap peak at ~2× chunk size: ``df`` goes away + # once ``prepared`` exists, ``prepared`` once ``buf`` exists, ``buf`` + # once COPY has consumed it. Matters most in parallel mode where + # 32 × per-worker peak can exhaust a 128 GB host. + dfs_iter = iter(dfs) with conn.cursor() as cur: - for df in dfs: + while True: + try: + df = next(dfs_iter) + except StopIteration: + break if df.empty: + del df continue prepared = _prepare_for_copy(df, columns) - buf = io.StringIO() - prepared.to_csv( - buf, - index=False, - header=False, - na_rep="", - date_format="%Y-%m-%d %H:%M:%S", - ) - buf.seek(0) + del df + n = len(prepared) + buf = _serialize_chunk_csv(prepared) + del prepared cur.copy_expert(sql, buf) + del buf conn.commit() - total += len(prepared) + total += n + # Hand pyarrow's pool memory back between chunks. Without this, + # arrow's internal buffer pool keeps the high-water bytes + # reserved across the worker's lifetime - inside long-running + # workers this presents as steadily climbing RSS even with the + # ``del``s above. Cheap (microseconds); call it every chunk. + try: + pa.default_memory_pool().release_unused() + except Exception: + pass return total @@ -1899,6 +2438,16 @@ def _build_argparser() -> argparse.ArgumentParser: "PGUSER / PGPASSWORD from the environment or .env file." ), ) + p.add_argument( + "--all-nullable", + action="store_true", + help=( + "Stamp every column nullable in the generated schema, bypassing " + "NOT NULL inference. Use when sampled rows wrongly suggest a " + "column has no nulls. Overrides ``all_nullable`` in the YAML " + "config when set." + ), + ) return p @@ -1921,13 +2470,23 @@ def main(argv: Optional[List[str]] = None) -> int: print(f"error: SAS file not found: {cfg.filename}", file=sys.stderr) return 2 - # Schema inference uses a bounded preview read so we never load a - # hundreds-of-millions-of-rows file into memory just to pick types. - # NB: ``meta.number_rows`` on a ``row_limit``-ed read reflects rows - # returned, not the file's total, so we don't trust it here. + # Schema inference reads the whole file so type + nullability are + # computed against every row. That's what the target host has the + # resources for and is the only way to honestly emit ``NOT NULL`` - + # a bounded preview routinely missed the ~0.2% of rows with nulls on + # otherwise-dense keys (e.g. MAFID). If you're on a box that can't + # fit the file in memory, override ``TYPE_INFERENCE_SAMPLE_ROWS`` to + # an integer cap and know that sampled specs may stamp ``NOT NULL`` + # on columns whose nulls live past the window. preview_df, meta = read_sas_preview(cfg.filename) preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude) - columns = infer_schema(preview_df, meta) + force_nullable = args.all_nullable or cfg.all_nullable + columns = infer_schema( + preview_df, + meta, + column_types=cfg.column_types, + force_nullable=force_nullable, + ) # Validate partition columns exist in the schema after filtering. if cfg.partition_by: @@ -2018,13 +2577,24 @@ def main(argv: Optional[List[str]] = None) -> int: # it while we're holding a Postgres transaction open. del preview_df + total_rows = getattr(meta, "number_rows", None) + def _filtered_chunks(): - seen = 0 - for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename): - chunk_df = apply_column_filter(chunk_df, cfg.include, cfg.exclude) - seen += len(chunk_df) - print(f" streaming... {seen:,} rows", file=sys.stderr) - yield chunk_df + pbar = tqdm( + total=total_rows, + unit="row", + unit_scale=True, + desc=f" {cfg.filename.name}", + file=sys.stderr, + dynamic_ncols=True, + ) + try: + for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename): + chunk_df = apply_column_filter(chunk_df, cfg.include, cfg.exclude) + pbar.update(len(chunk_df)) + yield chunk_df + finally: + pbar.close() db_user = db_password = None if args.dbcreds: diff --git a/generic_loader/sample_config.yaml b/generic_loader/sample_config.yaml index c487769..223cc1a 100644 --- a/generic_loader/sample_config.yaml +++ b/generic_loader/sample_config.yaml @@ -38,3 +38,24 @@ if_exists: append # indexes: # - state # - zip + +# column_types: Explicit {column_name: postgres_type} overrides that +# bypass automatic type inference for the listed columns. Useful when +# pyreadstat reports a column as NUM but you want it stored as TEXT +# (phone/ID columns that are conceptually strings), or when a column's +# inferred type is off for any other reason. Columns not listed here +# fall through to the normal inference path. Nullability is always +# computed from the data. +# +# column_types: +# RESP_PH_PREFIX_ID: TEXT +# SOMELONG_ID: BIGINT + +# all_nullable: If true, every column is stamped nullable in the generated +# schema; NOT NULL inference is skipped entirely. Use this when the sampler +# wrongly concludes a column has no nulls (e.g. a dense sample followed by +# rare-null data downstream) and COPY blows up mid-load on the first null +# it hits. Off by default. The CLI flag --all-nullable overrides this to +# true when set. +# +# all_nullable: false diff --git a/generic_loader/sample_folder_config.yaml b/generic_loader/sample_folder_config.yaml index 5740c3f..4cb394c 100644 --- a/generic_loader/sample_folder_config.yaml +++ b/generic_loader/sample_folder_config.yaml @@ -61,15 +61,52 @@ auto_detect: true # - state # - zip +# Folder-level column_types: Explicit {column_name: postgres_type} map that +# bypasses automatic type inference for the listed columns. Applied to +# every cluster unless a cluster supplies its own column_types, which are +# merged on top (cluster entries win on conflict). +# +# During --workers>1 runs the pre-scan derives a cluster-wide "auto-union" +# type per column (e.g. any file stores the column as CHAR -> TEXT; all +# NUM with any format hinting decimals -> DOUBLE PRECISION; otherwise +# BIGINT). Entries in column_types here win over that auto-union - use +# them when the auto result is wrong or when --no-prescan disables the +# auto-union and you still need to pin a column. +# +# Valid type strings are anything the CREATE TABLE DDL accepts (TEXT, +# INTEGER, BIGINT, DOUBLE PRECISION, DATE, TIMESTAMP, ...). Columns that +# don't exist in a given file are simply ignored for that file. +# +# column_types: +# RESP_PH_PREFIX_ID: TEXT +# RESP_PH_SUFFIX_ID: TEXT +# SOMELONG_ID: BIGINT + +# Folder-level all_nullable: If true, every column of every cluster is +# stamped nullable in the generated schema; NOT NULL inference is skipped +# entirely. Use this when the sampler wrongly concludes a column has no +# nulls (sampled rows happened to be dense, but later files in the cluster +# carry nulls) and COPY blows up mid-load. Inherited by all clusters +# unless a cluster supplies its own all_nullable. The CLI flag +# --all-nullable overrides both this and any per-cluster setting when +# passed. Off by default. +# +# all_nullable: false + # Explicit cluster patterns. Each pattern is matched against the file # *basename*. Files matched by a pattern are pulled out of the auto-detect # pool, so explicit and auto clusters compose cleanly. # -# `tablename` is required. `if_exists`, `include`, and `exclude` are -# optional per-cluster overrides of the folder-level defaults above. +# `tablename` is required. `if_exists`, `include`, `exclude`, and +# `column_types` are optional per-cluster overrides of the folder-level +# defaults above. Cluster-level column_types entries win over folder- +# level entries for the same column. clusters: - pattern: '^group_a\d+\.xpt$' tablename: group_a + # column_types: + # INTCOL: TEXT + # all_nullable: true # per-cluster override of the folder-level default # Example of an explicit override. Uncomment to force the group_b cluster to # append instead of replace even though the folder default is "replace": diff --git a/requirements.txt b/requirements.txt index 3422064..c09f442 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,10 @@ pandas>=2.0,<3.0 pyreadstat>=1.2,<2.0 numpy>=2.1,<3.0 +pyarrow>=22.0,<24.0 pyyaml>=6.0,<7.0 psycopg2-binary>=2.9,<3.0 python-dotenv>=1.0,<2.0 boto3>=1.28,<2.0 +openpyxl>=3.1,<4.0 +tqdm>=4.66,<5.0 diff --git a/utils/sas_profiler.py b/utils/sas_profiler.py new file mode 100644 index 0000000..3adc4d6 --- /dev/null +++ b/utils/sas_profiler.py @@ -0,0 +1,1138 @@ +"""Standalone utility that profiles a single local SAS file and writes an +Excel report with drop, partition, and index candidates plus type-inference +warnings. + +Configure the constants below and run:: + + python3 utils/sas_profiler.py + +Or override any of them from the command line:: + + python3 utils/sas_profiler.py \ + --file ./data/mystate.sas7bdat \ + --out ./reports/mystate_profile.xlsx + +The report is a paste-ready companion to +``generic_loader/load_sas.py`` and ``generic_loader/load_folder.py``: the +"inferred Postgres type" column uses the loader's own ``infer_schema`` so the +drop / partition / index suggestions map one-to-one onto valid YAML config +entries for those scripts. + +Supported inputs: ``.sas7bdat`` / ``.xpt`` / ``.xport`` (whatever the loader +can read). + +Python 3.10+ compatible. +""" + +from __future__ import annotations + +import argparse +import collections +import datetime as dt +import math +import os +import re +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +# The loader lives in a sibling directory that is *not* a proper package +# (no __init__.py). Its own modules import each other by bare name, so we +# add the directory to sys.path before importing it here. +_REPO_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(_REPO_ROOT / "generic_loader")) + +import numpy as np # noqa: E402 +import pandas as pd # noqa: E402 +from openpyxl import Workbook # noqa: E402 +from openpyxl.styles import Alignment, Font, PatternFill # noqa: E402 +from openpyxl.utils import get_column_letter # noqa: E402 + +from load_sas import ( # noqa: E402 + NUMERIC_INT_RANGE, + ColumnSpec, + infer_schema, + iter_sas_chunks, + read_sas_preview, +) + + +# --------------------------------------------------------------------------- +# Configuration - edit these before running, or override via CLI flags +# --------------------------------------------------------------------------- + +SAS_PATH: str = "./generic_loader/samples/sample_kitchensink.xpt" +"""Local path to the .sas7bdat / .xpt / .xport file to profile.""" + +OUTPUT_XLSX: str = "./sas_profile.xlsx" +"""Where to write the Excel report.""" + +HIGH_NULL_PCT: float = 95.0 +"""Columns whose null percentage meets or exceeds this threshold are flagged +as drop candidates.""" + +INDEX_UNIQUENESS_PCT: float = 95.0 +"""Columns whose distinct/non-null ratio meets or exceeds this threshold are +flagged as index candidates.""" + +PARTITION_MIN_FILL_PCT: float = 95.0 +"""Name-matched partition candidates must be non-null in at least this +fraction of rows.""" + +PRE_SHARDED_MAX_DISTINCT: int = 3 +"""A name-matched column with <= this many distinct values is treated as +pre-sharded ("this file is one slice; sibling files have the other values") +rather than as a ready-to-partition observed column.""" + +DISTINCT_CAP: int = 10_000 +"""Max size of the per-column distinct-value set. Exceeding this marks the +column as ``distinct_overflow`` and we report ">= CAP" in the xlsx.""" + +TOP_N_VALUES: int = 5 +"""Number of most-frequent values tracked per column.""" + +PREVIEW_ROWS_FOR_INFERENCE: int = 10_000 +"""Rows pulled from the file for the loader's schema inference. Matches +``load_sas.TYPE_INFERENCE_SAMPLE_ROWS`` so suggestions track the loader.""" + +PROFILE_CHUNK_ROWS: int = 5_000_000 +"""Rows per streaming chunk while profiling. Larger chunks amortize +pyreadstat / pandas overhead, and the profiler is typically run on a +beefy box (e.g. a 128 GB EC2) rather than a laptop, so the default is +set aggressively. + +Rough peak-memory estimate while a chunk is in flight: + + peak_bytes ~= chunksize * num_cols * ~50 bytes/cell * 2-3x + +(The 2-3x factor covers pyreadstat's read buffer + pandas frame +construction temporaries.) At 5M rows x 50 cols that's roughly 10-20 GB, +which is comfortable on a 128 GB host but would OOM a laptop. + +If you have lots of RAM and a very wide file, lower this; if you have a +narrow file and want max throughput, bump it higher with ``--chunksize`` +(the profiler will happily take 20M+ per chunk). If ``chunksize`` is +larger than the file, pyreadstat just hands back one chunk.""" + + +PARTITION_NAME_PATTERNS: Tuple[re.Pattern, ...] = ( + # ``state`` or ``state_code`` / ``statecode`` appearing as a full token + # anywhere in the column name. Uses underscore / start / end as token + # boundaries so we catch STATE, STATE_CODE, HOME_STATE, + # ADDR_LINE3_STATE, BIRTH_STATE_CODE, etc. without matching STATUS, + # ESTATE, INTERSTATE, or STATEWIDE. + re.compile(r"(?:^|_)state(?:_?code)?(?:_|$)", re.IGNORECASE), +) +"""Only columns whose name matches one of these patterns are ever considered +partition candidates. This deliberately ignores generic low-cardinality +signals (status flags, boolean columns, etc.) because in practice the only +useful partition key in this codebase is STATE. Add more patterns here if +that ever stops being true.""" + + +INDEX_NAME_PATTERNS: Tuple[re.Pattern, ...] = ( + re.compile(r"^id$", re.IGNORECASE), + re.compile(r"_id$", re.IGNORECASE), + re.compile(r"_key$", re.IGNORECASE), + re.compile(r"^pk_", re.IGNORECASE), +) +"""Name-bonus patterns for index-candidate ranking.""" + + +# --------------------------------------------------------------------------- +# Per-column streaming aggregator +# --------------------------------------------------------------------------- + + +@dataclass +class _ColumnStats: + """Accumulators updated chunk-by-chunk while streaming the file.""" + + name: str + n_total: int = 0 + n_null: int = 0 + n_empty_str: int = 0 + + distinct: set = field(default_factory=set) + distinct_overflow: bool = False + + top_counts: "collections.Counter[Any]" = field(default_factory=collections.Counter) + + min_val: Any = None + max_val: Any = None + + # Numeric running stats (Welford would be nicer but sum/sum-sq is plenty + # here for a "help me pick columns" report). + numeric_sum: float = 0.0 + numeric_sumsq: float = 0.0 + numeric_count: int = 0 + + # String byte-length stats (helps flag oversized TEXT columns). + str_max_bytes: int = 0 + str_sum_bytes: int = 0 + str_count: int = 0 + + samples: List[Any] = field(default_factory=list) + + def update(self, series: pd.Series) -> None: + """Fold one chunk's worth of this column into the accumulator. + + Implementation notes (this method is the dominant per-file cost): + + - All masks are vectorized - no ``Series.map(lambda ...)`` loops. + - Distinct tracking uses ``Series.value_counts`` so we iterate at + most once per *unique* value in the chunk (in C), not once per + row. + - Once ``distinct_overflow`` is set and ``top_counts`` is full, + subsequent chunks skip the value-counts pass entirely - we already + know the column is too varied to be a partition / drop candidate + and we already have the top-N. + """ + n = len(series) + if n == 0: + return + self.n_total += n + + is_object = pd.api.types.is_object_dtype(series) + is_numeric = pd.api.types.is_numeric_dtype(series) + is_datetime = pd.api.types.is_datetime64_any_dtype(series) + + if is_object: + # Vectorized equivalent of load_sas._char_missing_mask: treat + # None / NaN / empty string as missing. ``series == ""`` is + # False for non-string values so we don't need per-element type + # checks. + na_mask = series.isna() + empty_mask = (series == "") & ~na_mask + miss_mask = na_mask | empty_mask + self.n_empty_str += int(empty_mask.sum()) + else: + miss_mask = series.isna() + + self.n_null += int(miss_mask.sum()) + non_null = series[~miss_mask] if miss_mask.any() else series + if non_null.empty: + return + + # -- Numeric stats (C-level) --------------------------------------- + if is_numeric: + arr = non_null.to_numpy(dtype="float64", copy=False, na_value=np.nan) + # NaN-safe aggregates in one pass each (all C-level). + self.numeric_sum += float(np.nansum(arr)) + self.numeric_sumsq += float(np.nansum(arr * arr)) + self.numeric_count += int(arr.size) + cmin = float(np.nanmin(arr)) if arr.size else None + cmax = float(np.nanmax(arr)) if arr.size else None + if cmin is not None and (self.min_val is None or cmin < self.min_val): + self.min_val = cmin + if cmax is not None and (self.max_val is None or cmax > self.max_val): + self.max_val = cmax + + elif is_datetime: + cmin = non_null.min() + cmax = non_null.max() + if self.min_val is None or cmin < self.min_val: + self.min_val = cmin + if self.max_val is None or cmax > self.max_val: + self.max_val = cmax + + # -- String length stats via vectorized str.len -------------------- + # ``.str.len()`` is C-fast; for ASCII-dominated SAS data it matches + # UTF-8 byte length closely enough for the "oversized TEXT" flag. + if is_object: + lens = non_null.astype(str, copy=False).str.len() + lens = lens.dropna() + if not lens.empty: + bmax = int(lens.max()) + if bmax > self.str_max_bytes: + self.str_max_bytes = bmax + self.str_sum_bytes += int(lens.sum()) + self.str_count += int(lens.size) + + # -- Samples (tiny slice; free) ------------------------------------ + if len(self.samples) < 3: + needed = 3 - len(self.samples) + self.samples.extend(non_null.head(needed).tolist()) + + # -- Distinct / top_counts (vectorized via value_counts) ----------- + # Skip altogether once we're saturated: distinct is already known + # to be > DISTINCT_CAP and top_counts has its DISTINCT_CAP slots + # filled, so further value_counts calls can only bump existing + # keys - info we don't need for any of the classifiers. + top_full = len(self.top_counts) >= DISTINCT_CAP + if self.distinct_overflow and top_full: + return + + try: + vc = non_null.value_counts(sort=False) + except TypeError: + # Unhashable values (list/dict). Drop the column from both + # distinct and top-N tracking. + self.distinct_overflow = True + return + + if vc.empty: + return + + if not self.distinct_overflow: + # Only *new* values need to be considered for the distinct set. + for val in vc.index: + if val in self.distinct: + continue + if len(self.distinct) >= DISTINCT_CAP: + self.distinct_overflow = True + break + self.distinct.add(val) + + if not top_full: + # Bulk-merge known keys; cap adds for new keys. + for val, count in zip(vc.index.tolist(), vc.to_numpy().tolist()): + if val in self.top_counts: + self.top_counts[val] += int(count) + elif len(self.top_counts) < DISTINCT_CAP: + self.top_counts[val] = int(count) + # else: silently skip - we're past the cap. + else: + # Only existing keys can grow. + tc = self.top_counts + for val, count in zip(vc.index.tolist(), vc.to_numpy().tolist()): + if val in tc: + tc[val] += int(count) + + # -- Derived properties ------------------------------------------------ + + @property + def n_non_null(self) -> int: + return self.n_total - self.n_null + + @property + def null_pct(self) -> float: + if self.n_total == 0: + return 0.0 + return 100.0 * self.n_null / self.n_total + + @property + def fill_pct(self) -> float: + return 100.0 - self.null_pct + + @property + def distinct_count(self) -> int: + return len(self.distinct) + + @property + def distinct_display(self) -> str: + if self.distinct_overflow: + return f">= {DISTINCT_CAP:,}" + return f"{self.distinct_count:,}" + + @property + def mean(self) -> Optional[float]: + if self.numeric_count == 0: + return None + return self.numeric_sum / self.numeric_count + + @property + def std(self) -> Optional[float]: + if self.numeric_count < 2: + return None + mean = self.mean + var = self.numeric_sumsq / self.numeric_count - (mean * mean) + # Guard against tiny negative from floating point noise. + if var < 0: + var = 0.0 + return math.sqrt(var) + + @property + def top_value(self) -> Tuple[Any, int]: + if not self.top_counts: + return (None, 0) + return self.top_counts.most_common(1)[0] + + def top_values(self, n: int = TOP_N_VALUES) -> List[Tuple[Any, int]]: + return self.top_counts.most_common(n) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _matches_any(patterns: Tuple[re.Pattern, ...], name: str) -> bool: + return any(p.search(name) for p in patterns) + + +def _format_size(n_bytes: int) -> str: + size = float(n_bytes) + for unit in ("B", "KB", "MB", "GB", "TB"): + if size < 1024.0 or unit == "TB": + return f"{size:,.1f} {unit}" + size /= 1024.0 + return f"{size:,.1f} TB" + + +def _format_value(val: Any) -> str: + """Render a single Python value for display in the spreadsheet.""" + if val is None: + return "" + if isinstance(val, float) and pd.isna(val): + return "" + if isinstance(val, (pd.Timestamp, dt.date, dt.datetime)): + return str(val) + return repr(val) if isinstance(val, str) else str(val) + + +def _format_top_values(pairs: List[Tuple[Any, int]]) -> str: + if not pairs: + return "" + return ", ".join(f"{_format_value(v)} ({c:,})" for v, c in pairs) + + +def _format_samples(samples: List[Any]) -> str: + if not samples: + return "(all null)" + return ", ".join(_format_value(v) for v in samples) + + +# --------------------------------------------------------------------------- +# Streaming profile +# --------------------------------------------------------------------------- + + +def profile_file( + path: Path, + *, + chunksize: Optional[int] = None, +) -> Tuple[Dict[str, _ColumnStats], Dict[str, ColumnSpec], Any, int]: + """Stream ``path`` once, returning (stats, columns, meta, total_rows). + + ``columns`` is the loader's inferred schema from the first + ``PREVIEW_ROWS_FOR_INFERENCE`` rows - identical to what ``load_sas`` + would use. ``stats`` are the full-file observations we add on top. + """ + preview_df, meta = read_sas_preview(path, rows=PREVIEW_ROWS_FOR_INFERENCE) + total_rows_hint = getattr(meta, "number_rows", None) + columns = infer_schema(preview_df, meta, total_rows=total_rows_hint) + + stats: Dict[str, _ColumnStats] = { + name: _ColumnStats(name=name) for name in columns + } + + total_rows = 0 + effective_chunksize = chunksize if chunksize is not None else PROFILE_CHUNK_ROWS + kwargs = {"chunksize": effective_chunksize} + # pyreadstat + pandas are both C-level; the per-chunk overhead we pay + # is dominated by the value_counts passes in _ColumnStats.update, so + # the profile runs O(total_rows) with a small constant. + import time + started_at = time.monotonic() + last_print_at = started_at + for chunk_df, _chunk_meta in iter_sas_chunks(path, **kwargs): + total_rows += len(chunk_df) + for name, cs in stats.items(): + if name not in chunk_df.columns: + continue + cs.update(chunk_df[name]) + now = time.monotonic() + # Throttle progress output to ~one line per 2 seconds so huge files + # don't spam stderr but small files still print at least once. + if now - last_print_at >= 2.0: + elapsed = now - started_at + rate = total_rows / elapsed if elapsed > 0 else 0.0 + print( + f" profiling... {total_rows:,} rows " + f"({rate:,.0f} rows/s)", + file=sys.stderr, + ) + last_print_at = now + + elapsed = time.monotonic() - started_at + rate = total_rows / elapsed if elapsed > 0 else 0.0 + print( + f" profiled {total_rows:,} rows in {elapsed:.1f}s " + f"({rate:,.0f} rows/s)", + file=sys.stderr, + ) + + return stats, columns, meta, total_rows + + +# --------------------------------------------------------------------------- +# Classifiers +# --------------------------------------------------------------------------- + + +@dataclass +class _DropCandidate: + name: str + reason: str + + +@dataclass +class _PartitionCandidate: + name: str + kind: str # "observed" or "pre_sharded" + distinct_count: int + fill_pct: float + top_values: str + observed_values_in_file: str + note: str + score: float + + +@dataclass +class _IndexCandidate: + name: str + uniqueness_pct: float + distinct_count: int + fill_pct: float + name_bonus: bool + note: str + score: float + + +@dataclass +class _TypeWarning: + name: str + severity: str # "info" | "warn" | "error" + message: str + + +def _is_constant_like(cs: _ColumnStats) -> bool: + """True when the column is effectively a single value (possibly with + a handful of nulls / empties mixed in).""" + if cs.n_non_null == 0: + return False + return cs.distinct_count == 1 and not cs.distinct_overflow + + +def classify( + stats: Dict[str, _ColumnStats], + columns: Dict[str, ColumnSpec], + *, + high_null_pct: float, + index_uniqueness_pct: float, + partition_min_fill_pct: float, + pre_sharded_max_distinct: int, +) -> Tuple[ + List[_DropCandidate], + List[_PartitionCandidate], + List[_IndexCandidate], + List[_TypeWarning], +]: + """Turn per-column stats + the loader's schema into four ranked lists. + + Partition candidates are restricted to columns whose name matches + :data:`PARTITION_NAME_PATTERNS` - in practice STATE / STATE_CODE. A + generic "low-cardinality = partition candidate" heuristic produces too + much noise for this codebase, so we only surface columns we're confident + about by name. + """ + + drops: List[_DropCandidate] = [] + partitions: List[_PartitionCandidate] = [] + indexes: List[_IndexCandidate] = [] + warnings: List[_TypeWarning] = [] + + # -- Partition candidates (name-matched only) -------------------------- + # Run this before the drop check so pre-sharded STATE columns don't get + # silently dropped for being "constant". + claimed_by_partition: set = set() + for name, cs in stats.items(): + if not _matches_any(PARTITION_NAME_PATTERNS, name): + continue + if cs.n_total == 0 or cs.n_non_null == 0: + continue + if cs.fill_pct < partition_min_fill_pct: + continue + + is_pre_sharded = ( + not cs.distinct_overflow + and cs.distinct_count <= pre_sharded_max_distinct + ) + kind = "pre_sharded" if is_pre_sharded else "observed" + observed = _format_top_values(cs.top_values(pre_sharded_max_distinct)) + + if is_pre_sharded: + note = ( + f"pre-sharded: this file only contains {cs.distinct_count} " + f"distinct value(s) ({observed}); keep the column and set " + "partition_by at the load_folder level so sibling files merge " + "into separate partitions of one table" + ) + else: + note = ( + f"observed {cs.distinct_display} distinct value(s) across " + f"{cs.fill_pct:.1f}% of rows; LIST partitioning will create " + "one child table per distinct value" + ) + + partitions.append( + _PartitionCandidate( + name=name, + kind=kind, + distinct_count=cs.distinct_count, + fill_pct=cs.fill_pct, + top_values=_format_top_values(cs.top_values()), + observed_values_in_file=observed, + note=note, + # Pre-sharded beats observed as the snippet's top pick. + score=(1_000_000.0 if is_pre_sharded else 500_000.0) + cs.fill_pct, + ) + ) + claimed_by_partition.add(name) + + partitions.sort(key=lambda p: p.score, reverse=True) + + # -- Drop candidates --------------------------------------------------- + for name, cs in stats.items(): + if name in claimed_by_partition: + continue + if cs.n_total == 0: + continue + + reason: Optional[str] = None + if cs.n_null == cs.n_total: + reason = "all-null" + elif ( + cs.n_non_null > 0 + and cs.distinct_count == 0 + and not cs.distinct_overflow + ): + # Non-null but nothing hashable captured - treat as opaque. + reason = "all-empty / unhashable" + elif cs.n_non_null == cs.n_empty_str and cs.n_empty_str > 0: + reason = "all-empty" + elif _is_constant_like(cs): + only_val = next(iter(cs.distinct)) + reason = f"constant={_format_value(only_val)}" + elif cs.null_pct >= high_null_pct: + reason = f"null_pct={cs.null_pct:.1f}%" + + if reason is not None: + drops.append(_DropCandidate(name=name, reason=reason)) + + dropped_names = {d.name for d in drops} + partition_names = {p.name for p in partitions} + + # -- Index candidates -------------------------------------------------- + for name, cs in stats.items(): + if name in dropped_names or name in partition_names: + continue + spec = columns.get(name) + if spec is None: + continue + if cs.n_non_null == 0: + continue + if cs.distinct_overflow: + # Super-high-cardinality → perfect candidate for an index. + uniqueness = 100.0 + distinct_count = DISTINCT_CAP # display sentinel + else: + uniqueness = 100.0 * cs.distinct_count / cs.n_non_null + distinct_count = cs.distinct_count + if uniqueness < index_uniqueness_pct: + continue + + name_bonus = _matches_any(INDEX_NAME_PATTERNS, name) + notes: List[str] = [] + if name_bonus: + notes.append("name matches INDEX_NAME_PATTERNS (ID/KEY-ish)") + if cs.distinct_overflow: + notes.append( + f"distinct tracking capped at {DISTINCT_CAP:,}; " + "treating as high-cardinality" + ) + + # Rank: name match dominates, then raw uniqueness, then fill. + score = (500_000.0 if name_bonus else 0.0) + uniqueness + cs.fill_pct / 100.0 + + indexes.append( + _IndexCandidate( + name=name, + uniqueness_pct=uniqueness, + distinct_count=distinct_count, + fill_pct=cs.fill_pct, + name_bonus=name_bonus, + note="; ".join(notes), + score=score, + ) + ) + + indexes.sort(key=lambda i: i.score, reverse=True) + + # -- Type warnings ----------------------------------------------------- + for name, cs in stats.items(): + spec = columns.get(name) + if spec is None: + continue + + # Re-surface whatever the loader's own inference already flagged in + # notes - these are genuinely useful for the user to see without + # having to dry-run the loader. + for note in spec.notes: + warnings.append( + _TypeWarning(name=name, severity="info", message=note) + ) + if spec.sampled: + warnings.append( + _TypeWarning( + name=name, + severity="info", + message=( + "loader inferred type from a bounded preview; " + "sampled=True" + ), + ) + ) + + pg_type = spec.postgres_type.upper() + + # Preview said NOT NULL but the full file has nulls - loader would + # have emitted NOT NULL and then choked on COPY. + if not spec.nullable and cs.n_null > 0: + warnings.append( + _TypeWarning( + name=name, + severity="error", + message=( + f"preview saw zero nulls (NOT NULL) but full file has " + f"{cs.n_null:,} null(s); COPY would fail under the " + "loader's inferred NOT NULL" + ), + ) + ) + + # INTEGER range check against the full-file observed min/max. + if pg_type == "INTEGER" and cs.numeric_count > 0: + lo, hi = NUMERIC_INT_RANGE + vmin = cs.min_val if cs.min_val is not None else 0 + vmax = cs.max_val if cs.max_val is not None else 0 + try: + if vmin < lo or vmax > hi: + warnings.append( + _TypeWarning( + name=name, + severity="error", + message=( + f"loader inferred INTEGER from the preview but " + f"full-file range [{vmin}, {vmax}] overflows " + f"int4 {NUMERIC_INT_RANGE}; BIGINT required" + ), + ) + ) + except TypeError: + pass + + # Preview said all-null (loader defaults to TEXT) but data exists. + was_all_null_preview = any( + "all-null column" in n for n in spec.notes + ) + if was_all_null_preview and cs.n_non_null > 0: + warnings.append( + _TypeWarning( + name=name, + severity="warn", + message=( + "preview was all-null so loader defaulted to TEXT, " + f"but full file has {cs.n_non_null:,} non-null " + "value(s); consider a tighter include/exclude or " + "re-inferring with TYPE_INFERENCE_SAMPLE_ROWS=None" + ), + ) + ) + + return drops, partitions, indexes, warnings + + +# --------------------------------------------------------------------------- +# YAML snippet +# --------------------------------------------------------------------------- + + +def render_yaml_snippet( + drops: List[_DropCandidate], + partitions: List[_PartitionCandidate], + indexes: List[_IndexCandidate], +) -> str: + """Produce a paste-ready YAML snippet for the loader config.""" + lines: List[str] = ["# Suggested additions to your load_sas.py / load_folder.py config"] + + if drops: + lines.append("exclude:") + for d in drops: + lines.append(f" - {d.name} # {d.reason}") + else: + lines.append("# (no drop candidates found)") + + lines.append("") + + if partitions: + top = partitions[0] + if top.kind == "pre_sharded": + lines.append( + f"# !! PRE-SHARDED: this file only contains " + f"{top.name} = {top.observed_values_in_file}." + ) + lines.append( + "# !! Keep the column in the schema and set partition_by at the " + "load_folder level" + ) + lines.append( + "# !! so sibling files merge into one table under separate " + "partitions." + ) + lines.append("partition_by:") + lines.append(f" - {top.name}") + else: + lines.append( + "# (no partition candidates found - no column matched " + "PARTITION_NAME_PATTERNS)" + ) + + lines.append("") + + if indexes: + lines.append("indexes:") + for i in indexes: + bonus = " (name match)" if i.name_bonus else "" + lines.append( + f" - {i.name} # uniqueness={i.uniqueness_pct:.1f}%{bonus}" + ) + else: + lines.append("# (no index candidates found)") + + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# XLSX writer +# --------------------------------------------------------------------------- + + +_HEADER_FONT = Font(bold=True, color="FFFFFF") +_HEADER_FILL = PatternFill("solid", fgColor="305496") +_WARN_FILL = PatternFill("solid", fgColor="FFE699") +_ERROR_FILL = PatternFill("solid", fgColor="F4B183") + + +def _write_header(ws, headers: List[str]) -> None: + for col_idx, label in enumerate(headers, start=1): + cell = ws.cell(row=1, column=col_idx, value=label) + cell.font = _HEADER_FONT + cell.fill = _HEADER_FILL + cell.alignment = Alignment(vertical="center") + ws.freeze_panes = "A2" + + +def _autosize(ws, *, max_width: int = 60) -> None: + for col_cells in ws.columns: + letter = get_column_letter(col_cells[0].column) + longest = 0 + for cell in col_cells: + if cell.value is None: + continue + text = str(cell.value) + # Only measure the first line so a long YAML cell doesn't push + # everything else ultra-wide. + longest = max(longest, min(len(text.split("\n", 1)[0]), max_width)) + ws.column_dimensions[letter].width = min(max(longest + 2, 10), max_width) + + +def _write_overview( + ws, + *, + path: Path, + size_bytes: int, + total_rows: int, + total_cols: int, + thresholds: Dict[str, Any], +) -> None: + ws.cell(row=1, column=1, value="Field").font = _HEADER_FONT + ws.cell(row=1, column=1).fill = _HEADER_FILL + ws.cell(row=1, column=2, value="Value").font = _HEADER_FONT + ws.cell(row=1, column=2).fill = _HEADER_FILL + ws.freeze_panes = "A2" + + rows = [ + ("File path", str(path)), + ("File size", _format_size(size_bytes)), + ("Extension", path.suffix.lower()), + ("Total rows", f"{total_rows:,}"), + ("Total columns", f"{total_cols:,}"), + ("Generated at", dt.datetime.now().isoformat(timespec="seconds")), + ] + for k, v in thresholds.items(): + rows.append((f"threshold: {k}", str(v))) + + for i, (k, v) in enumerate(rows, start=2): + ws.cell(row=i, column=1, value=k) + ws.cell(row=i, column=2, value=v) + + _autosize(ws) + + +def _write_columns( + ws, + stats: Dict[str, _ColumnStats], + columns: Dict[str, ColumnSpec], +) -> None: + headers = [ + "column", "sas_format", "source_dtype", "inferred_postgres_type", + "nullable", "n_total", "n_null", "null_pct", "distinct_count", + "min", "max", "mean", "std", + "top_value", "top_count", + "max_str_bytes", "mean_str_bytes", + "sample_values", "notes", + ] + _write_header(ws, headers) + + for row_idx, (name, cs) in enumerate(stats.items(), start=2): + spec = columns.get(name) + top_val, top_count = cs.top_value + mean_bytes = (cs.str_sum_bytes / cs.str_count) if cs.str_count else None + values = [ + name, + spec.sas_format if spec else "", + spec.source_dtype if spec else "", + spec.postgres_type if spec else "", + "YES" if (spec and spec.nullable) else "NO", + cs.n_total, + cs.n_null, + round(cs.null_pct, 3), + cs.distinct_display, + _format_value(cs.min_val), + _format_value(cs.max_val), + round(cs.mean, 6) if cs.mean is not None else "", + round(cs.std, 6) if cs.std is not None else "", + _format_value(top_val), + top_count or "", + cs.str_max_bytes or "", + round(mean_bytes, 2) if mean_bytes is not None else "", + _format_samples(cs.samples), + "; ".join(spec.notes) if spec and spec.notes else "", + ] + for col_idx, v in enumerate(values, start=1): + ws.cell(row=row_idx, column=col_idx, value=v) + + _autosize(ws) + + +def _write_drop(ws, drops: List[_DropCandidate]) -> None: + headers = ["column", "reason"] + _write_header(ws, headers) + if not drops: + ws.cell(row=2, column=1, value="(no drop candidates)") + for i, d in enumerate(drops, start=2): + ws.cell(row=i, column=1, value=d.name) + ws.cell(row=i, column=2, value=d.reason) + _autosize(ws) + + +def _write_partition(ws, partitions: List[_PartitionCandidate]) -> None: + headers = [ + "rank", "column", "kind", "distinct_count", "fill_pct", + "observed_values_in_file", "top_values", "score", "note", + ] + _write_header(ws, headers) + if not partitions: + ws.cell(row=2, column=1, value="(no partition candidates)") + for rank, p in enumerate(partitions, start=1): + row = rank + 1 + ws.cell(row=row, column=1, value=rank) + ws.cell(row=row, column=2, value=p.name) + ws.cell(row=row, column=3, value=p.kind) + ws.cell(row=row, column=4, value=p.distinct_count) + ws.cell(row=row, column=5, value=round(p.fill_pct, 3)) + ws.cell(row=row, column=6, value=p.observed_values_in_file) + ws.cell(row=row, column=7, value=p.top_values) + ws.cell(row=row, column=8, value=round(p.score, 3)) + ws.cell(row=row, column=9, value=p.note) + if p.kind == "pre_sharded": + for col in range(1, len(headers) + 1): + ws.cell(row=row, column=col).fill = _WARN_FILL + _autosize(ws) + + +def _write_index(ws, indexes: List[_IndexCandidate]) -> None: + headers = [ + "rank", "column", "uniqueness_pct", "distinct_count", "fill_pct", + "name_bonus", "score", "note", + ] + _write_header(ws, headers) + if not indexes: + ws.cell(row=2, column=1, value="(no index candidates)") + for rank, i in enumerate(indexes, start=1): + row = rank + 1 + ws.cell(row=row, column=1, value=rank) + ws.cell(row=row, column=2, value=i.name) + ws.cell(row=row, column=3, value=round(i.uniqueness_pct, 3)) + ws.cell(row=row, column=4, value=i.distinct_count) + ws.cell(row=row, column=5, value=round(i.fill_pct, 3)) + ws.cell(row=row, column=6, value="YES" if i.name_bonus else "NO") + ws.cell(row=row, column=7, value=round(i.score, 3)) + ws.cell(row=row, column=8, value=i.note) + _autosize(ws) + + +def _write_warnings(ws, warnings: List[_TypeWarning]) -> None: + headers = ["column", "severity", "message"] + _write_header(ws, headers) + if not warnings: + ws.cell(row=2, column=1, value="(no type warnings)") + for i, w in enumerate(warnings, start=2): + ws.cell(row=i, column=1, value=w.name) + ws.cell(row=i, column=2, value=w.severity) + ws.cell(row=i, column=3, value=w.message) + fill = None + if w.severity == "error": + fill = _ERROR_FILL + elif w.severity == "warn": + fill = _WARN_FILL + if fill is not None: + for col in range(1, len(headers) + 1): + ws.cell(row=i, column=col).fill = fill + _autosize(ws) + + +def _write_yaml_sheet(ws, snippet: str) -> None: + ws.cell(row=1, column=1, value="YAML suggestion (paste into your loader config)").font = _HEADER_FONT + ws.cell(row=1, column=1).fill = _HEADER_FILL + cell = ws.cell(row=2, column=1, value=snippet) + cell.alignment = Alignment(wrap_text=True, vertical="top") + # Pick a comfy width for YAML; row height is auto when wrap_text is on. + ws.column_dimensions["A"].width = 100 + + +def write_report( + out_path: Path, + *, + path: Path, + size_bytes: int, + total_rows: int, + stats: Dict[str, _ColumnStats], + columns: Dict[str, ColumnSpec], + drops: List[_DropCandidate], + partitions: List[_PartitionCandidate], + indexes: List[_IndexCandidate], + warnings: List[_TypeWarning], + yaml_snippet: str, + thresholds: Dict[str, Any], +) -> None: + wb = Workbook() + ws = wb.active + ws.title = "Overview" + _write_overview( + ws, + path=path, + size_bytes=size_bytes, + total_rows=total_rows, + total_cols=len(columns), + thresholds=thresholds, + ) + _write_columns(wb.create_sheet("Columns"), stats, columns) + _write_drop(wb.create_sheet("Drop candidates"), drops) + _write_partition(wb.create_sheet("Partition candidates"), partitions) + _write_index(wb.create_sheet("Index candidates"), indexes) + _write_warnings(wb.create_sheet("Type warnings"), warnings) + _write_yaml_sheet(wb.create_sheet("YAML suggestion"), yaml_snippet) + wb.save(out_path) + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def _build_argparser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( + description=( + "Profile a local SAS file (.sas7bdat / .xpt / .xport) and write " + "an Excel report with drop, partition_by, and index suggestions " + "for generic_loader/load_sas.py and load_folder.py." + ), + ) + p.add_argument("--file", type=Path, default=Path(SAS_PATH), + help=f"Path to the SAS file to profile (default: {SAS_PATH!r}).") + p.add_argument("--out", type=Path, default=Path(OUTPUT_XLSX), + help=f"Where to write the .xlsx report (default: {OUTPUT_XLSX!r}).") + p.add_argument("--high-null-pct", type=float, default=HIGH_NULL_PCT, + help="Null percentage at/above which a column is a drop candidate.") + p.add_argument("--index-uniqueness-pct", type=float, default=INDEX_UNIQUENESS_PCT, + help="Uniqueness (distinct/non-null) at/above which a column is an index candidate.") + p.add_argument("--partition-min-fill-pct", type=float, default=PARTITION_MIN_FILL_PCT) + p.add_argument("--pre-sharded-max-distinct", type=int, default=PRE_SHARDED_MAX_DISTINCT) + p.add_argument( + "--chunksize", type=int, default=None, + help=( + "Rows per streaming read. Bigger chunks amortize pyreadstat / " + "pandas overhead (faster for huge files) but use more peak " + f"memory. Defaults to PROFILE_CHUNK_ROWS ({PROFILE_CHUNK_ROWS:,})." + ), + ) + return p + + +def main(argv: Optional[List[str]] = None) -> int: + args = _build_argparser().parse_args(argv) + + path: Path = args.file + out_path: Path = args.out + + if not path.exists(): + print(f"error: SAS file not found: {path}", file=sys.stderr) + return 2 + + print(f"profiling {path} -> {out_path}", file=sys.stderr) + stats, columns, meta, total_rows = profile_file(path, chunksize=args.chunksize) + + drops, partitions, indexes, warnings = classify( + stats, columns, + high_null_pct=args.high_null_pct, + index_uniqueness_pct=args.index_uniqueness_pct, + partition_min_fill_pct=args.partition_min_fill_pct, + pre_sharded_max_distinct=args.pre_sharded_max_distinct, + ) + + yaml_snippet = render_yaml_snippet(drops, partitions, indexes) + + thresholds = { + "HIGH_NULL_PCT": args.high_null_pct, + "INDEX_UNIQUENESS_PCT": args.index_uniqueness_pct, + "PARTITION_MIN_FILL_PCT": args.partition_min_fill_pct, + "PRE_SHARDED_MAX_DISTINCT": args.pre_sharded_max_distinct, + "PARTITION_NAME_PATTERNS": ", ".join(p.pattern for p in PARTITION_NAME_PATTERNS), + "DISTINCT_CAP": DISTINCT_CAP, + "TOP_N_VALUES": TOP_N_VALUES, + "PREVIEW_ROWS_FOR_INFERENCE": PREVIEW_ROWS_FOR_INFERENCE, + } + + out_path.parent.mkdir(parents=True, exist_ok=True) + write_report( + out_path, + path=path, + size_bytes=os.path.getsize(path), + total_rows=total_rows, + stats=stats, + columns=columns, + drops=drops, + partitions=partitions, + indexes=indexes, + warnings=warnings, + yaml_snippet=yaml_snippet, + thresholds=thresholds, + ) + + print( + f"wrote {out_path} ({len(stats)} columns, {total_rows:,} rows scanned)\n" + f" drops: {len(drops)}\n" + f" partitions: {len(partitions)}\n" + f" indexes: {len(indexes)}\n" + f" warnings: {len(warnings)}", + file=sys.stderr, + ) + return 0 + + +if __name__ == "__main__": + sys.exit(main())