diff options
Diffstat (limited to 'r_basicsr/data/video_test_dataset.py')
-rw-r--r-- | r_basicsr/data/video_test_dataset.py | 287 |
1 files changed, 287 insertions, 0 deletions
diff --git a/r_basicsr/data/video_test_dataset.py b/r_basicsr/data/video_test_dataset.py new file mode 100644 index 0000000..7e9db01 --- /dev/null +++ b/r_basicsr/data/video_test_dataset.py @@ -0,0 +1,287 @@ +import glob
+import torch
+from os import path as osp
+from torch.utils import data as data
+
+from r_basicsr.data.data_util import duf_downsample, generate_frame_indices, read_img_seq
+from r_basicsr.utils import get_root_logger, scandir
+from r_basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class VideoTestDataset(data.Dataset):
+ """Video test dataset.
+
+ Supported datasets: Vid4, REDS4, REDSofficial.
+ More generally, it supports testing dataset with following structures:
+
+ dataroot
+ ├── subfolder1
+ ├── frame000
+ ├── frame001
+ ├── ...
+ ├── subfolder1
+ ├── frame000
+ ├── frame001
+ ├── ...
+ ├── ...
+
+ For testing datasets, there is no need to prepare LMDB files.
+
+ Args:
+ opt (dict): Config for train dataset. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ io_backend (dict): IO backend type and other kwarg.
+ cache_data (bool): Whether to cache testing datasets.
+ name (str): Dataset name.
+ meta_info_file (str): The path to the file storing the list of test
+ folders. If not provided, all the folders in the dataroot will
+ be used.
+ num_frame (int): Window size for input frames.
+ padding (str): Padding mode.
+ """
+
+ def __init__(self, opt):
+ super(VideoTestDataset, self).__init__()
+ self.opt = opt
+ self.cache_data = opt['cache_data']
+ self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
+ self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
+
+ logger = get_root_logger()
+ logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
+ self.imgs_lq, self.imgs_gt = {}, {}
+ if 'meta_info_file' in opt:
+ with open(opt['meta_info_file'], 'r') as fin:
+ subfolders = [line.split(' ')[0] for line in fin]
+ subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders]
+ subfolders_gt = [osp.join(self.gt_root, key) for key in subfolders]
+ else:
+ subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
+ subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*')))
+
+ if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']:
+ for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt):
+ # get frame list for lq and gt
+ subfolder_name = osp.basename(subfolder_lq)
+ img_paths_lq = sorted(list(scandir(subfolder_lq, full_path=True)))
+ img_paths_gt = sorted(list(scandir(subfolder_gt, full_path=True)))
+
+ max_idx = len(img_paths_lq)
+ assert max_idx == len(img_paths_gt), (f'Different number of images in lq ({max_idx})'
+ f' and gt folders ({len(img_paths_gt)})')
+
+ self.data_info['lq_path'].extend(img_paths_lq)
+ self.data_info['gt_path'].extend(img_paths_gt)
+ self.data_info['folder'].extend([subfolder_name] * max_idx)
+ for i in range(max_idx):
+ self.data_info['idx'].append(f'{i}/{max_idx}')
+ border_l = [0] * max_idx
+ for i in range(self.opt['num_frame'] // 2):
+ border_l[i] = 1
+ border_l[max_idx - i - 1] = 1
+ self.data_info['border'].extend(border_l)
+
+ # cache data or save the frame list
+ if self.cache_data:
+ logger.info(f'Cache {subfolder_name} for VideoTestDataset...')
+ self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq)
+ self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt)
+ else:
+ self.imgs_lq[subfolder_name] = img_paths_lq
+ self.imgs_gt[subfolder_name] = img_paths_gt
+ else:
+ raise ValueError(f'Non-supported video test dataset: {type(opt["name"])}')
+
+ def __getitem__(self, index):
+ folder = self.data_info['folder'][index]
+ idx, max_idx = self.data_info['idx'][index].split('/')
+ idx, max_idx = int(idx), int(max_idx)
+ border = self.data_info['border'][index]
+ lq_path = self.data_info['lq_path'][index]
+
+ select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
+
+ if self.cache_data:
+ imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
+ img_gt = self.imgs_gt[folder][idx]
+ else:
+ img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
+ imgs_lq = read_img_seq(img_paths_lq)
+ img_gt = read_img_seq([self.imgs_gt[folder][idx]])
+ img_gt.squeeze_(0)
+
+ return {
+ 'lq': imgs_lq, # (t, c, h, w)
+ 'gt': img_gt, # (c, h, w)
+ 'folder': folder, # folder name
+ 'idx': self.data_info['idx'][index], # e.g., 0/99
+ 'border': border, # 1 for border, 0 for non-border
+ 'lq_path': lq_path # center frame
+ }
+
+ def __len__(self):
+ return len(self.data_info['gt_path'])
+
+
+@DATASET_REGISTRY.register()
+class VideoTestVimeo90KDataset(data.Dataset):
+ """Video test dataset for Vimeo90k-Test dataset.
+
+ It only keeps the center frame for testing.
+ For testing datasets, there is no need to prepare LMDB files.
+
+ Args:
+ opt (dict): Config for train dataset. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ io_backend (dict): IO backend type and other kwarg.
+ cache_data (bool): Whether to cache testing datasets.
+ name (str): Dataset name.
+ meta_info_file (str): The path to the file storing the list of test
+ folders. If not provided, all the folders in the dataroot will
+ be used.
+ num_frame (int): Window size for input frames.
+ padding (str): Padding mode.
+ """
+
+ def __init__(self, opt):
+ super(VideoTestVimeo90KDataset, self).__init__()
+ self.opt = opt
+ self.cache_data = opt['cache_data']
+ if self.cache_data:
+ raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.')
+ self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
+ self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
+ neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
+
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
+
+ logger = get_root_logger()
+ logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
+ with open(opt['meta_info_file'], 'r') as fin:
+ subfolders = [line.split(' ')[0] for line in fin]
+ for idx, subfolder in enumerate(subfolders):
+ gt_path = osp.join(self.gt_root, subfolder, 'im4.png')
+ self.data_info['gt_path'].append(gt_path)
+ lq_paths = [osp.join(self.lq_root, subfolder, f'im{i}.png') for i in neighbor_list]
+ self.data_info['lq_path'].append(lq_paths)
+ self.data_info['folder'].append('vimeo90k')
+ self.data_info['idx'].append(f'{idx}/{len(subfolders)}')
+ self.data_info['border'].append(0)
+
+ def __getitem__(self, index):
+ lq_path = self.data_info['lq_path'][index]
+ gt_path = self.data_info['gt_path'][index]
+ imgs_lq = read_img_seq(lq_path)
+ img_gt = read_img_seq([gt_path])
+ img_gt.squeeze_(0)
+
+ return {
+ 'lq': imgs_lq, # (t, c, h, w)
+ 'gt': img_gt, # (c, h, w)
+ 'folder': self.data_info['folder'][index], # folder name
+ 'idx': self.data_info['idx'][index], # e.g., 0/843
+ 'border': self.data_info['border'][index], # 0 for non-border
+ 'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame
+ }
+
+ def __len__(self):
+ return len(self.data_info['gt_path'])
+
+
+@DATASET_REGISTRY.register()
+class VideoTestDUFDataset(VideoTestDataset):
+ """ Video test dataset for DUF dataset.
+
+ Args:
+ opt (dict): Config for train dataset.
+ Most of keys are the same as VideoTestDataset.
+ It has the following extra keys:
+
+ use_duf_downsampling (bool): Whether to use duf downsampling to
+ generate low-resolution frames.
+ scale (bool): Scale, which will be added automatically.
+ """
+
+ def __getitem__(self, index):
+ folder = self.data_info['folder'][index]
+ idx, max_idx = self.data_info['idx'][index].split('/')
+ idx, max_idx = int(idx), int(max_idx)
+ border = self.data_info['border'][index]
+ lq_path = self.data_info['lq_path'][index]
+
+ select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
+
+ if self.cache_data:
+ if self.opt['use_duf_downsampling']:
+ # read imgs_gt to generate low-resolution frames
+ imgs_lq = self.imgs_gt[folder].index_select(0, torch.LongTensor(select_idx))
+ imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
+ else:
+ imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
+ img_gt = self.imgs_gt[folder][idx]
+ else:
+ if self.opt['use_duf_downsampling']:
+ img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx]
+ # read imgs_gt to generate low-resolution frames
+ imgs_lq = read_img_seq(img_paths_lq, require_mod_crop=True, scale=self.opt['scale'])
+ imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
+ else:
+ img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
+ imgs_lq = read_img_seq(img_paths_lq)
+ img_gt = read_img_seq([self.imgs_gt[folder][idx]], require_mod_crop=True, scale=self.opt['scale'])
+ img_gt.squeeze_(0)
+
+ return {
+ 'lq': imgs_lq, # (t, c, h, w)
+ 'gt': img_gt, # (c, h, w)
+ 'folder': folder, # folder name
+ 'idx': self.data_info['idx'][index], # e.g., 0/99
+ 'border': border, # 1 for border, 0 for non-border
+ 'lq_path': lq_path # center frame
+ }
+
+
+@DATASET_REGISTRY.register()
+class VideoRecurrentTestDataset(VideoTestDataset):
+ """Video test dataset for recurrent architectures, which takes LR video
+ frames as input and output corresponding HR video frames.
+
+ Args:
+ Same as VideoTestDataset.
+ Unused opt:
+ padding (str): Padding mode.
+
+ """
+
+ def __init__(self, opt):
+ super(VideoRecurrentTestDataset, self).__init__(opt)
+ # Find unique folder strings
+ self.folders = sorted(list(set(self.data_info['folder'])))
+
+ def __getitem__(self, index):
+ folder = self.folders[index]
+
+ if self.cache_data:
+ imgs_lq = self.imgs_lq[folder]
+ imgs_gt = self.imgs_gt[folder]
+ else:
+ raise NotImplementedError('Without cache_data is not implemented.')
+
+ return {
+ 'lq': imgs_lq,
+ 'gt': imgs_gt,
+ 'folder': folder,
+ }
+
+ def __len__(self):
+ return len(self.folders)
|