Source code for imednet.endpoints.base

"""
Base endpoint mix-in for all API resource endpoints.
"""

from typing import Any, Dict, Iterable, Optional
from urllib.parse import quote

from imednet.core.async_client import AsyncClient
from imednet.core.client import Client
from imednet.core.context import Context


def _find_by_attr(items: Iterable[Any], attr: str, value: Any) -> Optional[Any]:
    """Find an item in an iterable by matching a stringified attribute."""
    target = str(value)
    for item in items:
        if str(getattr(item, attr)) == target:
            return item
    return None


[docs]class BaseEndpoint: """ Shared base for endpoint wrappers. Handles context injection and filtering. """ BASE_PATH = "/api/v1/edc/studies" PATH: str # to be set in subclasses def __init__( self, client: Client, ctx: Context, async_client: AsyncClient | None = None, ) -> None: self._client = client self._async_client = async_client self._ctx = ctx cache_name: Optional[str] = getattr(self, "_cache_name", None) if cache_name: if getattr(self, "requires_study_key", True): setattr(self, cache_name, {}) else: setattr(self, cache_name, None) def _auto_filter(self, filters: Dict[str, Any]) -> Dict[str, Any]: # inject default studyKey if missing if "studyKey" not in filters and self._ctx.default_study_key: filters["studyKey"] = self._ctx.default_study_key return filters def _build_path(self, *segments: Any) -> str: """Return an API path joined with :data:`BASE_PATH`.""" parts = [self.BASE_PATH.strip("/")] for seg in segments: text = str(seg).strip("/") if text: # Encode path segments to prevent traversal and injection parts.append(quote(text, safe="")) return "/" + "/".join(parts) # ------------------------------------------------------------------ # Helper methods # ------------------------------------------------------------------ def _fallback_from_list(self, study_key: str, item_id: Any, attr: str): """Return an item from ``list`` when direct ``get`` fails.""" item = _find_by_attr(self.list(study_key), attr, item_id) # type: ignore[attr-defined] if item: return item raise ValueError(f"{attr} {item_id} not found in study {study_key}") async def _async_fallback_from_list(self, study_key: str, item_id: Any, attr: str): items = await self.async_list(study_key) # type: ignore[attr-defined] item = _find_by_attr(items, attr, item_id) if item: return item raise ValueError(f"{attr} {item_id} not found in study {study_key}") def _require_async_client(self) -> AsyncClient: """Return the configured async client or raise if missing.""" if self._async_client is None: raise RuntimeError("Async client not configured") return self._async_client