"""Hive-partitioned Parquet integration helpers."""
from __future__ import annotations
import inspect
from importlib import import_module
from types import ModuleType
from typing import Any, List, Optional
from ..sdk import ImednetSDK
from ..utils import validate_partition_key
from .export import _record_mapper
from .parquet_engine import PyArrowDatasetPartitionedStorageEngine
def _ensure_pyarrow() -> ModuleType:
try:
return import_module("pyarrow")
except ImportError as error:
raise ImportError(
"PyArrow is required for Hive Parquet export. Install with "
"\"pip install 'imednet[export]'\"."
) from error
def _cached_records_loader() -> Any:
try:
return import_module("imednet_workflows.cached_loader").CachedRecordsLoader
except ModuleNotFoundError as error:
if error.name and error.name.startswith("imednet_workflows"):
raise ImportError(
"Record export requires the optional 'imednet-workflows' package. "
"Install with `pip install imednet-workflows`."
) from error
raise
def _build_record_mapper(sdk: ImednetSDK, *, chunk_size: int) -> Any:
"""Build a workflows ``RecordMapper`` without importing the plugin at module import time."""
mapper_cls = _record_mapper()
parameters = inspect.signature(mapper_cls).parameters
kwargs: dict[str, Any] = {}
if "chunk_size" in parameters:
kwargs["chunk_size"] = chunk_size
if "loader" in parameters:
kwargs["loader"] = _cached_records_loader()(sdk)
return mapper_cls(sdk, **kwargs)
[docs]def export_to_hive_parquet(
sdk: ImednetSDK,
study_key: str,
base_dir: str,
*,
use_labels_as_columns: bool = False,
variable_whitelist: Optional[List[str]] = None,
form_whitelist: Optional[List[int]] = None,
chunk_size: int = 5_000,
) -> None:
"""Export study records to a Hive-partitioned Parquet directory layout."""
pyarrow_module = _ensure_pyarrow()
validate_partition_key(study_key)
mapper = _build_record_mapper(sdk, chunk_size=chunk_size)
storage_engine = PyArrowDatasetPartitionedStorageEngine()
forms = sdk.forms.list(study_key=study_key)
form_filter: dict[str, Any] = {"formIds": form_whitelist} if form_whitelist else {}
all_variables = sdk.variables.list(study_key=study_key, **form_filter)
variables_by_form: dict[int, list[Any]] = {}
for variable in all_variables:
variables_by_form.setdefault(variable.form_id, []).append(variable)
for form in forms:
if form_whitelist is not None and form.form_id not in form_whitelist:
continue
validate_partition_key(form.form_key)
variables = variables_by_form.get(form.form_id, [])
variable_keys = [
variable.variable_name
for variable in variables
if variable_whitelist is None or variable.variable_name in variable_whitelist
]
label_map = {
variable.variable_name: variable.label
for variable in variables
if variable.variable_name in variable_keys
}
record_model = mapper._build_record_model(variable_keys, label_map)
iter_records = getattr(mapper, "_iter_records", None)
iter_parsed_rows = getattr(mapper, "_iter_parsed_rows", None)
if callable(iter_records) and callable(iter_parsed_rows):
records = iter_records(
study_key,
extra_filters={
"formIds": [form.form_id],
**({"variableNames": variable_whitelist} if variable_whitelist else {}),
},
)
for rows, _ in iter_parsed_rows(records, record_model):
df = mapper._build_dataframe(
rows,
variable_keys,
label_map,
use_labels_as_columns,
)
if df.empty:
continue
table = pyarrow_module.Table.from_pandas(df, preserve_index=False)
storage_engine.write_form_table(
table,
base_dir=base_dir,
study_key=study_key,
form_key=form.form_key,
)
continue
records = mapper._fetch_records(
study_key,
extra_filters={
"formId": form.form_id,
**({"variableNames": variable_whitelist} if variable_whitelist else {}),
},
)
rows, _ = mapper._parse_records(records, record_model)
df = mapper._build_dataframe(
rows,
variable_keys,
label_map,
use_labels_as_columns,
)
table = pyarrow_module.Table.from_pandas(df, preserve_index=False)
storage_engine.write_form_table(
table,
base_dir=base_dir,
study_key=study_key,
form_key=form.form_key,
)
[docs]def hive_parquet_query(base_dir: str) -> str:
"""Return the DuckDB read_parquet query string for the given Hive base directory."""
escaped_base_dir = base_dir.replace("'", "''")
return (
f"SELECT * FROM read_parquet('{escaped_base_dir}/**/*.parquet', "
"hive_partitioning = true, union_by_name = true)"
)
__all__ = ["export_to_hive_parquet", "hive_parquet_query"]