From 495ffc4777522e40941753e3b1b79c02f84b25b4 Mon Sep 17 00:00:00 2001 From: Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:00:30 +0000 Subject: Add files via upload --- r_basicsr/utils/registry.py | 88 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 r_basicsr/utils/registry.py (limited to 'r_basicsr/utils/registry.py') diff --git a/r_basicsr/utils/registry.py b/r_basicsr/utils/registry.py new file mode 100644 index 0000000..1745e94 --- /dev/null +++ b/r_basicsr/utils/registry.py @@ -0,0 +1,88 @@ +# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 + + +class Registry(): + """ + The registry that provides name -> object mapping, to support third-party + users' custom modules. + + To create a registry (e.g. a backbone registry): + + .. code-block:: python + + BACKBONE_REGISTRY = Registry('BACKBONE') + + To register an object: + + .. code-block:: python + + @BACKBONE_REGISTRY.register() + class MyBackbone(): + ... + + Or: + + .. code-block:: python + + BACKBONE_REGISTRY.register(MyBackbone) + """ + + def __init__(self, name): + """ + Args: + name (str): the name of this registry + """ + self._name = name + self._obj_map = {} + + def _do_register(self, name, obj, suffix=None): + if isinstance(suffix, str): + name = name + '_' + suffix + + assert (name not in self._obj_map), (f"An object named '{name}' was already registered " + f"in '{self._name}' registry!") + self._obj_map[name] = obj + + def register(self, obj=None, suffix=None): + """ + Register the given object under the the name `obj.__name__`. + Can be used as either a decorator or not. + See docstring of this class for usage. + """ + if obj is None: + # used as a decorator + def deco(func_or_class): + name = func_or_class.__name__ + self._do_register(name, func_or_class, suffix) + return func_or_class + + return deco + + # used as a function call + name = obj.__name__ + self._do_register(name, obj, suffix) + + def get(self, name, suffix='basicsr'): + ret = self._obj_map.get(name) + if ret is None: + ret = self._obj_map.get(name + '_' + suffix) + print(f'Name {name} is not found, use name: {name}_{suffix}!') + if ret is None: + raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") + return ret + + def __contains__(self, name): + return name in self._obj_map + + def __iter__(self): + return iter(self._obj_map.items()) + + def keys(self): + return self._obj_map.keys() + + +DATASET_REGISTRY = Registry('dataset') +ARCH_REGISTRY = Registry('arch') +MODEL_REGISTRY = Registry('model') +LOSS_REGISTRY = Registry('loss') +METRIC_REGISTRY = Registry('metric') -- cgit v1.2.3