"""Airflow hook for retrieving an :class:`ImednetSDK` instance."""
from __future__ import annotations
import json
from datetime import date, datetime
from typing import Any, Dict, List, Mapping, MutableMapping, TypeAlias, Union, cast
from airflow.hooks.base import BaseHook
from imednet.config import Config, load_config
from imednet.sdk import ImednetSDK
Primitive = Union[str, int, float, bool, None]
# Primitive-only payload contract for discovery helpers that feed Airflow mapping/XCom.
PrimitiveContainer: TypeAlias = Union[
Primitive, List["PrimitiveContainer"], Dict[str, "PrimitiveContainer"]
]
_SENSITIVE_KEYS = {
"api_key",
"security_key",
"authorization",
"token",
"x-api-key",
"x-imn-security-key",
}
[docs]class ImednetHook(BaseHook):
"""Retrieve an :class:`ImednetSDK` instance from an Airflow connection."""
[docs] def __init__(self, imednet_conn_id: str = "imednet_default") -> None:
super().__init__()
self.imednet_conn_id = imednet_conn_id
@staticmethod
def _string_or_none(value: object) -> str | None:
"""Return a stripped string or ``None`` for non-string/blank values."""
if not isinstance(value, str):
return None
cleaned = value.strip()
return cleaned or None
@staticmethod
def _parsed_extras(value: object) -> MutableMapping[str, object] | None:
"""Return parsed connection extras when value is a dict-like payload."""
if isinstance(value, Mapping):
return cast(MutableMapping[str, object], dict(value))
if isinstance(value, str):
try:
parsed = json.loads(value)
except json.JSONDecodeError:
return None
if isinstance(parsed, dict):
return cast(MutableMapping[str, object], parsed)
return None
@classmethod
def _connection_extras(cls, conn: object) -> MutableMapping[str, object]:
"""Resolve extras from Airflow connection objects across API versions."""
extras = cls._parsed_extras(getattr(conn, "extra_dejson", None))
if extras is not None:
return extras
get_extra = getattr(conn, "get_extra", None)
if callable(get_extra):
try:
raw_extra = get_extra()
except (AttributeError, TypeError, ValueError):
raw_extra = None
extras = cls._parsed_extras(raw_extra)
if extras is not None:
return extras
extras = cls._parsed_extras(getattr(conn, "extra", None))
if extras is not None:
return extras
return {}
def _resolved_config(self) -> Config:
"""Resolve hook configuration from Airflow connection fields and env fallback."""
from airflow.hooks.base import BaseHook
conn = BaseHook.get_connection(self.imednet_conn_id)
extras_dict = self._connection_extras(conn)
config = load_config(
api_key=self._string_or_none(extras_dict.get("api_key"))
or self._string_or_none(getattr(conn, "login", None)),
security_key=self._string_or_none(extras_dict.get("security_key"))
or self._string_or_none(getattr(conn, "password", None)),
base_url=self._string_or_none(extras_dict.get("base_url")),
)
return config
@classmethod
def _to_primitive(cls, value: Any) -> PrimitiveContainer:
"""Recursively normalize values to primitive containers with credential redaction.
Pydantic-style objects are first converted via ``model_dump(mode="json", by_alias=True)``.
Dictionaries are traversed recursively and sensitive keys are masked. Unknown
object types fall back to ``str(value)`` so discovery outputs remain serializable.
"""
if value is None or isinstance(value, (str, int, float, bool)):
return cast(Primitive, value)
if isinstance(value, (date, datetime)):
return value.isoformat()
if hasattr(value, "model_dump"):
dumped = value.model_dump(mode="json", by_alias=True)
value = cast(Any, dumped)
if isinstance(value, Mapping):
output: Dict[str, PrimitiveContainer] = {}
for key, item in value.items():
key_str = str(key)
if key_str.lower() in _SENSITIVE_KEYS:
output[key_str] = "***"
else:
output[key_str] = cls._to_primitive(item)
return output
if isinstance(value, (list, tuple, set)):
return [cls._to_primitive(item) for item in value]
return str(value)
[docs] def get_sdk_client(self) -> ImednetSDK:
"""Return an SDK client for use within task execution context."""
config = self._resolved_config()
return ImednetSDK(
api_key=config.api_key,
security_key=config.security_key,
base_url=config.base_url,
)
[docs] def get_conn(self) -> ImednetSDK: # type: ignore[override]
"""Backward compatible alias for :meth:`get_sdk_client`."""
return self.get_sdk_client()
[docs] def describe_connection(self) -> Dict[str, PrimitiveContainer]:
"""Return redacted primitive metadata about resolved hook configuration."""
config = self._resolved_config()
return {
"imednet_conn_id": self.imednet_conn_id,
"base_url": self._to_primitive(config.base_url),
"api_key": "***",
"security_key": "***",
"api_key_configured": bool(config.api_key),
"security_key_configured": bool(config.security_key),
}
[docs] def list_study_keys(self) -> List[str]:
"""Return primitive study keys for mapped Airflow task expansion."""
keys: List[str] = []
for study in self.list_studies_metadata():
study_key = study.get("studyKey") or study.get("study_key")
if isinstance(study_key, str) and study_key:
keys.append(study_key)
return keys
__all__ = ["ImednetHook"]