Source code for imednet_workflows.extraction_engine

from __future__ import annotations

from typing import Any

from pydantic import BaseModel, Field, ValidationError

from imednet.models import Record
from imednet.models.reporting import AdverseEvent, DeviceDeficiency, ProtocolDeviation
from imednet.models.study_config import MappingRule, StudyConfiguration

_DOMAIN_MODEL_MAP = {
    "AE": AdverseEvent,
    "PD": ProtocolDeviation,
    "DD": DeviceDeficiency,
}


[docs]class ExtractionResult(BaseModel): """Canonical extraction output grouped by reporting domain.""" adverse_events: list[AdverseEvent] = Field(default_factory=list) protocol_deviations: list[ProtocolDeviation] = Field(default_factory=list) device_deficiencies: list[DeviceDeficiency] = Field(default_factory=list) validation_errors: list[dict[str, Any]] = Field(default_factory=list)
def _get_from_path(value: Any, path: str) -> Any: current = value for part in path.split("."): if isinstance(current, dict): if part not in current: return None current = current[part] else: return None return current def _extract_rule_value_from_payload( record: Record, rule: MappingRule, top_level_payload: dict[str, Any] ) -> Any: source_path = rule.source_variable_name if source_path.startswith("recordData."): return _get_from_path(record.record_data, source_path[len("recordData.") :]) if source_path.startswith("record_data."): return _get_from_path(record.record_data, source_path[len("record_data.") :]) value = _get_from_path(top_level_payload, source_path) if value is not None: return value if "." not in source_path and isinstance(record.record_data, dict): return record.record_data.get(source_path) return None def _is_missing_value(value: Any) -> bool: return value is None or (isinstance(value, str) and value == "") def _group_mappings_by_domain_and_form( study_configuration: StudyConfiguration, ) -> dict[str, dict[str, list[MappingRule]]]: grouped: dict[str, dict[str, list[MappingRule]]] = {} for rule in study_configuration.mappings: domain_key = rule.domain.upper() if domain_key not in _DOMAIN_MODEL_MAP: continue grouped.setdefault(domain_key, {}).setdefault(rule.source_form_key, []).append(rule) return grouped
[docs]def extract_canonical_records( records: list[Record], study_configuration: StudyConfiguration ) -> ExtractionResult: """Extract canonical AE/PD/DD models from raw records using study mappings.""" result = ExtractionResult() grouped_mappings = _group_mappings_by_domain_and_form(study_configuration) for record in records: top_level_payload = { **record.model_dump(by_alias=False), **record.model_dump(by_alias=True), } for domain, by_form in grouped_mappings.items(): rules = by_form.get(record.form_key) if not rules: continue payload: dict[str, Any] = {} for rule in rules: value = _extract_rule_value_from_payload(record, rule, top_level_payload) if _is_missing_value(value) and rule.fallback_value is not None: value = rule.fallback_value payload[rule.target_field] = value model_type = _DOMAIN_MODEL_MAP[domain] try: model_instance = model_type(**payload) except ValidationError as exc: result.validation_errors.append( { "recordId": record.record_id, "formKey": record.form_key, "domain": domain, "payload": payload, "errors": exc.errors(), } ) continue if domain == "AE": result.adverse_events.append(model_instance) elif domain == "PD": result.protocol_deviations.append(model_instance) else: result.device_deficiencies.append(model_instance) return result
__all__ = ["ExtractionResult", "extract_canonical_records"]