"""ReactionDataset — one row per Reaction, with lazy participant joins."""
from __future__ import annotations
from pathlib import Path
from typing import Generic
from ..types import (
EntityMetadata,
Molecule,
Participant,
ParticipantRole,
Properties,
Reaction,
)
from ._base import Dataset, _extract_properties
from ._molecules import resolve_structure
from ._source import DataSource
[docs]
class ReactionDataset(Dataset[Reaction[Properties]], Generic[Properties]): # noqa: UP046
def __init__(
self,
source: DataSource,
*,
participants_source: DataSource | None = None,
molecules_source: DataSource | None = None,
structures_dir: Path | None = None,
metadata: EntityMetadata | None = None,
) -> None:
super().__init__(source, metadata=metadata)
self._participants_source = participants_source
self._molecules_source = molecules_source
self._structures_dir = structures_dir
self._participants_by_rxn: dict[int, list[dict[str, object]]] | None = None
self._molecules_index: dict[int, dict[str, object]] | None = None
def _ensure_relationships(self) -> None:
if self._participants_by_rxn is None and self._participants_source is not None:
grouped: dict[int, list[dict[str, object]]] = {}
for row in (
self._participants_source.collect()
.sort("step_from")
.iter_rows(named=True)
):
grouped.setdefault(int(row["reaction_id"]), []).append(row) # type: ignore[arg-type]
self._participants_by_rxn = grouped
if self._molecules_index is None and self._molecules_source is not None:
self._molecules_index = {
int(row["id"]): row # type: ignore[arg-type]
for row in self._molecules_source.collect().iter_rows(named=True)
}
def _resolve_entry(self, row: dict[str, object]) -> Reaction[Properties]:
rxn_id = int(row["id"]) # type: ignore[arg-type]
self._ensure_relationships()
return Reaction(
id=rxn_id,
properties=_extract_properties(row), # type: ignore[arg-type]
participants=self._build_participants(rxn_id),
)
def _build_participants(self, rxn_id: int) -> list[Participant]:
if self._participants_by_rxn is None or self._molecules_index is None:
return []
return [
self._row_to_participant(r)
for r in self._participants_by_rxn.get(rxn_id, [])
]
def _row_to_participant(self, r: dict[str, object]) -> Participant:
mol_id = int(r["molecule_id"]) # type: ignore[arg-type]
mol_props = self._molecules_index.get(mol_id, {}) # type: ignore[union-attr]
return Participant(
molecule=Molecule(
id=mol_id,
properties=_extract_properties(mol_props),
structure_path=resolve_structure(self._structures_dir, mol_id),
),
role=ParticipantRole(str(r["role"])),
step_from=r.get("step_from"), # type: ignore[arg-type]
step_to=r.get("step_to"), # type: ignore[arg-type]
label=str(r.get("label", "")),
)
def _with_source(self, source: DataSource) -> ReactionDataset[Properties]:
ds = ReactionDataset(
source,
participants_source=self._participants_source,
molecules_source=self._molecules_source,
structures_dir=self._structures_dir,
metadata=self._metadata,
)
ds._participants_by_rxn = self._participants_by_rxn
ds._molecules_index = self._molecules_index
return ds