Add pyarrow dependency and optimize DataFrame serialization in load_sas.py

Included pyarrow as a new dependency in requirements.txt for improved CSV serialization performance. Refactored the _prepare_for_copy function to utilize vectorized operations for date and timestamp conversions, reducing CPU overhead. Introduced a new _serialize_chunk_csv function leveraging pyarrow for faster CSV writing, enhancing efficiency during data copying to Postgres.
This commit is contained in:
David Peterson 2026-04-20 21:32:56 -05:00
parent 5e347f50ef
commit 7beb44ac4d
2 changed files with 58 additions and 55 deletions

View File

@ -234,6 +234,8 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
import pandas as pd import pandas as pd
import psycopg2 import psycopg2
import psycopg2.extensions import psycopg2.extensions
import pyarrow as pa
import pyarrow.csv as pa_csv
import pyreadstat import pyreadstat
import yaml import yaml
from dotenv import load_dotenv from dotenv import load_dotenv
@ -599,7 +601,13 @@ def apply_column_filter(
exclude: Optional[List[str]], exclude: Optional[List[str]],
) -> pd.DataFrame: ) -> pd.DataFrame:
"""Restrict ``df`` to the requested columns. Names missing from the frame """Restrict ``df`` to the requested columns. Names missing from the frame
raise a clear error rather than silently dropping.""" raise a clear error rather than silently dropping.
Returns the input frame (or a column-sliced view / drop result) without
an extra ``.copy()`` downstream (:func:`_prepare_for_copy`) reads the
frame into a freshly built output and never mutates its input, so the
copies were pure overhead on every streamed chunk.
"""
if include is not None and exclude is not None: if include is not None and exclude is not None:
raise ValueError("include and exclude are mutually exclusive.") raise ValueError("include and exclude are mutually exclusive.")
@ -607,15 +615,15 @@ def apply_column_filter(
missing = [c for c in include if c not in df.columns] missing = [c for c in include if c not in df.columns]
if missing: if missing:
raise ValueError(f"include references unknown columns: {missing}") raise ValueError(f"include references unknown columns: {missing}")
return df.loc[:, list(include)].copy() return df.loc[:, list(include)]
if exclude is not None: if exclude is not None:
missing = [c for c in exclude if c not in df.columns] missing = [c for c in exclude if c not in df.columns]
if missing: if missing:
raise ValueError(f"exclude references unknown columns: {missing}") raise ValueError(f"exclude references unknown columns: {missing}")
return df.drop(columns=list(exclude)).copy() return df.drop(columns=list(exclude))
return df.copy() return df
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -1712,6 +1720,15 @@ def _seconds_to_time(v: Any) -> Optional[dt.time]:
def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.DataFrame: def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.DataFrame:
"""Materialize a copy of ``df`` with each column in the right shape for """Materialize a copy of ``df`` with each column in the right shape for
``to_csv`` so the CSV lands as valid input for the target Postgres type. ``to_csv`` so the CSV lands as valid input for the target Postgres type.
Per-column conversions are vectorized (``.astype`` / ``pd.to_datetime`` /
``.mask`` / ``.fillna``) instead of the element-wise ``.map(func)``
loops this function used to run. That was the single largest per-chunk
CPU cost on text-heavy loads - a 40-column × 100k-row chunk was issuing
~4M Python-level function calls just to cast strings. TIME columns are
still the ``.map`` path because SAS TIME8 is stored as seconds and the
clamp-to-24h logic doesn't fit cleanly in vector form; they're also
rare in practice.
""" """
out = pd.DataFrame(index=df.index) out = pd.DataFrame(index=df.index)
for name, spec in columns.items(): for name, spec in columns.items():
@ -1734,59 +1751,33 @@ def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.Da
if pd.api.types.is_datetime64_any_dtype(series): if pd.api.types.is_datetime64_any_dtype(series):
out[name] = series.dt.date out[name] = series.dt.date
elif pd.api.types.is_object_dtype(series): elif pd.api.types.is_object_dtype(series):
def _to_date(v: Any) -> Optional[dt.date]: # Vectorized parse: empty strings / None / unparseable -> NaT,
if v is None or (isinstance(v, float) and pd.isna(v)): # then .dt.date yields date objects or NaT. NaT serializes as
return None # an empty CSV field (matching ``NULL ''`` in COPY).
if isinstance(v, dt.datetime): parsed = pd.to_datetime(
return v.date() series.replace({"": None}), errors="coerce"
if isinstance(v, dt.date): )
return v out[name] = parsed.dt.date
if isinstance(v, str):
if v == "":
return None
try:
return dt.date.fromisoformat(v)
except ValueError:
return None
return None
out[name] = series.map(_to_date)
else: else:
out[name] = series out[name] = series
elif pg in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE", "TIMESTAMP WITH TIME ZONE"): elif pg in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE", "TIMESTAMP WITH TIME ZONE"):
if pd.api.types.is_datetime64_any_dtype(series): if pd.api.types.is_datetime64_any_dtype(series):
out[name] = series out[name] = series
elif pd.api.types.is_object_dtype(series): elif pd.api.types.is_object_dtype(series):
def _to_dt(v: Any) -> Optional[dt.datetime]: out[name] = pd.to_datetime(
if v is None or (isinstance(v, float) and pd.isna(v)): series.replace({"": None}), errors="coerce"
return None )
if isinstance(v, dt.datetime):
return v
if isinstance(v, dt.date):
return dt.datetime(v.year, v.month, v.day)
if isinstance(v, pd.Timestamp):
return v.to_pydatetime() if not pd.isna(v) else None
if isinstance(v, str):
if v == "":
return None
try:
return dt.datetime.fromisoformat(v)
except ValueError:
return None
return None
out[name] = series.map(_to_dt)
else: else:
out[name] = series out[name] = series
elif pg in ("TIME", "TIME WITHOUT TIME ZONE", "TIME WITH TIME ZONE"): elif pg in ("TIME", "TIME WITHOUT TIME ZONE", "TIME WITH TIME ZONE"):
out[name] = series.map(_seconds_to_time) out[name] = series.map(_seconds_to_time)
elif pg in ("TEXT", "VARCHAR", "CHARACTER VARYING", "CHAR", "CHARACTER"): elif pg in ("TEXT", "VARCHAR", "CHARACTER VARYING", "CHAR", "CHARACTER"):
# Leave empty strings as "" so `NULL ''` in COPY turns them into NULL. # Render every cell as a string and blank out nulls. ``NULL ''``
def _to_str(v: Any) -> Any: # in the COPY statement turns the blanks back into SQL NULL.
if v is None: # astype(str) stringifies NaN/None to the literal "nan"/"None",
return "" # so we mask those after the fact rather than branching per cell.
if isinstance(v, float) and pd.isna(v): na_mask = series.isna()
return "" out[name] = series.astype(str).mask(na_mask, "")
return str(v)
out[name] = series.map(_to_str)
elif pg == "BOOLEAN": elif pg == "BOOLEAN":
out[name] = series.astype("boolean") if series.dtype != object else series out[name] = series.astype("boolean") if series.dtype != object else series
else: else:
@ -1794,6 +1785,25 @@ def _prepare_for_copy(df: pd.DataFrame, columns: Dict[str, ColumnSpec]) -> pd.Da
return out return out
def _serialize_chunk_csv(prepared: pd.DataFrame) -> io.BytesIO:
"""Serialize a prepared frame into a CSV buffer for ``COPY FROM STDIN``.
Uses ``pyarrow.csv.write_csv`` (typically 5-10× faster than pandas'
pure-Python ``to_csv`` on wide/text-heavy frames). Null cells serialize
as empty strings and date/timestamp values land in ISO 8601 form, both
of which Postgres accepts under ``FORMAT csv, NULL ''``.
"""
table = pa.Table.from_pandas(prepared, preserve_index=False)
buf = io.BytesIO()
pa_csv.write_csv(
table,
buf,
write_options=pa_csv.WriteOptions(include_header=False),
)
buf.seek(0)
return buf
def copy_dataframes( def copy_dataframes(
conn, conn,
schema_name: str, schema_name: str,
@ -1821,15 +1831,7 @@ def copy_dataframes(
if df.empty: if df.empty:
continue continue
prepared = _prepare_for_copy(df, columns) prepared = _prepare_for_copy(df, columns)
buf = io.StringIO() buf = _serialize_chunk_csv(prepared)
prepared.to_csv(
buf,
index=False,
header=False,
na_rep="",
date_format="%Y-%m-%d %H:%M:%S",
)
buf.seek(0)
cur.copy_expert(sql, buf) cur.copy_expert(sql, buf)
conn.commit() conn.commit()
total += len(prepared) total += len(prepared)

View File

@ -1,6 +1,7 @@
pandas>=2.0,<3.0 pandas>=2.0,<3.0
pyreadstat>=1.2,<2.0 pyreadstat>=1.2,<2.0
numpy>=2.1,<3.0 numpy>=2.1,<3.0
pyarrow>=16.0,<21.0
pyyaml>=6.0,<7.0 pyyaml>=6.0,<7.0
psycopg2-binary>=2.9,<3.0 psycopg2-binary>=2.9,<3.0
python-dotenv>=1.0,<2.0 python-dotenv>=1.0,<2.0