summaryrefslogtreecommitdiffstats
path: root/r_basicsr/losses/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'r_basicsr/losses/__init__.py')
-rw-r--r--r_basicsr/losses/__init__.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/r_basicsr/losses/__init__.py b/r_basicsr/losses/__init__.py
new file mode 100644
index 0000000..1248657
--- /dev/null
+++ b/r_basicsr/losses/__init__.py
@@ -0,0 +1,31 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from r_basicsr.utils import get_root_logger, scandir
+from r_basicsr.utils.registry import LOSS_REGISTRY
+from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty
+
+__all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize']
+
+# automatically scan and import loss modules for registry
+# scan all the files under the 'losses' folder and collect files ending with '_loss.py'
+loss_folder = osp.dirname(osp.abspath(__file__))
+loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')]
+# import all the loss modules
+_model_modules = [importlib.import_module(f'r_basicsr.losses.{file_name}') for file_name in loss_filenames]
+
+
+def build_loss(opt):
+ """Build loss from options.
+
+ Args:
+ opt (dict): Configuration. It must contain:
+ type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ loss_type = opt.pop('type')
+ loss = LOSS_REGISTRY.get(loss_type)(**opt)
+ logger = get_root_logger()
+ logger.info(f'Loss [{loss.__class__.__name__}] is created.')
+ return loss