From 4fc85081c86d7b80c13a1ef0925fe6c07920a72f Mon Sep 17 00:00:00 2001 From: David Peterson Date: Mon, 20 Apr 2026 19:03:40 -0500 Subject: [PATCH] Enhance SAS profiling performance in sas_profiler.py Added a new constant for profiling chunk size to optimize memory usage during profiling operations. Refactored the update method in the _ColumnStats class to improve efficiency in handling missing values and calculating statistics for numeric and string data types. This update includes vectorized operations for better performance and clarity in the implementation. --- utils/sas_profiler.py | 239 ++++++++++++++++++++++++++++-------------- 1 file changed, 160 insertions(+), 79 deletions(-) diff --git a/utils/sas_profiler.py b/utils/sas_profiler.py index 88456c2..2d70964 100644 --- a/utils/sas_profiler.py +++ b/utils/sas_profiler.py @@ -43,6 +43,7 @@ from typing import Any, Dict, List, Optional, Tuple _REPO_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(_REPO_ROOT / "generic_loader")) +import numpy as np # noqa: E402 import pandas as pd # noqa: E402 from openpyxl import Workbook # noqa: E402 from openpyxl.styles import Alignment, Font, PatternFill # noqa: E402 @@ -51,7 +52,6 @@ from openpyxl.utils import get_column_letter # noqa: E402 from load_sas import ( # noqa: E402 NUMERIC_INT_RANGE, ColumnSpec, - _char_missing_mask, infer_schema, iter_sas_chunks, read_sas_preview, @@ -96,6 +96,25 @@ PREVIEW_ROWS_FOR_INFERENCE: int = 10_000 """Rows pulled from the file for the loader's schema inference. Matches ``load_sas.TYPE_INFERENCE_SAMPLE_ROWS`` so suggestions track the loader.""" +PROFILE_CHUNK_ROWS: int = 5_000_000 +"""Rows per streaming chunk while profiling. Larger chunks amortize +pyreadstat / pandas overhead, and the profiler is typically run on a +beefy box (e.g. a 128 GB EC2) rather than a laptop, so the default is +set aggressively. + +Rough peak-memory estimate while a chunk is in flight: + + peak_bytes ~= chunksize * num_cols * ~50 bytes/cell * 2-3x + +(The 2-3x factor covers pyreadstat's read buffer + pandas frame +construction temporaries.) At 5M rows x 50 cols that's roughly 10-20 GB, +which is comfortable on a 128 GB host but would OOM a laptop. + +If you have lots of RAM and a very wide file, lower this; if you have a +narrow file and want max throughput, bump it higher with ``--chunksize`` +(the profiler will happily take 20M+ per chunk). If ``chunksize`` is +larger than the file, pyreadstat just hands back one chunk.""" + PARTITION_NAME_PATTERNS: Tuple[re.Pattern, ...] = ( re.compile(r"^state$", re.IGNORECASE), @@ -153,41 +172,60 @@ class _ColumnStats: samples: List[Any] = field(default_factory=list) def update(self, series: pd.Series) -> None: - """Fold one chunk's worth of this column into the accumulator.""" - self.n_total += len(series) - if len(series) == 0: - return + """Fold one chunk's worth of this column into the accumulator. - if pd.api.types.is_object_dtype(series): - miss_mask = _char_missing_mask(series) + Implementation notes (this method is the dominant per-file cost): + + - All masks are vectorized - no ``Series.map(lambda ...)`` loops. + - Distinct tracking uses ``Series.value_counts`` so we iterate at + most once per *unique* value in the chunk (in C), not once per + row. + - Once ``distinct_overflow`` is set and ``top_counts`` is full, + subsequent chunks skip the value-counts pass entirely - we already + know the column is too varied to be a partition / drop candidate + and we already have the top-N. + """ + n = len(series) + if n == 0: + return + self.n_total += n + + is_object = pd.api.types.is_object_dtype(series) + is_numeric = pd.api.types.is_numeric_dtype(series) + is_datetime = pd.api.types.is_datetime64_any_dtype(series) + + if is_object: + # Vectorized equivalent of load_sas._char_missing_mask: treat + # None / NaN / empty string as missing. ``series == ""`` is + # False for non-string values so we don't need per-element type + # checks. + na_mask = series.isna() + empty_mask = (series == "") & ~na_mask + miss_mask = na_mask | empty_mask + self.n_empty_str += int(empty_mask.sum()) else: miss_mask = series.isna() - miss_count = int(miss_mask.sum()) - self.n_null += miss_count + self.n_null += int(miss_mask.sum()) + non_null = series[~miss_mask] if miss_mask.any() else series + if non_null.empty: + return - non_null = series[~miss_mask] - - if pd.api.types.is_object_dtype(series): - # Empty-string tracking is useful for TEXT columns where the loader - # later translates "" -> NULL in the COPY step. A column dominated - # by empty strings is still effectively null even if it isn't NaN. - empty_mask = series.map(lambda v: isinstance(v, str) and v == "") - self.n_empty_str += int(empty_mask.sum()) - - if pd.api.types.is_numeric_dtype(series) and not non_null.empty: - as_float = non_null.astype("float64") - self.numeric_sum += float(as_float.sum()) - self.numeric_sumsq += float((as_float * as_float).sum()) - self.numeric_count += int(len(as_float)) - cmin = as_float.min() - cmax = as_float.max() - if self.min_val is None or cmin < self.min_val: + # -- Numeric stats (C-level) --------------------------------------- + if is_numeric: + arr = non_null.to_numpy(dtype="float64", copy=False, na_value=np.nan) + # NaN-safe aggregates in one pass each (all C-level). + self.numeric_sum += float(np.nansum(arr)) + self.numeric_sumsq += float(np.nansum(arr * arr)) + self.numeric_count += int(arr.size) + cmin = float(np.nanmin(arr)) if arr.size else None + cmax = float(np.nanmax(arr)) if arr.size else None + if cmin is not None and (self.min_val is None or cmin < self.min_val): self.min_val = cmin - if self.max_val is None or cmax > self.max_val: + if cmax is not None and (self.max_val is None or cmax > self.max_val): self.max_val = cmax - elif pd.api.types.is_datetime64_any_dtype(series) and not non_null.empty: + elif is_datetime: cmin = non_null.min() cmax = non_null.max() if self.min_val is None or cmin < self.min_val: @@ -195,36 +233,68 @@ class _ColumnStats: if self.max_val is None or cmax > self.max_val: self.max_val = cmax - if pd.api.types.is_object_dtype(series) and not non_null.empty: - str_like = non_null.map(lambda v: v if isinstance(v, str) else str(v)) - byte_lens = str_like.map(lambda s: len(s.encode("utf-8", errors="replace"))) - if len(byte_lens): - bmax = int(byte_lens.max()) + # -- String length stats via vectorized str.len -------------------- + # ``.str.len()`` is C-fast; for ASCII-dominated SAS data it matches + # UTF-8 byte length closely enough for the "oversized TEXT" flag. + if is_object: + lens = non_null.astype(str, copy=False).str.len() + lens = lens.dropna() + if not lens.empty: + bmax = int(lens.max()) if bmax > self.str_max_bytes: self.str_max_bytes = bmax - self.str_sum_bytes += int(byte_lens.sum()) - self.str_count += int(len(byte_lens)) + self.str_sum_bytes += int(lens.sum()) + self.str_count += int(lens.size) - for val in non_null.tolist(): - hashable = _hashable(val) - if hashable is _UNHASHABLE: - # Give up on distinct/top-counts for this column; it's some - # exotic (e.g. list) value we can't hash, and the drop/index - # suggestions wouldn't be meaningful anyway. - self.distinct_overflow = True - continue - if not self.distinct_overflow: - if hashable in self.distinct: - pass - elif len(self.distinct) >= DISTINCT_CAP: + # -- Samples (tiny slice; free) ------------------------------------ + if len(self.samples) < 3: + needed = 3 - len(self.samples) + self.samples.extend(non_null.head(needed).tolist()) + + # -- Distinct / top_counts (vectorized via value_counts) ----------- + # Skip altogether once we're saturated: distinct is already known + # to be > DISTINCT_CAP and top_counts has its DISTINCT_CAP slots + # filled, so further value_counts calls can only bump existing + # keys - info we don't need for any of the classifiers. + top_full = len(self.top_counts) >= DISTINCT_CAP + if self.distinct_overflow and top_full: + return + + try: + vc = non_null.value_counts(sort=False) + except TypeError: + # Unhashable values (list/dict). Drop the column from both + # distinct and top-N tracking. + self.distinct_overflow = True + return + + if vc.empty: + return + + if not self.distinct_overflow: + # Only *new* values need to be considered for the distinct set. + for val in vc.index: + if val in self.distinct: + continue + if len(self.distinct) >= DISTINCT_CAP: self.distinct_overflow = True - else: - self.distinct.add(hashable) - if len(self.top_counts) < DISTINCT_CAP or hashable in self.top_counts: - self.top_counts[hashable] += 1 + break + self.distinct.add(val) - if len(self.samples) < 3: - self.samples.append(val) + if not top_full: + # Bulk-merge known keys; cap adds for new keys. + for val, count in zip(vc.index.tolist(), vc.to_numpy().tolist()): + if val in self.top_counts: + self.top_counts[val] += int(count) + elif len(self.top_counts) < DISTINCT_CAP: + self.top_counts[val] = int(count) + # else: silently skip - we're past the cap. + else: + # Only existing keys can grow. + tc = self.top_counts + for val, count in zip(vc.index.tolist(), vc.to_numpy().tolist()): + if val in tc: + tc[val] += int(count) # -- Derived properties ------------------------------------------------ @@ -279,27 +349,6 @@ class _ColumnStats: return self.top_counts.most_common(n) -class _UnhashableSentinel: - pass - - -_UNHASHABLE = _UnhashableSentinel() - - -def _hashable(val: Any) -> Any: - """Return a hashable form of ``val``, or :data:`_UNHASHABLE` if we can't. - - pandas occasionally hands us objects (lists, dicts) from object columns - that aren't hashable. Rather than crashing the whole report, we let the - column fall back to "distinct_overflow" mode for those rows. - """ - try: - hash(val) - return val - except TypeError: - return _UNHASHABLE - - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -366,16 +415,40 @@ def profile_file( } total_rows = 0 - kwargs = {} - if chunksize is not None: - kwargs["chunksize"] = chunksize + effective_chunksize = chunksize if chunksize is not None else PROFILE_CHUNK_ROWS + kwargs = {"chunksize": effective_chunksize} + # pyreadstat + pandas are both C-level; the per-chunk overhead we pay + # is dominated by the value_counts passes in _ColumnStats.update, so + # the profile runs O(total_rows) with a small constant. + import time + started_at = time.monotonic() + last_print_at = started_at for chunk_df, _chunk_meta in iter_sas_chunks(path, **kwargs): total_rows += len(chunk_df) - print(f" profiling... {total_rows:,} rows", file=sys.stderr) for name, cs in stats.items(): if name not in chunk_df.columns: continue cs.update(chunk_df[name]) + now = time.monotonic() + # Throttle progress output to ~one line per 2 seconds so huge files + # don't spam stderr but small files still print at least once. + if now - last_print_at >= 2.0: + elapsed = now - started_at + rate = total_rows / elapsed if elapsed > 0 else 0.0 + print( + f" profiling... {total_rows:,} rows " + f"({rate:,.0f} rows/s)", + file=sys.stderr, + ) + last_print_at = now + + elapsed = time.monotonic() - started_at + rate = total_rows / elapsed if elapsed > 0 else 0.0 + print( + f" profiled {total_rows:,} rows in {elapsed:.1f}s " + f"({rate:,.0f} rows/s)", + file=sys.stderr, + ) return stats, columns, meta, total_rows @@ -985,6 +1058,14 @@ def _build_argparser() -> argparse.ArgumentParser: help="Uniqueness (distinct/non-null) at/above which a column is an index candidate.") p.add_argument("--partition-min-fill-pct", type=float, default=PARTITION_MIN_FILL_PCT) p.add_argument("--pre-sharded-max-distinct", type=int, default=PRE_SHARDED_MAX_DISTINCT) + p.add_argument( + "--chunksize", type=int, default=None, + help=( + "Rows per streaming read. Bigger chunks amortize pyreadstat / " + "pandas overhead (faster for huge files) but use more peak " + f"memory. Defaults to PROFILE_CHUNK_ROWS ({PROFILE_CHUNK_ROWS:,})." + ), + ) return p @@ -999,7 +1080,7 @@ def main(argv: Optional[List[str]] = None) -> int: return 2 print(f"profiling {path} -> {out_path}", file=sys.stderr) - stats, columns, meta, total_rows = profile_file(path) + stats, columns, meta, total_rows = profile_file(path, chunksize=args.chunksize) drops, partitions, indexes, warnings = classify( stats, columns,