diff options
author | Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> | 2025-01-17 11:00:30 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-17 11:00:30 +0000 |
commit | 495ffc4777522e40941753e3b1b79c02f84b25b4 (patch) | |
tree | 5130fcb8676afdcb619a5e5eaef3ac28e135bc08 /r_basicsr/data | |
parent | febd45814cd41560c5247aacb111d8d013f3a303 (diff) | |
download | Comfyui-reactor-node-495ffc4777522e40941753e3b1b79c02f84b25b4.tar.gz |
Add files via upload
Diffstat (limited to 'r_basicsr/data')
-rw-r--r-- | r_basicsr/data/__init__.py | 101 | ||||
-rw-r--r-- | r_basicsr/data/data_sampler.py | 48 | ||||
-rw-r--r-- | r_basicsr/data/data_util.py | 313 | ||||
-rw-r--r-- | r_basicsr/data/degradations.py | 768 | ||||
-rw-r--r-- | r_basicsr/data/ffhq_dataset.py | 80 | ||||
-rw-r--r-- | r_basicsr/data/paired_image_dataset.py | 108 | ||||
-rw-r--r-- | r_basicsr/data/prefetch_dataloader.py | 125 | ||||
-rw-r--r-- | r_basicsr/data/realesrgan_dataset.py | 193 | ||||
-rw-r--r-- | r_basicsr/data/realesrgan_paired_dataset.py | 109 | ||||
-rw-r--r-- | r_basicsr/data/reds_dataset.py | 360 | ||||
-rw-r--r-- | r_basicsr/data/single_image_dataset.py | 68 | ||||
-rw-r--r-- | r_basicsr/data/transforms.py | 179 | ||||
-rw-r--r-- | r_basicsr/data/video_test_dataset.py | 287 | ||||
-rw-r--r-- | r_basicsr/data/vimeo90k_dataset.py | 192 |
14 files changed, 2931 insertions, 0 deletions
diff --git a/r_basicsr/data/__init__.py b/r_basicsr/data/__init__.py new file mode 100644 index 0000000..c8ae411 --- /dev/null +++ b/r_basicsr/data/__init__.py @@ -0,0 +1,101 @@ +import importlib
+import numpy as np
+import random
+import torch
+import torch.utils.data
+from copy import deepcopy
+from functools import partial
+from os import path as osp
+
+from r_basicsr.data.prefetch_dataloader import PrefetchDataLoader
+from r_basicsr.utils import get_root_logger, scandir
+from r_basicsr.utils.dist_util import get_dist_info
+from r_basicsr.utils.registry import DATASET_REGISTRY
+
+__all__ = ['build_dataset', 'build_dataloader']
+
+# automatically scan and import dataset modules for registry
+# scan all the files under the data folder with '_dataset' in file names
+data_folder = osp.dirname(osp.abspath(__file__))
+dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
+# import all the dataset modules
+_dataset_modules = [importlib.import_module(f'r_basicsr.data.{file_name}') for file_name in dataset_filenames]
+
+
+def build_dataset(dataset_opt):
+ """Build dataset from options.
+
+ Args:
+ dataset_opt (dict): Configuration for dataset. It must contain:
+ name (str): Dataset name.
+ type (str): Dataset type.
+ """
+ dataset_opt = deepcopy(dataset_opt)
+ dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
+ logger = get_root_logger()
+ logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
+ return dataset
+
+
+def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
+ """Build dataloader.
+
+ Args:
+ dataset (torch.utils.data.Dataset): Dataset.
+ dataset_opt (dict): Dataset options. It contains the following keys:
+ phase (str): 'train' or 'val'.
+ num_worker_per_gpu (int): Number of workers for each GPU.
+ batch_size_per_gpu (int): Training batch size for each GPU.
+ num_gpu (int): Number of GPUs. Used only in the train phase.
+ Default: 1.
+ dist (bool): Whether in distributed training. Used only in the train
+ phase. Default: False.
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
+ seed (int | None): Seed. Default: None
+ """
+ phase = dataset_opt['phase']
+ rank, _ = get_dist_info()
+ if phase == 'train':
+ if dist: # distributed training
+ batch_size = dataset_opt['batch_size_per_gpu']
+ num_workers = dataset_opt['num_worker_per_gpu']
+ else: # non-distributed training
+ multiplier = 1 if num_gpu == 0 else num_gpu
+ batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
+ num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
+ dataloader_args = dict(
+ dataset=dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=num_workers,
+ sampler=sampler,
+ drop_last=True)
+ if sampler is None:
+ dataloader_args['shuffle'] = True
+ dataloader_args['worker_init_fn'] = partial(
+ worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
+ elif phase in ['val', 'test']: # validation
+ dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
+ else:
+ raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
+
+ dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
+ dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
+
+ prefetch_mode = dataset_opt.get('prefetch_mode')
+ if prefetch_mode == 'cpu': # CPUPrefetcher
+ num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
+ logger = get_root_logger()
+ logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
+ return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
+ else:
+ # prefetch_mode=None: Normal dataloader
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
+ return torch.utils.data.DataLoader(**dataloader_args)
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ # Set the worker seed to num_workers * rank + worker_id + seed
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
diff --git a/r_basicsr/data/data_sampler.py b/r_basicsr/data/data_sampler.py new file mode 100644 index 0000000..5135c7f --- /dev/null +++ b/r_basicsr/data/data_sampler.py @@ -0,0 +1,48 @@ +import math
+import torch
+from torch.utils.data.sampler import Sampler
+
+
+class EnlargedSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+
+ Modified from torch.utils.data.distributed.DistributedSampler
+ Support enlarging the dataset for iteration-based training, for saving
+ time when restart the dataloader after each epoch
+
+ Args:
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
+ num_replicas (int | None): Number of processes participating in
+ the training. It is usually the world_size.
+ rank (int | None): Rank of the current process within num_replicas.
+ ratio (int): Enlarging ratio. Default: 1.
+ """
+
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
+ self.total_size = self.num_samples * self.num_replicas
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(self.total_size, generator=g).tolist()
+
+ dataset_size = len(self.dataset)
+ indices = [v % dataset_size for v in indices]
+
+ # subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/r_basicsr/data/data_util.py b/r_basicsr/data/data_util.py new file mode 100644 index 0000000..244e5ba --- /dev/null +++ b/r_basicsr/data/data_util.py @@ -0,0 +1,313 @@ +import cv2
+import numpy as np
+import torch
+from os import path as osp
+from torch.nn import functional as F
+
+from r_basicsr.data.transforms import mod_crop
+from r_basicsr.utils import img2tensor, scandir
+
+
+def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
+ """Read a sequence of images from a given folder path.
+
+ Args:
+ path (list[str] | str): List of image paths or image folder path.
+ require_mod_crop (bool): Require mod crop for each image.
+ Default: False.
+ scale (int): Scale factor for mod_crop. Default: 1.
+ return_imgname(bool): Whether return image names. Default False.
+
+ Returns:
+ Tensor: size (t, c, h, w), RGB, [0, 1].
+ list[str]: Returned image name list.
+ """
+ if isinstance(path, list):
+ img_paths = path
+ else:
+ img_paths = sorted(list(scandir(path, full_path=True)))
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
+
+ if require_mod_crop:
+ imgs = [mod_crop(img, scale) for img in imgs]
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
+ imgs = torch.stack(imgs, dim=0)
+
+ if return_imgname:
+ imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
+ return imgs, imgnames
+ else:
+ return imgs
+
+
+def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
+ """Generate an index list for reading `num_frames` frames from a sequence
+ of images.
+
+ Args:
+ crt_idx (int): Current center index.
+ max_frame_num (int): Max number of the sequence of images (from 1).
+ num_frames (int): Reading num_frames frames.
+ padding (str): Padding mode, one of
+ 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
+ Examples: current_idx = 0, num_frames = 5
+ The generated frame indices under different padding mode:
+ replicate: [0, 0, 0, 1, 2]
+ reflection: [2, 1, 0, 1, 2]
+ reflection_circle: [4, 3, 0, 1, 2]
+ circle: [3, 4, 0, 1, 2]
+
+ Returns:
+ list[int]: A list of indices.
+ """
+ assert num_frames % 2 == 1, 'num_frames should be an odd number.'
+ assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
+
+ max_frame_num = max_frame_num - 1 # start from 0
+ num_pad = num_frames // 2
+
+ indices = []
+ for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
+ if i < 0:
+ if padding == 'replicate':
+ pad_idx = 0
+ elif padding == 'reflection':
+ pad_idx = -i
+ elif padding == 'reflection_circle':
+ pad_idx = crt_idx + num_pad - i
+ else:
+ pad_idx = num_frames + i
+ elif i > max_frame_num:
+ if padding == 'replicate':
+ pad_idx = max_frame_num
+ elif padding == 'reflection':
+ pad_idx = max_frame_num * 2 - i
+ elif padding == 'reflection_circle':
+ pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
+ else:
+ pad_idx = i - num_frames
+ else:
+ pad_idx = i
+ indices.append(pad_idx)
+ return indices
+
+
+def paired_paths_from_lmdb(folders, keys):
+ """Generate paired paths from lmdb files.
+
+ Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
+
+ lq.lmdb
+ ├── data.mdb
+ ├── lock.mdb
+ ├── meta_info.txt
+
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
+ https://lmdb.readthedocs.io/en/release/ for more details.
+
+ The meta_info.txt is a specified txt file to record the meta information
+ of our datasets. It will be automatically created when preparing
+ datasets by our provided dataset tools.
+ Each line in the txt file records
+ 1)image name (with extension),
+ 2)image shape,
+ 3)compression level, separated by a white space.
+ Example: `baboon.png (120,125,3) 1`
+
+ We use the image name without extension as the lmdb key.
+ Note that we use the same key for the corresponding lq and gt images.
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ Note that this key is different from lmdb keys.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
+ raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
+ f'formats. But received {input_key}: {input_folder}; '
+ f'{gt_key}: {gt_folder}')
+ # ensure that the two meta_info files are the same
+ with open(osp.join(input_folder, 'meta_info.txt')) as fin:
+ input_lmdb_keys = [line.split('.')[0] for line in fin]
+ with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
+ gt_lmdb_keys = [line.split('.')[0] for line in fin]
+ if set(input_lmdb_keys) != set(gt_lmdb_keys):
+ raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
+ else:
+ paths = []
+ for lmdb_key in sorted(input_lmdb_keys):
+ paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
+ return paths
+
+
+def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
+ """Generate paired paths from an meta information file.
+
+ Each line in the meta information file contains the image names and
+ image shape (usually for gt), separated by a white space.
+
+ Example of an meta information file:
+ ```
+ 0001_s001.png (480,480,3)
+ 0001_s002.png (480,480,3)
+ ```
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ meta_info_file (str): Path to the meta information file.
+ filename_tmpl (str): Template for each filename. Note that the
+ template excludes the file extension. Usually the filename_tmpl is
+ for files in the input folder.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ with open(meta_info_file, 'r') as fin:
+ gt_names = [line.strip().split(' ')[0] for line in fin]
+
+ paths = []
+ for gt_name in gt_names:
+ basename, ext = osp.splitext(osp.basename(gt_name))
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
+ input_path = osp.join(input_folder, input_name)
+ gt_path = osp.join(gt_folder, gt_name)
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
+ return paths
+
+
+def paired_paths_from_folder(folders, keys, filename_tmpl):
+ """Generate paired paths from folders.
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ filename_tmpl (str): Template for each filename. Note that the
+ template excludes the file extension. Usually the filename_tmpl is
+ for files in the input folder.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ input_paths = list(scandir(input_folder))
+ gt_paths = list(scandir(gt_folder))
+ assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
+ f'{len(input_paths)}, {len(gt_paths)}.')
+ paths = []
+ for gt_path in gt_paths:
+ basename, ext = osp.splitext(osp.basename(gt_path))
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
+ input_path = osp.join(input_folder, input_name)
+ assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
+ gt_path = osp.join(gt_folder, gt_path)
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
+ return paths
+
+
+def paths_from_folder(folder):
+ """Generate paths from folder.
+
+ Args:
+ folder (str): Folder path.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+
+ paths = list(scandir(folder))
+ paths = [osp.join(folder, path) for path in paths]
+ return paths
+
+
+def paths_from_lmdb(folder):
+ """Generate paths from lmdb.
+
+ Args:
+ folder (str): Folder path.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ if not folder.endswith('.lmdb'):
+ raise ValueError(f'Folder {folder}folder should in lmdb format.')
+ with open(osp.join(folder, 'meta_info.txt')) as fin:
+ paths = [line.split('.')[0] for line in fin]
+ return paths
+
+
+def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
+ """Generate Gaussian kernel used in `duf_downsample`.
+
+ Args:
+ kernel_size (int): Kernel size. Default: 13.
+ sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
+
+ Returns:
+ np.array: The Gaussian kernel.
+ """
+ from scipy.ndimage import filters as filters
+ kernel = np.zeros((kernel_size, kernel_size))
+ # set element at the middle to one, a dirac delta
+ kernel[kernel_size // 2, kernel_size // 2] = 1
+ # gaussian-smooth the dirac, resulting in a gaussian filter
+ return filters.gaussian_filter(kernel, sigma)
+
+
+def duf_downsample(x, kernel_size=13, scale=4):
+ """Downsamping with Gaussian kernel used in the DUF official code.
+
+ Args:
+ x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
+ kernel_size (int): Kernel size. Default: 13.
+ scale (int): Downsampling factor. Supported scale: (2, 3, 4).
+ Default: 4.
+
+ Returns:
+ Tensor: DUF downsampled frames.
+ """
+ assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
+
+ squeeze_flag = False
+ if x.ndim == 4:
+ squeeze_flag = True
+ x = x.unsqueeze(0)
+ b, t, c, h, w = x.size()
+ x = x.view(-1, 1, h, w)
+ pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
+ x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
+
+ gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
+ gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
+ x = F.conv2d(x, gaussian_filter, stride=scale)
+ x = x[:, :, 2:-2, 2:-2]
+ x = x.view(b, t, c, x.size(2), x.size(3))
+ if squeeze_flag:
+ x = x.squeeze(0)
+ return x
diff --git a/r_basicsr/data/degradations.py b/r_basicsr/data/degradations.py new file mode 100644 index 0000000..c52cd91 --- /dev/null +++ b/r_basicsr/data/degradations.py @@ -0,0 +1,768 @@ +import cv2
+import math
+import numpy as np
+import random
+import torch
+from scipy import special
+from scipy.stats import multivariate_normal
+try:
+ from torchvision.transforms.functional_tensor import rgb_to_grayscale
+except:
+ from torchvision.transforms.functional import rgb_to_grayscale
+
+# -------------------------------------------------------------------- #
+# --------------------------- blur kernels --------------------------- #
+# -------------------------------------------------------------------- #
+
+
+# --------------------------- util functions --------------------------- #
+def sigma_matrix2(sig_x, sig_y, theta):
+ """Calculate the rotated sigma matrix (two dimensional matrix).
+
+ Args:
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+
+ Returns:
+ ndarray: Rotated sigma matrix.
+ """
+ d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
+ u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
+ return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
+
+
+def mesh_grid(kernel_size):
+ """Generate the mesh grid, centering at zero.
+
+ Args:
+ kernel_size (int):
+
+ Returns:
+ xy (ndarray): with the shape (kernel_size, kernel_size, 2)
+ xx (ndarray): with the shape (kernel_size, kernel_size)
+ yy (ndarray): with the shape (kernel_size, kernel_size)
+ """
+ ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
+ xx, yy = np.meshgrid(ax, ax)
+ xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
+ 1))).reshape(kernel_size, kernel_size, 2)
+ return xy, xx, yy
+
+
+def pdf2(sigma_matrix, grid):
+ """Calculate PDF of the bivariate Gaussian distribution.
+
+ Args:
+ sigma_matrix (ndarray): with the shape (2, 2)
+ grid (ndarray): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size.
+
+ Returns:
+ kernel (ndarrray): un-normalized kernel.
+ """
+ inverse_sigma = np.linalg.inv(sigma_matrix)
+ kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
+ return kernel
+
+
+def cdf2(d_matrix, grid):
+ """Calculate the CDF of the standard bivariate Gaussian distribution.
+ Used in skewed Gaussian distribution.
+
+ Args:
+ d_matrix (ndarrasy): skew matrix.
+ grid (ndarray): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size.
+
+ Returns:
+ cdf (ndarray): skewed cdf.
+ """
+ rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
+ grid = np.dot(grid, d_matrix)
+ cdf = rv.cdf(grid)
+ return cdf
+
+
+def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
+ """Generate a bivariate isotropic or anisotropic Gaussian kernel.
+
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
+
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+ isotropic (bool):
+
+ Returns:
+ kernel (ndarray): normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ if isotropic:
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
+ else:
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+ kernel = pdf2(sigma_matrix, grid)
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
+ """Generate a bivariate generalized Gaussian kernel.
+ Described in `Parameter Estimation For Multivariate Generalized
+ Gaussian Distributions`_
+ by Pascal et. al (2013).
+
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
+
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ beta (float): shape parameter, beta = 1 is the normal distribution.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+
+ Returns:
+ kernel (ndarray): normalized kernel.
+
+ .. _Parameter Estimation For Multivariate Generalized Gaussian
+ Distributions: https://arxiv.org/abs/1302.6498
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ if isotropic:
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
+ else:
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+ inverse_sigma = np.linalg.inv(sigma_matrix)
+ kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
+ """Generate a plateau-like anisotropic kernel.
+ 1 / (1+x^(beta))
+
+ Ref: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
+
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
+
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ beta (float): shape parameter, beta = 1 is the normal distribution.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+
+ Returns:
+ kernel (ndarray): normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ if isotropic:
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
+ else:
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+ inverse_sigma = np.linalg.inv(sigma_matrix)
+ kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def random_bivariate_Gaussian(kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ noise_range=None,
+ isotropic=True):
+ """Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
+
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
+
+ Args:
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi, math.pi]
+ noise_range(tuple, optional): multiplicative kernel noise,
+ [0.75, 1.25]. Default: None
+
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+ if isotropic is False:
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+ else:
+ sigma_y = sigma_x
+ rotation = 0
+
+ kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
+
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def random_bivariate_generalized_Gaussian(kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ beta_range,
+ noise_range=None,
+ isotropic=True):
+ """Randomly generate bivariate generalized Gaussian kernels.
+
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
+
+ Args:
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi, math.pi]
+ beta_range (tuple): [0.5, 8]
+ noise_range(tuple, optional): multiplicative kernel noise,
+ [0.75, 1.25]. Default: None
+
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+ if isotropic is False:
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+ else:
+ sigma_y = sigma_x
+ rotation = 0
+
+ # assume beta_range[0] < 1 < beta_range[1]
+ if np.random.uniform() < 0.5:
+ beta = np.random.uniform(beta_range[0], 1)
+ else:
+ beta = np.random.uniform(1, beta_range[1])
+
+ kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
+
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def random_bivariate_plateau(kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ beta_range,
+ noise_range=None,
+ isotropic=True):
+ """Randomly generate bivariate plateau kernels.
+
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
+
+ Args:
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi/2, math.pi/2]
+ beta_range (tuple): [1, 4]
+ noise_range(tuple, optional): multiplicative kernel noise,
+ [0.75, 1.25]. Default: None
+
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+ if isotropic is False:
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+ else:
+ sigma_y = sigma_x
+ rotation = 0
+
+ # TODO: this may be not proper
+ if np.random.uniform() < 0.5:
+ beta = np.random.uniform(beta_range[0], 1)
+ else:
+ beta = np.random.uniform(1, beta_range[1])
+
+ kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+
+ return kernel
+
+
+def random_mixed_kernels(kernel_list,
+ kernel_prob,
+ kernel_size=21,
+ sigma_x_range=(0.6, 5),
+ sigma_y_range=(0.6, 5),
+ rotation_range=(-math.pi, math.pi),
+ betag_range=(0.5, 8),
+ betap_range=(0.5, 8),
+ noise_range=None):
+ """Randomly generate mixed kernels.
+
+ Args:
+ kernel_list (tuple): a list name of kernel types,
+ support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
+ 'plateau_aniso']
+ kernel_prob (tuple): corresponding kernel probability for each
+ kernel type
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi, math.pi]
+ beta_range (tuple): [0.5, 8]
+ noise_range(tuple, optional): multiplicative kernel noise,
+ [0.75, 1.25]. Default: None
+
+ Returns:
+ kernel (ndarray):
+ """
+ kernel_type = random.choices(kernel_list, kernel_prob)[0]
+ if kernel_type == 'iso':
+ kernel = random_bivariate_Gaussian(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
+ elif kernel_type == 'aniso':
+ kernel = random_bivariate_Gaussian(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
+ elif kernel_type == 'generalized_iso':
+ kernel = random_bivariate_generalized_Gaussian(
+ kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ betag_range,
+ noise_range=noise_range,
+ isotropic=True)
+ elif kernel_type == 'generalized_aniso':
+ kernel = random_bivariate_generalized_Gaussian(
+ kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ betag_range,
+ noise_range=noise_range,
+ isotropic=False)
+ elif kernel_type == 'plateau_iso':
+ kernel = random_bivariate_plateau(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
+ elif kernel_type == 'plateau_aniso':
+ kernel = random_bivariate_plateau(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
+ return kernel
+
+
+np.seterr(divide='ignore', invalid='ignore')
+
+
+def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
+ """2D sinc filter, ref: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
+
+ Args:
+ cutoff (float): cutoff frequency in radians (pi is max)
+ kernel_size (int): horizontal and vertical size, must be odd.
+ pad_to (int): pad kernel size to desired size, must be odd or zero.
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ kernel = np.fromfunction(
+ lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
+ (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
+ (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
+ kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
+ kernel = kernel / np.sum(kernel)
+ if pad_to > kernel_size:
+ pad_size = (pad_to - kernel_size) // 2
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
+ return kernel
+
+
+# ------------------------------------------------------------- #
+# --------------------------- noise --------------------------- #
+# ------------------------------------------------------------- #
+
+# ----------------------- Gaussian Noise ----------------------- #
+
+
+def generate_gaussian_noise(img, sigma=10, gray_noise=False):
+ """Generate Gaussian noise.
+
+ Args:
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+ sigma (float): Noise scale (measured in range 255). Default: 10.
+
+ Returns:
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
+ float32.
+ """
+ if gray_noise:
+ noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
+ noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
+ else:
+ noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
+ return noise
+
+
+def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
+ """Add Gaussian noise.
+
+ Args:
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+ sigma (float): Noise scale (measured in range 255). Default: 10.
+
+ Returns:
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
+ float32.
+ """
+ noise = generate_gaussian_noise(img, sigma, gray_noise)
+ out = img + noise
+ if clip and rounds:
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = np.clip(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
+ """Add Gaussian noise (PyTorch version).
+
+ Args:
+ img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
+ scale (float | Tensor): Noise scale. Default: 1.0.
+
+ Returns:
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
+ float32.
+ """
+ b, _, h, w = img.size()
+ if not isinstance(sigma, (float, int)):
+ sigma = sigma.view(img.size(0), 1, 1, 1)
+ if isinstance(gray_noise, (float, int)):
+ cal_gray_noise = gray_noise > 0
+ else:
+ gray_noise = gray_noise.view(b, 1, 1, 1)
+ cal_gray_noise = torch.sum(gray_noise) > 0
+
+ if cal_gray_noise:
+ noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
+ noise_gray = noise_gray.view(b, 1, h, w)
+
+ # always calculate color noise
+ noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
+
+ if cal_gray_noise:
+ noise = noise * (1 - gray_noise) + noise_gray * gray_noise
+ return noise
+
+
+def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
+ """Add Gaussian noise (PyTorch version).
+
+ Args:
+ img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
+ scale (float | Tensor): Noise scale. Default: 1.0.
+
+ Returns:
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
+ float32.
+ """
+ noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
+ out = img + noise
+ if clip and rounds:
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = torch.clamp(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+# ----------------------- Random Gaussian Noise ----------------------- #
+def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
+ sigma = np.random.uniform(sigma_range[0], sigma_range[1])
+ if np.random.uniform() < gray_prob:
+ gray_noise = True
+ else:
+ gray_noise = False
+ return generate_gaussian_noise(img, sigma, gray_noise)
+
+
+def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
+ noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
+ out = img + noise
+ if clip and rounds:
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = np.clip(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
+ sigma = torch.rand(
+ img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
+ gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
+ gray_noise = (gray_noise < gray_prob).float()
+ return generate_gaussian_noise_pt(img, sigma, gray_noise)
+
+
+def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
+ noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
+ out = img + noise
+ if clip and rounds:
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = torch.clamp(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+# ----------------------- Poisson (Shot) Noise ----------------------- #
+
+
+def generate_poisson_noise(img, scale=1.0, gray_noise=False):
+ """Generate poisson noise.
+
+ Ref: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
+
+ Args:
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+ scale (float): Noise scale. Default: 1.0.
+ gray_noise (bool): Whether generate gray noise. Default: False.
+
+ Returns:
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
+ float32.
+ """
+ if gray_noise:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ # round and clip image for counting vals correctly
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = len(np.unique(img))
+ vals = 2**np.ceil(np.log2(vals))
+ out = np.float32(np.random.poisson(img * vals) / float(vals))
+ noise = out - img
+ if gray_noise:
+ noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
+ return noise * scale
+
+
+def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
+ """Add poisson noise.
+
+ Args:
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+ scale (float): Noise scale. Default: 1.0.
+ gray_noise (bool): Whether generate gray noise. Default: False.
+
+ Returns:
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
+ float32.
+ """
+ noise = generate_poisson_noise(img, scale, gray_noise)
+ out = img + noise
+ if clip and rounds:
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = np.clip(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
+ """Generate a batch of poisson noise (PyTorch version)
+
+ Args:
+ img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
+ scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
+ Default: 1.0.
+ gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
+ 0 for False, 1 for True. Default: 0.
+
+ Returns:
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
+ float32.
+ """
+ b, _, h, w = img.size()
+ if isinstance(gray_noise, (float, int)):
+ cal_gray_noise = gray_noise > 0
+ else:
+ gray_noise = gray_noise.view(b, 1, 1, 1)
+ cal_gray_noise = torch.sum(gray_noise) > 0
+ if cal_gray_noise:
+ img_gray = rgb_to_grayscale(img, num_output_channels=1)
+ # round and clip image for counting vals correctly
+ img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
+ # use for-loop to get the unique values for each sample
+ vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
+ vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
+ vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
+ out = torch.poisson(img_gray * vals) / vals
+ noise_gray = out - img_gray
+ noise_gray = noise_gray.expand(b, 3, h, w)
+
+ # always calculate color noise
+ # round and clip image for counting vals correctly
+ img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
+ # use for-loop to get the unique values for each sample
+ vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
+ vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
+ vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
+ out = torch.poisson(img * vals) / vals
+ noise = out - img
+ if cal_gray_noise:
+ noise = noise * (1 - gray_noise) + noise_gray * gray_noise
+ if not isinstance(scale, (float, int)):
+ scale = scale.view(b, 1, 1, 1)
+ return noise * scale
+
+
+def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
+ """Add poisson noise to a batch of images (PyTorch version).
+
+ Args:
+ img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
+ scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
+ Default: 1.0.
+ gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
+ 0 for False, 1 for True. Default: 0.
+
+ Returns:
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
+ float32.
+ """
+ noise = generate_poisson_noise_pt(img, scale, gray_noise)
+ out = img + noise
+ if clip and rounds:
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = torch.clamp(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+# ----------------------- Random Poisson (Shot) Noise ----------------------- #
+
+
+def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
+ scale = np.random.uniform(scale_range[0], scale_range[1])
+ if np.random.uniform() < gray_prob:
+ gray_noise = True
+ else:
+ gray_noise = False
+ return generate_poisson_noise(img, scale, gray_noise)
+
+
+def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
+ noise = random_generate_poisson_noise(img, scale_range, gray_prob)
+ out = img + noise
+ if clip and rounds:
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = np.clip(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
+ scale = torch.rand(
+ img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
+ gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
+ gray_noise = (gray_noise < gray_prob).float()
+ return generate_poisson_noise_pt(img, scale, gray_noise)
+
+
+def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
+ noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
+ out = img + noise
+ if clip and rounds:
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = torch.clamp(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+# ------------------------------------------------------------------------ #
+# --------------------------- JPEG compression --------------------------- #
+# ------------------------------------------------------------------------ #
+
+
+def add_jpg_compression(img, quality=90):
+ """Add JPG compression artifacts.
+
+ Args:
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+ quality (float): JPG compression quality. 0 for lowest quality, 100 for
+ best quality. Default: 90.
+
+ Returns:
+ (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
+ float32.
+ """
+ img = np.clip(img, 0, 1)
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
+ _, encimg = cv2.imencode('.jpg', img * 255., encode_param)
+ img = np.float32(cv2.imdecode(encimg, 1)) / 255.
+ return img
+
+
+def random_add_jpg_compression(img, quality_range=(90, 100)):
+ """Randomly add JPG compression artifacts.
+
+ Args:
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+ quality_range (tuple[float] | list[float]): JPG compression quality
+ range. 0 for lowest quality, 100 for best quality.
+ Default: (90, 100).
+
+ Returns:
+ (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
+ float32.
+ """
+ quality = np.random.uniform(quality_range[0], quality_range[1])
+ return add_jpg_compression(img, quality)
diff --git a/r_basicsr/data/ffhq_dataset.py b/r_basicsr/data/ffhq_dataset.py new file mode 100644 index 0000000..03dc72c --- /dev/null +++ b/r_basicsr/data/ffhq_dataset.py @@ -0,0 +1,80 @@ +import random
+import time
+from os import path as osp
+from torch.utils import data as data
+from torchvision.transforms.functional import normalize
+
+from r_basicsr.data.transforms import augment
+from r_basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from r_basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class FFHQDataset(data.Dataset):
+ """FFHQ dataset for StyleGAN.
+
+ Args:
+ opt (dict): Config for train datasets. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ io_backend (dict): IO backend type and other kwarg.
+ mean (list | tuple): Image mean.
+ std (list | tuple): Image std.
+ use_hflip (bool): Whether to horizontally flip.
+
+ """
+
+ def __init__(self, opt):
+ super(FFHQDataset, self).__init__()
+ self.opt = opt
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+
+ self.gt_folder = opt['dataroot_gt']
+ self.mean = opt['mean']
+ self.std = opt['std']
+
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.io_backend_opt['db_paths'] = self.gt_folder
+ if not self.gt_folder.endswith('.lmdb'):
+ raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
+ self.paths = [line.split('.')[0] for line in fin]
+ else:
+ # FFHQ has 70000 images in total
+ self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)]
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ # load gt image
+ gt_path = self.paths[index]
+ # avoid errors caused by high latency in reading files
+ retry = 3
+ while retry > 0:
+ try:
+ img_bytes = self.file_client.get(gt_path)
+ except Exception as e:
+ logger = get_root_logger()
+ logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}')
+ # change another file to read
+ index = random.randint(0, self.__len__())
+ gt_path = self.paths[index]
+ time.sleep(1) # sleep 1s for occasional server congestion
+ else:
+ break
+ finally:
+ retry -= 1
+ img_gt = imfrombytes(img_bytes, float32=True)
+
+ # random horizontal flip
+ img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
+ # normalize
+ normalize(img_gt, self.mean, self.std, inplace=True)
+ return {'gt': img_gt, 'gt_path': gt_path}
+
+ def __len__(self):
+ return len(self.paths)
diff --git a/r_basicsr/data/paired_image_dataset.py b/r_basicsr/data/paired_image_dataset.py new file mode 100644 index 0000000..76fcd1d --- /dev/null +++ b/r_basicsr/data/paired_image_dataset.py @@ -0,0 +1,108 @@ +from torch.utils import data as data
+from torchvision.transforms.functional import normalize
+
+from r_basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
+from r_basicsr.data.transforms import augment, paired_random_crop
+from r_basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor
+from r_basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class PairedImageDataset(data.Dataset):
+ """Paired image dataset for image restoration.
+
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
+
+ There are three modes:
+ 1. 'lmdb': Use lmdb files.
+ If opt['io_backend'] == lmdb.
+ 2. 'meta_info_file': Use meta information file to generate paths.
+ If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
+ 3. 'folder': Scan folders to generate paths.
+ The rest.
+
+ Args:
+ opt (dict): Config for train datasets. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ meta_info_file (str): Path for meta information file.
+ io_backend (dict): IO backend type and other kwarg.
+ filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
+ Default: '{}'.
+ gt_size (int): Cropped patched size for gt patches.
+ use_hflip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+
+ scale (bool): Scale, which will be added automatically.
+ phase (str): 'train' or 'val'.
+ """
+
+ def __init__(self, opt):
+ super(PairedImageDataset, self).__init__()
+ self.opt = opt
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.mean = opt['mean'] if 'mean' in opt else None
+ self.std = opt['std'] if 'std' in opt else None
+
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
+ if 'filename_tmpl' in opt:
+ self.filename_tmpl = opt['filename_tmpl']
+ else:
+ self.filename_tmpl = '{}'
+
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
+ self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
+ elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
+ self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
+ self.opt['meta_info_file'], self.filename_tmpl)
+ else:
+ self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ scale = self.opt['scale']
+
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
+ # image range: [0, 1], float32.
+ gt_path = self.paths[index]['gt_path']
+ img_bytes = self.file_client.get(gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+ lq_path = self.paths[index]['lq_path']
+ img_bytes = self.file_client.get(lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+
+ # augmentation for training
+ if self.opt['phase'] == 'train':
+ gt_size = self.opt['gt_size']
+ # random crop
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
+ # flip, rotation
+ img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
+
+ # color space transform
+ if 'color' in self.opt and self.opt['color'] == 'y':
+ img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None]
+ img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None]
+
+ # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
+ # TODO: It is better to update the datasets, rather than force to crop
+ if self.opt['phase'] != 'train':
+ img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
+ # normalize
+ if self.mean is not None or self.std is not None:
+ normalize(img_lq, self.mean, self.std, inplace=True)
+ normalize(img_gt, self.mean, self.std, inplace=True)
+
+ return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
+
+ def __len__(self):
+ return len(self.paths)
diff --git a/r_basicsr/data/prefetch_dataloader.py b/r_basicsr/data/prefetch_dataloader.py new file mode 100644 index 0000000..ce12779 --- /dev/null +++ b/r_basicsr/data/prefetch_dataloader.py @@ -0,0 +1,125 @@ +import queue as Queue
+import threading
+import torch
+from torch.utils.data import DataLoader
+
+
+class PrefetchGenerator(threading.Thread):
+ """A general prefetch generator.
+
+ Ref:
+ https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
+
+ Args:
+ generator: Python generator.
+ num_prefetch_queue (int): Number of prefetch queue.
+ """
+
+ def __init__(self, generator, num_prefetch_queue):
+ threading.Thread.__init__(self)
+ self.queue = Queue.Queue(num_prefetch_queue)
+ self.generator = generator
+ self.daemon = True
+ self.start()
+
+ def run(self):
+ for item in self.generator:
+ self.queue.put(item)
+ self.queue.put(None)
+
+ def __next__(self):
+ next_item = self.queue.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __iter__(self):
+ return self
+
+
+class PrefetchDataLoader(DataLoader):
+ """Prefetch version of dataloader.
+
+ Ref:
+ https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
+
+ TODO:
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
+ ddp.
+
+ Args:
+ num_prefetch_queue (int): Number of prefetch queue.
+ kwargs (dict): Other arguments for dataloader.
+ """
+
+ def __init__(self, num_prefetch_queue, **kwargs):
+ self.num_prefetch_queue = num_prefetch_queue
+ super(PrefetchDataLoader, self).__init__(**kwargs)
+
+ def __iter__(self):
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
+
+
+class CPUPrefetcher():
+ """CPU prefetcher.
+
+ Args:
+ loader: Dataloader.
+ """
+
+ def __init__(self, loader):
+ self.ori_loader = loader
+ self.loader = iter(loader)
+
+ def next(self):
+ try:
+ return next(self.loader)
+ except StopIteration:
+ return None
+
+ def reset(self):
+ self.loader = iter(self.ori_loader)
+
+
+class CUDAPrefetcher():
+ """CUDA prefetcher.
+
+ Ref:
+ https://github.com/NVIDIA/apex/issues/304#
+
+ It may consums more GPU memory.
+
+ Args:
+ loader: Dataloader.
+ opt (dict): Options.
+ """
+
+ def __init__(self, loader, opt):
+ self.ori_loader = loader
+ self.loader = iter(loader)
+ self.opt = opt
+ self.stream = torch.cuda.Stream()
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
+ self.preload()
+
+ def preload(self):
+ try:
+ self.batch = next(self.loader) # self.batch is a dict
+ except StopIteration:
+ self.batch = None
+ return None
+ # put tensors to gpu
+ with torch.cuda.stream(self.stream):
+ for k, v in self.batch.items():
+ if torch.is_tensor(v):
+ self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
+
+ def next(self):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ batch = self.batch
+ self.preload()
+ return batch
+
+ def reset(self):
+ self.loader = iter(self.ori_loader)
+ self.preload()
diff --git a/r_basicsr/data/realesrgan_dataset.py b/r_basicsr/data/realesrgan_dataset.py new file mode 100644 index 0000000..dd9ae11 --- /dev/null +++ b/r_basicsr/data/realesrgan_dataset.py @@ -0,0 +1,193 @@ +import cv2
+import math
+import numpy as np
+import os
+import os.path as osp
+import random
+import time
+import torch
+from torch.utils import data as data
+
+from r_basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
+from r_basicsr.data.transforms import augment
+from r_basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from r_basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register(suffix='basicsr')
+class RealESRGANDataset(data.Dataset):
+ """Dataset used for Real-ESRGAN model:
+ Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+ It loads gt (Ground-Truth) images, and augments them.
+ It also generates blur kernels and sinc kernels for generating low-quality images.
+ Note that the low-quality images are processed in tensors on GPUS for faster processing.
+
+ Args:
+ opt (dict): Config for train datasets. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ meta_info (str): Path for meta information file.
+ io_backend (dict): IO backend type and other kwarg.
+ use_hflip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+ Please see more options in the codes.
+ """
+
+ def __init__(self, opt):
+ super(RealESRGANDataset, self).__init__()
+ self.opt = opt
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.gt_folder = opt['dataroot_gt']
+
+ # file client (lmdb io backend)
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.io_backend_opt['db_paths'] = [self.gt_folder]
+ self.io_backend_opt['client_keys'] = ['gt']
+ if not self.gt_folder.endswith('.lmdb'):
+ raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
+ self.paths = [line.split('.')[0] for line in fin]
+ else:
+ # disk backend with meta_info
+ # Each line in the meta_info describes the relative path to an image
+ with open(self.opt['meta_info']) as fin:
+ paths = [line.strip().split(' ')[0] for line in fin]
+ self.paths = [os.path.join(self.gt_folder, v) for v in paths]
+
+ # blur settings for the first degradation
+ self.blur_kernel_size = opt['blur_kernel_size']
+ self.kernel_list = opt['kernel_list']
+ self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
+ self.blur_sigma = opt['blur_sigma']
+ self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
+ self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
+ self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
+
+ # blur settings for the second degradation
+ self.blur_kernel_size2 = opt['blur_kernel_size2']
+ self.kernel_list2 = opt['kernel_list2']
+ self.kernel_prob2 = opt['kernel_prob2']
+ self.blur_sigma2 = opt['blur_sigma2']
+ self.betag_range2 = opt['betag_range2']
+ self.betap_range2 = opt['betap_range2']
+ self.sinc_prob2 = opt['sinc_prob2']
+
+ # a final sinc filter
+ self.final_sinc_prob = opt['final_sinc_prob']
+
+ self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
+ # TODO: kernel range is now hard-coded, should be in the configure file
+ self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
+ self.pulse_tensor[10, 10] = 1
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ # -------------------------------- Load gt images -------------------------------- #
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
+ gt_path = self.paths[index]
+ # avoid errors caused by high latency in reading files
+ retry = 3
+ while retry > 0:
+ try:
+ img_bytes = self.file_client.get(gt_path, 'gt')
+ except (IOError, OSError) as e:
+ logger = get_root_logger()
+ logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
+ # change another file to read
+ index = random.randint(0, self.__len__())
+ gt_path = self.paths[index]
+ time.sleep(1) # sleep 1s for occasional server congestion
+ else:
+ break
+ finally:
+ retry -= 1
+ img_gt = imfrombytes(img_bytes, float32=True)
+
+ # -------------------- Do augmentation for training: flip, rotation -------------------- #
+ img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
+
+ # crop or pad to 400
+ # TODO: 400 is hard-coded. You may change it accordingly
+ h, w = img_gt.shape[0:2]
+ crop_pad_size = 400
+ # pad
+ if h < crop_pad_size or w < crop_pad_size:
+ pad_h = max(0, crop_pad_size - h)
+ pad_w = max(0, crop_pad_size - w)
+ img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
+ # crop
+ if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
+ h, w = img_gt.shape[0:2]
+ # randomly choose top and left coordinates
+ top = random.randint(0, h - crop_pad_size)
+ left = random.randint(0, w - crop_pad_size)
+ img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
+
+ # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
+ kernel_size = random.choice(self.kernel_range)
+ if np.random.uniform() < self.opt['sinc_prob']:
+ # this sinc filter setting is for kernels ranging from [7, 21]
+ if kernel_size < 13:
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
+ else:
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
+ kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
+ else:
+ kernel = random_mixed_kernels(
+ self.kernel_list,
+ self.kernel_prob,
+ kernel_size,
+ self.blur_sigma,
+ self.blur_sigma, [-math.pi, math.pi],
+ self.betag_range,
+ self.betap_range,
+ noise_range=None)
+ # pad kernel
+ pad_size = (21 - kernel_size) // 2
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
+
+ # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
+ kernel_size = random.choice(self.kernel_range)
+ if np.random.uniform() < self.opt['sinc_prob2']:
+ if kernel_size < 13:
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
+ else:
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
+ kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
+ else:
+ kernel2 = random_mixed_kernels(
+ self.kernel_list2,
+ self.kernel_prob2,
+ kernel_size,
+ self.blur_sigma2,
+ self.blur_sigma2, [-math.pi, math.pi],
+ self.betag_range2,
+ self.betap_range2,
+ noise_range=None)
+
+ # pad kernel
+ pad_size = (21 - kernel_size) // 2
+ kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
+
+ # ------------------------------------- the final sinc kernel ------------------------------------- #
+ if np.random.uniform() < self.opt['final_sinc_prob']:
+ kernel_size = random.choice(self.kernel_range)
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
+ sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
+ sinc_kernel = torch.FloatTensor(sinc_kernel)
+ else:
+ sinc_kernel = self.pulse_tensor
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
+ kernel = torch.FloatTensor(kernel)
+ kernel2 = torch.FloatTensor(kernel2)
+
+ return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
+ return return_d
+
+ def __len__(self):
+ return len(self.paths)
diff --git a/r_basicsr/data/realesrgan_paired_dataset.py b/r_basicsr/data/realesrgan_paired_dataset.py new file mode 100644 index 0000000..a31dad1 --- /dev/null +++ b/r_basicsr/data/realesrgan_paired_dataset.py @@ -0,0 +1,109 @@ +import os
+from torch.utils import data as data
+from torchvision.transforms.functional import normalize
+
+from r_basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
+from r_basicsr.data.transforms import augment, paired_random_crop
+from r_basicsr.utils import FileClient, imfrombytes, img2tensor
+from r_basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register(suffix='basicsr')
+class RealESRGANPairedDataset(data.Dataset):
+ """Paired image dataset for image restoration.
+
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
+
+ There are three modes:
+ 1. 'lmdb': Use lmdb files.
+ If opt['io_backend'] == lmdb.
+ 2. 'meta_info': Use meta information file to generate paths.
+ If opt['io_backend'] != lmdb and opt['meta_info'] is not None.
+ 3. 'folder': Scan folders to generate paths.
+ The rest.
+
+ Args:
+ opt (dict): Config for train datasets. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ meta_info (str): Path for meta information file.
+ io_backend (dict): IO backend type and other kwarg.
+ filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
+ Default: '{}'.
+ gt_size (int): Cropped patched size for gt patches.
+ use_hflip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h
+ and w for implementation).
+
+ scale (bool): Scale, which will be added automatically.
+ phase (str): 'train' or 'val'.
+ """
+
+ def __init__(self, opt):
+ super(RealESRGANPairedDataset, self).__init__()
+ self.opt = opt
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ # mean and std for normalizing the input images
+ self.mean = opt['mean'] if 'mean' in opt else None
+ self.std = opt['std'] if 'std' in opt else None
+
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
+ self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
+
+ # file client (lmdb io backend)
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
+ self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
+ elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
+ # disk backend with meta_info
+ # Each line in the meta_info describes the relative path to an image
+ with open(self.opt['meta_info']) as fin:
+ paths = [line.strip() for line in fin]
+ self.paths = []
+ for path in paths:
+ gt_path, lq_path = path.split(', ')
+ gt_path = os.path.join(self.gt_folder, gt_path)
+ lq_path = os.path.join(self.lq_folder, lq_path)
+ self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
+ else:
+ # disk backend
+ # it will scan the whole folder to get meta info
+ # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
+ self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ scale = self.opt['scale']
+
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
+ # image range: [0, 1], float32.
+ gt_path = self.paths[index]['gt_path']
+ img_bytes = self.file_client.get(gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+ lq_path = self.paths[index]['lq_path']
+ img_bytes = self.file_client.get(lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+
+ # augmentation for training
+ if self.opt['phase'] == 'train':
+ gt_size = self.opt['gt_size']
+ # random crop
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
+ # flip, rotation
+ img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
+ # normalize
+ if self.mean is not None or self.std is not None:
+ normalize(img_lq, self.mean, self.std, inplace=True)
+ normalize(img_gt, self.mean, self.std, inplace=True)
+
+ return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
+
+ def __len__(self):
+ return len(self.paths)
diff --git a/r_basicsr/data/reds_dataset.py b/r_basicsr/data/reds_dataset.py new file mode 100644 index 0000000..fa7df26 --- /dev/null +++ b/r_basicsr/data/reds_dataset.py @@ -0,0 +1,360 @@ +import numpy as np
+import random
+import torch
+from pathlib import Path
+from torch.utils import data as data
+
+from r_basicsr.data.transforms import augment, paired_random_crop
+from r_basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from r_basicsr.utils.flow_util import dequantize_flow
+from r_basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class REDSDataset(data.Dataset):
+ """REDS dataset for training.
+
+ The keys are generated from a meta info txt file.
+ basicsr/data/meta_info/meta_info_REDS_GT.txt
+
+ Each line contains:
+ 1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
+ a white space.
+ Examples:
+ 000 100 (720,1280,3)
+ 001 100 (720,1280,3)
+ ...
+
+ Key examples: "000/00000000"
+ GT (gt): Ground-Truth;
+ LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
+
+ Args:
+ opt (dict): Config for train dataset. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ dataroot_flow (str, optional): Data root path for flow.
+ meta_info_file (str): Path for meta information file.
+ val_partition (str): Validation partition types. 'REDS4' or
+ 'official'.
+ io_backend (dict): IO backend type and other kwarg.
+
+ num_frame (int): Window size for input frames.
+ gt_size (int): Cropped patched size for gt patches.
+ interval_list (list): Interval list for temporal augmentation.
+ random_reverse (bool): Random reverse input frames.
+ use_hflip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h
+ and w for implementation).
+
+ scale (bool): Scale, which will be added automatically.
+ """
+
+ def __init__(self, opt):
+ super(REDSDataset, self).__init__()
+ self.opt = opt
+ self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
+ self.flow_root = Path(opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None
+ assert opt['num_frame'] % 2 == 1, (f'num_frame should be odd number, but got {opt["num_frame"]}')
+ self.num_frame = opt['num_frame']
+ self.num_half_frames = opt['num_frame'] // 2
+
+ self.keys = []
+ with open(opt['meta_info_file'], 'r') as fin:
+ for line in fin:
+ folder, frame_num, _ = line.split(' ')
+ self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
+
+ # remove the video clips used in validation
+ if opt['val_partition'] == 'REDS4':
+ val_partition = ['000', '011', '015', '020']
+ elif opt['val_partition'] == 'official':
+ val_partition = [f'{v:03d}' for v in range(240, 270)]
+ else:
+ raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
+ f"Supported ones are ['official', 'REDS4'].")
+ self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
+
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.is_lmdb = False
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.is_lmdb = True
+ if self.flow_root is not None:
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
+ else:
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
+
+ # temporal augmentation configs
+ self.interval_list = opt['interval_list']
+ self.random_reverse = opt['random_reverse']
+ interval_str = ','.join(str(x) for x in opt['interval_list'])
+ logger = get_root_logger()
+ logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
+ f'random reverse is {self.random_reverse}.')
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ scale = self.opt['scale']
+ gt_size = self.opt['gt_size']
+ key = self.keys[index]
+ clip_name, frame_name = key.split('/') # key example: 000/00000000
+ center_frame_idx = int(frame_name)
+
+ # determine the neighboring frames
+ interval = random.choice(self.interval_list)
+
+ # ensure not exceeding the borders
+ start_frame_idx = center_frame_idx - self.num_half_frames * interval
+ end_frame_idx = center_frame_idx + self.num_half_frames * interval
+ # each clip has 100 frames starting from 0 to 99
+ while (start_frame_idx < 0) or (end_frame_idx > 99):
+ center_frame_idx = random.randint(0, 99)
+ start_frame_idx = (center_frame_idx - self.num_half_frames * interval)
+ end_frame_idx = center_frame_idx + self.num_half_frames * interval
+ frame_name = f'{center_frame_idx:08d}'
+ neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval))
+ # random reverse
+ if self.random_reverse and random.random() < 0.5:
+ neighbor_list.reverse()
+
+ assert len(neighbor_list) == self.num_frame, (f'Wrong length of neighbor list: {len(neighbor_list)}')
+
+ # get the GT frame (as the center frame)
+ if self.is_lmdb:
+ img_gt_path = f'{clip_name}/{frame_name}'
+ else:
+ img_gt_path = self.gt_root / clip_name / f'{frame_name}.png'
+ img_bytes = self.file_client.get(img_gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+
+ # get the neighboring LQ frames
+ img_lqs = []
+ for neighbor in neighbor_list:
+ if self.is_lmdb:
+ img_lq_path = f'{clip_name}/{neighbor:08d}'
+ else:
+ img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
+ img_bytes = self.file_client.get(img_lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+ img_lqs.append(img_lq)
+
+ # get flows
+ if self.flow_root is not None:
+ img_flows = []
+ # read previous flows
+ for i in range(self.num_half_frames, 0, -1):
+ if self.is_lmdb:
+ flow_path = f'{clip_name}/{frame_name}_p{i}'
+ else:
+ flow_path = (self.flow_root / clip_name / f'{frame_name}_p{i}.png')
+ img_bytes = self.file_client.get(flow_path, 'flow')
+ cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
+ dx, dy = np.split(cat_flow, 2, axis=0)
+ flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
+ img_flows.append(flow)
+ # read next flows
+ for i in range(1, self.num_half_frames + 1):
+ if self.is_lmdb:
+ flow_path = f'{clip_name}/{frame_name}_n{i}'
+ else:
+ flow_path = (self.flow_root / clip_name / f'{frame_name}_n{i}.png')
+ img_bytes = self.file_client.get(flow_path, 'flow')
+ cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
+ dx, dy = np.split(cat_flow, 2, axis=0)
+ flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
+ img_flows.append(flow)
+
+ # for random crop, here, img_flows and img_lqs have the same
+ # spatial size
+ img_lqs.extend(img_flows)
+
+ # randomly crop
+ img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
+ if self.flow_root is not None:
+ img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.num_frame:]
+
+ # augmentation - flip, rotate
+ img_lqs.append(img_gt)
+ if self.flow_root is not None:
+ img_results, img_flows = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'], img_flows)
+ else:
+ img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
+
+ img_results = img2tensor(img_results)
+ img_lqs = torch.stack(img_results[0:-1], dim=0)
+ img_gt = img_results[-1]
+
+ if self.flow_root is not None:
+ img_flows = img2tensor(img_flows)
+ # add the zero center flow
+ img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0]))
+ img_flows = torch.stack(img_flows, dim=0)
+
+ # img_lqs: (t, c, h, w)
+ # img_flows: (t, 2, h, w)
+ # img_gt: (c, h, w)
+ # key: str
+ if self.flow_root is not None:
+ return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key}
+ else:
+ return {'lq': img_lqs, 'gt': img_gt, 'key': key}
+
+ def __len__(self):
+ return len(self.keys)
+
+
+@DATASET_REGISTRY.register()
+class REDSRecurrentDataset(data.Dataset):
+ """REDS dataset for training recurrent networks.
+
+ The keys are generated from a meta info txt file.
+ basicsr/data/meta_info/meta_info_REDS_GT.txt
+
+ Each line contains:
+ 1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
+ a white space.
+ Examples:
+ 000 100 (720,1280,3)
+ 001 100 (720,1280,3)
+ ...
+
+ Key examples: "000/00000000"
+ GT (gt): Ground-Truth;
+ LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
+
+ Args:
+ opt (dict): Config for train dataset. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ dataroot_flow (str, optional): Data root path for flow.
+ meta_info_file (str): Path for meta information file.
+ val_partition (str): Validation partition types. 'REDS4' or
+ 'official'.
+ io_backend (dict): IO backend type and other kwarg.
+
+ num_frame (int): Window size for input frames.
+ gt_size (int): Cropped patched size for gt patches.
+ interval_list (list): Interval list for temporal augmentation.
+ random_reverse (bool): Random reverse input frames.
+ use_hflip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h
+ and w for implementation).
+
+ scale (bool): Scale, which will be added automatically.
+ """
+
+ def __init__(self, opt):
+ super(REDSRecurrentDataset, self).__init__()
+ self.opt = opt
+ self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
+ self.num_frame = opt['num_frame']
+
+ self.keys = []
+ with open(opt['meta_info_file'], 'r') as fin:
+ for line in fin:
+ folder, frame_num, _ = line.split(' ')
+ self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
+
+ # remove the video clips used in validation
+ if opt['val_partition'] == 'REDS4':
+ val_partition = ['000', '011', '015', '020']
+ elif opt['val_partition'] == 'official':
+ val_partition = [f'{v:03d}' for v in range(240, 270)]
+ else:
+ raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
+ f"Supported ones are ['official', 'REDS4'].")
+ if opt['test_mode']:
+ self.keys = [v for v in self.keys if v.split('/')[0] in val_partition]
+ else:
+ self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
+
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.is_lmdb = False
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.is_lmdb = True
+ if hasattr(self, 'flow_root') and self.flow_root is not None:
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
+ else:
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
+
+ # temporal augmentation configs
+ self.interval_list = opt.get('interval_list', [1])
+ self.random_reverse = opt.get('random_reverse', False)
+ interval_str = ','.join(str(x) for x in self.interval_list)
+ logger = get_root_logger()
+ logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
+ f'random reverse is {self.random_reverse}.')
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ scale = self.opt['scale']
+ gt_size = self.opt['gt_size']
+ key = self.keys[index]
+ clip_name, frame_name = key.split('/') # key example: 000/00000000
+
+ # determine the neighboring frames
+ interval = random.choice(self.interval_list)
+
+ # ensure not exceeding the borders
+ start_frame_idx = int(frame_name)
+ if start_frame_idx > 100 - self.num_frame * interval:
+ start_frame_idx = random.randint(0, 100 - self.num_frame * interval)
+ end_frame_idx = start_frame_idx + self.num_frame * interval
+
+ neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))
+
+ # random reverse
+ if self.random_reverse and random.random() < 0.5:
+ neighbor_list.reverse()
+
+ # get the neighboring LQ and GT frames
+ img_lqs = []
+ img_gts = []
+ for neighbor in neighbor_list:
+ if self.is_lmdb:
+ img_lq_path = f'{clip_name}/{neighbor:08d}'
+ img_gt_path = f'{clip_name}/{neighbor:08d}'
+ else:
+ img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
+ img_gt_path = self.gt_root / clip_name / f'{neighbor:08d}.png'
+
+ # get LQ
+ img_bytes = self.file_client.get(img_lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+ img_lqs.append(img_lq)
+
+ # get GT
+ img_bytes = self.file_client.get(img_gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+ img_gts.append(img_gt)
+
+ # randomly crop
+ img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
+
+ # augmentation - flip, rotate
+ img_lqs.extend(img_gts)
+ img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
+
+ img_results = img2tensor(img_results)
+ img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0)
+ img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0)
+
+ # img_lqs: (t, c, h, w)
+ # img_gts: (t, c, h, w)
+ # key: str
+ return {'lq': img_lqs, 'gt': img_gts, 'key': key}
+
+ def __len__(self):
+ return len(self.keys)
diff --git a/r_basicsr/data/single_image_dataset.py b/r_basicsr/data/single_image_dataset.py new file mode 100644 index 0000000..91bda89 --- /dev/null +++ b/r_basicsr/data/single_image_dataset.py @@ -0,0 +1,68 @@ +from os import path as osp
+from torch.utils import data as data
+from torchvision.transforms.functional import normalize
+
+from r_basicsr.data.data_util import paths_from_lmdb
+from r_basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir
+from r_basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class SingleImageDataset(data.Dataset):
+ """Read only lq images in the test phase.
+
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
+
+ There are two modes:
+ 1. 'meta_info_file': Use meta information file to generate paths.
+ 2. 'folder': Scan folders to generate paths.
+
+ Args:
+ opt (dict): Config for train datasets. It contains the following keys:
+ dataroot_lq (str): Data root path for lq.
+ meta_info_file (str): Path for meta information file.
+ io_backend (dict): IO backend type and other kwarg.
+ """
+
+ def __init__(self, opt):
+ super(SingleImageDataset, self).__init__()
+ self.opt = opt
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.mean = opt['mean'] if 'mean' in opt else None
+ self.std = opt['std'] if 'std' in opt else None
+ self.lq_folder = opt['dataroot_lq']
+
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.io_backend_opt['db_paths'] = [self.lq_folder]
+ self.io_backend_opt['client_keys'] = ['lq']
+ self.paths = paths_from_lmdb(self.lq_folder)
+ elif 'meta_info_file' in self.opt:
+ with open(self.opt['meta_info_file'], 'r') as fin:
+ self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin]
+ else:
+ self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ # load lq image
+ lq_path = self.paths[index]
+ img_bytes = self.file_client.get(lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+
+ # color space transform
+ if 'color' in self.opt and self.opt['color'] == 'y':
+ img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
+ # normalize
+ if self.mean is not None or self.std is not None:
+ normalize(img_lq, self.mean, self.std, inplace=True)
+ return {'lq': img_lq, 'lq_path': lq_path}
+
+ def __len__(self):
+ return len(self.paths)
diff --git a/r_basicsr/data/transforms.py b/r_basicsr/data/transforms.py new file mode 100644 index 0000000..85d1bc2 --- /dev/null +++ b/r_basicsr/data/transforms.py @@ -0,0 +1,179 @@ +import cv2
+import random
+import torch
+
+
+def mod_crop(img, scale):
+ """Mod crop images, used during testing.
+
+ Args:
+ img (ndarray): Input image.
+ scale (int): Scale factor.
+
+ Returns:
+ ndarray: Result image.
+ """
+ img = img.copy()
+ if img.ndim in (2, 3):
+ h, w = img.shape[0], img.shape[1]
+ h_remainder, w_remainder = h % scale, w % scale
+ img = img[:h - h_remainder, :w - w_remainder, ...]
+ else:
+ raise ValueError(f'Wrong img ndim: {img.ndim}.')
+ return img
+
+
+def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
+ """Paired random crop. Support Numpy array and Tensor inputs.
+
+ It crops lists of lq and gt images with corresponding locations.
+
+ Args:
+ img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
+ should have the same shape. If the input is an ndarray, it will
+ be transformed to a list containing itself.
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
+ should have the same shape. If the input is an ndarray, it will
+ be transformed to a list containing itself.
+ gt_patch_size (int): GT patch size.
+ scale (int): Scale factor.
+ gt_path (str): Path to ground-truth. Default: None.
+
+ Returns:
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
+ only have one element, just return ndarray.
+ """
+
+ if not isinstance(img_gts, list):
+ img_gts = [img_gts]
+ if not isinstance(img_lqs, list):
+ img_lqs = [img_lqs]
+
+ # determine input type: Numpy array or Tensor
+ input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
+
+ if input_type == 'Tensor':
+ h_lq, w_lq = img_lqs[0].size()[-2:]
+ h_gt, w_gt = img_gts[0].size()[-2:]
+ else:
+ h_lq, w_lq = img_lqs[0].shape[0:2]
+ h_gt, w_gt = img_gts[0].shape[0:2]
+ lq_patch_size = gt_patch_size // scale
+
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
+ f'({lq_patch_size}, {lq_patch_size}). '
+ f'Please remove {gt_path}.')
+
+ # randomly choose top and left coordinates for lq patch
+ top = random.randint(0, h_lq - lq_patch_size)
+ left = random.randint(0, w_lq - lq_patch_size)
+
+ # crop lq patch
+ if input_type == 'Tensor':
+ img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
+ else:
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
+
+ # crop corresponding gt patch
+ top_gt, left_gt = int(top * scale), int(left * scale)
+ if input_type == 'Tensor':
+ img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
+ else:
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
+ if len(img_gts) == 1:
+ img_gts = img_gts[0]
+ if len(img_lqs) == 1:
+ img_lqs = img_lqs[0]
+ return img_gts, img_lqs
+
+
+def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
+
+ We use vertical flip and transpose for rotation implementation.
+ All the images in the list use the same augmentation.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
+ is an ndarray, it will be transformed to a list.
+ hflip (bool): Horizontal flip. Default: True.
+ rotation (bool): Ratotation. Default: True.
+ flows (list[ndarray]: Flows to be augmented. If the input is an
+ ndarray, it will be transformed to a list.
+ Dimension is (h, w, 2). Default: None.
+ return_status (bool): Return the status of flip and rotation.
+ Default: False.
+
+ Returns:
+ list[ndarray] | ndarray: Augmented images and flows. If returned
+ results only have one element, just return ndarray.
+
+ """
+ hflip = hflip and random.random() < 0.5
+ vflip = rotation and random.random() < 0.5
+ rot90 = rotation and random.random() < 0.5
+
+ def _augment(img):
+ if hflip: # horizontal
+ cv2.flip(img, 1, img)
+ if vflip: # vertical
+ cv2.flip(img, 0, img)
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ def _augment_flow(flow):
+ if hflip: # horizontal
+ cv2.flip(flow, 1, flow)
+ flow[:, :, 0] *= -1
+ if vflip: # vertical
+ cv2.flip(flow, 0, flow)
+ flow[:, :, 1] *= -1
+ if rot90:
+ flow = flow.transpose(1, 0, 2)
+ flow = flow[:, :, [1, 0]]
+ return flow
+
+ if not isinstance(imgs, list):
+ imgs = [imgs]
+ imgs = [_augment(img) for img in imgs]
+ if len(imgs) == 1:
+ imgs = imgs[0]
+
+ if flows is not None:
+ if not isinstance(flows, list):
+ flows = [flows]
+ flows = [_augment_flow(flow) for flow in flows]
+ if len(flows) == 1:
+ flows = flows[0]
+ return imgs, flows
+ else:
+ if return_status:
+ return imgs, (hflip, vflip, rot90)
+ else:
+ return imgs
+
+
+def img_rotate(img, angle, center=None, scale=1.0):
+ """Rotate image.
+
+ Args:
+ img (ndarray): Image to be rotated.
+ angle (float): Rotation angle in degrees. Positive values mean
+ counter-clockwise rotation.
+ center (tuple[int]): Rotation center. If the center is None,
+ initialize it as the center of the image. Default: None.
+ scale (float): Isotropic scale factor. Default: 1.0.
+ """
+ (h, w) = img.shape[:2]
+
+ if center is None:
+ center = (w // 2, h // 2)
+
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
+ rotated_img = cv2.warpAffine(img, matrix, (w, h))
+ return rotated_img
diff --git a/r_basicsr/data/video_test_dataset.py b/r_basicsr/data/video_test_dataset.py new file mode 100644 index 0000000..7e9db01 --- /dev/null +++ b/r_basicsr/data/video_test_dataset.py @@ -0,0 +1,287 @@ +import glob
+import torch
+from os import path as osp
+from torch.utils import data as data
+
+from r_basicsr.data.data_util import duf_downsample, generate_frame_indices, read_img_seq
+from r_basicsr.utils import get_root_logger, scandir
+from r_basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class VideoTestDataset(data.Dataset):
+ """Video test dataset.
+
+ Supported datasets: Vid4, REDS4, REDSofficial.
+ More generally, it supports testing dataset with following structures:
+
+ dataroot
+ ├── subfolder1
+ ├── frame000
+ ├── frame001
+ ├── ...
+ ├── subfolder1
+ ├── frame000
+ ├── frame001
+ ├── ...
+ ├── ...
+
+ For testing datasets, there is no need to prepare LMDB files.
+
+ Args:
+ opt (dict): Config for train dataset. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ io_backend (dict): IO backend type and other kwarg.
+ cache_data (bool): Whether to cache testing datasets.
+ name (str): Dataset name.
+ meta_info_file (str): The path to the file storing the list of test
+ folders. If not provided, all the folders in the dataroot will
+ be used.
+ num_frame (int): Window size for input frames.
+ padding (str): Padding mode.
+ """
+
+ def __init__(self, opt):
+ super(VideoTestDataset, self).__init__()
+ self.opt = opt
+ self.cache_data = opt['cache_data']
+ self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
+ self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
+
+ logger = get_root_logger()
+ logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
+ self.imgs_lq, self.imgs_gt = {}, {}
+ if 'meta_info_file' in opt:
+ with open(opt['meta_info_file'], 'r') as fin:
+ subfolders = [line.split(' ')[0] for line in fin]
+ subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders]
+ subfolders_gt = [osp.join(self.gt_root, key) for key in subfolders]
+ else:
+ subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
+ subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*')))
+
+ if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']:
+ for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt):
+ # get frame list for lq and gt
+ subfolder_name = osp.basename(subfolder_lq)
+ img_paths_lq = sorted(list(scandir(subfolder_lq, full_path=True)))
+ img_paths_gt = sorted(list(scandir(subfolder_gt, full_path=True)))
+
+ max_idx = len(img_paths_lq)
+ assert max_idx == len(img_paths_gt), (f'Different number of images in lq ({max_idx})'
+ f' and gt folders ({len(img_paths_gt)})')
+
+ self.data_info['lq_path'].extend(img_paths_lq)
+ self.data_info['gt_path'].extend(img_paths_gt)
+ self.data_info['folder'].extend([subfolder_name] * max_idx)
+ for i in range(max_idx):
+ self.data_info['idx'].append(f'{i}/{max_idx}')
+ border_l = [0] * max_idx
+ for i in range(self.opt['num_frame'] // 2):
+ border_l[i] = 1
+ border_l[max_idx - i - 1] = 1
+ self.data_info['border'].extend(border_l)
+
+ # cache data or save the frame list
+ if self.cache_data:
+ logger.info(f'Cache {subfolder_name} for VideoTestDataset...')
+ self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq)
+ self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt)
+ else:
+ self.imgs_lq[subfolder_name] = img_paths_lq
+ self.imgs_gt[subfolder_name] = img_paths_gt
+ else:
+ raise ValueError(f'Non-supported video test dataset: {type(opt["name"])}')
+
+ def __getitem__(self, index):
+ folder = self.data_info['folder'][index]
+ idx, max_idx = self.data_info['idx'][index].split('/')
+ idx, max_idx = int(idx), int(max_idx)
+ border = self.data_info['border'][index]
+ lq_path = self.data_info['lq_path'][index]
+
+ select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
+
+ if self.cache_data:
+ imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
+ img_gt = self.imgs_gt[folder][idx]
+ else:
+ img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
+ imgs_lq = read_img_seq(img_paths_lq)
+ img_gt = read_img_seq([self.imgs_gt[folder][idx]])
+ img_gt.squeeze_(0)
+
+ return {
+ 'lq': imgs_lq, # (t, c, h, w)
+ 'gt': img_gt, # (c, h, w)
+ 'folder': folder, # folder name
+ 'idx': self.data_info['idx'][index], # e.g., 0/99
+ 'border': border, # 1 for border, 0 for non-border
+ 'lq_path': lq_path # center frame
+ }
+
+ def __len__(self):
+ return len(self.data_info['gt_path'])
+
+
+@DATASET_REGISTRY.register()
+class VideoTestVimeo90KDataset(data.Dataset):
+ """Video test dataset for Vimeo90k-Test dataset.
+
+ It only keeps the center frame for testing.
+ For testing datasets, there is no need to prepare LMDB files.
+
+ Args:
+ opt (dict): Config for train dataset. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ io_backend (dict): IO backend type and other kwarg.
+ cache_data (bool): Whether to cache testing datasets.
+ name (str): Dataset name.
+ meta_info_file (str): The path to the file storing the list of test
+ folders. If not provided, all the folders in the dataroot will
+ be used.
+ num_frame (int): Window size for input frames.
+ padding (str): Padding mode.
+ """
+
+ def __init__(self, opt):
+ super(VideoTestVimeo90KDataset, self).__init__()
+ self.opt = opt
+ self.cache_data = opt['cache_data']
+ if self.cache_data:
+ raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.')
+ self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
+ self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
+ neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
+
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
+
+ logger = get_root_logger()
+ logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
+ with open(opt['meta_info_file'], 'r') as fin:
+ subfolders = [line.split(' ')[0] for line in fin]
+ for idx, subfolder in enumerate(subfolders):
+ gt_path = osp.join(self.gt_root, subfolder, 'im4.png')
+ self.data_info['gt_path'].append(gt_path)
+ lq_paths = [osp.join(self.lq_root, subfolder, f'im{i}.png') for i in neighbor_list]
+ self.data_info['lq_path'].append(lq_paths)
+ self.data_info['folder'].append('vimeo90k')
+ self.data_info['idx'].append(f'{idx}/{len(subfolders)}')
+ self.data_info['border'].append(0)
+
+ def __getitem__(self, index):
+ lq_path = self.data_info['lq_path'][index]
+ gt_path = self.data_info['gt_path'][index]
+ imgs_lq = read_img_seq(lq_path)
+ img_gt = read_img_seq([gt_path])
+ img_gt.squeeze_(0)
+
+ return {
+ 'lq': imgs_lq, # (t, c, h, w)
+ 'gt': img_gt, # (c, h, w)
+ 'folder': self.data_info['folder'][index], # folder name
+ 'idx': self.data_info['idx'][index], # e.g., 0/843
+ 'border': self.data_info['border'][index], # 0 for non-border
+ 'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame
+ }
+
+ def __len__(self):
+ return len(self.data_info['gt_path'])
+
+
+@DATASET_REGISTRY.register()
+class VideoTestDUFDataset(VideoTestDataset):
+ """ Video test dataset for DUF dataset.
+
+ Args:
+ opt (dict): Config for train dataset.
+ Most of keys are the same as VideoTestDataset.
+ It has the following extra keys:
+
+ use_duf_downsampling (bool): Whether to use duf downsampling to
+ generate low-resolution frames.
+ scale (bool): Scale, which will be added automatically.
+ """
+
+ def __getitem__(self, index):
+ folder = self.data_info['folder'][index]
+ idx, max_idx = self.data_info['idx'][index].split('/')
+ idx, max_idx = int(idx), int(max_idx)
+ border = self.data_info['border'][index]
+ lq_path = self.data_info['lq_path'][index]
+
+ select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
+
+ if self.cache_data:
+ if self.opt['use_duf_downsampling']:
+ # read imgs_gt to generate low-resolution frames
+ imgs_lq = self.imgs_gt[folder].index_select(0, torch.LongTensor(select_idx))
+ imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
+ else:
+ imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
+ img_gt = self.imgs_gt[folder][idx]
+ else:
+ if self.opt['use_duf_downsampling']:
+ img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx]
+ # read imgs_gt to generate low-resolution frames
+ imgs_lq = read_img_seq(img_paths_lq, require_mod_crop=True, scale=self.opt['scale'])
+ imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
+ else:
+ img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
+ imgs_lq = read_img_seq(img_paths_lq)
+ img_gt = read_img_seq([self.imgs_gt[folder][idx]], require_mod_crop=True, scale=self.opt['scale'])
+ img_gt.squeeze_(0)
+
+ return {
+ 'lq': imgs_lq, # (t, c, h, w)
+ 'gt': img_gt, # (c, h, w)
+ 'folder': folder, # folder name
+ 'idx': self.data_info['idx'][index], # e.g., 0/99
+ 'border': border, # 1 for border, 0 for non-border
+ 'lq_path': lq_path # center frame
+ }
+
+
+@DATASET_REGISTRY.register()
+class VideoRecurrentTestDataset(VideoTestDataset):
+ """Video test dataset for recurrent architectures, which takes LR video
+ frames as input and output corresponding HR video frames.
+
+ Args:
+ Same as VideoTestDataset.
+ Unused opt:
+ padding (str): Padding mode.
+
+ """
+
+ def __init__(self, opt):
+ super(VideoRecurrentTestDataset, self).__init__(opt)
+ # Find unique folder strings
+ self.folders = sorted(list(set(self.data_info['folder'])))
+
+ def __getitem__(self, index):
+ folder = self.folders[index]
+
+ if self.cache_data:
+ imgs_lq = self.imgs_lq[folder]
+ imgs_gt = self.imgs_gt[folder]
+ else:
+ raise NotImplementedError('Without cache_data is not implemented.')
+
+ return {
+ 'lq': imgs_lq,
+ 'gt': imgs_gt,
+ 'folder': folder,
+ }
+
+ def __len__(self):
+ return len(self.folders)
diff --git a/r_basicsr/data/vimeo90k_dataset.py b/r_basicsr/data/vimeo90k_dataset.py new file mode 100644 index 0000000..f20a4fd --- /dev/null +++ b/r_basicsr/data/vimeo90k_dataset.py @@ -0,0 +1,192 @@ +import random
+import torch
+from pathlib import Path
+from torch.utils import data as data
+
+from r_basicsr.data.transforms import augment, paired_random_crop
+from r_basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from r_basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class Vimeo90KDataset(data.Dataset):
+ """Vimeo90K dataset for training.
+
+ The keys are generated from a meta info txt file.
+ basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
+
+ Each line contains:
+ 1. clip name; 2. frame number; 3. image shape, separated by a white space.
+ Examples:
+ 00001/0001 7 (256,448,3)
+ 00001/0002 7 (256,448,3)
+
+ Key examples: "00001/0001"
+ GT (gt): Ground-Truth;
+ LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
+
+ The neighboring frame list for different num_frame:
+ num_frame | frame list
+ 1 | 4
+ 3 | 3,4,5
+ 5 | 2,3,4,5,6
+ 7 | 1,2,3,4,5,6,7
+
+ Args:
+ opt (dict): Config for train dataset. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ meta_info_file (str): Path for meta information file.
+ io_backend (dict): IO backend type and other kwarg.
+
+ num_frame (int): Window size for input frames.
+ gt_size (int): Cropped patched size for gt patches.
+ random_reverse (bool): Random reverse input frames.
+ use_hflip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h
+ and w for implementation).
+
+ scale (bool): Scale, which will be added automatically.
+ """
+
+ def __init__(self, opt):
+ super(Vimeo90KDataset, self).__init__()
+ self.opt = opt
+ self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
+
+ with open(opt['meta_info_file'], 'r') as fin:
+ self.keys = [line.split(' ')[0] for line in fin]
+
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.is_lmdb = False
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.is_lmdb = True
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
+
+ # indices of input images
+ self.neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
+
+ # temporal augmentation configs
+ self.random_reverse = opt['random_reverse']
+ logger = get_root_logger()
+ logger.info(f'Random reverse is {self.random_reverse}.')
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ # random reverse
+ if self.random_reverse and random.random() < 0.5:
+ self.neighbor_list.reverse()
+
+ scale = self.opt['scale']
+ gt_size = self.opt['gt_size']
+ key = self.keys[index]
+ clip, seq = key.split('/') # key example: 00001/0001
+
+ # get the GT frame (im4.png)
+ if self.is_lmdb:
+ img_gt_path = f'{key}/im4'
+ else:
+ img_gt_path = self.gt_root / clip / seq / 'im4.png'
+ img_bytes = self.file_client.get(img_gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+
+ # get the neighboring LQ frames
+ img_lqs = []
+ for neighbor in self.neighbor_list:
+ if self.is_lmdb:
+ img_lq_path = f'{clip}/{seq}/im{neighbor}'
+ else:
+ img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
+ img_bytes = self.file_client.get(img_lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+ img_lqs.append(img_lq)
+
+ # randomly crop
+ img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
+
+ # augmentation - flip, rotate
+ img_lqs.append(img_gt)
+ img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
+
+ img_results = img2tensor(img_results)
+ img_lqs = torch.stack(img_results[0:-1], dim=0)
+ img_gt = img_results[-1]
+
+ # img_lqs: (t, c, h, w)
+ # img_gt: (c, h, w)
+ # key: str
+ return {'lq': img_lqs, 'gt': img_gt, 'key': key}
+
+ def __len__(self):
+ return len(self.keys)
+
+
+@DATASET_REGISTRY.register()
+class Vimeo90KRecurrentDataset(Vimeo90KDataset):
+
+ def __init__(self, opt):
+ super(Vimeo90KRecurrentDataset, self).__init__(opt)
+
+ self.flip_sequence = opt['flip_sequence']
+ self.neighbor_list = [1, 2, 3, 4, 5, 6, 7]
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ # random reverse
+ if self.random_reverse and random.random() < 0.5:
+ self.neighbor_list.reverse()
+
+ scale = self.opt['scale']
+ gt_size = self.opt['gt_size']
+ key = self.keys[index]
+ clip, seq = key.split('/') # key example: 00001/0001
+
+ # get the neighboring LQ and GT frames
+ img_lqs = []
+ img_gts = []
+ for neighbor in self.neighbor_list:
+ if self.is_lmdb:
+ img_lq_path = f'{clip}/{seq}/im{neighbor}'
+ img_gt_path = f'{clip}/{seq}/im{neighbor}'
+ else:
+ img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
+ img_gt_path = self.gt_root / clip / seq / f'im{neighbor}.png'
+ # LQ
+ img_bytes = self.file_client.get(img_lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+ # GT
+ img_bytes = self.file_client.get(img_gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+
+ img_lqs.append(img_lq)
+ img_gts.append(img_gt)
+
+ # randomly crop
+ img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
+
+ # augmentation - flip, rotate
+ img_lqs.extend(img_gts)
+ img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
+
+ img_results = img2tensor(img_results)
+ img_lqs = torch.stack(img_results[:7], dim=0)
+ img_gts = torch.stack(img_results[7:], dim=0)
+
+ if self.flip_sequence: # flip the sequence: 7 frames to 14 frames
+ img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0)
+ img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0)
+
+ # img_lqs: (t, c, h, w)
+ # img_gt: (c, h, w)
+ # key: str
+ return {'lq': img_lqs, 'gt': img_gts, 'key': key}
+
+ def __len__(self):
+ return len(self.keys)
|