"""S3-source counterpart to ``generic_loader/load_folder.py``. Reads a YAML config that points at an S3 bucket + prefix, lists every object under that prefix recursively, groups objects into *clusters* using the same explicit-pattern + auto-detect rules as ``load_folder.py``, and downloads each cluster's files into its own subfolder under a local destination root. ------------------------------------------------------------------------------- USAGE ------------------------------------------------------------------------------- 1. YAML config -------------- :: bucket: my-bucket # required prefix: census/2020/raw/ # required; recursive scan under it local_folder: ./downloads # required; one subfolder per cluster aws_profile: default # optional; default boto3 chain if omitted auto_detect: true # optional; default true extensions: # optional; default sas7bdat/xpt/xport - .sas7bdat on_exists: skip # optional; skip | overwrite | error concurrency: 4 # optional; default 4 clusters: - pattern: '^group_a\\d+\\.sas7bdat$' name: group_a - pattern: '^group_b\\d+\\.sas7bdat$' name: group_b 2. Command-line interface ------------------------- :: python s3_download.py --config download_config.yaml [--dry-run] [--overwrite] [--fail-fast] Flags: --config PATH Required. Path to the YAML config above. --dry-run List discovered clusters and the objects each would download (with sizes). No GET requests are issued beyond the initial LIST. --overwrite Force-redownload every matched object regardless of local cache state. Equivalent to ``on_exists: overwrite`` for this run. --fail-fast Abort the whole run on the first per-file download failure. Default is to log the failure and keep going. Exit codes: 0 - every file downloaded (or skipped) successfully (or dry-run completed) 1 - at least one file failed (details on stderr) 2 - the LIST returned no objects matching the configured extensions 3. Discovery rules ------------------ * Listing is recursive (no S3 ``Delimiter``). Regexes are matched against the *basename* of each key (the part after the last ``/``), so a nested object like ``census/2020/raw/nested/group_c1.sas7bdat`` is grouped by ``group_c1.sas7bdat`` alone. * Explicit patterns are tried in order. A key matched by one pattern is removed from the pool before the next pattern runs. Overlap between patterns is flagged as an error at discovery time. * Auto-detect groups remaining keys by ``re.sub(r'\\d+$', '', stem)`` with any trailing ``_`` / ``-`` stripped afterward, mirroring ``load_folder.py``. Stems without trailing digits become singleton clusters named after the stem. * Within a cluster, files are sorted numerically by the LAST digit group in the stem so ``_9_`` sorts before ``_40_`` regardless of zero-padding. 4. Library usage ---------------- :: from s3_download import load_download_config, list_s3_objects, \ discover_clusters, download_cluster, build_s3_client cfg = load_download_config("download_config.yaml") s3 = build_s3_client(cfg) objects = list_s3_objects(s3, cfg) for cluster in discover_clusters(cfg, objects): download_cluster(s3, cfg, cluster) """ from __future__ import annotations import argparse import re import sys from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import boto3 import yaml DEFAULT_EXTENSIONS: Tuple[str, ...] = (".sas7bdat", ".xpt", ".xport") VALID_ON_EXISTS: Tuple[str, ...] = ("skip", "overwrite", "error") DEFAULT_CONCURRENCY: int = 4 # --------------------------------------------------------------------------- # Dataclasses # --------------------------------------------------------------------------- @dataclass class _ExplicitPattern: """Parsed form of a single ``clusters[*]`` YAML entry.""" pattern: re.Pattern raw_pattern: str name: str @dataclass class DownloadConfig: """Top-level configuration parsed from YAML.""" bucket: str prefix: str local_folder: Path aws_profile: Optional[str] = None auto_detect: bool = True extensions: Tuple[str, ...] = DEFAULT_EXTENSIONS on_exists: str = "skip" concurrency: int = DEFAULT_CONCURRENCY explicit: List[_ExplicitPattern] = field(default_factory=list) @dataclass class S3Object: """A single S3 object selected from the LIST response.""" key: str basename: str size: int @dataclass class ClusterSpec: """Resolved per-cluster download settings.""" name: str objects: List[S3Object] source: str # "explicit" or "auto" pattern: Optional[str] = None # --------------------------------------------------------------------------- # Config loading # --------------------------------------------------------------------------- def _validate_on_exists(value: Any, where: str) -> str: s = str(value).lower() if s not in VALID_ON_EXISTS: raise ValueError( f"{where}: on_exists={value!r} is not one of {list(VALID_ON_EXISTS)}" ) return s def _parse_extensions(raw_value: Any, where: str) -> Tuple[str, ...]: if raw_value is None: return DEFAULT_EXTENSIONS if isinstance(raw_value, str): items = [raw_value] elif isinstance(raw_value, list): items = list(raw_value) else: raise ValueError( f"{where}: 'extensions' must be a string or list of strings." ) out: List[str] = [] for i, item in enumerate(items): if not isinstance(item, str) or not item.strip(): raise ValueError( f"{where}: 'extensions[{i}]' must be a non-empty string." ) ext = item.strip().lower() if not ext.startswith("."): ext = "." + ext out.append(ext) if len(out) != len(set(out)): raise ValueError(f"{where}: 'extensions' contains duplicates.") return tuple(out) def _parse_concurrency(raw_value: Any, where: str) -> int: if raw_value is None: return DEFAULT_CONCURRENCY try: value = int(raw_value) except (TypeError, ValueError): raise ValueError( f"{where}: 'concurrency' must be a positive integer, " f"got {raw_value!r}" ) if value <= 0: raise ValueError( f"{where}: 'concurrency' must be a positive integer, got {value}" ) return value def load_download_config(path: Path) -> DownloadConfig: """Parse and validate the YAML download config at ``path``.""" path = Path(path) with path.open("r", encoding="utf-8") as f: raw = yaml.safe_load(f) if not isinstance(raw, dict): raise ValueError( f"Config at {path} must be a YAML mapping at the top level." ) missing = [ k for k in ("bucket", "prefix", "local_folder") if k not in raw ] if missing: raise ValueError( f"Config {path} missing required keys: {', '.join(missing)}" ) bucket = str(raw["bucket"]).strip() if not bucket: raise ValueError(f"Config {path}: 'bucket' must be a non-empty string.") prefix = str(raw["prefix"]) # Normalize: strip leading slash, ensure exactly one trailing slash unless # the user explicitly asked for an empty prefix (whole bucket scan). prefix = prefix.lstrip("/") if prefix and not prefix.endswith("/"): prefix = prefix + "/" local_folder = Path(raw["local_folder"]) if not local_folder.is_absolute(): candidate = (path.parent / local_folder).resolve() # Mirror load_folder's behavior: prefer the config-relative path when # it exists, otherwise keep what the user wrote. Either way, we'll # mkdir(parents=True) before downloading so non-existence is fine. local_folder = candidate if candidate.parent.exists() else candidate aws_profile = raw.get("aws_profile") if aws_profile is not None: aws_profile = str(aws_profile).strip() or None auto_detect = bool(raw.get("auto_detect", True)) extensions = _parse_extensions(raw.get("extensions"), f"Config {path}") on_exists = _validate_on_exists( raw.get("on_exists", "skip"), f"Config {path}" ) concurrency = _parse_concurrency( raw.get("concurrency"), f"Config {path}" ) explicit: List[_ExplicitPattern] = [] seen_names: set = set() clusters_raw = raw.get("clusters") or [] if not isinstance(clusters_raw, list): raise ValueError(f"Config {path}: 'clusters' must be a list if present.") for i, entry in enumerate(clusters_raw): where = f"Config {path} clusters[{i}]" if not isinstance(entry, dict): raise ValueError(f"{where} must be a mapping.") if "pattern" not in entry or "name" not in entry: raise ValueError(f"{where} must include 'pattern' and 'name'.") raw_pat = str(entry["pattern"]) try: compiled = re.compile(raw_pat) except re.error as e: raise ValueError( f"{where}: invalid regex {raw_pat!r}: {e}" ) from e name = str(entry["name"]).strip() if not name: raise ValueError(f"{where}: 'name' must be a non-empty string.") if name in seen_names: raise ValueError( f"{where}: duplicate cluster name {name!r}" ) seen_names.add(name) explicit.append( _ExplicitPattern( pattern=compiled, raw_pattern=raw_pat, name=name, ) ) return DownloadConfig( bucket=bucket, prefix=prefix, local_folder=local_folder, aws_profile=aws_profile, auto_detect=auto_detect, extensions=extensions, on_exists=on_exists, concurrency=concurrency, explicit=explicit, ) # --------------------------------------------------------------------------- # S3 listing # --------------------------------------------------------------------------- def build_s3_client(cfg: DownloadConfig): """Build a boto3 S3 client honoring ``cfg.aws_profile`` if set.""" if cfg.aws_profile: session = boto3.Session(profile_name=cfg.aws_profile) else: session = boto3.Session() return session.client("s3") def list_s3_objects(s3_client, cfg: DownloadConfig) -> List[S3Object]: """List all objects under ``cfg.prefix`` recursively, filtered by extension.""" paginator = s3_client.get_paginator("list_objects_v2") out: List[S3Object] = [] for page in paginator.paginate(Bucket=cfg.bucket, Prefix=cfg.prefix): for entry in page.get("Contents", []): key = entry["Key"] if key.endswith("/"): continue basename = key.rsplit("/", 1)[-1] ext = ("." + basename.rsplit(".", 1)[-1].lower() if "." in basename else "") if ext not in cfg.extensions: continue out.append( S3Object(key=key, basename=basename, size=int(entry["Size"])) ) out.sort(key=lambda o: o.key) return out # --------------------------------------------------------------------------- # Cluster discovery (mirrors generic_loader/load_folder.py) # --------------------------------------------------------------------------- _TRAILING_DIGIT_RE = re.compile(r"\d+$") _DIGIT_GROUP_RE = re.compile(r"\d+") def _auto_prefix(stem: str) -> str: """Cluster key for *stem*: strip trailing digits and any trailing _/-.""" stripped = _TRAILING_DIGIT_RE.sub("", stem) stripped = stripped.rstrip("_-") return stripped or stem def _basename_stem(basename: str) -> str: if "." in basename: return basename.rsplit(".", 1)[0] return basename def _cluster_sort_key(obj: S3Object) -> Tuple[int, str]: """Sort key for ordering objects within a cluster by trailing digits.""" stem = _basename_stem(obj.basename) digits = _DIGIT_GROUP_RE.findall(stem) n = int(digits[-1]) if digits else -1 return (n, stem) def discover_clusters( cfg: DownloadConfig, objects: List[S3Object] ) -> List[ClusterSpec]: """Bucket *objects* into clusters using explicit patterns then auto-detect.""" clusters: List[ClusterSpec] = [] for i, p_i in enumerate(cfg.explicit): for j in range(i + 1, len(cfg.explicit)): p_j = cfg.explicit[j] for obj in objects: if (p_i.pattern.search(obj.basename) and p_j.pattern.search(obj.basename)): raise ValueError( f"Object {obj.basename!r} matches multiple explicit " f"patterns: {p_i.raw_pattern!r} and " f"{p_j.raw_pattern!r}" ) remaining = list(objects) for patt in cfg.explicit: matched = [o for o in remaining if patt.pattern.search(o.basename)] if not matched: clusters.append( ClusterSpec( name=patt.name, objects=[], source="explicit", pattern=patt.raw_pattern, ) ) continue remaining = [o for o in remaining if o not in matched] clusters.append( ClusterSpec( name=patt.name, objects=sorted(matched, key=_cluster_sort_key), source="explicit", pattern=patt.raw_pattern, ) ) if cfg.auto_detect and remaining: buckets: Dict[str, List[S3Object]] = {} for obj in remaining: key = _auto_prefix(_basename_stem(obj.basename)) buckets.setdefault(key, []).append(obj) for key in sorted(buckets): clusters.append( ClusterSpec( name=key, objects=sorted(buckets[key], key=_cluster_sort_key), source="auto", ) ) return clusters # --------------------------------------------------------------------------- # Download # --------------------------------------------------------------------------- def _local_path( cfg: DownloadConfig, cluster: ClusterSpec, obj: S3Object, basename_collisions: set, ) -> Path: """Resolve the on-disk destination path for *obj*. Falls back to a key-derived filename when two objects in the same cluster share a basename (possible under recursive scan). """ cluster_dir = cfg.local_folder / cluster.name if obj.basename in basename_collisions: safe = obj.key.replace("/", "__") return cluster_dir / safe return cluster_dir / obj.basename def _basename_collisions(cluster: ClusterSpec) -> set: """Return the set of basenames that appear more than once in *cluster*.""" seen: Dict[str, int] = {} for obj in cluster.objects: seen[obj.basename] = seen.get(obj.basename, 0) + 1 return {name for name, count in seen.items() if count > 1} def _decide_action( local_path: Path, obj: S3Object, on_exists: str ) -> Tuple[str, Optional[str]]: """Return ``(action, message)`` where action is 'download' or 'skip'.""" if on_exists == "overwrite": return ("download", None) if not local_path.exists(): return ("download", None) local_size = local_path.stat().st_size if local_size == obj.size: return ( "skip", f" skip {obj.key} -> {local_path} (size {local_size} matches)", ) if on_exists == "error": raise RuntimeError( f"Local file {local_path} exists with size {local_size} but S3 " f"object s3://{obj.key} has size {obj.size} (on_exists=error)" ) return ( "download", f" re-download {obj.key} -> {local_path} " f"(local size {local_size} != S3 {obj.size})", ) def download_cluster( s3_client, cfg: DownloadConfig, cluster: ClusterSpec, *, on_exists_override: Optional[str] = None, fail_fast: bool = False, ) -> Tuple[int, int, int, List[Tuple[str, Exception]]]: """Download every object in *cluster* into ``cfg.local_folder/cluster.name``. Returns ``(downloaded, skipped, bytes_downloaded, failures)`` where *failures* is a list of ``(key, exception)`` tuples. """ if not cluster.objects: return (0, 0, 0, []) on_exists = on_exists_override or cfg.on_exists cluster_dir = cfg.local_folder / cluster.name cluster_dir.mkdir(parents=True, exist_ok=True) collisions = _basename_collisions(cluster) plans: List[Tuple[S3Object, Path, str]] = [] skipped = 0 for obj in cluster.objects: local_path = _local_path(cfg, cluster, obj, collisions) action, message = _decide_action(local_path, obj, on_exists) if message: print(message, file=sys.stderr) if action == "skip": skipped += 1 continue plans.append((obj, local_path, action)) downloaded = 0 bytes_downloaded = 0 failures: List[Tuple[str, Exception]] = [] if not plans: return (0, skipped, 0, failures) def _do_one(item): obj, local_path, _action = item local_path.parent.mkdir(parents=True, exist_ok=True) s3_client.download_file(cfg.bucket, obj.key, str(local_path)) return obj if cfg.concurrency <= 1 or len(plans) == 1: for plan in plans: obj = plan[0] try: _do_one(plan) downloaded += 1 bytes_downloaded += obj.size print( f" ok {obj.key} -> {plan[1]} ({obj.size:,} bytes)", file=sys.stderr, ) except Exception as exc: failures.append((obj.key, exc)) print( f" FAIL {obj.key}: {exc}", file=sys.stderr, ) if fail_fast: break else: with ThreadPoolExecutor(max_workers=cfg.concurrency) as pool: future_to_plan = {pool.submit(_do_one, p): p for p in plans} for fut in as_completed(future_to_plan): plan = future_to_plan[fut] obj = plan[0] try: fut.result() downloaded += 1 bytes_downloaded += obj.size print( f" ok {obj.key} -> {plan[1]} " f"({obj.size:,} bytes)", file=sys.stderr, ) except Exception as exc: failures.append((obj.key, exc)) print( f" FAIL {obj.key}: {exc}", file=sys.stderr, ) if fail_fast: for other in future_to_plan: other.cancel() break return (downloaded, skipped, bytes_downloaded, failures) # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def _build_argparser() -> argparse.ArgumentParser: p = argparse.ArgumentParser( description=( "Download S3 objects under a prefix into a local folder, " "grouping objects into clusters that each become one subfolder." ), ) p.add_argument( "--config", required=True, type=Path, help="Path to YAML config", ) p.add_argument( "--dry-run", action="store_true", help=( "List discovered clusters and the objects each would download. " "No GET requests are issued beyond the initial LIST." ), ) p.add_argument( "--overwrite", action="store_true", help=( "Force-redownload every matched object regardless of local " "cache state. Equivalent to on_exists=overwrite." ), ) p.add_argument( "--fail-fast", action="store_true", help=( "Abort on the first per-file download failure. Default is to " "log the failure and keep going." ), ) return p def _format_bytes(n: int) -> str: units = ("B", "KiB", "MiB", "GiB", "TiB") size = float(n) for unit in units: if size < 1024 or unit == units[-1]: return f"{size:,.1f} {unit}" if unit != "B" else f"{int(size):,} B" size /= 1024 return f"{n} B" def _describe_cluster(cluster: ClusterSpec) -> str: src = cluster.source if cluster.pattern: src += f" pattern={cluster.pattern!r}" if not cluster.objects: return ( f"cluster {cluster.name!r} [{src}]\n objects: (no matching keys)" ) total = sum(o.size for o in cluster.objects) lines = [ f"cluster {cluster.name!r} [{src}] " f"{len(cluster.objects)} object(s), {_format_bytes(total)}" ] for obj in cluster.objects: lines.append(f" - {obj.key} ({_format_bytes(obj.size)})") return "\n".join(lines) def main(argv: Optional[List[str]] = None) -> int: args = _build_argparser().parse_args(argv) cfg = load_download_config(args.config) s3 = build_s3_client(cfg) objects = list_s3_objects(s3, cfg) if not objects: print( f"error: no objects matching extensions " f"{list(cfg.extensions)} found under " f"s3://{cfg.bucket}/{cfg.prefix}", file=sys.stderr, ) return 2 clusters = discover_clusters(cfg, objects) loadable = [c for c in clusters if c.objects] print( f"discovered {len(loadable)} cluster(s) " f"({sum(len(c.objects) for c in loadable)} object(s)) under " f"s3://{cfg.bucket}/{cfg.prefix}:" ) for c in clusters: print(_describe_cluster(c)) if args.dry_run: return 0 cfg.local_folder.mkdir(parents=True, exist_ok=True) on_exists_override = "overwrite" if args.overwrite else None totals: List[Tuple[str, int, int, int]] = [] # (name, dl, skip, bytes) failures: List[Tuple[str, str, Exception]] = [] # (cluster, key, exc) aborted = False for cluster in loadable: if aborted: break print( f"\n>>> downloading cluster {cluster.name!r} " f"({len(cluster.objects)} object(s))" ) try: dl, sk, by, fails = download_cluster( s3, cfg, cluster, on_exists_override=on_exists_override, fail_fast=args.fail_fast, ) except Exception as exc: failures.append((cluster.name, "", exc)) print( f" !! cluster {cluster.name!r} aborted: {exc}", file=sys.stderr, ) if args.fail_fast: aborted = True continue totals.append((cluster.name, dl, sk, by)) for key, exc in fails: failures.append((cluster.name, key, exc)) if fails and args.fail_fast: aborted = True print("\n=== summary ===") for name, dl, sk, by in totals: print( f" {name}: downloaded {dl}, skipped {sk}, " f"{_format_bytes(by)}" ) for cname, key, exc in failures: print(f" FAIL {cname} {key}: {exc}", file=sys.stderr) return 1 if failures else 0 if __name__ == "__main__": sys.exit(main())