Source code for lcmd_db.dataset._base

"""Abstract Dataset base class with lazy loading and rich indexing."""

from __future__ import annotations

import random
from abc import abstractmethod
from collections.abc import Iterator
from typing import TYPE_CHECKING, Generic, TypeVar, overload

import polars as pl

if TYPE_CHECKING:
    import ase
    import pandas as pd

from ..types import EntityMetadata, PropertyInfo
from ._source import DataSource, EagerSource, LazySource

E = TypeVar("E")

_EXCLUDE_KEYS: frozenset[str] = frozenset(("id",))


def _extract_properties(
    row: dict[str, object], exclude: frozenset[str] = _EXCLUDE_KEYS
) -> dict[str, object]:
    return {k: v for k, v in row.items() if k not in exclude}


[docs] class Dataset(Generic[E]): # noqa: UP046 — PEP 695 syntax requires 3.12+, we support 3.10 """Generic lazy dataset backed by a tabular data source. Provides lazy loading, integer/slice indexing, polars-based filtering, column selection, train/test splitting, and export to polars, pandas, or ASE formats. """
[docs] def __init__( self, source: DataSource, *, metadata: EntityMetadata | None = None ) -> None: self._source = source self._metadata = metadata self._df: pl.DataFrame | None = None
def __len__(self) -> int: if self._df is not None: return self._df.height return self._source.lazy.select(pl.len()).collect().item() @overload def __getitem__(self, idx: int) -> E: ... @overload def __getitem__(self, idx: slice | list[int]) -> Dataset[E]: ... def __getitem__(self, idx: int | slice | list[int]) -> E | Dataset[E]: if isinstance(idx, int): return self._resolve_entry(self.df.row(idx, named=True)) return self._with_source(EagerSource(self.df[idx])) def __iter__(self) -> Iterator[E]: return (self._resolve_entry(row) for row in self.df.iter_rows(named=True)) def __repr__(self) -> str: return f"{type(self).__name__}(columns={self._source.columns}, len={len(self)})" @abstractmethod def _resolve_entry(self, row: dict[str, object]) -> E: ... @abstractmethod def _with_source(self, source: DataSource) -> Dataset[E]: ... @property def df(self) -> pl.DataFrame: """Materialised polars DataFrame (collected on first access).""" if self._df is None: self._df = self._source.collect() return self._df @property def lazy(self) -> pl.LazyFrame: """Underlying polars LazyFrame for deferred computation.""" return self._source.lazy @property def columns(self) -> list[str]: """Column names available in the dataset.""" return self._source.columns @property def properties(self) -> list[PropertyInfo] | None: """Property metadata from the API, if available.""" return self._metadata.properties if self._metadata else None
[docs] def select(self, *columns: str) -> Dataset[E]: """Return a new dataset restricted to the given columns (``id`` is always kept). Args: *columns: Column names to keep. """ new_lf = self._source.lazy.select(sorted({"id", *columns})) return self._with_source(LazySource(new_lf))
[docs] def filter(self, expr: pl.Expr) -> Dataset[E]: """Return a new dataset containing only rows matching the expression. Args: expr: A polars expression, e.g. ``pl.col("weight") > 100``. """ return self._with_source(LazySource(self._source.lazy.filter(expr)))
[docs] def train_test_split( self, test_size: float = 0.2, *, seed: int = 42 ) -> tuple[Dataset[E], Dataset[E]]: """Split into train and test datasets by random shuffling. Args: test_size: Fraction of data to use for the test set. seed: Random seed for reproducibility. """ n = len(self) indices = list(range(n)) rng = random.Random(seed) rng.shuffle(indices) split = int(n * (1 - test_size)) train_idx, test_idx = sorted(indices[:split]), sorted(indices[split:]) return self[train_idx], self[test_idx]
[docs] def to_polars(self) -> pl.DataFrame: """Return the dataset as a polars DataFrame.""" return self.df
[docs] def to_pandas(self) -> pd.DataFrame: """Return the dataset as a pandas DataFrame.""" return self.df.to_pandas()
[docs] def to_ase(self) -> list[ase.Atoms]: """Convert each molecule entry to an ``ase.Atoms`` object. Requires ``ase`` to be installed and structure files to be available. """ from ..adapters.ase import convert return convert(self)