diff --git a/generic_loader/load_folder.py b/generic_loader/load_folder.py index c9032d2..cfb5aa9 100644 --- a/generic_loader/load_folder.py +++ b/generic_loader/load_folder.py @@ -151,7 +151,12 @@ import queue as _queue_mod import re import sys import threading -from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed +from concurrent.futures import ( + CancelledError, + ProcessPoolExecutor, + ThreadPoolExecutor, + as_completed, +) from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Tuple @@ -810,6 +815,7 @@ def load_cluster( workers: int = 1, progress_queue: Any = None, db_overrides: Optional[Dict[str, Optional[str]]] = None, + abort_on_first_failure: bool = False, ) -> int: """Load every file in ``cluster`` into one table. Returns total rows loaded. @@ -923,6 +929,7 @@ def load_cluster( progress_queue=progress_queue, db_overrides=db_overrides, column_types=cluster.column_types, + abort_on_first_failure=abort_on_first_failure, ) else: # Serial path: stream the first file on the main connection, then @@ -1134,6 +1141,7 @@ def _load_remaining_files_parallel( progress_queue: Any, db_overrides: Optional[Dict[str, Optional[str]]], column_types: Optional[Dict[str, str]] = None, + abort_on_first_failure: bool = False, ) -> int: """Run append-mode loads for ``files`` across a process pool. @@ -1142,12 +1150,27 @@ def _load_remaining_files_parallel( and stream via COPY just like the serial path. The table itself must already exist (and be committed) before this is called - the worker schema-compat probes read ``information_schema``, which won't see an - uncommitted ``CREATE TABLE``. Failures are collected and re-raised as - a single ``RuntimeError`` at the end so that all other workers' rows - still count toward the committed total. + uncommitted ``CREATE TABLE``. Failures are surfaced two ways: + + 1. **Live stderr feed.** Every worker's outcome - success or failure + - is written to stderr the moment ``as_completed`` hands it back, + via ``tqdm.write`` so the active progress bar stays intact. This + turns the previous "wait 30 minutes for the last worker to drain + before you find out the first 31 failed at minute 2" footgun into + an immediate notification, which matters most when one bad file + (e.g. schema drift, OOM, bad data) torpedoes most of the pool. + 2. **Final aggregate raise.** Errors are still collected and raised as + one ``RuntimeError`` after the pool drains so the per-cluster + success/fail summary in :func:`main` stays accurate. On + ``KeyboardInterrupt`` we re-raise after wrapping the partial + error list into the same ``RuntimeError`` shape, so Ctrl-C still + prints what failed up to that point instead of silently + discarding it. """ total = 0 errors: List[Tuple[str, str]] = [] + completed = 0 + n_files = len(files) # ``max_tasks_per_child=1`` recycles each worker process after every # file. Without this, glibc/pyarrow/pyreadstat all retain peak-water @@ -1176,12 +1199,79 @@ def _load_remaining_files_parallel( ) for p in files ] - for fut in as_completed(futures): - path_str, rows, err = fut.result() - if err is not None: - errors.append((path_str, err)) - else: - total += rows + aborted = False + try: + for fut in as_completed(futures): + # ``--abort-on-first-failure`` cancels still-pending + # futures; ``as_completed`` yields them anyway and + # ``.result()`` raises ``CancelledError``. Skip those + # quietly - the abort log already accounted for the + # cancellation count. + try: + path_str, rows, err = fut.result() + except CancelledError: + continue + completed += 1 + name = Path(path_str).name + if err is not None: + errors.append((path_str, err)) + # ``tqdm.write`` writes through the bar without + # garbling it; bare ``print`` would interleave with + # the rendered progress line. + tqdm.write( + f"[FAIL {completed}/{n_files}] {name}: {err}", + file=sys.stderr, + ) + if abort_on_first_failure and not aborted: + # Cancel anything still queued. Already-running + # workers can't be interrupted portably, so they + # keep going - we just stop dispatching new files + # and stop counting their results toward total. + # Set the flag once so we don't spam a cancel + # storm if multiple workers fail at the same time. + aborted = True + cancelled = 0 + for f in futures: + if not f.done() and f.cancel(): + cancelled += 1 + tqdm.write( + f"[abort] --abort-on-first-failure: cancelled " + f"{cancelled} pending file(s); waiting for " + f"{n_files - completed - cancelled} in-flight " + f"worker(s) to finish.", + file=sys.stderr, + ) + else: + tqdm.write( + f"[done {completed}/{n_files}] {name}: " + f"{rows:,} row(s)", + file=sys.stderr, + ) + total += rows + except KeyboardInterrupt: + # Cancel anything still queued so we exit promptly. Already- + # running workers will run to completion (Python can't kill + # a child mid-syscall portably) but at least pending tasks + # don't keep firing. We re-raise as the same RuntimeError + # shape used below so the caller's per-cluster summary path + # still sees the partial failure list instead of an opaque + # KeyboardInterrupt with no context about which workers + # had already failed. + for f in futures: + if not f.done(): + f.cancel() + tqdm.write( + f"[interrupt] Ctrl-C after {completed}/{n_files} file(s); " + f"{len(errors)} failure(s) collected so far.", + file=sys.stderr, + ) + if errors: + joined = "\n".join(f" {p}: {e}" for p, e in errors) + raise RuntimeError( + f"interrupted after {len(errors)} worker(s) failed " + f"while appending to {schemaname}.{tablename}:\n{joined}" + ) + raise if errors: joined = "\n".join(f" {p}: {e}" for p, e in errors) @@ -1224,6 +1314,20 @@ def _build_argparser() -> argparse.ArgumentParser: "cluster back and continue with the next one." ), ) + p.add_argument( + "--abort-on-first-failure", + action="store_true", + help=( + "Within a single cluster's parallel append, cancel every " + "still-pending worker the instant one worker fails. Use when " + "you know the failure is systemic (schema drift, bad creds, " + "OOM) and don't want to wait for the slow files to drain " + "before getting your prompt back. Already-running workers " + "still finish their current file - Python can't kill children " + "mid-syscall - but new files won't be dispatched. Orthogonal " + "to --fail-fast, which controls what happens between clusters." + ), + ) p.add_argument( "--dbcreds", action="store_true", @@ -1645,6 +1749,7 @@ def main(argv: Optional[List[str]] = None) -> int: workers=workers, progress_queue=progress_queue, db_overrides=db_overrides, + abort_on_first_failure=args.abort_on_first_failure, ) conn.commit() totals.append((cluster.tablename, len(cluster.files), rows))