diff --git a/generic_loader/load_folder.py b/generic_loader/load_folder.py index 82fe459..b789888 100644 --- a/generic_loader/load_folder.py +++ b/generic_loader/load_folder.py @@ -1,8 +1,9 @@ -"""Folder-level SAS-to-Postgres loader. +"""Folder-level data-to-Postgres loader. -Wraps :mod:`load_sas` so an entire directory of SAS files can be ingested in -one invocation. A directory often contains several *clusters* of files that -share a schema (e.g. ``group_a1.sas7bdat``, ``group_a2.sas7bdat``, ...). Each +Wraps :mod:`load_sas` so an entire directory of data files (SAS or delimited +text) can be ingested in one invocation. A directory often contains several +*clusters* of files that share a schema (e.g. ``group_a1.sas7bdat``, +``group_a2.sas7bdat``, ... or ``group_a1.csv``, ``group_a2.csv``, ...). Each cluster becomes one Postgres table; files inside a cluster are appended to it. ------------------------------------------------------------------------------- @@ -32,6 +33,17 @@ USAGE # include: [ID, INTCOL] # exclude: [ALLNULL] + # Optional folder default for explicit column type overrides. These + # win over the cluster-wide auto-union computed during pre-scan; set + # them when a column's SAS-level type varies across files (e.g. phone + # IDs stored as CHAR in some years and NUM in others) and you want to + # pin the Postgres type yourself rather than accept the auto-derived + # one. Per-cluster column_types inside each clusters[*] entry are + # merged on top of this map. + # column_types: + # RESP_PH_PREFIX_ID: TEXT + # SOME_BIGINT_COL: BIGINT + # Optional folder default for LIST partitioning. Omit or set [] for no # partitioning. Accepts a single string or a list of column names. # partition_by: @@ -43,14 +55,16 @@ USAGE # Optional explicit cluster patterns. Each pattern is matched against the # file *basename*. Matched files are pulled out of the auto-detect pool. - # Per-cluster if_exists/include/exclude/partition_by/max_partitions - # override the folder-level defaults. + # Per-cluster if_exists/include/exclude/partition_by/max_partitions/ + # column_types override the folder-level defaults. clusters: - pattern: '^group_a\\d+\\.sas7bdat$' tablename: group_a - pattern: '^group_b\\d+\\.sas7bdat$' tablename: group_b if_exists: replace + column_types: + PHONE_PREFIX: TEXT 2. Command-line interface ------------------------- @@ -77,12 +91,14 @@ Flags: Exit codes: 0 - every cluster loaded successfully (or dry-run completed) 1 - at least one cluster failed (details on stderr) - 2 - folder does not exist / contains no SAS files + 2 - folder does not exist / contains no data files 3. Discovery rules ------------------ -* Supported extensions: ``.sas7bdat``, ``.xpt``, ``.xport`` (matches - :mod:`load_sas`). The folder is not scanned recursively. +* Supported SAS extensions: ``.sas7bdat``, ``.xpt``, ``.xport``. + Supported text extensions: ``.txt``, ``.csv``, ``.tsv``. + The ``file_type`` config key controls which set is used. + The folder is not scanned recursively. * Explicit patterns are tried in order. A file matched by one pattern is removed from the pool before the next pattern runs, so earlier patterns win in case of overlap. Overlap between patterns is flagged as an error @@ -132,18 +148,32 @@ from __future__ import annotations import argparse import getpass +import multiprocessing as mp +import os +import queue as _queue_mod import re import sys +import threading +from concurrent.futures import ( + CancelledError, + ProcessPoolExecutor, + ThreadPoolExecutor, + as_completed, +) from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import yaml from dotenv import load_dotenv +from tqdm import tqdm from load_sas import ( + TEXT_EXTENSIONS, + VALID_FILE_TYPES, VALID_IF_EXISTS, _count_partitions, + _is_text_file, _merge_partition_trees, apply_column_filter, assert_schema_compatible, @@ -152,16 +182,20 @@ from load_sas import ( create_indexes, create_table, discover_partition_values_chunked, + extract_union_metadata, infer_schema, iter_sas_chunks, + read_sas_metadata, read_sas_preview, render_create_indexes, render_create_table, render_partition_ddl, + union_column_types, ) SAS_EXTENSIONS = (".sas7bdat", ".xpt", ".xport") +SUPPORTED_EXTENSIONS = SAS_EXTENSIONS + TEXT_EXTENSIONS # --------------------------------------------------------------------------- @@ -175,7 +209,16 @@ class ClusterSpec: ``partition_by``, ``max_partitions``, and ``indexes`` are resolved from the folder defaults and any per-cluster overrides during - :func:`discover_clusters`. + :func:`discover_clusters`. ``column_types`` holds the effective type + overrides for this cluster: user-supplied YAML entries merged on top + of the auto-union result computed during pre-scan (see :func:`main`). + The same dict is threaded through to workers so every file in the + cluster infers the same schema. + + Text-file config (``file_type``, ``delimiter``, ``text_encoding``, + ``quotechar``) is propagated from the folder config during + :func:`discover_clusters` so that reader functions receive the correct + parameters for delimited text files. """ tablename: str @@ -188,6 +231,12 @@ class ClusterSpec: partition_by: List[str] = field(default_factory=list) max_partitions: int = 10_000 indexes: List[str] = field(default_factory=list) + column_types: Dict[str, str] = field(default_factory=dict) + all_nullable: bool = False + file_type: str = "sas" + delimiter: str = "," + text_encoding: str = "utf-8" + quotechar: str = '"' @dataclass @@ -198,6 +247,8 @@ class _ExplicitPattern: An explicit empty list ``[]`` means "disable partitioning for this cluster". ``max_partitions`` defaults to ``None`` meaning "inherit from folder level". ``indexes`` defaults to ``None`` meaning "inherit from folder level". + ``column_types`` defaults to ``None`` meaning "inherit from folder level"; + an explicit ``{}`` means "no user overrides for this cluster". """ pattern: re.Pattern @@ -209,6 +260,8 @@ class _ExplicitPattern: partition_by: Optional[List[str]] = None max_partitions: Optional[int] = None indexes: Optional[List[str]] = None + column_types: Optional[Dict[str, str]] = None + all_nullable: Optional[bool] = None @dataclass @@ -217,6 +270,14 @@ class FolderConfig: ``partition_by``, ``max_partitions``, and ``indexes`` serve as defaults for every cluster unless overridden at the cluster level. + ``column_types`` is a ``{column_name: postgres_type_str}`` map of + user-supplied type overrides that win over the auto-union computed + during pre-scan. + + Text-file config (``file_type``, ``delimiter``, ``text_encoding``, + ``quotechar``) controls how delimited text files are discovered and + read. These fields are ignored when ``file_type`` is ``"sas"`` + (the default). """ folder: Path @@ -229,6 +290,12 @@ class FolderConfig: partition_by: List[str] = field(default_factory=list) max_partitions: int = 10_000 indexes: List[str] = field(default_factory=list) + column_types: Dict[str, str] = field(default_factory=dict) + all_nullable: bool = False + file_type: str = "sas" + delimiter: str = "," + text_encoding: str = "utf-8" + quotechar: str = '"' # --------------------------------------------------------------------------- @@ -389,6 +456,40 @@ def _validate_indexes_vs_columns( ) +def _parse_column_types( + raw_value: Any, where: str, *, allow_none: bool = False +) -> Optional[Dict[str, str]]: + """Parse a ``column_types`` mapping from YAML. + + The value must be a mapping ``{column_name: pg_type_str}``. Keys and + values are whitespace-stripped strings; empty strings raise. When + ``allow_none`` is True (used for per-cluster entries), an omitted key + returns ``None`` to mean "inherit from folder level"; an explicit + empty mapping returns ``{}`` (no overrides for this cluster). + """ + if raw_value is None: + return None if allow_none else {} + if not isinstance(raw_value, dict): + raise ValueError( + f"{where}: 'column_types' must be a mapping of " + f"{{column_name: postgres_type}}." + ) + out: Dict[str, str] = {} + for k, v in raw_value.items(): + key = str(k).strip() + if not key: + raise ValueError( + f"{where}: 'column_types' contains an empty column name." + ) + if not isinstance(v, str) or not v.strip(): + raise ValueError( + f"{where}: 'column_types[{key}]' must be a non-empty " + f"Postgres type string (got {v!r})." + ) + out[key] = v.strip() + return out + + def load_folder_config(path: Path) -> FolderConfig: """Parse and validate the folder-level YAML config at ``path``. @@ -431,6 +532,46 @@ def load_folder_config(path: Path) -> FolderConfig: indexes = _parse_indexes(raw.get("indexes"), f"Config {path}") _validate_indexes_vs_columns(indexes, exclude, f"Config {path}") + # -- folder-level column_types overrides -------------------------------- + column_types = _parse_column_types( + raw.get("column_types"), f"Config {path}" + ) + + # -- folder-level all_nullable ----------------------------------------- + # Sets the default for every cluster. Per-cluster ``all_nullable`` wins + # when present; the CLI ``--all-nullable`` flag trumps both. + raw_an_folder = raw.get("all_nullable", False) + if not isinstance(raw_an_folder, bool): + raise ValueError( + f"Config {path}: 'all_nullable' must be a boolean " + f"(got {raw_an_folder!r})." + ) + all_nullable_default = bool(raw_an_folder) + + # -- file_type ---------------------------------------------------------- + file_type = str(raw.get("file_type", "sas")).lower() + if file_type not in VALID_FILE_TYPES: + raise ValueError( + f"Config {path}: file_type={file_type!r} is not one of " + f"{VALID_FILE_TYPES}" + ) + + # -- text-file-specific fields ------------------------------------------ + raw_delim = raw.get("delimiter", ",") + if isinstance(raw_delim, str): + delim_lower = raw_delim.lower().strip() + if delim_lower in ("tab", "\\t"): + delimiter = "\t" + elif delim_lower in ("pipe", "|"): + delimiter = "|" + else: + delimiter = raw_delim + else: + delimiter = str(raw_delim) + + text_encoding = str(raw.get("text_encoding", "utf-8")) + quotechar = str(raw.get("quotechar", '"')) + explicit: List[_ExplicitPattern] = [] clusters_raw = raw.get("clusters") or [] if not isinstance(clusters_raw, list): @@ -472,6 +613,24 @@ def load_folder_config(path: Path) -> FolderConfig: effective_idx = c_indexes if c_indexes is not None else indexes _validate_indexes_vs_columns(effective_idx, effective_exclude, where) + # -- per-cluster column_types overrides ----------------------------- + c_column_types = _parse_column_types( + entry.get("column_types"), where, allow_none=True + ) + + # -- per-cluster all_nullable --------------------------------------- + c_all_nullable: Optional[bool] + if "all_nullable" in entry: + raw_c_an = entry["all_nullable"] + if not isinstance(raw_c_an, bool): + raise ValueError( + f"{where}: 'all_nullable' must be a boolean " + f"(got {raw_c_an!r})." + ) + c_all_nullable = bool(raw_c_an) + else: + c_all_nullable = None + explicit.append( _ExplicitPattern( pattern=compiled, @@ -483,6 +642,8 @@ def load_folder_config(path: Path) -> FolderConfig: partition_by=c_partition_by, max_partitions=c_max_partitions, indexes=c_indexes, + column_types=c_column_types, + all_nullable=c_all_nullable, ) ) @@ -497,6 +658,12 @@ def load_folder_config(path: Path) -> FolderConfig: partition_by=partition_by, max_partitions=max_partitions, indexes=indexes, + column_types=column_types or {}, + all_nullable=all_nullable_default, + file_type=file_type, + delimiter=delimiter, + text_encoding=text_encoding, + quotechar=quotechar, ) @@ -537,21 +704,52 @@ def _cluster_sort_key(path: Path) -> Tuple[int, str]: return (n, path.stem) -def _list_sas_files(folder: Path) -> List[Path]: +def _list_data_files(folder: Path, file_type: str = "sas") -> List[Path]: + """List data files in ``folder`` filtered by ``file_type``. + + When ``file_type`` is ``"text"``, only text extensions are matched. + When ``file_type`` is ``"sas"`` (the default), only SAS extensions are + matched. This keeps SAS and text file pools separate so a folder + containing both types doesn't accidentally mix them. + """ + if file_type == "text": + extensions = TEXT_EXTENSIONS + else: + extensions = SAS_EXTENSIONS files: List[Path] = [] for p in sorted(folder.iterdir()): - if p.is_file() and p.suffix.lower() in SAS_EXTENSIONS: + if p.is_file() and p.suffix.lower() in extensions: files.append(p) return files +def _list_sas_files(folder: Path) -> List[Path]: + """Backward-compatible wrapper around :func:`_list_data_files`.""" + return _list_data_files(folder, file_type="sas") + + +def _build_text_kw(cluster: ClusterSpec) -> Dict[str, Any]: + """Build the text-file keyword arguments dict from a cluster's config. + + Returns a dict suitable for spreading into :func:`read_sas_preview`, + :func:`read_sas_metadata`, :func:`iter_sas_chunks`, etc. For SAS + file_type clusters the dict still carries the defaults, which the + reader functions ignore for SAS extensions. + """ + return dict( + delimiter=cluster.delimiter, + text_encoding=cluster.text_encoding, + quotechar=cluster.quotechar, + ) + + def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: """Enumerate ``cfg.folder`` and bucket files into ``ClusterSpec`` objects. Pure/IO-bounded: the only filesystem access is listing ``cfg.folder``. No - SAS file is opened here. Explicit patterns are applied first, in config + data file is opened here. Explicit patterns are applied first, in config order; files matched by an earlier pattern are removed from the pool - before the next pattern runs. A file matching two patterns triggers a + before the next pattern runs. A file matching two patterns triggers a hard error (that's almost always a config bug). Partition settings are resolved per cluster: @@ -560,13 +758,24 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: cluster entry override the folder defaults when present. ``None`` means "inherit"; an explicit ``[]`` disables partitioning. * For auto-detected clusters, folder defaults are inherited directly. + + Text-file config (``file_type``, ``delimiter``, ``text_encoding``, + ``quotechar``) is propagated from the folder config to every cluster. """ if not cfg.folder.exists() or not cfg.folder.is_dir(): raise FileNotFoundError(f"Folder not found or not a directory: {cfg.folder}") - pool = _list_sas_files(cfg.folder) + pool = _list_data_files(cfg.folder, file_type=cfg.file_type) clusters: List[ClusterSpec] = [] + # Text-file kwargs to propagate to every cluster. + _text_fields = dict( + file_type=cfg.file_type, + delimiter=cfg.delimiter, + text_encoding=cfg.text_encoding, + quotechar=cfg.quotechar, + ) + # Detect cross-pattern overlap up front for a clearer error message. for i, p_i in enumerate(cfg.explicit): for j in range(i + 1, len(cfg.explicit)): @@ -594,6 +803,20 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: patt.indexes if patt.indexes is not None else cfg.indexes ) + # Resolve column_types: user overrides only. The auto-union adds + # more entries later (in :func:`main`) after the metadata pre-scan. + # None = inherit folder, {} = no cluster-level overrides, dict = + # cluster-level overrides that win over folder-level entries. + if patt.column_types is None: + resolved_ct: Dict[str, str] = dict(cfg.column_types) + else: + resolved_ct = {**cfg.column_types, **patt.column_types} + + # Resolve all_nullable: None = inherit folder, bool = override. + resolved_an = ( + patt.all_nullable if patt.all_nullable is not None + else cfg.all_nullable + ) matched = [f for f in remaining if patt.pattern.search(f.name)] if not matched: @@ -611,6 +834,9 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: partition_by=resolved_pb, max_partitions=resolved_mp, indexes=resolved_idx, + column_types=dict(resolved_ct), + all_nullable=resolved_an, + **_text_fields, ) ) continue @@ -627,6 +853,9 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: partition_by=resolved_pb, max_partitions=resolved_mp, indexes=resolved_idx, + column_types=dict(resolved_ct), + all_nullable=resolved_an, + **_text_fields, ) ) @@ -647,6 +876,9 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: partition_by=cfg.partition_by, max_partitions=cfg.max_partitions, indexes=cfg.indexes, + column_types=dict(cfg.column_types), + all_nullable=cfg.all_nullable, + **_text_fields, ) ) @@ -658,13 +890,40 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: # --------------------------------------------------------------------------- -def _infer_cluster_schema(path: Path, include, exclude): - """Infer the Postgres column schema from a SAS file preview.""" - preview_df, meta = read_sas_preview(path) +def _infer_cluster_schema( + path: Path, + include, + exclude, + *, + column_types: Optional[Dict[str, str]] = None, + force_nullable: bool = False, + text_kw: Optional[Dict[str, Any]] = None, +) -> Tuple[Dict, Optional[int]]: + """Infer the Postgres column schema from a data file preview. + + Returns ``(columns, total_rows)``. ``total_rows`` comes from the + file metadata (the file's declared row count) and is threaded + through to :func:`_stream_file` so the tqdm progress bar has a real + denominator instead of an indeterminate spinner. ``column_types`` + lets the caller pin specific columns to a chosen Postgres type + (typically the merged auto-union + YAML overrides for the cluster). + ``force_nullable`` stamps every column nullable regardless of what + the preview shows - see :func:`load_sas.infer_schema`. + + ``text_kw`` carries ``delimiter``, ``text_encoding``, ``quotechar`` + through to :func:`read_sas_preview` for text file dispatch. + """ + _tkw = text_kw or {} + preview_df, meta = read_sas_preview(path, **_tkw) preview_df = apply_column_filter(preview_df, include, exclude) total_rows = getattr(meta, "number_rows", None) - columns = infer_schema(preview_df, meta, total_rows=total_rows) - return columns + columns = infer_schema( + preview_df, meta, + total_rows=total_rows, + column_types=column_types, + force_nullable=force_nullable, + ) + return columns, total_rows def _discover_cluster_partitions( @@ -678,10 +937,11 @@ def _discover_cluster_partitions( Each file is scanned chunk-by-chunk so the full dataset is never materialized in memory. """ + tkw = _build_text_kw(cluster) merged: dict = {} for path in cluster.files: def _filtered_chunks(p=path): - for chunk_df, _chunk_meta in iter_sas_chunks(p): + for chunk_df, _chunk_meta in iter_sas_chunks(p, **tkw): yield apply_column_filter( chunk_df, cluster.include, cluster.exclude ) @@ -693,7 +953,16 @@ def _discover_cluster_partitions( return merged -def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int: +def load_cluster( + conn, + cluster: ClusterSpec, + schemaname: str, + *, + workers: int = 1, + progress_queue: Any = None, + db_overrides: Optional[Dict[str, Optional[str]]] = None, + abort_on_first_failure: bool = False, +) -> int: """Load every file in ``cluster`` into one table. Returns total rows loaded. When ``cluster.partition_by`` is non-empty, partition values are @@ -704,12 +973,34 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int: file mid-cluster fails, earlier chunks - including chunks from earlier files in the cluster - stay committed; only the in-flight chunk is rolled back by :func:`main`. + + ``workers`` controls parallelism for streaming. With ``workers == 1`` + every file streams on ``conn`` in sequence. With ``workers > 1`` the + main connection only does ``CREATE TABLE`` (and, for partitioned + clusters, partition discovery + pre-creation), commits, then dispatches + *every* file - including the first - to a ``ProcessPoolExecutor``. Each + worker opens its own psycopg2 connection, re-infers the per-file schema, + runs the same :func:`load_sas.assert_schema_compatible` check the serial + path uses, and streams chunks via COPY. Workers report per-chunk row + counts to ``progress_queue`` so the caller can drive a single aggregated + tqdm bar regardless of how many workers are in flight. + + ``db_overrides`` carries ``{"user", "password"}`` into workers when the + caller prompted for credentials interactively; leave ``None`` to let + workers read the standard libpq environment variables on their own. """ if not cluster.files: return 0 + tkw = _build_text_kw(cluster) + first, *rest = cluster.files - first_columns = _infer_cluster_schema(first, cluster.include, cluster.exclude) + first_columns, first_total_rows = _infer_cluster_schema( + first, cluster.include, cluster.exclude, + column_types=cluster.column_types, + force_nullable=cluster.all_nullable, + text_kw=tkw, + ) # -- Validate index columns early --------------------------------------- if cluster.indexes: @@ -767,21 +1058,65 @@ def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int: ) total = 0 - total += _stream_file( - conn, schemaname, cluster.tablename, first, first_columns, - cluster.include, cluster.exclude, - ) - for path in rest: - columns = _infer_cluster_schema(path, cluster.include, cluster.exclude) - # Uses the same check that if_exists=append runs. A type mismatch or - # missing column aborts the cluster; because chunks commit as they - # load, earlier chunks in the cluster remain in the table. - assert_schema_compatible(conn, schemaname, cluster.tablename, columns) - total += _stream_file( - conn, schemaname, cluster.tablename, path, columns, - cluster.include, cluster.exclude, + if workers > 1: + # Parallel path: commit the (empty) table now so worker subprocesses' + # ``assert_schema_compatible`` probes can actually see it via + # ``information_schema``, then dispatch *every* file (first + + # rest) to the pool. The previous design streamed the first file + # on the main connection before spawning workers, which made the + # serial first-file phase the long pole on big-file clusters + # (e.g. 52 × 5-50 GB). Now ``CREATE TABLE`` is the only serial + # work and it takes milliseconds. + conn.commit() + total += _load_remaining_files_parallel( + cluster.files, + schemaname, + cluster.tablename, + cluster.include, + cluster.exclude, + workers=workers, + progress_queue=progress_queue, + db_overrides=db_overrides, + column_types=cluster.column_types, + force_nullable=cluster.all_nullable, + abort_on_first_failure=abort_on_first_failure, + text_kw=tkw, ) + else: + # Serial path: stream the first file on the main connection, then + # iterate the rest. Worth keeping separate from the parallel path + # because spawning a single-worker pool just to load files in + # series would be pure overhead. + total += _stream_file( + conn, schemaname, cluster.tablename, first, first_columns, + cluster.include, cluster.exclude, + total_rows=first_total_rows, + progress_queue=progress_queue, + text_kw=tkw, + ) + conn.commit() + for path in rest: + columns, path_total_rows = _infer_cluster_schema( + path, cluster.include, cluster.exclude, + column_types=cluster.column_types, + force_nullable=cluster.all_nullable, + text_kw=tkw, + ) + # Uses the same check that if_exists=append runs. A type + # mismatch or missing column aborts the cluster; because + # chunks commit as they load, earlier chunks in the + # cluster remain in the table. + assert_schema_compatible( + conn, schemaname, cluster.tablename, columns + ) + total += _stream_file( + conn, schemaname, cluster.tablename, path, columns, + cluster.include, cluster.exclude, + total_rows=path_total_rows, + progress_queue=progress_queue, + text_kw=tkw, + ) # -- Index support ------------------------------------------------------ if cluster.indexes: @@ -798,21 +1133,337 @@ def _stream_file( columns, include, exclude, + *, + total_rows: Optional[int] = None, + progress_queue: Any = None, + text_kw: Optional[Dict[str, Any]] = None, ) -> int: + """Stream ``path`` into an existing table chunk by chunk. + + When ``progress_queue`` is provided, each chunk's row count is published + to the queue as ``("rows", n)`` tuples instead of being rendered to a + per-file tqdm bar. That lets :func:`main` drive a single folder-wide + progress bar from a background drainer thread, which is the only way + to keep a coherent progress view when the folder loader is running + files in parallel workers. + + ``text_kw`` carries ``delimiter``, ``text_encoding``, ``quotechar`` + through to :func:`iter_sas_chunks` for text file dispatch. + """ + _tkw = text_kw or {} + def _chunks(): - seen = 0 - for chunk_df, _chunk_meta in iter_sas_chunks(path): - chunk_df = apply_column_filter(chunk_df, include, exclude) - seen += len(chunk_df) - print( - f" {path.name}: streaming... {seen:,} rows", - file=sys.stderr, - ) - yield chunk_df + if progress_queue is not None: + for chunk_df, _chunk_meta in iter_sas_chunks(path, **_tkw): + chunk_df = apply_column_filter(chunk_df, include, exclude) + progress_queue.put(("rows", len(chunk_df))) + yield chunk_df + return + + pbar = tqdm( + total=total_rows, + unit="row", + unit_scale=True, + desc=f" {path.name}", + file=sys.stderr, + dynamic_ncols=True, + ) + try: + for chunk_df, _chunk_meta in iter_sas_chunks(path, **_tkw): + chunk_df = apply_column_filter(chunk_df, include, exclude) + pbar.update(len(chunk_df)) + yield chunk_df + finally: + pbar.close() return copy_dataframes(conn, schemaname, tablename, _chunks(), columns) +# --------------------------------------------------------------------------- +# Parallel append workers +# --------------------------------------------------------------------------- + + +def _worker_load_append_file( + path_str: str, + schemaname: str, + tablename: str, + include: Optional[List[str]], + exclude: Optional[List[str]], + progress_queue: Any, + db_overrides: Optional[Dict[str, Optional[str]]], + column_types: Optional[Dict[str, str]] = None, + force_nullable: bool = False, + text_kw: Optional[Dict[str, Any]] = None, +) -> Tuple[str, int, Optional[str]]: + """Worker process: load one data file in append mode. + + Runs in a subprocess spawned by :func:`_load_remaining_files_parallel`. + Opens its own psycopg2 connection, re-infers the per-file schema (so + per-file ``INTEGER`` vs ``BIGINT`` drift is caught by the existing + schema-compat check just like in the serial path), and streams chunks + via ``COPY``. Row counts are published to the shared queue for the + main process's global tqdm bar. + + ``text_kw`` carries ``delimiter``, ``text_encoding``, ``quotechar`` + through to the reader functions for text file dispatch. + + Returns ``(path_str, rows_loaded, error_or_None)`` - failures are + returned rather than raised so the parent can aggregate results + across workers without losing partial progress. + """ + from pathlib import Path as _Path + + from dotenv import load_dotenv as _load_dotenv + + import ctypes + import ctypes.util + import gc + + from load_sas import ( + apply_column_filter as _apply_column_filter, + assert_schema_compatible as _assert_schema_compatible, + connect as _connect, + copy_dataframes as _copy_dataframes, + infer_schema as _infer_schema, + iter_sas_chunks as _iter_sas_chunks, + read_sas_preview as _read_sas_preview, + ) + + _load_dotenv() + + _tkw = text_kw or {} + path = _Path(path_str) + try: + preview_df, meta = _read_sas_preview(path, **_tkw) + preview_df = _apply_column_filter(preview_df, include, exclude) + total_rows = getattr(meta, "number_rows", None) + columns = _infer_schema( + preview_df, meta, + total_rows=total_rows, + column_types=column_types, + force_nullable=force_nullable, + ) + # Drop the preview ASAP - on a 2M-row wide file it's hundreds of MB + # and we never need it again after schema inference. + del preview_df, meta + + user = db_overrides.get("user") if db_overrides else None + password = db_overrides.get("password") if db_overrides else None + conn = _connect(user=user, password=password) + conn.autocommit = False + try: + _assert_schema_compatible(conn, schemaname, tablename, columns) + + def _chunks(): + for chunk_df, _chunk_meta in _iter_sas_chunks(path, **_tkw): + chunk_df = _apply_column_filter(chunk_df, include, exclude) + if progress_queue is not None: + progress_queue.put(("rows", len(chunk_df))) + yield chunk_df + + rows = _copy_dataframes( + conn, schemaname, tablename, _chunks(), columns + ) + conn.commit() + return (path_str, rows, None) + finally: + conn.close() + except Exception as e: + import traceback as _traceback + tb = _traceback.format_exc() + # Keep the one-line summary (what the tqdm [FAIL] print uses) but + # tack on the full traceback so the final cluster-failure block + # shows the file/line that crashed. Without this, ``ProcessPool`` + # workers lose every frame of context - you get "FloatingPointError: + # overflow encountered in multiply" with no hint of where inside + # the pandas/numpy/pyarrow stack it happened. + return (path_str, 0, f"{type(e).__name__}: {e}\n{tb}") + finally: + # Hand memory back to the OS before the worker is recycled (or before + # ``max_tasks_per_child`` rotates this process). Three layers, each + # of which independently retains memory across calls: + # + # 1. pyarrow's memory pool aggressively reuses buffers - explicitly + # release_unused() returns them to the allocator. + # 2. Python's GC: cyclic refs from pandas/pyarrow chains aren't + # collected until a generation tick; force one now. + # 3. glibc's ptmalloc keeps freed heap in per-thread arenas instead + # of munmap'ing it back. ``malloc_trim(0)`` is the explicit ask. + # No-op (silently) on platforms without the symbol (macOS, etc). + try: + import pyarrow as _pa + _pa.default_memory_pool().release_unused() + except Exception: + pass + gc.collect() + try: + _libc_name = ctypes.util.find_library("c") + if _libc_name: + _libc = ctypes.CDLL(_libc_name) + if hasattr(_libc, "malloc_trim"): + _libc.malloc_trim(0) + except Exception: + pass + + +def _load_remaining_files_parallel( + files: List[Path], + schemaname: str, + tablename: str, + include: Optional[List[str]], + exclude: Optional[List[str]], + *, + workers: int, + progress_queue: Any, + db_overrides: Optional[Dict[str, Optional[str]]], + column_types: Optional[Dict[str, str]] = None, + force_nullable: bool = False, + abort_on_first_failure: bool = False, + text_kw: Optional[Dict[str, Any]] = None, +) -> int: + """Run append-mode loads for ``files`` across a process pool. + + Each file is an independent unit of work submitted to + ``ProcessPoolExecutor``. Workers infer schema, validate compatibility, + and stream via COPY just like the serial path. The table itself must + already exist (and be committed) before this is called - the worker + schema-compat probes read ``information_schema``, which won't see an + uncommitted ``CREATE TABLE``. Failures are surfaced two ways: + + 1. **Live stderr feed.** Every worker's outcome - success or failure + - is written to stderr the moment ``as_completed`` hands it back, + via ``tqdm.write`` so the active progress bar stays intact. This + turns the previous "wait 30 minutes for the last worker to drain + before you find out the first 31 failed at minute 2" footgun into + an immediate notification, which matters most when one bad file + (e.g. schema drift, OOM, bad data) torpedoes most of the pool. + 2. **Final aggregate raise.** Errors are still collected and raised as + one ``RuntimeError`` after the pool drains so the per-cluster + success/fail summary in :func:`main` stays accurate. On + ``KeyboardInterrupt`` we re-raise after wrapping the partial + error list into the same ``RuntimeError`` shape, so Ctrl-C still + prints what failed up to that point instead of silently + discarding it. + """ + total = 0 + errors: List[Tuple[str, str]] = [] + completed = 0 + n_files = len(files) + + # ``max_tasks_per_child=1`` recycles each worker process after every + # file. Without this, glibc/pyarrow/pyreadstat all retain peak-water + # memory inside long-lived workers; over a multi-hour run the sum + # across workers monotonically grows even though individual chunks + # have been freed at the Python level. Recycling per file gives the + # OS the memory back unconditionally - the only cost is one fork + + # python interpreter startup per file (~1-2 s), which is noise next + # to multi-GB sas7bdat reads. + pool_kwargs: Dict[str, Any] = {"max_workers": workers} + if sys.version_info >= (3, 11): + pool_kwargs["max_tasks_per_child"] = 1 + + with ProcessPoolExecutor(**pool_kwargs) as pool: + futures = [ + pool.submit( + _worker_load_append_file, + str(p), + schemaname, + tablename, + include, + exclude, + progress_queue, + db_overrides, + column_types, + force_nullable, + text_kw, + ) + for p in files + ] + aborted = False + try: + for fut in as_completed(futures): + # ``--abort-on-first-failure`` cancels still-pending + # futures; ``as_completed`` yields them anyway and + # ``.result()`` raises ``CancelledError``. Skip those + # quietly - the abort log already accounted for the + # cancellation count. + try: + path_str, rows, err = fut.result() + except CancelledError: + continue + completed += 1 + name = Path(path_str).name + if err is not None: + errors.append((path_str, err)) + # ``tqdm.write`` writes through the bar without + # garbling it; bare ``print`` would interleave with + # the rendered progress line. + tqdm.write( + f"[FAIL {completed}/{n_files}] {name}: {err}", + file=sys.stderr, + ) + if abort_on_first_failure and not aborted: + # Cancel anything still queued. Already-running + # workers can't be interrupted portably, so they + # keep going - we just stop dispatching new files + # and stop counting their results toward total. + # Set the flag once so we don't spam a cancel + # storm if multiple workers fail at the same time. + aborted = True + cancelled = 0 + for f in futures: + if not f.done() and f.cancel(): + cancelled += 1 + tqdm.write( + f"[abort] --abort-on-first-failure: cancelled " + f"{cancelled} pending file(s); waiting for " + f"{n_files - completed - cancelled} in-flight " + f"worker(s) to finish.", + file=sys.stderr, + ) + else: + tqdm.write( + f"[done {completed}/{n_files}] {name}: " + f"{rows:,} row(s)", + file=sys.stderr, + ) + total += rows + except KeyboardInterrupt: + # Cancel anything still queued so we exit promptly. Already- + # running workers will run to completion (Python can't kill + # a child mid-syscall portably) but at least pending tasks + # don't keep firing. We re-raise as the same RuntimeError + # shape used below so the caller's per-cluster summary path + # still sees the partial failure list instead of an opaque + # KeyboardInterrupt with no context about which workers + # had already failed. + for f in futures: + if not f.done(): + f.cancel() + tqdm.write( + f"[interrupt] Ctrl-C after {completed}/{n_files} file(s); " + f"{len(errors)} failure(s) collected so far.", + file=sys.stderr, + ) + if errors: + joined = "\n".join(f" {p}: {e}" for p, e in errors) + raise RuntimeError( + f"interrupted after {len(errors)} worker(s) failed " + f"while appending to {schemaname}.{tablename}:\n{joined}" + ) + raise + + if errors: + joined = "\n".join(f" {p}: {e}" for p, e in errors) + raise RuntimeError( + f"{len(errors)} worker(s) failed while appending to " + f"{schemaname}.{tablename}:\n{joined}" + ) + + return total + + # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- @@ -821,8 +1472,8 @@ def _stream_file( def _build_argparser() -> argparse.ArgumentParser: p = argparse.ArgumentParser( description=( - "Load every SAS file in a folder into Postgres, grouping files " - "into clusters that each become one table." + "Load every data file (SAS or delimited text) in a folder into " + "Postgres, grouping files into clusters that each become one table." ), ) p.add_argument("--config", required=True, type=Path, help="Path to YAML config") @@ -844,6 +1495,20 @@ def _build_argparser() -> argparse.ArgumentParser: "cluster back and continue with the next one." ), ) + p.add_argument( + "--abort-on-first-failure", + action="store_true", + help=( + "Within a single cluster's parallel append, cancel every " + "still-pending worker the instant one worker fails. Use when " + "you know the failure is systemic (schema drift, bad creds, " + "OOM) and don't want to wait for the slow files to drain " + "before getting your prompt back. Already-running workers " + "still finish their current file - Python can't kill children " + "mid-syscall - but new files won't be dispatched. Orthogonal " + "to --fail-fast, which controls what happens between clusters." + ), + ) p.add_argument( "--dbcreds", action="store_true", @@ -852,6 +1517,58 @@ def _build_argparser() -> argparse.ArgumentParser: "PGUSER / PGPASSWORD from the environment or .env file." ), ) + p.add_argument( + "--chunk-rows", + type=int, + default=None, + metavar="N", + help=( + "Per-chunk row target for streaming and COPY. " + "Overrides both the GENERIC_LOADER_CHUNK_ROWS env var and the " + "auto-scaling applied when --workers > 1. Peak memory per " + "worker is roughly 4 × N × avg_row_bytes; with wide data " + "files (~4 KB/row) and 32 workers, N=100000 is a safe starting " + "point on a 128 GB box." + ), + ) + p.add_argument( + "--no-prescan", + action="store_true", + help=( + "Skip the per-file metadata scan that populates the folder-wide " + "tqdm ETA. Useful when the folder is large (half-hour+ pre-scan) " + "or when you're iterating quickly on a failure. Without the " + "pre-scan the progress bar still shows rows loaded, rate, and " + "elapsed time - it just can't estimate remaining time." + ), + ) + p.add_argument( + "--all-nullable", + action="store_true", + help=( + "Stamp every column nullable in the generated schema, bypassing " + "NOT NULL inference for every cluster. Use when sampled rows " + "wrongly suggest a column has no nulls and COPY fails mid-load " + "on the first null it hits. Overrides the per-cluster and " + "folder-level ``all_nullable`` YAML settings when set." + ), + ) + p.add_argument( + "--workers", + type=int, + default=1, + metavar="N", + help=( + "Number of worker processes for the append phase. With N=1 " + "(default) files load serially on the main connection. With " + "N>1 the first file of each cluster still runs serially (to " + "create the table), then the remaining files load in parallel " + "across N processes, each with its own psycopg2 connection. " + "On a big box try N close to your core count. When N>1 the " + "per-chunk row target drops to 500,000 unless you've pinned " + "GENERIC_LOADER_CHUNK_ROWS, so peak memory stays bounded." + ), + ) return p @@ -884,12 +1601,32 @@ def main(argv: Optional[List[str]] = None) -> int: return 2 clusters = discover_clusters(cfg) + + # CLI override: --all-nullable trumps both folder-level and per-cluster + # YAML ``all_nullable`` settings. Applied here (before any schema work) + # so every downstream path - dry-run, pre-scan, worker dispatch - sees + # the same flag on the ClusterSpec. + if args.all_nullable: + for c in clusters: + c.all_nullable = True + print( + "[info] --all-nullable set: stamping every column nullable " + "across all clusters (NOT NULL inference disabled).", + file=sys.stderr, + ) + loadable = [c for c in clusters if c.files] if not loadable: + if cfg.file_type == "text": + ext_label = ", ".join(TEXT_EXTENSIONS) + kind = "text" + else: + ext_label = ", ".join(SAS_EXTENSIONS) + kind = "SAS" print( - f"error: no SAS files found in {cfg.folder} " - f"(looked for {', '.join(SAS_EXTENSIONS)})", + f"error: no {kind} files found in {cfg.folder} " + f"(looked for {ext_label})", file=sys.stderr, ) return 2 @@ -902,7 +1639,17 @@ def main(argv: Optional[List[str]] = None) -> int: print() for c in loadable: print(f"--- DDL for cluster {c.tablename!r} ---") - columns = _infer_cluster_schema(c.files[0], c.include, c.exclude) + # Dry-run skips the pre-scan (so no auto-union) but user-supplied + # ``column_types`` from YAML are already baked into ``c.column_types`` + # by ``discover_clusters`` - honor them here so the previewed DDL + # matches what a real load would produce on a single-file cluster. + _dry_tkw = _build_text_kw(c) + columns, _ = _infer_cluster_schema( + c.files[0], c.include, c.exclude, + column_types=c.column_types, + force_nullable=c.all_nullable, + text_kw=_dry_tkw, + ) # Print parent CREATE TABLE (with PARTITION BY if applicable). print( render_create_table( @@ -970,6 +1717,238 @@ def main(argv: Optional[List[str]] = None) -> int: if args.dbcreds: db_user = input("Database username: ") db_password = getpass.getpass("Database password: ") + db_overrides: Optional[Dict[str, Optional[str]]] = ( + {"user": db_user, "password": db_password} if args.dbcreds else None + ) + + workers = max(1, int(args.workers)) + + # Per-worker peak memory ~= chunk_rows × avg_row_bytes × ~4 (the original + # pyreadstat DataFrame, the type-coerced ``prepared`` copy, the pyarrow + # table, and the serialized CSV buffer can all be alive simultaneously). + # With 32 workers and 500k rows × wide sas7bdat that's easily >128 GB - + # the default the loader shipped with OOM'd on a c6i.32xlarge box. Scale + # the auto target inversely with worker count so total memory stays + # roughly flat regardless of how many workers you pick. Floor of 50k + # keeps per-chunk overhead amortized; ceiling of 500k is where pyarrow + # / pyreadstat buffer spikes start to dominate. + # + # Order of precedence (most wins): + # 1. ``--chunk-rows N`` CLI flag (if provided) + # 2. ``GENERIC_LOADER_CHUNK_ROWS`` env var (if already set) + # 3. Auto-pick based on ``workers`` + if args.chunk_rows is not None: + os.environ["GENERIC_LOADER_CHUNK_ROWS"] = str(int(args.chunk_rows)) + print( + f"[info] --chunk-rows {args.chunk_rows:,}: pinning per-chunk " + f"row target (overrides auto-scaling).", + file=sys.stderr, + ) + elif "GENERIC_LOADER_CHUNK_ROWS" in os.environ: + print( + f"[info] honoring GENERIC_LOADER_CHUNK_ROWS=" + f"{os.environ['GENERIC_LOADER_CHUNK_ROWS']} from environment.", + file=sys.stderr, + ) + elif workers > 1: + auto_rows = max(50_000, min(500_000, 3_200_000 // workers)) + os.environ["GENERIC_LOADER_CHUNK_ROWS"] = str(auto_rows) + print( + f"[info] parallel mode (workers={workers}): auto-scaled " + f"per-chunk rows to {auto_rows:,}. " + f"Use --chunk-rows N to override if you have RAM headroom.", + file=sys.stderr, + ) + + # -- Metadata pre-scan ----------------------------------------------------- + # Sum ``number_rows`` across every file so the tqdm bar has a real + # denominator, AND collect the per-column (readstat_type, sas_format) + # tuples so we can union schemas across files in a cluster before any + # CREATE TABLE runs. ``read_sas_metadata`` uses pyreadstat's + # ``metadataonly=True`` fast path, but on multi-GB sas7bdat files + # that still reads tens of MB of scattered subheader pages per file - + # sequentially that's minutes for a 52-file folder. pyreadstat + # releases the GIL during I/O and C decoding, so a ThreadPool gives + # near-linear scaling until the disk saturates. ``--no-prescan`` + # bypasses the scan entirely; the progress bar then runs without an + # ETA *and* the auto-union is skipped (user overrides from YAML + # still apply). + all_files: List[Path] = [p for c in loadable for p in c.files] + grand_total: Optional[int] = 0 + file_meta_by_path: Dict[str, Dict[str, Tuple[str, Optional[str]]]] = {} + + if args.no_prescan: + grand_total = None + print( + f"[info] --no-prescan set: skipping row-count pre-scan for " + f"{len(all_files)} file(s); progress bar will show rate + " + f"elapsed but no ETA. Cluster-wide schema auto-union is also " + f"disabled; only user-specified column_types overrides apply.", + file=sys.stderr, + ) + else: + prescan_workers = min(16, max(1, len(all_files))) + print( + f"pre-scanning row counts + per-column metadata for " + f"{len(all_files)} file(s) across {prescan_workers} thread(s)...", + file=sys.stderr, + ) + + def _scan_one( + p: Path, + ) -> Tuple[ + Path, + Optional[int], + Optional[Dict[str, Tuple[str, Optional[str]]]], + Optional[str], + ]: + try: + _prescan_tkw = dict( + delimiter=cfg.delimiter, + text_encoding=cfg.text_encoding, + quotechar=cfg.quotechar, + ) + meta = read_sas_metadata(p, **_prescan_tkw) + n = getattr(meta, "number_rows", None) + col_meta = extract_union_metadata(meta) + return ( + p, + int(n) if n is not None else None, + col_meta, + None, + ) + except Exception as e: + return (p, None, None, str(e)) + + unknown_total_files: List[str] = [] + running_total = 0 + with ThreadPoolExecutor(max_workers=prescan_workers) as tpool: + prescan_bar = tqdm( + total=len(all_files), + unit="file", + desc=" prescanning", + file=sys.stderr, + dynamic_ncols=True, + ) + try: + for p, n, col_meta, err in tpool.map(_scan_one, all_files): + prescan_bar.update(1) + if err is not None: + unknown_total_files.append(f"{p.name} ({err})") + elif n is None: + unknown_total_files.append(p.name) + else: + running_total += n + if col_meta is not None: + file_meta_by_path[str(p)] = col_meta + finally: + prescan_bar.close() + + if unknown_total_files: + print( + f"[warn] could not read row count from " + f"{len(unknown_total_files)} file(s); progress bar ETA will " + f"be approximate.", + file=sys.stderr, + ) + print( + f" total rows across folder: {running_total:,}", + file=sys.stderr, + ) + grand_total = running_total + + # -- Cluster-wide schema auto-union --------------------------------------- + # For each cluster, compute ``auto_types`` from the union of every + # file's metadata (see :func:`load_sas.union_column_types`). Merge with + # any user-supplied YAML overrides (user wins) and attach the result + # back onto the cluster so every later read - first-file inference, + # worker inference, schema-compat check - sees the same frozen schema. + # With ``--no-prescan`` the file_meta_by_path dict is empty and + # ``auto_types`` resolves to {}, so only the YAML overrides survive. + for c in loadable: + per_file = [ + file_meta_by_path[str(p)] + for p in c.files + if str(p) in file_meta_by_path + ] + auto_types = union_column_types(per_file) if per_file else {} + user_overrides = dict(c.column_types) # already merged folder+cluster + # User-supplied overrides win over the auto-union. + merged = {**auto_types, **user_overrides} + c.column_types = merged + + if auto_types: + # Only call out columns where auto-union *changed* something + # relative to the default "first file wins" inference. We + # don't have the default inference in hand at this point, so + # log the full resolved map at a debug-friendly level - it's + # bounded by column count and the user asked for visibility + # into what got overridden. + shown = auto_types + if user_overrides: + # Distinguish the user-forced entries in the log so it's + # obvious which types came from YAML. + shown = { + col: ( + f"{user_overrides[col]} (user override)" + if col in user_overrides + else pg + ) + for col, pg in merged.items() + } + print( + f"[info] cluster {c.tablename!r}: auto-union derived " + f"{len(auto_types)} column type(s) across " + f"{len(per_file)} file(s): {shown}", + file=sys.stderr, + ) + elif user_overrides and args.no_prescan: + print( + f"[info] cluster {c.tablename!r}: using {len(user_overrides)} " + f"user-supplied column_types override(s); auto-union " + f"disabled by --no-prescan.", + file=sys.stderr, + ) + + # -- Shared progress plumbing --------------------------------------------- + # The queue crosses process boundaries when workers > 1 (managed proxy) + # and is a plain in-process queue otherwise; the put/get contract is + # identical either way. A daemon thread drains it and advances the one + # tqdm bar that spans the whole folder load. + manager: Optional[Any] = None + progress_queue: Any + if workers > 1: + manager = mp.Manager() + progress_queue = manager.Queue() + else: + progress_queue = _queue_mod.Queue() + + pbar = tqdm( + total=grand_total or None, + unit="row", + unit_scale=True, + desc=f"{cfg.folder.name}", + file=sys.stderr, + dynamic_ncols=True, + ) + stop_drainer = threading.Event() + + def _drainer() -> None: + while not stop_drainer.is_set(): + try: + event = progress_queue.get(timeout=0.1) + except _queue_mod.Empty: + continue + except (EOFError, OSError): + return + if not event: + continue + kind = event[0] + if kind == "rows": + pbar.update(event[1]) + + drainer_thread = threading.Thread(target=_drainer, daemon=True) + drainer_thread.start() conn = connect(user=db_user, password=db_password) conn.autocommit = False @@ -979,10 +1958,19 @@ def main(argv: Optional[List[str]] = None) -> int: for cluster in loadable: print( f"\n>>> loading cluster {cluster.tablename!r} " - f"({len(cluster.files)} file(s))" + f"({len(cluster.files)} file(s)) " + f"[workers={workers}]" ) try: - rows = load_cluster(conn, cluster, cfg.schemaname) + rows = load_cluster( + conn, + cluster, + cfg.schemaname, + workers=workers, + progress_queue=progress_queue, + db_overrides=db_overrides, + abort_on_first_failure=args.abort_on_first_failure, + ) conn.commit() totals.append((cluster.tablename, len(cluster.files), rows)) print( @@ -999,7 +1987,14 @@ def main(argv: Optional[List[str]] = None) -> int: if args.fail_fast: break finally: + # Drain any pending progress events before shutting the bar down so + # the final rendered total matches what actually landed. + stop_drainer.set() + drainer_thread.join(timeout=2.0) + pbar.close() conn.close() + if manager is not None: + manager.shutdown() print("\n=== summary ===") for name, fcount, rows in totals: diff --git a/generic_loader/load_sas.py b/generic_loader/load_sas.py index 74f5ff8..ed65ef8 100644 --- a/generic_loader/load_sas.py +++ b/generic_loader/load_sas.py @@ -1,4 +1,4 @@ -"""Per-file SAS-to-Postgres loader. +"""Per-file data-to-Postgres loader (SAS and delimited text). Library-style functions plus a thin CLI wrapper. Designed so an orchestrator can wrap the library for directory/batch mode; orchestration is out of scope @@ -16,10 +16,11 @@ USAGE Supported inputs: * ``.sas7bdat`` (read with ``encoding="latin-1"``) * ``.xpt`` / ``.xport`` (SAS transport files) + * ``.csv`` / ``.tsv`` / ``.txt`` (delimited text files with headers) 1. YAML config -------------- -Every invocation is driven by a YAML file describing one SAS file to load:: +Every invocation is driven by a YAML file describing one data file to load:: filename: samples/sample_kitchensink.xpt # required; relative paths are # resolved against the config @@ -183,15 +184,17 @@ Priority order used by :func:`infer_schema`: value exceeds the int32 range ``NUMERIC_INT_RANGE``); otherwise ``DOUBLE PRECISION``. -Type inference scans only the first ``TYPE_INFERENCE_SAMPLE_ROWS`` rows for -performance on large files. The CLI enforces this at read time via -:func:`read_sas_preview`, so the whole file is never materialized just to pick -types. Sampled specs carry an ``inferred_from_sample`` marker and the usual -tradeoffs: if the first N rows fit ``INTEGER`` but a later row exceeds int32, -or a column had no nulls in the preview but does later in the file, ``COPY`` -will fail mid-stream and the whole transaction rolls back. Set -``TYPE_INFERENCE_SAMPLE_ROWS = None`` to scan every row when exact typing -matters more than speed. +Type inference scans the whole file by default (``TYPE_INFERENCE_SAMPLE_ROWS += None``) so type + nullability are both computed against every row. The CLI +materializes the file once for schema inference, then re-streams it chunk by +chunk into ``COPY``; peak memory is roughly one full dataframe. Override +``TYPE_INFERENCE_SAMPLE_ROWS`` to an integer cap if you're on a host that +can't hold the file in memory - but know that sampled specs carry the usual +risks: a later row may exceed the inferred integer range, or a column that +had no nulls in the preview may carry nulls later in the file (which then +detonates ``COPY`` because the sampled spec stamped it ``NOT NULL``). Seen +in production on a 2.5M-row file with ~6k null MAFIDs past the 10k-row +preview - the entire load aborted mid-stream. Streaming loads use :func:`iter_sas_chunks` + :func:`copy_dataframes`, which commit each chunk as it is copied so an interrupted load retains the rows @@ -225,16 +228,48 @@ import math import os import re import sys +import warnings from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Tuple +import numpy as np import pandas as pd import psycopg2 import psycopg2.extensions +import pyarrow as pa +import pyarrow.csv as pa_csv import pyreadstat import yaml from dotenv import load_dotenv +from pandas.errors import PerformanceWarning +from tqdm import tqdm + +# ``_prepare_for_copy`` builds its output frame one column at a time with +# ``out[name] = ...``. On wide SAS files (~100+ columns) pandas prints a +# ``PerformanceWarning: DataFrame is highly fragmented`` once per chunk to +# nudge callers toward ``pd.concat(axis=1, ...)``. The fragmentation only +# matters for row-oriented ops or in-place ``.copy()``; we hand the frame +# straight to ``pyarrow.Table.from_pandas`` which reads columns +# independently, so the warning is pure noise for our pipeline. Filter it +# at import time - narrow category match so nothing else is suppressed. +warnings.filterwarnings("ignore", category=PerformanceWarning) + +# Turn numpy's "raise on float overflow" (and friends) into silent inf/nan +# production, module-wide. Pandas ships with ``np.errstate(over="raise")`` +# wrapped around several internal ops (most painfully, the multiply inside +# ``pd.to_datetime(unit="s")`` that converts SAS epoch -> nanoseconds). +# Our data routinely carries ``inf`` / huge sentinels, which trip that +# ``raise`` and blow up an entire worker before ``errors="coerce"`` gets +# a chance to turn them into NaT. Even with ``_safe_numeric_to_datetime`` +# pre-masking the obvious cases, other code paths (pandas object-dtype +# datetime parsing, pyarrow type promotion, pyreadstat) can also trigger. +# Setting a process-wide ``seterr`` is a heavier hammer than an +# ``errstate`` block but survives library internals that don't explicitly +# rewrap it. Downside: a real overflow bug in new code would now silently +# produce inf/nan instead of raising - acceptable for a bulk loader where +# "don't crash on bad rows, null them and move on" is the whole point. +np.seterr(over="ignore", invalid="ignore", divide="ignore") logger = logging.getLogger(__name__) @@ -255,21 +290,39 @@ values; too small a sample is easy to mis-infer.""" NUMERIC_INT_RANGE = (-2_147_483_648, 2_147_483_647) """INTEGER bounds; anything outside becomes BIGINT.""" -TYPE_INFERENCE_SAMPLE_ROWS: Optional[int] = 10_000 +TYPE_INFERENCE_SAMPLE_ROWS: Optional[int] = None """Cap on rows inspected during per-column type inference. Also governs how many rows :func:`read_sas_preview` pulls from the file for dry-run / validate / -schema-inference flows. Set to ``None`` to scan every row (and read the whole -file into memory for the preview step - don't do this on multi-hundred-million -row files).""" +schema-inference flows. -DEFAULT_CHUNK_ROWS = 100_000 +Default is ``None`` (scan every row, reading the whole file into memory for +the schema-inference step). That's the only honest setting for nullability: +any integer cap lets a column look ``NOT NULL`` across the first N rows +while the file actually holds rare nulls past the window, which then +detonates ``COPY`` mid-stream (seen in production on a 2.5M-row file where +~6k MAFIDs were null past the 10k-row preview). If you're loading a file +so large that a full read won't fit in memory, set this to an integer cap +and accept that sampled specs can't be trusted for ``NOT NULL``.""" + +DEFAULT_CHUNK_ROWS = 2_000_000 """Rows per chunk when streaming a SAS file into ``COPY``. Larger values mean -fewer COPY round-trips but more peak memory per chunk; smaller values are -gentler on memory.""" +fewer COPY round-trips and lower per-row overhead but more peak memory per +chunk; smaller values are gentler on memory. + +The chunk size can be overridden at runtime via the +``GENERIC_LOADER_CHUNK_ROWS`` environment variable (read inside +:func:`iter_sas_chunks`), so ``.env``-driven overrides work without code +changes. Explicit ``chunksize=`` kwargs still win over both.""" VALID_IF_EXISTS = ("fail", "replace", "append") +VALID_FILE_TYPES = ("sas", "text") +"""Supported ``file_type`` values in the YAML config.""" + +TEXT_EXTENSIONS = (".txt", ".csv", ".tsv") +"""File extensions recognised as delimited text files.""" + _PG_IDENT_MAX_LEN = 63 """PostgreSQL maximum identifier length in bytes (characters for ASCII).""" @@ -279,6 +332,20 @@ _PG_IDENT_MAX_LEN = 63 # --------------------------------------------------------------------------- +@dataclass +class TextFileMetadata: + """Minimal metadata object for text files, mimicking pyreadstat metadata. + + Provides the same attribute surface that :func:`infer_schema` reads from + pyreadstat metadata objects: ``column_names``, ``column_labels``, + ``original_variable_types``, and ``number_rows``. + """ + column_names: List[str] + column_labels: List[str] + original_variable_types: Dict[str, str] + number_rows: Optional[int] = None + + @dataclass class LoaderConfig: filename: Path @@ -290,6 +357,12 @@ class LoaderConfig: partition_by: List[str] = field(default_factory=list) max_partitions: int = 10_000 indexes: List[str] = field(default_factory=list) + column_types: Dict[str, str] = field(default_factory=dict) + all_nullable: bool = False + file_type: str = "sas" + delimiter: str = "," + text_encoding: str = "utf-8" + quotechar: str = '"' @dataclass @@ -500,6 +573,73 @@ def load_config(path: Path) -> LoaderConfig: f"{missing_in_include}" ) + # -- column_types ------------------------------------------------------- + # Optional ``{column_name: pg_type}`` escape hatch that bypasses + # automatic type inference for specific columns. Useful when + # pyreadstat reports a column as NUM but the downstream consumer + # expects TEXT (e.g. phone-id columns), or when a column has drifted + # between CHAR and NUM across file versions and you want to pin + # TEXT up front. See also :func:`infer_schema`. + raw_ct = raw.get("column_types") + column_types: Dict[str, str] = {} + if raw_ct is not None: + if not isinstance(raw_ct, dict): + raise ValueError( + f"Config {path}: 'column_types' must be a mapping of " + f"{{column_name: postgres_type}}." + ) + for k, v in raw_ct.items(): + key = str(k).strip() + if not key: + raise ValueError( + f"Config {path}: 'column_types' contains an empty " + f"column name." + ) + if not isinstance(v, str) or not v.strip(): + raise ValueError( + f"Config {path}: 'column_types[{key}]' must be a " + f"non-empty Postgres type string (got {v!r})." + ) + column_types[key] = v.strip() + + # -- all_nullable ------------------------------------------------------- + # When inference wrongly stamps a column NOT NULL (sampled rows happened + # to be dense; later rows carry nulls) downstream COPYs fail mid-stream. + # Set ``all_nullable: true`` in the YAML to stamp every column nullable + # up front. The CLI flag ``--all-nullable`` overrides this to ``true`` + # if set. + raw_an = raw.get("all_nullable", False) + if not isinstance(raw_an, bool): + raise ValueError( + f"Config {path}: 'all_nullable' must be a boolean (got {raw_an!r})." + ) + all_nullable = bool(raw_an) + + # -- file_type ---------------------------------------------------------- + file_type = str(raw.get("file_type", "sas")).lower() + if file_type not in VALID_FILE_TYPES: + raise ValueError( + f"Config {path}: file_type={file_type!r} is not one of " + f"{VALID_FILE_TYPES}" + ) + + # -- text-file-specific fields ------------------------------------------ + # Only validated when file_type == "text"; harmless defaults otherwise. + raw_delim = raw.get("delimiter", ",") + if isinstance(raw_delim, str): + delim_lower = raw_delim.lower().strip() + if delim_lower in ("tab", "\\t"): + delimiter = "\t" + elif delim_lower in ("pipe", "|"): + delimiter = "|" + else: + delimiter = raw_delim + else: + delimiter = str(raw_delim) + + text_encoding = str(raw.get("text_encoding", "utf-8")) + quotechar = str(raw.get("quotechar", '"')) + return LoaderConfig( filename=filename, schemaname=schemaname, @@ -510,6 +650,12 @@ def load_config(path: Path) -> LoaderConfig: partition_by=partition_by, max_partitions=max_partitions, indexes=indexes, + column_types=column_types, + all_nullable=all_nullable, + file_type=file_type, + delimiter=delimiter, + text_encoding=text_encoding, + quotechar=quotechar, ) @@ -518,6 +664,11 @@ def load_config(path: Path) -> LoaderConfig: # --------------------------------------------------------------------------- +def _is_text_file(path: Path) -> bool: + """Return True if ``path`` has a recognised delimited-text extension.""" + return Path(path).suffix.lower() in TEXT_EXTENSIONS + + def _sas_reader(path: Path) -> Tuple[Any, Dict[str, Any]]: """Return ``(pyreadstat_reader, extra_kwargs)`` for ``path``. @@ -535,13 +686,192 @@ def _sas_reader(path: Path) -> Tuple[Any, Dict[str, Any]]: raise ValueError(f"Unsupported SAS file extension: {suffix}") -def read_sas(path: Path) -> Tuple[pd.DataFrame, Any]: - """Read an entire SAS file into memory. Only safe for small files. +# --------------------------------------------------------------------------- +# Text file readers +# --------------------------------------------------------------------------- + + +def _count_text_lines(path: Path, encoding: str = "utf-8") -> int: + """Count data rows in a text file (excludes the header line). + + Reads the file in binary chunks for speed; counts newlines and + subtracts one for the header. + """ + count = 0 + with open(path, "rb") as fh: + for chunk in iter(lambda: fh.read(1 << 20), b""): + count += chunk.count(b"\n") + # If the file doesn't end with a newline the last line is still a row. + # But the first line is the header, so subtract 1. + # Edge case: empty file or header-only -> 0 rows. + return max(0, count - 1) if count > 0 else 0 + + +def _build_text_metadata( + column_names: List[str], + number_rows: Optional[int] = None, +) -> TextFileMetadata: + """Build a :class:`TextFileMetadata` from column names and an optional + row count.""" + return TextFileMetadata( + column_names=list(column_names), + column_labels=list(column_names), + original_variable_types={}, + number_rows=number_rows, + ) + + +def read_text( + path: Path, + delimiter: str = ",", + encoding: str = "utf-8", + quotechar: str = '"', +) -> Tuple[pd.DataFrame, TextFileMetadata]: + """Read an entire delimited text file into memory. + + Returns ``(DataFrame, TextFileMetadata)`` — the metadata object carries + the same attributes that :func:`infer_schema` reads from pyreadstat + metadata. + """ + path = Path(path) + df = pd.read_csv( + path, + delimiter=delimiter, + encoding=encoding, + quotechar=quotechar, + dtype=str, + keep_default_na=True, + na_values=[""], + ) + meta = _build_text_metadata(list(df.columns), number_rows=len(df)) + return df, meta + + +def read_text_preview( + path: Path, + delimiter: str = ",", + encoding: str = "utf-8", + quotechar: str = '"', + rows: Optional[int] = None, +) -> Tuple[pd.DataFrame, TextFileMetadata]: + """Read the first ``rows`` records from a delimited text file. + + When ``rows`` is ``None`` or 0, reads the entire file (matching the + semantics of :func:`read_sas_preview`). + """ + path = Path(path) + nrows = int(rows) if rows else None + df = pd.read_csv( + path, + delimiter=delimiter, + encoding=encoding, + quotechar=quotechar, + nrows=nrows, + dtype=str, + keep_default_na=True, + na_values=[""], + ) + # For total row count, do a fast line count when we only read a preview. + if nrows is not None and nrows > 0: + total = _count_text_lines(path, encoding) + else: + total = len(df) + meta = _build_text_metadata(list(df.columns), number_rows=total) + return df, meta + + +def read_text_metadata( + path: Path, + delimiter: str = ",", + encoding: str = "utf-8", + quotechar: str = '"', +) -> TextFileMetadata: + """Read only the header and line count from a delimited text file. + + Fast path: reads the first line for column names and counts newlines + for the row total without materializing a DataFrame. + """ + path = Path(path) + # Read just the header row. + df_header = pd.read_csv( + path, + delimiter=delimiter, + encoding=encoding, + quotechar=quotechar, + nrows=0, + ) + column_names = list(df_header.columns) + total = _count_text_lines(path, encoding) + return _build_text_metadata(column_names, number_rows=total) + + +def iter_text_chunks( + path: Path, + delimiter: str = ",", + encoding: str = "utf-8", + quotechar: str = '"', + chunksize: Optional[int] = None, +): + """Yield ``(df_chunk, meta)`` tuples for streaming text file loads. + + Uses ``pandas.read_csv()`` with ``chunksize`` for memory-efficient + iteration. The metadata object is rebuilt for each chunk with the + chunk's column names and ``number_rows`` set to the total file rows + (computed once up front). + """ + path = Path(path) + if chunksize is None: + raw_env = os.environ.get("GENERIC_LOADER_CHUNK_ROWS") + if raw_env is not None: + try: + chunksize = int(raw_env) + except ValueError: + chunksize = DEFAULT_CHUNK_ROWS + else: + chunksize = DEFAULT_CHUNK_ROWS + + total = _count_text_lines(path, encoding) + + reader = pd.read_csv( + path, + delimiter=delimiter, + encoding=encoding, + quotechar=quotechar, + chunksize=chunksize, + dtype=str, + keep_default_na=True, + na_values=[""], + ) + for chunk_df in reader: + meta = _build_text_metadata(list(chunk_df.columns), number_rows=total) + yield chunk_df, meta + + +# --------------------------------------------------------------------------- +# Unified reader dispatch +# --------------------------------------------------------------------------- + + +def read_sas( + path: Path, + *, + delimiter: str = ",", + text_encoding: str = "utf-8", + quotechar: str = '"', +) -> Tuple[pd.DataFrame, Any]: + """Read an entire SAS or delimited text file into memory. + + For SAS files (``.sas7bdat``, ``.xpt``, ``.xport``), delegates to + pyreadstat. For text files (``.txt``, ``.csv``, ``.tsv``), delegates + to :func:`read_text`. The text-specific parameters are ignored for SAS + files. Kept for backward compatibility and tests; the CLI now uses :func:`read_sas_preview` + :func:`iter_sas_chunks` so it never materializes the whole frame at once. """ + if _is_text_file(path): + return read_text(path, delimiter=delimiter, encoding=text_encoding, quotechar=quotechar) reader, kwargs = _sas_reader(path) return reader(str(Path(path)), **kwargs) @@ -550,29 +880,98 @@ def read_sas_preview( path: Path, *, rows: Optional[int] = None, + delimiter: str = ",", + text_encoding: str = "utf-8", + quotechar: str = '"', ) -> Tuple[pd.DataFrame, Any]: """Read the first ``rows`` records from ``path`` plus its metadata. Defaults to ``TYPE_INFERENCE_SAMPLE_ROWS`` when ``rows`` is not given. Passing ``rows=None`` with ``TYPE_INFERENCE_SAMPLE_ROWS=None`` reads the whole file (pyreadstat treats ``row_limit=0`` as unlimited). + + For text files, delegates to :func:`read_text_preview`. """ - reader, kwargs = _sas_reader(path) effective = rows if rows is not None else TYPE_INFERENCE_SAMPLE_ROWS + if _is_text_file(path): + return read_text_preview( + path, + delimiter=delimiter, + encoding=text_encoding, + quotechar=quotechar, + rows=effective, + ) + reader, kwargs = _sas_reader(path) row_limit = int(effective) if effective else 0 return reader(str(Path(path)), row_limit=row_limit, **kwargs) +def read_sas_metadata( + path: Path, + *, + delimiter: str = ",", + text_encoding: str = "utf-8", + quotechar: str = '"', +) -> Any: + """Read only the metadata (no rows) from a SAS or text file. + + Uses pyreadstat's ``metadataonly=True`` fast path for SAS files: the + reader decodes the file header (column names, formats, total row count, + etc.) and returns without touching the data pages. Orders of magnitude + faster than :func:`read_sas_preview` when all you need is + ``meta.number_rows`` - typically a few ms per sas7bdat file, which + makes it cheap to pre-scan a whole folder to populate a global + progress bar. + + For text files, delegates to :func:`read_text_metadata`. + """ + if _is_text_file(path): + return read_text_metadata( + path, delimiter=delimiter, encoding=text_encoding, quotechar=quotechar, + ) + reader, kwargs = _sas_reader(path) + _, meta = reader(str(Path(path)), metadataonly=True, **kwargs) + return meta + + def iter_sas_chunks( path: Path, *, - chunksize: int = DEFAULT_CHUNK_ROWS, + chunksize: Optional[int] = None, + delimiter: str = ",", + text_encoding: str = "utf-8", + quotechar: str = '"', ): """Yield ``(df_chunk, meta)`` tuples for streaming loads. Thin wrapper over ``pyreadstat.read_file_in_chunks`` that picks the right underlying reader by extension and threads through our encoding defaults. + + When ``chunksize`` is ``None`` (the default), the effective value comes + from the ``GENERIC_LOADER_CHUNK_ROWS`` environment variable if set and + parseable, otherwise from :data:`DEFAULT_CHUNK_ROWS`. An explicit int + always wins. + + For text files, delegates to :func:`iter_text_chunks`. """ + if _is_text_file(path): + yield from iter_text_chunks( + path, + delimiter=delimiter, + encoding=text_encoding, + quotechar=quotechar, + chunksize=chunksize, + ) + return + if chunksize is None: + raw = os.environ.get("GENERIC_LOADER_CHUNK_ROWS") + if raw is not None: + try: + chunksize = int(raw) + except ValueError: + chunksize = DEFAULT_CHUNK_ROWS + else: + chunksize = DEFAULT_CHUNK_ROWS reader, kwargs = _sas_reader(path) yield from pyreadstat.read_file_in_chunks( reader, str(Path(path)), chunksize=chunksize, **kwargs @@ -590,7 +989,13 @@ def apply_column_filter( exclude: Optional[List[str]], ) -> pd.DataFrame: """Restrict ``df`` to the requested columns. Names missing from the frame - raise a clear error rather than silently dropping.""" + raise a clear error rather than silently dropping. + + Returns the input frame (or a column-sliced view / drop result) without + an extra ``.copy()`` — downstream (:func:`_prepare_for_copy`) reads the + frame into a freshly built output and never mutates its input, so the + copies were pure overhead on every streamed chunk. + """ if include is not None and exclude is not None: raise ValueError("include and exclude are mutually exclusive.") @@ -598,15 +1003,15 @@ def apply_column_filter( missing = [c for c in include if c not in df.columns] if missing: raise ValueError(f"include references unknown columns: {missing}") - return df.loc[:, list(include)].copy() + return df.loc[:, list(include)] if exclude is not None: missing = [c for c in exclude if c not in df.columns] if missing: raise ValueError(f"exclude references unknown columns: {missing}") - return df.drop(columns=list(exclude)).copy() + return df.drop(columns=list(exclude)) - return df.copy() + return df # --------------------------------------------------------------------------- @@ -634,6 +1039,126 @@ def _format_driven_type(sas_format: Optional[str]) -> Optional[str]: return None +_DECIMAL_FORMAT_RE = re.compile(r"\.(\d+)") + + +def _format_hints_decimal(sas_format: Optional[str]) -> bool: + """True if a numeric SAS format string explicitly carries decimal places. + + SAS numeric formats are ``NAMEw.d``; ``d > 0`` means the variable was + intended to render with ``d`` decimal digits (COMMA10.2, F8.3, ...). + A bare width like ``BEST12.`` or ``F8.`` has no digits after the dot + and is treated as integer-presenting. Used by + :func:`union_column_types` to pick BIGINT vs DOUBLE PRECISION when a + column is numeric in every file of a cluster. + """ + if not sas_format: + return False + m = _DECIMAL_FORMAT_RE.search(sas_format) + if not m: + return False + try: + return int(m.group(1)) > 0 + except ValueError: + return False + + +def extract_union_metadata( + meta: Any, +) -> Dict[str, Tuple[str, Optional[str]]]: + """Pull the (readstat_type, sas_format) pair for every column in ``meta``. + + Returns a plain dict that's safe to pass between processes and to + :func:`union_column_types`. ``readstat_type`` is the simplified type + reported by pyreadstat: ``"string"`` for SAS CHAR, ``"double"`` for + SAS NUM. ``sas_format`` comes from ``meta.original_variable_types`` + and drives date/datetime detection during union. + """ + var_types = dict(getattr(meta, "variable_types", None) or {}) + formats = dict(getattr(meta, "original_variable_types", None) or {}) + names = list( + getattr(meta, "column_names", None) + or list(var_types.keys()) + or list(formats.keys()) + ) + out: Dict[str, Tuple[str, Optional[str]]] = {} + for col in names: + rtype = str(var_types.get(col, "")) if var_types else "" + fmt = formats.get(col) + out[col] = (rtype, fmt if fmt else None) + return out + + +def union_column_types( + per_file_metas: Iterable[Dict[str, Tuple[str, Optional[str]]]], +) -> Dict[str, str]: + """Derive one Postgres type per column that's safe across every file. + + ``per_file_metas`` is an iterable (one entry per file in a cluster) of + ``{column_name: (readstat_type, sas_format)}`` dicts as produced by + :func:`extract_union_metadata`. + + Rules, evaluated per column: + + * **CHAR/NUM drift wins TEXT.** If any file stores the column as CHAR + (``readstat_type != "double"``) the union is ``TEXT``. This covers + the phone-id case where some years stored ``RESP_PH_PREFIX_ID`` as + CHAR and others as NUM. + * **All NUM, format hints DATETIME → TIMESTAMP.** Any file whose + format resolves to ``TIMESTAMP`` (via :func:`_format_driven_type`) + pins the column to ``TIMESTAMP`` even if other files left the + format blank. + * **All NUM, format hints DATE → DATE.** Same idea for date-only + formats. + * **All NUM, any decimal hint → DOUBLE PRECISION.** A ``w.d`` format + with ``d > 0`` in any file implies fractional values somewhere. + * **All NUM, no useful hint → DOUBLE PRECISION.** SAS numeric + formats are *display* formats, not storage constraints - a + ``BEST12.`` / ``F8.`` / blank-format column can still hold floats, + and pyreadstat hands back plain ``float64`` regardless. Defaulting + to ``DOUBLE PRECISION`` here costs the same 8 bytes as ``BIGINT`` + but can't fail on real data. For columns that truly are + integer-only and you want ``BIGINT`` semantics in queries, pin + them via a ``column_types`` override. + + Columns missing from a given file are simply skipped for that file; + the union is computed over whichever files *did* supply the column. + Columns that never appear anywhere are omitted from the result. + """ + per_col: Dict[str, List[Tuple[str, Optional[str]]]] = {} + for meta in per_file_metas: + for col, pair in meta.items(): + per_col.setdefault(col, []).append(pair) + + result: Dict[str, str] = {} + for col, entries in per_col.items(): + any_char = any( + rtype and rtype.lower() != "double" for rtype, _ in entries + ) + if any_char: + result[col] = "TEXT" + continue + formats = [fmt for _, fmt in entries if fmt] + driven = [_format_driven_type(f) for f in formats] + if "TIMESTAMP" in driven: + result[col] = "TIMESTAMP" + elif "DATE" in driven: + result[col] = "DATE" + else: + # Safe default: DOUBLE PRECISION. The BIGINT default we tried + # first failed the moment a file contained a fractional + # value in a column whose format didn't carry a decimal + # hint (very common: SAS ``BEST12.`` / ``F8.`` are display + # formats, not storage constraints, so the underlying + # 8-byte float can hold any value). Same storage cost as + # BIGINT, handles both integer- and float-valued data, and + # keeps loads from failing mid-cluster. Use a + # ``column_types`` override to pin specific columns to + # ``BIGINT`` when you want integer semantics in queries. + result[col] = "DOUBLE PRECISION" + return result + + def _all_null(series: pd.Series) -> bool: if pd.api.types.is_object_dtype(series): return bool(series.map(lambda v: v is None or (isinstance(v, str) and v == "") or (isinstance(v, float) and pd.isna(v))).all()) @@ -759,6 +1284,8 @@ def infer_schema( *, coerce_chars: bool = COERCE_CHAR_COLUMNS, total_rows: Optional[int] = None, + column_types: Optional[Dict[str, str]] = None, + force_nullable: bool = False, ) -> Dict[str, ColumnSpec]: """Infer a Postgres column spec for each column in ``df``. @@ -774,11 +1301,30 @@ def infer_schema( ``total_rows`` lets callers who already sampled the frame (e.g. via :func:`read_sas_preview`) report the real file size in the per-column "inferred from first N of M rows" note. Falls back to ``len(df)``. + + ``column_types`` is an optional map ``{column_name: pg_type_str}`` + whose entries bypass inference entirely - the caller has already + decided the type (e.g. via :func:`union_column_types` across a + cluster, or a YAML ``column_types`` override). Nullability is still + computed from the data. Columns in ``column_types`` that don't exist + in ``df`` are ignored so a shared override dict can apply to clusters + with different column sets. + + ``force_nullable=True`` stamps every column nullable regardless of + what the data sample shows. Escape hatch for when inference marks a + column ``NOT NULL`` because the sampled rows happened to be dense but + downstream files carry nulls in that column - common with cluster + loads where one file's preview can't speak for the rest. Cheaper than + trying to sharpen the sampler: widen the column and move on. """ original_formats: Dict[str, str] = dict(getattr(meta, "original_variable_types", {}) or {}) - # Row-walking type probes run on a bounded head slice; nullability and the - # all-null check still see every row so NOT NULL declarations stay honest. + # When ``TYPE_INFERENCE_SAMPLE_ROWS`` is an integer cap, row-walking type + # probes run on the head slice for speed; nullability and the all-null + # check still walk every row of ``df``. That's only honest when the + # caller handed us the full file - with the default cap of ``None`` the + # CLI does exactly that. Callers who pass a partial preview and a tight + # integer cap accept that ``NOT NULL`` can be wrong for rare-null columns. df_rows = len(df) effective_total = total_rows if total_rows is not None else df_rows if TYPE_INFERENCE_SAMPLE_ROWS is not None and df_rows > TYPE_INFERENCE_SAMPLE_ROWS: @@ -789,6 +1335,8 @@ def infer_schema( sample_size = df_rows sampled = sample_size < effective_total + overrides: Dict[str, str] = dict(column_types or {}) + # Temporarily flip the module-level flag if the caller asked us to. global COERCE_CHAR_COLUMNS saved = COERCE_CHAR_COLUMNS @@ -801,6 +1349,27 @@ def infer_schema( sas_format = original_formats.get(col) notes: List[str] = [] + if col in overrides: + pg_type = overrides[col] + notes.append( + f"type forced to {pg_type} via column_types override" + ) + if force_nullable: + nullable = True + notes.append("nullable forced via --all-nullable") + else: + nullable = _is_nullable(series) + out[col] = ColumnSpec( + name=col, + postgres_type=pg_type, + nullable=nullable, + sas_format=sas_format, + source_dtype=str(series.dtype), + notes=notes, + sampled=sampled, + ) + continue + pg_type = _format_driven_type(sas_format) if pg_type is None: @@ -831,7 +1400,11 @@ def infer_schema( f"{effective_total:,} rows" ) - nullable = _is_nullable(series) + if force_nullable: + nullable = True + notes.append("nullable forced via --all-nullable") + else: + nullable = _is_nullable(series) out[col] = ColumnSpec( name=col, @@ -964,6 +1537,30 @@ def _normalize_type(pg_type: str) -> str: return _TYPE_NORMALIZATION.get(stripped, stripped.lower()) +# Widening pairs: (inferred_from_source, existing_in_target). When the +# incoming spec is narrower than the target we accept it - the value is +# guaranteed to fit, and ``_prepare_for_copy`` already emits ``COPY`` +# payloads that Postgres silently promotes to the wider column type. The +# INVERSE direction stays a hard failure: a BIGINT value does not fit in +# an INTEGER column, so we must not let a cluster whose first file had +# only small ints accept a later file with a value past int32. Comes up +# most often on cluster loads where file 1 pushed the target to BIGINT +# (a single value > 2_147_483_647) and file N happens to sit entirely +# within int32 range - strict equality would reject file N even though +# the copy is trivially safe. +_WIDENING_COMPATIBLE: set = { + ("smallint", "integer"), + ("smallint", "bigint"), + ("integer", "bigint"), + ("real", "double precision"), + # INTEGER / BIGINT into DOUBLE PRECISION is lossless for int32 and + # exact up to 2**53 for int64, which covers every value pandas could + # have carried through as Int64 without wrapping anyway. + ("integer", "double precision"), + ("bigint", "double precision"), +} + + def _assert_schema_compatible( conn, schema: str, table: str, columns: Dict[str, ColumnSpec] ) -> None: @@ -990,11 +1587,22 @@ def _assert_schema_compatible( inferred_norm = _normalize_type(spec.postgres_type) target_norm = _normalize_type(target_type) if inferred_norm != target_norm: - mismatches.append( - f"column {name!r}: inferred {spec.postgres_type} " - f"(normalized {inferred_norm!r}) but target is {target_type} " - f"(normalized {target_norm!r})" - ) + if (inferred_norm, target_norm) in _WIDENING_COMPATIBLE: + # Narrower inferred type fits inside the wider target. + # Accept silently-but-noisily so the operator knows the + # file came in with a smaller range than the cluster's + # target was sized for. + warnings.append( + f"column {name!r}: inferred {spec.postgres_type} " + f"(narrower than target {target_type}); accepting - " + f"values fit in the wider target type" + ) + else: + mismatches.append( + f"column {name!r}: inferred {spec.postgres_type} " + f"(normalized {inferred_norm!r}) but target is {target_type} " + f"(normalized {target_norm!r})" + ) target_is_notnull = (target_nullable == "NO") if spec.nullable and target_is_notnull: warnings.append( @@ -1661,9 +2269,151 @@ def _seconds_to_time(v: Any) -> Optional[dt.time]: return dt.time(h, m, s) +# Safe outer bound for the numeric->datetime conversion below. The true +# ceiling is ``pd.Timestamp.max`` (2262-04-11), which in seconds since 1960 +# is ~9.52e9. We pick a much tighter bound - year ~2200, ~7.6e9 seconds, +# ~87600 days - because (a) any real SAS data past ~2100 is garbage anyway, +# and (b) staying well inside the float64 + datetime64[ns] windows gives +# pandas' internals zero room to trip the ``over="raise"`` they wrap +# around the ns-multiply. ``7.5e9 * 1e9 = 7.5e18``, comfortably under both +# ``int64.max`` (~9.22e18) and float64 overflow (~1.8e308). +_SAS_DATETIME_SAFE_S = 7_500_000_000 +_SAS_DATETIME_SAFE_D = 87_000 + + +def _safe_numeric_to_datetime( + series: pd.Series, + *, + unit: str, + column_name: str, + target_type: str, +) -> pd.Series: + """Convert a numeric SAS-epoch series to ``datetime64[ns]`` without letting + one stray cell take down the worker. + + Failure modes seen in production: + + * ``np.inf`` / ``-np.inf`` slipping through pyreadstat (SAS missing-value + sentinels, divide-by-zero in the source, uninitialized cells). + * Absurdly large finite floats (e.g. ``1.7e308``) where ``value * 1e9`` + overflows float64. + * Values between ``pd.Timestamp.max`` and float64 safety (~9.5e9 to 1e308 + seconds) where the nanosecond multiply silently produces garbage or + overflows int64. + + All of these trigger ``FloatingPointError: overflow encountered in multiply`` + inside ``pd.to_datetime`` because pandas wraps the multiply in + ``np.errstate(over="raise")`` -- our outer ``errors="coerce"`` never + gets a chance to turn the bad value into ``NaT``. + + Strategy, belt + suspenders + airbag: + + 1. Coerce to float64 up front. Object-dtype branches hand us mixed + int/float/str; ``pd.to_numeric(errors="coerce")`` parses what it can + and NaNs the rest, so we hit the rest of this function with a + pristine float series. + 2. Mask non-finite values and anything outside the safe epoch window to + NaN *before* ``pd.to_datetime`` sees them. + 3. Run the conversion under a permissive ``errstate``. + 4. If that still raises (some pandas version internally re-enables + ``over="raise"`` in a way ``errstate`` can't override), catch it + and return all-NaT for the column with a loud warning. Better a + NULL column in one chunk than a dead worker + no diagnostics. + + Emits one stderr line per chunk per affected column so silent data + loss doesn't sneak by. + """ + if not pd.api.types.is_float_dtype(series): + series = pd.to_numeric(series, errors="coerce").astype("float64") + + arr = series.to_numpy(dtype="float64", copy=False, na_value=np.nan) + if unit == "s": + bound = _SAS_DATETIME_SAFE_S + elif unit == "D": + bound = _SAS_DATETIME_SAFE_D + else: + bound = _SAS_DATETIME_SAFE_S + with np.errstate(over="ignore", invalid="ignore", divide="ignore"): + finite_mask = np.isfinite(arr) + # ``np.abs(inf) -> inf``, ``np.abs(nan) -> nan``; both compare False + # to ``bound``, so ``in_range_mask`` already excludes non-finite + # values. The explicit ``finite_mask &`` below is belt-and-suspenders + # in case a future numpy changes that semantic. + in_range_mask = np.abs(arr) < bound + keep_mask = finite_mask & in_range_mask + was_present = ~np.isnan(arr) + coerced = int(((~keep_mask) & was_present).sum()) + if coerced: + tqdm.write( + f"[warn] {target_type} column {column_name!r}: {coerced:,} " + f"row(s) had non-representable values (Inf/NaN/out-of-range), " + f"coerced to NULL", + file=sys.stderr, + ) + cleaned_arr = np.where(keep_mask, arr, np.nan) + cleaned = pd.Series(cleaned_arr, index=series.index) + try: + with np.errstate(over="ignore", invalid="ignore", divide="ignore"): + return pd.to_datetime( + cleaned, unit=unit, origin="1960-01-01", errors="coerce", + ) + except (FloatingPointError, OverflowError, ValueError) as exc: + tqdm.write( + f"[error] {target_type} column {column_name!r}: " + f"pd.to_datetime raised {type(exc).__name__}: {exc}; " + f"returning NaT for the entire chunk. This usually means one " + f"or more values slipped past the pre-mask (bound={bound}). " + f"Consider setting the column to TEXT via column_types if this " + f"recurs.", + file=sys.stderr, + ) + return pd.Series(pd.NaT, index=series.index, dtype="datetime64[ns]") + + +def _safe_object_to_datetime( + series: pd.Series, + *, + column_name: str, + target_type: str, +) -> pd.Series: + """Object-dtype to datetime. Shares the safety net (errstate + + try/except) with :func:`_safe_numeric_to_datetime`. If the column is + actually numeric-flavored (e.g. SAS wrote numbers into an object + column), route to the numeric path; otherwise parse with ``to_datetime`` + on the object itself. + """ + coerced = series.replace({"": None}) + numeric = pd.to_numeric(coerced, errors="coerce") + all_numeric = numeric.notna().sum() == coerced.notna().sum() + if all_numeric and coerced.notna().any(): + return _safe_numeric_to_datetime( + numeric, unit="s", column_name=column_name, target_type=target_type, + ) + try: + with np.errstate(over="ignore", invalid="ignore", divide="ignore"): + return pd.to_datetime(coerced, errors="coerce") + except (FloatingPointError, OverflowError, ValueError) as exc: + tqdm.write( + f"[error] {target_type} column {column_name!r}: " + f"pd.to_datetime raised {type(exc).__name__}: {exc}; " + f"returning NaT for the entire chunk.", + file=sys.stderr, + ) + return pd.Series(pd.NaT, index=series.index, dtype="datetime64[ns]") + + def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.DataFrame: """Materialize a copy of ``df`` with each column in the right shape for ``to_csv`` so the CSV lands as valid input for the target Postgres type. + + Per-column conversions are vectorized (``.astype`` / ``pd.to_datetime`` / + ``.mask`` / ``.fillna``) instead of the element-wise ``.map(func)`` + loops this function used to run. That was the single largest per-chunk + CPU cost on text-heavy loads - a 40-column × 100k-row chunk was issuing + ~4M Python-level function calls just to cast strings. TIME columns are + still the ``.map`` path because SAS TIME8 is stored as seconds and the + clamp-to-24h logic doesn't fit cleanly in vector form; they're also + rare in practice. """ out = pd.DataFrame(index=df.index) for name, spec in columns.items(): @@ -1686,59 +2436,86 @@ def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.Da if pd.api.types.is_datetime64_any_dtype(series): out[name] = series.dt.date elif pd.api.types.is_object_dtype(series): - def _to_date(v: Any) -> Optional[dt.date]: - if v is None or (isinstance(v, float) and pd.isna(v)): - return None - if isinstance(v, dt.datetime): - return v.date() - if isinstance(v, dt.date): - return v - if isinstance(v, str): - if v == "": - return None - try: - return dt.date.fromisoformat(v) - except ValueError: - return None - return None - out[name] = series.map(_to_date) + # Vectorized parse: empty strings / None / unparseable -> NaT, + # then .dt.date yields date objects or NaT. NaT serializes as + # an empty CSV field (matching ``NULL ''`` in COPY). Routed + # through ``_safe_object_to_datetime`` so an object column + # that actually contains SAS-epoch numerics (seen when one + # file of a cluster stores the column as NUM and another as + # CHAR + the union flipped it to TEXT-then-DATE) can't trip + # the overflow-in-multiply bug. + parsed = _safe_object_to_datetime( + series, column_name=name, target_type="DATE", + ) + out[name] = parsed.dt.date + elif pd.api.types.is_numeric_dtype(series): + # pyreadstat couldn't decode the SAS format (some + # ``DATEw.``/``YYMMDDw.`` variants and all custom formats slip + # through) so the column came back as float64: days since + # 1960-01-01, the SAS epoch. Without this branch the raw + # number would hit COPY and Postgres rejects it with + # ``invalid input syntax for type date``. + parsed = _safe_numeric_to_datetime( + series, unit="D", column_name=name, target_type="DATE", + ) + out[name] = parsed.dt.date else: out[name] = series elif pg in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE", "TIMESTAMP WITH TIME ZONE"): if pd.api.types.is_datetime64_any_dtype(series): out[name] = series elif pd.api.types.is_object_dtype(series): - def _to_dt(v: Any) -> Optional[dt.datetime]: - if v is None or (isinstance(v, float) and pd.isna(v)): - return None - if isinstance(v, dt.datetime): - return v - if isinstance(v, dt.date): - return dt.datetime(v.year, v.month, v.day) - if isinstance(v, pd.Timestamp): - return v.to_pydatetime() if not pd.isna(v) else None - if isinstance(v, str): - if v == "": - return None - try: - return dt.datetime.fromisoformat(v) - except ValueError: - return None - return None - out[name] = series.map(_to_dt) + # Same rationale as the DATE object branch above: route + # through the safety net so numeric-flavored object columns + # can't blow us up during the ns multiply. + out[name] = _safe_object_to_datetime( + series, column_name=name, target_type="TIMESTAMP", + ) + elif pd.api.types.is_numeric_dtype(series): + # Same story as the DATE branch above, but SAS datetimes are + # *seconds* since 1960-01-01 (fractional seconds for + # ``DATETIMEw.d``). Example caught in the wild: + # ``1915465463.615`` -> 2020-09-13 05:44:23.615. + out[name] = _safe_numeric_to_datetime( + series, unit="s", column_name=name, target_type="TIMESTAMP", + ) else: out[name] = series elif pg in ("TIME", "TIME WITHOUT TIME ZONE", "TIME WITH TIME ZONE"): out[name] = series.map(_seconds_to_time) elif pg in ("TEXT", "VARCHAR", "CHARACTER VARYING", "CHAR", "CHARACTER"): - # Leave empty strings as "" so `NULL ''` in COPY turns them into NULL. - def _to_str(v: Any) -> Any: - if v is None: - return "" - if isinstance(v, float) and pd.isna(v): - return "" - return str(v) - out[name] = series.map(_to_str) + # Render every cell as a string and blank out nulls. ``NULL ''`` + # in the COPY statement turns the blanks back into SQL NULL. + # astype(str) stringifies NaN/None to the literal "nan"/"None", + # so we mask those after the fact rather than branching per cell. + na_mask = series.isna() + if pd.api.types.is_numeric_dtype(series): + # Hit when a column was auto-unioned to TEXT because at + # least one file of the cluster stored it as CHAR but this + # particular file stored it as NUM (typical of SAS phone-id + # columns). Default float formatting would emit "123.0" - + # which doesn't match the plain "123" coming from the CHAR + # files. When the whole chunk is integer-valued, round to + # int before stringifying; when any fractional value is + # present we leave float formatting alone so we don't + # silently drop precision. + nonnull = series.dropna() + int_like = False + if not nonnull.empty: + try: + int_like = bool(((nonnull % 1) == 0).all()) + except TypeError: + int_like = False + if int_like: + # ``Int64`` preserves NA; ``.astype(str)`` renders NA + # as '', which we then mask out alongside original + # NaNs. + as_str = series.astype("Int64").astype(str) + out[name] = as_str.mask(na_mask, "") + else: + out[name] = series.astype(str).mask(na_mask, "") + else: + out[name] = series.astype(str).mask(na_mask, "") elif pg == "BOOLEAN": out[name] = series.astype("boolean") if series.dtype != object else series else: @@ -1746,6 +2523,25 @@ def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.Da return out +def _serialize_chunk_csv(prepared: pd.DataFrame) -> io.BytesIO: + """Serialize a prepared frame into a CSV buffer for ``COPY FROM STDIN``. + + Uses ``pyarrow.csv.write_csv`` (typically 5-10× faster than pandas' + pure-Python ``to_csv`` on wide/text-heavy frames). Null cells serialize + as empty strings and date/timestamp values land in ISO 8601 form, both + of which Postgres accepts under ``FORMAT csv, NULL ''``. + """ + table = pa.Table.from_pandas(prepared, preserve_index=False) + buf = io.BytesIO() + pa_csv.write_csv( + table, + buf, + write_options=pa_csv.WriteOptions(include_header=False), + ) + buf.seek(0) + return buf + + def copy_dataframes( conn, schema_name: str, @@ -1768,23 +2564,43 @@ def copy_dataframes( ) total = 0 + # Pull chunks one at a time so each ``df`` is unreferenced before the + # generator reads the next one. Without this the loop-variable binding + # of a ``for df in dfs:`` keeps the previous chunk alive during the + # next pyreadstat read, pushing peak memory to 5-6× chunk size per + # worker (old df + incoming df + prepared + pyarrow table + CSV buf). + # With explicit drops we cap peak at ~2× chunk size: ``df`` goes away + # once ``prepared`` exists, ``prepared`` once ``buf`` exists, ``buf`` + # once COPY has consumed it. Matters most in parallel mode where + # 32 × per-worker peak can exhaust a 128 GB host. + dfs_iter = iter(dfs) with conn.cursor() as cur: - for df in dfs: + while True: + try: + df = next(dfs_iter) + except StopIteration: + break if df.empty: + del df continue prepared = _prepare_for_copy(df, columns) - buf = io.StringIO() - prepared.to_csv( - buf, - index=False, - header=False, - na_rep="", - date_format="%Y-%m-%d %H:%M:%S", - ) - buf.seek(0) + del df + n = len(prepared) + buf = _serialize_chunk_csv(prepared) + del prepared cur.copy_expert(sql, buf) + del buf conn.commit() - total += len(prepared) + total += n + # Hand pyarrow's pool memory back between chunks. Without this, + # arrow's internal buffer pool keeps the high-water bytes + # reserved across the worker's lifetime - inside long-running + # workers this presents as steadily climbing RSS even with the + # ``del``s above. Cheap (microseconds); call it every chunk. + try: + pa.default_memory_pool().release_unused() + except Exception: + pass return total @@ -1875,7 +2691,7 @@ def validate_against_manifest( def _build_argparser() -> argparse.ArgumentParser: p = argparse.ArgumentParser( - description="Load a single SAS file (XPT or sas7bdat) into Postgres.", + description="Load a single data file (SAS or delimited text) into Postgres.", ) p.add_argument("--config", required=True, type=Path, help="Path to YAML config") p.add_argument( @@ -1899,6 +2715,16 @@ def _build_argparser() -> argparse.ArgumentParser: "PGUSER / PGPASSWORD from the environment or .env file." ), ) + p.add_argument( + "--all-nullable", + action="store_true", + help=( + "Stamp every column nullable in the generated schema, bypassing " + "NOT NULL inference. Use when sampled rows wrongly suggest a " + "column has no nulls. Overrides ``all_nullable`` in the YAML " + "config when set." + ), + ) return p @@ -1918,16 +2744,35 @@ def main(argv: Optional[List[str]] = None) -> int: cfg = load_config(args.config) if not cfg.filename.exists(): - print(f"error: SAS file not found: {cfg.filename}", file=sys.stderr) + file_label = "text file" if cfg.file_type == "text" else "SAS file" + print(f"error: {file_label} not found: {cfg.filename}", file=sys.stderr) return 2 - # Schema inference uses a bounded preview read so we never load a - # hundreds-of-millions-of-rows file into memory just to pick types. - # NB: ``meta.number_rows`` on a ``row_limit``-ed read reflects rows - # returned, not the file's total, so we don't trust it here. - preview_df, meta = read_sas_preview(cfg.filename) + # Build kwargs dict for text-file parameters. These are passed through + # to the unified reader functions and silently ignored for SAS files. + _text_kw: Dict[str, Any] = dict( + delimiter=cfg.delimiter, + text_encoding=cfg.text_encoding, + quotechar=cfg.quotechar, + ) + + # Schema inference reads the whole file so type + nullability are + # computed against every row. That's what the target host has the + # resources for and is the only way to honestly emit ``NOT NULL`` - + # a bounded preview routinely missed the ~0.2% of rows with nulls on + # otherwise-dense keys (e.g. MAFID). If you're on a box that can't + # fit the file in memory, override ``TYPE_INFERENCE_SAMPLE_ROWS`` to + # an integer cap and know that sampled specs may stamp ``NOT NULL`` + # on columns whose nulls live past the window. + preview_df, meta = read_sas_preview(cfg.filename, **_text_kw) preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude) - columns = infer_schema(preview_df, meta) + force_nullable = args.all_nullable or cfg.all_nullable + columns = infer_schema( + preview_df, + meta, + column_types=cfg.column_types, + force_nullable=force_nullable, + ) # Validate partition columns exist in the schema after filtering. if cfg.partition_by: @@ -1968,7 +2813,7 @@ def main(argv: Optional[List[str]] = None) -> int: print(" discovering partition values (full file scan)...", file=sys.stderr) def _discovery_chunks(): - for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename): + for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename, **_text_kw): yield apply_column_filter(chunk_df, cfg.include, cfg.exclude) partition_values = discover_partition_values_chunked( @@ -2018,13 +2863,24 @@ def main(argv: Optional[List[str]] = None) -> int: # it while we're holding a Postgres transaction open. del preview_df + total_rows = getattr(meta, "number_rows", None) + def _filtered_chunks(): - seen = 0 - for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename): - chunk_df = apply_column_filter(chunk_df, cfg.include, cfg.exclude) - seen += len(chunk_df) - print(f" streaming... {seen:,} rows", file=sys.stderr) - yield chunk_df + pbar = tqdm( + total=total_rows, + unit="row", + unit_scale=True, + desc=f" {cfg.filename.name}", + file=sys.stderr, + dynamic_ncols=True, + ) + try: + for chunk_df, _chunk_meta in iter_sas_chunks(cfg.filename, **_text_kw): + chunk_df = apply_column_filter(chunk_df, cfg.include, cfg.exclude) + pbar.update(len(chunk_df)) + yield chunk_df + finally: + pbar.close() db_user = db_password = None if args.dbcreds: diff --git a/generic_loader/sample_config.yaml b/generic_loader/sample_config.yaml index c487769..38c1e6a 100644 --- a/generic_loader/sample_config.yaml +++ b/generic_loader/sample_config.yaml @@ -16,6 +16,23 @@ tablename: kitchensink # Defaults to fail. if_exists: append +# file_type: Type of data file to load. One of: sas | text. Default: sas. +# sas - SAS files (.sas7bdat, .xpt, .xport) read via pyreadstat +# text - Delimited text files (.txt, .csv, .tsv) read via pandas +# file_type: sas + +# delimiter: Column delimiter for text files. Only used when file_type: text. +# Accepts: "," (comma, default), "tab" or "\t" (tab), "pipe" or "|" (pipe), +# or any single character. +# delimiter: "," + +# text_encoding: Character encoding for text files. Default: utf-8. +# Common alternatives: latin-1, cp1252, iso-8859-1. +# text_encoding: utf-8 + +# quotechar: Quote character for text files. Default: '"' (double quote). +# quotechar: '"' + # partition_by: Partition the table by unique values of these columns. # Columns are applied in cascading order (first column = top-level partition). # Requires if_exists: replace or fail (not append for initial creation). @@ -38,3 +55,24 @@ if_exists: append # indexes: # - state # - zip + +# column_types: Explicit {column_name: postgres_type} overrides that +# bypass automatic type inference for the listed columns. Useful when +# pyreadstat reports a column as NUM but you want it stored as TEXT +# (phone/ID columns that are conceptually strings), or when a column's +# inferred type is off for any other reason. Columns not listed here +# fall through to the normal inference path. Nullability is always +# computed from the data. +# +# column_types: +# RESP_PH_PREFIX_ID: TEXT +# SOMELONG_ID: BIGINT + +# all_nullable: If true, every column is stamped nullable in the generated +# schema; NOT NULL inference is skipped entirely. Use this when the sampler +# wrongly concludes a column has no nulls (e.g. a dense sample followed by +# rare-null data downstream) and COPY blows up mid-load on the first null +# it hits. Off by default. The CLI flag --all-nullable overrides this to +# true when set. +# +# all_nullable: false diff --git a/generic_loader/sample_folder_config.yaml b/generic_loader/sample_folder_config.yaml index 5740c3f..909c57b 100644 --- a/generic_loader/sample_folder_config.yaml +++ b/generic_loader/sample_folder_config.yaml @@ -27,6 +27,25 @@ if_exists: replace # see the embedded-digit example near the bottom of this file. auto_detect: true +# file_type: Type of data files in this folder. One of: sas | text. Default: sas. +# sas - SAS files (.sas7bdat, .xpt, .xport) read via pyreadstat +# text - Delimited text files (.txt, .csv, .tsv) read via pandas +# When set to 'text', the folder scanner looks for .txt/.csv/.tsv files +# instead of .sas7bdat/.xpt/.xport files. +# file_type: sas + +# delimiter: Column delimiter for text files. Only used when file_type: text. +# Accepts: "," (comma, default), "tab" or "\t" (tab), "pipe" or "|" (pipe), +# or any single character. +# delimiter: "," + +# text_encoding: Character encoding for text files. Default: utf-8. +# Common alternatives: latin-1, cp1252, iso-8859-1. +# text_encoding: utf-8 + +# quotechar: Quote character for text files. Default: '"' (double quote). +# quotechar: '"' + # Folder-level column filter. Every file in every cluster passes through # this filter. `include` and `exclude` are mutually exclusive. A cluster can # override these via its own `include` / `exclude` keys. @@ -61,15 +80,52 @@ auto_detect: true # - state # - zip +# Folder-level column_types: Explicit {column_name: postgres_type} map that +# bypasses automatic type inference for the listed columns. Applied to +# every cluster unless a cluster supplies its own column_types, which are +# merged on top (cluster entries win on conflict). +# +# During --workers>1 runs the pre-scan derives a cluster-wide "auto-union" +# type per column (e.g. any file stores the column as CHAR -> TEXT; all +# NUM with any format hinting decimals -> DOUBLE PRECISION; otherwise +# BIGINT). Entries in column_types here win over that auto-union - use +# them when the auto result is wrong or when --no-prescan disables the +# auto-union and you still need to pin a column. +# +# Valid type strings are anything the CREATE TABLE DDL accepts (TEXT, +# INTEGER, BIGINT, DOUBLE PRECISION, DATE, TIMESTAMP, ...). Columns that +# don't exist in a given file are simply ignored for that file. +# +# column_types: +# RESP_PH_PREFIX_ID: TEXT +# RESP_PH_SUFFIX_ID: TEXT +# SOMELONG_ID: BIGINT + +# Folder-level all_nullable: If true, every column of every cluster is +# stamped nullable in the generated schema; NOT NULL inference is skipped +# entirely. Use this when the sampler wrongly concludes a column has no +# nulls (sampled rows happened to be dense, but later files in the cluster +# carry nulls) and COPY blows up mid-load. Inherited by all clusters +# unless a cluster supplies its own all_nullable. The CLI flag +# --all-nullable overrides both this and any per-cluster setting when +# passed. Off by default. +# +# all_nullable: false + # Explicit cluster patterns. Each pattern is matched against the file # *basename*. Files matched by a pattern are pulled out of the auto-detect # pool, so explicit and auto clusters compose cleanly. # -# `tablename` is required. `if_exists`, `include`, and `exclude` are -# optional per-cluster overrides of the folder-level defaults above. +# `tablename` is required. `if_exists`, `include`, `exclude`, and +# `column_types` are optional per-cluster overrides of the folder-level +# defaults above. Cluster-level column_types entries win over folder- +# level entries for the same column. clusters: - pattern: '^group_a\d+\.xpt$' tablename: group_a + # column_types: + # INTCOL: TEXT + # all_nullable: true # per-cluster override of the folder-level default # Example of an explicit override. Uncomment to force the group_b cluster to # append instead of replace even though the folder default is "replace": @@ -111,6 +167,10 @@ clusters: # - pattern: '^year2020_regionA_\d+_detail\.sas7bdat$' # tablename: year2020_regionA_detail + # Text file cluster example (when file_type: text): + # - pattern: '^data_group_a\d+\.txt$' + # tablename: data_group_a + # With only the group_a pattern explicit, auto_detect: true will still # bucket group_b1.xpt + group_b2.xpt into a "group_b" cluster and the lone # standalone.xpt into a "standalone" cluster. See generate_sample_folder.py diff --git a/requirements.txt b/requirements.txt index 3422064..c09f442 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,10 @@ pandas>=2.0,<3.0 pyreadstat>=1.2,<2.0 numpy>=2.1,<3.0 +pyarrow>=22.0,<24.0 pyyaml>=6.0,<7.0 psycopg2-binary>=2.9,<3.0 python-dotenv>=1.0,<2.0 boto3>=1.28,<2.0 +openpyxl>=3.1,<4.0 +tqdm>=4.66,<5.0 diff --git a/utils/data_explorer.py b/utils/data_explorer.py index bd25498..d06f740 100644 --- a/utils/data_explorer.py +++ b/utils/data_explorer.py @@ -1,30 +1,39 @@ """Explore S3 directories and categorise them by accessibility. Reads a text file containing one S3 prefix per line (paths within the bucket -configured by the ``S3_BUCKET`` constant), then for each prefix: +configured by the ``S3_BUCKET`` constant or ``--bucket`` CLI argument), then +for each prefix: + - Lists all objects recursively (via ``list_objects_v2`` paginator) -- **Only considers files matching the ``FILE_EXTENSION`` filter** (default - ``.sas7bdat``). All other file types are ignored. +- **Only considers files matching the configured extensions** (default: all + supported extensions — SAS and text). All other file types are ignored. - Tests read permission with ``head_object`` on the first matching file found - If the first file is accessible, tests ALL remaining files individually - Categorises the directory as **Available**, **Blocked**, **Empty**, and tracks individual file **Exceptions** within available directories +Supported file types +-------------------- +* **SAS files**: ``.sas7bdat``, ``.xpt``, ``.xport`` +* **Text / delimited files**: ``.txt``, ``.csv``, ``.tsv`` + A directory is considered *empty* if it contains no files matching the extension filter, even when other file types are present. -Configure the constants below, then run:: +Configure the constants below (or use CLI arguments), then run:: - python3 data_explorer.py + python3 data_explorer.py [OPTIONS] -Python 3.10+ compatible. Requires only ``boto3`` / ``botocore`` and stdlib. +Python 3.10+ compatible. Requires ``boto3`` / ``botocore`` and stdlib. """ from __future__ import annotations +import argparse +import os import sys from dataclasses import dataclass, field -from typing import List, Tuple +from typing import List, Set, Tuple # --------------------------------------------------------------------------- # Dependency check @@ -43,11 +52,25 @@ except ImportError: # --------------------------------------------------------------------------- -# Configuration — edit these before running +# Extension constants # --------------------------------------------------------------------------- -FILE_EXTENSION: str = ".sas7bdat" -"""Only files whose key ends with this extension (case-insensitive) are considered.""" +SAS_EXTENSIONS: Set[str] = {".sas7bdat", ".xpt", ".xport"} +"""File extensions recognised as SAS data files.""" + +TEXT_EXTENSIONS: Set[str] = {".txt", ".csv", ".tsv"} +"""File extensions recognised as delimited text / CSV files.""" + +SUPPORTED_EXTENSIONS: Set[str] = SAS_EXTENSIONS | TEXT_EXTENSIONS +"""Union of all file extensions this tool can work with.""" + + +# --------------------------------------------------------------------------- +# Configuration defaults — edit these or override via CLI arguments +# --------------------------------------------------------------------------- + +FILE_EXTENSIONS: Set[str] = SUPPORTED_EXTENSIONS +"""Set of extensions to filter on (case-insensitive). Defaults to all supported.""" INPUT_FILE: str = "s3_directories.txt" """Path to the text file containing one S3 prefix per line.""" @@ -58,6 +81,57 @@ S3_BUCKET: str = "my-bucket" AWS_PROFILE: str = "default" """AWS CLI profile name used for authentication.""" +# Text-file reading defaults (used when downloading / previewing text files) +DEFAULT_DELIMITER: str = "," +DEFAULT_ENCODING: str = "utf-8" +DEFAULT_QUOTECHAR: str = '"' + + +# --------------------------------------------------------------------------- +# Auto-detection helpers +# --------------------------------------------------------------------------- + + +def detect_file_type(filename: str) -> str: + """Return ``'sas'``, ``'text'``, or ``'unknown'`` based on *filename* extension. + + The check is case-insensitive. For ``.tsv`` files the caller should + default the delimiter to a tab character (``'\\t'``). + + Examples + -------- + >>> detect_file_type("data.sas7bdat") + 'sas' + >>> detect_file_type("report.CSV") + 'text' + >>> detect_file_type("archive.zip") + 'unknown' + """ + ext = os.path.splitext(filename)[1].lower() + if ext in SAS_EXTENSIONS: + return "sas" + if ext in TEXT_EXTENSIONS: + return "text" + return "unknown" + + +def default_delimiter_for(filename: str) -> str: + """Return a sensible default delimiter for *filename*. + + * ``.tsv`` → ``'\\t'`` + * everything else → ``','`` + """ + ext = os.path.splitext(filename)[1].lower() + if ext == ".tsv": + return "\t" + return "," + + +def matches_extensions(key: str, extensions: Set[str]) -> bool: + """Return ``True`` if *key* ends with any extension in *extensions* (case-insensitive).""" + key_lower = key.lower() + return any(key_lower.endswith(ext) for ext in extensions) + # --------------------------------------------------------------------------- # Data structures @@ -149,27 +223,36 @@ def format_size(size_bytes: int) -> str: return f"{size_bytes:,.1f} TB" +def extensions_label(extensions: Set[str]) -> str: + """Return a compact, sorted label for a set of extensions (e.g. ``.csv/.tsv/.txt``).""" + return "/".join(sorted(extensions)) + + def list_objects( s3_client: "botocore.client.S3", bucket: str, prefix: str, + extensions: Set[str] | None = None, ) -> Tuple[List[Tuple[str, int]], int]: """Recursively list all objects under *prefix*. - Only objects whose key ends with ``FILE_EXTENSION`` (case-insensitive) are - counted. All other files are silently skipped. + Only objects whose key ends with one of *extensions* (case-insensitive) are + counted. All other files are silently skipped. When *extensions* is + ``None`` the module-level ``FILE_EXTENSIONS`` set is used. Returns ``(files, total_size)`` where *files* is a list of ``(key, size)`` tuples for every matching object and *total_size* is the sum of their sizes in bytes. """ - ext_lower = FILE_EXTENSION.lower() + if extensions is None: + extensions = FILE_EXTENSIONS + exts_lower = {e.lower() for e in extensions} paginator = s3_client.get_paginator("list_objects_v2") files: List[Tuple[str, int]] = [] total_size: int = 0 for page in paginator.paginate(Bucket=bucket, Prefix=prefix): for obj in page.get("Contents", []): - if not obj["Key"].lower().endswith(ext_lower): + if not any(obj["Key"].lower().endswith(ext) for ext in exts_lower): continue files.append((obj["Key"], obj["Size"])) total_size += obj["Size"] @@ -196,8 +279,26 @@ def check_read_permission( # --------------------------------------------------------------------------- -def explore_directories(prefixes: List[str]) -> Results: - """Explore every prefix in ``S3_BUCKET`` and return categorised *Results*.""" +def explore_directories( + prefixes: List[str], + *, + extensions: Set[str] | None = None, +) -> Results: + """Explore every prefix in ``S3_BUCKET`` and return categorised *Results*. + + Parameters + ---------- + prefixes: + List of S3 key prefixes to explore. + extensions: + Set of file extensions to filter on. Defaults to the module-level + ``FILE_EXTENSIONS`` (which itself defaults to ``SUPPORTED_EXTENSIONS``). + """ + if extensions is None: + extensions = FILE_EXTENSIONS + exts_lower = {e.lower() for e in extensions} + ext_label = extensions_label(extensions) + session = boto3.Session(profile_name=AWS_PROFILE) s3 = session.client("s3") @@ -206,13 +307,13 @@ def explore_directories(prefixes: List[str]) -> Results: for idx, prefix in enumerate(prefixes, start=1): print( - f"[{idx}/{total}] Checking {prefix} (filtering for {FILE_EXTENSION}) ...", + f"[{idx}/{total}] Checking {prefix} (filtering for {ext_label}) ...", file=sys.stderr, ) # --- Recursive listing ------------------------------------------------ try: - files, total_size = list_objects(s3, S3_BUCKET, prefix) + files, total_size = list_objects(s3, S3_BUCKET, prefix, extensions=extensions) except botocore.exceptions.ClientError as exc: code = exc.response.get("Error", {}).get("Code", "Unknown") message = exc.response.get("Error", {}).get("Message", str(exc)) @@ -234,12 +335,13 @@ def explore_directories(prefixes: List[str]) -> Results: # --- Permission check on first file ----------------------------------- # Prefer a real object over a zero-byte directory marker (key ending - # in "/") for the head_object test. + # in "/") for the head_object test. The selected key must also match + # the extension filter. first_key, _ = files[0] test_key = first_key if first_key.endswith("/") and total_size > 0: for key, size in files: - if not (key.endswith("/") and size == 0): + if not (key.endswith("/") and size == 0) and matches_extensions(key, exts_lower): test_key = key break @@ -268,7 +370,7 @@ def explore_directories(prefixes: List[str]) -> Results: if remaining: if len(remaining) > 10: print( - f" Verifying access to {file_count} {FILE_EXTENSION} files in {prefix} ...", + f" Verifying access to {file_count} {ext_label} files in {prefix} ...", file=sys.stderr, ) @@ -306,11 +408,25 @@ def explore_directories(prefixes: List[str]) -> Results: # --------------------------------------------------------------------------- -def print_results(results: Results) -> None: - """Print a clean, human-readable summary to stdout.""" +def print_results(results: Results, *, extensions: Set[str] | None = None) -> None: + """Print a clean, human-readable summary to stdout. + + Parameters + ---------- + results: + The exploration results to display. + extensions: + The set of extensions that were used for filtering. Used only for + labelling in the output. Defaults to ``FILE_EXTENSIONS``. + """ + if extensions is None: + extensions = FILE_EXTENSIONS + ext_label = extensions_label(extensions) + print() print("=== S3 Directory Explorer Results ===") print(f"Bucket: {S3_BUCKET}") + print(f"Extensions: {ext_label}") # --- Available --- print() @@ -319,7 +435,7 @@ def print_results(results: Results) -> None: for d in results.available: print(f" {d.prefix}") print( - f" {FILE_EXTENSION} files: {d.accessible_count}/{d.total_count} accessible" + f" Matching files ({ext_label}): {d.accessible_count}/{d.total_count} accessible" f" | Total Size: {format_size(d.accessible_size)}" ) else: @@ -332,7 +448,7 @@ def print_results(results: Results) -> None: for d in results.blocked: if d.file_count: print(f" {d.prefix}") - print(f" {FILE_EXTENSION} files found: {d.file_count} | Error: {d.error}") + print(f" Matching files ({ext_label}) found: {d.file_count} | Error: {d.error}") else: print(f" {d.prefix}") print(f" Error: {d.error}") @@ -351,7 +467,7 @@ def print_results(results: Results) -> None: # --- Empty --- print() - print(f"--- Empty / no {FILE_EXTENSION} files ({len(results.empty)}) ---") + print(f"--- Empty / no matching files ({len(results.empty)}) ---") if results.empty: for d in results.empty: print(f" {d.prefix}") @@ -361,20 +477,163 @@ def print_results(results: Results) -> None: print() +# --------------------------------------------------------------------------- +# CLI argument parsing +# --------------------------------------------------------------------------- + + +def build_arg_parser() -> argparse.ArgumentParser: + """Build and return the CLI argument parser. + + Supports selecting file-type filters, text-file reading parameters, and + overriding the default bucket / profile / input-file settings. + """ + parser = argparse.ArgumentParser( + description=( + "Explore S3 directories and categorise them by accessibility. " + "Supports SAS files (.sas7bdat, .xpt, .xport) and delimited text " + "files (.txt, .csv, .tsv)." + ), + ) + + # --- File-type / extension selection --- + type_group = parser.add_argument_group("File-type selection") + type_group.add_argument( + "--file-type", + choices=["sas", "text", "all"], + default="all", + help=( + "Restrict the scan to a specific file type. " + "'sas' = .sas7bdat/.xpt/.xport only; " + "'text' = .txt/.csv/.tsv only; " + "'all' = both (default)." + ), + ) + type_group.add_argument( + "--extensions", + nargs="+", + metavar="EXT", + help=( + "Explicit list of extensions to filter on (e.g. --extensions .csv .tsv). " + "Overrides --file-type when provided." + ), + ) + + # --- Text-file reading parameters --- + text_group = parser.add_argument_group( + "Text-file parameters", + description=( + "Parameters used when reading delimited text files. These are " + "stored for downstream consumers and do not affect the S3 scan " + "itself." + ), + ) + text_group.add_argument( + "--delimiter", + default=None, + help=( + "Field delimiter for text files (default: ',' for .csv/.txt, " + "'\\t' for .tsv). Use 'tab' or '\\t' for a tab character." + ), + ) + text_group.add_argument( + "--encoding", + default=DEFAULT_ENCODING, + help=f"Character encoding for text files (default: {DEFAULT_ENCODING}).", + ) + text_group.add_argument( + "--quotechar", + default=DEFAULT_QUOTECHAR, + help=f"Quote character for text files (default: {DEFAULT_QUOTECHAR!r}).", + ) + + # --- S3 / general settings --- + s3_group = parser.add_argument_group("S3 settings") + s3_group.add_argument( + "--bucket", + default=None, + help=f"S3 bucket name (default: {S3_BUCKET}).", + ) + s3_group.add_argument( + "--profile", + default=None, + help=f"AWS CLI profile name (default: {AWS_PROFILE}).", + ) + s3_group.add_argument( + "--input-file", + default=None, + help=f"Path to the text file with S3 prefixes (default: {INPUT_FILE}).", + ) + + return parser + + +def resolve_extensions(args: argparse.Namespace) -> Set[str]: + """Determine the active extension set from parsed CLI *args*. + + If ``--extensions`` is provided it takes precedence. Otherwise + ``--file-type`` is used to select a predefined set. + """ + if args.extensions: + # Normalise: ensure each extension starts with a dot and is lowercase + exts: Set[str] = set() + for ext in args.extensions: + ext = ext.strip().lower() + if not ext.startswith("."): + ext = "." + ext + exts.add(ext) + return exts + + if args.file_type == "sas": + return SAS_EXTENSIONS + if args.file_type == "text": + return TEXT_EXTENSIONS + return SUPPORTED_EXTENSIONS + + +def resolve_delimiter(args: argparse.Namespace) -> str: + """Return the effective delimiter from parsed CLI *args*. + + Handles the special values ``'tab'`` and ``'\\t'`` so users can specify a + tab character on the command line without shell-escaping issues. + """ + if args.delimiter is None: + return DEFAULT_DELIMITER + raw = args.delimiter + if raw.lower() in ("tab", "\\t"): + return "\t" + return raw + + # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- if __name__ == "__main__": - import os + parser = build_arg_parser() + args = parser.parse_args() + + # --- Apply CLI overrides to module-level config --------------------------- + if args.bucket: + S3_BUCKET = args.bucket + if args.profile: + AWS_PROFILE = args.profile + input_file = args.input_file if args.input_file else INPUT_FILE + + active_extensions = resolve_extensions(args) + FILE_EXTENSIONS = active_extensions + + delimiter = resolve_delimiter(args) + encoding = args.encoding + quotechar = args.quotechar # --- Read input file ------------------------------------------------------ - if not os.path.exists(INPUT_FILE): - print(f"ERROR: Input file not found: {INPUT_FILE}", file=sys.stderr) + if not os.path.exists(input_file): + print(f"ERROR: Input file not found: {input_file}", file=sys.stderr) sys.exit(1) try: - prefixes = read_input_file(INPUT_FILE) + prefixes = read_input_file(input_file) except Exception as exc: print(f"ERROR: Could not read input file: {exc}", file=sys.stderr) sys.exit(1) @@ -399,7 +658,17 @@ if __name__ == "__main__": print(f"ERROR: AWS profile validation failed: {exc}", file=sys.stderr) sys.exit(1) + # --- Print active configuration ------------------------------------------- + ext_label = extensions_label(active_extensions) + print(f"Bucket: {S3_BUCKET}", file=sys.stderr) + print(f"Extensions: {ext_label}", file=sys.stderr) + if active_extensions & TEXT_EXTENSIONS: + print( + f"Text opts: delimiter={delimiter!r} encoding={encoding!r} " + f"quotechar={quotechar!r}", + file=sys.stderr, + ) + # --- Explore -------------------------------------------------------------- - print(f"Bucket: {S3_BUCKET}", file=sys.stderr) - results = explore_directories(prefixes) - print_results(results) + results = explore_directories(prefixes, extensions=active_extensions) + print_results(results, extensions=active_extensions) diff --git a/utils/file_viewer.py b/utils/file_viewer.py index 0b3303d..6c6343d 100644 --- a/utils/file_viewer.py +++ b/utils/file_viewer.py @@ -1,15 +1,23 @@ -"""Standalone utility to download a .sas7bdat file from S3 and print a -column-level summary of the first 10 rows. +"""Standalone utility to download a SAS or delimited text file from S3 and +print a column-level summary of the first *N* rows. -Configure the four constants below, then run:: +Supported formats +----------------- +* **SAS** – ``.sas7bdat``, ``.xpt``, ``.xport`` (read via *pyreadstat*) +* **Text** – ``.csv``, ``.tsv``, ``.txt`` (read via *pandas.read_csv*) + +Configure the four constants below **or** use the CLI arguments, then run:: python3 file_viewer.py + python3 file_viewer.py --local path/to/file.csv + python3 file_viewer.py --local path/to/data.tsv --delimiter $'\\t' Python 3.14 compatible. """ from __future__ import annotations +import argparse import os import sys @@ -19,14 +27,28 @@ import pyreadstat # --------------------------------------------------------------------------- -# Configuration — edit these before running +# Supported file extensions +# --------------------------------------------------------------------------- + +SAS_EXTENSIONS: set[str] = {".sas7bdat", ".xpt", ".xport"} +"""File extensions recognised as SAS data files.""" + +TEXT_EXTENSIONS: set[str] = {".txt", ".csv", ".tsv"} +"""File extensions recognised as delimited text files.""" + +SUPPORTED_EXTENSIONS: set[str] = SAS_EXTENSIONS | TEXT_EXTENSIONS +"""Union of all supported file extensions.""" + + +# --------------------------------------------------------------------------- +# Configuration — edit these before running (or use CLI arguments) # --------------------------------------------------------------------------- S3_BUCKET: str = "my-bucket" """S3 bucket name.""" S3_KEY: str = "path/to/file.sas7bdat" -"""Object key (path) within the bucket to the .sas7bdat file.""" +"""Object key (path) within the bucket to a supported data file.""" LOCAL_FOLDER: str = "./downloads" """Local directory to download the file into.""" @@ -45,6 +67,8 @@ def _ensure_local_copy(bucket: str, key: str, local_path: str) -> None: If *local_path* exists and its size matches the S3 object's size, the download is skipped and a message is printed. + + Supports any file whose extension is in :data:`SUPPORTED_EXTENSIONS`. """ session = boto3.Session(profile_name=AWS_PROFILE) s3 = session.client("s3") @@ -69,12 +93,117 @@ def _ensure_local_copy(bucket: str, key: str, local_path: str) -> None: print("Download complete.") +# -- SAS readers ------------------------------------------------------------- + + def _read_sas_head(path: str, row_count: int = 10) -> pd.DataFrame: - """Read the first *row_count* rows of a .sas7bdat file.""" - df, _ = pyreadstat.read_sas7bdat(path, row_offset=0, row_limit=row_count) + """Read the first *row_count* rows of a SAS file (``.sas7bdat``, ``.xpt``, ``.xport``).""" + ext = os.path.splitext(path)[1].lower() + if ext == ".sas7bdat": + df, _ = pyreadstat.read_sas7bdat(path, row_offset=0, row_limit=row_count) + elif ext in {".xpt", ".xport"}: + df, _ = pyreadstat.read_xport(path, row_offset=0, row_limit=row_count) + else: + raise ValueError(f"Unsupported SAS extension: {ext}") return df +# -- Text readers ------------------------------------------------------------ + + +def _read_text_head( + path: str, + row_count: int = 10, + delimiter: str = ",", + encoding: str = "utf-8", + quotechar: str = '"', +) -> pd.DataFrame: + """Read the first *row_count* rows of a delimited text file. + + Parameters + ---------- + path : str + Path to the ``.csv``, ``.tsv``, or ``.txt`` file. + row_count : int, optional + Number of data rows to read (default ``10``). + delimiter : str, optional + Column delimiter (default ``","``). For ``.tsv`` files the caller + should pass ``"\\t"``. + encoding : str, optional + File encoding (default ``"utf-8"``). + quotechar : str, optional + Character used to quote fields (default ``'"'``). + """ + return pd.read_csv( + path, + sep=delimiter, + encoding=encoding, + quotechar=quotechar, + nrows=row_count, + ) + + +# -- Unified reader ---------------------------------------------------------- + + +def _read_head( + path: str, + row_count: int = 10, + delimiter: str | None = None, + encoding: str = "utf-8", + quotechar: str = '"', +) -> pd.DataFrame: + """Read the first *row_count* rows of a supported data file. + + Auto-detects the file type from its extension and delegates to the + appropriate reader. For ``.tsv`` files the delimiter defaults to tab + (``"\\t"``); for other text files it defaults to ``","``. + + Parameters + ---------- + path : str + Path to the data file. + row_count : int, optional + Number of data rows to read (default ``10``). + delimiter : str or None, optional + Column delimiter for text files. ``None`` means *auto-detect* + (tab for ``.tsv``, comma otherwise). + encoding : str, optional + Encoding for text files (default ``"utf-8"``). + quotechar : str, optional + Quote character for text files (default ``'"'``). + + Returns + ------- + pandas.DataFrame + """ + ext = os.path.splitext(path)[1].lower() + + if ext not in SUPPORTED_EXTENSIONS: + raise ValueError( + f"Unsupported file extension '{ext}'. " + f"Supported extensions: {sorted(SUPPORTED_EXTENSIONS)}" + ) + + if ext in SAS_EXTENSIONS: + return _read_sas_head(path, row_count=row_count) + + # --- Text file path --- + if delimiter is None: + delimiter = "\t" if ext == ".tsv" else "," + + return _read_text_head( + path, + row_count=row_count, + delimiter=delimiter, + encoding=encoding, + quotechar=quotechar, + ) + + +# -- Display ----------------------------------------------------------------- + + def _sample_values(series: pd.Series, n: int = 3) -> str: """Return up to *n* non-null sample values as a comma-separated string.""" non_null = series.dropna() @@ -114,26 +243,126 @@ def _print_summary(df: pd.DataFrame) -> None: print() +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def _build_parser() -> argparse.ArgumentParser: + """Build the argument parser for the file-viewer CLI.""" + parser = argparse.ArgumentParser( + description=( + "Download a SAS or delimited text file from S3 (or read a local " + "file) and print a column-level summary of the first N rows.\n\n" + "Supported extensions: " + + ", ".join(sorted(SUPPORTED_EXTENSIONS)) + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + source = parser.add_mutually_exclusive_group() + source.add_argument( + "--local", + metavar="FILE", + default=None, + help=( + "Path to a local data file to summarise (skips S3 download). " + "Supported extensions: " + + ", ".join(sorted(SUPPORTED_EXTENSIONS)) + ), + ) + source.add_argument( + "--s3-key", + metavar="KEY", + default=None, + help="Override the S3_KEY constant with this object key.", + ) + + parser.add_argument( + "--rows", + type=int, + default=10, + metavar="N", + help="Number of rows to read (default: 10).", + ) + + # Text-file-specific options + text_group = parser.add_argument_group( + "text file options", + "These options apply only to .csv / .tsv / .txt files.", + ) + text_group.add_argument( + "--delimiter", + default=None, + help=( + 'Column delimiter for text files (default: "," for .csv/.txt, ' + '"\\t" for .tsv). Use $\'\\t\' in the shell for a literal tab.' + ), + ) + text_group.add_argument( + "--encoding", + default="utf-8", + help='File encoding for text files (default: "utf-8").', + ) + text_group.add_argument( + "--quotechar", + default='"', + help='Quote character for text files (default: \'"\').', + ) + + return parser + + # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- if __name__ == "__main__": - # --- Download ----------------------------------------------------------- - os.makedirs(LOCAL_FOLDER, exist_ok=True) - local_filename = os.path.basename(S3_KEY) - local_path = os.path.join(LOCAL_FOLDER, local_filename) + parser = _build_parser() + args = parser.parse_args() - try: - _ensure_local_copy(S3_BUCKET, S3_KEY, local_path) - except Exception as exc: - print(f"S3 download error: {exc}", file=sys.stderr) - sys.exit(1) + if args.local: + # ---- Local file mode ----------------------------------------------- + local_path = args.local + ext = os.path.splitext(local_path)[1].lower() + if ext not in SUPPORTED_EXTENSIONS: + parser.error( + f"Unsupported file extension '{ext}'. " + f"Supported: {sorted(SUPPORTED_EXTENSIONS)}" + ) + if not os.path.isfile(local_path): + print(f"File not found: {local_path}", file=sys.stderr) + sys.exit(1) + else: + # ---- S3 download mode ---------------------------------------------- + s3_key = args.s3_key or S3_KEY + ext = os.path.splitext(s3_key)[1].lower() + if ext not in SUPPORTED_EXTENSIONS: + parser.error( + f"Unsupported file extension '{ext}' in S3 key. " + f"Supported: {sorted(SUPPORTED_EXTENSIONS)}" + ) - # --- Read & summarize --------------------------------------------------- + os.makedirs(LOCAL_FOLDER, exist_ok=True) + local_filename = os.path.basename(s3_key) + local_path = os.path.join(LOCAL_FOLDER, local_filename) + + try: + _ensure_local_copy(S3_BUCKET, s3_key, local_path) + except Exception as exc: + print(f"S3 download error: {exc}", file=sys.stderr) + sys.exit(1) + + # ---- Read & summarise -------------------------------------------------- try: - df = _read_sas_head(local_path, row_count=10) + df = _read_head( + local_path, + row_count=args.rows, + delimiter=args.delimiter, + encoding=args.encoding, + quotechar=args.quotechar, + ) except Exception as exc: print(f"File read error: {exc}", file=sys.stderr) sys.exit(2) diff --git a/utils/s3_download.py b/utils/s3_download.py index c987db5..d1a2f2e 100644 --- a/utils/s3_download.py +++ b/utils/s3_download.py @@ -5,6 +5,10 @@ under that prefix recursively, groups objects into *clusters* using the same explicit-pattern + auto-detect rules as ``load_folder.py``, and downloads each cluster's files into its own subfolder under a local destination root. +Supported file types: + * SAS data files: ``.sas7bdat``, ``.xpt``, ``.xport`` + * Delimited text files: ``.txt``, ``.csv``, ``.tsv`` + ------------------------------------------------------------------------------- USAGE ------------------------------------------------------------------------------- @@ -19,8 +23,9 @@ USAGE aws_profile: default # optional; default boto3 chain if omitted auto_detect: true # optional; default true - extensions: # optional; default sas7bdat/xpt/xport + extensions: # optional; default sas7bdat/xpt/xport/txt/csv/tsv - .sas7bdat + - .csv on_exists: skip # optional; skip | overwrite | error concurrency: 4 # optional; default 4 @@ -58,7 +63,8 @@ Exit codes: * Listing is recursive (no S3 ``Delimiter``). Regexes are matched against the *basename* of each key (the part after the last ``/``), so a nested object like ``census/2020/raw/nested/group_c1.sas7bdat`` is grouped by - ``group_c1.sas7bdat`` alone. + ``group_c1.sas7bdat`` alone. Text files (e.g. ``data.csv``) are handled + identically — the basename is extracted and matched the same way. * Explicit patterns are tried in order. A key matched by one pattern is removed from the pool before the next pattern runs. Overlap between patterns is flagged as an error at discovery time. @@ -97,7 +103,9 @@ import boto3 import yaml -DEFAULT_EXTENSIONS: Tuple[str, ...] = (".sas7bdat", ".xpt", ".xport") +SAS_EXTENSIONS: Tuple[str, ...] = (".sas7bdat", ".xpt", ".xport") +TEXT_EXTENSIONS: Tuple[str, ...] = (".txt", ".csv", ".tsv") +DEFAULT_EXTENSIONS: Tuple[str, ...] = SAS_EXTENSIONS + TEXT_EXTENSIONS VALID_ON_EXISTS: Tuple[str, ...] = ("skip", "overwrite", "error") DEFAULT_CONCURRENCY: int = 4 @@ -318,7 +326,12 @@ def build_s3_client(cfg: DownloadConfig): def list_s3_objects(s3_client, cfg: DownloadConfig) -> List[S3Object]: - """List all objects under ``cfg.prefix`` recursively, filtered by extension.""" + """List all objects under ``cfg.prefix`` recursively, filtered by extension. + + Supports SAS extensions (``.sas7bdat``, ``.xpt``, ``.xport``) and text + extensions (``.txt``, ``.csv``, ``.tsv``) — whichever are present in + ``cfg.extensions``. + """ paginator = s3_client.get_paginator("list_objects_v2") out: List[S3Object] = [] for page in paginator.paginate(Bucket=cfg.bucket, Prefix=cfg.prefix): @@ -584,8 +597,12 @@ def download_cluster( def _build_argparser() -> argparse.ArgumentParser: p = argparse.ArgumentParser( description=( - "Download S3 objects under a prefix into a local folder, " - "grouping objects into clusters that each become one subfolder." + "Download S3 objects (SAS data files and/or delimited text files) " + "under a prefix into a local folder, grouping objects into " + "clusters that each become one subfolder. " + "Supported extensions: " + + ", ".join(DEFAULT_EXTENSIONS) + + "." ), ) p.add_argument( diff --git a/utils/sample_s3_download_config.yaml b/utils/sample_s3_download_config.yaml index 4ae715e..c48237a 100644 --- a/utils/sample_s3_download_config.yaml +++ b/utils/sample_s3_download_config.yaml @@ -52,11 +52,13 @@ local_folder: ./downloads auto_detect: true # Object extensions to consider. Anything else under the prefix is ignored. -# Default (when this key is omitted): .sas7bdat, .xpt, .xport (matches -# generic_loader/load_folder.py). +# Default (when this key is omitted): .sas7bdat, .xpt, .xport, .txt, .csv, .tsv # extensions: # - .sas7bdat # - .xpt +# - .txt +# - .csv +# - .tsv # --------------------------------------------------------------------------- # Optional: download behavior @@ -103,3 +105,7 @@ clusters: # # - pattern: '^year2020_regionA_\d+_detail\.sas7bdat$' # name: year2020_regionA_detail + + # Text file cluster example (when file_type: text): + # - pattern: '^data_group_a\d+\.txt$' + # name: data_group_a diff --git a/utils/sas_profiler.py b/utils/sas_profiler.py new file mode 100644 index 0000000..f676421 --- /dev/null +++ b/utils/sas_profiler.py @@ -0,0 +1,1274 @@ +"""Standalone utility that profiles a single local SAS or delimited text file +and writes an Excel report with drop, partition, and index candidates plus +type-inference warnings. + +Configure the constants below and run:: + + python3 utils/sas_profiler.py + +Or override any of them from the command line:: + + python3 utils/sas_profiler.py \ + --file ./data/mystate.sas7bdat \ + --out ./reports/mystate_profile.xlsx + + # Profile a CSV file with default comma delimiter: + python3 utils/sas_profiler.py --file ./data/myfile.csv + + # Profile a TSV file (tab delimiter auto-detected from extension): + python3 utils/sas_profiler.py --file ./data/myfile.tsv + + # Profile a pipe-delimited .txt file: + python3 utils/sas_profiler.py --file ./data/myfile.txt --delimiter '|' + +The report is a paste-ready companion to +``generic_loader/load_sas.py`` and ``generic_loader/load_folder.py``: the +"inferred Postgres type" column uses the loader's own ``infer_schema`` so the +drop / partition / index suggestions map one-to-one onto valid YAML config +entries for those scripts. + +Supported inputs: + +- SAS files: ``.sas7bdat`` / ``.xpt`` / ``.xport`` +- Delimited text files: ``.csv`` / ``.tsv`` / ``.txt`` (with headers) + +For text files, SAS-specific metadata (formats, labels) is not available; +those fields show "N/A" in the report. All other profiling (column names, +data types from pandas, value distributions, null counts, etc.) works +identically. + +Python 3.10+ compatible. +""" + +from __future__ import annotations + +import argparse +import collections +import datetime as dt +import math +import os +import re +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +# The loader lives in a sibling directory that is *not* a proper package +# (no __init__.py). Its own modules import each other by bare name, so we +# add the directory to sys.path before importing it here. +_REPO_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(_REPO_ROOT / "generic_loader")) + +import numpy as np # noqa: E402 +import pandas as pd # noqa: E402 +from openpyxl import Workbook # noqa: E402 +from openpyxl.styles import Alignment, Font, PatternFill # noqa: E402 +from openpyxl.utils import get_column_letter # noqa: E402 + +from load_sas import ( # noqa: E402 + NUMERIC_INT_RANGE, + ColumnSpec, + infer_schema, + iter_sas_chunks, + read_sas_preview, +) + + +# --------------------------------------------------------------------------- +# File extension constants +# --------------------------------------------------------------------------- + +TEXT_EXTENSIONS = {".txt", ".csv", ".tsv"} +"""File extensions recognised as delimited text files.""" + +SAS_EXTENSIONS = {".sas7bdat", ".xpt", ".xport"} +"""File extensions recognised as SAS data files.""" + +SUPPORTED_EXTENSIONS = SAS_EXTENSIONS | TEXT_EXTENSIONS +"""All file extensions the profiler can handle.""" + + +def _is_text_file(path: Path) -> bool: + """Return True if ``path`` has a recognised delimited-text extension.""" + return path.suffix.lower() in TEXT_EXTENSIONS + + +def _is_supported_file(path: Path) -> bool: + """Return True if ``path`` has any supported extension.""" + return path.suffix.lower() in SUPPORTED_EXTENSIONS + + +def _auto_delimiter(path: Path, explicit: Optional[str]) -> str: + """Return the effective delimiter for *path*. + + If the caller supplied an explicit delimiter, use it. Otherwise default + to ``"\\t"`` for ``.tsv`` files and ``","`` for everything else. + """ + if explicit is not None: + return explicit + if path.suffix.lower() == ".tsv": + return "\t" + return "," + + +# --------------------------------------------------------------------------- +# Configuration - edit these before running, or override via CLI flags +# --------------------------------------------------------------------------- + +SAS_PATH: str = "./generic_loader/samples/sample_kitchensink.xpt" +"""Local path to the file to profile. Accepts ``.sas7bdat``, ``.xpt``, +``.xport``, ``.csv``, ``.tsv``, or ``.txt``.""" + +OUTPUT_XLSX: str = "./sas_profile.xlsx" +"""Where to write the Excel report.""" + +HIGH_NULL_PCT: float = 95.0 +"""Columns whose null percentage meets or exceeds this threshold are flagged +as drop candidates.""" + +INDEX_UNIQUENESS_PCT: float = 95.0 +"""Columns whose distinct/non-null ratio meets or exceeds this threshold are +flagged as index candidates.""" + +PARTITION_MIN_FILL_PCT: float = 95.0 +"""Name-matched partition candidates must be non-null in at least this +fraction of rows.""" + +PRE_SHARDED_MAX_DISTINCT: int = 3 +"""A name-matched column with <= this many distinct values is treated as +pre-sharded ("this file is one slice; sibling files have the other values") +rather than as a ready-to-partition observed column.""" + +DISTINCT_CAP: int = 10_000 +"""Max size of the per-column distinct-value set. Exceeding this marks the +column as ``distinct_overflow`` and we report ">= CAP" in the xlsx.""" + +TOP_N_VALUES: int = 5 +"""Number of most-frequent values tracked per column.""" + +PREVIEW_ROWS_FOR_INFERENCE: int = 10_000 +"""Rows pulled from the file for the loader's schema inference. Matches +``load_sas.TYPE_INFERENCE_SAMPLE_ROWS`` so suggestions track the loader.""" + +PROFILE_CHUNK_ROWS: int = 5_000_000 +"""Rows per streaming chunk while profiling. Larger chunks amortize +pyreadstat / pandas overhead, and the profiler is typically run on a +beefy box (e.g. a 128 GB EC2) rather than a laptop, so the default is +set aggressively. + +Rough peak-memory estimate while a chunk is in flight: + + peak_bytes ~= chunksize * num_cols * ~50 bytes/cell * 2-3x + +(The 2-3x factor covers pyreadstat's read buffer + pandas frame +construction temporaries.) At 5M rows x 50 cols that's roughly 10-20 GB, +which is comfortable on a 128 GB host but would OOM a laptop. + +If you have lots of RAM and a very wide file, lower this; if you have a +narrow file and want max throughput, bump it higher with ``--chunksize`` +(the profiler will happily take 20M+ per chunk). If ``chunksize`` is +larger than the file, pyreadstat just hands back one chunk.""" + + +PARTITION_NAME_PATTERNS: Tuple[re.Pattern, ...] = ( + # ``state`` or ``state_code`` / ``statecode`` appearing as a full token + # anywhere in the column name. Uses underscore / start / end as token + # boundaries so we catch STATE, STATE_CODE, HOME_STATE, + # ADDR_LINE3_STATE, BIRTH_STATE_CODE, etc. without matching STATUS, + # ESTATE, INTERSTATE, or STATEWIDE. + re.compile(r"(?:^|_)state(?:_?code)?(?:_|$)", re.IGNORECASE), +) +"""Only columns whose name matches one of these patterns are ever considered +partition candidates. This deliberately ignores generic low-cardinality +signals (status flags, boolean columns, etc.) because in practice the only +useful partition key in this codebase is STATE. Add more patterns here if +that ever stops being true.""" + + +INDEX_NAME_PATTERNS: Tuple[re.Pattern, ...] = ( + re.compile(r"^id$", re.IGNORECASE), + re.compile(r"_id$", re.IGNORECASE), + re.compile(r"_key$", re.IGNORECASE), + re.compile(r"^pk_", re.IGNORECASE), +) +"""Name-bonus patterns for index-candidate ranking.""" + + +# --------------------------------------------------------------------------- +# Per-column streaming aggregator +# --------------------------------------------------------------------------- + + +@dataclass +class _ColumnStats: + """Accumulators updated chunk-by-chunk while streaming the file.""" + + name: str + n_total: int = 0 + n_null: int = 0 + n_empty_str: int = 0 + + distinct: set = field(default_factory=set) + distinct_overflow: bool = False + + top_counts: "collections.Counter[Any]" = field(default_factory=collections.Counter) + + min_val: Any = None + max_val: Any = None + + # Numeric running stats (Welford would be nicer but sum/sum-sq is plenty + # here for a "help me pick columns" report). + numeric_sum: float = 0.0 + numeric_sumsq: float = 0.0 + numeric_count: int = 0 + + # String byte-length stats (helps flag oversized TEXT columns). + str_max_bytes: int = 0 + str_sum_bytes: int = 0 + str_count: int = 0 + + samples: List[Any] = field(default_factory=list) + + def update(self, series: pd.Series) -> None: + """Fold one chunk's worth of this column into the accumulator. + + Implementation notes (this method is the dominant per-file cost): + + - All masks are vectorized - no ``Series.map(lambda ...)`` loops. + - Distinct tracking uses ``Series.value_counts`` so we iterate at + most once per *unique* value in the chunk (in C), not once per + row. + - Once ``distinct_overflow`` is set and ``top_counts`` is full, + subsequent chunks skip the value-counts pass entirely - we already + know the column is too varied to be a partition / drop candidate + and we already have the top-N. + """ + n = len(series) + if n == 0: + return + self.n_total += n + + is_object = pd.api.types.is_object_dtype(series) + is_numeric = pd.api.types.is_numeric_dtype(series) + is_datetime = pd.api.types.is_datetime64_any_dtype(series) + + if is_object: + # Vectorized equivalent of load_sas._char_missing_mask: treat + # None / NaN / empty string as missing. ``series == ""`` is + # False for non-string values so we don't need per-element type + # checks. + na_mask = series.isna() + empty_mask = (series == "") & ~na_mask + miss_mask = na_mask | empty_mask + self.n_empty_str += int(empty_mask.sum()) + else: + miss_mask = series.isna() + + self.n_null += int(miss_mask.sum()) + non_null = series[~miss_mask] if miss_mask.any() else series + if non_null.empty: + return + + # -- Numeric stats (C-level) --------------------------------------- + if is_numeric: + arr = non_null.to_numpy(dtype="float64", copy=False, na_value=np.nan) + # NaN-safe aggregates in one pass each (all C-level). + self.numeric_sum += float(np.nansum(arr)) + self.numeric_sumsq += float(np.nansum(arr * arr)) + self.numeric_count += int(arr.size) + cmin = float(np.nanmin(arr)) if arr.size else None + cmax = float(np.nanmax(arr)) if arr.size else None + if cmin is not None and (self.min_val is None or cmin < self.min_val): + self.min_val = cmin + if cmax is not None and (self.max_val is None or cmax > self.max_val): + self.max_val = cmax + + elif is_datetime: + cmin = non_null.min() + cmax = non_null.max() + if self.min_val is None or cmin < self.min_val: + self.min_val = cmin + if self.max_val is None or cmax > self.max_val: + self.max_val = cmax + + # -- String length stats via vectorized str.len -------------------- + # ``.str.len()`` is C-fast; for ASCII-dominated SAS data it matches + # UTF-8 byte length closely enough for the "oversized TEXT" flag. + if is_object: + lens = non_null.astype(str, copy=False).str.len() + lens = lens.dropna() + if not lens.empty: + bmax = int(lens.max()) + if bmax > self.str_max_bytes: + self.str_max_bytes = bmax + self.str_sum_bytes += int(lens.sum()) + self.str_count += int(lens.size) + + # -- Samples (tiny slice; free) ------------------------------------ + if len(self.samples) < 3: + needed = 3 - len(self.samples) + self.samples.extend(non_null.head(needed).tolist()) + + # -- Distinct / top_counts (vectorized via value_counts) ----------- + # Skip altogether once we're saturated: distinct is already known + # to be > DISTINCT_CAP and top_counts has its DISTINCT_CAP slots + # filled, so further value_counts calls can only bump existing + # keys - info we don't need for any of the classifiers. + top_full = len(self.top_counts) >= DISTINCT_CAP + if self.distinct_overflow and top_full: + return + + try: + vc = non_null.value_counts(sort=False) + except TypeError: + # Unhashable values (list/dict). Drop the column from both + # distinct and top-N tracking. + self.distinct_overflow = True + return + + if vc.empty: + return + + if not self.distinct_overflow: + # Only *new* values need to be considered for the distinct set. + for val in vc.index: + if val in self.distinct: + continue + if len(self.distinct) >= DISTINCT_CAP: + self.distinct_overflow = True + break + self.distinct.add(val) + + if not top_full: + # Bulk-merge known keys; cap adds for new keys. + for val, count in zip(vc.index.tolist(), vc.to_numpy().tolist()): + if val in self.top_counts: + self.top_counts[val] += int(count) + elif len(self.top_counts) < DISTINCT_CAP: + self.top_counts[val] = int(count) + # else: silently skip - we're past the cap. + else: + # Only existing keys can grow. + tc = self.top_counts + for val, count in zip(vc.index.tolist(), vc.to_numpy().tolist()): + if val in tc: + tc[val] += int(count) + + # -- Derived properties ------------------------------------------------ + + @property + def n_non_null(self) -> int: + return self.n_total - self.n_null + + @property + def null_pct(self) -> float: + if self.n_total == 0: + return 0.0 + return 100.0 * self.n_null / self.n_total + + @property + def fill_pct(self) -> float: + return 100.0 - self.null_pct + + @property + def distinct_count(self) -> int: + return len(self.distinct) + + @property + def distinct_display(self) -> str: + if self.distinct_overflow: + return f">= {DISTINCT_CAP:,}" + return f"{self.distinct_count:,}" + + @property + def mean(self) -> Optional[float]: + if self.numeric_count == 0: + return None + return self.numeric_sum / self.numeric_count + + @property + def std(self) -> Optional[float]: + if self.numeric_count < 2: + return None + mean = self.mean + var = self.numeric_sumsq / self.numeric_count - (mean * mean) + # Guard against tiny negative from floating point noise. + if var < 0: + var = 0.0 + return math.sqrt(var) + + @property + def top_value(self) -> Tuple[Any, int]: + if not self.top_counts: + return (None, 0) + return self.top_counts.most_common(1)[0] + + def top_values(self, n: int = TOP_N_VALUES) -> List[Tuple[Any, int]]: + return self.top_counts.most_common(n) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _matches_any(patterns: Tuple[re.Pattern, ...], name: str) -> bool: + return any(p.search(name) for p in patterns) + + +def _format_size(n_bytes: int) -> str: + size = float(n_bytes) + for unit in ("B", "KB", "MB", "GB", "TB"): + if size < 1024.0 or unit == "TB": + return f"{size:,.1f} {unit}" + size /= 1024.0 + return f"{size:,.1f} TB" + + +def _format_value(val: Any) -> str: + """Render a single Python value for display in the spreadsheet.""" + if val is None: + return "" + if isinstance(val, float) and pd.isna(val): + return "" + if isinstance(val, (pd.Timestamp, dt.date, dt.datetime)): + return str(val) + return repr(val) if isinstance(val, str) else str(val) + + +def _format_top_values(pairs: List[Tuple[Any, int]]) -> str: + if not pairs: + return "" + return ", ".join(f"{_format_value(v)} ({c:,})" for v, c in pairs) + + +def _format_samples(samples: List[Any]) -> str: + if not samples: + return "(all null)" + return ", ".join(_format_value(v) for v in samples) + + +# --------------------------------------------------------------------------- +# Streaming profile +# --------------------------------------------------------------------------- + + +def profile_file( + path: Path, + *, + chunksize: Optional[int] = None, + delimiter: Optional[str] = None, + encoding: str = "utf-8", + quotechar: str = '"', +) -> Tuple[Dict[str, _ColumnStats], Dict[str, ColumnSpec], Any, int, bool]: + """Stream ``path`` once, returning *(stats, columns, meta, total_rows, is_text)*. + + ``columns`` is the loader's inferred schema from the first + ``PREVIEW_ROWS_FOR_INFERENCE`` rows - identical to what ``load_sas`` + would use. ``stats`` are the full-file observations we add on top. + + ``is_text`` is True when the file was read as a delimited text file + rather than a SAS file. Callers can use this to adjust display (e.g. + showing "N/A" for SAS-specific metadata fields). + + For text files (``.csv``, ``.tsv``, ``.txt``), the ``delimiter``, + ``encoding``, and ``quotechar`` parameters control parsing. If + ``delimiter`` is ``None``, ``.tsv`` files default to tab and all + others default to comma. + """ + is_text = _is_text_file(path) + effective_delimiter = _auto_delimiter(path, delimiter) + + # Build kwargs for the loader's text-aware read functions. + text_kwargs: Dict[str, Any] = { + "delimiter": effective_delimiter, + "text_encoding": encoding, + "quotechar": quotechar, + } + + preview_df, meta = read_sas_preview( + path, rows=PREVIEW_ROWS_FOR_INFERENCE, **text_kwargs, + ) + total_rows_hint = getattr(meta, "number_rows", None) + columns = infer_schema(preview_df, meta, total_rows=total_rows_hint) + + stats: Dict[str, _ColumnStats] = { + name: _ColumnStats(name=name) for name in columns + } + + total_rows = 0 + effective_chunksize = chunksize if chunksize is not None else PROFILE_CHUNK_ROWS + chunk_kwargs: Dict[str, Any] = {"chunksize": effective_chunksize} + chunk_kwargs.update(text_kwargs) + # pyreadstat + pandas are both C-level; the per-chunk overhead we pay + # is dominated by the value_counts passes in _ColumnStats.update, so + # the profile runs O(total_rows) with a small constant. + import time + started_at = time.monotonic() + last_print_at = started_at + for chunk_df, _chunk_meta in iter_sas_chunks(path, **chunk_kwargs): + total_rows += len(chunk_df) + for name, cs in stats.items(): + if name not in chunk_df.columns: + continue + cs.update(chunk_df[name]) + now = time.monotonic() + # Throttle progress output to ~one line per 2 seconds so huge files + # don't spam stderr but small files still print at least once. + if now - last_print_at >= 2.0: + elapsed = now - started_at + rate = total_rows / elapsed if elapsed > 0 else 0.0 + print( + f" profiling... {total_rows:,} rows " + f"({rate:,.0f} rows/s)", + file=sys.stderr, + ) + last_print_at = now + + elapsed = time.monotonic() - started_at + rate = total_rows / elapsed if elapsed > 0 else 0.0 + print( + f" profiled {total_rows:,} rows in {elapsed:.1f}s " + f"({rate:,.0f} rows/s)", + file=sys.stderr, + ) + + return stats, columns, meta, total_rows, is_text + + +# --------------------------------------------------------------------------- +# Classifiers +# --------------------------------------------------------------------------- + + +@dataclass +class _DropCandidate: + name: str + reason: str + + +@dataclass +class _PartitionCandidate: + name: str + kind: str # "observed" or "pre_sharded" + distinct_count: int + fill_pct: float + top_values: str + observed_values_in_file: str + note: str + score: float + + +@dataclass +class _IndexCandidate: + name: str + uniqueness_pct: float + distinct_count: int + fill_pct: float + name_bonus: bool + note: str + score: float + + +@dataclass +class _TypeWarning: + name: str + severity: str # "info" | "warn" | "error" + message: str + + +def _is_constant_like(cs: _ColumnStats) -> bool: + """True when the column is effectively a single value (possibly with + a handful of nulls / empties mixed in).""" + if cs.n_non_null == 0: + return False + return cs.distinct_count == 1 and not cs.distinct_overflow + + +def classify( + stats: Dict[str, _ColumnStats], + columns: Dict[str, ColumnSpec], + *, + high_null_pct: float, + index_uniqueness_pct: float, + partition_min_fill_pct: float, + pre_sharded_max_distinct: int, +) -> Tuple[ + List[_DropCandidate], + List[_PartitionCandidate], + List[_IndexCandidate], + List[_TypeWarning], +]: + """Turn per-column stats + the loader's schema into four ranked lists. + + Partition candidates are restricted to columns whose name matches + :data:`PARTITION_NAME_PATTERNS` - in practice STATE / STATE_CODE. A + generic "low-cardinality = partition candidate" heuristic produces too + much noise for this codebase, so we only surface columns we're confident + about by name. + """ + + drops: List[_DropCandidate] = [] + partitions: List[_PartitionCandidate] = [] + indexes: List[_IndexCandidate] = [] + warnings: List[_TypeWarning] = [] + + # -- Partition candidates (name-matched only) -------------------------- + # Run this before the drop check so pre-sharded STATE columns don't get + # silently dropped for being "constant". + claimed_by_partition: set = set() + for name, cs in stats.items(): + if not _matches_any(PARTITION_NAME_PATTERNS, name): + continue + if cs.n_total == 0 or cs.n_non_null == 0: + continue + if cs.fill_pct < partition_min_fill_pct: + continue + + is_pre_sharded = ( + not cs.distinct_overflow + and cs.distinct_count <= pre_sharded_max_distinct + ) + kind = "pre_sharded" if is_pre_sharded else "observed" + observed = _format_top_values(cs.top_values(pre_sharded_max_distinct)) + + if is_pre_sharded: + note = ( + f"pre-sharded: this file only contains {cs.distinct_count} " + f"distinct value(s) ({observed}); keep the column and set " + "partition_by at the load_folder level so sibling files merge " + "into separate partitions of one table" + ) + else: + note = ( + f"observed {cs.distinct_display} distinct value(s) across " + f"{cs.fill_pct:.1f}% of rows; LIST partitioning will create " + "one child table per distinct value" + ) + + partitions.append( + _PartitionCandidate( + name=name, + kind=kind, + distinct_count=cs.distinct_count, + fill_pct=cs.fill_pct, + top_values=_format_top_values(cs.top_values()), + observed_values_in_file=observed, + note=note, + # Pre-sharded beats observed as the snippet's top pick. + score=(1_000_000.0 if is_pre_sharded else 500_000.0) + cs.fill_pct, + ) + ) + claimed_by_partition.add(name) + + partitions.sort(key=lambda p: p.score, reverse=True) + + # -- Drop candidates --------------------------------------------------- + for name, cs in stats.items(): + if name in claimed_by_partition: + continue + if cs.n_total == 0: + continue + + reason: Optional[str] = None + if cs.n_null == cs.n_total: + reason = "all-null" + elif ( + cs.n_non_null > 0 + and cs.distinct_count == 0 + and not cs.distinct_overflow + ): + # Non-null but nothing hashable captured - treat as opaque. + reason = "all-empty / unhashable" + elif cs.n_non_null == cs.n_empty_str and cs.n_empty_str > 0: + reason = "all-empty" + elif _is_constant_like(cs): + only_val = next(iter(cs.distinct)) + reason = f"constant={_format_value(only_val)}" + elif cs.null_pct >= high_null_pct: + reason = f"null_pct={cs.null_pct:.1f}%" + + if reason is not None: + drops.append(_DropCandidate(name=name, reason=reason)) + + dropped_names = {d.name for d in drops} + partition_names = {p.name for p in partitions} + + # -- Index candidates -------------------------------------------------- + for name, cs in stats.items(): + if name in dropped_names or name in partition_names: + continue + spec = columns.get(name) + if spec is None: + continue + if cs.n_non_null == 0: + continue + if cs.distinct_overflow: + # Super-high-cardinality → perfect candidate for an index. + uniqueness = 100.0 + distinct_count = DISTINCT_CAP # display sentinel + else: + uniqueness = 100.0 * cs.distinct_count / cs.n_non_null + distinct_count = cs.distinct_count + if uniqueness < index_uniqueness_pct: + continue + + name_bonus = _matches_any(INDEX_NAME_PATTERNS, name) + notes: List[str] = [] + if name_bonus: + notes.append("name matches INDEX_NAME_PATTERNS (ID/KEY-ish)") + if cs.distinct_overflow: + notes.append( + f"distinct tracking capped at {DISTINCT_CAP:,}; " + "treating as high-cardinality" + ) + + # Rank: name match dominates, then raw uniqueness, then fill. + score = (500_000.0 if name_bonus else 0.0) + uniqueness + cs.fill_pct / 100.0 + + indexes.append( + _IndexCandidate( + name=name, + uniqueness_pct=uniqueness, + distinct_count=distinct_count, + fill_pct=cs.fill_pct, + name_bonus=name_bonus, + note="; ".join(notes), + score=score, + ) + ) + + indexes.sort(key=lambda i: i.score, reverse=True) + + # -- Type warnings ----------------------------------------------------- + for name, cs in stats.items(): + spec = columns.get(name) + if spec is None: + continue + + # Re-surface whatever the loader's own inference already flagged in + # notes - these are genuinely useful for the user to see without + # having to dry-run the loader. + for note in spec.notes: + warnings.append( + _TypeWarning(name=name, severity="info", message=note) + ) + if spec.sampled: + warnings.append( + _TypeWarning( + name=name, + severity="info", + message=( + "loader inferred type from a bounded preview; " + "sampled=True" + ), + ) + ) + + pg_type = spec.postgres_type.upper() + + # Preview said NOT NULL but the full file has nulls - loader would + # have emitted NOT NULL and then choked on COPY. + if not spec.nullable and cs.n_null > 0: + warnings.append( + _TypeWarning( + name=name, + severity="error", + message=( + f"preview saw zero nulls (NOT NULL) but full file has " + f"{cs.n_null:,} null(s); COPY would fail under the " + "loader's inferred NOT NULL" + ), + ) + ) + + # INTEGER range check against the full-file observed min/max. + if pg_type == "INTEGER" and cs.numeric_count > 0: + lo, hi = NUMERIC_INT_RANGE + vmin = cs.min_val if cs.min_val is not None else 0 + vmax = cs.max_val if cs.max_val is not None else 0 + try: + if vmin < lo or vmax > hi: + warnings.append( + _TypeWarning( + name=name, + severity="error", + message=( + f"loader inferred INTEGER from the preview but " + f"full-file range [{vmin}, {vmax}] overflows " + f"int4 {NUMERIC_INT_RANGE}; BIGINT required" + ), + ) + ) + except TypeError: + pass + + # Preview said all-null (loader defaults to TEXT) but data exists. + was_all_null_preview = any( + "all-null column" in n for n in spec.notes + ) + if was_all_null_preview and cs.n_non_null > 0: + warnings.append( + _TypeWarning( + name=name, + severity="warn", + message=( + "preview was all-null so loader defaulted to TEXT, " + f"but full file has {cs.n_non_null:,} non-null " + "value(s); consider a tighter include/exclude or " + "re-inferring with TYPE_INFERENCE_SAMPLE_ROWS=None" + ), + ) + ) + + return drops, partitions, indexes, warnings + + +# --------------------------------------------------------------------------- +# YAML snippet +# --------------------------------------------------------------------------- + + +def render_yaml_snippet( + drops: List[_DropCandidate], + partitions: List[_PartitionCandidate], + indexes: List[_IndexCandidate], +) -> str: + """Produce a paste-ready YAML snippet for the loader config.""" + lines: List[str] = ["# Suggested additions to your load_sas.py / load_folder.py config"] + + if drops: + lines.append("exclude:") + for d in drops: + lines.append(f" - {d.name} # {d.reason}") + else: + lines.append("# (no drop candidates found)") + + lines.append("") + + if partitions: + top = partitions[0] + if top.kind == "pre_sharded": + lines.append( + f"# !! PRE-SHARDED: this file only contains " + f"{top.name} = {top.observed_values_in_file}." + ) + lines.append( + "# !! Keep the column in the schema and set partition_by at the " + "load_folder level" + ) + lines.append( + "# !! so sibling files merge into one table under separate " + "partitions." + ) + lines.append("partition_by:") + lines.append(f" - {top.name}") + else: + lines.append( + "# (no partition candidates found - no column matched " + "PARTITION_NAME_PATTERNS)" + ) + + lines.append("") + + if indexes: + lines.append("indexes:") + for i in indexes: + bonus = " (name match)" if i.name_bonus else "" + lines.append( + f" - {i.name} # uniqueness={i.uniqueness_pct:.1f}%{bonus}" + ) + else: + lines.append("# (no index candidates found)") + + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# XLSX writer +# --------------------------------------------------------------------------- + + +_HEADER_FONT = Font(bold=True, color="FFFFFF") +_HEADER_FILL = PatternFill("solid", fgColor="305496") +_WARN_FILL = PatternFill("solid", fgColor="FFE699") +_ERROR_FILL = PatternFill("solid", fgColor="F4B183") + + +def _write_header(ws, headers: List[str]) -> None: + for col_idx, label in enumerate(headers, start=1): + cell = ws.cell(row=1, column=col_idx, value=label) + cell.font = _HEADER_FONT + cell.fill = _HEADER_FILL + cell.alignment = Alignment(vertical="center") + ws.freeze_panes = "A2" + + +def _autosize(ws, *, max_width: int = 60) -> None: + for col_cells in ws.columns: + letter = get_column_letter(col_cells[0].column) + longest = 0 + for cell in col_cells: + if cell.value is None: + continue + text = str(cell.value) + # Only measure the first line so a long YAML cell doesn't push + # everything else ultra-wide. + longest = max(longest, min(len(text.split("\n", 1)[0]), max_width)) + ws.column_dimensions[letter].width = min(max(longest + 2, 10), max_width) + + +def _write_overview( + ws, + *, + path: Path, + size_bytes: int, + total_rows: int, + total_cols: int, + thresholds: Dict[str, Any], +) -> None: + ws.cell(row=1, column=1, value="Field").font = _HEADER_FONT + ws.cell(row=1, column=1).fill = _HEADER_FILL + ws.cell(row=1, column=2, value="Value").font = _HEADER_FONT + ws.cell(row=1, column=2).fill = _HEADER_FILL + ws.freeze_panes = "A2" + + rows = [ + ("File path", str(path)), + ("File size", _format_size(size_bytes)), + ("Extension", path.suffix.lower()), + ("Total rows", f"{total_rows:,}"), + ("Total columns", f"{total_cols:,}"), + ("Generated at", dt.datetime.now().isoformat(timespec="seconds")), + ] + for k, v in thresholds.items(): + rows.append((f"threshold: {k}", str(v))) + + for i, (k, v) in enumerate(rows, start=2): + ws.cell(row=i, column=1, value=k) + ws.cell(row=i, column=2, value=v) + + _autosize(ws) + + +def _write_columns( + ws, + stats: Dict[str, _ColumnStats], + columns: Dict[str, ColumnSpec], + *, + is_text: bool = False, +) -> None: + headers = [ + "column", "sas_format", "source_dtype", "inferred_postgres_type", + "nullable", "n_total", "n_null", "null_pct", "distinct_count", + "min", "max", "mean", "std", + "top_value", "top_count", + "max_str_bytes", "mean_str_bytes", + "sample_values", "notes", + ] + _write_header(ws, headers) + + for row_idx, (name, cs) in enumerate(stats.items(), start=2): + spec = columns.get(name) + top_val, top_count = cs.top_value + mean_bytes = (cs.str_sum_bytes / cs.str_count) if cs.str_count else None + # For text files, SAS-specific metadata is not available. + if is_text: + sas_format_display = "N/A" + else: + sas_format_display = spec.sas_format if spec else "" + values = [ + name, + sas_format_display, + spec.source_dtype if spec else "", + spec.postgres_type if spec else "", + "YES" if (spec and spec.nullable) else "NO", + cs.n_total, + cs.n_null, + round(cs.null_pct, 3), + cs.distinct_display, + _format_value(cs.min_val), + _format_value(cs.max_val), + round(cs.mean, 6) if cs.mean is not None else "", + round(cs.std, 6) if cs.std is not None else "", + _format_value(top_val), + top_count or "", + cs.str_max_bytes or "", + round(mean_bytes, 2) if mean_bytes is not None else "", + _format_samples(cs.samples), + "; ".join(spec.notes) if spec and spec.notes else "", + ] + for col_idx, v in enumerate(values, start=1): + ws.cell(row=row_idx, column=col_idx, value=v) + + _autosize(ws) + + +def _write_drop(ws, drops: List[_DropCandidate]) -> None: + headers = ["column", "reason"] + _write_header(ws, headers) + if not drops: + ws.cell(row=2, column=1, value="(no drop candidates)") + for i, d in enumerate(drops, start=2): + ws.cell(row=i, column=1, value=d.name) + ws.cell(row=i, column=2, value=d.reason) + _autosize(ws) + + +def _write_partition(ws, partitions: List[_PartitionCandidate]) -> None: + headers = [ + "rank", "column", "kind", "distinct_count", "fill_pct", + "observed_values_in_file", "top_values", "score", "note", + ] + _write_header(ws, headers) + if not partitions: + ws.cell(row=2, column=1, value="(no partition candidates)") + for rank, p in enumerate(partitions, start=1): + row = rank + 1 + ws.cell(row=row, column=1, value=rank) + ws.cell(row=row, column=2, value=p.name) + ws.cell(row=row, column=3, value=p.kind) + ws.cell(row=row, column=4, value=p.distinct_count) + ws.cell(row=row, column=5, value=round(p.fill_pct, 3)) + ws.cell(row=row, column=6, value=p.observed_values_in_file) + ws.cell(row=row, column=7, value=p.top_values) + ws.cell(row=row, column=8, value=round(p.score, 3)) + ws.cell(row=row, column=9, value=p.note) + if p.kind == "pre_sharded": + for col in range(1, len(headers) + 1): + ws.cell(row=row, column=col).fill = _WARN_FILL + _autosize(ws) + + +def _write_index(ws, indexes: List[_IndexCandidate]) -> None: + headers = [ + "rank", "column", "uniqueness_pct", "distinct_count", "fill_pct", + "name_bonus", "score", "note", + ] + _write_header(ws, headers) + if not indexes: + ws.cell(row=2, column=1, value="(no index candidates)") + for rank, i in enumerate(indexes, start=1): + row = rank + 1 + ws.cell(row=row, column=1, value=rank) + ws.cell(row=row, column=2, value=i.name) + ws.cell(row=row, column=3, value=round(i.uniqueness_pct, 3)) + ws.cell(row=row, column=4, value=i.distinct_count) + ws.cell(row=row, column=5, value=round(i.fill_pct, 3)) + ws.cell(row=row, column=6, value="YES" if i.name_bonus else "NO") + ws.cell(row=row, column=7, value=round(i.score, 3)) + ws.cell(row=row, column=8, value=i.note) + _autosize(ws) + + +def _write_warnings(ws, warnings: List[_TypeWarning]) -> None: + headers = ["column", "severity", "message"] + _write_header(ws, headers) + if not warnings: + ws.cell(row=2, column=1, value="(no type warnings)") + for i, w in enumerate(warnings, start=2): + ws.cell(row=i, column=1, value=w.name) + ws.cell(row=i, column=2, value=w.severity) + ws.cell(row=i, column=3, value=w.message) + fill = None + if w.severity == "error": + fill = _ERROR_FILL + elif w.severity == "warn": + fill = _WARN_FILL + if fill is not None: + for col in range(1, len(headers) + 1): + ws.cell(row=i, column=col).fill = fill + _autosize(ws) + + +def _write_yaml_sheet(ws, snippet: str) -> None: + ws.cell(row=1, column=1, value="YAML suggestion (paste into your loader config)").font = _HEADER_FONT + ws.cell(row=1, column=1).fill = _HEADER_FILL + cell = ws.cell(row=2, column=1, value=snippet) + cell.alignment = Alignment(wrap_text=True, vertical="top") + # Pick a comfy width for YAML; row height is auto when wrap_text is on. + ws.column_dimensions["A"].width = 100 + + +def write_report( + out_path: Path, + *, + path: Path, + size_bytes: int, + total_rows: int, + stats: Dict[str, _ColumnStats], + columns: Dict[str, ColumnSpec], + drops: List[_DropCandidate], + partitions: List[_PartitionCandidate], + indexes: List[_IndexCandidate], + warnings: List[_TypeWarning], + yaml_snippet: str, + thresholds: Dict[str, Any], + is_text: bool = False, +) -> None: + wb = Workbook() + ws = wb.active + ws.title = "Overview" + _write_overview( + ws, + path=path, + size_bytes=size_bytes, + total_rows=total_rows, + total_cols=len(columns), + thresholds=thresholds, + ) + _write_columns(wb.create_sheet("Columns"), stats, columns, is_text=is_text) + _write_drop(wb.create_sheet("Drop candidates"), drops) + _write_partition(wb.create_sheet("Partition candidates"), partitions) + _write_index(wb.create_sheet("Index candidates"), indexes) + _write_warnings(wb.create_sheet("Type warnings"), warnings) + _write_yaml_sheet(wb.create_sheet("YAML suggestion"), yaml_snippet) + wb.save(out_path) + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def _build_argparser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( + description=( + "Profile a local SAS or delimited text file and write an Excel " + "report with drop, partition_by, and index suggestions for " + "generic_loader/load_sas.py and load_folder.py.\n\n" + "Supported formats: .sas7bdat, .xpt, .xport (SAS); " + ".csv, .tsv, .txt (delimited text with headers)." + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + p.add_argument( + "--file", type=Path, default=Path(SAS_PATH), + help=( + "Path to the file to profile. Accepts SAS files " + "(.sas7bdat/.xpt/.xport) and delimited text files " + f"(.csv/.tsv/.txt). Default: {SAS_PATH!r}." + ), + ) + p.add_argument("--out", type=Path, default=Path(OUTPUT_XLSX), + help=f"Where to write the .xlsx report (default: {OUTPUT_XLSX!r}).") + p.add_argument("--high-null-pct", type=float, default=HIGH_NULL_PCT, + help="Null percentage at/above which a column is a drop candidate.") + p.add_argument("--index-uniqueness-pct", type=float, default=INDEX_UNIQUENESS_PCT, + help="Uniqueness (distinct/non-null) at/above which a column is an index candidate.") + p.add_argument("--partition-min-fill-pct", type=float, default=PARTITION_MIN_FILL_PCT) + p.add_argument("--pre-sharded-max-distinct", type=int, default=PRE_SHARDED_MAX_DISTINCT) + p.add_argument( + "--chunksize", type=int, default=None, + help=( + "Rows per streaming read. Bigger chunks amortize pyreadstat / " + "pandas overhead (faster for huge files) but use more peak " + f"memory. Defaults to PROFILE_CHUNK_ROWS ({PROFILE_CHUNK_ROWS:,})." + ), + ) + # -- Text file options ------------------------------------------------- + p.add_argument( + "--delimiter", type=str, default=None, + help=( + "Column delimiter for text files (.csv/.tsv/.txt). " + "Defaults to tab for .tsv, comma for .csv/.txt. " + "Ignored for SAS files. Example: --delimiter '|'" + ), + ) + p.add_argument( + "--encoding", type=str, default="utf-8", + help=( + "Character encoding for text files (default: utf-8). " + "Ignored for SAS files." + ), + ) + p.add_argument( + "--quotechar", type=str, default='"', + help=( + 'Quote character for text files (default: \'"\'). ' + "Ignored for SAS files." + ), + ) + return p + + +def main(argv: Optional[List[str]] = None) -> int: + args = _build_argparser().parse_args(argv) + + path: Path = args.file + out_path: Path = args.out + + if not path.exists(): + print(f"error: file not found: {path}", file=sys.stderr) + return 2 + + if not _is_supported_file(path): + exts = ", ".join(sorted(SUPPORTED_EXTENSIONS)) + print( + f"error: unsupported extension {path.suffix!r}; " + f"expected one of: {exts}", + file=sys.stderr, + ) + return 2 + + file_kind = "text" if _is_text_file(path) else "SAS" + print(f"profiling {path} ({file_kind}) -> {out_path}", file=sys.stderr) + stats, columns, meta, total_rows, is_text = profile_file( + path, + chunksize=args.chunksize, + delimiter=args.delimiter, + encoding=args.encoding, + quotechar=args.quotechar, + ) + + drops, partitions, indexes, warnings = classify( + stats, columns, + high_null_pct=args.high_null_pct, + index_uniqueness_pct=args.index_uniqueness_pct, + partition_min_fill_pct=args.partition_min_fill_pct, + pre_sharded_max_distinct=args.pre_sharded_max_distinct, + ) + + yaml_snippet = render_yaml_snippet(drops, partitions, indexes) + + thresholds = { + "HIGH_NULL_PCT": args.high_null_pct, + "INDEX_UNIQUENESS_PCT": args.index_uniqueness_pct, + "PARTITION_MIN_FILL_PCT": args.partition_min_fill_pct, + "PRE_SHARDED_MAX_DISTINCT": args.pre_sharded_max_distinct, + "PARTITION_NAME_PATTERNS": ", ".join(p.pattern for p in PARTITION_NAME_PATTERNS), + "DISTINCT_CAP": DISTINCT_CAP, + "TOP_N_VALUES": TOP_N_VALUES, + "PREVIEW_ROWS_FOR_INFERENCE": PREVIEW_ROWS_FOR_INFERENCE, + } + + out_path.parent.mkdir(parents=True, exist_ok=True) + write_report( + out_path, + path=path, + size_bytes=os.path.getsize(path), + total_rows=total_rows, + stats=stats, + columns=columns, + drops=drops, + partitions=partitions, + indexes=indexes, + warnings=warnings, + yaml_snippet=yaml_snippet, + thresholds=thresholds, + is_text=is_text, + ) + + print( + f"wrote {out_path} ({len(stats)} columns, {total_rows:,} rows scanned)\n" + f" drops: {len(drops)}\n" + f" partitions: {len(partitions)}\n" + f" indexes: {len(indexes)}\n" + f" warnings: {len(warnings)}", + file=sys.stderr, + ) + return 0 + + +if __name__ == "__main__": + sys.exit(main())