Source code for imednet.integrations.warehouse

"""Snowflake warehouse export sink.

This module implements the **warehouse export path** for a Snowflake
destination.  Study records are:

1. Written to Parquet files in a local staging directory (one file per batch).
2. Uploaded to the configured Snowflake internal stage via ``PUT``.
3. Bulk-loaded into the target table with ``COPY INTO ... FROM @<stage>``.

This two-phase approach decouples data preparation from bulk ingestion,
allows the Parquet files to be independently audited or re-uploaded, and
leverages Snowflake's native columnar loader for best throughput.

Manifest
--------
After each successful ``COPY INTO``, a manifest entry is appended to
``SinkConfig.extra["manifest_path"]`` (if provided):

.. code-block:: json

    {
        "batch_id":   "MYSTUDY/FORM1/0",
        "stage_path": "@MY_STAGE/imednet/MYSTUDY/FORM1/batch_0.parquet",
        "row_count":  500,
        "loaded_at":  "2024-01-15T12:00:00Z"
    }

Optional dependencies
---------------------
* ``snowflake-connector-python`` (``pip install 'imednet[snowflake]'``)
* ``pyarrow`` (included in ``imednet[snowflake]``)

Both are imported lazily at connection / write time.

Idempotency
-----------
When ``SinkConfig.idempotent`` is ``True`` (default) the sink uses
``COPY INTO ... FORCE = FALSE`` so that Snowflake skips files that have
already been loaded, making re-runs safe.  Set ``idempotent = False``
to force re-ingestion of previously loaded files.

Usage
-----
.. code-block:: python

    from imednet.integrations.warehouse import SnowflakeExportSink, SnowflakeSinkConfig

    config = SnowflakeSinkConfig(
        account="myorg-myaccount",
        user="loader",
        **{"password": os.environ["SF_PASS"]},  # keep credentials out of source code
        database="IMEDNET_DB",
        schema="PUBLIC",
        warehouse="COMPUTE_WH",
        stage="MY_STAGE",
        table="RECORDS",
        stage_prefix="imednet",
        local_staging_dir="/tmp/imednet_stage",
    )
    with SnowflakeExportSink(config=config) as sink:
        for i, batch in enumerate(batched(records, config.batch_size)):
            sink.write_batch(batch, batch_id=f"MYSTUDY/FORM1/{i}")
"""

from __future__ import annotations

import json
import logging
import os
import tempfile
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional, Sequence

from imednet.errors import ExportBatchError, ExportConfigurationError
from imednet.sdk import ImednetSDK

from .sink_base import ExportSink, SinkConfig, _require_optional_dep, iter_batches

logger = logging.getLogger(__name__)


[docs]@dataclass class SnowflakeSinkConfig(SinkConfig): """Configuration for :class:`SnowflakeExportSink`. Parameters ---------- account: Snowflake account identifier (``<org>-<account>`` or legacy format). user: Snowflake user name. password: Snowflake password. Never logged. database: Target database. schema: Target schema. warehouse: Virtual warehouse used for the ``COPY INTO`` command. stage: Snowflake internal stage name (e.g. ``"MY_STAGE"``). table: Destination table name inside *database*.*schema*. stage_prefix: Path prefix inside the stage (default ``"imednet"``). local_staging_dir: Local directory used to write Parquet files before ``PUT``. Defaults to a temporary directory created by :mod:`tempfile`. manifest_path: Optional path to a JSON-lines file where each loaded batch is recorded. """ account: str = "" user: str = "" password: str = field(default="", repr=False) database: str = "" schema: str = "PUBLIC" warehouse: str = "" stage: str = "" table: str = "" stage_prefix: str = "imednet" local_staging_dir: Optional[str | os.PathLike[str]] = None manifest_path: Optional[str | os.PathLike[str]] = None
def _records_to_arrow_table(records: Sequence[Any]) -> Any: """Convert *records* to a ``pyarrow.Table``.""" pa = _require_optional_dep("pyarrow", "snowflake") rows = [ { "record_id": getattr(r, "record_id", None), "form_id": getattr(r, "form_id", None), "visit_id": getattr(r, "visit_id", None), "subject_key": getattr(r, "subject_key", None), **dict(getattr(r, "record_data", {}) or {}), } for r in records ] return pa.Table.from_pylist(rows)
[docs]class SnowflakeExportSink(ExportSink): """Stage Parquet files and bulk-load them into Snowflake. Parameters ---------- config: :class:`SnowflakeSinkConfig` containing all connection details and staging paths. Raises ------ ~imednet.errors.ExportConfigurationError When the Snowflake connector cannot be initialised or the required configuration values are missing. ImportError When ``snowflake-connector-python`` or ``pyarrow`` are not installed. """
[docs] def __init__(self, config: Optional[SinkConfig] = None) -> None: cfg = config if isinstance(config, SnowflakeSinkConfig) else SnowflakeSinkConfig() super().__init__(cfg) self._cfg: SnowflakeSinkConfig = cfg self._conn: Any = None self._tmp_dir: Optional[tempfile.TemporaryDirectory[str]] = None self._connect()
# ------------------------------------------------------------------ # Connection management # ------------------------------------------------------------------ def _connect(self) -> None: cfg = self._cfg missing = [ name for name, val in [ ("account", cfg.account), ("user", cfg.user), ("password", cfg.password), ("database", cfg.database), ("warehouse", cfg.warehouse), ("stage", cfg.stage), ("table", cfg.table), ] if not val ] if missing: raise ExportConfigurationError( f"SnowflakeSinkConfig is missing required fields: {missing}" ) snowflake = _require_optional_dep("snowflake.connector", "snowflake") logger.debug("Connecting to Snowflake account=%s database=%s", cfg.account, cfg.database) try: self._conn = snowflake.connect( account=cfg.account, user=cfg.user, **{"password": cfg.password}, database=cfg.database, schema=cfg.schema, warehouse=cfg.warehouse, ) except Exception as exc: raise ExportConfigurationError( f"Cannot connect to Snowflake account '{cfg.account}': {exc}" ) from exc # Set up local staging directory if cfg.local_staging_dir: resolved_staging_dir = os.fspath(cfg.local_staging_dir) Path(resolved_staging_dir).mkdir(parents=True, exist_ok=True) self._staging_dir: str = resolved_staging_dir else: self._tmp_dir = tempfile.TemporaryDirectory() self._staging_dir = self._tmp_dir.name # ------------------------------------------------------------------ # ExportSink interface # ------------------------------------------------------------------
[docs] def write_batch(self, records: Sequence[Any], *, batch_id: str) -> int: """Write *records* to Snowflake via Parquet staging + COPY INTO. Parameters ---------- records: Sequence of typed ``Record`` model instances or plain dicts. batch_id: Idempotency key (e.g. ``"MYSTUDY/FORM1/0"``). Returns ------- int Number of rows loaded. """ if not records: return 0 # 1. Convert to Parquet arrow_table = _records_to_arrow_table(records) safe_batch = batch_id.replace("/", "_").replace(" ", "_") local_path = Path(self._staging_dir) / f"{safe_batch}.parquet" pq = _require_optional_dep("pyarrow.parquet", "snowflake") pq.write_table(arrow_table, str(local_path)) cfg = self._cfg stage_path = f"@{cfg.stage}/{cfg.stage_prefix}/{safe_batch}.parquet" last_exc: Optional[Exception] = None for attempt in range(self.config.max_retries + 1): try: cur = self._conn.cursor() # 2. PUT to stage cur.execute(f"PUT file://{local_path} @{cfg.stage}/{cfg.stage_prefix}/") # 3. COPY INTO table force_clause = "FORCE = FALSE" if self.config.idempotent else "FORCE = TRUE" cur.execute( f"COPY INTO {cfg.database}.{cfg.schema}.{cfg.table} " f"FROM @{cfg.stage}/{cfg.stage_prefix}/{safe_batch}.parquet " f"FILE_FORMAT = (TYPE = PARQUET) " f"MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE " f"{force_clause}" ) rows_loaded = len(records) cur.close() logger.debug( "Loaded batch %s (%d rows) via stage %s", batch_id, rows_loaded, stage_path, ) self._append_manifest(batch_id, stage_path, rows_loaded) return rows_loaded except Exception as exc: # noqa: BLE001 last_exc = exc if attempt < self.config.max_retries: delay = self.config.retry_backoff * (2**attempt) logger.warning( "Batch %s attempt %d failed (%s); retrying in %.1fs", batch_id, attempt + 1, exc, delay, ) time.sleep(delay) raise ExportBatchError( f"Batch {batch_id!r} failed after {self.config.max_retries + 1} " f"attempts: {last_exc}", batch_id=batch_id, )
[docs] def flush(self) -> None: """No-op: each batch is committed individually."""
[docs] def close(self) -> None: """Close the Snowflake connection and clean up temporary staging files.""" if self._conn is not None: try: self._conn.close() finally: self._conn = None if self._tmp_dir is not None: try: self._tmp_dir.cleanup() finally: self._tmp_dir = None
# ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _append_manifest(self, batch_id: str, stage_path: str, row_count: int) -> None: """Append a manifest entry for the loaded batch (JSON-lines format).""" manifest_path = self._cfg.manifest_path if not manifest_path: return entry = { "batch_id": batch_id, "stage_path": stage_path, "row_count": row_count, "loaded_at": datetime.now(tz=timezone.utc).isoformat(), } with open(manifest_path, "a", encoding="utf-8") as f: f.write(json.dumps(entry) + os.linesep)
[docs]def export_to_snowflake( sdk: ImednetSDK, study_key: str, *, config: SnowflakeSinkConfig, ) -> int: """Export study records to Snowflake using :class:`SnowflakeExportSink`.""" records = sdk.records.list(study_key=study_key, record_data_filter=None) total_written = 0 with SnowflakeExportSink(config=config) as sink: for index, batch in enumerate(iter_batches(records, config.batch_size)): total_written += sink.write_batch(batch, batch_id=f"{study_key}/records/{index}") return total_written
__all__ = ["SnowflakeExportSink", "SnowflakeSinkConfig", "export_to_snowflake"]