Source code for lcmd_db.dataset._reactions

"""ReactionDataset — one row per Reaction, with lazy participant joins."""

from __future__ import annotations

from pathlib import Path
from typing import Generic

from ..types import (
    EntityMetadata,
    Molecule,
    Participant,
    ParticipantRole,
    Properties,
    Reaction,
)
from ._base import Dataset, _extract_properties
from ._molecules import resolve_structure
from ._source import DataSource


[docs] class ReactionDataset(Dataset[Reaction[Properties]], Generic[Properties]): # noqa: UP046 def __init__( self, source: DataSource, *, participants_source: DataSource | None = None, molecules_source: DataSource | None = None, structures_dir: Path | None = None, metadata: EntityMetadata | None = None, ) -> None: super().__init__(source, metadata=metadata) self._participants_source = participants_source self._molecules_source = molecules_source self._structures_dir = structures_dir self._participants_by_rxn: dict[int, list[dict[str, object]]] | None = None self._molecules_index: dict[int, dict[str, object]] | None = None def _ensure_relationships(self) -> None: if self._participants_by_rxn is None and self._participants_source is not None: grouped: dict[int, list[dict[str, object]]] = {} for row in ( self._participants_source.collect() .sort("step_from") .iter_rows(named=True) ): grouped.setdefault(int(row["reaction_id"]), []).append(row) # type: ignore[arg-type] self._participants_by_rxn = grouped if self._molecules_index is None and self._molecules_source is not None: self._molecules_index = { int(row["id"]): row # type: ignore[arg-type] for row in self._molecules_source.collect().iter_rows(named=True) } def _resolve_entry(self, row: dict[str, object]) -> Reaction[Properties]: rxn_id = int(row["id"]) # type: ignore[arg-type] self._ensure_relationships() return Reaction( id=rxn_id, properties=_extract_properties(row), # type: ignore[arg-type] participants=self._build_participants(rxn_id), ) def _build_participants(self, rxn_id: int) -> list[Participant]: if self._participants_by_rxn is None or self._molecules_index is None: return [] return [ self._row_to_participant(r) for r in self._participants_by_rxn.get(rxn_id, []) ] def _row_to_participant(self, r: dict[str, object]) -> Participant: mol_id = int(r["molecule_id"]) # type: ignore[arg-type] mol_props = self._molecules_index.get(mol_id, {}) # type: ignore[union-attr] return Participant( molecule=Molecule( id=mol_id, properties=_extract_properties(mol_props), structure_path=resolve_structure(self._structures_dir, mol_id), ), role=ParticipantRole(str(r["role"])), step_from=r.get("step_from"), # type: ignore[arg-type] step_to=r.get("step_to"), # type: ignore[arg-type] label=str(r.get("label", "")), ) def _with_source(self, source: DataSource) -> ReactionDataset[Properties]: ds = ReactionDataset( source, participants_source=self._participants_source, molecules_source=self._molecules_source, structures_dir=self._structures_dir, metadata=self._metadata, ) ds._participants_by_rxn = self._participants_by_rxn ds._molecules_index = self._molecules_index return ds