Source code for xpersist.registry

# Adapted from https://github.com/explosion/thinc
import sys
import typing

import catalogue

# Use typing_extensions for Python versions < 3.8
if sys.version_info < (3, 8):
    from typing_extensions import Protocol
else:
    from typing import Protocol


_DIn = typing.TypeVar('_DIn')


class Decorator(Protocol):
    def __call__(self, name: str) -> typing.Callable[[_DIn], _DIn]:
        ...


[docs]class registry: """xpersist's global registry entrypoint. This is used to register serializers and other components that are used by xpersist. """ serializers: Decorator = catalogue.create('xpersist', 'serializers', entry_points=True) metadata_store: Decorator = catalogue.create('xpersist', 'metadata_store', entry_points=True)
[docs] @classmethod def create(cls, registry_name: str, entry_points: bool = False) -> None: """Create a new custom registry.""" if hasattr(cls, registry_name): raise ValueError(f"Registry '{registry_name}' already exists") reg: Decorator = catalogue.create('xpersist', registry_name, entry_points=entry_points) setattr(cls, registry_name, reg)
[docs] @classmethod def has(cls, registry_name: str, func_name: str) -> bool: """Check whether a function is available in a registry. Parameters ---------- registry_name : str The name of the registry to check. func_name : str The name of the function to check. Returns ------- bool Whether the function is available in the registry. """ if not hasattr(cls, registry_name): return False reg = getattr(cls, registry_name) return func_name in reg
[docs] @classmethod def get(cls, registry_name: str, func_name: str) -> typing.Callable: """Get a registered function from a given registry. Parameters ---------- registry_name : str The name of the registry to get the function from. func_name : str The name of the function to get. Returns ------- func : typing.Callable The function from the registry. """ if not hasattr(cls, registry_name): raise ValueError(f"Unknown registry: '{registry_name}'") reg = getattr(cls, registry_name) func = reg.get(func_name) if func is None: raise ValueError(f"Could not find '{func_name}' in '{registry_name}'") return func