Add column type overrides in load_folder.py and load_sas.py for enhanced schema control

Implemented a new feature allowing users to specify explicit column type mappings via a `column_types` configuration in both `load_folder.py` and `load_sas.py`. This addition enables users to bypass automatic type inference for specific columns, ensuring correct data types are used when loading datasets. Updated the YAML configuration files to include examples of the new `column_types` option, enhancing usability and flexibility in handling varying data formats across files.
This commit is contained in:
David Peterson 2026-04-21 12:14:44 -05:00
parent 0c5e6e31f0
commit ae65140390
4 changed files with 446 additions and 29 deletions

View File

@ -32,6 +32,17 @@ USAGE
# include: [ID, INTCOL]
# exclude: [ALLNULL]
# Optional folder default for explicit column type overrides. These
# win over the cluster-wide auto-union computed during pre-scan; set
# them when a column's SAS-level type varies across files (e.g. phone
# IDs stored as CHAR in some years and NUM in others) and you want to
# pin the Postgres type yourself rather than accept the auto-derived
# one. Per-cluster column_types inside each clusters[*] entry are
# merged on top of this map.
# column_types:
# RESP_PH_PREFIX_ID: TEXT
# SOME_BIGINT_COL: BIGINT
# Optional folder default for LIST partitioning. Omit or set [] for no
# partitioning. Accepts a single string or a list of column names.
# partition_by:
@ -43,14 +54,16 @@ USAGE
# Optional explicit cluster patterns. Each pattern is matched against the
# file *basename*. Matched files are pulled out of the auto-detect pool.
# Per-cluster if_exists/include/exclude/partition_by/max_partitions
# override the folder-level defaults.
# Per-cluster if_exists/include/exclude/partition_by/max_partitions/
# column_types override the folder-level defaults.
clusters:
- pattern: '^group_a\\d+\\.sas7bdat$'
tablename: group_a
- pattern: '^group_b\\d+\\.sas7bdat$'
tablename: group_b
if_exists: replace
column_types:
PHONE_PREFIX: TEXT
2. Command-line interface
-------------------------
@ -158,6 +171,7 @@ from load_sas import (
create_indexes,
create_table,
discover_partition_values_chunked,
extract_union_metadata,
infer_schema,
iter_sas_chunks,
read_sas_metadata,
@ -165,6 +179,7 @@ from load_sas import (
render_create_indexes,
render_create_table,
render_partition_ddl,
union_column_types,
)
@ -182,7 +197,11 @@ class ClusterSpec:
``partition_by``, ``max_partitions``, and ``indexes`` are resolved from
the folder defaults and any per-cluster overrides during
:func:`discover_clusters`.
:func:`discover_clusters`. ``column_types`` holds the effective type
overrides for this cluster: user-supplied YAML entries merged on top
of the auto-union result computed during pre-scan (see :func:`main`).
The same dict is threaded through to workers so every file in the
cluster infers the same schema.
"""
tablename: str
@ -195,6 +214,7 @@ class ClusterSpec:
partition_by: List[str] = field(default_factory=list)
max_partitions: int = 10_000
indexes: List[str] = field(default_factory=list)
column_types: Dict[str, str] = field(default_factory=dict)
@dataclass
@ -205,6 +225,8 @@ class _ExplicitPattern:
An explicit empty list ``[]`` means "disable partitioning for this cluster".
``max_partitions`` defaults to ``None`` meaning "inherit from folder level".
``indexes`` defaults to ``None`` meaning "inherit from folder level".
``column_types`` defaults to ``None`` meaning "inherit from folder level";
an explicit ``{}`` means "no user overrides for this cluster".
"""
pattern: re.Pattern
@ -216,6 +238,7 @@ class _ExplicitPattern:
partition_by: Optional[List[str]] = None
max_partitions: Optional[int] = None
indexes: Optional[List[str]] = None
column_types: Optional[Dict[str, str]] = None
@dataclass
@ -224,6 +247,9 @@ class FolderConfig:
``partition_by``, ``max_partitions``, and ``indexes`` serve as defaults
for every cluster unless overridden at the cluster level.
``column_types`` is a ``{column_name: postgres_type_str}`` map of
user-supplied type overrides that win over the auto-union computed
during pre-scan.
"""
folder: Path
@ -236,6 +262,7 @@ class FolderConfig:
partition_by: List[str] = field(default_factory=list)
max_partitions: int = 10_000
indexes: List[str] = field(default_factory=list)
column_types: Dict[str, str] = field(default_factory=dict)
# ---------------------------------------------------------------------------
@ -396,6 +423,40 @@ def _validate_indexes_vs_columns(
)
def _parse_column_types(
raw_value: Any, where: str, *, allow_none: bool = False
) -> Optional[Dict[str, str]]:
"""Parse a ``column_types`` mapping from YAML.
The value must be a mapping ``{column_name: pg_type_str}``. Keys and
values are whitespace-stripped strings; empty strings raise. When
``allow_none`` is True (used for per-cluster entries), an omitted key
returns ``None`` to mean "inherit from folder level"; an explicit
empty mapping returns ``{}`` (no overrides for this cluster).
"""
if raw_value is None:
return None if allow_none else {}
if not isinstance(raw_value, dict):
raise ValueError(
f"{where}: 'column_types' must be a mapping of "
f"{{column_name: postgres_type}}."
)
out: Dict[str, str] = {}
for k, v in raw_value.items():
key = str(k).strip()
if not key:
raise ValueError(
f"{where}: 'column_types' contains an empty column name."
)
if not isinstance(v, str) or not v.strip():
raise ValueError(
f"{where}: 'column_types[{key}]' must be a non-empty "
f"Postgres type string (got {v!r})."
)
out[key] = v.strip()
return out
def load_folder_config(path: Path) -> FolderConfig:
"""Parse and validate the folder-level YAML config at ``path``.
@ -438,6 +499,11 @@ def load_folder_config(path: Path) -> FolderConfig:
indexes = _parse_indexes(raw.get("indexes"), f"Config {path}")
_validate_indexes_vs_columns(indexes, exclude, f"Config {path}")
# -- folder-level column_types overrides --------------------------------
column_types = _parse_column_types(
raw.get("column_types"), f"Config {path}"
)
explicit: List[_ExplicitPattern] = []
clusters_raw = raw.get("clusters") or []
if not isinstance(clusters_raw, list):
@ -479,6 +545,11 @@ def load_folder_config(path: Path) -> FolderConfig:
effective_idx = c_indexes if c_indexes is not None else indexes
_validate_indexes_vs_columns(effective_idx, effective_exclude, where)
# -- per-cluster column_types overrides -----------------------------
c_column_types = _parse_column_types(
entry.get("column_types"), where, allow_none=True
)
explicit.append(
_ExplicitPattern(
pattern=compiled,
@ -490,6 +561,7 @@ def load_folder_config(path: Path) -> FolderConfig:
partition_by=c_partition_by,
max_partitions=c_max_partitions,
indexes=c_indexes,
column_types=c_column_types,
)
)
@ -504,6 +576,7 @@ def load_folder_config(path: Path) -> FolderConfig:
partition_by=partition_by,
max_partitions=max_partitions,
indexes=indexes,
column_types=column_types or {},
)
@ -601,6 +674,14 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]:
patt.indexes if patt.indexes is not None
else cfg.indexes
)
# Resolve column_types: user overrides only. The auto-union adds
# more entries later (in :func:`main`) after the metadata pre-scan.
# None = inherit folder, {} = no cluster-level overrides, dict =
# cluster-level overrides that win over folder-level entries.
if patt.column_types is None:
resolved_ct: Dict[str, str] = dict(cfg.column_types)
else:
resolved_ct = {**cfg.column_types, **patt.column_types}
matched = [f for f in remaining if patt.pattern.search(f.name)]
if not matched:
@ -618,6 +699,7 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]:
partition_by=resolved_pb,
max_partitions=resolved_mp,
indexes=resolved_idx,
column_types=dict(resolved_ct),
)
)
continue
@ -634,6 +716,7 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]:
partition_by=resolved_pb,
max_partitions=resolved_mp,
indexes=resolved_idx,
column_types=dict(resolved_ct),
)
)
@ -654,6 +737,7 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]:
partition_by=cfg.partition_by,
max_partitions=cfg.max_partitions,
indexes=cfg.indexes,
column_types=dict(cfg.column_types),
)
)
@ -666,19 +750,29 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]:
def _infer_cluster_schema(
path: Path, include, exclude
path: Path,
include,
exclude,
*,
column_types: Optional[Dict[str, str]] = None,
) -> Tuple[Dict, Optional[int]]:
"""Infer the Postgres column schema from a SAS file preview.
Returns ``(columns, total_rows)``. ``total_rows`` comes from the
pyreadstat metadata (the file's declared row count) and is threaded
through to :func:`_stream_file` so the tqdm progress bar has a real
denominator instead of an indeterminate spinner.
denominator instead of an indeterminate spinner. ``column_types``
lets the caller pin specific columns to a chosen Postgres type
(typically the merged auto-union + YAML overrides for the cluster).
"""
preview_df, meta = read_sas_preview(path)
preview_df = apply_column_filter(preview_df, include, exclude)
total_rows = getattr(meta, "number_rows", None)
columns = infer_schema(preview_df, meta, total_rows=total_rows)
columns = infer_schema(
preview_df, meta,
total_rows=total_rows,
column_types=column_types,
)
return columns, total_rows
@ -748,7 +842,8 @@ def load_cluster(
first, *rest = cluster.files
first_columns, first_total_rows = _infer_cluster_schema(
first, cluster.include, cluster.exclude
first, cluster.include, cluster.exclude,
column_types=cluster.column_types,
)
# -- Validate index columns early ---------------------------------------
@ -827,6 +922,7 @@ def load_cluster(
workers=workers,
progress_queue=progress_queue,
db_overrides=db_overrides,
column_types=cluster.column_types,
)
else:
# Serial path: stream the first file on the main connection, then
@ -842,7 +938,8 @@ def load_cluster(
conn.commit()
for path in rest:
columns, path_total_rows = _infer_cluster_schema(
path, cluster.include, cluster.exclude
path, cluster.include, cluster.exclude,
column_types=cluster.column_types,
)
# Uses the same check that if_exists=append runs. A type
# mismatch or missing column aborts the cluster; because
@ -926,6 +1023,7 @@ def _worker_load_append_file(
exclude: Optional[List[str]],
progress_queue: Any,
db_overrides: Optional[Dict[str, Optional[str]]],
column_types: Optional[Dict[str, str]] = None,
) -> Tuple[str, int, Optional[str]]:
"""Worker process: load one SAS file in append mode.
@ -965,7 +1063,11 @@ def _worker_load_append_file(
preview_df, meta = _read_sas_preview(path)
preview_df = _apply_column_filter(preview_df, include, exclude)
total_rows = getattr(meta, "number_rows", None)
columns = _infer_schema(preview_df, meta, total_rows=total_rows)
columns = _infer_schema(
preview_df, meta,
total_rows=total_rows,
column_types=column_types,
)
# Drop the preview ASAP - on a 2M-row wide file it's hundreds of MB
# and we never need it again after schema inference.
del preview_df, meta
@ -1031,6 +1133,7 @@ def _load_remaining_files_parallel(
workers: int,
progress_queue: Any,
db_overrides: Optional[Dict[str, Optional[str]]],
column_types: Optional[Dict[str, str]] = None,
) -> int:
"""Run append-mode loads for ``files`` across a process pool.
@ -1069,6 +1172,7 @@ def _load_remaining_files_parallel(
exclude,
progress_queue,
db_overrides,
column_types,
)
for p in files
]
@ -1219,7 +1323,14 @@ def main(argv: Optional[List[str]] = None) -> int:
print()
for c in loadable:
print(f"--- DDL for cluster {c.tablename!r} ---")
columns, _ = _infer_cluster_schema(c.files[0], c.include, c.exclude)
# Dry-run skips the pre-scan (so no auto-union) but user-supplied
# ``column_types`` from YAML are already baked into ``c.column_types``
# by ``discover_clusters`` - honor them here so the previewed DDL
# matches what a real load would produce on a single-file cluster.
columns, _ = _infer_cluster_schema(
c.files[0], c.include, c.exclude,
column_types=c.column_types,
)
# Print parent CREATE TABLE (with PARTITION BY if applicable).
print(
render_create_table(
@ -1332,40 +1443,58 @@ def main(argv: Optional[List[str]] = None) -> int:
# -- Metadata pre-scan -----------------------------------------------------
# Sum ``number_rows`` across every file so the tqdm bar has a real
# denominator. ``read_sas_metadata`` uses pyreadstat's ``metadataonly=True``
# fast path, but on multi-GB sas7bdat files that still reads tens of MB
# of scattered subheader pages per file - sequentially that's minutes for
# a 52-file folder. pyreadstat releases the GIL during I/O and C decoding,
# so a ThreadPool gives near-linear scaling until the disk saturates.
# ``--no-prescan`` bypasses the scan entirely; the progress bar then runs
# without an ETA - useful when pre-scan itself is expensive (half hour+
# on very large files) or when debugging iteratively.
# denominator, AND collect the per-column (readstat_type, sas_format)
# tuples so we can union schemas across files in a cluster before any
# CREATE TABLE runs. ``read_sas_metadata`` uses pyreadstat's
# ``metadataonly=True`` fast path, but on multi-GB sas7bdat files
# that still reads tens of MB of scattered subheader pages per file -
# sequentially that's minutes for a 52-file folder. pyreadstat
# releases the GIL during I/O and C decoding, so a ThreadPool gives
# near-linear scaling until the disk saturates. ``--no-prescan``
# bypasses the scan entirely; the progress bar then runs without an
# ETA *and* the auto-union is skipped (user overrides from YAML
# still apply).
all_files: List[Path] = [p for c in loadable for p in c.files]
grand_total: Optional[int] = 0
file_meta_by_path: Dict[str, Dict[str, Tuple[str, Optional[str]]]] = {}
if args.no_prescan:
grand_total = None
print(
f"[info] --no-prescan set: skipping row-count pre-scan for "
f"{len(all_files)} file(s); progress bar will show rate + "
f"elapsed but no ETA.",
f"elapsed but no ETA. Cluster-wide schema auto-union is also "
f"disabled; only user-specified column_types overrides apply.",
file=sys.stderr,
)
else:
prescan_workers = min(16, max(1, len(all_files)))
print(
f"pre-scanning row counts for {len(all_files)} file(s) "
f"across {prescan_workers} thread(s)...",
f"pre-scanning row counts + per-column metadata for "
f"{len(all_files)} file(s) across {prescan_workers} thread(s)...",
file=sys.stderr,
)
def _scan_one(p: Path) -> Tuple[Path, Optional[int], Optional[str]]:
def _scan_one(
p: Path,
) -> Tuple[
Path,
Optional[int],
Optional[Dict[str, Tuple[str, Optional[str]]]],
Optional[str],
]:
try:
meta = read_sas_metadata(p)
n = getattr(meta, "number_rows", None)
return (p, int(n) if n is not None else None, None)
col_meta = extract_union_metadata(meta)
return (
p,
int(n) if n is not None else None,
col_meta,
None,
)
except Exception as e:
return (p, None, str(e))
return (p, None, None, str(e))
unknown_total_files: List[str] = []
running_total = 0
@ -1378,7 +1507,7 @@ def main(argv: Optional[List[str]] = None) -> int:
dynamic_ncols=True,
)
try:
for p, n, err in tpool.map(_scan_one, all_files):
for p, n, col_meta, err in tpool.map(_scan_one, all_files):
prescan_bar.update(1)
if err is not None:
unknown_total_files.append(f"{p.name} ({err})")
@ -1386,6 +1515,8 @@ def main(argv: Optional[List[str]] = None) -> int:
unknown_total_files.append(p.name)
else:
running_total += n
if col_meta is not None:
file_meta_by_path[str(p)] = col_meta
finally:
prescan_bar.close()
@ -1402,6 +1533,59 @@ def main(argv: Optional[List[str]] = None) -> int:
)
grand_total = running_total
# -- Cluster-wide schema auto-union ---------------------------------------
# For each cluster, compute ``auto_types`` from the union of every
# file's metadata (see :func:`load_sas.union_column_types`). Merge with
# any user-supplied YAML overrides (user wins) and attach the result
# back onto the cluster so every later read - first-file inference,
# worker inference, schema-compat check - sees the same frozen schema.
# With ``--no-prescan`` the file_meta_by_path dict is empty and
# ``auto_types`` resolves to {}, so only the YAML overrides survive.
for c in loadable:
per_file = [
file_meta_by_path[str(p)]
for p in c.files
if str(p) in file_meta_by_path
]
auto_types = union_column_types(per_file) if per_file else {}
user_overrides = dict(c.column_types) # already merged folder+cluster
# User-supplied overrides win over the auto-union.
merged = {**auto_types, **user_overrides}
c.column_types = merged
if auto_types:
# Only call out columns where auto-union *changed* something
# relative to the default "first file wins" inference. We
# don't have the default inference in hand at this point, so
# log the full resolved map at a debug-friendly level - it's
# bounded by column count and the user asked for visibility
# into what got overridden.
shown = auto_types
if user_overrides:
# Distinguish the user-forced entries in the log so it's
# obvious which types came from YAML.
shown = {
col: (
f"{user_overrides[col]} (user override)"
if col in user_overrides
else pg
)
for col, pg in merged.items()
}
print(
f"[info] cluster {c.tablename!r}: auto-union derived "
f"{len(auto_types)} column type(s) across "
f"{len(per_file)} file(s): {shown}",
file=sys.stderr,
)
elif user_overrides and args.no_prescan:
print(
f"[info] cluster {c.tablename!r}: using {len(user_overrides)} "
f"user-supplied column_types override(s); auto-union "
f"disabled by --no-prescan.",
file=sys.stderr,
)
# -- Shared progress plumbing ---------------------------------------------
# The queue crosses process boundaries when workers > 1 (managed proxy)
# and is a plain in-process queue otherwise; the put/get contract is

View File

@ -307,6 +307,7 @@ class LoaderConfig:
partition_by: List[str] = field(default_factory=list)
max_partitions: int = 10_000
indexes: List[str] = field(default_factory=list)
column_types: Dict[str, str] = field(default_factory=dict)
@dataclass
@ -517,6 +518,35 @@ def load_config(path: Path) -> LoaderConfig:
f"{missing_in_include}"
)
# -- column_types -------------------------------------------------------
# Optional ``{column_name: pg_type}`` escape hatch that bypasses
# automatic type inference for specific columns. Useful when
# pyreadstat reports a column as NUM but the downstream consumer
# expects TEXT (e.g. phone-id columns), or when a column has drifted
# between CHAR and NUM across file versions and you want to pin
# TEXT up front. See also :func:`infer_schema`.
raw_ct = raw.get("column_types")
column_types: Dict[str, str] = {}
if raw_ct is not None:
if not isinstance(raw_ct, dict):
raise ValueError(
f"Config {path}: 'column_types' must be a mapping of "
f"{{column_name: postgres_type}}."
)
for k, v in raw_ct.items():
key = str(k).strip()
if not key:
raise ValueError(
f"Config {path}: 'column_types' contains an empty "
f"column name."
)
if not isinstance(v, str) or not v.strip():
raise ValueError(
f"Config {path}: 'column_types[{key}]' must be a "
f"non-empty Postgres type string (got {v!r})."
)
column_types[key] = v.strip()
return LoaderConfig(
filename=filename,
schemaname=schemaname,
@ -527,6 +557,7 @@ def load_config(path: Path) -> LoaderConfig:
partition_by=partition_by,
max_partitions=max_partitions,
indexes=indexes,
column_types=column_types,
)
@ -687,6 +718,117 @@ def _format_driven_type(sas_format: Optional[str]) -> Optional[str]:
return None
_DECIMAL_FORMAT_RE = re.compile(r"\.(\d+)")
def _format_hints_decimal(sas_format: Optional[str]) -> bool:
"""True if a numeric SAS format string explicitly carries decimal places.
SAS numeric formats are ``NAMEw.d``; ``d > 0`` means the variable was
intended to render with ``d`` decimal digits (COMMA10.2, F8.3, ...).
A bare width like ``BEST12.`` or ``F8.`` has no digits after the dot
and is treated as integer-presenting. Used by
:func:`union_column_types` to pick BIGINT vs DOUBLE PRECISION when a
column is numeric in every file of a cluster.
"""
if not sas_format:
return False
m = _DECIMAL_FORMAT_RE.search(sas_format)
if not m:
return False
try:
return int(m.group(1)) > 0
except ValueError:
return False
def extract_union_metadata(
meta: Any,
) -> Dict[str, Tuple[str, Optional[str]]]:
"""Pull the (readstat_type, sas_format) pair for every column in ``meta``.
Returns a plain dict that's safe to pass between processes and to
:func:`union_column_types`. ``readstat_type`` is the simplified type
reported by pyreadstat: ``"string"`` for SAS CHAR, ``"double"`` for
SAS NUM. ``sas_format`` comes from ``meta.original_variable_types``
and drives date/datetime detection during union.
"""
var_types = dict(getattr(meta, "variable_types", None) or {})
formats = dict(getattr(meta, "original_variable_types", None) or {})
names = list(
getattr(meta, "column_names", None)
or list(var_types.keys())
or list(formats.keys())
)
out: Dict[str, Tuple[str, Optional[str]]] = {}
for col in names:
rtype = str(var_types.get(col, "")) if var_types else ""
fmt = formats.get(col)
out[col] = (rtype, fmt if fmt else None)
return out
def union_column_types(
per_file_metas: Iterable[Dict[str, Tuple[str, Optional[str]]]],
) -> Dict[str, str]:
"""Derive one Postgres type per column that's safe across every file.
``per_file_metas`` is an iterable (one entry per file in a cluster) of
``{column_name: (readstat_type, sas_format)}`` dicts as produced by
:func:`extract_union_metadata`.
Rules, evaluated per column:
* **CHAR/NUM drift wins TEXT.** If any file stores the column as CHAR
(``readstat_type != "double"``) the union is ``TEXT``. This covers
the phone-id case where some years stored ``RESP_PH_PREFIX_ID`` as
CHAR and others as NUM.
* **All NUM, format hints DATETIME TIMESTAMP.** Any file whose
format resolves to ``TIMESTAMP`` (via :func:`_format_driven_type`)
pins the column to ``TIMESTAMP`` even if other files left the
format blank.
* **All NUM, format hints DATE DATE.** Same idea for date-only
formats.
* **All NUM, any decimal hint DOUBLE PRECISION.** A ``w.d`` format
with ``d > 0`` in any file implies fractional values somewhere.
* **All NUM, otherwise BIGINT.** Default to BIGINT per user
preference: integer-presenting NUM columns drift between
INTEGER/BIGINT/DOUBLE across files, and the few extra bytes are
worth not re-failing every load.
Columns missing from a given file are simply skipped for that file;
the union is computed over whichever files *did* supply the column.
Columns that never appear anywhere are omitted from the result.
"""
per_col: Dict[str, List[Tuple[str, Optional[str]]]] = {}
for meta in per_file_metas:
for col, pair in meta.items():
per_col.setdefault(col, []).append(pair)
result: Dict[str, str] = {}
for col, entries in per_col.items():
any_char = any(
rtype and rtype.lower() != "double" for rtype, _ in entries
)
if any_char:
result[col] = "TEXT"
continue
formats = [fmt for _, fmt in entries if fmt]
driven = [_format_driven_type(f) for f in formats]
if "TIMESTAMP" in driven:
result[col] = "TIMESTAMP"
elif "DATE" in driven:
result[col] = "DATE"
elif any(_format_hints_decimal(f) for f in formats):
result[col] = "DOUBLE PRECISION"
else:
# Safe default: BIGINT. The user explicitly accepted wasting a
# few bytes here to avoid INTEGER→BIGINT widening failures on
# multi-year clusters.
result[col] = "BIGINT"
return result
def _all_null(series: pd.Series) -> bool:
if pd.api.types.is_object_dtype(series):
return bool(series.map(lambda v: v is None or (isinstance(v, str) and v == "") or (isinstance(v, float) and pd.isna(v))).all())
@ -812,6 +954,7 @@ def infer_schema(
*,
coerce_chars: bool = COERCE_CHAR_COLUMNS,
total_rows: Optional[int] = None,
column_types: Optional[Dict[str, str]] = None,
) -> Dict[str, ColumnSpec]:
"""Infer a Postgres column spec for each column in ``df``.
@ -827,6 +970,14 @@ def infer_schema(
``total_rows`` lets callers who already sampled the frame (e.g. via
:func:`read_sas_preview`) report the real file size in the per-column
"inferred from first N of M rows" note. Falls back to ``len(df)``.
``column_types`` is an optional map ``{column_name: pg_type_str}``
whose entries bypass inference entirely - the caller has already
decided the type (e.g. via :func:`union_column_types` across a
cluster, or a YAML ``column_types`` override). Nullability is still
computed from the data. Columns in ``column_types`` that don't exist
in ``df`` are ignored so a shared override dict can apply to clusters
with different column sets.
"""
original_formats: Dict[str, str] = dict(getattr(meta, "original_variable_types", {}) or {})
@ -846,6 +997,8 @@ def infer_schema(
sample_size = df_rows
sampled = sample_size < effective_total
overrides: Dict[str, str] = dict(column_types or {})
# Temporarily flip the module-level flag if the caller asked us to.
global COERCE_CHAR_COLUMNS
saved = COERCE_CHAR_COLUMNS
@ -858,6 +1011,23 @@ def infer_schema(
sas_format = original_formats.get(col)
notes: List[str] = []
if col in overrides:
pg_type = overrides[col]
notes.append(
f"type forced to {pg_type} via column_types override"
)
nullable = _is_nullable(series)
out[col] = ColumnSpec(
name=col,
postgres_type=pg_type,
nullable=nullable,
sas_format=sas_format,
source_dtype=str(series.dtype),
notes=notes,
sampled=sampled,
)
continue
pg_type = _format_driven_type(sas_format)
if pg_type is None:
@ -1832,7 +2002,33 @@ def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.Da
# astype(str) stringifies NaN/None to the literal "nan"/"None",
# so we mask those after the fact rather than branching per cell.
na_mask = series.isna()
out[name] = series.astype(str).mask(na_mask, "")
if pd.api.types.is_numeric_dtype(series):
# Hit when a column was auto-unioned to TEXT because at
# least one file of the cluster stored it as CHAR but this
# particular file stored it as NUM (typical of SAS phone-id
# columns). Default float formatting would emit "123.0" -
# which doesn't match the plain "123" coming from the CHAR
# files. When the whole chunk is integer-valued, round to
# int before stringifying; when any fractional value is
# present we leave float formatting alone so we don't
# silently drop precision.
nonnull = series.dropna()
int_like = False
if not nonnull.empty:
try:
int_like = bool(((nonnull % 1) == 0).all())
except TypeError:
int_like = False
if int_like:
# ``Int64`` preserves NA; ``.astype(str)`` renders NA
# as '<NA>', which we then mask out alongside original
# NaNs.
as_str = series.astype("Int64").astype(str)
out[name] = as_str.mask(na_mask, "")
else:
out[name] = series.astype(str).mask(na_mask, "")
else:
out[name] = series.astype(str).mask(na_mask, "")
elif pg == "BOOLEAN":
out[name] = series.astype("boolean") if series.dtype != object else series
else:
@ -2064,7 +2260,7 @@ def main(argv: Optional[List[str]] = None) -> int:
# on columns whose nulls live past the window.
preview_df, meta = read_sas_preview(cfg.filename)
preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude)
columns = infer_schema(preview_df, meta)
columns = infer_schema(preview_df, meta, column_types=cfg.column_types)
# Validate partition columns exist in the schema after filtering.
if cfg.partition_by:

View File

@ -38,3 +38,15 @@ if_exists: append
# indexes:
# - state
# - zip
# column_types: Explicit {column_name: postgres_type} overrides that
# bypass automatic type inference for the listed columns. Useful when
# pyreadstat reports a column as NUM but you want it stored as TEXT
# (phone/ID columns that are conceptually strings), or when a column's
# inferred type is off for any other reason. Columns not listed here
# fall through to the normal inference path. Nullability is always
# computed from the data.
#
# column_types:
# RESP_PH_PREFIX_ID: TEXT
# SOMELONG_ID: BIGINT

View File

@ -61,15 +61,40 @@ auto_detect: true
# - state
# - zip
# Folder-level column_types: Explicit {column_name: postgres_type} map that
# bypasses automatic type inference for the listed columns. Applied to
# every cluster unless a cluster supplies its own column_types, which are
# merged on top (cluster entries win on conflict).
#
# During --workers>1 runs the pre-scan derives a cluster-wide "auto-union"
# type per column (e.g. any file stores the column as CHAR -> TEXT; all
# NUM with any format hinting decimals -> DOUBLE PRECISION; otherwise
# BIGINT). Entries in column_types here win over that auto-union - use
# them when the auto result is wrong or when --no-prescan disables the
# auto-union and you still need to pin a column.
#
# Valid type strings are anything the CREATE TABLE DDL accepts (TEXT,
# INTEGER, BIGINT, DOUBLE PRECISION, DATE, TIMESTAMP, ...). Columns that
# don't exist in a given file are simply ignored for that file.
#
# column_types:
# RESP_PH_PREFIX_ID: TEXT
# RESP_PH_SUFFIX_ID: TEXT
# SOMELONG_ID: BIGINT
# Explicit cluster patterns. Each pattern is matched against the file
# *basename*. Files matched by a pattern are pulled out of the auto-detect
# pool, so explicit and auto clusters compose cleanly.
#
# `tablename` is required. `if_exists`, `include`, and `exclude` are
# optional per-cluster overrides of the folder-level defaults above.
# `tablename` is required. `if_exists`, `include`, `exclude`, and
# `column_types` are optional per-cluster overrides of the folder-level
# defaults above. Cluster-level column_types entries win over folder-
# level entries for the same column.
clusters:
- pattern: '^group_a\d+\.xpt$'
tablename: group_a
# column_types:
# INTCOL: TEXT
# Example of an explicit override. Uncomment to force the group_b cluster to
# append instead of replace even though the folder default is "replace":