Source code for lcmd_db.dataset._base
"""Abstract Dataset base class with lazy loading and rich indexing."""
from __future__ import annotations
import random
from abc import abstractmethod
from collections.abc import Iterator
from typing import TYPE_CHECKING, Generic, TypeVar, overload
import polars as pl
if TYPE_CHECKING:
import ase
import pandas as pd
from ..types import EntityMetadata, PropertyInfo
from ._source import DataSource, EagerSource, LazySource
E = TypeVar("E")
_EXCLUDE_KEYS: frozenset[str] = frozenset(("id",))
def _extract_properties(
row: dict[str, object], exclude: frozenset[str] = _EXCLUDE_KEYS
) -> dict[str, object]:
return {k: v for k, v in row.items() if k not in exclude}
[docs]
class Dataset(Generic[E]): # noqa: UP046 — PEP 695 syntax requires 3.12+, we support 3.10
"""Generic lazy dataset backed by a tabular data source.
Provides lazy loading, integer/slice indexing, polars-based filtering,
column selection, train/test splitting, and export to polars, pandas,
or ASE formats.
"""
[docs]
def __init__(
self, source: DataSource, *, metadata: EntityMetadata | None = None
) -> None:
self._source = source
self._metadata = metadata
self._df: pl.DataFrame | None = None
def __len__(self) -> int:
if self._df is not None:
return self._df.height
return self._source.lazy.select(pl.len()).collect().item()
@overload
def __getitem__(self, idx: int) -> E: ...
@overload
def __getitem__(self, idx: slice | list[int]) -> Dataset[E]: ...
def __getitem__(self, idx: int | slice | list[int]) -> E | Dataset[E]:
if isinstance(idx, int):
return self._resolve_entry(self.df.row(idx, named=True))
return self._with_source(EagerSource(self.df[idx]))
def __iter__(self) -> Iterator[E]:
return (self._resolve_entry(row) for row in self.df.iter_rows(named=True))
def __repr__(self) -> str:
return f"{type(self).__name__}(columns={self._source.columns}, len={len(self)})"
@abstractmethod
def _resolve_entry(self, row: dict[str, object]) -> E: ...
@abstractmethod
def _with_source(self, source: DataSource) -> Dataset[E]: ...
@property
def df(self) -> pl.DataFrame:
"""Materialised polars DataFrame (collected on first access)."""
if self._df is None:
self._df = self._source.collect()
return self._df
@property
def lazy(self) -> pl.LazyFrame:
"""Underlying polars LazyFrame for deferred computation."""
return self._source.lazy
@property
def columns(self) -> list[str]:
"""Column names available in the dataset."""
return self._source.columns
@property
def properties(self) -> list[PropertyInfo] | None:
"""Property metadata from the API, if available."""
return self._metadata.properties if self._metadata else None
[docs]
def select(self, *columns: str) -> Dataset[E]:
"""Return a new dataset restricted to the given columns (``id`` is always kept).
Args:
*columns: Column names to keep.
"""
new_lf = self._source.lazy.select(sorted({"id", *columns}))
return self._with_source(LazySource(new_lf))
[docs]
def filter(self, expr: pl.Expr) -> Dataset[E]:
"""Return a new dataset containing only rows matching the expression.
Args:
expr: A polars expression, e.g. ``pl.col("weight") > 100``.
"""
return self._with_source(LazySource(self._source.lazy.filter(expr)))
[docs]
def train_test_split(
self, test_size: float = 0.2, *, seed: int = 42
) -> tuple[Dataset[E], Dataset[E]]:
"""Split into train and test datasets by random shuffling.
Args:
test_size: Fraction of data to use for the test set.
seed: Random seed for reproducibility.
"""
n = len(self)
indices = list(range(n))
rng = random.Random(seed)
rng.shuffle(indices)
split = int(n * (1 - test_size))
train_idx, test_idx = sorted(indices[:split]), sorted(indices[split:])
return self[train_idx], self[test_idx]
[docs]
def to_polars(self) -> pl.DataFrame:
"""Return the dataset as a polars DataFrame."""
return self.df
[docs]
def to_pandas(self) -> pd.DataFrame:
"""Return the dataset as a pandas DataFrame."""
return self.df.to_pandas()
[docs]
def to_ase(self) -> list[ase.Atoms]:
"""Convert each molecule entry to an ``ase.Atoms`` object.
Requires ``ase`` to be installed and structure files to be available.
"""
from ..adapters.ase import convert
return convert(self)