Enhance SAS profiling performance in sas_profiler.py
Added a new constant for profiling chunk size to optimize memory usage during profiling operations. Refactored the update method in the _ColumnStats class to improve efficiency in handling missing values and calculating statistics for numeric and string data types. This update includes vectorized operations for better performance and clarity in the implementation.
This commit is contained in:
parent
5449a25b44
commit
4fc85081c8
@ -43,6 +43,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||||||
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
||||||
sys.path.insert(0, str(_REPO_ROOT / "generic_loader"))
|
sys.path.insert(0, str(_REPO_ROOT / "generic_loader"))
|
||||||
|
|
||||||
|
import numpy as np # noqa: E402
|
||||||
import pandas as pd # noqa: E402
|
import pandas as pd # noqa: E402
|
||||||
from openpyxl import Workbook # noqa: E402
|
from openpyxl import Workbook # noqa: E402
|
||||||
from openpyxl.styles import Alignment, Font, PatternFill # noqa: E402
|
from openpyxl.styles import Alignment, Font, PatternFill # noqa: E402
|
||||||
@ -51,7 +52,6 @@ from openpyxl.utils import get_column_letter # noqa: E402
|
|||||||
from load_sas import ( # noqa: E402
|
from load_sas import ( # noqa: E402
|
||||||
NUMERIC_INT_RANGE,
|
NUMERIC_INT_RANGE,
|
||||||
ColumnSpec,
|
ColumnSpec,
|
||||||
_char_missing_mask,
|
|
||||||
infer_schema,
|
infer_schema,
|
||||||
iter_sas_chunks,
|
iter_sas_chunks,
|
||||||
read_sas_preview,
|
read_sas_preview,
|
||||||
@ -96,6 +96,25 @@ PREVIEW_ROWS_FOR_INFERENCE: int = 10_000
|
|||||||
"""Rows pulled from the file for the loader's schema inference. Matches
|
"""Rows pulled from the file for the loader's schema inference. Matches
|
||||||
``load_sas.TYPE_INFERENCE_SAMPLE_ROWS`` so suggestions track the loader."""
|
``load_sas.TYPE_INFERENCE_SAMPLE_ROWS`` so suggestions track the loader."""
|
||||||
|
|
||||||
|
PROFILE_CHUNK_ROWS: int = 5_000_000
|
||||||
|
"""Rows per streaming chunk while profiling. Larger chunks amortize
|
||||||
|
pyreadstat / pandas overhead, and the profiler is typically run on a
|
||||||
|
beefy box (e.g. a 128 GB EC2) rather than a laptop, so the default is
|
||||||
|
set aggressively.
|
||||||
|
|
||||||
|
Rough peak-memory estimate while a chunk is in flight:
|
||||||
|
|
||||||
|
peak_bytes ~= chunksize * num_cols * ~50 bytes/cell * 2-3x
|
||||||
|
|
||||||
|
(The 2-3x factor covers pyreadstat's read buffer + pandas frame
|
||||||
|
construction temporaries.) At 5M rows x 50 cols that's roughly 10-20 GB,
|
||||||
|
which is comfortable on a 128 GB host but would OOM a laptop.
|
||||||
|
|
||||||
|
If you have lots of RAM and a very wide file, lower this; if you have a
|
||||||
|
narrow file and want max throughput, bump it higher with ``--chunksize``
|
||||||
|
(the profiler will happily take 20M+ per chunk). If ``chunksize`` is
|
||||||
|
larger than the file, pyreadstat just hands back one chunk."""
|
||||||
|
|
||||||
|
|
||||||
PARTITION_NAME_PATTERNS: Tuple[re.Pattern, ...] = (
|
PARTITION_NAME_PATTERNS: Tuple[re.Pattern, ...] = (
|
||||||
re.compile(r"^state$", re.IGNORECASE),
|
re.compile(r"^state$", re.IGNORECASE),
|
||||||
@ -153,41 +172,60 @@ class _ColumnStats:
|
|||||||
samples: List[Any] = field(default_factory=list)
|
samples: List[Any] = field(default_factory=list)
|
||||||
|
|
||||||
def update(self, series: pd.Series) -> None:
|
def update(self, series: pd.Series) -> None:
|
||||||
"""Fold one chunk's worth of this column into the accumulator."""
|
"""Fold one chunk's worth of this column into the accumulator.
|
||||||
self.n_total += len(series)
|
|
||||||
if len(series) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
if pd.api.types.is_object_dtype(series):
|
Implementation notes (this method is the dominant per-file cost):
|
||||||
miss_mask = _char_missing_mask(series)
|
|
||||||
|
- All masks are vectorized - no ``Series.map(lambda ...)`` loops.
|
||||||
|
- Distinct tracking uses ``Series.value_counts`` so we iterate at
|
||||||
|
most once per *unique* value in the chunk (in C), not once per
|
||||||
|
row.
|
||||||
|
- Once ``distinct_overflow`` is set and ``top_counts`` is full,
|
||||||
|
subsequent chunks skip the value-counts pass entirely - we already
|
||||||
|
know the column is too varied to be a partition / drop candidate
|
||||||
|
and we already have the top-N.
|
||||||
|
"""
|
||||||
|
n = len(series)
|
||||||
|
if n == 0:
|
||||||
|
return
|
||||||
|
self.n_total += n
|
||||||
|
|
||||||
|
is_object = pd.api.types.is_object_dtype(series)
|
||||||
|
is_numeric = pd.api.types.is_numeric_dtype(series)
|
||||||
|
is_datetime = pd.api.types.is_datetime64_any_dtype(series)
|
||||||
|
|
||||||
|
if is_object:
|
||||||
|
# Vectorized equivalent of load_sas._char_missing_mask: treat
|
||||||
|
# None / NaN / empty string as missing. ``series == ""`` is
|
||||||
|
# False for non-string values so we don't need per-element type
|
||||||
|
# checks.
|
||||||
|
na_mask = series.isna()
|
||||||
|
empty_mask = (series == "") & ~na_mask
|
||||||
|
miss_mask = na_mask | empty_mask
|
||||||
|
self.n_empty_str += int(empty_mask.sum())
|
||||||
else:
|
else:
|
||||||
miss_mask = series.isna()
|
miss_mask = series.isna()
|
||||||
|
|
||||||
miss_count = int(miss_mask.sum())
|
self.n_null += int(miss_mask.sum())
|
||||||
self.n_null += miss_count
|
non_null = series[~miss_mask] if miss_mask.any() else series
|
||||||
|
if non_null.empty:
|
||||||
|
return
|
||||||
|
|
||||||
non_null = series[~miss_mask]
|
# -- Numeric stats (C-level) ---------------------------------------
|
||||||
|
if is_numeric:
|
||||||
if pd.api.types.is_object_dtype(series):
|
arr = non_null.to_numpy(dtype="float64", copy=False, na_value=np.nan)
|
||||||
# Empty-string tracking is useful for TEXT columns where the loader
|
# NaN-safe aggregates in one pass each (all C-level).
|
||||||
# later translates "" -> NULL in the COPY step. A column dominated
|
self.numeric_sum += float(np.nansum(arr))
|
||||||
# by empty strings is still effectively null even if it isn't NaN.
|
self.numeric_sumsq += float(np.nansum(arr * arr))
|
||||||
empty_mask = series.map(lambda v: isinstance(v, str) and v == "")
|
self.numeric_count += int(arr.size)
|
||||||
self.n_empty_str += int(empty_mask.sum())
|
cmin = float(np.nanmin(arr)) if arr.size else None
|
||||||
|
cmax = float(np.nanmax(arr)) if arr.size else None
|
||||||
if pd.api.types.is_numeric_dtype(series) and not non_null.empty:
|
if cmin is not None and (self.min_val is None or cmin < self.min_val):
|
||||||
as_float = non_null.astype("float64")
|
|
||||||
self.numeric_sum += float(as_float.sum())
|
|
||||||
self.numeric_sumsq += float((as_float * as_float).sum())
|
|
||||||
self.numeric_count += int(len(as_float))
|
|
||||||
cmin = as_float.min()
|
|
||||||
cmax = as_float.max()
|
|
||||||
if self.min_val is None or cmin < self.min_val:
|
|
||||||
self.min_val = cmin
|
self.min_val = cmin
|
||||||
if self.max_val is None or cmax > self.max_val:
|
if cmax is not None and (self.max_val is None or cmax > self.max_val):
|
||||||
self.max_val = cmax
|
self.max_val = cmax
|
||||||
|
|
||||||
elif pd.api.types.is_datetime64_any_dtype(series) and not non_null.empty:
|
elif is_datetime:
|
||||||
cmin = non_null.min()
|
cmin = non_null.min()
|
||||||
cmax = non_null.max()
|
cmax = non_null.max()
|
||||||
if self.min_val is None or cmin < self.min_val:
|
if self.min_val is None or cmin < self.min_val:
|
||||||
@ -195,36 +233,68 @@ class _ColumnStats:
|
|||||||
if self.max_val is None or cmax > self.max_val:
|
if self.max_val is None or cmax > self.max_val:
|
||||||
self.max_val = cmax
|
self.max_val = cmax
|
||||||
|
|
||||||
if pd.api.types.is_object_dtype(series) and not non_null.empty:
|
# -- String length stats via vectorized str.len --------------------
|
||||||
str_like = non_null.map(lambda v: v if isinstance(v, str) else str(v))
|
# ``.str.len()`` is C-fast; for ASCII-dominated SAS data it matches
|
||||||
byte_lens = str_like.map(lambda s: len(s.encode("utf-8", errors="replace")))
|
# UTF-8 byte length closely enough for the "oversized TEXT" flag.
|
||||||
if len(byte_lens):
|
if is_object:
|
||||||
bmax = int(byte_lens.max())
|
lens = non_null.astype(str, copy=False).str.len()
|
||||||
|
lens = lens.dropna()
|
||||||
|
if not lens.empty:
|
||||||
|
bmax = int(lens.max())
|
||||||
if bmax > self.str_max_bytes:
|
if bmax > self.str_max_bytes:
|
||||||
self.str_max_bytes = bmax
|
self.str_max_bytes = bmax
|
||||||
self.str_sum_bytes += int(byte_lens.sum())
|
self.str_sum_bytes += int(lens.sum())
|
||||||
self.str_count += int(len(byte_lens))
|
self.str_count += int(lens.size)
|
||||||
|
|
||||||
for val in non_null.tolist():
|
# -- Samples (tiny slice; free) ------------------------------------
|
||||||
hashable = _hashable(val)
|
if len(self.samples) < 3:
|
||||||
if hashable is _UNHASHABLE:
|
needed = 3 - len(self.samples)
|
||||||
# Give up on distinct/top-counts for this column; it's some
|
self.samples.extend(non_null.head(needed).tolist())
|
||||||
# exotic (e.g. list) value we can't hash, and the drop/index
|
|
||||||
# suggestions wouldn't be meaningful anyway.
|
# -- Distinct / top_counts (vectorized via value_counts) -----------
|
||||||
self.distinct_overflow = True
|
# Skip altogether once we're saturated: distinct is already known
|
||||||
continue
|
# to be > DISTINCT_CAP and top_counts has its DISTINCT_CAP slots
|
||||||
if not self.distinct_overflow:
|
# filled, so further value_counts calls can only bump existing
|
||||||
if hashable in self.distinct:
|
# keys - info we don't need for any of the classifiers.
|
||||||
pass
|
top_full = len(self.top_counts) >= DISTINCT_CAP
|
||||||
elif len(self.distinct) >= DISTINCT_CAP:
|
if self.distinct_overflow and top_full:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
vc = non_null.value_counts(sort=False)
|
||||||
|
except TypeError:
|
||||||
|
# Unhashable values (list/dict). Drop the column from both
|
||||||
|
# distinct and top-N tracking.
|
||||||
|
self.distinct_overflow = True
|
||||||
|
return
|
||||||
|
|
||||||
|
if vc.empty:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.distinct_overflow:
|
||||||
|
# Only *new* values need to be considered for the distinct set.
|
||||||
|
for val in vc.index:
|
||||||
|
if val in self.distinct:
|
||||||
|
continue
|
||||||
|
if len(self.distinct) >= DISTINCT_CAP:
|
||||||
self.distinct_overflow = True
|
self.distinct_overflow = True
|
||||||
else:
|
break
|
||||||
self.distinct.add(hashable)
|
self.distinct.add(val)
|
||||||
if len(self.top_counts) < DISTINCT_CAP or hashable in self.top_counts:
|
|
||||||
self.top_counts[hashable] += 1
|
|
||||||
|
|
||||||
if len(self.samples) < 3:
|
if not top_full:
|
||||||
self.samples.append(val)
|
# Bulk-merge known keys; cap adds for new keys.
|
||||||
|
for val, count in zip(vc.index.tolist(), vc.to_numpy().tolist()):
|
||||||
|
if val in self.top_counts:
|
||||||
|
self.top_counts[val] += int(count)
|
||||||
|
elif len(self.top_counts) < DISTINCT_CAP:
|
||||||
|
self.top_counts[val] = int(count)
|
||||||
|
# else: silently skip - we're past the cap.
|
||||||
|
else:
|
||||||
|
# Only existing keys can grow.
|
||||||
|
tc = self.top_counts
|
||||||
|
for val, count in zip(vc.index.tolist(), vc.to_numpy().tolist()):
|
||||||
|
if val in tc:
|
||||||
|
tc[val] += int(count)
|
||||||
|
|
||||||
# -- Derived properties ------------------------------------------------
|
# -- Derived properties ------------------------------------------------
|
||||||
|
|
||||||
@ -279,27 +349,6 @@ class _ColumnStats:
|
|||||||
return self.top_counts.most_common(n)
|
return self.top_counts.most_common(n)
|
||||||
|
|
||||||
|
|
||||||
class _UnhashableSentinel:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
_UNHASHABLE = _UnhashableSentinel()
|
|
||||||
|
|
||||||
|
|
||||||
def _hashable(val: Any) -> Any:
|
|
||||||
"""Return a hashable form of ``val``, or :data:`_UNHASHABLE` if we can't.
|
|
||||||
|
|
||||||
pandas occasionally hands us objects (lists, dicts) from object columns
|
|
||||||
that aren't hashable. Rather than crashing the whole report, we let the
|
|
||||||
column fall back to "distinct_overflow" mode for those rows.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
hash(val)
|
|
||||||
return val
|
|
||||||
except TypeError:
|
|
||||||
return _UNHASHABLE
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers
|
# Helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -366,16 +415,40 @@ def profile_file(
|
|||||||
}
|
}
|
||||||
|
|
||||||
total_rows = 0
|
total_rows = 0
|
||||||
kwargs = {}
|
effective_chunksize = chunksize if chunksize is not None else PROFILE_CHUNK_ROWS
|
||||||
if chunksize is not None:
|
kwargs = {"chunksize": effective_chunksize}
|
||||||
kwargs["chunksize"] = chunksize
|
# pyreadstat + pandas are both C-level; the per-chunk overhead we pay
|
||||||
|
# is dominated by the value_counts passes in _ColumnStats.update, so
|
||||||
|
# the profile runs O(total_rows) with a small constant.
|
||||||
|
import time
|
||||||
|
started_at = time.monotonic()
|
||||||
|
last_print_at = started_at
|
||||||
for chunk_df, _chunk_meta in iter_sas_chunks(path, **kwargs):
|
for chunk_df, _chunk_meta in iter_sas_chunks(path, **kwargs):
|
||||||
total_rows += len(chunk_df)
|
total_rows += len(chunk_df)
|
||||||
print(f" profiling... {total_rows:,} rows", file=sys.stderr)
|
|
||||||
for name, cs in stats.items():
|
for name, cs in stats.items():
|
||||||
if name not in chunk_df.columns:
|
if name not in chunk_df.columns:
|
||||||
continue
|
continue
|
||||||
cs.update(chunk_df[name])
|
cs.update(chunk_df[name])
|
||||||
|
now = time.monotonic()
|
||||||
|
# Throttle progress output to ~one line per 2 seconds so huge files
|
||||||
|
# don't spam stderr but small files still print at least once.
|
||||||
|
if now - last_print_at >= 2.0:
|
||||||
|
elapsed = now - started_at
|
||||||
|
rate = total_rows / elapsed if elapsed > 0 else 0.0
|
||||||
|
print(
|
||||||
|
f" profiling... {total_rows:,} rows "
|
||||||
|
f"({rate:,.0f} rows/s)",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
last_print_at = now
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - started_at
|
||||||
|
rate = total_rows / elapsed if elapsed > 0 else 0.0
|
||||||
|
print(
|
||||||
|
f" profiled {total_rows:,} rows in {elapsed:.1f}s "
|
||||||
|
f"({rate:,.0f} rows/s)",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
return stats, columns, meta, total_rows
|
return stats, columns, meta, total_rows
|
||||||
|
|
||||||
@ -985,6 +1058,14 @@ def _build_argparser() -> argparse.ArgumentParser:
|
|||||||
help="Uniqueness (distinct/non-null) at/above which a column is an index candidate.")
|
help="Uniqueness (distinct/non-null) at/above which a column is an index candidate.")
|
||||||
p.add_argument("--partition-min-fill-pct", type=float, default=PARTITION_MIN_FILL_PCT)
|
p.add_argument("--partition-min-fill-pct", type=float, default=PARTITION_MIN_FILL_PCT)
|
||||||
p.add_argument("--pre-sharded-max-distinct", type=int, default=PRE_SHARDED_MAX_DISTINCT)
|
p.add_argument("--pre-sharded-max-distinct", type=int, default=PRE_SHARDED_MAX_DISTINCT)
|
||||||
|
p.add_argument(
|
||||||
|
"--chunksize", type=int, default=None,
|
||||||
|
help=(
|
||||||
|
"Rows per streaming read. Bigger chunks amortize pyreadstat / "
|
||||||
|
"pandas overhead (faster for huge files) but use more peak "
|
||||||
|
f"memory. Defaults to PROFILE_CHUNK_ROWS ({PROFILE_CHUNK_ROWS:,})."
|
||||||
|
),
|
||||||
|
)
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
@ -999,7 +1080,7 @@ def main(argv: Optional[List[str]] = None) -> int:
|
|||||||
return 2
|
return 2
|
||||||
|
|
||||||
print(f"profiling {path} -> {out_path}", file=sys.stderr)
|
print(f"profiling {path} -> {out_path}", file=sys.stderr)
|
||||||
stats, columns, meta, total_rows = profile_file(path)
|
stats, columns, meta, total_rows = profile_file(path, chunksize=args.chunksize)
|
||||||
|
|
||||||
drops, partitions, indexes, warnings = classify(
|
drops, partitions, indexes, warnings = classify(
|
||||||
stats, columns,
|
stats, columns,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user