Source code for lcmd_db.client

"""Core dataset loading implementation."""

from __future__ import annotations

import json
import logging
from collections.abc import Sequence
from pathlib import Path
from urllib.parse import urlencode

import polars as pl

from ._cache import DownloadCache
from .exceptions import DatasetNotFoundError, DownloadError
from .settings import settings
from .types import (
    DataFormat,
    Include,
    SubsetData,
    SubsetMetadata,
)

logger = logging.getLogger(__name__)


def _load_dataset_impl(
    subset: str,
    *,
    include: Sequence[Include] = ("molecules", "reactions", "fragments"),
    data_format: DataFormat | str = DataFormat.PARQUET,
    include_structures: bool = False,
    molecule_properties: Sequence[str] | None = None,
    reaction_properties: Sequence[str] | None = None,
    fragment_properties: Sequence[str] | None = None,
    cache_dir: str | Path | None = None,
    force_download: bool = False,
) -> SubsetData:
    """Download (or load from cache) a subset and return its data.

    Args:
        subset: Subset slug to download.
        include: Entity types to include in the download.
        data_format: Serialisation format for entity data files.
        include_structures: Whether to download XYZ structure files.
        molecule_properties: Restrict molecule columns to these properties.
        reaction_properties: Restrict reaction columns to these properties.
        fragment_properties: Restrict fragment columns to these properties.
        cache_dir: Override the default cache directory.
        force_download: Re-download even if cached data exists.
    """
    fmt = DataFormat(data_format) if isinstance(data_format, str) else data_format
    include_set = _build_include_set(include, include_structures)

    url = _build_download_url(
        subset,
        include_set,
        fmt,
        molecule_properties,
        reaction_properties,
        fragment_properties,
    )

    cache = DownloadCache(Path(cache_dir) if cache_dir is not None else None)
    cache_key = DownloadCache.compute_cache_key(
        subset,
        sorted(include_set),
        fmt.value,
        molecule_properties,
        reaction_properties,
        fragment_properties,
    )

    data_dir = _download(cache, url, subset, cache_key, force_download)

    result = SubsetData._from_download(
        data_dir=data_dir,
        data_format=fmt,
        molecules=_read_entity(data_dir, "molecules", fmt)
        if "molecules" in include_set
        else None,
        reactions=_read_entity(data_dir, "reactions", fmt)
        if "reactions" in include_set
        else None,
        fragments=_read_entity(data_dir, "fragments", fmt)
        if "fragments" in include_set
        else None,
        metadata=_read_metadata(data_dir),
    )

    if "structures" in include_set and result.molecules is not None:
        result.molecules = _attach_structure_paths(
            result.molecules, data_dir / "structures"
        )

    _trigger_stub_sync(cache)

    return result


[docs] def clear_cache( subset: str | None = None, cache_dir: str | Path | None = None, ) -> None: """Remove cached data files. Args: subset: If given, only clear the cache for this subset. Otherwise clear all. cache_dir: Override the default cache directory. """ DownloadCache(Path(cache_dir) if cache_dir is not None else None).clear(subset)
_VALID_INCLUDE: frozenset[str] = frozenset( {"molecules", "reactions", "fragments", "structures"} ) def _build_include_set(include: Sequence[str], include_structures: bool) -> set[str]: invalid = set(include) - _VALID_INCLUDE if invalid: raise ValueError( f"Invalid include values: {sorted(invalid)}. " f"Valid options: {sorted(_VALID_INCLUDE)}" ) result = set(include) if include_structures: result |= {"structures", "molecules"} return result def _build_download_url( subset: str, include_set: set[str], fmt: DataFormat, molecule_properties: Sequence[str] | None, reaction_properties: Sequence[str] | None, fragment_properties: Sequence[str] | None, ) -> str: params: dict[str, str] = { "include": ",".join(sorted(include_set)), "data_format": fmt.value, } for key, props in [ ("molecule_properties", molecule_properties), ("reaction_properties", reaction_properties), ("fragment_properties", fragment_properties), ]: if props is not None: params[key] = ",".join(props) return f"{settings.base_url}/api/v1/subsets/{subset}/download/?{urlencode(params)}" def _download( cache: DownloadCache, url: str, subset: str, cache_key: str, force: bool, ) -> Path: try: return cache.get_or_download(url, subset, cache_key, force=force) except FileNotFoundError as exc: raise DatasetNotFoundError(str(exc)) from exc except Exception as exc: if _is_not_found(exc): raise DatasetNotFoundError(f"Subset '{subset}' not found") from exc raise DownloadError(f"Failed to download dataset: {exc}") from exc def _is_not_found(exc: BaseException) -> bool: """Check if the exception represents an HTTP 404.""" response = getattr(exc, "response", None) if response is not None and getattr(response, "status_code", None) == 404: return True code = getattr(exc, "code", None) if code == 404: return True return False _READERS = { DataFormat.PARQUET: pl.read_parquet, DataFormat.CSV: pl.read_csv, DataFormat.TSV: lambda p: pl.read_csv(p, separator="\t"), DataFormat.XLSX: pl.read_excel, DataFormat.JSON: pl.read_json, } def _read_entity(data_dir: Path, entity: str, fmt: DataFormat) -> pl.DataFrame | None: data_file = data_dir / f"{entity}.{fmt.value}" return _READERS[fmt](data_file) if data_file.exists() else None def _attach_structure_paths(df: pl.DataFrame, structures_dir: Path) -> pl.DataFrame: if not structures_dir.exists(): return df return df.with_columns( pl.col("id") .map_elements( lambda mid: str(structures_dir / f"{mid}.xyz"), return_dtype=pl.Utf8 ) .alias("structure_path") ) def _read_metadata(data_dir: Path) -> SubsetMetadata | None: meta_file = data_dir / "metadata.json" if not meta_file.exists(): return None return SubsetMetadata.model_validate(json.loads(meta_file.read_bytes())) def _trigger_stub_sync(cache: DownloadCache) -> None: if not settings.auto_sync_stubs: return from ._sync import sync_stubs sync_stubs(cache)