"""Snowflake warehouse export sink.
This module implements the **warehouse export path** for a Snowflake
destination. Study records are:
1. Written to Parquet files in a local staging directory (one file per batch).
2. Uploaded to the configured Snowflake internal stage via ``PUT``.
3. Bulk-loaded into the target table with ``COPY INTO ... FROM @<stage>``.
This two-phase approach decouples data preparation from bulk ingestion,
allows the Parquet files to be independently audited or re-uploaded, and
leverages Snowflake's native columnar loader for best throughput.
Manifest
--------
After each successful ``COPY INTO``, a manifest entry is appended to
``SinkConfig.extra["manifest_path"]`` (if provided):
.. code-block:: json
{
"batch_id": "MYSTUDY/FORM1/0",
"stage_path": "@MY_STAGE/imednet/MYSTUDY/FORM1/batch_0.parquet",
"row_count": 500,
"loaded_at": "2024-01-15T12:00:00Z"
}
Optional dependencies
---------------------
* ``snowflake-connector-python`` (``pip install 'imednet[snowflake]'``)
* ``pyarrow`` (included in ``imednet[snowflake]``)
Both are imported lazily at connection / write time.
Idempotency
-----------
When ``SinkConfig.idempotent`` is ``True`` (default) the sink uses
``COPY INTO ... FORCE = FALSE`` so that Snowflake skips files that have
already been loaded, making re-runs safe. Set ``idempotent = False``
to force re-ingestion of previously loaded files.
Usage
-----
.. code-block:: python
from imednet.integrations.warehouse import SnowflakeExportSink, SnowflakeSinkConfig
config = SnowflakeSinkConfig(
account="myorg-myaccount",
user="loader",
**{"password": os.environ["SF_PASS"]}, # keep credentials out of source code
database="IMEDNET_DB",
schema="PUBLIC",
warehouse="COMPUTE_WH",
stage="MY_STAGE",
table="RECORDS",
stage_prefix="imednet",
local_staging_dir="/tmp/imednet_stage",
)
with SnowflakeExportSink(config=config) as sink:
for i, batch in enumerate(batched(records, config.batch_size)):
sink.write_batch(batch, batch_id=f"MYSTUDY/FORM1/{i}")
"""
from __future__ import annotations
import json
import logging
import os
import tempfile
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional, Sequence
from imednet.errors import ExportBatchError, ExportConfigurationError
from imednet.sdk import ImednetSDK
from .sink_base import ExportSink, SinkConfig, _require_optional_dep, iter_batches
logger = logging.getLogger(__name__)
[docs]@dataclass
class SnowflakeSinkConfig(SinkConfig):
"""Configuration for :class:`SnowflakeExportSink`.
Parameters
----------
account:
Snowflake account identifier (``<org>-<account>`` or legacy format).
user:
Snowflake user name.
password:
Snowflake password. Never logged.
database:
Target database.
schema:
Target schema.
warehouse:
Virtual warehouse used for the ``COPY INTO`` command.
stage:
Snowflake internal stage name (e.g. ``"MY_STAGE"``).
table:
Destination table name inside *database*.*schema*.
stage_prefix:
Path prefix inside the stage (default ``"imednet"``).
local_staging_dir:
Local directory used to write Parquet files before ``PUT``.
Defaults to a temporary directory created by :mod:`tempfile`.
manifest_path:
Optional path to a JSON-lines file where each loaded batch is
recorded.
"""
account: str = ""
user: str = ""
password: str = field(default="", repr=False)
database: str = ""
schema: str = "PUBLIC"
warehouse: str = ""
stage: str = ""
table: str = ""
stage_prefix: str = "imednet"
local_staging_dir: Optional[str | os.PathLike[str]] = None
manifest_path: Optional[str | os.PathLike[str]] = None
def _records_to_arrow_table(records: Sequence[Any]) -> Any:
"""Convert *records* to a ``pyarrow.Table``."""
pa = _require_optional_dep("pyarrow", "snowflake")
rows = [
{
"record_id": getattr(r, "record_id", None),
"form_id": getattr(r, "form_id", None),
"visit_id": getattr(r, "visit_id", None),
"subject_key": getattr(r, "subject_key", None),
**dict(getattr(r, "record_data", {}) or {}),
}
for r in records
]
return pa.Table.from_pylist(rows)
[docs]class SnowflakeExportSink(ExportSink):
"""Stage Parquet files and bulk-load them into Snowflake.
Parameters
----------
config:
:class:`SnowflakeSinkConfig` containing all connection details and
staging paths.
Raises
------
~imednet.errors.ExportConfigurationError
When the Snowflake connector cannot be initialised or the required
configuration values are missing.
ImportError
When ``snowflake-connector-python`` or ``pyarrow`` are not installed.
"""
[docs] def __init__(self, config: Optional[SinkConfig] = None) -> None:
cfg = config if isinstance(config, SnowflakeSinkConfig) else SnowflakeSinkConfig()
super().__init__(cfg)
self._cfg: SnowflakeSinkConfig = cfg
self._conn: Any = None
self._tmp_dir: Optional[tempfile.TemporaryDirectory[str]] = None
self._connect()
# ------------------------------------------------------------------
# Connection management
# ------------------------------------------------------------------
def _connect(self) -> None:
cfg = self._cfg
missing = [
name
for name, val in [
("account", cfg.account),
("user", cfg.user),
("password", cfg.password),
("database", cfg.database),
("warehouse", cfg.warehouse),
("stage", cfg.stage),
("table", cfg.table),
]
if not val
]
if missing:
raise ExportConfigurationError(
f"SnowflakeSinkConfig is missing required fields: {missing}"
)
snowflake = _require_optional_dep("snowflake.connector", "snowflake")
logger.debug("Connecting to Snowflake account=%s database=%s", cfg.account, cfg.database)
try:
self._conn = snowflake.connect(
account=cfg.account,
user=cfg.user,
**{"password": cfg.password},
database=cfg.database,
schema=cfg.schema,
warehouse=cfg.warehouse,
)
except Exception as exc:
raise ExportConfigurationError(
f"Cannot connect to Snowflake account '{cfg.account}': {exc}"
) from exc
# Set up local staging directory
if cfg.local_staging_dir:
resolved_staging_dir = os.fspath(cfg.local_staging_dir)
Path(resolved_staging_dir).mkdir(parents=True, exist_ok=True)
self._staging_dir: str = resolved_staging_dir
else:
self._tmp_dir = tempfile.TemporaryDirectory()
self._staging_dir = self._tmp_dir.name
# ------------------------------------------------------------------
# ExportSink interface
# ------------------------------------------------------------------
[docs] def write_batch(self, records: Sequence[Any], *, batch_id: str) -> int:
"""Write *records* to Snowflake via Parquet staging + COPY INTO.
Parameters
----------
records:
Sequence of typed ``Record`` model instances or plain dicts.
batch_id:
Idempotency key (e.g. ``"MYSTUDY/FORM1/0"``).
Returns
-------
int
Number of rows loaded.
"""
if not records:
return 0
# 1. Convert to Parquet
arrow_table = _records_to_arrow_table(records)
safe_batch = batch_id.replace("/", "_").replace(" ", "_")
local_path = Path(self._staging_dir) / f"{safe_batch}.parquet"
pq = _require_optional_dep("pyarrow.parquet", "snowflake")
pq.write_table(arrow_table, str(local_path))
cfg = self._cfg
stage_path = f"@{cfg.stage}/{cfg.stage_prefix}/{safe_batch}.parquet"
last_exc: Optional[Exception] = None
for attempt in range(self.config.max_retries + 1):
try:
cur = self._conn.cursor()
# 2. PUT to stage
cur.execute(f"PUT file://{local_path} @{cfg.stage}/{cfg.stage_prefix}/")
# 3. COPY INTO table
force_clause = "FORCE = FALSE" if self.config.idempotent else "FORCE = TRUE"
cur.execute(
f"COPY INTO {cfg.database}.{cfg.schema}.{cfg.table} "
f"FROM @{cfg.stage}/{cfg.stage_prefix}/{safe_batch}.parquet "
f"FILE_FORMAT = (TYPE = PARQUET) "
f"MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE "
f"{force_clause}"
)
rows_loaded = len(records)
cur.close()
logger.debug(
"Loaded batch %s (%d rows) via stage %s",
batch_id,
rows_loaded,
stage_path,
)
self._append_manifest(batch_id, stage_path, rows_loaded)
return rows_loaded
except Exception as exc: # noqa: BLE001
last_exc = exc
if attempt < self.config.max_retries:
delay = self.config.retry_backoff * (2**attempt)
logger.warning(
"Batch %s attempt %d failed (%s); retrying in %.1fs",
batch_id,
attempt + 1,
exc,
delay,
)
time.sleep(delay)
raise ExportBatchError(
f"Batch {batch_id!r} failed after {self.config.max_retries + 1} "
f"attempts: {last_exc}",
batch_id=batch_id,
)
[docs] def flush(self) -> None:
"""No-op: each batch is committed individually."""
[docs] def close(self) -> None:
"""Close the Snowflake connection and clean up temporary staging files."""
if self._conn is not None:
try:
self._conn.close()
finally:
self._conn = None
if self._tmp_dir is not None:
try:
self._tmp_dir.cleanup()
finally:
self._tmp_dir = None
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _append_manifest(self, batch_id: str, stage_path: str, row_count: int) -> None:
"""Append a manifest entry for the loaded batch (JSON-lines format)."""
manifest_path = self._cfg.manifest_path
if not manifest_path:
return
entry = {
"batch_id": batch_id,
"stage_path": stage_path,
"row_count": row_count,
"loaded_at": datetime.now(tz=timezone.utc).isoformat(),
}
with open(manifest_path, "a", encoding="utf-8") as f:
f.write(json.dumps(entry) + os.linesep)
[docs]def export_to_snowflake(
sdk: ImednetSDK,
study_key: str,
*,
config: SnowflakeSinkConfig,
) -> int:
"""Export study records to Snowflake using :class:`SnowflakeExportSink`."""
records = sdk.records.list(study_key=study_key, record_data_filter=None)
total_written = 0
with SnowflakeExportSink(config=config) as sink:
for index, batch in enumerate(iter_batches(records, config.batch_size)):
total_written += sink.write_batch(batch, batch_id=f"{study_key}/records/{index}")
return total_written
__all__ = ["SnowflakeExportSink", "SnowflakeSinkConfig", "export_to_snowflake"]