foxtrot/generic_loader/load_folder.py
2026-04-20 08:38:38 -05:00

577 lines
20 KiB
Python

"""Folder-level SAS-to-Postgres loader.
Wraps :mod:`load_sas` so an entire directory of SAS files can be ingested in
one invocation. A directory often contains several *clusters* of files that
share a schema (e.g. ``group_a1.sas7bdat``, ``group_a2.sas7bdat``, ...). Each
cluster becomes one Postgres table; files inside a cluster are appended to it.
-------------------------------------------------------------------------------
USAGE
-------------------------------------------------------------------------------
1. YAML config
--------------
::
folder: samples/folder_test # required; relative paths resolve against
# the config file's directory
schemaname: public # required
# Optional. One of: fail | replace | append. Default: fail.
# Applied to the first file of each cluster (subsequent files in the
# cluster always run through the append-mode compatibility check).
if_exists: fail
# Optional. Default: true. When true, files that don't match any explicit
# pattern below are grouped by their common prefix (trailing digits, and
# optional trailing separators, are stripped from each file stem).
auto_detect: true
# Optional. Columns to force-include or force-exclude across every file.
# include and exclude are mutually exclusive.
# include: [ID, INTCOL]
# exclude: [ALLNULL]
# Optional explicit cluster patterns. Each pattern is matched against the
# file *basename*. Matched files are pulled out of the auto-detect pool.
# Per-cluster if_exists/include/exclude override the folder-level defaults.
clusters:
- pattern: '^group_a\\d+\\.sas7bdat$'
tablename: group_a
- pattern: '^group_b\\d+\\.sas7bdat$'
tablename: group_b
if_exists: replace
2. Command-line interface
-------------------------
::
python load_folder.py --config folder_config.yaml [--dry-run] [--fail-fast]
[--dbcreds]
Flags:
--config PATH Required. Path to the YAML config above.
--dry-run Print the discovered clusters and the inferred CREATE
TABLE for each (schema from the first file of the
cluster). The database is never touched.
--fail-fast Abort the whole run on the first cluster failure.
Default is to log the failure, roll that cluster back,
and keep going.
--dbcreds Prompt interactively for the database username and
password instead of reading ``PGUSER`` / ``PGPASSWORD``
from the environment or ``.env`` file. The password
prompt does not echo. Has no effect with ``--dry-run``
(no connection is opened).
Exit codes:
0 - every cluster loaded successfully (or dry-run completed)
1 - at least one cluster failed (details on stderr)
2 - folder does not exist / contains no SAS files
3. Discovery rules
------------------
* Supported extensions: ``.sas7bdat``, ``.xpt``, ``.xport`` (matches
:mod:`load_sas`). The folder is not scanned recursively.
* Explicit patterns are tried in order. A file matched by one pattern is
removed from the pool before the next pattern runs, so earlier patterns
win in case of overlap. Overlap between patterns is flagged as an error
at config-parse time (a file matching two patterns is almost always a bug).
* Auto-detect groups remaining files by ``re.sub(r'\\d+$', '', stem)`` with
any trailing ``_`` / ``-`` stripped afterward. Stems without trailing
digits become singleton clusters named after the stem.
4. Library usage
----------------
::
from load_folder import load_folder_config, discover_clusters, load_cluster
from load_sas import connect
cfg = load_folder_config("folder_config.yaml")
clusters = discover_clusters(cfg)
conn = connect()
try:
for cluster in clusters:
load_cluster(conn, cluster, cfg.schemaname)
finally:
conn.close()
"""
from __future__ import annotations
import argparse
import getpass
import re
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import yaml
from dotenv import load_dotenv
from load_sas import (
VALID_IF_EXISTS,
apply_column_filter,
assert_schema_compatible,
connect,
copy_dataframes,
create_table,
infer_schema,
iter_sas_chunks,
read_sas_preview,
render_create_table,
)
SAS_EXTENSIONS = (".sas7bdat", ".xpt", ".xport")
# ---------------------------------------------------------------------------
# Dataclasses
# ---------------------------------------------------------------------------
@dataclass
class ClusterSpec:
tablename: str
files: List[Path]
if_exists: str
include: Optional[List[str]]
exclude: Optional[List[str]]
source: str # "explicit" or "auto"
pattern: Optional[str] = None
@dataclass
class _ExplicitPattern:
"""Parsed form of a single ``clusters[*]`` YAML entry."""
pattern: re.Pattern
raw_pattern: str
tablename: str
if_exists: Optional[str] = None
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
@dataclass
class FolderConfig:
folder: Path
schemaname: str
if_exists: str = "fail"
auto_detect: bool = True
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
explicit: List[_ExplicitPattern] = field(default_factory=list)
# ---------------------------------------------------------------------------
# Config loading
# ---------------------------------------------------------------------------
def _validate_if_exists(value: Any, where: str) -> str:
s = str(value).lower()
if s not in VALID_IF_EXISTS:
raise ValueError(
f"{where}: if_exists={value!r} is not one of {VALID_IF_EXISTS}"
)
return s
def _parse_columns_filter(
raw: Dict[str, Any], where: str
) -> Tuple[Optional[List[str]], Optional[List[str]]]:
include = raw.get("include")
exclude = raw.get("exclude")
if include is not None and exclude is not None:
raise ValueError(f"{where}: 'include' and 'exclude' are mutually exclusive.")
if include is not None and not isinstance(include, list):
raise ValueError(f"{where}: 'include' must be a list of column names.")
if exclude is not None and not isinstance(exclude, list):
raise ValueError(f"{where}: 'exclude' must be a list of column names.")
include_out = [str(c) for c in include] if include is not None else None
exclude_out = [str(c) for c in exclude] if exclude is not None else None
return include_out, exclude_out
def load_folder_config(path: Path) -> FolderConfig:
"""Parse and validate the folder-level YAML config at ``path``."""
path = Path(path)
with path.open("r", encoding="utf-8") as f:
raw = yaml.safe_load(f)
if not isinstance(raw, dict):
raise ValueError(f"Config at {path} must be a YAML mapping at the top level.")
missing = [k for k in ("folder", "schemaname") if k not in raw]
if missing:
raise ValueError(f"Config {path} missing required keys: {', '.join(missing)}")
folder = Path(raw["folder"])
if not folder.is_absolute():
candidate = (path.parent / folder).resolve()
folder = candidate if candidate.exists() else folder
schemaname = str(raw["schemaname"])
if_exists = _validate_if_exists(raw.get("if_exists", "fail"), f"Config {path}")
auto_detect = bool(raw.get("auto_detect", True))
include, exclude = _parse_columns_filter(raw, f"Config {path}")
explicit: List[_ExplicitPattern] = []
clusters_raw = raw.get("clusters") or []
if not isinstance(clusters_raw, list):
raise ValueError(f"Config {path}: 'clusters' must be a list if present.")
for i, entry in enumerate(clusters_raw):
where = f"Config {path} clusters[{i}]"
if not isinstance(entry, dict):
raise ValueError(f"{where} must be a mapping.")
if "pattern" not in entry or "tablename" not in entry:
raise ValueError(f"{where} must include 'pattern' and 'tablename'.")
raw_pat = str(entry["pattern"])
try:
compiled = re.compile(raw_pat)
except re.error as e:
raise ValueError(f"{where}: invalid regex {raw_pat!r}: {e}") from e
c_if_exists = (
_validate_if_exists(entry["if_exists"], where)
if "if_exists" in entry
else None
)
c_include, c_exclude = _parse_columns_filter(entry, where)
explicit.append(
_ExplicitPattern(
pattern=compiled,
raw_pattern=raw_pat,
tablename=str(entry["tablename"]),
if_exists=c_if_exists,
include=c_include,
exclude=c_exclude,
)
)
return FolderConfig(
folder=folder,
schemaname=schemaname,
if_exists=if_exists,
auto_detect=auto_detect,
include=include,
exclude=exclude,
explicit=explicit,
)
# ---------------------------------------------------------------------------
# Cluster discovery
# ---------------------------------------------------------------------------
_TRAILING_DIGIT_RE = re.compile(r"\d+$")
def _auto_prefix(stem: str) -> str:
"""Derive the cluster key for a file stem.
Strip trailing digits and any trailing separators so
``group_a1`` / ``group_a_2`` / ``group_a-3`` all land in the same
``group_a`` bucket. If nothing is stripped, the stem is its own key.
"""
stripped = _TRAILING_DIGIT_RE.sub("", stem)
stripped = stripped.rstrip("_-")
return stripped or stem
def _list_sas_files(folder: Path) -> List[Path]:
files: List[Path] = []
for p in sorted(folder.iterdir()):
if p.is_file() and p.suffix.lower() in SAS_EXTENSIONS:
files.append(p)
return files
def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]:
"""Enumerate ``cfg.folder`` and bucket files into ``ClusterSpec`` objects.
Pure/IO-bounded: the only filesystem access is listing ``cfg.folder``. No
SAS file is opened here. Explicit patterns are applied first, in config
order; files matched by an earlier pattern are removed from the pool
before the next pattern runs. A file matching two patterns triggers a
hard error (that's almost always a config bug).
"""
if not cfg.folder.exists() or not cfg.folder.is_dir():
raise FileNotFoundError(f"Folder not found or not a directory: {cfg.folder}")
pool = _list_sas_files(cfg.folder)
clusters: List[ClusterSpec] = []
# Detect cross-pattern overlap up front for a clearer error message.
for i, p_i in enumerate(cfg.explicit):
for j in range(i + 1, len(cfg.explicit)):
p_j = cfg.explicit[j]
for f in pool:
if p_i.pattern.search(f.name) and p_j.pattern.search(f.name):
raise ValueError(
f"File {f.name!r} matches multiple explicit patterns: "
f"{p_i.raw_pattern!r} and {p_j.raw_pattern!r}"
)
remaining = list(pool)
for patt in cfg.explicit:
matched = [f for f in remaining if patt.pattern.search(f.name)]
if not matched:
# Not an error - the folder might legitimately not contain files
# for this pattern on a given run. Emit a note for the CLI.
clusters.append(
ClusterSpec(
tablename=patt.tablename,
files=[],
if_exists=patt.if_exists or cfg.if_exists,
include=patt.include if patt.include is not None else cfg.include,
exclude=patt.exclude if patt.exclude is not None else cfg.exclude,
source="explicit",
pattern=patt.raw_pattern,
)
)
continue
remaining = [f for f in remaining if f not in matched]
clusters.append(
ClusterSpec(
tablename=patt.tablename,
files=sorted(matched),
if_exists=patt.if_exists or cfg.if_exists,
include=patt.include if patt.include is not None else cfg.include,
exclude=patt.exclude if patt.exclude is not None else cfg.exclude,
source="explicit",
pattern=patt.raw_pattern,
)
)
if cfg.auto_detect and remaining:
buckets: Dict[str, List[Path]] = {}
for f in remaining:
key = _auto_prefix(f.stem)
buckets.setdefault(key, []).append(f)
for key in sorted(buckets):
clusters.append(
ClusterSpec(
tablename=key,
files=sorted(buckets[key]),
if_exists=cfg.if_exists,
include=cfg.include,
exclude=cfg.exclude,
source="auto",
)
)
return clusters
# ---------------------------------------------------------------------------
# Per-cluster load
# ---------------------------------------------------------------------------
def _infer_cluster_schema(path: Path, include, exclude):
preview_df, meta = read_sas_preview(path)
preview_df = apply_column_filter(preview_df, include, exclude)
total_rows = getattr(meta, "number_rows", None)
columns = infer_schema(preview_df, meta, total_rows=total_rows)
return columns
def load_cluster(conn, cluster: ClusterSpec, schemaname: str) -> int:
"""Load every file in ``cluster`` into one table. Returns total rows loaded.
Commits happen per chunk inside :func:`load_sas.copy_dataframes`. If a
file mid-cluster fails, earlier chunks - including chunks from earlier
files in the cluster - stay committed; only the in-flight chunk is
rolled back by :func:`main`.
"""
if not cluster.files:
return 0
first, *rest = cluster.files
first_columns = _infer_cluster_schema(first, cluster.include, cluster.exclude)
create_table(
conn, schemaname, cluster.tablename, first_columns, cluster.if_exists
)
total = 0
total += _stream_file(
conn, schemaname, cluster.tablename, first, first_columns,
cluster.include, cluster.exclude,
)
for path in rest:
columns = _infer_cluster_schema(path, cluster.include, cluster.exclude)
# Uses the same check that if_exists=append runs. A type mismatch or
# missing column aborts the cluster; because chunks commit as they
# load, earlier chunks in the cluster remain in the table.
assert_schema_compatible(conn, schemaname, cluster.tablename, columns)
total += _stream_file(
conn, schemaname, cluster.tablename, path, columns,
cluster.include, cluster.exclude,
)
return total
def _stream_file(
conn,
schemaname: str,
tablename: str,
path: Path,
columns,
include,
exclude,
) -> int:
def _chunks():
seen = 0
for chunk_df, _chunk_meta in iter_sas_chunks(path):
chunk_df = apply_column_filter(chunk_df, include, exclude)
seen += len(chunk_df)
print(
f" {path.name}: streaming... {seen:,} rows",
file=sys.stderr,
)
yield chunk_df
return copy_dataframes(conn, schemaname, tablename, _chunks(), columns)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
description=(
"Load every SAS file in a folder into Postgres, grouping files "
"into clusters that each become one table."
),
)
p.add_argument("--config", required=True, type=Path, help="Path to YAML config")
p.add_argument(
"--dry-run",
action="store_true",
help=(
"Print discovered clusters and the inferred CREATE TABLE for "
"each; don't touch Postgres."
),
)
p.add_argument(
"--fail-fast",
action="store_true",
help=(
"Abort on the first cluster failure. Default is to roll that "
"cluster back and continue with the next one."
),
)
p.add_argument(
"--dbcreds",
action="store_true",
help=(
"Prompt for database username and password instead of reading "
"PGUSER / PGPASSWORD from the environment or .env file."
),
)
return p
def _describe_cluster(cluster: ClusterSpec) -> str:
src = f"{cluster.source}"
if cluster.pattern:
src += f" pattern={cluster.pattern!r}"
files = ", ".join(f.name for f in cluster.files) or "(no matching files)"
return (
f"cluster {cluster.tablename!r} [{src}] if_exists={cluster.if_exists}\n"
f" files: {files}"
)
def main(argv: Optional[List[str]] = None) -> int:
args = _build_argparser().parse_args(argv)
load_dotenv()
cfg = load_folder_config(args.config)
if not cfg.folder.exists() or not cfg.folder.is_dir():
print(f"error: folder not found: {cfg.folder}", file=sys.stderr)
return 2
clusters = discover_clusters(cfg)
loadable = [c for c in clusters if c.files]
if not loadable:
print(
f"error: no SAS files found in {cfg.folder} "
f"(looked for {', '.join(SAS_EXTENSIONS)})",
file=sys.stderr,
)
return 2
print(f"discovered {len(loadable)} cluster(s) in {cfg.folder}:")
for c in clusters:
print(_describe_cluster(c))
if args.dry_run:
print()
for c in loadable:
print(f"--- CREATE TABLE for cluster {c.tablename!r} ---")
columns = _infer_cluster_schema(c.files[0], c.include, c.exclude)
print(render_create_table(cfg.schemaname, c.tablename, columns))
print()
return 0
db_user = db_password = None
if args.dbcreds:
db_user = input("Database username: ")
db_password = getpass.getpass("Database password: ")
conn = connect(user=db_user, password=db_password)
conn.autocommit = False
failures: List[Tuple[str, Exception]] = []
totals: List[Tuple[str, int, int]] = [] # (tablename, files, rows)
try:
for cluster in loadable:
print(
f"\n>>> loading cluster {cluster.tablename!r} "
f"({len(cluster.files)} file(s))"
)
try:
rows = load_cluster(conn, cluster, cfg.schemaname)
conn.commit()
totals.append((cluster.tablename, len(cluster.files), rows))
print(
f" -> loaded {rows:,} row(s) into "
f"{cfg.schemaname}.{cluster.tablename}"
)
except Exception as e:
conn.rollback()
failures.append((cluster.tablename, e))
print(
f" !! cluster {cluster.tablename!r} failed: {e}",
file=sys.stderr,
)
if args.fail_fast:
break
finally:
conn.close()
print("\n=== summary ===")
for name, fcount, rows in totals:
print(f" ok {name}: {fcount} file(s), {rows:,} row(s)")
for name, err in failures:
print(f" FAIL {name}: {err}", file=sys.stderr)
return 1 if failures else 0
if __name__ == "__main__":
sys.exit(main())