"""Core dataset loading implementation."""
from __future__ import annotations
import json
import logging
from collections.abc import Sequence
from pathlib import Path
from urllib.parse import urlencode
import polars as pl
from ._cache import DownloadCache
from .exceptions import DatasetNotFoundError, DownloadError
from .settings import settings
from .types import (
DataFormat,
Include,
SubsetData,
SubsetMetadata,
)
logger = logging.getLogger(__name__)
def _load_dataset_impl(
subset: str,
*,
include: Sequence[Include] = ("molecules", "reactions", "fragments"),
data_format: DataFormat | str = DataFormat.PARQUET,
include_structures: bool = False,
molecule_properties: Sequence[str] | None = None,
reaction_properties: Sequence[str] | None = None,
fragment_properties: Sequence[str] | None = None,
cache_dir: str | Path | None = None,
force_download: bool = False,
) -> SubsetData:
"""Download (or load from cache) a subset and return its data.
Args:
subset: Subset slug to download.
include: Entity types to include in the download.
data_format: Serialisation format for entity data files.
include_structures: Whether to download XYZ structure files.
molecule_properties: Restrict molecule columns to these properties.
reaction_properties: Restrict reaction columns to these properties.
fragment_properties: Restrict fragment columns to these properties.
cache_dir: Override the default cache directory.
force_download: Re-download even if cached data exists.
"""
fmt = DataFormat(data_format) if isinstance(data_format, str) else data_format
include_set = _build_include_set(include, include_structures)
url = _build_download_url(
subset,
include_set,
fmt,
molecule_properties,
reaction_properties,
fragment_properties,
)
cache = DownloadCache(Path(cache_dir) if cache_dir is not None else None)
cache_key = DownloadCache.compute_cache_key(
subset,
sorted(include_set),
fmt.value,
molecule_properties,
reaction_properties,
fragment_properties,
)
data_dir = _download(cache, url, subset, cache_key, force_download)
result = SubsetData._from_download(
data_dir=data_dir,
data_format=fmt,
molecules=_read_entity(data_dir, "molecules", fmt)
if "molecules" in include_set
else None,
reactions=_read_entity(data_dir, "reactions", fmt)
if "reactions" in include_set
else None,
fragments=_read_entity(data_dir, "fragments", fmt)
if "fragments" in include_set
else None,
metadata=_read_metadata(data_dir),
)
if "structures" in include_set and result.molecules is not None:
result.molecules = _attach_structure_paths(
result.molecules, data_dir / "structures"
)
_trigger_stub_sync(cache)
return result
[docs]
def clear_cache(
subset: str | None = None,
cache_dir: str | Path | None = None,
) -> None:
"""Remove cached data files.
Args:
subset: If given, only clear the cache for this subset. Otherwise clear all.
cache_dir: Override the default cache directory.
"""
DownloadCache(Path(cache_dir) if cache_dir is not None else None).clear(subset)
_VALID_INCLUDE: frozenset[str] = frozenset(
{"molecules", "reactions", "fragments", "structures"}
)
def _build_include_set(include: Sequence[str], include_structures: bool) -> set[str]:
invalid = set(include) - _VALID_INCLUDE
if invalid:
raise ValueError(
f"Invalid include values: {sorted(invalid)}. "
f"Valid options: {sorted(_VALID_INCLUDE)}"
)
result = set(include)
if include_structures:
result |= {"structures", "molecules"}
return result
def _build_download_url(
subset: str,
include_set: set[str],
fmt: DataFormat,
molecule_properties: Sequence[str] | None,
reaction_properties: Sequence[str] | None,
fragment_properties: Sequence[str] | None,
) -> str:
params: dict[str, str] = {
"include": ",".join(sorted(include_set)),
"data_format": fmt.value,
}
for key, props in [
("molecule_properties", molecule_properties),
("reaction_properties", reaction_properties),
("fragment_properties", fragment_properties),
]:
if props is not None:
params[key] = ",".join(props)
return f"{settings.base_url}/api/v1/subsets/{subset}/download/?{urlencode(params)}"
def _download(
cache: DownloadCache,
url: str,
subset: str,
cache_key: str,
force: bool,
) -> Path:
try:
return cache.get_or_download(url, subset, cache_key, force=force)
except FileNotFoundError as exc:
raise DatasetNotFoundError(str(exc)) from exc
except Exception as exc:
if _is_not_found(exc):
raise DatasetNotFoundError(f"Subset '{subset}' not found") from exc
raise DownloadError(f"Failed to download dataset: {exc}") from exc
def _is_not_found(exc: BaseException) -> bool:
"""Check if the exception represents an HTTP 404."""
response = getattr(exc, "response", None)
if response is not None and getattr(response, "status_code", None) == 404:
return True
code = getattr(exc, "code", None)
if code == 404:
return True
return False
_READERS = {
DataFormat.PARQUET: pl.read_parquet,
DataFormat.CSV: pl.read_csv,
DataFormat.TSV: lambda p: pl.read_csv(p, separator="\t"),
DataFormat.XLSX: pl.read_excel,
DataFormat.JSON: pl.read_json,
}
def _read_entity(data_dir: Path, entity: str, fmt: DataFormat) -> pl.DataFrame | None:
data_file = data_dir / f"{entity}.{fmt.value}"
return _READERS[fmt](data_file) if data_file.exists() else None
def _attach_structure_paths(df: pl.DataFrame, structures_dir: Path) -> pl.DataFrame:
if not structures_dir.exists():
return df
return df.with_columns(
pl.col("id")
.map_elements(
lambda mid: str(structures_dir / f"{mid}.xyz"), return_dtype=pl.Utf8
)
.alias("structure_path")
)
def _read_metadata(data_dir: Path) -> SubsetMetadata | None:
meta_file = data_dir / "metadata.json"
if not meta_file.exists():
return None
return SubsetMetadata.model_validate(json.loads(meta_file.read_bytes()))
def _trigger_stub_sync(cache: DownloadCache) -> None:
if not settings.auto_sync_stubs:
return
from ._sync import sync_stubs
sync_stubs(cache)