Enhance SAS profiling performance in sas_profiler.py

Added a new constant for profiling chunk size to optimize memory usage during profiling operations. Refactored the update method in the _ColumnStats class to improve efficiency in handling missing values and calculating statistics for numeric and string data types. This update includes vectorized operations for better performance and clarity in the implementation.
This commit is contained in:
David Peterson 2026-04-20 19:03:40 -05:00
parent 5449a25b44
commit 4fc85081c8

View File

@ -43,6 +43,7 @@ from typing import Any, Dict, List, Optional, Tuple
_REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(_REPO_ROOT / "generic_loader"))
import numpy as np # noqa: E402
import pandas as pd # noqa: E402
from openpyxl import Workbook # noqa: E402
from openpyxl.styles import Alignment, Font, PatternFill # noqa: E402
@ -51,7 +52,6 @@ from openpyxl.utils import get_column_letter # noqa: E402
from load_sas import ( # noqa: E402
NUMERIC_INT_RANGE,
ColumnSpec,
_char_missing_mask,
infer_schema,
iter_sas_chunks,
read_sas_preview,
@ -96,6 +96,25 @@ PREVIEW_ROWS_FOR_INFERENCE: int = 10_000
"""Rows pulled from the file for the loader's schema inference. Matches
``load_sas.TYPE_INFERENCE_SAMPLE_ROWS`` so suggestions track the loader."""
PROFILE_CHUNK_ROWS: int = 5_000_000
"""Rows per streaming chunk while profiling. Larger chunks amortize
pyreadstat / pandas overhead, and the profiler is typically run on a
beefy box (e.g. a 128 GB EC2) rather than a laptop, so the default is
set aggressively.
Rough peak-memory estimate while a chunk is in flight:
peak_bytes ~= chunksize * num_cols * ~50 bytes/cell * 2-3x
(The 2-3x factor covers pyreadstat's read buffer + pandas frame
construction temporaries.) At 5M rows x 50 cols that's roughly 10-20 GB,
which is comfortable on a 128 GB host but would OOM a laptop.
If you have lots of RAM and a very wide file, lower this; if you have a
narrow file and want max throughput, bump it higher with ``--chunksize``
(the profiler will happily take 20M+ per chunk). If ``chunksize`` is
larger than the file, pyreadstat just hands back one chunk."""
PARTITION_NAME_PATTERNS: Tuple[re.Pattern, ...] = (
re.compile(r"^state$", re.IGNORECASE),
@ -153,41 +172,60 @@ class _ColumnStats:
samples: List[Any] = field(default_factory=list)
def update(self, series: pd.Series) -> None:
"""Fold one chunk's worth of this column into the accumulator."""
self.n_total += len(series)
if len(series) == 0:
return
"""Fold one chunk's worth of this column into the accumulator.
if pd.api.types.is_object_dtype(series):
miss_mask = _char_missing_mask(series)
Implementation notes (this method is the dominant per-file cost):
- All masks are vectorized - no ``Series.map(lambda ...)`` loops.
- Distinct tracking uses ``Series.value_counts`` so we iterate at
most once per *unique* value in the chunk (in C), not once per
row.
- Once ``distinct_overflow`` is set and ``top_counts`` is full,
subsequent chunks skip the value-counts pass entirely - we already
know the column is too varied to be a partition / drop candidate
and we already have the top-N.
"""
n = len(series)
if n == 0:
return
self.n_total += n
is_object = pd.api.types.is_object_dtype(series)
is_numeric = pd.api.types.is_numeric_dtype(series)
is_datetime = pd.api.types.is_datetime64_any_dtype(series)
if is_object:
# Vectorized equivalent of load_sas._char_missing_mask: treat
# None / NaN / empty string as missing. ``series == ""`` is
# False for non-string values so we don't need per-element type
# checks.
na_mask = series.isna()
empty_mask = (series == "") & ~na_mask
miss_mask = na_mask | empty_mask
self.n_empty_str += int(empty_mask.sum())
else:
miss_mask = series.isna()
miss_count = int(miss_mask.sum())
self.n_null += miss_count
self.n_null += int(miss_mask.sum())
non_null = series[~miss_mask] if miss_mask.any() else series
if non_null.empty:
return
non_null = series[~miss_mask]
if pd.api.types.is_object_dtype(series):
# Empty-string tracking is useful for TEXT columns where the loader
# later translates "" -> NULL in the COPY step. A column dominated
# by empty strings is still effectively null even if it isn't NaN.
empty_mask = series.map(lambda v: isinstance(v, str) and v == "")
self.n_empty_str += int(empty_mask.sum())
if pd.api.types.is_numeric_dtype(series) and not non_null.empty:
as_float = non_null.astype("float64")
self.numeric_sum += float(as_float.sum())
self.numeric_sumsq += float((as_float * as_float).sum())
self.numeric_count += int(len(as_float))
cmin = as_float.min()
cmax = as_float.max()
if self.min_val is None or cmin < self.min_val:
# -- Numeric stats (C-level) ---------------------------------------
if is_numeric:
arr = non_null.to_numpy(dtype="float64", copy=False, na_value=np.nan)
# NaN-safe aggregates in one pass each (all C-level).
self.numeric_sum += float(np.nansum(arr))
self.numeric_sumsq += float(np.nansum(arr * arr))
self.numeric_count += int(arr.size)
cmin = float(np.nanmin(arr)) if arr.size else None
cmax = float(np.nanmax(arr)) if arr.size else None
if cmin is not None and (self.min_val is None or cmin < self.min_val):
self.min_val = cmin
if self.max_val is None or cmax > self.max_val:
if cmax is not None and (self.max_val is None or cmax > self.max_val):
self.max_val = cmax
elif pd.api.types.is_datetime64_any_dtype(series) and not non_null.empty:
elif is_datetime:
cmin = non_null.min()
cmax = non_null.max()
if self.min_val is None or cmin < self.min_val:
@ -195,36 +233,68 @@ class _ColumnStats:
if self.max_val is None or cmax > self.max_val:
self.max_val = cmax
if pd.api.types.is_object_dtype(series) and not non_null.empty:
str_like = non_null.map(lambda v: v if isinstance(v, str) else str(v))
byte_lens = str_like.map(lambda s: len(s.encode("utf-8", errors="replace")))
if len(byte_lens):
bmax = int(byte_lens.max())
# -- String length stats via vectorized str.len --------------------
# ``.str.len()`` is C-fast; for ASCII-dominated SAS data it matches
# UTF-8 byte length closely enough for the "oversized TEXT" flag.
if is_object:
lens = non_null.astype(str, copy=False).str.len()
lens = lens.dropna()
if not lens.empty:
bmax = int(lens.max())
if bmax > self.str_max_bytes:
self.str_max_bytes = bmax
self.str_sum_bytes += int(byte_lens.sum())
self.str_count += int(len(byte_lens))
for val in non_null.tolist():
hashable = _hashable(val)
if hashable is _UNHASHABLE:
# Give up on distinct/top-counts for this column; it's some
# exotic (e.g. list) value we can't hash, and the drop/index
# suggestions wouldn't be meaningful anyway.
self.distinct_overflow = True
continue
if not self.distinct_overflow:
if hashable in self.distinct:
pass
elif len(self.distinct) >= DISTINCT_CAP:
self.distinct_overflow = True
else:
self.distinct.add(hashable)
if len(self.top_counts) < DISTINCT_CAP or hashable in self.top_counts:
self.top_counts[hashable] += 1
self.str_sum_bytes += int(lens.sum())
self.str_count += int(lens.size)
# -- Samples (tiny slice; free) ------------------------------------
if len(self.samples) < 3:
self.samples.append(val)
needed = 3 - len(self.samples)
self.samples.extend(non_null.head(needed).tolist())
# -- Distinct / top_counts (vectorized via value_counts) -----------
# Skip altogether once we're saturated: distinct is already known
# to be > DISTINCT_CAP and top_counts has its DISTINCT_CAP slots
# filled, so further value_counts calls can only bump existing
# keys - info we don't need for any of the classifiers.
top_full = len(self.top_counts) >= DISTINCT_CAP
if self.distinct_overflow and top_full:
return
try:
vc = non_null.value_counts(sort=False)
except TypeError:
# Unhashable values (list/dict). Drop the column from both
# distinct and top-N tracking.
self.distinct_overflow = True
return
if vc.empty:
return
if not self.distinct_overflow:
# Only *new* values need to be considered for the distinct set.
for val in vc.index:
if val in self.distinct:
continue
if len(self.distinct) >= DISTINCT_CAP:
self.distinct_overflow = True
break
self.distinct.add(val)
if not top_full:
# Bulk-merge known keys; cap adds for new keys.
for val, count in zip(vc.index.tolist(), vc.to_numpy().tolist()):
if val in self.top_counts:
self.top_counts[val] += int(count)
elif len(self.top_counts) < DISTINCT_CAP:
self.top_counts[val] = int(count)
# else: silently skip - we're past the cap.
else:
# Only existing keys can grow.
tc = self.top_counts
for val, count in zip(vc.index.tolist(), vc.to_numpy().tolist()):
if val in tc:
tc[val] += int(count)
# -- Derived properties ------------------------------------------------
@ -279,27 +349,6 @@ class _ColumnStats:
return self.top_counts.most_common(n)
class _UnhashableSentinel:
pass
_UNHASHABLE = _UnhashableSentinel()
def _hashable(val: Any) -> Any:
"""Return a hashable form of ``val``, or :data:`_UNHASHABLE` if we can't.
pandas occasionally hands us objects (lists, dicts) from object columns
that aren't hashable. Rather than crashing the whole report, we let the
column fall back to "distinct_overflow" mode for those rows.
"""
try:
hash(val)
return val
except TypeError:
return _UNHASHABLE
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@ -366,16 +415,40 @@ def profile_file(
}
total_rows = 0
kwargs = {}
if chunksize is not None:
kwargs["chunksize"] = chunksize
effective_chunksize = chunksize if chunksize is not None else PROFILE_CHUNK_ROWS
kwargs = {"chunksize": effective_chunksize}
# pyreadstat + pandas are both C-level; the per-chunk overhead we pay
# is dominated by the value_counts passes in _ColumnStats.update, so
# the profile runs O(total_rows) with a small constant.
import time
started_at = time.monotonic()
last_print_at = started_at
for chunk_df, _chunk_meta in iter_sas_chunks(path, **kwargs):
total_rows += len(chunk_df)
print(f" profiling... {total_rows:,} rows", file=sys.stderr)
for name, cs in stats.items():
if name not in chunk_df.columns:
continue
cs.update(chunk_df[name])
now = time.monotonic()
# Throttle progress output to ~one line per 2 seconds so huge files
# don't spam stderr but small files still print at least once.
if now - last_print_at >= 2.0:
elapsed = now - started_at
rate = total_rows / elapsed if elapsed > 0 else 0.0
print(
f" profiling... {total_rows:,} rows "
f"({rate:,.0f} rows/s)",
file=sys.stderr,
)
last_print_at = now
elapsed = time.monotonic() - started_at
rate = total_rows / elapsed if elapsed > 0 else 0.0
print(
f" profiled {total_rows:,} rows in {elapsed:.1f}s "
f"({rate:,.0f} rows/s)",
file=sys.stderr,
)
return stats, columns, meta, total_rows
@ -985,6 +1058,14 @@ def _build_argparser() -> argparse.ArgumentParser:
help="Uniqueness (distinct/non-null) at/above which a column is an index candidate.")
p.add_argument("--partition-min-fill-pct", type=float, default=PARTITION_MIN_FILL_PCT)
p.add_argument("--pre-sharded-max-distinct", type=int, default=PRE_SHARDED_MAX_DISTINCT)
p.add_argument(
"--chunksize", type=int, default=None,
help=(
"Rows per streaming read. Bigger chunks amortize pyreadstat / "
"pandas overhead (faster for huge files) but use more peak "
f"memory. Defaults to PROFILE_CHUNK_ROWS ({PROFILE_CHUNK_ROWS:,})."
),
)
return p
@ -999,7 +1080,7 @@ def main(argv: Optional[List[str]] = None) -> int:
return 2
print(f"profiling {path} -> {out_path}", file=sys.stderr)
stats, columns, meta, total_rows = profile_file(path)
stats, columns, meta, total_rows = profile_file(path, chunksize=args.chunksize)
drops, partitions, indexes, warnings = classify(
stats, columns,