729 lines
24 KiB
Python
729 lines
24 KiB
Python
|
|
"""S3-source counterpart to ``generic_loader/load_folder.py``.
|
||
|
|
|
||
|
|
Reads a YAML config that points at an S3 bucket + prefix, lists every object
|
||
|
|
under that prefix recursively, groups objects into *clusters* using the same
|
||
|
|
explicit-pattern + auto-detect rules as ``load_folder.py``, and downloads each
|
||
|
|
cluster's files into its own subfolder under a local destination root.
|
||
|
|
|
||
|
|
-------------------------------------------------------------------------------
|
||
|
|
USAGE
|
||
|
|
-------------------------------------------------------------------------------
|
||
|
|
|
||
|
|
1. YAML config
|
||
|
|
--------------
|
||
|
|
::
|
||
|
|
|
||
|
|
bucket: my-bucket # required
|
||
|
|
prefix: census/2020/raw/ # required; recursive scan under it
|
||
|
|
local_folder: ./downloads # required; one subfolder per cluster
|
||
|
|
aws_profile: default # optional; default boto3 chain if omitted
|
||
|
|
|
||
|
|
auto_detect: true # optional; default true
|
||
|
|
extensions: # optional; default sas7bdat/xpt/xport
|
||
|
|
- .sas7bdat
|
||
|
|
on_exists: skip # optional; skip | overwrite | error
|
||
|
|
concurrency: 4 # optional; default 4
|
||
|
|
|
||
|
|
clusters:
|
||
|
|
- pattern: '^group_a\\d+\\.sas7bdat$'
|
||
|
|
name: group_a
|
||
|
|
- pattern: '^group_b\\d+\\.sas7bdat$'
|
||
|
|
name: group_b
|
||
|
|
|
||
|
|
2. Command-line interface
|
||
|
|
-------------------------
|
||
|
|
::
|
||
|
|
|
||
|
|
python s3_download.py --config download_config.yaml [--dry-run]
|
||
|
|
[--overwrite] [--fail-fast]
|
||
|
|
|
||
|
|
Flags:
|
||
|
|
--config PATH Required. Path to the YAML config above.
|
||
|
|
--dry-run List discovered clusters and the objects each would
|
||
|
|
download (with sizes). No GET requests are issued
|
||
|
|
beyond the initial LIST.
|
||
|
|
--overwrite Force-redownload every matched object regardless of
|
||
|
|
local cache state. Equivalent to ``on_exists: overwrite``
|
||
|
|
for this run.
|
||
|
|
--fail-fast Abort the whole run on the first per-file download
|
||
|
|
failure. Default is to log the failure and keep going.
|
||
|
|
|
||
|
|
Exit codes:
|
||
|
|
0 - every file downloaded (or skipped) successfully (or dry-run completed)
|
||
|
|
1 - at least one file failed (details on stderr)
|
||
|
|
2 - the LIST returned no objects matching the configured extensions
|
||
|
|
|
||
|
|
3. Discovery rules
|
||
|
|
------------------
|
||
|
|
* Listing is recursive (no S3 ``Delimiter``). Regexes are matched against
|
||
|
|
the *basename* of each key (the part after the last ``/``), so a nested
|
||
|
|
object like ``census/2020/raw/nested/group_c1.sas7bdat`` is grouped by
|
||
|
|
``group_c1.sas7bdat`` alone.
|
||
|
|
* Explicit patterns are tried in order. A key matched by one pattern is
|
||
|
|
removed from the pool before the next pattern runs. Overlap between
|
||
|
|
patterns is flagged as an error at discovery time.
|
||
|
|
* Auto-detect groups remaining keys by ``re.sub(r'\\d+$', '', stem)`` with
|
||
|
|
any trailing ``_`` / ``-`` stripped afterward, mirroring
|
||
|
|
``load_folder.py``. Stems without trailing digits become singleton
|
||
|
|
clusters named after the stem.
|
||
|
|
* Within a cluster, files are sorted numerically by the LAST digit group
|
||
|
|
in the stem so ``_9_`` sorts before ``_40_`` regardless of zero-padding.
|
||
|
|
|
||
|
|
4. Library usage
|
||
|
|
----------------
|
||
|
|
::
|
||
|
|
|
||
|
|
from s3_download import load_download_config, list_s3_objects, \
|
||
|
|
discover_clusters, download_cluster, build_s3_client
|
||
|
|
|
||
|
|
cfg = load_download_config("download_config.yaml")
|
||
|
|
s3 = build_s3_client(cfg)
|
||
|
|
objects = list_s3_objects(s3, cfg)
|
||
|
|
for cluster in discover_clusters(cfg, objects):
|
||
|
|
download_cluster(s3, cfg, cluster)
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import re
|
||
|
|
import sys
|
||
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any, Dict, List, Optional, Tuple
|
||
|
|
|
||
|
|
import boto3
|
||
|
|
import yaml
|
||
|
|
|
||
|
|
|
||
|
|
DEFAULT_EXTENSIONS: Tuple[str, ...] = (".sas7bdat", ".xpt", ".xport")
|
||
|
|
VALID_ON_EXISTS: Tuple[str, ...] = ("skip", "overwrite", "error")
|
||
|
|
DEFAULT_CONCURRENCY: int = 4
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Dataclasses
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class _ExplicitPattern:
|
||
|
|
"""Parsed form of a single ``clusters[*]`` YAML entry."""
|
||
|
|
|
||
|
|
pattern: re.Pattern
|
||
|
|
raw_pattern: str
|
||
|
|
name: str
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class DownloadConfig:
|
||
|
|
"""Top-level configuration parsed from YAML."""
|
||
|
|
|
||
|
|
bucket: str
|
||
|
|
prefix: str
|
||
|
|
local_folder: Path
|
||
|
|
aws_profile: Optional[str] = None
|
||
|
|
auto_detect: bool = True
|
||
|
|
extensions: Tuple[str, ...] = DEFAULT_EXTENSIONS
|
||
|
|
on_exists: str = "skip"
|
||
|
|
concurrency: int = DEFAULT_CONCURRENCY
|
||
|
|
explicit: List[_ExplicitPattern] = field(default_factory=list)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class S3Object:
|
||
|
|
"""A single S3 object selected from the LIST response."""
|
||
|
|
|
||
|
|
key: str
|
||
|
|
basename: str
|
||
|
|
size: int
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class ClusterSpec:
|
||
|
|
"""Resolved per-cluster download settings."""
|
||
|
|
|
||
|
|
name: str
|
||
|
|
objects: List[S3Object]
|
||
|
|
source: str # "explicit" or "auto"
|
||
|
|
pattern: Optional[str] = None
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Config loading
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
def _validate_on_exists(value: Any, where: str) -> str:
|
||
|
|
s = str(value).lower()
|
||
|
|
if s not in VALID_ON_EXISTS:
|
||
|
|
raise ValueError(
|
||
|
|
f"{where}: on_exists={value!r} is not one of {list(VALID_ON_EXISTS)}"
|
||
|
|
)
|
||
|
|
return s
|
||
|
|
|
||
|
|
|
||
|
|
def _parse_extensions(raw_value: Any, where: str) -> Tuple[str, ...]:
|
||
|
|
if raw_value is None:
|
||
|
|
return DEFAULT_EXTENSIONS
|
||
|
|
if isinstance(raw_value, str):
|
||
|
|
items = [raw_value]
|
||
|
|
elif isinstance(raw_value, list):
|
||
|
|
items = list(raw_value)
|
||
|
|
else:
|
||
|
|
raise ValueError(
|
||
|
|
f"{where}: 'extensions' must be a string or list of strings."
|
||
|
|
)
|
||
|
|
out: List[str] = []
|
||
|
|
for i, item in enumerate(items):
|
||
|
|
if not isinstance(item, str) or not item.strip():
|
||
|
|
raise ValueError(
|
||
|
|
f"{where}: 'extensions[{i}]' must be a non-empty string."
|
||
|
|
)
|
||
|
|
ext = item.strip().lower()
|
||
|
|
if not ext.startswith("."):
|
||
|
|
ext = "." + ext
|
||
|
|
out.append(ext)
|
||
|
|
if len(out) != len(set(out)):
|
||
|
|
raise ValueError(f"{where}: 'extensions' contains duplicates.")
|
||
|
|
return tuple(out)
|
||
|
|
|
||
|
|
|
||
|
|
def _parse_concurrency(raw_value: Any, where: str) -> int:
|
||
|
|
if raw_value is None:
|
||
|
|
return DEFAULT_CONCURRENCY
|
||
|
|
try:
|
||
|
|
value = int(raw_value)
|
||
|
|
except (TypeError, ValueError):
|
||
|
|
raise ValueError(
|
||
|
|
f"{where}: 'concurrency' must be a positive integer, "
|
||
|
|
f"got {raw_value!r}"
|
||
|
|
)
|
||
|
|
if value <= 0:
|
||
|
|
raise ValueError(
|
||
|
|
f"{where}: 'concurrency' must be a positive integer, got {value}"
|
||
|
|
)
|
||
|
|
return value
|
||
|
|
|
||
|
|
|
||
|
|
def load_download_config(path: Path) -> DownloadConfig:
|
||
|
|
"""Parse and validate the YAML download config at ``path``."""
|
||
|
|
path = Path(path)
|
||
|
|
with path.open("r", encoding="utf-8") as f:
|
||
|
|
raw = yaml.safe_load(f)
|
||
|
|
|
||
|
|
if not isinstance(raw, dict):
|
||
|
|
raise ValueError(
|
||
|
|
f"Config at {path} must be a YAML mapping at the top level."
|
||
|
|
)
|
||
|
|
|
||
|
|
missing = [
|
||
|
|
k for k in ("bucket", "prefix", "local_folder") if k not in raw
|
||
|
|
]
|
||
|
|
if missing:
|
||
|
|
raise ValueError(
|
||
|
|
f"Config {path} missing required keys: {', '.join(missing)}"
|
||
|
|
)
|
||
|
|
|
||
|
|
bucket = str(raw["bucket"]).strip()
|
||
|
|
if not bucket:
|
||
|
|
raise ValueError(f"Config {path}: 'bucket' must be a non-empty string.")
|
||
|
|
|
||
|
|
prefix = str(raw["prefix"])
|
||
|
|
# Normalize: strip leading slash, ensure exactly one trailing slash unless
|
||
|
|
# the user explicitly asked for an empty prefix (whole bucket scan).
|
||
|
|
prefix = prefix.lstrip("/")
|
||
|
|
if prefix and not prefix.endswith("/"):
|
||
|
|
prefix = prefix + "/"
|
||
|
|
|
||
|
|
local_folder = Path(raw["local_folder"])
|
||
|
|
if not local_folder.is_absolute():
|
||
|
|
candidate = (path.parent / local_folder).resolve()
|
||
|
|
# Mirror load_folder's behavior: prefer the config-relative path when
|
||
|
|
# it exists, otherwise keep what the user wrote. Either way, we'll
|
||
|
|
# mkdir(parents=True) before downloading so non-existence is fine.
|
||
|
|
local_folder = candidate if candidate.parent.exists() else candidate
|
||
|
|
|
||
|
|
aws_profile = raw.get("aws_profile")
|
||
|
|
if aws_profile is not None:
|
||
|
|
aws_profile = str(aws_profile).strip() or None
|
||
|
|
|
||
|
|
auto_detect = bool(raw.get("auto_detect", True))
|
||
|
|
extensions = _parse_extensions(raw.get("extensions"), f"Config {path}")
|
||
|
|
on_exists = _validate_on_exists(
|
||
|
|
raw.get("on_exists", "skip"), f"Config {path}"
|
||
|
|
)
|
||
|
|
concurrency = _parse_concurrency(
|
||
|
|
raw.get("concurrency"), f"Config {path}"
|
||
|
|
)
|
||
|
|
|
||
|
|
explicit: List[_ExplicitPattern] = []
|
||
|
|
seen_names: set = set()
|
||
|
|
clusters_raw = raw.get("clusters") or []
|
||
|
|
if not isinstance(clusters_raw, list):
|
||
|
|
raise ValueError(f"Config {path}: 'clusters' must be a list if present.")
|
||
|
|
for i, entry in enumerate(clusters_raw):
|
||
|
|
where = f"Config {path} clusters[{i}]"
|
||
|
|
if not isinstance(entry, dict):
|
||
|
|
raise ValueError(f"{where} must be a mapping.")
|
||
|
|
if "pattern" not in entry or "name" not in entry:
|
||
|
|
raise ValueError(f"{where} must include 'pattern' and 'name'.")
|
||
|
|
raw_pat = str(entry["pattern"])
|
||
|
|
try:
|
||
|
|
compiled = re.compile(raw_pat)
|
||
|
|
except re.error as e:
|
||
|
|
raise ValueError(
|
||
|
|
f"{where}: invalid regex {raw_pat!r}: {e}"
|
||
|
|
) from e
|
||
|
|
name = str(entry["name"]).strip()
|
||
|
|
if not name:
|
||
|
|
raise ValueError(f"{where}: 'name' must be a non-empty string.")
|
||
|
|
if name in seen_names:
|
||
|
|
raise ValueError(
|
||
|
|
f"{where}: duplicate cluster name {name!r}"
|
||
|
|
)
|
||
|
|
seen_names.add(name)
|
||
|
|
explicit.append(
|
||
|
|
_ExplicitPattern(
|
||
|
|
pattern=compiled, raw_pattern=raw_pat, name=name,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
return DownloadConfig(
|
||
|
|
bucket=bucket,
|
||
|
|
prefix=prefix,
|
||
|
|
local_folder=local_folder,
|
||
|
|
aws_profile=aws_profile,
|
||
|
|
auto_detect=auto_detect,
|
||
|
|
extensions=extensions,
|
||
|
|
on_exists=on_exists,
|
||
|
|
concurrency=concurrency,
|
||
|
|
explicit=explicit,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# S3 listing
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
def build_s3_client(cfg: DownloadConfig):
|
||
|
|
"""Build a boto3 S3 client honoring ``cfg.aws_profile`` if set."""
|
||
|
|
if cfg.aws_profile:
|
||
|
|
session = boto3.Session(profile_name=cfg.aws_profile)
|
||
|
|
else:
|
||
|
|
session = boto3.Session()
|
||
|
|
return session.client("s3")
|
||
|
|
|
||
|
|
|
||
|
|
def list_s3_objects(s3_client, cfg: DownloadConfig) -> List[S3Object]:
|
||
|
|
"""List all objects under ``cfg.prefix`` recursively, filtered by extension."""
|
||
|
|
paginator = s3_client.get_paginator("list_objects_v2")
|
||
|
|
out: List[S3Object] = []
|
||
|
|
for page in paginator.paginate(Bucket=cfg.bucket, Prefix=cfg.prefix):
|
||
|
|
for entry in page.get("Contents", []):
|
||
|
|
key = entry["Key"]
|
||
|
|
if key.endswith("/"):
|
||
|
|
continue
|
||
|
|
basename = key.rsplit("/", 1)[-1]
|
||
|
|
ext = ("." + basename.rsplit(".", 1)[-1].lower()
|
||
|
|
if "." in basename else "")
|
||
|
|
if ext not in cfg.extensions:
|
||
|
|
continue
|
||
|
|
out.append(
|
||
|
|
S3Object(key=key, basename=basename, size=int(entry["Size"]))
|
||
|
|
)
|
||
|
|
out.sort(key=lambda o: o.key)
|
||
|
|
return out
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Cluster discovery (mirrors generic_loader/load_folder.py)
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
_TRAILING_DIGIT_RE = re.compile(r"\d+$")
|
||
|
|
_DIGIT_GROUP_RE = re.compile(r"\d+")
|
||
|
|
|
||
|
|
|
||
|
|
def _auto_prefix(stem: str) -> str:
|
||
|
|
"""Cluster key for *stem*: strip trailing digits and any trailing _/-."""
|
||
|
|
stripped = _TRAILING_DIGIT_RE.sub("", stem)
|
||
|
|
stripped = stripped.rstrip("_-")
|
||
|
|
return stripped or stem
|
||
|
|
|
||
|
|
|
||
|
|
def _basename_stem(basename: str) -> str:
|
||
|
|
if "." in basename:
|
||
|
|
return basename.rsplit(".", 1)[0]
|
||
|
|
return basename
|
||
|
|
|
||
|
|
|
||
|
|
def _cluster_sort_key(obj: S3Object) -> Tuple[int, str]:
|
||
|
|
"""Sort key for ordering objects within a cluster by trailing digits."""
|
||
|
|
stem = _basename_stem(obj.basename)
|
||
|
|
digits = _DIGIT_GROUP_RE.findall(stem)
|
||
|
|
n = int(digits[-1]) if digits else -1
|
||
|
|
return (n, stem)
|
||
|
|
|
||
|
|
|
||
|
|
def discover_clusters(
|
||
|
|
cfg: DownloadConfig, objects: List[S3Object]
|
||
|
|
) -> List[ClusterSpec]:
|
||
|
|
"""Bucket *objects* into clusters using explicit patterns then auto-detect."""
|
||
|
|
clusters: List[ClusterSpec] = []
|
||
|
|
|
||
|
|
for i, p_i in enumerate(cfg.explicit):
|
||
|
|
for j in range(i + 1, len(cfg.explicit)):
|
||
|
|
p_j = cfg.explicit[j]
|
||
|
|
for obj in objects:
|
||
|
|
if (p_i.pattern.search(obj.basename)
|
||
|
|
and p_j.pattern.search(obj.basename)):
|
||
|
|
raise ValueError(
|
||
|
|
f"Object {obj.basename!r} matches multiple explicit "
|
||
|
|
f"patterns: {p_i.raw_pattern!r} and "
|
||
|
|
f"{p_j.raw_pattern!r}"
|
||
|
|
)
|
||
|
|
|
||
|
|
remaining = list(objects)
|
||
|
|
for patt in cfg.explicit:
|
||
|
|
matched = [o for o in remaining if patt.pattern.search(o.basename)]
|
||
|
|
if not matched:
|
||
|
|
clusters.append(
|
||
|
|
ClusterSpec(
|
||
|
|
name=patt.name,
|
||
|
|
objects=[],
|
||
|
|
source="explicit",
|
||
|
|
pattern=patt.raw_pattern,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
continue
|
||
|
|
remaining = [o for o in remaining if o not in matched]
|
||
|
|
clusters.append(
|
||
|
|
ClusterSpec(
|
||
|
|
name=patt.name,
|
||
|
|
objects=sorted(matched, key=_cluster_sort_key),
|
||
|
|
source="explicit",
|
||
|
|
pattern=patt.raw_pattern,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
if cfg.auto_detect and remaining:
|
||
|
|
buckets: Dict[str, List[S3Object]] = {}
|
||
|
|
for obj in remaining:
|
||
|
|
key = _auto_prefix(_basename_stem(obj.basename))
|
||
|
|
buckets.setdefault(key, []).append(obj)
|
||
|
|
for key in sorted(buckets):
|
||
|
|
clusters.append(
|
||
|
|
ClusterSpec(
|
||
|
|
name=key,
|
||
|
|
objects=sorted(buckets[key], key=_cluster_sort_key),
|
||
|
|
source="auto",
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
return clusters
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Download
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
def _local_path(
|
||
|
|
cfg: DownloadConfig,
|
||
|
|
cluster: ClusterSpec,
|
||
|
|
obj: S3Object,
|
||
|
|
basename_collisions: set,
|
||
|
|
) -> Path:
|
||
|
|
"""Resolve the on-disk destination path for *obj*.
|
||
|
|
|
||
|
|
Falls back to a key-derived filename when two objects in the same cluster
|
||
|
|
share a basename (possible under recursive scan).
|
||
|
|
"""
|
||
|
|
cluster_dir = cfg.local_folder / cluster.name
|
||
|
|
if obj.basename in basename_collisions:
|
||
|
|
safe = obj.key.replace("/", "__")
|
||
|
|
return cluster_dir / safe
|
||
|
|
return cluster_dir / obj.basename
|
||
|
|
|
||
|
|
|
||
|
|
def _basename_collisions(cluster: ClusterSpec) -> set:
|
||
|
|
"""Return the set of basenames that appear more than once in *cluster*."""
|
||
|
|
seen: Dict[str, int] = {}
|
||
|
|
for obj in cluster.objects:
|
||
|
|
seen[obj.basename] = seen.get(obj.basename, 0) + 1
|
||
|
|
return {name for name, count in seen.items() if count > 1}
|
||
|
|
|
||
|
|
|
||
|
|
def _decide_action(
|
||
|
|
local_path: Path, obj: S3Object, on_exists: str
|
||
|
|
) -> Tuple[str, Optional[str]]:
|
||
|
|
"""Return ``(action, message)`` where action is 'download' or 'skip'."""
|
||
|
|
if on_exists == "overwrite":
|
||
|
|
return ("download", None)
|
||
|
|
if not local_path.exists():
|
||
|
|
return ("download", None)
|
||
|
|
local_size = local_path.stat().st_size
|
||
|
|
if local_size == obj.size:
|
||
|
|
return (
|
||
|
|
"skip",
|
||
|
|
f" skip {obj.key} -> {local_path} (size {local_size} matches)",
|
||
|
|
)
|
||
|
|
if on_exists == "error":
|
||
|
|
raise RuntimeError(
|
||
|
|
f"Local file {local_path} exists with size {local_size} but S3 "
|
||
|
|
f"object s3://{obj.key} has size {obj.size} (on_exists=error)"
|
||
|
|
)
|
||
|
|
return (
|
||
|
|
"download",
|
||
|
|
f" re-download {obj.key} -> {local_path} "
|
||
|
|
f"(local size {local_size} != S3 {obj.size})",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def download_cluster(
|
||
|
|
s3_client,
|
||
|
|
cfg: DownloadConfig,
|
||
|
|
cluster: ClusterSpec,
|
||
|
|
*,
|
||
|
|
on_exists_override: Optional[str] = None,
|
||
|
|
fail_fast: bool = False,
|
||
|
|
) -> Tuple[int, int, int, List[Tuple[str, Exception]]]:
|
||
|
|
"""Download every object in *cluster* into ``cfg.local_folder/cluster.name``.
|
||
|
|
|
||
|
|
Returns ``(downloaded, skipped, bytes_downloaded, failures)`` where
|
||
|
|
*failures* is a list of ``(key, exception)`` tuples.
|
||
|
|
"""
|
||
|
|
if not cluster.objects:
|
||
|
|
return (0, 0, 0, [])
|
||
|
|
|
||
|
|
on_exists = on_exists_override or cfg.on_exists
|
||
|
|
cluster_dir = cfg.local_folder / cluster.name
|
||
|
|
cluster_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
collisions = _basename_collisions(cluster)
|
||
|
|
|
||
|
|
plans: List[Tuple[S3Object, Path, str]] = []
|
||
|
|
skipped = 0
|
||
|
|
for obj in cluster.objects:
|
||
|
|
local_path = _local_path(cfg, cluster, obj, collisions)
|
||
|
|
action, message = _decide_action(local_path, obj, on_exists)
|
||
|
|
if message:
|
||
|
|
print(message, file=sys.stderr)
|
||
|
|
if action == "skip":
|
||
|
|
skipped += 1
|
||
|
|
continue
|
||
|
|
plans.append((obj, local_path, action))
|
||
|
|
|
||
|
|
downloaded = 0
|
||
|
|
bytes_downloaded = 0
|
||
|
|
failures: List[Tuple[str, Exception]] = []
|
||
|
|
|
||
|
|
if not plans:
|
||
|
|
return (0, skipped, 0, failures)
|
||
|
|
|
||
|
|
def _do_one(item):
|
||
|
|
obj, local_path, _action = item
|
||
|
|
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
s3_client.download_file(cfg.bucket, obj.key, str(local_path))
|
||
|
|
return obj
|
||
|
|
|
||
|
|
if cfg.concurrency <= 1 or len(plans) == 1:
|
||
|
|
for plan in plans:
|
||
|
|
obj = plan[0]
|
||
|
|
try:
|
||
|
|
_do_one(plan)
|
||
|
|
downloaded += 1
|
||
|
|
bytes_downloaded += obj.size
|
||
|
|
print(
|
||
|
|
f" ok {obj.key} -> {plan[1]} ({obj.size:,} bytes)",
|
||
|
|
file=sys.stderr,
|
||
|
|
)
|
||
|
|
except Exception as exc:
|
||
|
|
failures.append((obj.key, exc))
|
||
|
|
print(
|
||
|
|
f" FAIL {obj.key}: {exc}", file=sys.stderr,
|
||
|
|
)
|
||
|
|
if fail_fast:
|
||
|
|
break
|
||
|
|
else:
|
||
|
|
with ThreadPoolExecutor(max_workers=cfg.concurrency) as pool:
|
||
|
|
future_to_plan = {pool.submit(_do_one, p): p for p in plans}
|
||
|
|
for fut in as_completed(future_to_plan):
|
||
|
|
plan = future_to_plan[fut]
|
||
|
|
obj = plan[0]
|
||
|
|
try:
|
||
|
|
fut.result()
|
||
|
|
downloaded += 1
|
||
|
|
bytes_downloaded += obj.size
|
||
|
|
print(
|
||
|
|
f" ok {obj.key} -> {plan[1]} "
|
||
|
|
f"({obj.size:,} bytes)",
|
||
|
|
file=sys.stderr,
|
||
|
|
)
|
||
|
|
except Exception as exc:
|
||
|
|
failures.append((obj.key, exc))
|
||
|
|
print(
|
||
|
|
f" FAIL {obj.key}: {exc}", file=sys.stderr,
|
||
|
|
)
|
||
|
|
if fail_fast:
|
||
|
|
for other in future_to_plan:
|
||
|
|
other.cancel()
|
||
|
|
break
|
||
|
|
|
||
|
|
return (downloaded, skipped, bytes_downloaded, failures)
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# CLI
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
def _build_argparser() -> argparse.ArgumentParser:
|
||
|
|
p = argparse.ArgumentParser(
|
||
|
|
description=(
|
||
|
|
"Download S3 objects under a prefix into a local folder, "
|
||
|
|
"grouping objects into clusters that each become one subfolder."
|
||
|
|
),
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--config", required=True, type=Path, help="Path to YAML config",
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--dry-run",
|
||
|
|
action="store_true",
|
||
|
|
help=(
|
||
|
|
"List discovered clusters and the objects each would download. "
|
||
|
|
"No GET requests are issued beyond the initial LIST."
|
||
|
|
),
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--overwrite",
|
||
|
|
action="store_true",
|
||
|
|
help=(
|
||
|
|
"Force-redownload every matched object regardless of local "
|
||
|
|
"cache state. Equivalent to on_exists=overwrite."
|
||
|
|
),
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--fail-fast",
|
||
|
|
action="store_true",
|
||
|
|
help=(
|
||
|
|
"Abort on the first per-file download failure. Default is to "
|
||
|
|
"log the failure and keep going."
|
||
|
|
),
|
||
|
|
)
|
||
|
|
return p
|
||
|
|
|
||
|
|
|
||
|
|
def _format_bytes(n: int) -> str:
|
||
|
|
units = ("B", "KiB", "MiB", "GiB", "TiB")
|
||
|
|
size = float(n)
|
||
|
|
for unit in units:
|
||
|
|
if size < 1024 or unit == units[-1]:
|
||
|
|
return f"{size:,.1f} {unit}" if unit != "B" else f"{int(size):,} B"
|
||
|
|
size /= 1024
|
||
|
|
return f"{n} B"
|
||
|
|
|
||
|
|
|
||
|
|
def _describe_cluster(cluster: ClusterSpec) -> str:
|
||
|
|
src = cluster.source
|
||
|
|
if cluster.pattern:
|
||
|
|
src += f" pattern={cluster.pattern!r}"
|
||
|
|
if not cluster.objects:
|
||
|
|
return (
|
||
|
|
f"cluster {cluster.name!r} [{src}]\n objects: (no matching keys)"
|
||
|
|
)
|
||
|
|
total = sum(o.size for o in cluster.objects)
|
||
|
|
lines = [
|
||
|
|
f"cluster {cluster.name!r} [{src}] "
|
||
|
|
f"{len(cluster.objects)} object(s), {_format_bytes(total)}"
|
||
|
|
]
|
||
|
|
for obj in cluster.objects:
|
||
|
|
lines.append(f" - {obj.key} ({_format_bytes(obj.size)})")
|
||
|
|
return "\n".join(lines)
|
||
|
|
|
||
|
|
|
||
|
|
def main(argv: Optional[List[str]] = None) -> int:
|
||
|
|
args = _build_argparser().parse_args(argv)
|
||
|
|
|
||
|
|
cfg = load_download_config(args.config)
|
||
|
|
|
||
|
|
s3 = build_s3_client(cfg)
|
||
|
|
objects = list_s3_objects(s3, cfg)
|
||
|
|
|
||
|
|
if not objects:
|
||
|
|
print(
|
||
|
|
f"error: no objects matching extensions "
|
||
|
|
f"{list(cfg.extensions)} found under "
|
||
|
|
f"s3://{cfg.bucket}/{cfg.prefix}",
|
||
|
|
file=sys.stderr,
|
||
|
|
)
|
||
|
|
return 2
|
||
|
|
|
||
|
|
clusters = discover_clusters(cfg, objects)
|
||
|
|
loadable = [c for c in clusters if c.objects]
|
||
|
|
|
||
|
|
print(
|
||
|
|
f"discovered {len(loadable)} cluster(s) "
|
||
|
|
f"({sum(len(c.objects) for c in loadable)} object(s)) under "
|
||
|
|
f"s3://{cfg.bucket}/{cfg.prefix}:"
|
||
|
|
)
|
||
|
|
for c in clusters:
|
||
|
|
print(_describe_cluster(c))
|
||
|
|
|
||
|
|
if args.dry_run:
|
||
|
|
return 0
|
||
|
|
|
||
|
|
cfg.local_folder.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
on_exists_override = "overwrite" if args.overwrite else None
|
||
|
|
|
||
|
|
totals: List[Tuple[str, int, int, int]] = [] # (name, dl, skip, bytes)
|
||
|
|
failures: List[Tuple[str, str, Exception]] = [] # (cluster, key, exc)
|
||
|
|
aborted = False
|
||
|
|
for cluster in loadable:
|
||
|
|
if aborted:
|
||
|
|
break
|
||
|
|
print(
|
||
|
|
f"\n>>> downloading cluster {cluster.name!r} "
|
||
|
|
f"({len(cluster.objects)} object(s))"
|
||
|
|
)
|
||
|
|
try:
|
||
|
|
dl, sk, by, fails = download_cluster(
|
||
|
|
s3, cfg, cluster,
|
||
|
|
on_exists_override=on_exists_override,
|
||
|
|
fail_fast=args.fail_fast,
|
||
|
|
)
|
||
|
|
except Exception as exc:
|
||
|
|
failures.append((cluster.name, "<cluster>", exc))
|
||
|
|
print(
|
||
|
|
f" !! cluster {cluster.name!r} aborted: {exc}",
|
||
|
|
file=sys.stderr,
|
||
|
|
)
|
||
|
|
if args.fail_fast:
|
||
|
|
aborted = True
|
||
|
|
continue
|
||
|
|
totals.append((cluster.name, dl, sk, by))
|
||
|
|
for key, exc in fails:
|
||
|
|
failures.append((cluster.name, key, exc))
|
||
|
|
if fails and args.fail_fast:
|
||
|
|
aborted = True
|
||
|
|
|
||
|
|
print("\n=== summary ===")
|
||
|
|
for name, dl, sk, by in totals:
|
||
|
|
print(
|
||
|
|
f" {name}: downloaded {dl}, skipped {sk}, "
|
||
|
|
f"{_format_bytes(by)}"
|
||
|
|
)
|
||
|
|
for cname, key, exc in failures:
|
||
|
|
print(f" FAIL {cname} {key}: {exc}", file=sys.stderr)
|
||
|
|
|
||
|
|
return 1 if failures else 0
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
sys.exit(main())
|