from __future__ import annotations
from typing import Any, Iterable, List, Optional, cast
from imednet.core.endpoint.abc import EndpointABC
from imednet.core.endpoint.operations.filter_get import FilterGetOperation
from imednet.core.endpoint.operations.get import PathGetOperation
from imednet.core.paginator import AsyncPaginator, Paginator
from imednet.core.protocols import AsyncRequestorProtocol, RequestorProtocol
from imednet.errors import ClientError, NotFoundError
from ..protocols import ListEndpointProtocol
from .parsing import ParsingMixin, T
[docs]class FilterGetEndpointMixin(EndpointABC[T]):
"""Mixin implementing ``get`` via filtering."""
# MODEL and _id_param are inherited from EndpointABC as abstract or properties
PAGINATOR_CLS: type[Paginator] = Paginator
ASYNC_PAGINATOR_CLS: type[AsyncPaginator] = AsyncPaginator
def _validate_get_result(self, items: List[T], study_key: Optional[str], item_id: Any) -> T:
if not items:
if self.requires_study_key:
raise NotFoundError(
f"{self.MODEL.__name__} {item_id} not found in study {study_key}"
)
raise NotFoundError(f"{self.MODEL.__name__} {item_id} not found")
return items[0]
def _get_sync(
self,
client: RequestorProtocol,
paginator_cls: type[Paginator],
*,
study_key: Optional[str],
item_id: Any,
) -> T:
filters = {self._id_param: item_id}
# Explicitly cast self to ListEndpointProtocol[T] since this mixin depends on
# a list endpoint being mixed in (like ListEndpointMixin).
list_endpoint = cast(ListEndpointProtocol[T], self)
operation = FilterGetOperation[T](
study_key=study_key,
item_id=item_id,
filters=filters,
validate_func=self._validate_get_result,
list_sync_func=list_endpoint._list_sync,
)
return operation.execute_sync(client, paginator_cls)
async def _get_async(
self,
client: AsyncRequestorProtocol,
paginator_cls: type[AsyncPaginator],
*,
study_key: Optional[str],
item_id: Any,
) -> T:
filters = {self._id_param: item_id}
# Explicitly cast self to ListEndpointProtocol[T] since this mixin depends on
# a list endpoint being mixed in (like ListEndpointMixin).
list_endpoint = cast(ListEndpointProtocol[T], self)
operation = FilterGetOperation[T](
study_key=study_key,
item_id=item_id,
filters=filters,
validate_func=self._validate_get_result,
list_async_func=list_endpoint._list_async,
)
return await operation.execute_async(client, paginator_cls)
[docs] def get(self, study_key: Optional[str], item_id: Any) -> T:
"""Get an item by ID using filtering."""
return self._get_sync(
self._require_sync_client(),
self.PAGINATOR_CLS,
study_key=study_key,
item_id=item_id,
)
[docs] async def async_get(self, study_key: Optional[str], item_id: Any) -> T:
"""Asynchronously get an item by ID using filtering."""
return await self._get_async(
self._require_async_client(),
self.ASYNC_PAGINATOR_CLS,
study_key=study_key,
item_id=item_id,
)
[docs]class PathGetEndpointMixin(ParsingMixin[T], EndpointABC[T]):
"""Mixin implementing ``get`` via direct path."""
# PATH is inherited from EndpointABC as abstract
def _get_path_for_id(self, study_key: Optional[str], item_id: Any) -> str:
segments: Iterable[Any]
if self.requires_study_key:
if not study_key:
raise ClientError("Study key must be provided")
segments = (study_key, self.PATH, item_id)
else:
segments = (self.PATH, item_id) if self.PATH else (item_id,)
# No cast needed as we inherit EndpointABC which defines _build_path
return self._build_path(*segments)
def _raise_not_found(self, study_key: Optional[str], item_id: Any) -> None:
raise NotFoundError(f"{self.MODEL.__name__} not found")
def _get_path_sync(
self,
client: RequestorProtocol,
*,
study_key: Optional[str],
item_id: Any,
) -> T:
path = self._get_path_for_id(study_key, item_id)
operation = PathGetOperation[T](
path=path,
parse_func=self._parse_item,
not_found_func=lambda: self._raise_not_found(study_key, item_id),
)
return operation.execute_sync(client)
async def _get_path_async(
self,
client: AsyncRequestorProtocol,
*,
study_key: Optional[str],
item_id: Any,
) -> T:
path = self._get_path_for_id(study_key, item_id)
operation = PathGetOperation[T](
path=path,
parse_func=self._parse_item,
not_found_func=lambda: self._raise_not_found(study_key, item_id),
)
return await operation.execute_async(client)
[docs] def get(self, study_key: Optional[str], item_id: Any) -> T:
"""Get an item by ID using direct path."""
return self._get_path_sync(
self._require_sync_client(),
study_key=study_key,
item_id=item_id,
)
[docs] async def async_get(self, study_key: Optional[str], item_id: Any) -> T:
"""Asynchronously get an item by ID using direct path."""
return await self._get_path_async(
self._require_async_client(),
study_key=study_key,
item_id=item_id,
)