from __future__ import annotations
import json
import sqlite3
import threading
from collections.abc import Iterator
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, cast
from tenacity import Retrying, retry_if_exception_type, stop_after_attempt, wait_exponential
from imednet.models.records import Record
from imednet.utils.filters import build_filter_string
from .chunked_pipeline import DEFAULT_CHUNK_SIZE
if TYPE_CHECKING:
from imednet.sdk import ImednetSDK
DEFAULT_CACHE_DIR = Path.home() / ".imednet" / "cache"
# Per-DB-path locks that serialise _initialise_cache across threads within the
# same process. Switching an SQLite database to WAL journal mode requires a
# brief exclusive lock; if multiple threads attempt the switch simultaneously
# they all race for that lock and the losers immediately raise
# ``sqlite3.OperationalError: database is locked``. Serialising initialisation
# eliminates the race entirely. The dict grows at most one entry per distinct
# database file, which is negligible.
_db_init_locks: dict[str, threading.Lock] = {}
_db_init_locks_guard = threading.Lock()
def _get_db_init_lock(resolved_path: Path) -> threading.Lock:
key = str(resolved_path)
with _db_init_locks_guard:
return _db_init_locks.setdefault(key, threading.Lock())
[docs]def get_cache_connection(db_path: str | Path) -> sqlite3.Connection:
"""Return a SQLite connection configured for concurrent cache access."""
resolved_path = Path(db_path).expanduser()
resolved_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(resolved_path, timeout=30.0)
conn.row_factory = sqlite3.Row
# busy_timeout instructs SQLite to retry at the C level on SQLITE_BUSY
# (e.g. during the WAL transition); this complements Python's connect
# timeout and is also effective in cross-process scenarios.
conn.execute("PRAGMA busy_timeout=30000;")
conn.execute("PRAGMA journal_mode=WAL;")
conn.execute("PRAGMA synchronous=NORMAL;")
return conn
[docs]class CachedRecordsLoader:
"""Load study records through a local SQLite cache with incremental sync."""
[docs] def __init__(
self,
sdk: "ImednetSDK",
*,
cache_dir: str | Path | None = None,
database_name: str = "records_cache.sqlite3",
retry_attempts: int = 3,
) -> None:
self._sdk = sdk
base_dir = DEFAULT_CACHE_DIR if cache_dir is None else Path(cache_dir).expanduser()
self.db_path = base_dir / database_name
self._retry_attempts = retry_attempts
self._initialise_cache()
[docs] def load_records(self, study_key: str, *, reconcile: bool = True) -> list[Record]:
"""Synchronise the cache for ``study_key`` and return cached records."""
self.sync_records(study_key, reconcile=reconcile)
return self.get_cached_records(study_key)
[docs] def sync_records(self, study_key: str, *, reconcile: bool = True) -> None:
"""Synchronise the cache for ``study_key`` without materialising cached rows."""
conn = get_cache_connection(self.db_path)
try:
high_water_mark = self._get_high_water_mark(conn, study_key)
delta_records = self._fetch_delta_records(study_key, high_water_mark)
self._upsert_records(conn, delta_records)
if reconcile:
active_record_ids = self._fetch_active_record_ids(study_key)
self.reconcile_cache(conn, study_key, active_record_ids)
finally:
conn.close()
[docs] def get_cached_records(
self, study_key: str, *, conn: sqlite3.Connection | None = None
) -> list[Record]:
"""Return cached records for ``study_key`` without contacting the API."""
return list(self.iter_cached_records(study_key, conn=conn))
[docs] def iter_cached_records(
self,
study_key: str,
*,
conn: sqlite3.Connection | None = None,
chunk_size: int = DEFAULT_CHUNK_SIZE,
) -> Iterator[Record]:
"""Yield cached records for ``study_key`` in bounded chunks."""
if chunk_size <= 0:
raise ValueError("chunk_size must be greater than zero")
close_conn = False
if conn is None:
conn = get_cache_connection(self.db_path)
close_conn = True
try:
cursor = conn.execute(
"""
SELECT payload
FROM record_cache
WHERE study_key = ?
ORDER BY record_id
""",
(study_key,),
)
while True:
rows = cursor.fetchmany(chunk_size)
if not rows:
break
for row in rows:
yield Record.from_json(json.loads(cast(str, row["payload"])))
finally:
if close_conn:
conn.close()
[docs] def reconcile_cache(
self, conn: sqlite3.Connection, study_key: str, active_record_ids: set[int]
) -> None:
"""Prune records removed from the upstream EDC backend."""
local_rows = conn.execute(
"SELECT record_id FROM record_cache WHERE study_key = ?",
(study_key,),
).fetchall()
local_ids = {cast(int, row["record_id"]) for row in local_rows}
orphaned_ids = local_ids - active_record_ids
if orphaned_ids:
with conn:
conn.executemany(
"DELETE FROM record_cache WHERE study_key = ? AND record_id = ?",
[(study_key, orphaned_id) for orphaned_id in orphaned_ids],
)
def _initialise_cache(self) -> None:
resolved = Path(self.db_path).expanduser().resolve()
with _get_db_init_lock(resolved):
conn = get_cache_connection(self.db_path)
try:
conn.execute("""
CREATE TABLE IF NOT EXISTS record_cache (
study_key TEXT NOT NULL,
record_id INTEGER NOT NULL,
form_key TEXT NOT NULL,
date_modified TEXT NOT NULL,
payload TEXT NOT NULL,
PRIMARY KEY (study_key, record_id)
)
""")
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_record_cache_study_modified
ON record_cache (study_key, date_modified)
""")
conn.commit()
finally:
conn.close()
def _get_high_water_mark(self, conn: sqlite3.Connection, study_key: str) -> str | None:
row = conn.execute(
"SELECT MAX(date_modified) AS max_date_modified FROM record_cache WHERE study_key = ?",
(study_key,),
).fetchone()
if row is None:
return None
return cast(str | None, row["max_date_modified"])
def _fetch_delta_records(self, study_key: str, high_water_mark: str | None) -> list[Record]:
if not high_water_mark:
return self._list_records(study_key=study_key, record_data_filter=None)
# Use >= to avoid missing updates that share the high-water-mark timestamp.
# _upsert_records keeps refresh idempotent by deduplicating on (study_key, record_id).
delta_filter = build_filter_string({"date_modified": (">=", high_water_mark)})
return self._list_records_with_filter_override(
study_key=study_key,
filter_string=delta_filter,
)
def _fetch_active_record_ids(self, study_key: str) -> set[int]:
records = self._list_records(study_key=study_key, record_data_filter=None, deleted=False)
return {record.record_id for record in records}
def _list_records(self, **filters: Any) -> list[Record]:
retryer = Retrying(
stop=stop_after_attempt(self._retry_attempts),
wait=wait_exponential(multiplier=1, min=1, max=8),
retry=retry_if_exception_type(Exception),
reraise=True,
)
return cast(list[Record], retryer(self._sdk.records.list, **filters))
def _list_records_with_filter_override(
self, *, study_key: str, filter_string: str
) -> list[Record]:
"""List records using an explicit raw ``filter`` query parameter.
This bypasses automatic filter construction so incremental sync can
send only the timestamp predicate (without ``studyKey``).
"""
retryer = Retrying(
stop=stop_after_attempt(self._retry_attempts),
wait=wait_exponential(multiplier=1, min=1, max=8),
retry=retry_if_exception_type(Exception),
reraise=True,
)
endpoint = self._sdk.records
list_sync = getattr(endpoint, "_list_sync", None)
require_sync_client = getattr(endpoint, "_require_sync_client", None)
paginator_cls = getattr(endpoint, "PAGINATOR_CLS", None)
if not callable(list_sync) or not callable(require_sync_client) or paginator_cls is None:
raise TypeError(
"Records endpoint does not support raw filter overrides: "
"_list_sync, _require_sync_client, and PAGINATOR_CLS are required"
)
return cast(
list[Record],
retryer(
# _list_sync signature:
# (client, paginator_cls, *, study_key, extra_params, **filters)
list_sync,
require_sync_client(),
paginator_cls,
study_key=study_key,
extra_params={"filter": filter_string},
record_data_filter=None,
),
)
def _upsert_records(self, conn: sqlite3.Connection, records: Iterable[Record]) -> None:
payloads = [
(
record.study_key,
record.record_id,
record.form_key,
record.date_modified.isoformat(),
json.dumps(record.model_dump(mode="json", by_alias=True), sort_keys=True),
)
for record in records
]
if not payloads:
return
with conn:
conn.executemany(
"""
INSERT INTO record_cache (study_key, record_id, form_key, date_modified, payload)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(study_key, record_id) DO UPDATE SET
form_key = excluded.form_key,
date_modified = excluded.date_modified,
payload = excluded.payload
""",
payloads,
)