diff --git a/utils/sas_profiler.py b/utils/sas_profiler.py index 88456c2..2d70964 100644 --- a/utils/sas_profiler.py +++ b/utils/sas_profiler.py @@ -43,6 +43,7 @@ from typing import Any, Dict, List, Optional, Tuple _REPO_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(_REPO_ROOT / "generic_loader")) +import numpy as np # noqa: E402 import pandas as pd # noqa: E402 from openpyxl import Workbook # noqa: E402 from openpyxl.styles import Alignment, Font, PatternFill # noqa: E402 @@ -51,7 +52,6 @@ from openpyxl.utils import get_column_letter # noqa: E402 from load_sas import ( # noqa: E402 NUMERIC_INT_RANGE, ColumnSpec, - _char_missing_mask, infer_schema, iter_sas_chunks, read_sas_preview, @@ -96,6 +96,25 @@ PREVIEW_ROWS_FOR_INFERENCE: int = 10_000 """Rows pulled from the file for the loader's schema inference. Matches ``load_sas.TYPE_INFERENCE_SAMPLE_ROWS`` so suggestions track the loader.""" +PROFILE_CHUNK_ROWS: int = 5_000_000 +"""Rows per streaming chunk while profiling. Larger chunks amortize +pyreadstat / pandas overhead, and the profiler is typically run on a +beefy box (e.g. a 128 GB EC2) rather than a laptop, so the default is +set aggressively. + +Rough peak-memory estimate while a chunk is in flight: + + peak_bytes ~= chunksize * num_cols * ~50 bytes/cell * 2-3x + +(The 2-3x factor covers pyreadstat's read buffer + pandas frame +construction temporaries.) At 5M rows x 50 cols that's roughly 10-20 GB, +which is comfortable on a 128 GB host but would OOM a laptop. + +If you have lots of RAM and a very wide file, lower this; if you have a +narrow file and want max throughput, bump it higher with ``--chunksize`` +(the profiler will happily take 20M+ per chunk). If ``chunksize`` is +larger than the file, pyreadstat just hands back one chunk.""" + PARTITION_NAME_PATTERNS: Tuple[re.Pattern, ...] = ( re.compile(r"^state$", re.IGNORECASE), @@ -153,41 +172,60 @@ class _ColumnStats: samples: List[Any] = field(default_factory=list) def update(self, series: pd.Series) -> None: - """Fold one chunk's worth of this column into the accumulator.""" - self.n_total += len(series) - if len(series) == 0: - return + """Fold one chunk's worth of this column into the accumulator. - if pd.api.types.is_object_dtype(series): - miss_mask = _char_missing_mask(series) + Implementation notes (this method is the dominant per-file cost): + + - All masks are vectorized - no ``Series.map(lambda ...)`` loops. + - Distinct tracking uses ``Series.value_counts`` so we iterate at + most once per *unique* value in the chunk (in C), not once per + row. + - Once ``distinct_overflow`` is set and ``top_counts`` is full, + subsequent chunks skip the value-counts pass entirely - we already + know the column is too varied to be a partition / drop candidate + and we already have the top-N. + """ + n = len(series) + if n == 0: + return + self.n_total += n + + is_object = pd.api.types.is_object_dtype(series) + is_numeric = pd.api.types.is_numeric_dtype(series) + is_datetime = pd.api.types.is_datetime64_any_dtype(series) + + if is_object: + # Vectorized equivalent of load_sas._char_missing_mask: treat + # None / NaN / empty string as missing. ``series == ""`` is + # False for non-string values so we don't need per-element type + # checks. + na_mask = series.isna() + empty_mask = (series == "") & ~na_mask + miss_mask = na_mask | empty_mask + self.n_empty_str += int(empty_mask.sum()) else: miss_mask = series.isna() - miss_count = int(miss_mask.sum()) - self.n_null += miss_count + self.n_null += int(miss_mask.sum()) + non_null = series[~miss_mask] if miss_mask.any() else series + if non_null.empty: + return - non_null = series[~miss_mask] - - if pd.api.types.is_object_dtype(series): - # Empty-string tracking is useful for TEXT columns where the loader - # later translates "" -> NULL in the COPY step. A column dominated - # by empty strings is still effectively null even if it isn't NaN. - empty_mask = series.map(lambda v: isinstance(v, str) and v == "") - self.n_empty_str += int(empty_mask.sum()) - - if pd.api.types.is_numeric_dtype(series) and not non_null.empty: - as_float = non_null.astype("float64") - self.numeric_sum += float(as_float.sum()) - self.numeric_sumsq += float((as_float * as_float).sum()) - self.numeric_count += int(len(as_float)) - cmin = as_float.min() - cmax = as_float.max() - if self.min_val is None or cmin < self.min_val: + # -- Numeric stats (C-level) --------------------------------------- + if is_numeric: + arr = non_null.to_numpy(dtype="float64", copy=False, na_value=np.nan) + # NaN-safe aggregates in one pass each (all C-level). + self.numeric_sum += float(np.nansum(arr)) + self.numeric_sumsq += float(np.nansum(arr * arr)) + self.numeric_count += int(arr.size) + cmin = float(np.nanmin(arr)) if arr.size else None + cmax = float(np.nanmax(arr)) if arr.size else None + if cmin is not None and (self.min_val is None or cmin < self.min_val): self.min_val = cmin - if self.max_val is None or cmax > self.max_val: + if cmax is not None and (self.max_val is None or cmax > self.max_val): self.max_val = cmax - elif pd.api.types.is_datetime64_any_dtype(series) and not non_null.empty: + elif is_datetime: cmin = non_null.min() cmax = non_null.max() if self.min_val is None or cmin < self.min_val: @@ -195,36 +233,68 @@ class _ColumnStats: if self.max_val is None or cmax > self.max_val: self.max_val = cmax - if pd.api.types.is_object_dtype(series) and not non_null.empty: - str_like = non_null.map(lambda v: v if isinstance(v, str) else str(v)) - byte_lens = str_like.map(lambda s: len(s.encode("utf-8", errors="replace"))) - if len(byte_lens): - bmax = int(byte_lens.max()) + # -- String length stats via vectorized str.len -------------------- + # ``.str.len()`` is C-fast; for ASCII-dominated SAS data it matches + # UTF-8 byte length closely enough for the "oversized TEXT" flag. + if is_object: + lens = non_null.astype(str, copy=False).str.len() + lens = lens.dropna() + if not lens.empty: + bmax = int(lens.max()) if bmax > self.str_max_bytes: self.str_max_bytes = bmax - self.str_sum_bytes += int(byte_lens.sum()) - self.str_count += int(len(byte_lens)) + self.str_sum_bytes += int(lens.sum()) + self.str_count += int(lens.size) - for val in non_null.tolist(): - hashable = _hashable(val) - if hashable is _UNHASHABLE: - # Give up on distinct/top-counts for this column; it's some - # exotic (e.g. list) value we can't hash, and the drop/index - # suggestions wouldn't be meaningful anyway. - self.distinct_overflow = True - continue - if not self.distinct_overflow: - if hashable in self.distinct: - pass - elif len(self.distinct) >= DISTINCT_CAP: + # -- Samples (tiny slice; free) ------------------------------------ + if len(self.samples) < 3: + needed = 3 - len(self.samples) + self.samples.extend(non_null.head(needed).tolist()) + + # -- Distinct / top_counts (vectorized via value_counts) ----------- + # Skip altogether once we're saturated: distinct is already known + # to be > DISTINCT_CAP and top_counts has its DISTINCT_CAP slots + # filled, so further value_counts calls can only bump existing + # keys - info we don't need for any of the classifiers. + top_full = len(self.top_counts) >= DISTINCT_CAP + if self.distinct_overflow and top_full: + return + + try: + vc = non_null.value_counts(sort=False) + except TypeError: + # Unhashable values (list/dict). Drop the column from both + # distinct and top-N tracking. + self.distinct_overflow = True + return + + if vc.empty: + return + + if not self.distinct_overflow: + # Only *new* values need to be considered for the distinct set. + for val in vc.index: + if val in self.distinct: + continue + if len(self.distinct) >= DISTINCT_CAP: self.distinct_overflow = True - else: - self.distinct.add(hashable) - if len(self.top_counts) < DISTINCT_CAP or hashable in self.top_counts: - self.top_counts[hashable] += 1 + break + self.distinct.add(val) - if len(self.samples) < 3: - self.samples.append(val) + if not top_full: + # Bulk-merge known keys; cap adds for new keys. + for val, count in zip(vc.index.tolist(), vc.to_numpy().tolist()): + if val in self.top_counts: + self.top_counts[val] += int(count) + elif len(self.top_counts) < DISTINCT_CAP: + self.top_counts[val] = int(count) + # else: silently skip - we're past the cap. + else: + # Only existing keys can grow. + tc = self.top_counts + for val, count in zip(vc.index.tolist(), vc.to_numpy().tolist()): + if val in tc: + tc[val] += int(count) # -- Derived properties ------------------------------------------------ @@ -279,27 +349,6 @@ class _ColumnStats: return self.top_counts.most_common(n) -class _UnhashableSentinel: - pass - - -_UNHASHABLE = _UnhashableSentinel() - - -def _hashable(val: Any) -> Any: - """Return a hashable form of ``val``, or :data:`_UNHASHABLE` if we can't. - - pandas occasionally hands us objects (lists, dicts) from object columns - that aren't hashable. Rather than crashing the whole report, we let the - column fall back to "distinct_overflow" mode for those rows. - """ - try: - hash(val) - return val - except TypeError: - return _UNHASHABLE - - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -366,16 +415,40 @@ def profile_file( } total_rows = 0 - kwargs = {} - if chunksize is not None: - kwargs["chunksize"] = chunksize + effective_chunksize = chunksize if chunksize is not None else PROFILE_CHUNK_ROWS + kwargs = {"chunksize": effective_chunksize} + # pyreadstat + pandas are both C-level; the per-chunk overhead we pay + # is dominated by the value_counts passes in _ColumnStats.update, so + # the profile runs O(total_rows) with a small constant. + import time + started_at = time.monotonic() + last_print_at = started_at for chunk_df, _chunk_meta in iter_sas_chunks(path, **kwargs): total_rows += len(chunk_df) - print(f" profiling... {total_rows:,} rows", file=sys.stderr) for name, cs in stats.items(): if name not in chunk_df.columns: continue cs.update(chunk_df[name]) + now = time.monotonic() + # Throttle progress output to ~one line per 2 seconds so huge files + # don't spam stderr but small files still print at least once. + if now - last_print_at >= 2.0: + elapsed = now - started_at + rate = total_rows / elapsed if elapsed > 0 else 0.0 + print( + f" profiling... {total_rows:,} rows " + f"({rate:,.0f} rows/s)", + file=sys.stderr, + ) + last_print_at = now + + elapsed = time.monotonic() - started_at + rate = total_rows / elapsed if elapsed > 0 else 0.0 + print( + f" profiled {total_rows:,} rows in {elapsed:.1f}s " + f"({rate:,.0f} rows/s)", + file=sys.stderr, + ) return stats, columns, meta, total_rows @@ -985,6 +1058,14 @@ def _build_argparser() -> argparse.ArgumentParser: help="Uniqueness (distinct/non-null) at/above which a column is an index candidate.") p.add_argument("--partition-min-fill-pct", type=float, default=PARTITION_MIN_FILL_PCT) p.add_argument("--pre-sharded-max-distinct", type=int, default=PRE_SHARDED_MAX_DISTINCT) + p.add_argument( + "--chunksize", type=int, default=None, + help=( + "Rows per streaming read. Bigger chunks amortize pyreadstat / " + "pandas overhead (faster for huge files) but use more peak " + f"memory. Defaults to PROFILE_CHUNK_ROWS ({PROFILE_CHUNK_ROWS:,})." + ), + ) return p @@ -999,7 +1080,7 @@ def main(argv: Optional[List[str]] = None) -> int: return 2 print(f"profiling {path} -> {out_path}", file=sys.stderr) - stats, columns, meta, total_rows = profile_file(path) + stats, columns, meta, total_rows = profile_file(path, chunksize=args.chunksize) drops, partitions, indexes, warnings = classify( stats, columns,