diff --git a/generic_loader/load_folder.py b/generic_loader/load_folder.py index 76b38d9..c9032d2 100644 --- a/generic_loader/load_folder.py +++ b/generic_loader/load_folder.py @@ -32,6 +32,17 @@ USAGE # include: [ID, INTCOL] # exclude: [ALLNULL] + # Optional folder default for explicit column type overrides. These + # win over the cluster-wide auto-union computed during pre-scan; set + # them when a column's SAS-level type varies across files (e.g. phone + # IDs stored as CHAR in some years and NUM in others) and you want to + # pin the Postgres type yourself rather than accept the auto-derived + # one. Per-cluster column_types inside each clusters[*] entry are + # merged on top of this map. + # column_types: + # RESP_PH_PREFIX_ID: TEXT + # SOME_BIGINT_COL: BIGINT + # Optional folder default for LIST partitioning. Omit or set [] for no # partitioning. Accepts a single string or a list of column names. # partition_by: @@ -43,14 +54,16 @@ USAGE # Optional explicit cluster patterns. Each pattern is matched against the # file *basename*. Matched files are pulled out of the auto-detect pool. - # Per-cluster if_exists/include/exclude/partition_by/max_partitions - # override the folder-level defaults. + # Per-cluster if_exists/include/exclude/partition_by/max_partitions/ + # column_types override the folder-level defaults. clusters: - pattern: '^group_a\\d+\\.sas7bdat$' tablename: group_a - pattern: '^group_b\\d+\\.sas7bdat$' tablename: group_b if_exists: replace + column_types: + PHONE_PREFIX: TEXT 2. Command-line interface ------------------------- @@ -158,6 +171,7 @@ from load_sas import ( create_indexes, create_table, discover_partition_values_chunked, + extract_union_metadata, infer_schema, iter_sas_chunks, read_sas_metadata, @@ -165,6 +179,7 @@ from load_sas import ( render_create_indexes, render_create_table, render_partition_ddl, + union_column_types, ) @@ -182,7 +197,11 @@ class ClusterSpec: ``partition_by``, ``max_partitions``, and ``indexes`` are resolved from the folder defaults and any per-cluster overrides during - :func:`discover_clusters`. + :func:`discover_clusters`. ``column_types`` holds the effective type + overrides for this cluster: user-supplied YAML entries merged on top + of the auto-union result computed during pre-scan (see :func:`main`). + The same dict is threaded through to workers so every file in the + cluster infers the same schema. """ tablename: str @@ -195,6 +214,7 @@ class ClusterSpec: partition_by: List[str] = field(default_factory=list) max_partitions: int = 10_000 indexes: List[str] = field(default_factory=list) + column_types: Dict[str, str] = field(default_factory=dict) @dataclass @@ -205,6 +225,8 @@ class _ExplicitPattern: An explicit empty list ``[]`` means "disable partitioning for this cluster". ``max_partitions`` defaults to ``None`` meaning "inherit from folder level". ``indexes`` defaults to ``None`` meaning "inherit from folder level". + ``column_types`` defaults to ``None`` meaning "inherit from folder level"; + an explicit ``{}`` means "no user overrides for this cluster". """ pattern: re.Pattern @@ -216,6 +238,7 @@ class _ExplicitPattern: partition_by: Optional[List[str]] = None max_partitions: Optional[int] = None indexes: Optional[List[str]] = None + column_types: Optional[Dict[str, str]] = None @dataclass @@ -224,6 +247,9 @@ class FolderConfig: ``partition_by``, ``max_partitions``, and ``indexes`` serve as defaults for every cluster unless overridden at the cluster level. + ``column_types`` is a ``{column_name: postgres_type_str}`` map of + user-supplied type overrides that win over the auto-union computed + during pre-scan. """ folder: Path @@ -236,6 +262,7 @@ class FolderConfig: partition_by: List[str] = field(default_factory=list) max_partitions: int = 10_000 indexes: List[str] = field(default_factory=list) + column_types: Dict[str, str] = field(default_factory=dict) # --------------------------------------------------------------------------- @@ -396,6 +423,40 @@ def _validate_indexes_vs_columns( ) +def _parse_column_types( + raw_value: Any, where: str, *, allow_none: bool = False +) -> Optional[Dict[str, str]]: + """Parse a ``column_types`` mapping from YAML. + + The value must be a mapping ``{column_name: pg_type_str}``. Keys and + values are whitespace-stripped strings; empty strings raise. When + ``allow_none`` is True (used for per-cluster entries), an omitted key + returns ``None`` to mean "inherit from folder level"; an explicit + empty mapping returns ``{}`` (no overrides for this cluster). + """ + if raw_value is None: + return None if allow_none else {} + if not isinstance(raw_value, dict): + raise ValueError( + f"{where}: 'column_types' must be a mapping of " + f"{{column_name: postgres_type}}." + ) + out: Dict[str, str] = {} + for k, v in raw_value.items(): + key = str(k).strip() + if not key: + raise ValueError( + f"{where}: 'column_types' contains an empty column name." + ) + if not isinstance(v, str) or not v.strip(): + raise ValueError( + f"{where}: 'column_types[{key}]' must be a non-empty " + f"Postgres type string (got {v!r})." + ) + out[key] = v.strip() + return out + + def load_folder_config(path: Path) -> FolderConfig: """Parse and validate the folder-level YAML config at ``path``. @@ -438,6 +499,11 @@ def load_folder_config(path: Path) -> FolderConfig: indexes = _parse_indexes(raw.get("indexes"), f"Config {path}") _validate_indexes_vs_columns(indexes, exclude, f"Config {path}") + # -- folder-level column_types overrides -------------------------------- + column_types = _parse_column_types( + raw.get("column_types"), f"Config {path}" + ) + explicit: List[_ExplicitPattern] = [] clusters_raw = raw.get("clusters") or [] if not isinstance(clusters_raw, list): @@ -479,6 +545,11 @@ def load_folder_config(path: Path) -> FolderConfig: effective_idx = c_indexes if c_indexes is not None else indexes _validate_indexes_vs_columns(effective_idx, effective_exclude, where) + # -- per-cluster column_types overrides ----------------------------- + c_column_types = _parse_column_types( + entry.get("column_types"), where, allow_none=True + ) + explicit.append( _ExplicitPattern( pattern=compiled, @@ -490,6 +561,7 @@ def load_folder_config(path: Path) -> FolderConfig: partition_by=c_partition_by, max_partitions=c_max_partitions, indexes=c_indexes, + column_types=c_column_types, ) ) @@ -504,6 +576,7 @@ def load_folder_config(path: Path) -> FolderConfig: partition_by=partition_by, max_partitions=max_partitions, indexes=indexes, + column_types=column_types or {}, ) @@ -601,6 +674,14 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: patt.indexes if patt.indexes is not None else cfg.indexes ) + # Resolve column_types: user overrides only. The auto-union adds + # more entries later (in :func:`main`) after the metadata pre-scan. + # None = inherit folder, {} = no cluster-level overrides, dict = + # cluster-level overrides that win over folder-level entries. + if patt.column_types is None: + resolved_ct: Dict[str, str] = dict(cfg.column_types) + else: + resolved_ct = {**cfg.column_types, **patt.column_types} matched = [f for f in remaining if patt.pattern.search(f.name)] if not matched: @@ -618,6 +699,7 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: partition_by=resolved_pb, max_partitions=resolved_mp, indexes=resolved_idx, + column_types=dict(resolved_ct), ) ) continue @@ -634,6 +716,7 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: partition_by=resolved_pb, max_partitions=resolved_mp, indexes=resolved_idx, + column_types=dict(resolved_ct), ) ) @@ -654,6 +737,7 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: partition_by=cfg.partition_by, max_partitions=cfg.max_partitions, indexes=cfg.indexes, + column_types=dict(cfg.column_types), ) ) @@ -666,19 +750,29 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]: def _infer_cluster_schema( - path: Path, include, exclude + path: Path, + include, + exclude, + *, + column_types: Optional[Dict[str, str]] = None, ) -> Tuple[Dict, Optional[int]]: """Infer the Postgres column schema from a SAS file preview. Returns ``(columns, total_rows)``. ``total_rows`` comes from the pyreadstat metadata (the file's declared row count) and is threaded through to :func:`_stream_file` so the tqdm progress bar has a real - denominator instead of an indeterminate spinner. + denominator instead of an indeterminate spinner. ``column_types`` + lets the caller pin specific columns to a chosen Postgres type + (typically the merged auto-union + YAML overrides for the cluster). """ preview_df, meta = read_sas_preview(path) preview_df = apply_column_filter(preview_df, include, exclude) total_rows = getattr(meta, "number_rows", None) - columns = infer_schema(preview_df, meta, total_rows=total_rows) + columns = infer_schema( + preview_df, meta, + total_rows=total_rows, + column_types=column_types, + ) return columns, total_rows @@ -748,7 +842,8 @@ def load_cluster( first, *rest = cluster.files first_columns, first_total_rows = _infer_cluster_schema( - first, cluster.include, cluster.exclude + first, cluster.include, cluster.exclude, + column_types=cluster.column_types, ) # -- Validate index columns early --------------------------------------- @@ -827,6 +922,7 @@ def load_cluster( workers=workers, progress_queue=progress_queue, db_overrides=db_overrides, + column_types=cluster.column_types, ) else: # Serial path: stream the first file on the main connection, then @@ -842,7 +938,8 @@ def load_cluster( conn.commit() for path in rest: columns, path_total_rows = _infer_cluster_schema( - path, cluster.include, cluster.exclude + path, cluster.include, cluster.exclude, + column_types=cluster.column_types, ) # Uses the same check that if_exists=append runs. A type # mismatch or missing column aborts the cluster; because @@ -926,6 +1023,7 @@ def _worker_load_append_file( exclude: Optional[List[str]], progress_queue: Any, db_overrides: Optional[Dict[str, Optional[str]]], + column_types: Optional[Dict[str, str]] = None, ) -> Tuple[str, int, Optional[str]]: """Worker process: load one SAS file in append mode. @@ -965,7 +1063,11 @@ def _worker_load_append_file( preview_df, meta = _read_sas_preview(path) preview_df = _apply_column_filter(preview_df, include, exclude) total_rows = getattr(meta, "number_rows", None) - columns = _infer_schema(preview_df, meta, total_rows=total_rows) + columns = _infer_schema( + preview_df, meta, + total_rows=total_rows, + column_types=column_types, + ) # Drop the preview ASAP - on a 2M-row wide file it's hundreds of MB # and we never need it again after schema inference. del preview_df, meta @@ -1031,6 +1133,7 @@ def _load_remaining_files_parallel( workers: int, progress_queue: Any, db_overrides: Optional[Dict[str, Optional[str]]], + column_types: Optional[Dict[str, str]] = None, ) -> int: """Run append-mode loads for ``files`` across a process pool. @@ -1069,6 +1172,7 @@ def _load_remaining_files_parallel( exclude, progress_queue, db_overrides, + column_types, ) for p in files ] @@ -1219,7 +1323,14 @@ def main(argv: Optional[List[str]] = None) -> int: print() for c in loadable: print(f"--- DDL for cluster {c.tablename!r} ---") - columns, _ = _infer_cluster_schema(c.files[0], c.include, c.exclude) + # Dry-run skips the pre-scan (so no auto-union) but user-supplied + # ``column_types`` from YAML are already baked into ``c.column_types`` + # by ``discover_clusters`` - honor them here so the previewed DDL + # matches what a real load would produce on a single-file cluster. + columns, _ = _infer_cluster_schema( + c.files[0], c.include, c.exclude, + column_types=c.column_types, + ) # Print parent CREATE TABLE (with PARTITION BY if applicable). print( render_create_table( @@ -1332,40 +1443,58 @@ def main(argv: Optional[List[str]] = None) -> int: # -- Metadata pre-scan ----------------------------------------------------- # Sum ``number_rows`` across every file so the tqdm bar has a real - # denominator. ``read_sas_metadata`` uses pyreadstat's ``metadataonly=True`` - # fast path, but on multi-GB sas7bdat files that still reads tens of MB - # of scattered subheader pages per file - sequentially that's minutes for - # a 52-file folder. pyreadstat releases the GIL during I/O and C decoding, - # so a ThreadPool gives near-linear scaling until the disk saturates. - # ``--no-prescan`` bypasses the scan entirely; the progress bar then runs - # without an ETA - useful when pre-scan itself is expensive (half hour+ - # on very large files) or when debugging iteratively. + # denominator, AND collect the per-column (readstat_type, sas_format) + # tuples so we can union schemas across files in a cluster before any + # CREATE TABLE runs. ``read_sas_metadata`` uses pyreadstat's + # ``metadataonly=True`` fast path, but on multi-GB sas7bdat files + # that still reads tens of MB of scattered subheader pages per file - + # sequentially that's minutes for a 52-file folder. pyreadstat + # releases the GIL during I/O and C decoding, so a ThreadPool gives + # near-linear scaling until the disk saturates. ``--no-prescan`` + # bypasses the scan entirely; the progress bar then runs without an + # ETA *and* the auto-union is skipped (user overrides from YAML + # still apply). all_files: List[Path] = [p for c in loadable for p in c.files] grand_total: Optional[int] = 0 + file_meta_by_path: Dict[str, Dict[str, Tuple[str, Optional[str]]]] = {} if args.no_prescan: grand_total = None print( f"[info] --no-prescan set: skipping row-count pre-scan for " f"{len(all_files)} file(s); progress bar will show rate + " - f"elapsed but no ETA.", + f"elapsed but no ETA. Cluster-wide schema auto-union is also " + f"disabled; only user-specified column_types overrides apply.", file=sys.stderr, ) else: prescan_workers = min(16, max(1, len(all_files))) print( - f"pre-scanning row counts for {len(all_files)} file(s) " - f"across {prescan_workers} thread(s)...", + f"pre-scanning row counts + per-column metadata for " + f"{len(all_files)} file(s) across {prescan_workers} thread(s)...", file=sys.stderr, ) - def _scan_one(p: Path) -> Tuple[Path, Optional[int], Optional[str]]: + def _scan_one( + p: Path, + ) -> Tuple[ + Path, + Optional[int], + Optional[Dict[str, Tuple[str, Optional[str]]]], + Optional[str], + ]: try: meta = read_sas_metadata(p) n = getattr(meta, "number_rows", None) - return (p, int(n) if n is not None else None, None) + col_meta = extract_union_metadata(meta) + return ( + p, + int(n) if n is not None else None, + col_meta, + None, + ) except Exception as e: - return (p, None, str(e)) + return (p, None, None, str(e)) unknown_total_files: List[str] = [] running_total = 0 @@ -1378,7 +1507,7 @@ def main(argv: Optional[List[str]] = None) -> int: dynamic_ncols=True, ) try: - for p, n, err in tpool.map(_scan_one, all_files): + for p, n, col_meta, err in tpool.map(_scan_one, all_files): prescan_bar.update(1) if err is not None: unknown_total_files.append(f"{p.name} ({err})") @@ -1386,6 +1515,8 @@ def main(argv: Optional[List[str]] = None) -> int: unknown_total_files.append(p.name) else: running_total += n + if col_meta is not None: + file_meta_by_path[str(p)] = col_meta finally: prescan_bar.close() @@ -1402,6 +1533,59 @@ def main(argv: Optional[List[str]] = None) -> int: ) grand_total = running_total + # -- Cluster-wide schema auto-union --------------------------------------- + # For each cluster, compute ``auto_types`` from the union of every + # file's metadata (see :func:`load_sas.union_column_types`). Merge with + # any user-supplied YAML overrides (user wins) and attach the result + # back onto the cluster so every later read - first-file inference, + # worker inference, schema-compat check - sees the same frozen schema. + # With ``--no-prescan`` the file_meta_by_path dict is empty and + # ``auto_types`` resolves to {}, so only the YAML overrides survive. + for c in loadable: + per_file = [ + file_meta_by_path[str(p)] + for p in c.files + if str(p) in file_meta_by_path + ] + auto_types = union_column_types(per_file) if per_file else {} + user_overrides = dict(c.column_types) # already merged folder+cluster + # User-supplied overrides win over the auto-union. + merged = {**auto_types, **user_overrides} + c.column_types = merged + + if auto_types: + # Only call out columns where auto-union *changed* something + # relative to the default "first file wins" inference. We + # don't have the default inference in hand at this point, so + # log the full resolved map at a debug-friendly level - it's + # bounded by column count and the user asked for visibility + # into what got overridden. + shown = auto_types + if user_overrides: + # Distinguish the user-forced entries in the log so it's + # obvious which types came from YAML. + shown = { + col: ( + f"{user_overrides[col]} (user override)" + if col in user_overrides + else pg + ) + for col, pg in merged.items() + } + print( + f"[info] cluster {c.tablename!r}: auto-union derived " + f"{len(auto_types)} column type(s) across " + f"{len(per_file)} file(s): {shown}", + file=sys.stderr, + ) + elif user_overrides and args.no_prescan: + print( + f"[info] cluster {c.tablename!r}: using {len(user_overrides)} " + f"user-supplied column_types override(s); auto-union " + f"disabled by --no-prescan.", + file=sys.stderr, + ) + # -- Shared progress plumbing --------------------------------------------- # The queue crosses process boundaries when workers > 1 (managed proxy) # and is a plain in-process queue otherwise; the put/get contract is diff --git a/generic_loader/load_sas.py b/generic_loader/load_sas.py index 13e59cb..fe52537 100644 --- a/generic_loader/load_sas.py +++ b/generic_loader/load_sas.py @@ -307,6 +307,7 @@ class LoaderConfig: partition_by: List[str] = field(default_factory=list) max_partitions: int = 10_000 indexes: List[str] = field(default_factory=list) + column_types: Dict[str, str] = field(default_factory=dict) @dataclass @@ -517,6 +518,35 @@ def load_config(path: Path) -> LoaderConfig: f"{missing_in_include}" ) + # -- column_types ------------------------------------------------------- + # Optional ``{column_name: pg_type}`` escape hatch that bypasses + # automatic type inference for specific columns. Useful when + # pyreadstat reports a column as NUM but the downstream consumer + # expects TEXT (e.g. phone-id columns), or when a column has drifted + # between CHAR and NUM across file versions and you want to pin + # TEXT up front. See also :func:`infer_schema`. + raw_ct = raw.get("column_types") + column_types: Dict[str, str] = {} + if raw_ct is not None: + if not isinstance(raw_ct, dict): + raise ValueError( + f"Config {path}: 'column_types' must be a mapping of " + f"{{column_name: postgres_type}}." + ) + for k, v in raw_ct.items(): + key = str(k).strip() + if not key: + raise ValueError( + f"Config {path}: 'column_types' contains an empty " + f"column name." + ) + if not isinstance(v, str) or not v.strip(): + raise ValueError( + f"Config {path}: 'column_types[{key}]' must be a " + f"non-empty Postgres type string (got {v!r})." + ) + column_types[key] = v.strip() + return LoaderConfig( filename=filename, schemaname=schemaname, @@ -527,6 +557,7 @@ def load_config(path: Path) -> LoaderConfig: partition_by=partition_by, max_partitions=max_partitions, indexes=indexes, + column_types=column_types, ) @@ -687,6 +718,117 @@ def _format_driven_type(sas_format: Optional[str]) -> Optional[str]: return None +_DECIMAL_FORMAT_RE = re.compile(r"\.(\d+)") + + +def _format_hints_decimal(sas_format: Optional[str]) -> bool: + """True if a numeric SAS format string explicitly carries decimal places. + + SAS numeric formats are ``NAMEw.d``; ``d > 0`` means the variable was + intended to render with ``d`` decimal digits (COMMA10.2, F8.3, ...). + A bare width like ``BEST12.`` or ``F8.`` has no digits after the dot + and is treated as integer-presenting. Used by + :func:`union_column_types` to pick BIGINT vs DOUBLE PRECISION when a + column is numeric in every file of a cluster. + """ + if not sas_format: + return False + m = _DECIMAL_FORMAT_RE.search(sas_format) + if not m: + return False + try: + return int(m.group(1)) > 0 + except ValueError: + return False + + +def extract_union_metadata( + meta: Any, +) -> Dict[str, Tuple[str, Optional[str]]]: + """Pull the (readstat_type, sas_format) pair for every column in ``meta``. + + Returns a plain dict that's safe to pass between processes and to + :func:`union_column_types`. ``readstat_type`` is the simplified type + reported by pyreadstat: ``"string"`` for SAS CHAR, ``"double"`` for + SAS NUM. ``sas_format`` comes from ``meta.original_variable_types`` + and drives date/datetime detection during union. + """ + var_types = dict(getattr(meta, "variable_types", None) or {}) + formats = dict(getattr(meta, "original_variable_types", None) or {}) + names = list( + getattr(meta, "column_names", None) + or list(var_types.keys()) + or list(formats.keys()) + ) + out: Dict[str, Tuple[str, Optional[str]]] = {} + for col in names: + rtype = str(var_types.get(col, "")) if var_types else "" + fmt = formats.get(col) + out[col] = (rtype, fmt if fmt else None) + return out + + +def union_column_types( + per_file_metas: Iterable[Dict[str, Tuple[str, Optional[str]]]], +) -> Dict[str, str]: + """Derive one Postgres type per column that's safe across every file. + + ``per_file_metas`` is an iterable (one entry per file in a cluster) of + ``{column_name: (readstat_type, sas_format)}`` dicts as produced by + :func:`extract_union_metadata`. + + Rules, evaluated per column: + + * **CHAR/NUM drift wins TEXT.** If any file stores the column as CHAR + (``readstat_type != "double"``) the union is ``TEXT``. This covers + the phone-id case where some years stored ``RESP_PH_PREFIX_ID`` as + CHAR and others as NUM. + * **All NUM, format hints DATETIME → TIMESTAMP.** Any file whose + format resolves to ``TIMESTAMP`` (via :func:`_format_driven_type`) + pins the column to ``TIMESTAMP`` even if other files left the + format blank. + * **All NUM, format hints DATE → DATE.** Same idea for date-only + formats. + * **All NUM, any decimal hint → DOUBLE PRECISION.** A ``w.d`` format + with ``d > 0`` in any file implies fractional values somewhere. + * **All NUM, otherwise → BIGINT.** Default to BIGINT per user + preference: integer-presenting NUM columns drift between + INTEGER/BIGINT/DOUBLE across files, and the few extra bytes are + worth not re-failing every load. + + Columns missing from a given file are simply skipped for that file; + the union is computed over whichever files *did* supply the column. + Columns that never appear anywhere are omitted from the result. + """ + per_col: Dict[str, List[Tuple[str, Optional[str]]]] = {} + for meta in per_file_metas: + for col, pair in meta.items(): + per_col.setdefault(col, []).append(pair) + + result: Dict[str, str] = {} + for col, entries in per_col.items(): + any_char = any( + rtype and rtype.lower() != "double" for rtype, _ in entries + ) + if any_char: + result[col] = "TEXT" + continue + formats = [fmt for _, fmt in entries if fmt] + driven = [_format_driven_type(f) for f in formats] + if "TIMESTAMP" in driven: + result[col] = "TIMESTAMP" + elif "DATE" in driven: + result[col] = "DATE" + elif any(_format_hints_decimal(f) for f in formats): + result[col] = "DOUBLE PRECISION" + else: + # Safe default: BIGINT. The user explicitly accepted wasting a + # few bytes here to avoid INTEGER→BIGINT widening failures on + # multi-year clusters. + result[col] = "BIGINT" + return result + + def _all_null(series: pd.Series) -> bool: if pd.api.types.is_object_dtype(series): return bool(series.map(lambda v: v is None or (isinstance(v, str) and v == "") or (isinstance(v, float) and pd.isna(v))).all()) @@ -812,6 +954,7 @@ def infer_schema( *, coerce_chars: bool = COERCE_CHAR_COLUMNS, total_rows: Optional[int] = None, + column_types: Optional[Dict[str, str]] = None, ) -> Dict[str, ColumnSpec]: """Infer a Postgres column spec for each column in ``df``. @@ -827,6 +970,14 @@ def infer_schema( ``total_rows`` lets callers who already sampled the frame (e.g. via :func:`read_sas_preview`) report the real file size in the per-column "inferred from first N of M rows" note. Falls back to ``len(df)``. + + ``column_types`` is an optional map ``{column_name: pg_type_str}`` + whose entries bypass inference entirely - the caller has already + decided the type (e.g. via :func:`union_column_types` across a + cluster, or a YAML ``column_types`` override). Nullability is still + computed from the data. Columns in ``column_types`` that don't exist + in ``df`` are ignored so a shared override dict can apply to clusters + with different column sets. """ original_formats: Dict[str, str] = dict(getattr(meta, "original_variable_types", {}) or {}) @@ -846,6 +997,8 @@ def infer_schema( sample_size = df_rows sampled = sample_size < effective_total + overrides: Dict[str, str] = dict(column_types or {}) + # Temporarily flip the module-level flag if the caller asked us to. global COERCE_CHAR_COLUMNS saved = COERCE_CHAR_COLUMNS @@ -858,6 +1011,23 @@ def infer_schema( sas_format = original_formats.get(col) notes: List[str] = [] + if col in overrides: + pg_type = overrides[col] + notes.append( + f"type forced to {pg_type} via column_types override" + ) + nullable = _is_nullable(series) + out[col] = ColumnSpec( + name=col, + postgres_type=pg_type, + nullable=nullable, + sas_format=sas_format, + source_dtype=str(series.dtype), + notes=notes, + sampled=sampled, + ) + continue + pg_type = _format_driven_type(sas_format) if pg_type is None: @@ -1832,7 +2002,33 @@ def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.Da # astype(str) stringifies NaN/None to the literal "nan"/"None", # so we mask those after the fact rather than branching per cell. na_mask = series.isna() - out[name] = series.astype(str).mask(na_mask, "") + if pd.api.types.is_numeric_dtype(series): + # Hit when a column was auto-unioned to TEXT because at + # least one file of the cluster stored it as CHAR but this + # particular file stored it as NUM (typical of SAS phone-id + # columns). Default float formatting would emit "123.0" - + # which doesn't match the plain "123" coming from the CHAR + # files. When the whole chunk is integer-valued, round to + # int before stringifying; when any fractional value is + # present we leave float formatting alone so we don't + # silently drop precision. + nonnull = series.dropna() + int_like = False + if not nonnull.empty: + try: + int_like = bool(((nonnull % 1) == 0).all()) + except TypeError: + int_like = False + if int_like: + # ``Int64`` preserves NA; ``.astype(str)`` renders NA + # as '', which we then mask out alongside original + # NaNs. + as_str = series.astype("Int64").astype(str) + out[name] = as_str.mask(na_mask, "") + else: + out[name] = series.astype(str).mask(na_mask, "") + else: + out[name] = series.astype(str).mask(na_mask, "") elif pg == "BOOLEAN": out[name] = series.astype("boolean") if series.dtype != object else series else: @@ -2064,7 +2260,7 @@ def main(argv: Optional[List[str]] = None) -> int: # on columns whose nulls live past the window. preview_df, meta = read_sas_preview(cfg.filename) preview_df = apply_column_filter(preview_df, cfg.include, cfg.exclude) - columns = infer_schema(preview_df, meta) + columns = infer_schema(preview_df, meta, column_types=cfg.column_types) # Validate partition columns exist in the schema after filtering. if cfg.partition_by: diff --git a/generic_loader/sample_config.yaml b/generic_loader/sample_config.yaml index c487769..9ebfe45 100644 --- a/generic_loader/sample_config.yaml +++ b/generic_loader/sample_config.yaml @@ -38,3 +38,15 @@ if_exists: append # indexes: # - state # - zip + +# column_types: Explicit {column_name: postgres_type} overrides that +# bypass automatic type inference for the listed columns. Useful when +# pyreadstat reports a column as NUM but you want it stored as TEXT +# (phone/ID columns that are conceptually strings), or when a column's +# inferred type is off for any other reason. Columns not listed here +# fall through to the normal inference path. Nullability is always +# computed from the data. +# +# column_types: +# RESP_PH_PREFIX_ID: TEXT +# SOMELONG_ID: BIGINT diff --git a/generic_loader/sample_folder_config.yaml b/generic_loader/sample_folder_config.yaml index 5740c3f..b1961b9 100644 --- a/generic_loader/sample_folder_config.yaml +++ b/generic_loader/sample_folder_config.yaml @@ -61,15 +61,40 @@ auto_detect: true # - state # - zip +# Folder-level column_types: Explicit {column_name: postgres_type} map that +# bypasses automatic type inference for the listed columns. Applied to +# every cluster unless a cluster supplies its own column_types, which are +# merged on top (cluster entries win on conflict). +# +# During --workers>1 runs the pre-scan derives a cluster-wide "auto-union" +# type per column (e.g. any file stores the column as CHAR -> TEXT; all +# NUM with any format hinting decimals -> DOUBLE PRECISION; otherwise +# BIGINT). Entries in column_types here win over that auto-union - use +# them when the auto result is wrong or when --no-prescan disables the +# auto-union and you still need to pin a column. +# +# Valid type strings are anything the CREATE TABLE DDL accepts (TEXT, +# INTEGER, BIGINT, DOUBLE PRECISION, DATE, TIMESTAMP, ...). Columns that +# don't exist in a given file are simply ignored for that file. +# +# column_types: +# RESP_PH_PREFIX_ID: TEXT +# RESP_PH_SUFFIX_ID: TEXT +# SOMELONG_ID: BIGINT + # Explicit cluster patterns. Each pattern is matched against the file # *basename*. Files matched by a pattern are pulled out of the auto-detect # pool, so explicit and auto clusters compose cleanly. # -# `tablename` is required. `if_exists`, `include`, and `exclude` are -# optional per-cluster overrides of the folder-level defaults above. +# `tablename` is required. `if_exists`, `include`, `exclude`, and +# `column_types` are optional per-cluster overrides of the folder-level +# defaults above. Cluster-level column_types entries win over folder- +# level entries for the same column. clusters: - pattern: '^group_a\d+\.xpt$' tablename: group_a + # column_types: + # INTCOL: TEXT # Example of an explicit override. Uncomment to force the group_b cluster to # append instead of replace even though the folder default is "replace":