From 495ffc4777522e40941753e3b1b79c02f84b25b4 Mon Sep 17 00:00:00 2001 From: Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:00:30 +0000 Subject: Add files via upload --- r_basicsr/archs/__init__.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 r_basicsr/archs/__init__.py (limited to 'r_basicsr/archs/__init__.py') diff --git a/r_basicsr/archs/__init__.py b/r_basicsr/archs/__init__.py new file mode 100644 index 0000000..4a3f3c4 --- /dev/null +++ b/r_basicsr/archs/__init__.py @@ -0,0 +1,25 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from r_basicsr.utils import get_root_logger, scandir +from r_basicsr.utils.registry import ARCH_REGISTRY + +__all__ = ['build_network'] + +# automatically scan and import arch modules for registry +# scan all the files under the 'archs' folder and collect files ending with +# '_arch.py' +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] +# import all the arch modules +_arch_modules = [importlib.import_module(f'r_basicsr.archs.{file_name}') for file_name in arch_filenames] + + +def build_network(opt): + opt = deepcopy(opt) + network_type = opt.pop('type') + net = ARCH_REGISTRY.get(network_type)(**opt) + logger = get_root_logger() + logger.info(f'Network [{net.__class__.__name__}] is created.') + return net -- cgit v1.2.3