from typing import Dict, Optional, List
[docs]class Registry:
"""
The registry that provides name -> object mapping.
To create a registry:
.. code-block:: python
MODEL_REGISTRY = Registry("MODEL")
To register an object with its ``__name__``:
.. code-block:: python
@MODEL_REGISTRY.register()
class ResNet50:
pass
# or
MODEL_REGISTRY.register(obj=ResNet50)
To register an object with a given name:
.. code-block:: python
@MODEL_REGISTRY.register("resnet")
class RestNet50:
pass
# or
MODEL_REGISTRY.register("resnet", ResNet50)
To get a registered object from registry:
.. code-block:: python
model_class = MODEL_REGISTRY.get("ResNet50")
# or
model_class = MODEL_REGISTRY.get("resnet")
"""
[docs] def __init__(self, name:str) -> None:
"""
Args:
name (str): name of this registry
"""
self._name = name
self._obj_map: Dict[str, object] = {}
def _do_register(self, name:str, obj:object) -> None:
if name in self._obj_map:
raise KeyError("An object named '{}' was already registered in '{}' registry.".format(name, self._name))
self._obj_map[name] = obj
[docs] def register(self, name:str=None, obj:object=None) -> Optional[object]:
"""
Register the given object with given name.
If the object is not given, it will act as a decorator.
Args:
name (str, optional): if not given, it will use `obj.__name__` as the name.
obj (object, optional): if not given, this method will return a decorator.
Returns:
Optional[object]: None or a decorator.
"""
if obj is None:
# use as a decorator
def decorator(func_or_class:object) -> object:
nonlocal name
if name is None:
name = func_or_class.__name__
self._do_register(name, func_or_class)
return func_or_class
return decorator
# use as a function call
if name is None:
name = obj.__name__
self._do_register(name, obj)
[docs] def unregister(self, name:str) -> None:
"""
Remove registered object.
Args:
name (str): registered name
"""
if name not in self._obj_map:
raise KeyError("An object named '{}' isn't registered in '{}' registry.".format(name, self._name))
else:
del self._obj_map[name]
[docs] def is_registered(self, name):
"""
Get whether the given name has been registered.
Args:
name (str):
Returns:
bool: whether the name has been registered.
"""
return name in self._obj_map
[docs] def get(self, name:str) -> object:
"""
Get a registered object from registry by its name.
Args:
name (str): registered name.
Returns:
object: registered object.
"""
if name not in self._obj_map:
raise KeyError(
"No object name '{}' found in '{}' registry.".format(
name, self._name
)
)
return self._obj_map[name]
[docs] def registered_names(self) -> List[str]:
"""
Get all registered names.
Returns:
list[str]: list of registered names.
"""
return list(self._obj_map.keys())