summaryrefslogtreecommitdiffstats
path: root/r_basicsr/utils
diff options
context:
space:
mode:
authorGrafting Rayman <156515434+GraftingRayman@users.noreply.github.com>2025-01-17 11:00:30 +0000
committerGitHub <noreply@github.com>2025-01-17 11:00:30 +0000
commit495ffc4777522e40941753e3b1b79c02f84b25b4 (patch)
tree5130fcb8676afdcb619a5e5eaef3ac28e135bc08 /r_basicsr/utils
parentfebd45814cd41560c5247aacb111d8d013f3a303 (diff)
downloadComfyui-reactor-node-495ffc4777522e40941753e3b1b79c02f84b25b4.tar.gz
Add files via upload
Diffstat (limited to 'r_basicsr/utils')
-rw-r--r--r_basicsr/utils/__init__.py44
-rw-r--r--r_basicsr/utils/color_util.py208
-rw-r--r--r_basicsr/utils/diffjpeg.py515
-rw-r--r--r_basicsr/utils/dist_util.py82
-rw-r--r--r_basicsr/utils/download_util.py99
-rw-r--r--r_basicsr/utils/file_client.py167
-rw-r--r--r_basicsr/utils/flow_util.py170
-rw-r--r--r_basicsr/utils/img_process_util.py83
-rw-r--r--r_basicsr/utils/img_util.py172
-rw-r--r--r_basicsr/utils/lmdb_util.py196
-rw-r--r--r_basicsr/utils/logger.py213
-rw-r--r--r_basicsr/utils/matlab_functions.py178
-rw-r--r--r_basicsr/utils/misc.py141
-rw-r--r--r_basicsr/utils/options.py194
-rw-r--r--r_basicsr/utils/plot_util.py84
-rw-r--r--r_basicsr/utils/registry.py88
16 files changed, 2634 insertions, 0 deletions
diff --git a/r_basicsr/utils/__init__.py b/r_basicsr/utils/__init__.py
new file mode 100644
index 0000000..57c730a
--- /dev/null
+++ b/r_basicsr/utils/__init__.py
@@ -0,0 +1,44 @@
+from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb
+from .diffjpeg import DiffJPEG
+from .file_client import FileClient
+from .img_process_util import USMSharp, usm_sharp
+from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
+from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
+from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
+
+__all__ = [
+ # color_util.py
+ 'bgr2ycbcr',
+ 'rgb2ycbcr',
+ 'rgb2ycbcr_pt',
+ 'ycbcr2bgr',
+ 'ycbcr2rgb',
+ # file_client.py
+ 'FileClient',
+ # img_util.py
+ 'img2tensor',
+ 'tensor2img',
+ 'imfrombytes',
+ 'imwrite',
+ 'crop_border',
+ # logger.py
+ 'MessageLogger',
+ 'AvgTimer',
+ 'init_tb_logger',
+ 'init_wandb_logger',
+ 'get_root_logger',
+ 'get_env_info',
+ # misc.py
+ 'set_random_seed',
+ 'get_time_str',
+ 'mkdir_and_rename',
+ 'make_exp_dirs',
+ 'scandir',
+ 'check_resume',
+ 'sizeof_fmt',
+ # diffjpeg
+ 'DiffJPEG',
+ # img_process_util
+ 'USMSharp',
+ 'usm_sharp'
+]
diff --git a/r_basicsr/utils/color_util.py b/r_basicsr/utils/color_util.py
new file mode 100644
index 0000000..8b7676f
--- /dev/null
+++ b/r_basicsr/utils/color_util.py
@@ -0,0 +1,208 @@
+import numpy as np
+import torch
+
+
+def rgb2ycbcr(img, y_only=False):
+ """Convert a RGB image to YCbCr image.
+
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def bgr2ycbcr(img, y_only=False):
+ """Convert a BGR image to YCbCr image.
+
+ The bgr version of rgb2ycbcr.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2rgb(img):
+ """Convert a YCbCr image to RGB image.
+
+ This function produces the same results as Matlab's ycbcr2rgb function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted RGB image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2bgr(img):
+ """Convert a YCbCr image to BGR image.
+
+ The bgr version of ycbcr2rgb.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted BGR image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def _convert_input_type_range(img):
+ """Convert the type and range of the input image.
+
+ It converts the input image to np.float32 type and range of [0, 1].
+ It is mainly used for pre-processing the input image in colorspace
+ conversion functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with type of np.float32 and range of
+ [0, 1].
+ """
+ img_type = img.dtype
+ img = img.astype(np.float32)
+ if img_type == np.float32:
+ pass
+ elif img_type == np.uint8:
+ img /= 255.
+ else:
+ raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
+ return img
+
+
+def _convert_output_type_range(img, dst_type):
+ """Convert the type and range of the image according to dst_type.
+
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
+ images will be converted to np.uint8 type with range [0, 255]. If
+ `dst_type` is np.float32, it converts the image to np.float32 type with
+ range [0, 1].
+ It is mainly used for post-processing images in colorspace conversion
+ functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The image to be converted with np.float32 type and
+ range [0, 255].
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
+ converts the image to np.uint8 type with range [0, 255]. If
+ dst_type is np.float32, it converts the image to np.float32 type
+ with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with desired type and range.
+ """
+ if dst_type not in (np.uint8, np.float32):
+ raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
+ if dst_type == np.uint8:
+ img = img.round()
+ else:
+ img /= 255.
+ return img.astype(dst_type)
+
+
+def rgb2ycbcr_pt(img, y_only=False):
+ """Convert RGB images to YCbCr images (PyTorch version).
+
+ It implements the ITU-R BT.601 conversion for standard-definition television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ Args:
+ img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format.
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ (Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float.
+ """
+ if y_only:
+ weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img)
+ out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
+ else:
+ weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img)
+ bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img)
+ out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
+
+ out_img = out_img / 255.
+ return out_img
diff --git a/r_basicsr/utils/diffjpeg.py b/r_basicsr/utils/diffjpeg.py
new file mode 100644
index 0000000..c055c1b
--- /dev/null
+++ b/r_basicsr/utils/diffjpeg.py
@@ -0,0 +1,515 @@
+"""
+Modified from https://github.com/mlomnitz/DiffJPEG
+
+For images not divisible by 8
+https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343
+"""
+import itertools
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+# ------------------------ utils ------------------------#
+y_table = np.array(
+ [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56],
+ [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92],
+ [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
+ dtype=np.float32).T
+y_table = nn.Parameter(torch.from_numpy(y_table))
+c_table = np.empty((8, 8), dtype=np.float32)
+c_table.fill(99)
+c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T
+c_table = nn.Parameter(torch.from_numpy(c_table))
+
+
+def diff_round(x):
+ """ Differentiable rounding function
+ """
+ return torch.round(x) + (x - torch.round(x))**3
+
+
+def quality_to_factor(quality):
+ """ Calculate factor corresponding to quality
+
+ Args:
+ quality(float): Quality for jpeg compression.
+
+ Returns:
+ float: Compression factor.
+ """
+ if quality < 50:
+ quality = 5000. / quality
+ else:
+ quality = 200. - quality * 2
+ return quality / 100.
+
+
+# ------------------------ compression ------------------------#
+class RGB2YCbCrJpeg(nn.Module):
+ """ Converts RGB image to YCbCr
+ """
+
+ def __init__(self):
+ super(RGB2YCbCrJpeg, self).__init__()
+ matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]],
+ dtype=np.float32).T
+ self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
+ self.matrix = nn.Parameter(torch.from_numpy(matrix))
+
+ def forward(self, image):
+ """
+ Args:
+ image(Tensor): batch x 3 x height x width
+
+ Returns:
+ Tensor: batch x height x width x 3
+ """
+ image = image.permute(0, 2, 3, 1)
+ result = torch.tensordot(image, self.matrix, dims=1) + self.shift
+ return result.view(image.shape)
+
+
+class ChromaSubsampling(nn.Module):
+ """ Chroma subsampling on CbCr channels
+ """
+
+ def __init__(self):
+ super(ChromaSubsampling, self).__init__()
+
+ def forward(self, image):
+ """
+ Args:
+ image(tensor): batch x height x width x 3
+
+ Returns:
+ y(tensor): batch x height x width
+ cb(tensor): batch x height/2 x width/2
+ cr(tensor): batch x height/2 x width/2
+ """
+ image_2 = image.permute(0, 3, 1, 2).clone()
+ cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
+ cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
+ cb = cb.permute(0, 2, 3, 1)
+ cr = cr.permute(0, 2, 3, 1)
+ return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
+
+
+class BlockSplitting(nn.Module):
+ """ Splitting image into patches
+ """
+
+ def __init__(self):
+ super(BlockSplitting, self).__init__()
+ self.k = 8
+
+ def forward(self, image):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x h*w/64 x h x w
+ """
+ height, _ = image.shape[1:3]
+ batch_size = image.shape[0]
+ image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
+ image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
+ return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
+
+
+class DCT8x8(nn.Module):
+ """ Discrete Cosine Transformation
+ """
+
+ def __init__(self):
+ super(DCT8x8, self).__init__()
+ tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
+ for x, y, u, v in itertools.product(range(8), repeat=4):
+ tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)
+ alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
+ self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
+ self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())
+
+ def forward(self, image):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ image = image - 128
+ result = self.scale * torch.tensordot(image, self.tensor, dims=2)
+ result.view(image.shape)
+ return result
+
+
+class YQuantize(nn.Module):
+ """ JPEG Quantization for Y channel
+
+ Args:
+ rounding(function): rounding function to use
+ """
+
+ def __init__(self, rounding):
+ super(YQuantize, self).__init__()
+ self.rounding = rounding
+ self.y_table = y_table
+
+ def forward(self, image, factor=1):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ if isinstance(factor, (int, float)):
+ image = image.float() / (self.y_table * factor)
+ else:
+ b = factor.size(0)
+ table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
+ image = image.float() / table
+ image = self.rounding(image)
+ return image
+
+
+class CQuantize(nn.Module):
+ """ JPEG Quantization for CbCr channels
+
+ Args:
+ rounding(function): rounding function to use
+ """
+
+ def __init__(self, rounding):
+ super(CQuantize, self).__init__()
+ self.rounding = rounding
+ self.c_table = c_table
+
+ def forward(self, image, factor=1):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ if isinstance(factor, (int, float)):
+ image = image.float() / (self.c_table * factor)
+ else:
+ b = factor.size(0)
+ table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
+ image = image.float() / table
+ image = self.rounding(image)
+ return image
+
+
+class CompressJpeg(nn.Module):
+ """Full JPEG compression algorithm
+
+ Args:
+ rounding(function): rounding function to use
+ """
+
+ def __init__(self, rounding=torch.round):
+ super(CompressJpeg, self).__init__()
+ self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling())
+ self.l2 = nn.Sequential(BlockSplitting(), DCT8x8())
+ self.c_quantize = CQuantize(rounding=rounding)
+ self.y_quantize = YQuantize(rounding=rounding)
+
+ def forward(self, image, factor=1):
+ """
+ Args:
+ image(tensor): batch x 3 x height x width
+
+ Returns:
+ dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8.
+ """
+ y, cb, cr = self.l1(image * 255)
+ components = {'y': y, 'cb': cb, 'cr': cr}
+ for k in components.keys():
+ comp = self.l2(components[k])
+ if k in ('cb', 'cr'):
+ comp = self.c_quantize(comp, factor=factor)
+ else:
+ comp = self.y_quantize(comp, factor=factor)
+
+ components[k] = comp
+
+ return components['y'], components['cb'], components['cr']
+
+
+# ------------------------ decompression ------------------------#
+
+
+class YDequantize(nn.Module):
+ """Dequantize Y channel
+ """
+
+ def __init__(self):
+ super(YDequantize, self).__init__()
+ self.y_table = y_table
+
+ def forward(self, image, factor=1):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ if isinstance(factor, (int, float)):
+ out = image * (self.y_table * factor)
+ else:
+ b = factor.size(0)
+ table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
+ out = image * table
+ return out
+
+
+class CDequantize(nn.Module):
+ """Dequantize CbCr channel
+ """
+
+ def __init__(self):
+ super(CDequantize, self).__init__()
+ self.c_table = c_table
+
+ def forward(self, image, factor=1):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ if isinstance(factor, (int, float)):
+ out = image * (self.c_table * factor)
+ else:
+ b = factor.size(0)
+ table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
+ out = image * table
+ return out
+
+
+class iDCT8x8(nn.Module):
+ """Inverse discrete Cosine Transformation
+ """
+
+ def __init__(self):
+ super(iDCT8x8, self).__init__()
+ alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
+ self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
+ tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
+ for x, y, u, v in itertools.product(range(8), repeat=4):
+ tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16)
+ self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
+
+ def forward(self, image):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ image = image * self.alpha
+ result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
+ result.view(image.shape)
+ return result
+
+
+class BlockMerging(nn.Module):
+ """Merge patches into image
+ """
+
+ def __init__(self):
+ super(BlockMerging, self).__init__()
+
+ def forward(self, patches, height, width):
+ """
+ Args:
+ patches(tensor) batch x height*width/64, height x width
+ height(int)
+ width(int)
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ k = 8
+ batch_size = patches.shape[0]
+ image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
+ image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
+ return image_transposed.contiguous().view(batch_size, height, width)
+
+
+class ChromaUpsampling(nn.Module):
+ """Upsample chroma layers
+ """
+
+ def __init__(self):
+ super(ChromaUpsampling, self).__init__()
+
+ def forward(self, y, cb, cr):
+ """
+ Args:
+ y(tensor): y channel image
+ cb(tensor): cb channel
+ cr(tensor): cr channel
+
+ Returns:
+ Tensor: batch x height x width x 3
+ """
+
+ def repeat(x, k=2):
+ height, width = x.shape[1:3]
+ x = x.unsqueeze(-1)
+ x = x.repeat(1, 1, k, k)
+ x = x.view(-1, height * k, width * k)
+ return x
+
+ cb = repeat(cb)
+ cr = repeat(cr)
+ return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
+
+
+class YCbCr2RGBJpeg(nn.Module):
+ """Converts YCbCr image to RGB JPEG
+ """
+
+ def __init__(self):
+ super(YCbCr2RGBJpeg, self).__init__()
+
+ matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T
+ self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
+ self.matrix = nn.Parameter(torch.from_numpy(matrix))
+
+ def forward(self, image):
+ """
+ Args:
+ image(tensor): batch x height x width x 3
+
+ Returns:
+ Tensor: batch x 3 x height x width
+ """
+ result = torch.tensordot(image + self.shift, self.matrix, dims=1)
+ return result.view(image.shape).permute(0, 3, 1, 2)
+
+
+class DeCompressJpeg(nn.Module):
+ """Full JPEG decompression algorithm
+
+ Args:
+ rounding(function): rounding function to use
+ """
+
+ def __init__(self, rounding=torch.round):
+ super(DeCompressJpeg, self).__init__()
+ self.c_dequantize = CDequantize()
+ self.y_dequantize = YDequantize()
+ self.idct = iDCT8x8()
+ self.merging = BlockMerging()
+ self.chroma = ChromaUpsampling()
+ self.colors = YCbCr2RGBJpeg()
+
+ def forward(self, y, cb, cr, imgh, imgw, factor=1):
+ """
+ Args:
+ compressed(dict(tensor)): batch x h*w/64 x 8 x 8
+ imgh(int)
+ imgw(int)
+ factor(float)
+
+ Returns:
+ Tensor: batch x 3 x height x width
+ """
+ components = {'y': y, 'cb': cb, 'cr': cr}
+ for k in components.keys():
+ if k in ('cb', 'cr'):
+ comp = self.c_dequantize(components[k], factor=factor)
+ height, width = int(imgh / 2), int(imgw / 2)
+ else:
+ comp = self.y_dequantize(components[k], factor=factor)
+ height, width = imgh, imgw
+ comp = self.idct(comp)
+ components[k] = self.merging(comp, height, width)
+ #
+ image = self.chroma(components['y'], components['cb'], components['cr'])
+ image = self.colors(image)
+
+ image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image))
+ return image / 255
+
+
+# ------------------------ main DiffJPEG ------------------------ #
+
+
+class DiffJPEG(nn.Module):
+ """This JPEG algorithm result is slightly different from cv2.
+ DiffJPEG supports batch processing.
+
+ Args:
+ differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round
+ """
+
+ def __init__(self, differentiable=True):
+ super(DiffJPEG, self).__init__()
+ if differentiable:
+ rounding = diff_round
+ else:
+ rounding = torch.round
+
+ self.compress = CompressJpeg(rounding=rounding)
+ self.decompress = DeCompressJpeg(rounding=rounding)
+
+ def forward(self, x, quality):
+ """
+ Args:
+ x (Tensor): Input image, bchw, rgb, [0, 1]
+ quality(float): Quality factor for jpeg compression scheme.
+ """
+ factor = quality
+ if isinstance(factor, (int, float)):
+ factor = quality_to_factor(factor)
+ else:
+ for i in range(factor.size(0)):
+ factor[i] = quality_to_factor(factor[i])
+ h, w = x.size()[-2:]
+ h_pad, w_pad = 0, 0
+ # why should use 16
+ if h % 16 != 0:
+ h_pad = 16 - h % 16
+ if w % 16 != 0:
+ w_pad = 16 - w % 16
+ x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0)
+
+ y, cb, cr = self.compress(x, factor=factor)
+ recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor)
+ recovered = recovered[:, :, 0:h, 0:w]
+ return recovered
+
+
+if __name__ == '__main__':
+ import cv2
+
+ from r_basicsr.utils import img2tensor, tensor2img
+
+ img_gt = cv2.imread('test.png') / 255.
+
+ # -------------- cv2 -------------- #
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20]
+ _, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param)
+ img_lq = np.float32(cv2.imdecode(encimg, 1))
+ cv2.imwrite('cv2_JPEG_20.png', img_lq)
+
+ # -------------- DiffJPEG -------------- #
+ jpeger = DiffJPEG(differentiable=False).cuda()
+ img_gt = img2tensor(img_gt)
+ img_gt = torch.stack([img_gt, img_gt]).cuda()
+ quality = img_gt.new_tensor([20, 40])
+ out = jpeger(img_gt, quality=quality)
+
+ cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0]))
+ cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1]))
diff --git a/r_basicsr/utils/dist_util.py b/r_basicsr/utils/dist_util.py
new file mode 100644
index 0000000..380f155
--- /dev/null
+++ b/r_basicsr/utils/dist_util.py
@@ -0,0 +1,82 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
+import functools
+import os
+import subprocess
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+
+def init_dist(launcher, backend='nccl', **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method('spawn')
+ if launcher == 'pytorch':
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == 'slurm':
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError(f'Invalid launcher type: {launcher}')
+
+
+def _init_dist_pytorch(backend, **kwargs):
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_slurm(backend, port=None):
+ """Initialize slurm distributed training environment.
+
+ If argument ``port`` is not specified, then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+
+ Args:
+ backend (str): Backend of torch.distributed.
+ port (int, optional): Master port. Defaults to None.
+ """
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
+ # specify master port
+ if port is not None:
+ os.environ['MASTER_PORT'] = str(port)
+ elif 'MASTER_PORT' in os.environ:
+ pass # use MASTER_PORT in the environment variable
+ else:
+ # 29500 is torch.distributed default port
+ os.environ['MASTER_PORT'] = '29500'
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+ os.environ['RANK'] = str(proc_id)
+ dist.init_process_group(backend=backend)
+
+
+def get_dist_info():
+ if dist.is_available():
+ initialized = dist.is_initialized()
+ else:
+ initialized = False
+ if initialized:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def master_only(func):
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ return func(*args, **kwargs)
+
+ return wrapper
diff --git a/r_basicsr/utils/download_util.py b/r_basicsr/utils/download_util.py
new file mode 100644
index 0000000..59b2621
--- /dev/null
+++ b/r_basicsr/utils/download_util.py
@@ -0,0 +1,99 @@
+import math
+import os
+import requests
+from torch.hub import download_url_to_file, get_dir
+from tqdm import tqdm
+from urllib.parse import urlparse
+
+from .misc import sizeof_fmt
+
+
+def download_file_from_google_drive(file_id, save_path):
+ """Download files from google drive.
+
+ Ref:
+ https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
+
+ Args:
+ file_id (str): File id.
+ save_path (str): Save path.
+ """
+
+ session = requests.Session()
+ URL = 'https://docs.google.com/uc?export=download'
+ params = {'id': file_id}
+
+ response = session.get(URL, params=params, stream=True)
+ token = get_confirm_token(response)
+ if token:
+ params['confirm'] = token
+ response = session.get(URL, params=params, stream=True)
+
+ # get file size
+ response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
+ if 'Content-Range' in response_file_size.headers:
+ file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
+ else:
+ file_size = None
+
+ save_response_content(response, save_path, file_size)
+
+
+def get_confirm_token(response):
+ for key, value in response.cookies.items():
+ if key.startswith('download_warning'):
+ return value
+ return None
+
+
+def save_response_content(response, destination, file_size=None, chunk_size=32768):
+ if file_size is not None:
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
+
+ readable_file_size = sizeof_fmt(file_size)
+ else:
+ pbar = None
+
+ with open(destination, 'wb') as f:
+ downloaded_size = 0
+ for chunk in response.iter_content(chunk_size):
+ downloaded_size += chunk_size
+ if pbar is not None:
+ pbar.update(1)
+ pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
+ if chunk: # filter out keep-alive new chunks
+ f.write(chunk)
+ if pbar is not None:
+ pbar.close()
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Load file form http url, will download models if necessary.
+
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+
+ Args:
+ url (str): URL to be downloaded.
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
+ Default: None.
+ progress (bool): Whether to show the download progress. Default: True.
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
+
+ Returns:
+ str: The path to the downloaded file.
+ """
+ if model_dir is None: # use the pytorch hub_dir
+ hub_dir = get_dir()
+ model_dir = os.path.join(hub_dir, 'checkpoints')
+
+ os.makedirs(model_dir, exist_ok=True)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+ return cached_file
diff --git a/r_basicsr/utils/file_client.py b/r_basicsr/utils/file_client.py
new file mode 100644
index 0000000..8f6340e
--- /dev/null
+++ b/r_basicsr/utils/file_client.py
@@ -0,0 +1,167 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
+from abc import ABCMeta, abstractmethod
+
+
+class BaseStorageBackend(metaclass=ABCMeta):
+ """Abstract class of storage backends.
+
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
+ as texts.
+ """
+
+ @abstractmethod
+ def get(self, filepath):
+ pass
+
+ @abstractmethod
+ def get_text(self, filepath):
+ pass
+
+
+class MemcachedBackend(BaseStorageBackend):
+ """Memcached storage backend.
+
+ Attributes:
+ server_list_cfg (str): Config file for memcached server list.
+ client_cfg (str): Config file for memcached client.
+ sys_path (str | None): Additional path to be appended to `sys.path`.
+ Default: None.
+ """
+
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
+ if sys_path is not None:
+ import sys
+ sys.path.append(sys_path)
+ try:
+ import mc
+ except ImportError:
+ raise ImportError('Please install memcached to enable MemcachedBackend.')
+
+ self.server_list_cfg = server_list_cfg
+ self.client_cfg = client_cfg
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
+ # mc.pyvector servers as a point which points to a memory cache
+ self._mc_buffer = mc.pyvector()
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ import mc
+ self._client.Get(filepath, self._mc_buffer)
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
+ return value_buf
+
+ def get_text(self, filepath):
+ raise NotImplementedError
+
+
+class HardDiskBackend(BaseStorageBackend):
+ """Raw hard disks storage backend."""
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ with open(filepath, 'rb') as f:
+ value_buf = f.read()
+ return value_buf
+
+ def get_text(self, filepath):
+ filepath = str(filepath)
+ with open(filepath, 'r') as f:
+ value_buf = f.read()
+ return value_buf
+
+
+class LmdbBackend(BaseStorageBackend):
+ """Lmdb storage backend.
+
+ Args:
+ db_paths (str | list[str]): Lmdb database paths.
+ client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
+ readonly (bool, optional): Lmdb environment parameter. If True,
+ disallow any write operations. Default: True.
+ lock (bool, optional): Lmdb environment parameter. If False, when
+ concurrent access occurs, do not lock the database. Default: False.
+ readahead (bool, optional): Lmdb environment parameter. If False,
+ disable the OS filesystem readahead mechanism, which may improve
+ random read performance when a database is larger than RAM.
+ Default: False.
+
+ Attributes:
+ db_paths (list): Lmdb database path.
+ _client (list): A list of several lmdb envs.
+ """
+
+ def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
+ try:
+ import lmdb
+ except ImportError:
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
+
+ if isinstance(client_keys, str):
+ client_keys = [client_keys]
+
+ if isinstance(db_paths, list):
+ self.db_paths = [str(v) for v in db_paths]
+ elif isinstance(db_paths, str):
+ self.db_paths = [str(db_paths)]
+ assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
+ f'but received {len(client_keys)} and {len(self.db_paths)}.')
+
+ self._client = {}
+ for client, path in zip(client_keys, self.db_paths):
+ self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
+
+ def get(self, filepath, client_key):
+ """Get values according to the filepath from one lmdb named client_key.
+
+ Args:
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
+ client_key (str): Used for distinguishing different lmdb envs.
+ """
+ filepath = str(filepath)
+ assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.')
+ client = self._client[client_key]
+ with client.begin(write=False) as txn:
+ value_buf = txn.get(filepath.encode('ascii'))
+ return value_buf
+
+ def get_text(self, filepath):
+ raise NotImplementedError
+
+
+class FileClient(object):
+ """A general file client to access files in different backend.
+
+ The client loads a file or text in a specified backend from its path
+ and return it as a binary file. it can also register other backend
+ accessor with a given name and backend class.
+
+ Attributes:
+ backend (str): The storage backend type. Options are "disk",
+ "memcached" and "lmdb".
+ client (:obj:`BaseStorageBackend`): The backend object.
+ """
+
+ _backends = {
+ 'disk': HardDiskBackend,
+ 'memcached': MemcachedBackend,
+ 'lmdb': LmdbBackend,
+ }
+
+ def __init__(self, backend='disk', **kwargs):
+ if backend not in self._backends:
+ raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
+ f' are {list(self._backends.keys())}')
+ self.backend = backend
+ self.client = self._backends[backend](**kwargs)
+
+ def get(self, filepath, client_key='default'):
+ # client_key is used only for lmdb, where different fileclients have
+ # different lmdb environments.
+ if self.backend == 'lmdb':
+ return self.client.get(filepath, client_key)
+ else:
+ return self.client.get(filepath)
+
+ def get_text(self, filepath):
+ return self.client.get_text(filepath)
diff --git a/r_basicsr/utils/flow_util.py b/r_basicsr/utils/flow_util.py
new file mode 100644
index 0000000..d133012
--- /dev/null
+++ b/r_basicsr/utils/flow_util.py
@@ -0,0 +1,170 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
+import cv2
+import numpy as np
+import os
+
+
+def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
+ """Read an optical flow map.
+
+ Args:
+ flow_path (ndarray or str): Flow path.
+ quantize (bool): whether to read quantized pair, if set to True,
+ remaining args will be passed to :func:`dequantize_flow`.
+ concat_axis (int): The axis that dx and dy are concatenated,
+ can be either 0 or 1. Ignored if quantize is False.
+
+ Returns:
+ ndarray: Optical flow represented as a (h, w, 2) numpy array
+ """
+ if quantize:
+ assert concat_axis in [0, 1]
+ cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
+ if cat_flow.ndim != 2:
+ raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.')
+ assert cat_flow.shape[concat_axis] % 2 == 0
+ dx, dy = np.split(cat_flow, 2, axis=concat_axis)
+ flow = dequantize_flow(dx, dy, *args, **kwargs)
+ else:
+ with open(flow_path, 'rb') as f:
+ try:
+ header = f.read(4).decode('utf-8')
+ except Exception:
+ raise IOError(f'Invalid flow file: {flow_path}')
+ else:
+ if header != 'PIEH':
+ raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH')
+
+ w = np.fromfile(f, np.int32, 1).squeeze()
+ h = np.fromfile(f, np.int32, 1).squeeze()
+ flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
+
+ return flow.astype(np.float32)
+
+
+def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
+ """Write optical flow to file.
+
+ If the flow is not quantized, it will be saved as a .flo file losslessly,
+ otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
+ will be concatenated horizontally into a single image if quantize is True.)
+
+ Args:
+ flow (ndarray): (h, w, 2) array of optical flow.
+ filename (str): Output filepath.
+ quantize (bool): Whether to quantize the flow and save it to 2 jpeg
+ images. If set to True, remaining args will be passed to
+ :func:`quantize_flow`.
+ concat_axis (int): The axis that dx and dy are concatenated,
+ can be either 0 or 1. Ignored if quantize is False.
+ """
+ if not quantize:
+ with open(filename, 'wb') as f:
+ f.write('PIEH'.encode('utf-8'))
+ np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
+ flow = flow.astype(np.float32)
+ flow.tofile(f)
+ f.flush()
+ else:
+ assert concat_axis in [0, 1]
+ dx, dy = quantize_flow(flow, *args, **kwargs)
+ dxdy = np.concatenate((dx, dy), axis=concat_axis)
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ cv2.imwrite(filename, dxdy)
+
+
+def quantize_flow(flow, max_val=0.02, norm=True):
+ """Quantize flow to [0, 255].
+
+ After this step, the size of flow will be much smaller, and can be
+ dumped as jpeg images.
+
+ Args:
+ flow (ndarray): (h, w, 2) array of optical flow.
+ max_val (float): Maximum value of flow, values beyond
+ [-max_val, max_val] will be truncated.
+ norm (bool): Whether to divide flow values by image width/height.
+
+ Returns:
+ tuple[ndarray]: Quantized dx and dy.
+ """
+ h, w, _ = flow.shape
+ dx = flow[..., 0]
+ dy = flow[..., 1]
+ if norm:
+ dx = dx / w # avoid inplace operations
+ dy = dy / h
+ # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
+ flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]]
+ return tuple(flow_comps)
+
+
+def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
+ """Recover from quantized flow.
+
+ Args:
+ dx (ndarray): Quantized dx.
+ dy (ndarray): Quantized dy.
+ max_val (float): Maximum value used when quantizing.
+ denorm (bool): Whether to multiply flow values with width/height.
+
+ Returns:
+ ndarray: Dequantized flow.
+ """
+ assert dx.shape == dy.shape
+ assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
+
+ dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
+
+ if denorm:
+ dx *= dx.shape[1]
+ dy *= dx.shape[0]
+ flow = np.dstack((dx, dy))
+ return flow
+
+
+def quantize(arr, min_val, max_val, levels, dtype=np.int64):
+ """Quantize an array of (-inf, inf) to [0, levels-1].
+
+ Args:
+ arr (ndarray): Input array.
+ min_val (scalar): Minimum value to be clipped.
+ max_val (scalar): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the quantized array.
+
+ Returns:
+ tuple: Quantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+ arr = np.clip(arr, min_val, max_val) - min_val
+ quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
+
+ return quantized_arr
+
+
+def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
+ """Dequantize an array.
+
+ Args:
+ arr (ndarray): Input array.
+ min_val (scalar): Minimum value to be clipped.
+ max_val (scalar): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the dequantized array.
+
+ Returns:
+ tuple: Dequantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+ dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val
+
+ return dequantized_arr
diff --git a/r_basicsr/utils/img_process_util.py b/r_basicsr/utils/img_process_util.py
new file mode 100644
index 0000000..fb5fbc9
--- /dev/null
+++ b/r_basicsr/utils/img_process_util.py
@@ -0,0 +1,83 @@
+import cv2
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+
+def filter2D(img, kernel):
+ """PyTorch version of cv2.filter2D
+
+ Args:
+ img (Tensor): (b, c, h, w)
+ kernel (Tensor): (b, k, k)
+ """
+ k = kernel.size(-1)
+ b, c, h, w = img.size()
+ if k % 2 == 1:
+ img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
+ else:
+ raise ValueError('Wrong kernel size')
+
+ ph, pw = img.size()[-2:]
+
+ if kernel.size(0) == 1:
+ # apply the same kernel to all batch images
+ img = img.view(b * c, 1, ph, pw)
+ kernel = kernel.view(1, 1, k, k)
+ return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
+ else:
+ img = img.view(1, b * c, ph, pw)
+ kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
+ return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
+
+
+def usm_sharp(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening.
+
+ Input image: I; Blurry image: B.
+ 1. sharp = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * sharp + (1 - Mask) * I
+
+
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ sharp = img + weight * residual
+ sharp = np.clip(sharp, 0, 1)
+ return soft_mask * sharp + (1 - soft_mask) * img
+
+
+class USMSharp(torch.nn.Module):
+
+ def __init__(self, radius=50, sigma=0):
+ super(USMSharp, self).__init__()
+ if radius % 2 == 0:
+ radius += 1
+ self.radius = radius
+ kernel = cv2.getGaussianKernel(radius, sigma)
+ kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
+ self.register_buffer('kernel', kernel)
+
+ def forward(self, img, weight=0.5, threshold=10):
+ blur = filter2D(img, self.kernel)
+ residual = img - blur
+
+ mask = torch.abs(residual) * 255 > threshold
+ mask = mask.float()
+ soft_mask = filter2D(mask, self.kernel)
+ sharp = img + weight * residual
+ sharp = torch.clip(sharp, 0, 1)
+ return soft_mask * sharp + (1 - soft_mask) * img
diff --git a/r_basicsr/utils/img_util.py b/r_basicsr/utils/img_util.py
new file mode 100644
index 0000000..3ad2be2
--- /dev/null
+++ b/r_basicsr/utils/img_util.py
@@ -0,0 +1,172 @@
+import cv2
+import math
+import numpy as np
+import os
+import torch
+from torchvision.utils import make_grid
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+ """Numpy array to tensor.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Input images.
+ bgr2rgb (bool): Whether to change bgr to rgb.
+ float32 (bool): Whether to change to float32.
+
+ Returns:
+ list[tensor] | tensor: Tensor images. If returned results only have
+ one element, just return tensor.
+ """
+
+ def _totensor(img, bgr2rgb, float32):
+ if img.shape[2] == 3 and bgr2rgb:
+ if img.dtype == 'float64':
+ img = img.astype('float32')
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = torch.from_numpy(img.transpose(2, 0, 1))
+ if float32:
+ img = img.float()
+ return img
+
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
+ else:
+ return _totensor(imgs, bgr2rgb, float32)
+
+
+def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
+ """Convert torch Tensors into image numpy arrays.
+
+ After clamping to [min, max], values will be normalized to [0, 1].
+
+ Args:
+ tensor (Tensor or list[Tensor]): Accept shapes:
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
+ 2) 3D Tensor of shape (3/1 x H x W);
+ 3) 2D Tensor of shape (H x W).
+ Tensor channel should be in RGB order.
+ rgb2bgr (bool): Whether to change rgb to bgr.
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
+ to uint8 type with range [0, 255]; otherwise, float type with
+ range [0, 1]. Default: ``np.uint8``.
+ min_max (tuple[int]): min and max values for clamp.
+
+ Returns:
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
+ shape (H x W). The channel order is BGR.
+ """
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
+ raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
+
+ if torch.is_tensor(tensor):
+ tensor = [tensor]
+ result = []
+ for _tensor in tensor:
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
+
+ n_dim = _tensor.dim()
+ if n_dim == 4:
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 3:
+ img_np = _tensor.numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if img_np.shape[2] == 1: # gray image
+ img_np = np.squeeze(img_np, axis=2)
+ else:
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 2:
+ img_np = _tensor.numpy()
+ else:
+ raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
+ if out_type == np.uint8:
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
+ img_np = (img_np * 255.0).round()
+ img_np = img_np.astype(out_type)
+ result.append(img_np)
+ if len(result) == 1:
+ result = result[0]
+ return result
+
+
+def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
+ """This implementation is slightly faster than tensor2img.
+ It now only supports torch tensor with shape (1, c, h, w).
+
+ Args:
+ tensor (Tensor): Now only support torch tensor with (1, c, h, w).
+ rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
+ min_max (tuple[int]): min and max values for clamp.
+ """
+ output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
+ output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
+ output = output.type(torch.uint8).cpu().numpy()
+ if rgb2bgr:
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
+ return output
+
+
+def imfrombytes(content, flag='color', float32=False):
+ """Read an image from bytes.
+
+ Args:
+ content (bytes): Image bytes got from files or other streams.
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are `color`, `grayscale` and `unchanged`.
+ float32 (bool): Whether to change to float32., If True, will also norm
+ to [0, 1]. Default: False.
+
+ Returns:
+ ndarray: Loaded image array.
+ """
+ img_np = np.frombuffer(content, np.uint8)
+ imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
+ img = cv2.imdecode(img_np, imread_flags[flag])
+ if float32:
+ img = img.astype(np.float32) / 255.
+ return img
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ """Write image to file.
+
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv's :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically.
+
+ Returns:
+ bool: Successful or not.
+ """
+ if auto_mkdir:
+ dir_name = os.path.abspath(os.path.dirname(file_path))
+ os.makedirs(dir_name, exist_ok=True)
+ ok = cv2.imwrite(file_path, img, params)
+ if not ok:
+ raise IOError('Failed in writing images.')
+
+
+def crop_border(imgs, crop_border):
+ """Crop borders of images.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
+ crop_border (int): Crop border for each end of height and weight.
+
+ Returns:
+ list[ndarray]: Cropped images.
+ """
+ if crop_border == 0:
+ return imgs
+ else:
+ if isinstance(imgs, list):
+ return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
+ else:
+ return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
diff --git a/r_basicsr/utils/lmdb_util.py b/r_basicsr/utils/lmdb_util.py
new file mode 100644
index 0000000..97774ca
--- /dev/null
+++ b/r_basicsr/utils/lmdb_util.py
@@ -0,0 +1,196 @@
+import cv2
+import lmdb
+import sys
+from multiprocessing import Pool
+from os import path as osp
+from tqdm import tqdm
+
+
+def make_lmdb_from_imgs(data_path,
+ lmdb_path,
+ img_path_list,
+ keys,
+ batch=5000,
+ compress_level=1,
+ multiprocessing_read=False,
+ n_thread=40,
+ map_size=None):
+ """Make lmdb from images.
+
+ Contents of lmdb. The file structure is:
+ example.lmdb
+ ├── data.mdb
+ ├── lock.mdb
+ ├── meta_info.txt
+
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
+ https://lmdb.readthedocs.io/en/release/ for more details.
+
+ The meta_info.txt is a specified txt file to record the meta information
+ of our datasets. It will be automatically created when preparing
+ datasets by our provided dataset tools.
+ Each line in the txt file records 1)image name (with extension),
+ 2)image shape, and 3)compression level, separated by a white space.
+
+ For example, the meta information could be:
+ `000_00000000.png (720,1280,3) 1`, which means:
+ 1) image name (with extension): 000_00000000.png;
+ 2) image shape: (720,1280,3);
+ 3) compression level: 1
+
+ We use the image name without extension as the lmdb key.
+
+ If `multiprocessing_read` is True, it will read all the images to memory
+ using multiprocessing. Thus, your server needs to have enough memory.
+
+ Args:
+ data_path (str): Data path for reading images.
+ lmdb_path (str): Lmdb save path.
+ img_path_list (str): Image path list.
+ keys (str): Used for lmdb keys.
+ batch (int): After processing batch images, lmdb commits.
+ Default: 5000.
+ compress_level (int): Compress level when encoding images. Default: 1.
+ multiprocessing_read (bool): Whether use multiprocessing to read all
+ the images to memory. Default: False.
+ n_thread (int): For multiprocessing.
+ map_size (int | None): Map size for lmdb env. If None, use the
+ estimated size from images. Default: None
+ """
+
+ assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
+ f'but got {len(img_path_list)} and {len(keys)}')
+ print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
+ print(f'Totoal images: {len(img_path_list)}')
+ if not lmdb_path.endswith('.lmdb'):
+ raise ValueError("lmdb_path must end with '.lmdb'.")
+ if osp.exists(lmdb_path):
+ print(f'Folder {lmdb_path} already exists. Exit.')
+ sys.exit(1)
+
+ if multiprocessing_read:
+ # read all the images to memory (multiprocessing)
+ dataset = {} # use dict to keep the order for multiprocessing
+ shapes = {}
+ print(f'Read images with multiprocessing, #thread: {n_thread} ...')
+ pbar = tqdm(total=len(img_path_list), unit='image')
+
+ def callback(arg):
+ """get the image data and update pbar."""
+ key, dataset[key], shapes[key] = arg
+ pbar.update(1)
+ pbar.set_description(f'Read {key}')
+
+ pool = Pool(n_thread)
+ for path, key in zip(img_path_list, keys):
+ pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
+ pool.close()
+ pool.join()
+ pbar.close()
+ print(f'Finish reading {len(img_path_list)} images.')
+
+ # create lmdb environment
+ if map_size is None:
+ # obtain data size for one image
+ img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
+ _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+ data_size_per_img = img_byte.nbytes
+ print('Data size per image is: ', data_size_per_img)
+ data_size = data_size_per_img * len(img_path_list)
+ map_size = data_size * 10
+
+ env = lmdb.open(lmdb_path, map_size=map_size)
+
+ # write data to lmdb
+ pbar = tqdm(total=len(img_path_list), unit='chunk')
+ txn = env.begin(write=True)
+ txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
+ for idx, (path, key) in enumerate(zip(img_path_list, keys)):
+ pbar.update(1)
+ pbar.set_description(f'Write {key}')
+ key_byte = key.encode('ascii')
+ if multiprocessing_read:
+ img_byte = dataset[key]
+ h, w, c = shapes[key]
+ else:
+ _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
+ h, w, c = img_shape
+
+ txn.put(key_byte, img_byte)
+ # write meta information
+ txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
+ if idx % batch == 0:
+ txn.commit()
+ txn = env.begin(write=True)
+ pbar.close()
+ txn.commit()
+ env.close()
+ txt_file.close()
+ print('\nFinish writing lmdb.')
+
+
+def read_img_worker(path, key, compress_level):
+ """Read image worker.
+
+ Args:
+ path (str): Image path.
+ key (str): Image key.
+ compress_level (int): Compress level when encoding images.
+
+ Returns:
+ str: Image key.
+ byte: Image byte.
+ tuple[int]: Image shape.
+ """
+
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
+ if img.ndim == 2:
+ h, w = img.shape
+ c = 1
+ else:
+ h, w, c = img.shape
+ _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+ return (key, img_byte, (h, w, c))
+
+
+class LmdbMaker():
+ """LMDB Maker.
+
+ Args:
+ lmdb_path (str): Lmdb save path.
+ map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
+ batch (int): After processing batch images, lmdb commits.
+ Default: 5000.
+ compress_level (int): Compress level when encoding images. Default: 1.
+ """
+
+ def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
+ if not lmdb_path.endswith('.lmdb'):
+ raise ValueError("lmdb_path must end with '.lmdb'.")
+ if osp.exists(lmdb_path):
+ print(f'Folder {lmdb_path} already exists. Exit.')
+ sys.exit(1)
+
+ self.lmdb_path = lmdb_path
+ self.batch = batch
+ self.compress_level = compress_level
+ self.env = lmdb.open(lmdb_path, map_size=map_size)
+ self.txn = self.env.begin(write=True)
+ self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
+ self.counter = 0
+
+ def put(self, img_byte, key, img_shape):
+ self.counter += 1
+ key_byte = key.encode('ascii')
+ self.txn.put(key_byte, img_byte)
+ # write meta information
+ h, w, c = img_shape
+ self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
+ if self.counter % self.batch == 0:
+ self.txn.commit()
+ self.txn = self.env.begin(write=True)
+
+ def close(self):
+ self.txn.commit()
+ self.env.close()
+ self.txt_file.close()
diff --git a/r_basicsr/utils/logger.py b/r_basicsr/utils/logger.py
new file mode 100644
index 0000000..2a8a868
--- /dev/null
+++ b/r_basicsr/utils/logger.py
@@ -0,0 +1,213 @@
+import datetime
+import logging
+import time
+
+from .dist_util import get_dist_info, master_only
+
+initialized_logger = {}
+
+
+class AvgTimer():
+
+ def __init__(self, window=200):
+ self.window = window # average window
+ self.current_time = 0
+ self.total_time = 0
+ self.count = 0
+ self.avg_time = 0
+ self.start()
+
+ def start(self):
+ self.start_time = self.tic = time.time()
+
+ def record(self):
+ self.count += 1
+ self.toc = time.time()
+ self.current_time = self.toc - self.tic
+ self.total_time += self.current_time
+ # calculate average time
+ self.avg_time = self.total_time / self.count
+
+ # reset
+ if self.count > self.window:
+ self.count = 0
+ self.total_time = 0
+
+ self.tic = time.time()
+
+ def get_current_time(self):
+ return self.current_time
+
+ def get_avg_time(self):
+ return self.avg_time
+
+
+class MessageLogger():
+ """Message logger for printing.
+
+ Args:
+ opt (dict): Config. It contains the following keys:
+ name (str): Exp name.
+ logger (dict): Contains 'print_freq' (str) for logger interval.
+ train (dict): Contains 'total_iter' (int) for total iters.
+ use_tb_logger (bool): Use tensorboard logger.
+ start_iter (int): Start iter. Default: 1.
+ tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
+ """
+
+ def __init__(self, opt, start_iter=1, tb_logger=None):
+ self.exp_name = opt['name']
+ self.interval = opt['logger']['print_freq']
+ self.start_iter = start_iter
+ self.max_iters = opt['train']['total_iter']
+ self.use_tb_logger = opt['logger']['use_tb_logger']
+ self.tb_logger = tb_logger
+ self.start_time = time.time()
+ self.logger = get_root_logger()
+
+ def reset_start_time(self):
+ self.start_time = time.time()
+
+ @master_only
+ def __call__(self, log_vars):
+ """Format logging message.
+
+ Args:
+ log_vars (dict): It contains the following keys:
+ epoch (int): Epoch number.
+ iter (int): Current iter.
+ lrs (list): List for learning rates.
+
+ time (float): Iter time.
+ data_time (float): Data time for each iter.
+ """
+ # epoch, iter, learning rates
+ epoch = log_vars.pop('epoch')
+ current_iter = log_vars.pop('iter')
+ lrs = log_vars.pop('lrs')
+
+ message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
+ for v in lrs:
+ message += f'{v:.3e},'
+ message += ')] '
+
+ # time and estimated time
+ if 'time' in log_vars.keys():
+ iter_time = log_vars.pop('time')
+ data_time = log_vars.pop('data_time')
+
+ total_time = time.time() - self.start_time
+ time_sec_avg = total_time / (current_iter - self.start_iter + 1)
+ eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
+ message += f'[eta: {eta_str}, '
+ message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
+
+ # other items, especially losses
+ for k, v in log_vars.items():
+ message += f'{k}: {v:.4e} '
+ # tensorboard logger
+ if self.use_tb_logger and 'debug' not in self.exp_name:
+ if k.startswith('l_'):
+ self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
+ else:
+ self.tb_logger.add_scalar(k, v, current_iter)
+ self.logger.info(message)
+
+
+@master_only
+def init_tb_logger(log_dir):
+ from torch.utils.tensorboard import SummaryWriter
+ tb_logger = SummaryWriter(log_dir=log_dir)
+ return tb_logger
+
+
+@master_only
+def init_wandb_logger(opt):
+ """We now only use wandb to sync tensorboard log."""
+ import wandb
+ logger = get_root_logger()
+
+ project = opt['logger']['wandb']['project']
+ resume_id = opt['logger']['wandb'].get('resume_id')
+ if resume_id:
+ wandb_id = resume_id
+ resume = 'allow'
+ logger.warning(f'Resume wandb logger with id={wandb_id}.')
+ else:
+ wandb_id = wandb.util.generate_id()
+ resume = 'never'
+
+ wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
+
+ logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
+
+
+def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
+ """Get the root logger.
+
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added.
+
+ Args:
+ logger_name (str): root logger name. Default: 'basicsr'.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the root logger.
+ log_level (int): The root logger level. Note that only the process of
+ rank 0 is affected, while other processes will set the level to
+ "Error" and be silent most of the time.
+
+ Returns:
+ logging.Logger: The root logger.
+ """
+ logger = logging.getLogger(logger_name)
+ # if the logger has been initialized, just return it
+ if logger_name in initialized_logger:
+ return logger
+
+ format_str = '%(asctime)s %(levelname)s: %(message)s'
+ stream_handler = logging.StreamHandler()
+ stream_handler.setFormatter(logging.Formatter(format_str))
+ logger.addHandler(stream_handler)
+ logger.propagate = False
+ rank, _ = get_dist_info()
+ if rank != 0:
+ logger.setLevel('ERROR')
+ elif log_file is not None:
+ logger.setLevel(log_level)
+ # add file handler
+ file_handler = logging.FileHandler(log_file, 'w')
+ file_handler.setFormatter(logging.Formatter(format_str))
+ file_handler.setLevel(log_level)
+ logger.addHandler(file_handler)
+ initialized_logger[logger_name] = True
+ return logger
+
+
+def get_env_info():
+ """Get environment information.
+
+ Currently, only log the software version.
+ """
+ import torch
+ import torchvision
+
+ from r_basicsr.version import __version__
+ msg = r"""
+ ____ _ _____ ____
+ / __ ) ____ _ _____ (_)_____/ ___/ / __ \
+ / __ |/ __ `// ___// // ___/\__ \ / /_/ /
+ / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
+ /_____/ \__,_//____//_/ \___//____//_/ |_|
+ ______ __ __ __ __
+ / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
+ / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
+ / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
+ \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
+ """
+ msg += ('\nVersion Information: '
+ f'\n\tBasicSR: {__version__}'
+ f'\n\tPyTorch: {torch.__version__}'
+ f'\n\tTorchVision: {torchvision.__version__}')
+ return msg
diff --git a/r_basicsr/utils/matlab_functions.py b/r_basicsr/utils/matlab_functions.py
new file mode 100644
index 0000000..6d0b8cd
--- /dev/null
+++ b/r_basicsr/utils/matlab_functions.py
@@ -0,0 +1,178 @@
+import math
+import numpy as np
+import torch
+
+
+def cubic(x):
+ """cubic function used for calculate_weights_indices."""
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * (
+ (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
+ (absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ """Calculate weights and indices, used for imresize function.
+
+ Args:
+ in_length (int): Input length.
+ out_length (int): Output length.
+ scale (float): Scale factor.
+ kernel_width (int): Kernel width.
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
+ """
+
+ if (scale < 1) and antialiasing:
+ # Use a modified kernel (larger kernel width) to simultaneously
+ # interpolate and antialias
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5 + scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ p = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
+ out_length, p)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
+
+ # apply cubic kernel
+ if (scale < 1) and antialiasing:
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, p)
+
+ # If a column in weights is all zero, get rid of it. only consider the
+ # first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, p - 2)
+ weights = weights.narrow(1, 1, p - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, p - 2)
+ weights = weights.narrow(1, 0, p - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+@torch.no_grad()
+def imresize(img, scale, antialiasing=True):
+ """imresize function same as MATLAB.
+
+ It now only supports bicubic.
+ The same scale applies for both height and width.
+
+ Args:
+ img (Tensor | Numpy array):
+ Tensor: Input image with shape (c, h, w), [0, 1] range.
+ Numpy: Input image with shape (h, w, c), [0, 1] range.
+ scale (float): Scale factor. The same scale applies for both height
+ and width.
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
+ Default: True.
+
+ Returns:
+ Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
+ """
+ squeeze_flag = False
+ if type(img).__module__ == np.__name__: # numpy type
+ numpy_type = True
+ if img.ndim == 2:
+ img = img[:, :, None]
+ squeeze_flag = True
+ img = torch.from_numpy(img.transpose(2, 0, 1)).float()
+ else:
+ numpy_type = False
+ if img.ndim == 2:
+ img = img.unsqueeze(0)
+ squeeze_flag = True
+
+ in_c, in_h, in_w = img.size()
+ out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # get weights and indices
+ weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
+ antialiasing)
+ weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
+ antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
+ img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
+
+ sym_patch = img[:, :sym_len_hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_he:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_c, out_h, in_w)
+ kernel_width = weights_h.size(1)
+ for i in range(out_h):
+ idx = int(indices_h[i][0])
+ for j in range(in_c):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
+ out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_we:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_c, out_h, out_w)
+ kernel_width = weights_w.size(1)
+ for i in range(out_w):
+ idx = int(indices_w[i][0])
+ for j in range(in_c):
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
+
+ if squeeze_flag:
+ out_2 = out_2.squeeze(0)
+ if numpy_type:
+ out_2 = out_2.numpy()
+ if not squeeze_flag:
+ out_2 = out_2.transpose(1, 2, 0)
+
+ return out_2
diff --git a/r_basicsr/utils/misc.py b/r_basicsr/utils/misc.py
new file mode 100644
index 0000000..a43f878
--- /dev/null
+++ b/r_basicsr/utils/misc.py
@@ -0,0 +1,141 @@
+import numpy as np
+import os
+import random
+import time
+import torch
+from os import path as osp
+
+from .dist_util import master_only
+
+
+def set_random_seed(seed):
+ """Set random seeds."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def get_time_str():
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
+
+
+def mkdir_and_rename(path):
+ """mkdirs. If path exists, rename it with timestamp and create a new one.
+
+ Args:
+ path (str): Folder path.
+ """
+ if osp.exists(path):
+ new_name = path + '_archived_' + get_time_str()
+ print(f'Path already exists. Rename it to {new_name}', flush=True)
+ os.rename(path, new_name)
+ os.makedirs(path, exist_ok=True)
+
+
+@master_only
+def make_exp_dirs(opt):
+ """Make dirs for experiments."""
+ path_opt = opt['path'].copy()
+ if opt['is_train']:
+ mkdir_and_rename(path_opt.pop('experiments_root'))
+ else:
+ mkdir_and_rename(path_opt.pop('results_root'))
+ for key, path in path_opt.items():
+ if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key):
+ continue
+ else:
+ os.makedirs(path, exist_ok=True)
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+ """Scan a directory to find the interested files.
+
+ Args:
+ dir_path (str): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ full_path (bool, optional): If set to True, include the dir_path.
+ Default: False.
+
+ Returns:
+ A generator for all the interested files with relative paths.
+ """
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ if full_path:
+ return_path = entry.path
+ else:
+ return_path = osp.relpath(entry.path, root)
+
+ if suffix is None:
+ yield return_path
+ elif return_path.endswith(suffix):
+ yield return_path
+ else:
+ if recursive:
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
+ else:
+ continue
+
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
+
+
+def check_resume(opt, resume_iter):
+ """Check resume states and pretrain_network paths.
+
+ Args:
+ opt (dict): Options.
+ resume_iter (int): Resume iteration.
+ """
+ if opt['path']['resume_state']:
+ # get all the networks
+ networks = [key for key in opt.keys() if key.startswith('network_')]
+ flag_pretrain = False
+ for network in networks:
+ if opt['path'].get(f'pretrain_{network}') is not None:
+ flag_pretrain = True
+ if flag_pretrain:
+ print('pretrain_network path will be ignored during resuming.')
+ # set pretrained model paths
+ for network in networks:
+ name = f'pretrain_{network}'
+ basename = network.replace('network_', '')
+ if opt['path'].get('ignore_resume_networks') is None or (network
+ not in opt['path']['ignore_resume_networks']):
+ opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
+ print(f"Set {name} to {opt['path'][name]}")
+
+ # change param_key to params in resume
+ param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')]
+ for param_key in param_keys:
+ if opt['path'][param_key] == 'params_ema':
+ opt['path'][param_key] = 'params'
+ print(f'Set {param_key} to params')
+
+
+def sizeof_fmt(size, suffix='B'):
+ """Get human readable file size.
+
+ Args:
+ size (int): File size.
+ suffix (str): Suffix. Default: 'B'.
+
+ Return:
+ str: Formatted file size.
+ """
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
+ if abs(size) < 1024.0:
+ return f'{size:3.1f} {unit}{suffix}'
+ size /= 1024.0
+ return f'{size:3.1f} Y{suffix}'
diff --git a/r_basicsr/utils/options.py b/r_basicsr/utils/options.py
new file mode 100644
index 0000000..cf90df4
--- /dev/null
+++ b/r_basicsr/utils/options.py
@@ -0,0 +1,194 @@
+import argparse
+import random
+import torch
+import yaml
+from collections import OrderedDict
+from os import path as osp
+
+from r_basicsr.utils import set_random_seed
+from r_basicsr.utils.dist_util import get_dist_info, init_dist, master_only
+
+
+def ordered_yaml():
+ """Support OrderedDict for yaml.
+
+ Returns:
+ yaml Loader and Dumper.
+ """
+ try:
+ from yaml import CDumper as Dumper
+ from yaml import CLoader as Loader
+ except ImportError:
+ from yaml import Dumper, Loader
+
+ _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
+
+ def dict_representer(dumper, data):
+ return dumper.represent_dict(data.items())
+
+ def dict_constructor(loader, node):
+ return OrderedDict(loader.construct_pairs(node))
+
+ Dumper.add_representer(OrderedDict, dict_representer)
+ Loader.add_constructor(_mapping_tag, dict_constructor)
+ return Loader, Dumper
+
+
+def dict2str(opt, indent_level=1):
+ """dict to string for printing options.
+
+ Args:
+ opt (dict): Option dict.
+ indent_level (int): Indent level. Default: 1.
+
+ Return:
+ (str): Option string for printing.
+ """
+ msg = '\n'
+ for k, v in opt.items():
+ if isinstance(v, dict):
+ msg += ' ' * (indent_level * 2) + k + ':['
+ msg += dict2str(v, indent_level + 1)
+ msg += ' ' * (indent_level * 2) + ']\n'
+ else:
+ msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
+ return msg
+
+
+def _postprocess_yml_value(value):
+ # None
+ if value == '~' or value.lower() == 'none':
+ return None
+ # bool
+ if value.lower() == 'true':
+ return True
+ elif value.lower() == 'false':
+ return False
+ # !!float number
+ if value.startswith('!!float'):
+ return float(value.replace('!!float', ''))
+ # number
+ if value.isdigit():
+ return int(value)
+ elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
+ return float(value)
+ # list
+ if value.startswith('['):
+ return eval(value)
+ # str
+ return value
+
+
+def parse_options(root_path, is_train=True):
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
+ parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
+ parser.add_argument('--auto_resume', action='store_true')
+ parser.add_argument('--debug', action='store_true')
+ parser.add_argument('--local_rank', type=int, default=0)
+ parser.add_argument(
+ '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
+ args = parser.parse_args()
+
+ # parse yml to dict
+ with open(args.opt, mode='r') as f:
+ opt = yaml.load(f, Loader=ordered_yaml()[0])
+
+ # distributed settings
+ if args.launcher == 'none':
+ opt['dist'] = False
+ print('Disable distributed.', flush=True)
+ else:
+ opt['dist'] = True
+ if args.launcher == 'slurm' and 'dist_params' in opt:
+ init_dist(args.launcher, **opt['dist_params'])
+ else:
+ init_dist(args.launcher)
+ opt['rank'], opt['world_size'] = get_dist_info()
+
+ # random seed
+ seed = opt.get('manual_seed')
+ if seed is None:
+ seed = random.randint(1, 10000)
+ opt['manual_seed'] = seed
+ set_random_seed(seed + opt['rank'])
+
+ # force to update yml options
+ if args.force_yml is not None:
+ for entry in args.force_yml:
+ # now do not support creating new keys
+ keys, value = entry.split('=')
+ keys, value = keys.strip(), value.strip()
+ value = _postprocess_yml_value(value)
+ eval_str = 'opt'
+ for key in keys.split(':'):
+ eval_str += f'["{key}"]'
+ eval_str += '=value'
+ # using exec function
+ exec(eval_str)
+
+ opt['auto_resume'] = args.auto_resume
+ opt['is_train'] = is_train
+
+ # debug setting
+ if args.debug and not opt['name'].startswith('debug'):
+ opt['name'] = 'debug_' + opt['name']
+
+ if opt['num_gpu'] == 'auto':
+ opt['num_gpu'] = torch.cuda.device_count()
+
+ # datasets
+ for phase, dataset in opt['datasets'].items():
+ # for multiple datasets, e.g., val_1, val_2; test_1, test_2
+ phase = phase.split('_')[0]
+ dataset['phase'] = phase
+ if 'scale' in opt:
+ dataset['scale'] = opt['scale']
+ if dataset.get('dataroot_gt') is not None:
+ dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
+ if dataset.get('dataroot_lq') is not None:
+ dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
+
+ # paths
+ for key, val in opt['path'].items():
+ if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
+ opt['path'][key] = osp.expanduser(val)
+
+ if is_train:
+ experiments_root = osp.join(root_path, 'experiments', opt['name'])
+ opt['path']['experiments_root'] = experiments_root
+ opt['path']['models'] = osp.join(experiments_root, 'models')
+ opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
+ opt['path']['log'] = experiments_root
+ opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
+
+ # change some options for debug mode
+ if 'debug' in opt['name']:
+ if 'val' in opt:
+ opt['val']['val_freq'] = 8
+ opt['logger']['print_freq'] = 1
+ opt['logger']['save_checkpoint_freq'] = 8
+ else: # test
+ results_root = osp.join(root_path, 'results', opt['name'])
+ opt['path']['results_root'] = results_root
+ opt['path']['log'] = results_root
+ opt['path']['visualization'] = osp.join(results_root, 'visualization')
+
+ return opt, args
+
+
+@master_only
+def copy_opt_file(opt_file, experiments_root):
+ # copy the yml file to the experiment root
+ import sys
+ import time
+ from shutil import copyfile
+ cmd = ' '.join(sys.argv)
+ filename = osp.join(experiments_root, osp.basename(opt_file))
+ copyfile(opt_file, filename)
+
+ with open(filename, 'r+') as f:
+ lines = f.readlines()
+ lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
+ f.seek(0)
+ f.writelines(lines)
diff --git a/r_basicsr/utils/plot_util.py b/r_basicsr/utils/plot_util.py
new file mode 100644
index 0000000..3bd950c
--- /dev/null
+++ b/r_basicsr/utils/plot_util.py
@@ -0,0 +1,84 @@
+import re
+
+
+def read_data_from_tensorboard(log_path, tag):
+ """Get raw data (steps and values) from tensorboard events.
+
+ Args:
+ log_path (str): Path to the tensorboard log.
+ tag (str): tag to be read.
+ """
+ from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
+
+ # tensorboard event
+ event_acc = EventAccumulator(log_path)
+ event_acc.Reload()
+ scalar_list = event_acc.Tags()['scalars']
+ print('tag list: ', scalar_list)
+ steps = [int(s.step) for s in event_acc.Scalars(tag)]
+ values = [s.value for s in event_acc.Scalars(tag)]
+ return steps, values
+
+
+def read_data_from_txt_2v(path, pattern, step_one=False):
+ """Read data from txt with 2 returned values (usually [step, value]).
+
+ Args:
+ path (str): path to the txt file.
+ pattern (str): re (regular expression) pattern.
+ step_one (bool): add 1 to steps. Default: False.
+ """
+ with open(path) as f:
+ lines = f.readlines()
+ lines = [line.strip() for line in lines]
+ steps = []
+ values = []
+
+ pattern = re.compile(pattern)
+ for line in lines:
+ match = pattern.match(line)
+ if match:
+ steps.append(int(match.group(1)))
+ values.append(float(match.group(2)))
+ if step_one:
+ steps = [v + 1 for v in steps]
+ return steps, values
+
+
+def read_data_from_txt_1v(path, pattern):
+ """Read data from txt with 1 returned values.
+
+ Args:
+ path (str): path to the txt file.
+ pattern (str): re (regular expression) pattern.
+ """
+ with open(path) as f:
+ lines = f.readlines()
+ lines = [line.strip() for line in lines]
+ data = []
+
+ pattern = re.compile(pattern)
+ for line in lines:
+ match = pattern.match(line)
+ if match:
+ data.append(float(match.group(1)))
+ return data
+
+
+def smooth_data(values, smooth_weight):
+ """ Smooth data using 1st-order IIR low-pass filter (what tensorflow does).
+
+ Ref: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/\
+ tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704
+
+ Args:
+ values (list): A list of values to be smoothed.
+ smooth_weight (float): Smooth weight.
+ """
+ values_sm = []
+ last_sm_value = values[0]
+ for value in values:
+ value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value
+ values_sm.append(value_sm)
+ last_sm_value = value_sm
+ return values_sm
diff --git a/r_basicsr/utils/registry.py b/r_basicsr/utils/registry.py
new file mode 100644
index 0000000..1745e94
--- /dev/null
+++ b/r_basicsr/utils/registry.py
@@ -0,0 +1,88 @@
+# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
+
+
+class Registry():
+ """
+ The registry that provides name -> object mapping, to support third-party
+ users' custom modules.
+
+ To create a registry (e.g. a backbone registry):
+
+ .. code-block:: python
+
+ BACKBONE_REGISTRY = Registry('BACKBONE')
+
+ To register an object:
+
+ .. code-block:: python
+
+ @BACKBONE_REGISTRY.register()
+ class MyBackbone():
+ ...
+
+ Or:
+
+ .. code-block:: python
+
+ BACKBONE_REGISTRY.register(MyBackbone)
+ """
+
+ def __init__(self, name):
+ """
+ Args:
+ name (str): the name of this registry
+ """
+ self._name = name
+ self._obj_map = {}
+
+ def _do_register(self, name, obj, suffix=None):
+ if isinstance(suffix, str):
+ name = name + '_' + suffix
+
+ assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
+ f"in '{self._name}' registry!")
+ self._obj_map[name] = obj
+
+ def register(self, obj=None, suffix=None):
+ """
+ Register the given object under the the name `obj.__name__`.
+ Can be used as either a decorator or not.
+ See docstring of this class for usage.
+ """
+ if obj is None:
+ # used as a decorator
+ def deco(func_or_class):
+ name = func_or_class.__name__
+ self._do_register(name, func_or_class, suffix)
+ return func_or_class
+
+ return deco
+
+ # used as a function call
+ name = obj.__name__
+ self._do_register(name, obj, suffix)
+
+ def get(self, name, suffix='basicsr'):
+ ret = self._obj_map.get(name)
+ if ret is None:
+ ret = self._obj_map.get(name + '_' + suffix)
+ print(f'Name {name} is not found, use name: {name}_{suffix}!')
+ if ret is None:
+ raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
+ return ret
+
+ def __contains__(self, name):
+ return name in self._obj_map
+
+ def __iter__(self):
+ return iter(self._obj_map.items())
+
+ def keys(self):
+ return self._obj_map.keys()
+
+
+DATASET_REGISTRY = Registry('dataset')
+ARCH_REGISTRY = Registry('arch')
+MODEL_REGISTRY = Registry('model')
+LOSS_REGISTRY = Registry('loss')
+METRIC_REGISTRY = Registry('metric')