summaryrefslogtreecommitdiffstats
path: root/r_basicsr/utils/registry.py
diff options
context:
space:
mode:
Diffstat (limited to 'r_basicsr/utils/registry.py')
-rw-r--r--r_basicsr/utils/registry.py88
1 files changed, 88 insertions, 0 deletions
diff --git a/r_basicsr/utils/registry.py b/r_basicsr/utils/registry.py
new file mode 100644
index 0000000..1745e94
--- /dev/null
+++ b/r_basicsr/utils/registry.py
@@ -0,0 +1,88 @@
+# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
+
+
+class Registry():
+ """
+ The registry that provides name -> object mapping, to support third-party
+ users' custom modules.
+
+ To create a registry (e.g. a backbone registry):
+
+ .. code-block:: python
+
+ BACKBONE_REGISTRY = Registry('BACKBONE')
+
+ To register an object:
+
+ .. code-block:: python
+
+ @BACKBONE_REGISTRY.register()
+ class MyBackbone():
+ ...
+
+ Or:
+
+ .. code-block:: python
+
+ BACKBONE_REGISTRY.register(MyBackbone)
+ """
+
+ def __init__(self, name):
+ """
+ Args:
+ name (str): the name of this registry
+ """
+ self._name = name
+ self._obj_map = {}
+
+ def _do_register(self, name, obj, suffix=None):
+ if isinstance(suffix, str):
+ name = name + '_' + suffix
+
+ assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
+ f"in '{self._name}' registry!")
+ self._obj_map[name] = obj
+
+ def register(self, obj=None, suffix=None):
+ """
+ Register the given object under the the name `obj.__name__`.
+ Can be used as either a decorator or not.
+ See docstring of this class for usage.
+ """
+ if obj is None:
+ # used as a decorator
+ def deco(func_or_class):
+ name = func_or_class.__name__
+ self._do_register(name, func_or_class, suffix)
+ return func_or_class
+
+ return deco
+
+ # used as a function call
+ name = obj.__name__
+ self._do_register(name, obj, suffix)
+
+ def get(self, name, suffix='basicsr'):
+ ret = self._obj_map.get(name)
+ if ret is None:
+ ret = self._obj_map.get(name + '_' + suffix)
+ print(f'Name {name} is not found, use name: {name}_{suffix}!')
+ if ret is None:
+ raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
+ return ret
+
+ def __contains__(self, name):
+ return name in self._obj_map
+
+ def __iter__(self):
+ return iter(self._obj_map.items())
+
+ def keys(self):
+ return self._obj_map.keys()
+
+
+DATASET_REGISTRY = Registry('dataset')
+ARCH_REGISTRY = Registry('arch')
+MODEL_REGISTRY = Registry('model')
+LOSS_REGISTRY = Registry('loss')
+METRIC_REGISTRY = Registry('metric')