Implement parallel processing for partition discovery in load_folder.py and enhance column filtering in load_sas.py #9

Merged
dp merged 1 commits from fix_prescan into main 2026-04-22 15:35:20 +00:00
2 changed files with 209 additions and 49 deletions
Showing only changes of commit dd83f58412 - Show all commits

View File

@ -157,7 +157,6 @@ import threading
from concurrent.futures import ( from concurrent.futures import (
CancelledError, CancelledError,
ProcessPoolExecutor, ProcessPoolExecutor,
ThreadPoolExecutor,
as_completed, as_completed,
) )
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -885,6 +884,94 @@ def discover_clusters(cfg: FolderConfig) -> List[ClusterSpec]:
return clusters return clusters
# ---------------------------------------------------------------------------
# Top-level workers (must be importable for ProcessPoolExecutor pickling)
# ---------------------------------------------------------------------------
def _prescan_worker(
path_str: str,
delimiter: str,
text_encoding: str,
quotechar: str,
) -> Tuple[
str,
Optional[int],
Optional[Dict[str, Tuple[str, Optional[str]]]],
Optional[str],
]:
"""Top-level prescan worker for ProcessPoolExecutor.
Reads only metadata (pyreadstat ``metadataonly=True``) for one file and
returns ``(path_str, number_rows, per_column_meta, error_message)``.
Kept at module level so ``ProcessPoolExecutor`` can pickle it; the
closure version we used to call from a ``ThreadPoolExecutor`` shared
the parent's GIL and serialised the per-file Python work, which is
why pre-scan felt slow even though the actual disk reads were fast.
"""
try:
meta = read_sas_metadata(
Path(path_str),
delimiter=delimiter,
text_encoding=text_encoding,
quotechar=quotechar,
)
n = getattr(meta, "number_rows", None)
col_meta = extract_union_metadata(meta)
return (
path_str,
int(n) if n is not None else None,
col_meta,
None,
)
except Exception as e:
return (path_str, None, None, f"{type(e).__name__}: {e}")
def _partition_scan_worker(
path_str: str,
partition_by: List[str],
delimiter: str,
text_encoding: str,
quotechar: str,
) -> Tuple[str, Optional[dict], Optional[str]]:
"""Top-level partition-discovery worker for ProcessPoolExecutor.
Streams ``path_str`` with ``usecols=partition_by`` so pyreadstat only
decodes the partition columns themselves - on a wide sas7bdat that's
typically a 10x+ I/O reduction over reading every column. Returns
``(path_str, partial_partition_tree, error_message)``.
``columns`` is intentionally not passed: partition-value normalisation
needs the cluster-wide schema, which is out of process. The merger in
:func:`_discover_cluster_partitions` re-applies normalisation when it
folds the per-file trees together.
"""
try:
def _chunks() -> Any:
for chunk_df, _meta in iter_sas_chunks(
Path(path_str),
delimiter=delimiter,
text_encoding=text_encoding,
quotechar=quotechar,
usecols=list(partition_by),
):
yield chunk_df
tree = discover_partition_values_chunked(
_chunks(), list(partition_by),
)
return (path_str, tree, None)
except Exception as e:
import traceback as _traceback
return (
path_str,
None,
f"{type(e).__name__}: {e}\n{_traceback.format_exc()}",
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Per-cluster load # Per-cluster load
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -929,22 +1016,43 @@ def _infer_cluster_schema(
def _discover_cluster_partitions( def _discover_cluster_partitions(
cluster: ClusterSpec, cluster: ClusterSpec,
columns: Dict, columns: Dict,
*,
workers: int = 1,
) -> dict: ) -> dict:
"""Scan ALL files in ``cluster`` to discover partition values. """Scan ALL files in ``cluster`` to discover partition values.
Returns a nested partition-value tree suitable for passing to Returns a nested partition-value tree suitable for passing to
:func:`load_sas.render_partition_ddl` and :func:`load_sas.create_table`. :func:`load_sas.render_partition_ddl` and :func:`load_sas.create_table`.
Each file is scanned chunk-by-chunk so the full dataset is never
materialized in memory. Each file is read with ``usecols=cluster.partition_by`` so pyreadstat
only decodes the partition columns - on a wide sas7bdat that drops
the bytes-touched-per-file by an order of magnitude vs the old
full-row scan. With ``workers > 1`` the per-file scans run in a
``ProcessPoolExecutor`` and the partial trees are merged as they
complete (true parallelism, visible in ``htop`` as N python procs).
``include`` / ``exclude`` are intentionally not honoured here:
partition columns are validated against the inferred schema before
we ever get called, which already enforces that they survived any
explicit include/exclude filter. The old serial path applied the
filter for symmetry, but it was a no-op once usecols pinned us to
the partition set anyway.
""" """
tkw = _build_text_kw(cluster) tkw = _build_text_kw(cluster)
merged: dict = {} merged: dict = {}
if not cluster.files:
return merged
n_workers = max(1, min(int(workers), len(cluster.files)))
if n_workers <= 1 or len(cluster.files) == 1:
for path in cluster.files: for path in cluster.files:
def _filtered_chunks(p=path): def _filtered_chunks(p=path):
for chunk_df, _chunk_meta in iter_sas_chunks(p, **tkw): for chunk_df, _meta in iter_sas_chunks(
yield apply_column_filter( p, usecols=list(cluster.partition_by), **tkw,
chunk_df, cluster.include, cluster.exclude ):
) yield chunk_df
file_tree = discover_partition_values_chunked( file_tree = discover_partition_values_chunked(
_filtered_chunks(), cluster.partition_by, columns, _filtered_chunks(), cluster.partition_by, columns,
@ -952,6 +1060,41 @@ def _discover_cluster_partitions(
_merge_partition_trees(merged, file_tree) _merge_partition_trees(merged, file_tree)
return merged return merged
with ProcessPoolExecutor(max_workers=n_workers) as ppool:
futures = {
ppool.submit(
_partition_scan_worker,
str(path),
list(cluster.partition_by),
tkw["delimiter"],
tkw["text_encoding"],
tkw["quotechar"],
): path
for path in cluster.files
}
bar = tqdm(
total=len(futures),
unit="file",
desc=" discovering partitions",
file=sys.stderr,
dynamic_ncols=True,
)
try:
for fut in as_completed(futures):
path = futures[fut]
_path_str, file_tree, err = fut.result()
bar.update(1)
if err is not None:
raise RuntimeError(
f"partition discovery failed for {path}: {err}"
)
if file_tree:
_merge_partition_trees(merged, file_tree)
finally:
bar.close()
return merged
def load_cluster( def load_cluster(
conn, conn,
@ -1041,7 +1184,7 @@ def load_cluster(
file=sys.stderr, file=sys.stderr,
) )
partition_values = _discover_cluster_partitions( partition_values = _discover_cluster_partitions(
cluster, first_columns, cluster, first_columns, workers=workers,
) )
total_parts = _count_partitions(partition_values) total_parts = _count_partitions(partition_values)
print( print(
@ -1676,7 +1819,7 @@ def main(argv: Optional[List[str]] = None) -> int:
file=sys.stderr, file=sys.stderr,
) )
partition_values = _discover_cluster_partitions( partition_values = _discover_cluster_partitions(
c, columns, c, columns, workers=max(1, int(args.workers)),
) )
total_parts = _count_partitions(partition_values) total_parts = _count_partitions(partition_values)
print( print(
@ -1787,42 +1930,35 @@ def main(argv: Optional[List[str]] = None) -> int:
file=sys.stderr, file=sys.stderr,
) )
else: else:
prescan_workers = min(16, max(1, len(all_files))) # Cap at min(workers, file_count, 32). Pyreadstat metadata reads are
# mostly C-side I/O + struct decoding, but the per-file Python work
# (extract_union_metadata, dict construction) was being serialised
# by the GIL in the old ThreadPool path - which is why the bar
# crawled even though disk was idle. ProcessPool gives us actual
# parallelism; you'll now see N python procs in htop.
prescan_workers = min(
32, max(1, max(workers, 16)), len(all_files),
)
print( print(
f"pre-scanning row counts + per-column metadata for " f"pre-scanning row counts + per-column metadata for "
f"{len(all_files)} file(s) across {prescan_workers} thread(s)...", f"{len(all_files)} file(s) across {prescan_workers} "
f"process(es)...",
file=sys.stderr, file=sys.stderr,
) )
def _scan_one(
p: Path,
) -> Tuple[
Path,
Optional[int],
Optional[Dict[str, Tuple[str, Optional[str]]]],
Optional[str],
]:
try:
_prescan_tkw = dict(
delimiter=cfg.delimiter,
text_encoding=cfg.text_encoding,
quotechar=cfg.quotechar,
)
meta = read_sas_metadata(p, **_prescan_tkw)
n = getattr(meta, "number_rows", None)
col_meta = extract_union_metadata(meta)
return (
p,
int(n) if n is not None else None,
col_meta,
None,
)
except Exception as e:
return (p, None, None, str(e))
unknown_total_files: List[str] = [] unknown_total_files: List[str] = []
running_total = 0 running_total = 0
with ThreadPoolExecutor(max_workers=prescan_workers) as tpool: with ProcessPoolExecutor(max_workers=prescan_workers) as ppool:
futures = {
ppool.submit(
_prescan_worker,
str(p),
cfg.delimiter,
cfg.text_encoding,
cfg.quotechar,
): p
for p in all_files
}
prescan_bar = tqdm( prescan_bar = tqdm(
total=len(all_files), total=len(all_files),
unit="file", unit="file",
@ -1831,16 +1967,20 @@ def main(argv: Optional[List[str]] = None) -> int:
dynamic_ncols=True, dynamic_ncols=True,
) )
try: try:
for p, n, col_meta, err in tpool.map(_scan_one, all_files): for fut in as_completed(futures):
path_obj = futures[fut]
path_str, n, col_meta, err = fut.result()
prescan_bar.update(1) prescan_bar.update(1)
if err is not None: if err is not None:
unknown_total_files.append(f"{p.name} ({err})") unknown_total_files.append(
f"{path_obj.name} ({err})"
)
elif n is None: elif n is None:
unknown_total_files.append(p.name) unknown_total_files.append(path_obj.name)
else: else:
running_total += n running_total += n
if col_meta is not None: if col_meta is not None:
file_meta_by_path[str(p)] = col_meta file_meta_by_path[path_str] = col_meta
finally: finally:
prescan_bar.close() prescan_bar.close()

View File

@ -811,6 +811,7 @@ def iter_text_chunks(
encoding: str = "utf-8", encoding: str = "utf-8",
quotechar: str = '"', quotechar: str = '"',
chunksize: Optional[int] = None, chunksize: Optional[int] = None,
usecols: Optional[List[str]] = None,
): ):
"""Yield ``(df_chunk, meta)`` tuples for streaming text file loads. """Yield ``(df_chunk, meta)`` tuples for streaming text file loads.
@ -818,6 +819,10 @@ def iter_text_chunks(
iteration. The metadata object is rebuilt for each chunk with the iteration. The metadata object is rebuilt for each chunk with the
chunk's column names and ``number_rows`` set to the total file rows chunk's column names and ``number_rows`` set to the total file rows
(computed once up front). (computed once up front).
When ``usecols`` is provided, only those columns are parsed - useful
for cheap partition-value discovery scans where the rest of the row
would be wasted I/O.
""" """
path = Path(path) path = Path(path)
if chunksize is None: if chunksize is None:
@ -832,8 +837,7 @@ def iter_text_chunks(
total = _count_text_lines(path, encoding) total = _count_text_lines(path, encoding)
reader = pd.read_csv( read_csv_kwargs: Dict[str, Any] = dict(
path,
delimiter=delimiter, delimiter=delimiter,
encoding=encoding, encoding=encoding,
quotechar=quotechar, quotechar=quotechar,
@ -842,6 +846,10 @@ def iter_text_chunks(
keep_default_na=True, keep_default_na=True,
na_values=[""], na_values=[""],
) )
if usecols is not None:
read_csv_kwargs["usecols"] = list(usecols)
reader = pd.read_csv(path, **read_csv_kwargs)
for chunk_df in reader: for chunk_df in reader:
meta = _build_text_metadata(list(chunk_df.columns), number_rows=total) meta = _build_text_metadata(list(chunk_df.columns), number_rows=total)
yield chunk_df, meta yield chunk_df, meta
@ -941,6 +949,7 @@ def iter_sas_chunks(
delimiter: str = ",", delimiter: str = ",",
text_encoding: str = "utf-8", text_encoding: str = "utf-8",
quotechar: str = '"', quotechar: str = '"',
usecols: Optional[List[str]] = None,
): ):
"""Yield ``(df_chunk, meta)`` tuples for streaming loads. """Yield ``(df_chunk, meta)`` tuples for streaming loads.
@ -952,6 +961,13 @@ def iter_sas_chunks(
parseable, otherwise from :data:`DEFAULT_CHUNK_ROWS`. An explicit int parseable, otherwise from :data:`DEFAULT_CHUNK_ROWS`. An explicit int
always wins. always wins.
When ``usecols`` is provided, pyreadstat only decodes the listed
columns. For wide sas7bdat files this is dramatically cheaper than
a full read - the C decoder skips unwanted columns instead of
materializing them. Used by partition-value discovery to avoid
re-reading every byte of every file just to extract a couple of
partition keys.
For text files, delegates to :func:`iter_text_chunks`. For text files, delegates to :func:`iter_text_chunks`.
""" """
if _is_text_file(path): if _is_text_file(path):
@ -961,6 +977,7 @@ def iter_sas_chunks(
encoding=text_encoding, encoding=text_encoding,
quotechar=quotechar, quotechar=quotechar,
chunksize=chunksize, chunksize=chunksize,
usecols=usecols,
) )
return return
if chunksize is None: if chunksize is None:
@ -973,6 +990,9 @@ def iter_sas_chunks(
else: else:
chunksize = DEFAULT_CHUNK_ROWS chunksize = DEFAULT_CHUNK_ROWS
reader, kwargs = _sas_reader(path) reader, kwargs = _sas_reader(path)
if usecols is not None:
kwargs = dict(kwargs)
kwargs["usecols"] = list(usecols)
yield from pyreadstat.read_file_in_chunks( yield from pyreadstat.read_file_in_chunks(
reader, str(Path(path)), chunksize=chunksize, **kwargs reader, str(Path(path)), chunksize=chunksize, **kwargs
) )