foxtrot/utils/s3_download.py

746 lines
24 KiB
Python
Raw Normal View History

"""S3-source counterpart to ``generic_loader/load_folder.py``.
Reads a YAML config that points at an S3 bucket + prefix, lists every object
under that prefix recursively, groups objects into *clusters* using the same
explicit-pattern + auto-detect rules as ``load_folder.py``, and downloads each
cluster's files into its own subfolder under a local destination root.
2026-04-22 01:05:26 +00:00
Supported file types:
* SAS data files: ``.sas7bdat``, ``.xpt``, ``.xport``
* Delimited text files: ``.txt``, ``.csv``, ``.tsv``
-------------------------------------------------------------------------------
USAGE
-------------------------------------------------------------------------------
1. YAML config
--------------
::
bucket: my-bucket # required
prefix: census/2020/raw/ # required; recursive scan under it
local_folder: ./downloads # required; one subfolder per cluster
aws_profile: default # optional; default boto3 chain if omitted
auto_detect: true # optional; default true
2026-04-22 01:05:26 +00:00
extensions: # optional; default sas7bdat/xpt/xport/txt/csv/tsv
- .sas7bdat
2026-04-22 01:05:26 +00:00
- .csv
on_exists: skip # optional; skip | overwrite | error
concurrency: 4 # optional; default 4
clusters:
- pattern: '^group_a\\d+\\.sas7bdat$'
name: group_a
- pattern: '^group_b\\d+\\.sas7bdat$'
name: group_b
2. Command-line interface
-------------------------
::
python s3_download.py --config download_config.yaml [--dry-run]
[--overwrite] [--fail-fast]
Flags:
--config PATH Required. Path to the YAML config above.
--dry-run List discovered clusters and the objects each would
download (with sizes). No GET requests are issued
beyond the initial LIST.
--overwrite Force-redownload every matched object regardless of
local cache state. Equivalent to ``on_exists: overwrite``
for this run.
--fail-fast Abort the whole run on the first per-file download
failure. Default is to log the failure and keep going.
Exit codes:
0 - every file downloaded (or skipped) successfully (or dry-run completed)
1 - at least one file failed (details on stderr)
2 - the LIST returned no objects matching the configured extensions
3. Discovery rules
------------------
* Listing is recursive (no S3 ``Delimiter``). Regexes are matched against
the *basename* of each key (the part after the last ``/``), so a nested
object like ``census/2020/raw/nested/group_c1.sas7bdat`` is grouped by
2026-04-22 01:05:26 +00:00
``group_c1.sas7bdat`` alone. Text files (e.g. ``data.csv``) are handled
identically the basename is extracted and matched the same way.
* Explicit patterns are tried in order. A key matched by one pattern is
removed from the pool before the next pattern runs. Overlap between
patterns is flagged as an error at discovery time.
* Auto-detect groups remaining keys by ``re.sub(r'\\d+$', '', stem)`` with
any trailing ``_`` / ``-`` stripped afterward, mirroring
``load_folder.py``. Stems without trailing digits become singleton
clusters named after the stem.
* Within a cluster, files are sorted numerically by the LAST digit group
in the stem so ``_9_`` sorts before ``_40_`` regardless of zero-padding.
4. Library usage
----------------
::
from s3_download import load_download_config, list_s3_objects, \
discover_clusters, download_cluster, build_s3_client
cfg = load_download_config("download_config.yaml")
s3 = build_s3_client(cfg)
objects = list_s3_objects(s3, cfg)
for cluster in discover_clusters(cfg, objects):
download_cluster(s3, cfg, cluster)
"""
from __future__ import annotations
import argparse
import re
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import boto3
import yaml
2026-04-22 01:05:26 +00:00
SAS_EXTENSIONS: Tuple[str, ...] = (".sas7bdat", ".xpt", ".xport")
TEXT_EXTENSIONS: Tuple[str, ...] = (".txt", ".csv", ".tsv")
DEFAULT_EXTENSIONS: Tuple[str, ...] = SAS_EXTENSIONS + TEXT_EXTENSIONS
VALID_ON_EXISTS: Tuple[str, ...] = ("skip", "overwrite", "error")
DEFAULT_CONCURRENCY: int = 4
# ---------------------------------------------------------------------------
# Dataclasses
# ---------------------------------------------------------------------------
@dataclass
class _ExplicitPattern:
"""Parsed form of a single ``clusters[*]`` YAML entry."""
pattern: re.Pattern
raw_pattern: str
name: str
@dataclass
class DownloadConfig:
"""Top-level configuration parsed from YAML."""
bucket: str
prefix: str
local_folder: Path
aws_profile: Optional[str] = None
auto_detect: bool = True
extensions: Tuple[str, ...] = DEFAULT_EXTENSIONS
on_exists: str = "skip"
concurrency: int = DEFAULT_CONCURRENCY
explicit: List[_ExplicitPattern] = field(default_factory=list)
@dataclass
class S3Object:
"""A single S3 object selected from the LIST response."""
key: str
basename: str
size: int
@dataclass
class ClusterSpec:
"""Resolved per-cluster download settings."""
name: str
objects: List[S3Object]
source: str # "explicit" or "auto"
pattern: Optional[str] = None
# ---------------------------------------------------------------------------
# Config loading
# ---------------------------------------------------------------------------
def _validate_on_exists(value: Any, where: str) -> str:
s = str(value).lower()
if s not in VALID_ON_EXISTS:
raise ValueError(
f"{where}: on_exists={value!r} is not one of {list(VALID_ON_EXISTS)}"
)
return s
def _parse_extensions(raw_value: Any, where: str) -> Tuple[str, ...]:
if raw_value is None:
return DEFAULT_EXTENSIONS
if isinstance(raw_value, str):
items = [raw_value]
elif isinstance(raw_value, list):
items = list(raw_value)
else:
raise ValueError(
f"{where}: 'extensions' must be a string or list of strings."
)
out: List[str] = []
for i, item in enumerate(items):
if not isinstance(item, str) or not item.strip():
raise ValueError(
f"{where}: 'extensions[{i}]' must be a non-empty string."
)
ext = item.strip().lower()
if not ext.startswith("."):
ext = "." + ext
out.append(ext)
if len(out) != len(set(out)):
raise ValueError(f"{where}: 'extensions' contains duplicates.")
return tuple(out)
def _parse_concurrency(raw_value: Any, where: str) -> int:
if raw_value is None:
return DEFAULT_CONCURRENCY
try:
value = int(raw_value)
except (TypeError, ValueError):
raise ValueError(
f"{where}: 'concurrency' must be a positive integer, "
f"got {raw_value!r}"
)
if value <= 0:
raise ValueError(
f"{where}: 'concurrency' must be a positive integer, got {value}"
)
return value
def load_download_config(path: Path) -> DownloadConfig:
"""Parse and validate the YAML download config at ``path``."""
path = Path(path)
with path.open("r", encoding="utf-8") as f:
raw = yaml.safe_load(f)
if not isinstance(raw, dict):
raise ValueError(
f"Config at {path} must be a YAML mapping at the top level."
)
missing = [
k for k in ("bucket", "prefix", "local_folder") if k not in raw
]
if missing:
raise ValueError(
f"Config {path} missing required keys: {', '.join(missing)}"
)
bucket = str(raw["bucket"]).strip()
if not bucket:
raise ValueError(f"Config {path}: 'bucket' must be a non-empty string.")
prefix = str(raw["prefix"])
# Normalize: strip leading slash, ensure exactly one trailing slash unless
# the user explicitly asked for an empty prefix (whole bucket scan).
prefix = prefix.lstrip("/")
if prefix and not prefix.endswith("/"):
prefix = prefix + "/"
local_folder = Path(raw["local_folder"])
if not local_folder.is_absolute():
candidate = (path.parent / local_folder).resolve()
# Mirror load_folder's behavior: prefer the config-relative path when
# it exists, otherwise keep what the user wrote. Either way, we'll
# mkdir(parents=True) before downloading so non-existence is fine.
local_folder = candidate if candidate.parent.exists() else candidate
aws_profile = raw.get("aws_profile")
if aws_profile is not None:
aws_profile = str(aws_profile).strip() or None
auto_detect = bool(raw.get("auto_detect", True))
extensions = _parse_extensions(raw.get("extensions"), f"Config {path}")
on_exists = _validate_on_exists(
raw.get("on_exists", "skip"), f"Config {path}"
)
concurrency = _parse_concurrency(
raw.get("concurrency"), f"Config {path}"
)
explicit: List[_ExplicitPattern] = []
seen_names: set = set()
clusters_raw = raw.get("clusters") or []
if not isinstance(clusters_raw, list):
raise ValueError(f"Config {path}: 'clusters' must be a list if present.")
for i, entry in enumerate(clusters_raw):
where = f"Config {path} clusters[{i}]"
if not isinstance(entry, dict):
raise ValueError(f"{where} must be a mapping.")
if "pattern" not in entry or "name" not in entry:
raise ValueError(f"{where} must include 'pattern' and 'name'.")
raw_pat = str(entry["pattern"])
try:
compiled = re.compile(raw_pat)
except re.error as e:
raise ValueError(
f"{where}: invalid regex {raw_pat!r}: {e}"
) from e
name = str(entry["name"]).strip()
if not name:
raise ValueError(f"{where}: 'name' must be a non-empty string.")
if name in seen_names:
raise ValueError(
f"{where}: duplicate cluster name {name!r}"
)
seen_names.add(name)
explicit.append(
_ExplicitPattern(
pattern=compiled, raw_pattern=raw_pat, name=name,
)
)
return DownloadConfig(
bucket=bucket,
prefix=prefix,
local_folder=local_folder,
aws_profile=aws_profile,
auto_detect=auto_detect,
extensions=extensions,
on_exists=on_exists,
concurrency=concurrency,
explicit=explicit,
)
# ---------------------------------------------------------------------------
# S3 listing
# ---------------------------------------------------------------------------
def build_s3_client(cfg: DownloadConfig):
"""Build a boto3 S3 client honoring ``cfg.aws_profile`` if set."""
if cfg.aws_profile:
session = boto3.Session(profile_name=cfg.aws_profile)
else:
session = boto3.Session()
return session.client("s3")
def list_s3_objects(s3_client, cfg: DownloadConfig) -> List[S3Object]:
2026-04-22 01:05:26 +00:00
"""List all objects under ``cfg.prefix`` recursively, filtered by extension.
Supports SAS extensions (``.sas7bdat``, ``.xpt``, ``.xport``) and text
extensions (``.txt``, ``.csv``, ``.tsv``) whichever are present in
``cfg.extensions``.
"""
paginator = s3_client.get_paginator("list_objects_v2")
out: List[S3Object] = []
for page in paginator.paginate(Bucket=cfg.bucket, Prefix=cfg.prefix):
for entry in page.get("Contents", []):
key = entry["Key"]
if key.endswith("/"):
continue
basename = key.rsplit("/", 1)[-1]
ext = ("." + basename.rsplit(".", 1)[-1].lower()
if "." in basename else "")
if ext not in cfg.extensions:
continue
out.append(
S3Object(key=key, basename=basename, size=int(entry["Size"]))
)
out.sort(key=lambda o: o.key)
return out
# ---------------------------------------------------------------------------
# Cluster discovery (mirrors generic_loader/load_folder.py)
# ---------------------------------------------------------------------------
_TRAILING_DIGIT_RE = re.compile(r"\d+$")
_DIGIT_GROUP_RE = re.compile(r"\d+")
def _auto_prefix(stem: str) -> str:
"""Cluster key for *stem*: strip trailing digits and any trailing _/-."""
stripped = _TRAILING_DIGIT_RE.sub("", stem)
stripped = stripped.rstrip("_-")
return stripped or stem
def _basename_stem(basename: str) -> str:
if "." in basename:
return basename.rsplit(".", 1)[0]
return basename
def _cluster_sort_key(obj: S3Object) -> Tuple[int, str]:
"""Sort key for ordering objects within a cluster by trailing digits."""
stem = _basename_stem(obj.basename)
digits = _DIGIT_GROUP_RE.findall(stem)
n = int(digits[-1]) if digits else -1
return (n, stem)
def discover_clusters(
cfg: DownloadConfig, objects: List[S3Object]
) -> List[ClusterSpec]:
"""Bucket *objects* into clusters using explicit patterns then auto-detect."""
clusters: List[ClusterSpec] = []
for i, p_i in enumerate(cfg.explicit):
for j in range(i + 1, len(cfg.explicit)):
p_j = cfg.explicit[j]
for obj in objects:
if (p_i.pattern.search(obj.basename)
and p_j.pattern.search(obj.basename)):
raise ValueError(
f"Object {obj.basename!r} matches multiple explicit "
f"patterns: {p_i.raw_pattern!r} and "
f"{p_j.raw_pattern!r}"
)
remaining = list(objects)
for patt in cfg.explicit:
matched = [o for o in remaining if patt.pattern.search(o.basename)]
if not matched:
clusters.append(
ClusterSpec(
name=patt.name,
objects=[],
source="explicit",
pattern=patt.raw_pattern,
)
)
continue
remaining = [o for o in remaining if o not in matched]
clusters.append(
ClusterSpec(
name=patt.name,
objects=sorted(matched, key=_cluster_sort_key),
source="explicit",
pattern=patt.raw_pattern,
)
)
if cfg.auto_detect and remaining:
buckets: Dict[str, List[S3Object]] = {}
for obj in remaining:
key = _auto_prefix(_basename_stem(obj.basename))
buckets.setdefault(key, []).append(obj)
for key in sorted(buckets):
clusters.append(
ClusterSpec(
name=key,
objects=sorted(buckets[key], key=_cluster_sort_key),
source="auto",
)
)
return clusters
# ---------------------------------------------------------------------------
# Download
# ---------------------------------------------------------------------------
def _local_path(
cfg: DownloadConfig,
cluster: ClusterSpec,
obj: S3Object,
basename_collisions: set,
) -> Path:
"""Resolve the on-disk destination path for *obj*.
Falls back to a key-derived filename when two objects in the same cluster
share a basename (possible under recursive scan).
"""
cluster_dir = cfg.local_folder / cluster.name
if obj.basename in basename_collisions:
safe = obj.key.replace("/", "__")
return cluster_dir / safe
return cluster_dir / obj.basename
def _basename_collisions(cluster: ClusterSpec) -> set:
"""Return the set of basenames that appear more than once in *cluster*."""
seen: Dict[str, int] = {}
for obj in cluster.objects:
seen[obj.basename] = seen.get(obj.basename, 0) + 1
return {name for name, count in seen.items() if count > 1}
def _decide_action(
local_path: Path, obj: S3Object, on_exists: str
) -> Tuple[str, Optional[str]]:
"""Return ``(action, message)`` where action is 'download' or 'skip'."""
if on_exists == "overwrite":
return ("download", None)
if not local_path.exists():
return ("download", None)
local_size = local_path.stat().st_size
if local_size == obj.size:
return (
"skip",
f" skip {obj.key} -> {local_path} (size {local_size} matches)",
)
if on_exists == "error":
raise RuntimeError(
f"Local file {local_path} exists with size {local_size} but S3 "
f"object s3://{obj.key} has size {obj.size} (on_exists=error)"
)
return (
"download",
f" re-download {obj.key} -> {local_path} "
f"(local size {local_size} != S3 {obj.size})",
)
def download_cluster(
s3_client,
cfg: DownloadConfig,
cluster: ClusterSpec,
*,
on_exists_override: Optional[str] = None,
fail_fast: bool = False,
) -> Tuple[int, int, int, List[Tuple[str, Exception]]]:
"""Download every object in *cluster* into ``cfg.local_folder/cluster.name``.
Returns ``(downloaded, skipped, bytes_downloaded, failures)`` where
*failures* is a list of ``(key, exception)`` tuples.
"""
if not cluster.objects:
return (0, 0, 0, [])
on_exists = on_exists_override or cfg.on_exists
cluster_dir = cfg.local_folder / cluster.name
cluster_dir.mkdir(parents=True, exist_ok=True)
collisions = _basename_collisions(cluster)
plans: List[Tuple[S3Object, Path, str]] = []
skipped = 0
for obj in cluster.objects:
local_path = _local_path(cfg, cluster, obj, collisions)
action, message = _decide_action(local_path, obj, on_exists)
if message:
print(message, file=sys.stderr)
if action == "skip":
skipped += 1
continue
plans.append((obj, local_path, action))
downloaded = 0
bytes_downloaded = 0
failures: List[Tuple[str, Exception]] = []
if not plans:
return (0, skipped, 0, failures)
def _do_one(item):
obj, local_path, _action = item
local_path.parent.mkdir(parents=True, exist_ok=True)
s3_client.download_file(cfg.bucket, obj.key, str(local_path))
return obj
if cfg.concurrency <= 1 or len(plans) == 1:
for plan in plans:
obj = plan[0]
try:
_do_one(plan)
downloaded += 1
bytes_downloaded += obj.size
print(
f" ok {obj.key} -> {plan[1]} ({obj.size:,} bytes)",
file=sys.stderr,
)
except Exception as exc:
failures.append((obj.key, exc))
print(
f" FAIL {obj.key}: {exc}", file=sys.stderr,
)
if fail_fast:
break
else:
with ThreadPoolExecutor(max_workers=cfg.concurrency) as pool:
future_to_plan = {pool.submit(_do_one, p): p for p in plans}
for fut in as_completed(future_to_plan):
plan = future_to_plan[fut]
obj = plan[0]
try:
fut.result()
downloaded += 1
bytes_downloaded += obj.size
print(
f" ok {obj.key} -> {plan[1]} "
f"({obj.size:,} bytes)",
file=sys.stderr,
)
except Exception as exc:
failures.append((obj.key, exc))
print(
f" FAIL {obj.key}: {exc}", file=sys.stderr,
)
if fail_fast:
for other in future_to_plan:
other.cancel()
break
return (downloaded, skipped, bytes_downloaded, failures)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
description=(
2026-04-22 01:05:26 +00:00
"Download S3 objects (SAS data files and/or delimited text files) "
"under a prefix into a local folder, grouping objects into "
"clusters that each become one subfolder. "
"Supported extensions: "
+ ", ".join(DEFAULT_EXTENSIONS)
+ "."
),
)
p.add_argument(
"--config", required=True, type=Path, help="Path to YAML config",
)
p.add_argument(
"--dry-run",
action="store_true",
help=(
"List discovered clusters and the objects each would download. "
"No GET requests are issued beyond the initial LIST."
),
)
p.add_argument(
"--overwrite",
action="store_true",
help=(
"Force-redownload every matched object regardless of local "
"cache state. Equivalent to on_exists=overwrite."
),
)
p.add_argument(
"--fail-fast",
action="store_true",
help=(
"Abort on the first per-file download failure. Default is to "
"log the failure and keep going."
),
)
return p
def _format_bytes(n: int) -> str:
units = ("B", "KiB", "MiB", "GiB", "TiB")
size = float(n)
for unit in units:
if size < 1024 or unit == units[-1]:
return f"{size:,.1f} {unit}" if unit != "B" else f"{int(size):,} B"
size /= 1024
return f"{n} B"
def _describe_cluster(cluster: ClusterSpec) -> str:
src = cluster.source
if cluster.pattern:
src += f" pattern={cluster.pattern!r}"
if not cluster.objects:
return (
f"cluster {cluster.name!r} [{src}]\n objects: (no matching keys)"
)
total = sum(o.size for o in cluster.objects)
lines = [
f"cluster {cluster.name!r} [{src}] "
f"{len(cluster.objects)} object(s), {_format_bytes(total)}"
]
for obj in cluster.objects:
lines.append(f" - {obj.key} ({_format_bytes(obj.size)})")
return "\n".join(lines)
def main(argv: Optional[List[str]] = None) -> int:
args = _build_argparser().parse_args(argv)
cfg = load_download_config(args.config)
s3 = build_s3_client(cfg)
objects = list_s3_objects(s3, cfg)
if not objects:
print(
f"error: no objects matching extensions "
f"{list(cfg.extensions)} found under "
f"s3://{cfg.bucket}/{cfg.prefix}",
file=sys.stderr,
)
return 2
clusters = discover_clusters(cfg, objects)
loadable = [c for c in clusters if c.objects]
print(
f"discovered {len(loadable)} cluster(s) "
f"({sum(len(c.objects) for c in loadable)} object(s)) under "
f"s3://{cfg.bucket}/{cfg.prefix}:"
)
for c in clusters:
print(_describe_cluster(c))
if args.dry_run:
return 0
cfg.local_folder.mkdir(parents=True, exist_ok=True)
on_exists_override = "overwrite" if args.overwrite else None
totals: List[Tuple[str, int, int, int]] = [] # (name, dl, skip, bytes)
failures: List[Tuple[str, str, Exception]] = [] # (cluster, key, exc)
aborted = False
for cluster in loadable:
if aborted:
break
print(
f"\n>>> downloading cluster {cluster.name!r} "
f"({len(cluster.objects)} object(s))"
)
try:
dl, sk, by, fails = download_cluster(
s3, cfg, cluster,
on_exists_override=on_exists_override,
fail_fast=args.fail_fast,
)
except Exception as exc:
failures.append((cluster.name, "<cluster>", exc))
print(
f" !! cluster {cluster.name!r} aborted: {exc}",
file=sys.stderr,
)
if args.fail_fast:
aborted = True
continue
totals.append((cluster.name, dl, sk, by))
for key, exc in fails:
failures.append((cluster.name, key, exc))
if fails and args.fail_fast:
aborted = True
print("\n=== summary ===")
for name, dl, sk, by in totals:
print(
f" {name}: downloaded {dl}, skipped {sk}, "
f"{_format_bytes(by)}"
)
for cname, key, exc in failures:
print(f" FAIL {cname} {key}: {exc}", file=sys.stderr)
return 1 if failures else 0
if __name__ == "__main__":
sys.exit(main())