Enhance SAS profiling performance in sas_profiler.py

Added a new constant for profiling chunk size to optimize memory usage during profiling operations. Refactored the update method in the _ColumnStats class to improve efficiency in handling missing values and calculating statistics for numeric and string data types. This update includes vectorized operations for better performance and clarity in the implementation.
This commit is contained in:
David Peterson 2026-04-20 19:03:40 -05:00
parent 5449a25b44
commit 4fc85081c8

View File

@ -43,6 +43,7 @@ from typing import Any, Dict, List, Optional, Tuple
_REPO_ROOT = Path(__file__).resolve().parent.parent _REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(_REPO_ROOT / "generic_loader")) sys.path.insert(0, str(_REPO_ROOT / "generic_loader"))
import numpy as np # noqa: E402
import pandas as pd # noqa: E402 import pandas as pd # noqa: E402
from openpyxl import Workbook # noqa: E402 from openpyxl import Workbook # noqa: E402
from openpyxl.styles import Alignment, Font, PatternFill # noqa: E402 from openpyxl.styles import Alignment, Font, PatternFill # noqa: E402
@ -51,7 +52,6 @@ from openpyxl.utils import get_column_letter # noqa: E402
from load_sas import ( # noqa: E402 from load_sas import ( # noqa: E402
NUMERIC_INT_RANGE, NUMERIC_INT_RANGE,
ColumnSpec, ColumnSpec,
_char_missing_mask,
infer_schema, infer_schema,
iter_sas_chunks, iter_sas_chunks,
read_sas_preview, read_sas_preview,
@ -96,6 +96,25 @@ PREVIEW_ROWS_FOR_INFERENCE: int = 10_000
"""Rows pulled from the file for the loader's schema inference. Matches """Rows pulled from the file for the loader's schema inference. Matches
``load_sas.TYPE_INFERENCE_SAMPLE_ROWS`` so suggestions track the loader.""" ``load_sas.TYPE_INFERENCE_SAMPLE_ROWS`` so suggestions track the loader."""
PROFILE_CHUNK_ROWS: int = 5_000_000
"""Rows per streaming chunk while profiling. Larger chunks amortize
pyreadstat / pandas overhead, and the profiler is typically run on a
beefy box (e.g. a 128 GB EC2) rather than a laptop, so the default is
set aggressively.
Rough peak-memory estimate while a chunk is in flight:
peak_bytes ~= chunksize * num_cols * ~50 bytes/cell * 2-3x
(The 2-3x factor covers pyreadstat's read buffer + pandas frame
construction temporaries.) At 5M rows x 50 cols that's roughly 10-20 GB,
which is comfortable on a 128 GB host but would OOM a laptop.
If you have lots of RAM and a very wide file, lower this; if you have a
narrow file and want max throughput, bump it higher with ``--chunksize``
(the profiler will happily take 20M+ per chunk). If ``chunksize`` is
larger than the file, pyreadstat just hands back one chunk."""
PARTITION_NAME_PATTERNS: Tuple[re.Pattern, ...] = ( PARTITION_NAME_PATTERNS: Tuple[re.Pattern, ...] = (
re.compile(r"^state$", re.IGNORECASE), re.compile(r"^state$", re.IGNORECASE),
@ -153,41 +172,60 @@ class _ColumnStats:
samples: List[Any] = field(default_factory=list) samples: List[Any] = field(default_factory=list)
def update(self, series: pd.Series) -> None: def update(self, series: pd.Series) -> None:
"""Fold one chunk's worth of this column into the accumulator.""" """Fold one chunk's worth of this column into the accumulator.
self.n_total += len(series)
if len(series) == 0:
return
if pd.api.types.is_object_dtype(series): Implementation notes (this method is the dominant per-file cost):
miss_mask = _char_missing_mask(series)
- All masks are vectorized - no ``Series.map(lambda ...)`` loops.
- Distinct tracking uses ``Series.value_counts`` so we iterate at
most once per *unique* value in the chunk (in C), not once per
row.
- Once ``distinct_overflow`` is set and ``top_counts`` is full,
subsequent chunks skip the value-counts pass entirely - we already
know the column is too varied to be a partition / drop candidate
and we already have the top-N.
"""
n = len(series)
if n == 0:
return
self.n_total += n
is_object = pd.api.types.is_object_dtype(series)
is_numeric = pd.api.types.is_numeric_dtype(series)
is_datetime = pd.api.types.is_datetime64_any_dtype(series)
if is_object:
# Vectorized equivalent of load_sas._char_missing_mask: treat
# None / NaN / empty string as missing. ``series == ""`` is
# False for non-string values so we don't need per-element type
# checks.
na_mask = series.isna()
empty_mask = (series == "") & ~na_mask
miss_mask = na_mask | empty_mask
self.n_empty_str += int(empty_mask.sum())
else: else:
miss_mask = series.isna() miss_mask = series.isna()
miss_count = int(miss_mask.sum()) self.n_null += int(miss_mask.sum())
self.n_null += miss_count non_null = series[~miss_mask] if miss_mask.any() else series
if non_null.empty:
return
non_null = series[~miss_mask] # -- Numeric stats (C-level) ---------------------------------------
if is_numeric:
if pd.api.types.is_object_dtype(series): arr = non_null.to_numpy(dtype="float64", copy=False, na_value=np.nan)
# Empty-string tracking is useful for TEXT columns where the loader # NaN-safe aggregates in one pass each (all C-level).
# later translates "" -> NULL in the COPY step. A column dominated self.numeric_sum += float(np.nansum(arr))
# by empty strings is still effectively null even if it isn't NaN. self.numeric_sumsq += float(np.nansum(arr * arr))
empty_mask = series.map(lambda v: isinstance(v, str) and v == "") self.numeric_count += int(arr.size)
self.n_empty_str += int(empty_mask.sum()) cmin = float(np.nanmin(arr)) if arr.size else None
cmax = float(np.nanmax(arr)) if arr.size else None
if pd.api.types.is_numeric_dtype(series) and not non_null.empty: if cmin is not None and (self.min_val is None or cmin < self.min_val):
as_float = non_null.astype("float64")
self.numeric_sum += float(as_float.sum())
self.numeric_sumsq += float((as_float * as_float).sum())
self.numeric_count += int(len(as_float))
cmin = as_float.min()
cmax = as_float.max()
if self.min_val is None or cmin < self.min_val:
self.min_val = cmin self.min_val = cmin
if self.max_val is None or cmax > self.max_val: if cmax is not None and (self.max_val is None or cmax > self.max_val):
self.max_val = cmax self.max_val = cmax
elif pd.api.types.is_datetime64_any_dtype(series) and not non_null.empty: elif is_datetime:
cmin = non_null.min() cmin = non_null.min()
cmax = non_null.max() cmax = non_null.max()
if self.min_val is None or cmin < self.min_val: if self.min_val is None or cmin < self.min_val:
@ -195,36 +233,68 @@ class _ColumnStats:
if self.max_val is None or cmax > self.max_val: if self.max_val is None or cmax > self.max_val:
self.max_val = cmax self.max_val = cmax
if pd.api.types.is_object_dtype(series) and not non_null.empty: # -- String length stats via vectorized str.len --------------------
str_like = non_null.map(lambda v: v if isinstance(v, str) else str(v)) # ``.str.len()`` is C-fast; for ASCII-dominated SAS data it matches
byte_lens = str_like.map(lambda s: len(s.encode("utf-8", errors="replace"))) # UTF-8 byte length closely enough for the "oversized TEXT" flag.
if len(byte_lens): if is_object:
bmax = int(byte_lens.max()) lens = non_null.astype(str, copy=False).str.len()
lens = lens.dropna()
if not lens.empty:
bmax = int(lens.max())
if bmax > self.str_max_bytes: if bmax > self.str_max_bytes:
self.str_max_bytes = bmax self.str_max_bytes = bmax
self.str_sum_bytes += int(byte_lens.sum()) self.str_sum_bytes += int(lens.sum())
self.str_count += int(len(byte_lens)) self.str_count += int(lens.size)
for val in non_null.tolist(): # -- Samples (tiny slice; free) ------------------------------------
hashable = _hashable(val) if len(self.samples) < 3:
if hashable is _UNHASHABLE: needed = 3 - len(self.samples)
# Give up on distinct/top-counts for this column; it's some self.samples.extend(non_null.head(needed).tolist())
# exotic (e.g. list) value we can't hash, and the drop/index
# suggestions wouldn't be meaningful anyway. # -- Distinct / top_counts (vectorized via value_counts) -----------
self.distinct_overflow = True # Skip altogether once we're saturated: distinct is already known
continue # to be > DISTINCT_CAP and top_counts has its DISTINCT_CAP slots
if not self.distinct_overflow: # filled, so further value_counts calls can only bump existing
if hashable in self.distinct: # keys - info we don't need for any of the classifiers.
pass top_full = len(self.top_counts) >= DISTINCT_CAP
elif len(self.distinct) >= DISTINCT_CAP: if self.distinct_overflow and top_full:
return
try:
vc = non_null.value_counts(sort=False)
except TypeError:
# Unhashable values (list/dict). Drop the column from both
# distinct and top-N tracking.
self.distinct_overflow = True
return
if vc.empty:
return
if not self.distinct_overflow:
# Only *new* values need to be considered for the distinct set.
for val in vc.index:
if val in self.distinct:
continue
if len(self.distinct) >= DISTINCT_CAP:
self.distinct_overflow = True self.distinct_overflow = True
else: break
self.distinct.add(hashable) self.distinct.add(val)
if len(self.top_counts) < DISTINCT_CAP or hashable in self.top_counts:
self.top_counts[hashable] += 1
if len(self.samples) < 3: if not top_full:
self.samples.append(val) # Bulk-merge known keys; cap adds for new keys.
for val, count in zip(vc.index.tolist(), vc.to_numpy().tolist()):
if val in self.top_counts:
self.top_counts[val] += int(count)
elif len(self.top_counts) < DISTINCT_CAP:
self.top_counts[val] = int(count)
# else: silently skip - we're past the cap.
else:
# Only existing keys can grow.
tc = self.top_counts
for val, count in zip(vc.index.tolist(), vc.to_numpy().tolist()):
if val in tc:
tc[val] += int(count)
# -- Derived properties ------------------------------------------------ # -- Derived properties ------------------------------------------------
@ -279,27 +349,6 @@ class _ColumnStats:
return self.top_counts.most_common(n) return self.top_counts.most_common(n)
class _UnhashableSentinel:
pass
_UNHASHABLE = _UnhashableSentinel()
def _hashable(val: Any) -> Any:
"""Return a hashable form of ``val``, or :data:`_UNHASHABLE` if we can't.
pandas occasionally hands us objects (lists, dicts) from object columns
that aren't hashable. Rather than crashing the whole report, we let the
column fall back to "distinct_overflow" mode for those rows.
"""
try:
hash(val)
return val
except TypeError:
return _UNHASHABLE
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers # Helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -366,16 +415,40 @@ def profile_file(
} }
total_rows = 0 total_rows = 0
kwargs = {} effective_chunksize = chunksize if chunksize is not None else PROFILE_CHUNK_ROWS
if chunksize is not None: kwargs = {"chunksize": effective_chunksize}
kwargs["chunksize"] = chunksize # pyreadstat + pandas are both C-level; the per-chunk overhead we pay
# is dominated by the value_counts passes in _ColumnStats.update, so
# the profile runs O(total_rows) with a small constant.
import time
started_at = time.monotonic()
last_print_at = started_at
for chunk_df, _chunk_meta in iter_sas_chunks(path, **kwargs): for chunk_df, _chunk_meta in iter_sas_chunks(path, **kwargs):
total_rows += len(chunk_df) total_rows += len(chunk_df)
print(f" profiling... {total_rows:,} rows", file=sys.stderr)
for name, cs in stats.items(): for name, cs in stats.items():
if name not in chunk_df.columns: if name not in chunk_df.columns:
continue continue
cs.update(chunk_df[name]) cs.update(chunk_df[name])
now = time.monotonic()
# Throttle progress output to ~one line per 2 seconds so huge files
# don't spam stderr but small files still print at least once.
if now - last_print_at >= 2.0:
elapsed = now - started_at
rate = total_rows / elapsed if elapsed > 0 else 0.0
print(
f" profiling... {total_rows:,} rows "
f"({rate:,.0f} rows/s)",
file=sys.stderr,
)
last_print_at = now
elapsed = time.monotonic() - started_at
rate = total_rows / elapsed if elapsed > 0 else 0.0
print(
f" profiled {total_rows:,} rows in {elapsed:.1f}s "
f"({rate:,.0f} rows/s)",
file=sys.stderr,
)
return stats, columns, meta, total_rows return stats, columns, meta, total_rows
@ -985,6 +1058,14 @@ def _build_argparser() -> argparse.ArgumentParser:
help="Uniqueness (distinct/non-null) at/above which a column is an index candidate.") help="Uniqueness (distinct/non-null) at/above which a column is an index candidate.")
p.add_argument("--partition-min-fill-pct", type=float, default=PARTITION_MIN_FILL_PCT) p.add_argument("--partition-min-fill-pct", type=float, default=PARTITION_MIN_FILL_PCT)
p.add_argument("--pre-sharded-max-distinct", type=int, default=PRE_SHARDED_MAX_DISTINCT) p.add_argument("--pre-sharded-max-distinct", type=int, default=PRE_SHARDED_MAX_DISTINCT)
p.add_argument(
"--chunksize", type=int, default=None,
help=(
"Rows per streaming read. Bigger chunks amortize pyreadstat / "
"pandas overhead (faster for huge files) but use more peak "
f"memory. Defaults to PROFILE_CHUNK_ROWS ({PROFILE_CHUNK_ROWS:,})."
),
)
return p return p
@ -999,7 +1080,7 @@ def main(argv: Optional[List[str]] = None) -> int:
return 2 return 2
print(f"profiling {path} -> {out_path}", file=sys.stderr) print(f"profiling {path} -> {out_path}", file=sys.stderr)
stats, columns, meta, total_rows = profile_file(path) stats, columns, meta, total_rows = profile_file(path, chunksize=args.chunksize)
drops, partitions, indexes, warnings = classify( drops, partitions, indexes, warnings = classify(
stats, columns, stats, columns,