"""
Base endpoint mix-in for all API resource endpoints.
"""
from __future__ import annotations
import warnings
from typing import Any, Callable, Dict, List, Optional, TypeVar
from imednet.constants import DEFAULT_PAGE_SIZE
from imednet.core.endpoint.abc import EndpointABC
from imednet.core.endpoint.operations import FilterGetOperation, ListOperation
from imednet.core.endpoint.strategies import (
DefaultParamProcessor,
KeepStudyKeyStrategy,
OptionalStudyKeyStrategy,
StudyKeyStrategy,
)
from imednet.core.endpoint.structs import ListRequestState, ParamState
from imednet.core.paginator import AsyncPaginator, Paginator
from imednet.core.parsing import get_model_parser
from imednet.core.protocols import AsyncRequestorProtocol, ParamProcessor, RequestorProtocol
from imednet.models.json_base import JsonModel
from imednet.utils.filters import build_filter_string
from imednet.utils.typing import FilterValue, ItemId
from imednet.utils.url import build_safe_path
T = TypeVar("T", bound=JsonModel)
[docs]class GenericEndpoint(EndpointABC[T]):
"""
Generic base for endpoint wrappers.
Handles context injection and basic path building.
Does NOT include EDC-specific logic.
"""
BASE_PATH = ""
_client: Optional[RequestorProtocol]
_async_client: Optional[AsyncRequestorProtocol]
[docs] def __init__(
self,
client: Optional[RequestorProtocol] = None,
ctx: object | None = None,
async_client: Optional[AsyncRequestorProtocol] = None,
) -> None:
if ctx is not None:
warnings.warn(
"The 'ctx' constructor argument is deprecated and ignored. "
"Pass study_key explicitly or use ImednetSDK.study_context(...).",
DeprecationWarning,
stacklevel=2,
)
self._client = client
self._async_client = async_client
def _auto_filter(self, filters: Dict[str, Any]) -> Dict[str, Any]:
"""Pass-through for filters in generic endpoints."""
return filters
def _build_path(self, *segments: Any) -> str:
"""
Return an API path joined with :data:`BASE_PATH`.
Args:
*segments: URL path segments to append.
Returns:
The full API path string.
"""
return "/" + build_safe_path(self.BASE_PATH, *segments)
def _require_sync_client(self) -> RequestorProtocol:
"""Return the configured sync client or raise if missing."""
if self._client is None:
raise RuntimeError("Sync client not configured")
return self._client
def _require_async_client(self) -> AsyncRequestorProtocol:
"""Return the configured async client or raise if missing."""
if self._async_client is None:
raise RuntimeError("Async client not configured")
return self._async_client
class _ListGetEndpointBase(GenericEndpoint[T]):
"""
Generic base for endpoints that provide list and get-by-filter functionality.
Uses composable operations to provide standard list/get read operations.
"""
PAGE_SIZE: int = DEFAULT_PAGE_SIZE
PAGINATOR_CLS: type[Paginator] = Paginator
ASYNC_PAGINATOR_CLS: type[AsyncPaginator] = AsyncPaginator
PARAM_PROCESSOR: Optional[ParamProcessor] = None
PARAM_PROCESSOR_CLS: type[ParamProcessor] = DefaultParamProcessor
STUDY_KEY_STRATEGY: Optional[StudyKeyStrategy] = None
@property
def study_key_strategy(self) -> StudyKeyStrategy:
if self.STUDY_KEY_STRATEGY:
return self.STUDY_KEY_STRATEGY
if self.requires_study_key:
return KeepStudyKeyStrategy()
return OptionalStudyKeyStrategy()
@property
def param_processor(self) -> ParamProcessor:
if self.PARAM_PROCESSOR:
return self.PARAM_PROCESSOR
return self.PARAM_PROCESSOR_CLS()
def _parse_item(self, item: Any) -> T:
parse_func = get_model_parser(self.MODEL)
return parse_func(item)
def _resolve_parse_func(self) -> Callable[[Any], T]:
return self._parse_item
def _resolve_params(
self,
study_key: Optional[str],
extra_params: Optional[Dict[str, Any]],
filters: Dict[str, Any],
) -> ParamState:
filters = self._auto_filter(filters.copy())
processed_filters, special_params = self.param_processor.process_filters(filters)
if special_params:
if extra_params is None:
extra_params = {}
else:
extra_params = extra_params.copy()
extra_params.update(special_params)
if study_key:
processed_filters["studyKey"] = study_key
study, processed_filters = self.study_key_strategy.process(processed_filters)
self._validate_study_key(study)
other_filters = {k: v for k, v in processed_filters.items() if k != "studyKey"}
params: Dict[str, Any] = {}
if processed_filters:
params["filter"] = build_filter_string(processed_filters)
if extra_params:
params.update(extra_params)
return ParamState(study=study, params=params, other_filters=other_filters)
def _prepare_list_request(
self,
study_key: Optional[str],
extra_params: Optional[Dict[str, Any]],
filters: Dict[str, Any],
) -> ListRequestState[T]:
param_state = self._resolve_params(study_key, extra_params, filters)
study = param_state.study
params = param_state.params
path = self._get_endpoint_path(study)
return ListRequestState(
path=path,
params=params,
study=study,
)
def _validate_get_result(self, items: List[T], study_key: Optional[str], item_id: ItemId) -> T:
if not items:
self._raise_not_found(study_key, item_id)
return items[0]
@staticmethod
def _require_item_id(item_id: ItemId) -> None:
if item_id is None:
raise TypeError("Missing required argument: item_id")
[docs]class SyncListGetEndpoint(_ListGetEndpointBase[T]):
[docs] def __init__(
self,
client: RequestorProtocol,
ctx: object | None = None,
) -> None:
super().__init__(client=client, ctx=ctx)
def _list_sync(
self,
client: RequestorProtocol,
paginator_cls: type[Paginator],
*,
study_key: Optional[str] = None,
extra_params: Optional[Dict[str, Any]] = None,
**filters: Any,
) -> List[T]:
state = self._prepare_list_request(study_key, extra_params, filters)
return ListOperation[T](
path=state.path,
params=state.params,
page_size=self.PAGE_SIZE,
parse_func=self._resolve_parse_func(),
).execute_sync(client, paginator_cls)
[docs] def list(self, study_key: Optional[str] = None, **filters: FilterValue) -> List[T]:
# Cast FilterValue → Any at the public/internal boundary to satisfy
# mypy's invariant dict type-checking on `_list_sync`'s **filters: Any.
_filters: Dict[str, Any] = dict(filters)
return self._list_sync(
self._require_sync_client(),
self.PAGINATOR_CLS,
study_key=study_key,
**_filters,
)
def _get_sync(
self,
client: RequestorProtocol,
paginator_cls: type[Paginator],
*,
study_key: Optional[str],
item_id: ItemId,
) -> T:
filters: Dict[str, Any] = {self._id_param: item_id}
operation = FilterGetOperation[T](
study_key=study_key,
item_id=item_id,
filters=filters,
validate_func=self._validate_get_result,
list_sync_func=self._list_sync,
)
return operation.execute_sync(client, paginator_cls)
[docs] def get(self, study_key: Optional[str], item_id: ItemId) -> T:
self._require_item_id(item_id)
return self._get_sync(
self._require_sync_client(),
self.PAGINATOR_CLS,
study_key=study_key,
item_id=item_id,
)
[docs]class AsyncListGetEndpoint(_ListGetEndpointBase[T]):
[docs] def __init__(
self,
async_client: AsyncRequestorProtocol,
ctx: object | None = None,
) -> None:
super().__init__(ctx=ctx, async_client=async_client)
async def _list_async(
self,
client: AsyncRequestorProtocol,
paginator_cls: type[AsyncPaginator],
*,
study_key: Optional[str] = None,
extra_params: Optional[Dict[str, Any]] = None,
**filters: Any,
) -> List[T]:
state = self._prepare_list_request(study_key, extra_params, filters)
return await ListOperation[T](
path=state.path,
params=state.params,
page_size=self.PAGE_SIZE,
parse_func=self._resolve_parse_func(),
).execute_async(client, paginator_cls)
[docs] async def async_list(self, study_key: Optional[str] = None, **filters: FilterValue) -> List[T]:
# Cast FilterValue → Any at the public/internal boundary.
_filters: Dict[str, Any] = dict(filters)
return await self._list_async(
self._require_async_client(),
self.ASYNC_PAGINATOR_CLS,
study_key=study_key,
**_filters,
)
async def _get_async(
self,
client: AsyncRequestorProtocol,
paginator_cls: type[AsyncPaginator],
*,
study_key: Optional[str],
item_id: ItemId,
) -> T:
filters: Dict[str, Any] = {self._id_param: item_id}
operation = FilterGetOperation[T](
study_key=study_key,
item_id=item_id,
filters=filters,
validate_func=self._validate_get_result,
list_async_func=self._list_async,
)
return await operation.execute_async(client, paginator_cls)
[docs] async def async_get(self, study_key: Optional[str], item_id: ItemId) -> T:
self._require_item_id(item_id)
return await self._get_async(
self._require_async_client(),
self.ASYNC_PAGINATOR_CLS,
study_key=study_key,
item_id=item_id,
)
# Backward-compatible alias. New code should use SyncListGetEndpoint / AsyncListGetEndpoint.
GenericListGetEndpoint = SyncListGetEndpoint