Source code for xpersist.cache

import datetime
import enum
import json
import tempfile
import typing

import fsspec
import pydantic

from .registry import registry
from .serializers import pick_serializer


[docs]class Artifact(pydantic.BaseModel): """A pydantic model for representing an artifact in the cache.""" key: str serializer: str load_kwargs: typing.Optional[typing.Dict] = pydantic.Field(default_factory=dict) dump_kwargs: typing.Optional[typing.Dict] = pydantic.Field(default_factory=dict) additional_metadata: typing.Optional[typing.Dict] = pydantic.Field(default_factory=dict) created_at: typing.Optional[datetime.datetime] = pydantic.Field( default_factory=datetime.datetime.utcnow ) _value: typing.Any = pydantic.PrivateAttr(default=None) class Config: validate_assignment = True
class DuplicateKeyEnum(str, enum.Enum): skip = 'skip' overwrite = 'overwrite' raise_error = 'raise_error'
[docs]@pydantic.dataclasses.dataclass class CacheStore: """Implements caching functionality using fsspec backends (local, s3fs, gcsfs, etc...). Some backends may require other dependencies. For example to work with S3 cache store, s3fs is required. Parameters ---------- path : str The path to the cache store. This can be a local directory or a cloud storage bucket. By default, the path is set to the temporary local directory. storage_options : dict fsspec parameters passed to the backend file-system such as Google Cloud Storage, Amazon Web Service S3. readonly : bool if True, the cache store is readonly. If False, the cache store is writable. on_duplicate_key : DuplicateKeyEnum The behavior when a key is duplicated in the cache store. Valid options are: - 'skip' (default): do nothing - 'overwrite': overwrite the existing artifact - 'raise_error': raise an error if the key is already in the cache store """ path: str = tempfile.gettempdir() readonly: bool = False on_duplicate_key: DuplicateKeyEnum = 'skip' storage_options: typing.Dict[typing.Any, typing.Any] = None def __post_init_post_parse__(self): self.storage_options = {} if self.storage_options is None else self.storage_options self.mapper = fsspec.get_mapper(self.path, **self.storage_options) self.raw_path = self.mapper.fs._strip_protocol(self.path) self.protocol = self.mapper.fs.protocol self._ensure_dir(self.raw_path) self._suffix = '.artifact.json' self._metadata_store_prefix = 'xpersist_metadata_store' self._metadata_store_path = self._construct_item_path(self._metadata_store_prefix) self._ensure_dir(self._metadata_store_path) def _ensure_dir(self, key: str) -> None: if not self.mapper.fs.exists(key): self.mapper.fs.makedirs(key, exist_ok=True) def _construct_item_path(self, key) -> str: return f'{self.path}/{key}' def _artifact_meta_relative_path(self, key: str) -> str: return f'{self._metadata_store_prefix}/{key}{self._suffix}' def _artifact_meta_full_path(self, key: str) -> str: return f'{self._metadata_store_path}/{key}{self._suffix}' def __contains__(self, key: str) -> bool: """Returns True if the key is in the cache store.""" return self._artifact_meta_relative_path(key) in self.mapper
[docs] def keys(self) -> typing.List[str]: """Returns a list of keys in the cache store.""" keys = self.mapper.fs.ls(self._metadata_store_path) return [ key.split(f'{self._metadata_store_prefix}/')[-1].split(self._suffix)[0] for key in keys ]
[docs] def delete(self, key: str, dry_run: bool = True) -> None: """Deletes the key and corresponding artifact from the cache store. Parameters ---------- key : str Key to delete from the cache store. dry_run : bool If True, the key is not deleted from the cache store. This is useful for debugging. """ if key not in self: raise KeyError(f'Key `{key}` not found in cache store.') paths = [self._artifact_meta_full_path(key), self._construct_item_path(key)] if not dry_run: for path in paths: self.mapper.fs.delete(path, recursive=True) else: print('DRY RUN: would delete items with the following paths:\n') for path in paths: print(f'* {path}') print('\nTo delete these items, call `delete(key, dry_run=False)`')
def __getitem__(self, key: str) -> typing.Any: """Returns the artifact corresponding to the key.""" return self.get(key) def __setitem__(self, key: str, value: typing.Any) -> None: """Sets the key and corresponding artifact in the cache store.""" self.put(key, value) def __delitem__(self, key: str) -> None: """Deletes the key and corresponding artifact from the cache store.""" self.delete(key, dry_run=False)
[docs] @pydantic.validate_arguments def get_artifact(self, key: str) -> Artifact: """Returns the artifact corresponding to the key. Parameters ---------- key : str Key to get from the cache store. Returns ------- artifact: Artifact The artifact corresponding to the key. Raises ------ KeyError If the key is not in the cache store. """ metadata_file = self._artifact_meta_relative_path(key) message = f'{key} not found in cache store: {self._metadata_store_path}' if key not in self: raise KeyError(message) try: return Artifact(**json.loads(self.mapper[metadata_file])) except Exception as exc: raise KeyError( f'Unable to load artifact sidecar file {metadata_file} for key: {key}' ) from exc
[docs] @pydantic.validate_arguments def get( self, key: str, serializer: str = None, load_kwargs: typing.Dict[typing.Any, typing.Any] = None, ) -> typing.Any: """Returns the value for the key if the key is in the cache store. Parameters ---------- key : str Key to get from the cache store. serializer : str The name of the serializer you want to use. The built-in serializers are: - 'auto' (default): automatically choose the serializer based on the type of the value - 'xarray.netcdf': requires xarray and netCDF4 - 'xarray.zarr': requires xarray and zarr - 'pandas.csv' : requires pandas - 'pandas.parquet': requires pandas and pyarrow or fastparquet You can also register your own serializer via the @xpersist.registry.serializers.register decorator. load_kwargs : dict Additional keyword arguments to pass to the serializer when loading artifact from the cache store. Returns ------- value : the value for the key if the key is in the cache store. Examples -------- >>> from xpersist import CacheStore >>> store = CacheStore("/tmp/my-cache") >>> store.keys() ['foo'] >>> store.get("foo") [1, 2, 3] """ artifact = self.get_artifact(key) try: serializer_name = serializer or artifact.serializer load_kwargs = load_kwargs or artifact.load_kwargs serializer = registry.serializers.get(serializer_name)() return serializer.load(self._construct_item_path(artifact.key), **load_kwargs) except Exception as exc: raise ValueError(f'Unable to load artifact {artifact.key} from cache store') from exc
[docs] @pydantic.validate_arguments def put( self, key: str, value: typing.Any, serializer: str = 'auto', dump_kwargs: typing.Dict[typing.Any, typing.Any] = None, additional_metadata: typing.Dict[typing.Any, typing.Any] = None, ) -> Artifact: """Records and serializes key with its corresponding value in the cache store. Parameters ---------- key : str Key to put in the cache store. value : typing.Any Value to put in the cache store. serializer : str The name of the serializer you want to use. The built-in serializers are: - 'auto' (default): automatically choose the serializer based on the type of the value - 'xarray.netcdf': requires xarray and netCDF4 - 'xarray.zarr': requires xarray and zarr - 'pandas.csv' : requires pandas - 'pandas.parquet': requires pandas and pyarrow or fastparquet You can also register your own serializer via the @xpersist.registry.serializers.register decorator. dump_kwargs : dict Additional keyword arguments to pass to the serializer when dumping artifact to the cache store. additional_metadata : dict A dict with types that serialize to json. These fields can be used for searching artifacts in the metadata store. Returns ------- value : typing.Any Reference to the value that was put in the cache store. Examples -------- >>> from xpersist import CacheStore >>> store = CacheStore("/tmp/my-cache") >>> store.keys() [] >>> store.put("foo", [1, 2, 3]) >>> store.keys() ['foo'] """ dump_kwargs = dump_kwargs or {} additional_metadata = additional_metadata or {} if not self.readonly: method = getattr(self, f'_put_{self.on_duplicate_key.value}') serializer_name = pick_serializer(value) if serializer == 'auto' else serializer artifact = Artifact( key=key, serializer=serializer_name, dump_kwargs=dump_kwargs, additional_metadata=additional_metadata, ) artifact._value = value method(artifact) return artifact._value
def _put_raise_error(self, artifact: Artifact) -> None: """Raises an error if the key is already in the cache store.""" if artifact.key in self: raise ValueError(f'Key {artifact.key} already in cache store') else: self._put_overwrite(artifact) def _put_skip(self, artifact: Artifact) -> None: """Does nothing if the key is already in the cache store.""" if artifact.key not in self: self._put_overwrite(artifact) def _put_overwrite(self, artifact: Artifact) -> None: serializer = registry.serializers.get(artifact.serializer)() with self.mapper.fs.transaction: serializer.dump( artifact._value, self._construct_item_path(artifact.key), **artifact.dump_kwargs ) with self.mapper.fs.open(self._artifact_meta_full_path(artifact.key), 'w') as fobj: fobj.write(artifact.json(indent=2))