From 495ffc4777522e40941753e3b1b79c02f84b25b4 Mon Sep 17 00:00:00 2001 From: Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:00:30 +0000 Subject: Add files via upload --- r_basicsr/models/__init__.py | 29 ++ r_basicsr/models/base_model.py | 380 ++++++++++++++++++++++++++ r_basicsr/models/edvr_model.py | 62 +++++ r_basicsr/models/esrgan_model.py | 83 ++++++ r_basicsr/models/hifacegan_model.py | 288 +++++++++++++++++++ r_basicsr/models/lr_scheduler.py | 96 +++++++ r_basicsr/models/realesrgan_model.py | 267 ++++++++++++++++++ r_basicsr/models/realesrnet_model.py | 189 +++++++++++++ r_basicsr/models/sr_model.py | 231 ++++++++++++++++ r_basicsr/models/srgan_model.py | 149 ++++++++++ r_basicsr/models/stylegan2_model.py | 283 +++++++++++++++++++ r_basicsr/models/swinir_model.py | 33 +++ r_basicsr/models/video_base_model.py | 160 +++++++++++ r_basicsr/models/video_gan_model.py | 17 ++ r_basicsr/models/video_recurrent_gan_model.py | 180 ++++++++++++ r_basicsr/models/video_recurrent_model.py | 197 +++++++++++++ 16 files changed, 2644 insertions(+) create mode 100644 r_basicsr/models/__init__.py create mode 100644 r_basicsr/models/base_model.py create mode 100644 r_basicsr/models/edvr_model.py create mode 100644 r_basicsr/models/esrgan_model.py create mode 100644 r_basicsr/models/hifacegan_model.py create mode 100644 r_basicsr/models/lr_scheduler.py create mode 100644 r_basicsr/models/realesrgan_model.py create mode 100644 r_basicsr/models/realesrnet_model.py create mode 100644 r_basicsr/models/sr_model.py create mode 100644 r_basicsr/models/srgan_model.py create mode 100644 r_basicsr/models/stylegan2_model.py create mode 100644 r_basicsr/models/swinir_model.py create mode 100644 r_basicsr/models/video_base_model.py create mode 100644 r_basicsr/models/video_gan_model.py create mode 100644 r_basicsr/models/video_recurrent_gan_model.py create mode 100644 r_basicsr/models/video_recurrent_model.py (limited to 'r_basicsr/models') diff --git a/r_basicsr/models/__init__.py b/r_basicsr/models/__init__.py new file mode 100644 index 0000000..b01cdba --- /dev/null +++ b/r_basicsr/models/__init__.py @@ -0,0 +1,29 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from r_basicsr.utils import get_root_logger, scandir +from r_basicsr.utils.registry import MODEL_REGISTRY + +__all__ = ['build_model'] + +# automatically scan and import model modules for registry +# scan all the files under the 'models' folder and collect files ending with '_model.py' +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] +# import all the model modules +_model_modules = [importlib.import_module(f'r_basicsr.models.{file_name}') for file_name in model_filenames] + + +def build_model(opt): + """Build model from options. + + Args: + opt (dict): Configuration. It must contain: + model_type (str): Model type. + """ + opt = deepcopy(opt) + model = MODEL_REGISTRY.get(opt['model_type'])(opt) + logger = get_root_logger() + logger.info(f'Model [{model.__class__.__name__}] is created.') + return model diff --git a/r_basicsr/models/base_model.py b/r_basicsr/models/base_model.py new file mode 100644 index 0000000..bd2faad --- /dev/null +++ b/r_basicsr/models/base_model.py @@ -0,0 +1,380 @@ +import os +import time +import torch +from collections import OrderedDict +from copy import deepcopy +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from r_basicsr.models import lr_scheduler as lr_scheduler +from r_basicsr.utils import get_root_logger +from r_basicsr.utils.dist_util import master_only + + +class BaseModel(): + """Base model.""" + + def __init__(self, opt): + self.opt = opt + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.is_train = opt['is_train'] + self.schedulers = [] + self.optimizers = [] + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def save(self, epoch, current_iter): + """Save networks and training state.""" + pass + + def validation(self, dataloader, current_iter, tb_logger, save_img=False): + """Validation function. + + Args: + dataloader (torch.utils.data.DataLoader): Validation dataloader. + current_iter (int): Current iteration. + tb_logger (tensorboard logger): Tensorboard logger. + save_img (bool): Whether to save images. Default: False. + """ + if self.opt['dist']: + self.dist_validation(dataloader, current_iter, tb_logger, save_img) + else: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def _initialize_best_metric_results(self, dataset_name): + """Initialize the best metric results dict for recording the best metric value and iteration.""" + if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results: + return + elif not hasattr(self, 'best_metric_results'): + self.best_metric_results = dict() + + # add a dataset record + record = dict() + for metric, content in self.opt['val']['metrics'].items(): + better = content.get('better', 'higher') + init_val = float('-inf') if better == 'higher' else float('inf') + record[metric] = dict(better=better, val=init_val, iter=-1) + self.best_metric_results[dataset_name] = record + + def _update_best_metric_result(self, dataset_name, metric, val, current_iter): + if self.best_metric_results[dataset_name][metric]['better'] == 'higher': + if val >= self.best_metric_results[dataset_name][metric]['val']: + self.best_metric_results[dataset_name][metric]['val'] = val + self.best_metric_results[dataset_name][metric]['iter'] = current_iter + else: + if val <= self.best_metric_results[dataset_name][metric]['val']: + self.best_metric_results[dataset_name][metric]['val'] = val + self.best_metric_results[dataset_name][metric]['iter'] = current_iter + + def model_ema(self, decay=0.999): + net_g = self.get_bare_model(self.net_g) + + net_g_params = dict(net_g.named_parameters()) + net_g_ema_params = dict(self.net_g_ema.named_parameters()) + + for k in net_g_ema_params.keys(): + net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay) + + def get_current_log(self): + return self.log_dict + + def model_to_device(self, net): + """Model to device. It also warps models with DistributedDataParallel + or DataParallel. + + Args: + net (nn.Module) + """ + net = net.to(self.device) + if self.opt['dist']: + find_unused_parameters = self.opt.get('find_unused_parameters', False) + net = DistributedDataParallel( + net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters) + elif self.opt['num_gpu'] > 1: + net = DataParallel(net) + return net + + def get_optimizer(self, optim_type, params, lr, **kwargs): + if optim_type == 'Adam': + optimizer = torch.optim.Adam(params, lr, **kwargs) + else: + raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.') + return optimizer + + def setup_schedulers(self): + """Set up schedulers.""" + train_opt = self.opt['train'] + scheduler_type = train_opt['scheduler'].pop('type') + if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'CosineAnnealingRestartLR': + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler'])) + else: + raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.') + + def get_bare_model(self, net): + """Get bare model, especially under wrapping with + DistributedDataParallel or DataParallel. + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + return net + + @master_only + def print_network(self, net): + """Print the str and parameter number of a network. + + Args: + net (nn.Module) + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}' + else: + net_cls_str = f'{net.__class__.__name__}' + + net = self.get_bare_model(net) + net_str = str(net) + net_params = sum(map(lambda x: x.numel(), net.parameters())) + + logger = get_root_logger() + logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}') + logger.info(net_str) + + def _set_lr(self, lr_groups_l): + """Set learning rate for warm-up. + + Args: + lr_groups_l (list): List for lr_groups, each for an optimizer. + """ + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def _get_init_lr(self): + """Get the initial lr, which is set by the scheduler. + """ + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, current_iter, warmup_iter=-1): + """Update learning rate. + + Args: + current_iter (int): Current iteration. + warmup_iter (int): Warm-up iter numbers. -1 for no warm-up. + Default: -1. + """ + if current_iter > 1: + for scheduler in self.schedulers: + scheduler.step() + # set up warm-up learning rate + if current_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + # currently only support linearly warm up + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + return [param_group['lr'] for param_group in self.optimizers[0].param_groups] + + @master_only + def save_network(self, net, net_label, current_iter, param_key='params'): + """Save networks. + + Args: + net (nn.Module | list[nn.Module]): Network(s) to be saved. + net_label (str): Network label. + current_iter (int): Current iter number. + param_key (str | list[str]): The parameter key(s) to save network. + Default: 'params'. + """ + if current_iter == -1: + current_iter = 'latest' + save_filename = f'{net_label}_{current_iter}.pth' + save_path = os.path.join(self.opt['path']['models'], save_filename) + + net = net if isinstance(net, list) else [net] + param_key = param_key if isinstance(param_key, list) else [param_key] + assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.' + + save_dict = {} + for net_, param_key_ in zip(net, param_key): + net_ = self.get_bare_model(net_) + state_dict = net_.state_dict() + for key, param in state_dict.items(): + if key.startswith('module.'): # remove unnecessary 'module.' + key = key[7:] + state_dict[key] = param.cpu() + save_dict[param_key_] = state_dict + + # avoid occasional writing errors + retry = 3 + while retry > 0: + try: + torch.save(save_dict, save_path) + except Exception as e: + logger = get_root_logger() + logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}') + time.sleep(1) + else: + break + finally: + retry -= 1 + if retry == 0: + logger.warning(f'Still cannot save {save_path}. Just ignore it.') + # raise IOError(f'Cannot save {save_path}.') + + def _print_different_keys_loading(self, crt_net, load_net, strict=True): + """Print keys with different name or different size when loading models. + + 1. Print keys with different names. + 2. If strict=False, print the same key but with different tensor size. + It also ignore these keys with different sizes (not load). + + Args: + crt_net (torch model): Current network. + load_net (dict): Loaded network. + strict (bool): Whether strictly loaded. Default: True. + """ + crt_net = self.get_bare_model(crt_net) + crt_net = crt_net.state_dict() + crt_net_keys = set(crt_net.keys()) + load_net_keys = set(load_net.keys()) + + logger = get_root_logger() + if crt_net_keys != load_net_keys: + logger.warning('Current net - loaded net:') + for v in sorted(list(crt_net_keys - load_net_keys)): + logger.warning(f' {v}') + logger.warning('Loaded net - current net:') + for v in sorted(list(load_net_keys - crt_net_keys)): + logger.warning(f' {v}') + + # check the size for the same keys + if not strict: + common_keys = crt_net_keys & load_net_keys + for k in common_keys: + if crt_net[k].size() != load_net[k].size(): + logger.warning(f'Size different, ignore [{k}]: crt_net: ' + f'{crt_net[k].shape}; load_net: {load_net[k].shape}') + load_net[k + '.ignore'] = load_net.pop(k) + + def load_network(self, net, load_path, strict=True, param_key='params'): + """Load network. + + Args: + load_path (str): The path of networks to be loaded. + net (nn.Module): Network. + strict (bool): Whether strictly loaded. + param_key (str): The parameter key of loaded network. If set to + None, use the root 'path'. + Default: 'params'. + """ + logger = get_root_logger() + net = self.get_bare_model(net) + load_net = torch.load(load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info('Loading: params_ema does not exist, use params.') + load_net = load_net[param_key] + logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].') + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + self._print_different_keys_loading(net, load_net, strict) + net.load_state_dict(load_net, strict=strict) + + @master_only + def save_training_state(self, epoch, current_iter): + """Save training states during training, which will be used for + resuming. + + Args: + epoch (int): Current epoch. + current_iter (int): Current iteration. + """ + if current_iter != -1: + state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []} + for o in self.optimizers: + state['optimizers'].append(o.state_dict()) + for s in self.schedulers: + state['schedulers'].append(s.state_dict()) + save_filename = f'{current_iter}.state' + save_path = os.path.join(self.opt['path']['training_states'], save_filename) + + # avoid occasional writing errors + retry = 3 + while retry > 0: + try: + torch.save(state, save_path) + except Exception as e: + logger = get_root_logger() + logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}') + time.sleep(1) + else: + break + finally: + retry -= 1 + if retry == 0: + logger.warning(f'Still cannot save {save_path}. Just ignore it.') + # raise IOError(f'Cannot save {save_path}.') + + def resume_training(self, resume_state): + """Reload the optimizers and schedulers for resumed training. + + Args: + resume_state (dict): Resume state. + """ + resume_optimizers = resume_state['optimizers'] + resume_schedulers = resume_state['schedulers'] + assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' + assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + for i, s in enumerate(resume_schedulers): + self.schedulers[i].load_state_dict(s) + + def reduce_loss_dict(self, loss_dict): + """reduce loss dict. + + In distributed training, it averages the losses among different GPUs . + + Args: + loss_dict (OrderedDict): Loss dict. + """ + with torch.no_grad(): + if self.opt['dist']: + keys = [] + losses = [] + for name, value in loss_dict.items(): + keys.append(name) + losses.append(value) + losses = torch.stack(losses, 0) + torch.distributed.reduce(losses, dst=0) + if self.opt['rank'] == 0: + losses /= self.opt['world_size'] + loss_dict = {key: loss for key, loss in zip(keys, losses)} + + log_dict = OrderedDict() + for name, value in loss_dict.items(): + log_dict[name] = value.mean().item() + + return log_dict diff --git a/r_basicsr/models/edvr_model.py b/r_basicsr/models/edvr_model.py new file mode 100644 index 0000000..1475033 --- /dev/null +++ b/r_basicsr/models/edvr_model.py @@ -0,0 +1,62 @@ +from r_basicsr.utils import get_root_logger +from r_basicsr.utils.registry import MODEL_REGISTRY +from .video_base_model import VideoBaseModel + + +@MODEL_REGISTRY.register() +class EDVRModel(VideoBaseModel): + """EDVR Model. + + Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501 + """ + + def __init__(self, opt): + super(EDVRModel, self).__init__(opt) + if self.is_train: + self.train_tsa_iter = opt['train'].get('tsa_iter') + + def setup_optimizers(self): + train_opt = self.opt['train'] + dcn_lr_mul = train_opt.get('dcn_lr_mul', 1) + logger = get_root_logger() + logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.') + if dcn_lr_mul == 1: + optim_params = self.net_g.parameters() + else: # separate dcn params and normal params for different lr + normal_params = [] + dcn_params = [] + for name, param in self.net_g.named_parameters(): + if 'dcn' in name: + dcn_params.append(param) + else: + normal_params.append(param) + optim_params = [ + { # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }, + { + 'params': dcn_params, + 'lr': train_opt['optim_g']['lr'] * dcn_lr_mul + }, + ] + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def optimize_parameters(self, current_iter): + if self.train_tsa_iter: + if current_iter == 1: + logger = get_root_logger() + logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.') + for name, param in self.net_g.named_parameters(): + if 'fusion' not in name: + param.requires_grad = False + elif current_iter == self.train_tsa_iter: + logger = get_root_logger() + logger.warning('Train all the parameters.') + for param in self.net_g.parameters(): + param.requires_grad = True + + super(EDVRModel, self).optimize_parameters(current_iter) diff --git a/r_basicsr/models/esrgan_model.py b/r_basicsr/models/esrgan_model.py new file mode 100644 index 0000000..8924920 --- /dev/null +++ b/r_basicsr/models/esrgan_model.py @@ -0,0 +1,83 @@ +import torch +from collections import OrderedDict + +from r_basicsr.utils.registry import MODEL_REGISTRY +from .srgan_model import SRGANModel + + +@MODEL_REGISTRY.register() +class ESRGANModel(SRGANModel): + """ESRGAN model for single image super-resolution.""" + + def optimize_parameters(self, current_iter): + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + # gan loss (relativistic gan) + real_d_pred = self.net_d(self.gt).detach() + fake_g_pred = self.net_d(self.output) + l_g_real = self.cri_gan(real_d_pred - torch.mean(fake_g_pred), False, is_disc=False) + l_g_fake = self.cri_gan(fake_g_pred - torch.mean(real_d_pred), True, is_disc=False) + l_g_gan = (l_g_real + l_g_fake) / 2 + + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # gan loss (relativistic gan) + + # In order to avoid the error in distributed training: + # "Error detected in CudnnBatchNormBackward: RuntimeError: one of + # the variables needed for gradient computation has been modified by + # an inplace operation", + # we separate the backwards for real and fake, and also detach the + # tensor for calculating mean. + + # real + fake_d_pred = self.net_d(self.output).detach() + real_d_pred = self.net_d(self.gt) + l_d_real = self.cri_gan(real_d_pred - torch.mean(fake_d_pred), True, is_disc=True) * 0.5 + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach()) + l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred.detach()), False, is_disc=True) * 0.5 + l_d_fake.backward() + self.optimizer_d.step() + + loss_dict['l_d_real'] = l_d_real + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) diff --git a/r_basicsr/models/hifacegan_model.py b/r_basicsr/models/hifacegan_model.py new file mode 100644 index 0000000..fd67d11 --- /dev/null +++ b/r_basicsr/models/hifacegan_model.py @@ -0,0 +1,288 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + +from r_basicsr.archs import build_network +from r_basicsr.losses import build_loss +from r_basicsr.metrics import calculate_metric +from r_basicsr.utils import imwrite, tensor2img +from r_basicsr.utils.registry import MODEL_REGISTRY +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class HiFaceGANModel(SRModel): + """HiFaceGAN model for generic-purpose face restoration. + No prior modeling required, works for any degradations. + Currently doesn't support EMA for inference. + """ + + def init_training_settings(self): + + train_opt = self.opt['train'] + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + raise (NotImplementedError('HiFaceGAN does not support EMA now. Pass')) + + self.net_g.train() + + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # define losses + # HiFaceGAN does not use pixel loss by default + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('feature_matching_opt'): + self.cri_feat = build_loss(train_opt['feature_matching_opt']).to(self.device) + else: + self.cri_feat = None + + if self.cri_pix is None and self.cri_perceptual is None: + raise ValueError('Both pixel and perceptual losses are None.') + + if train_opt.get('gan_opt'): + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, self.net_g.parameters(), **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + # optimizer d + optim_type = train_opt['optim_d'].pop('type') + self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) + self.optimizers.append(self.optimizer_d) + + def discriminate(self, input_lq, output, ground_truth): + """ + This is a conditional (on the input) discriminator + In Batch Normalization, the fake and real images are + recommended to be in the same batch to avoid disparate + statistics in fake and real images. + So both fake and real images are fed to D all at once. + """ + h, w = output.shape[-2:] + if output.shape[-2:] != input_lq.shape[-2:]: + lq = torch.nn.functional.interpolate(input_lq, (h, w)) + real = torch.nn.functional.interpolate(ground_truth, (h, w)) + fake_concat = torch.cat([lq, output], dim=1) + real_concat = torch.cat([lq, real], dim=1) + else: + fake_concat = torch.cat([input_lq, output], dim=1) + real_concat = torch.cat([input_lq, ground_truth], dim=1) + + fake_and_real = torch.cat([fake_concat, real_concat], dim=0) + discriminator_out = self.net_d(fake_and_real) + pred_fake, pred_real = self._divide_pred(discriminator_out) + return pred_fake, pred_real + + @staticmethod + def _divide_pred(pred): + """ + Take the prediction of fake and real images from the combined batch. + The prediction contains the intermediate outputs of multiscale GAN, + so it's usually a list + """ + if type(pred) == list: + fake = [] + real = [] + for p in pred: + fake.append([tensor[:tensor.size(0) // 2] for tensor in p]) + real.append([tensor[tensor.size(0) // 2:] for tensor in p]) + else: + fake = pred[:pred.size(0) // 2] + real = pred[pred.size(0) // 2:] + + return fake, real + + def optimize_parameters(self, current_iter): + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_g_total = 0 + loss_dict = OrderedDict() + + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + + # Requires real prediction for feature matching loss + pred_fake, pred_real = self.discriminate(self.lq, self.output, self.gt) + l_g_gan = self.cri_gan(pred_fake, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + # feature matching loss + if self.cri_feat: + l_g_feat = self.cri_feat(pred_fake, pred_real) + l_g_total += l_g_feat + loss_dict['l_g_feat'] = l_g_feat + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # TODO: Benchmark test between HiFaceGAN and SRGAN implementation: + # SRGAN use the same fake output for discriminator update + # while HiFaceGAN regenerate a new output using updated net_g + # This should not make too much difference though. Stick to SRGAN now. + # ------------------------------------------------------------------- + # ---------- Below are original HiFaceGAN code snippet -------------- + # ------------------------------------------------------------------- + # with torch.no_grad(): + # fake_image = self.net_g(self.lq) + # fake_image = fake_image.detach() + # fake_image.requires_grad_() + # pred_fake, pred_real = self.discriminate(self.lq, fake_image, self.gt) + + # real + pred_fake, pred_real = self.discriminate(self.lq, self.output.detach(), self.gt) + l_d_real = self.cri_gan(pred_real, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + # fake + l_d_fake = self.cri_gan(pred_fake, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + + l_d_total = (l_d_real + l_d_fake) / 2 + l_d_total.backward() + self.optimizer_d.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + print('HiFaceGAN does not support EMA now. pass') + + def validation(self, dataloader, current_iter, tb_logger, save_img=False): + """ + Warning: HiFaceGAN requires train() mode even for validation + For more info, see https://github.com/Lotayou/Face-Renovation/issues/31 + + Args: + dataloader (torch.utils.data.DataLoader): Validation dataloader. + current_iter (int): Current iteration. + tb_logger (tensorboard logger): Tensorboard logger. + save_img (bool): Whether to save images. Default: False. + """ + + if self.opt['network_g']['type'] in ('HiFaceGAN', 'SPADEGenerator'): + self.net_g.train() + + if self.opt['dist']: + self.dist_validation(dataloader, current_iter, tb_logger, save_img) + else: + print('In HiFaceGANModel: The new metrics package is under development.' + + 'Using super method now (Only PSNR & SSIM are supported)') + super().nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + """ + TODO: Validation using updated metric system + The metrics are now evaluated after all images have been tested + This allows batch processing, and also allows evaluation of + distributional metrics, such as: + + @ Frechet Inception Distance: FID + @ Maximum Mean Discrepancy: MMD + + Warning: + Need careful batch management for different inference settings. + + """ + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = dict() # {metric: 0 for metric in self.opt['val']['metrics'].keys()} + sr_tensors = [] + gt_tensors = [] + + pbar = tqdm(total=len(dataloader), unit='image') + for val_data in dataloader: + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() # detached cpu tensor, non-squeeze + sr_tensors.append(visuals['result']) + if 'gt' in visuals: + gt_tensors.append(visuals['gt']) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + + imwrite(tensor2img(visuals['result']), save_img_path) + + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() + + if with_metrics: + sr_pack = torch.cat(sr_tensors, dim=0) + gt_pack = torch.cat(gt_tensors, dim=0) + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + # The new metric caller automatically returns mean value + # FIXME: ERROR: calculate_metric only supports two arguments. Now the codes cannot be successfully run + self.metric_results[name] = calculate_metric(dict(sr_pack=sr_pack, gt_pack=gt_pack), opt_) + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def save(self, epoch, current_iter): + if hasattr(self, 'net_g_ema'): + print('HiFaceGAN does not support EMA now. Fallback to normal mode.') + + self.save_network(self.net_g, 'net_g', current_iter) + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/r_basicsr/models/lr_scheduler.py b/r_basicsr/models/lr_scheduler.py new file mode 100644 index 0000000..084122d --- /dev/null +++ b/r_basicsr/models/lr_scheduler.py @@ -0,0 +1,96 @@ +import math +from collections import Counter +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepRestartLR(_LRScheduler): + """ MultiStep with restarts learning rate scheme. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + restarts (list): Restart iterations. Default: [0]. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.restarts = restarts + self.restart_weights = restart_weights + assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] + + +def get_position_from_periods(iteration, cumulative_period): + """Get the position from a period list. + + It will return the index of the right-closest number in the period list. + For example, the cumulative_period = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 2. + + Args: + iteration (int): Current iteration. + cumulative_period (list[int]): Cumulative period list. + + Returns: + int: The position of the right-closest number in the period list. + """ + for i, period in enumerate(cumulative_period): + if iteration <= period: + return i + + +class CosineAnnealingRestartLR(_LRScheduler): + """ Cosine annealing with restarts learning rate scheme. + + An example of config: + periods = [10, 10, 10, 10] + restart_weights = [1, 0.5, 0.5, 0.5] + eta_min=1e-7 + + It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the + scheduler will restart with the weights in restart_weights. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + periods (list): Period for each cosine anneling cycle. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + eta_min (float): The minimum lr. Default: 0. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): + self.periods = periods + self.restart_weights = restart_weights + self.eta_min = eta_min + assert (len(self.periods) == len( + self.restart_weights)), 'periods and restart_weights should have the same length.' + self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] + super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + idx = get_position_from_periods(self.last_epoch, self.cumulative_period) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] + current_period = self.periods[idx] + + return [ + self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * + (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) + for base_lr in self.base_lrs + ] diff --git a/r_basicsr/models/realesrgan_model.py b/r_basicsr/models/realesrgan_model.py new file mode 100644 index 0000000..b05dafd --- /dev/null +++ b/r_basicsr/models/realesrgan_model.py @@ -0,0 +1,267 @@ +import numpy as np +import random +import torch +from collections import OrderedDict +from torch.nn import functional as F + +from r_basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt +from r_basicsr.data.transforms import paired_random_crop +from r_basicsr.losses.loss_util import get_refined_artifact_map +from r_basicsr.models.srgan_model import SRGANModel +from r_basicsr.utils import DiffJPEG, USMSharp +from r_basicsr.utils.img_process_util import filter2D +from r_basicsr.utils.registry import MODEL_REGISTRY + + +@MODEL_REGISTRY.register(suffix='basicsr') +class RealESRGANModel(SRGANModel): + """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It mainly performs: + 1. randomly synthesize LQ images in GPU tensors + 2. optimize the networks with GAN training. + """ + + def __init__(self, opt): + super(RealESRGANModel, self).__init__(opt) + self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts + self.usm_sharpener = USMSharp().cuda() # do usm sharpening + self.queue_size = opt.get('queue_size', 180) + + @torch.no_grad() + def _dequeue_and_enqueue(self): + """It is the training pair pool for increasing the diversity in a batch. + + Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a + batch could not have different resize scaling factors. Therefore, we employ this training pair pool + to increase the degradation diversity in a batch. + """ + # initialize + b, c, h, w = self.lq.size() + if not hasattr(self, 'queue_lr'): + assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' + self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() + _, c, h, w = self.gt.size() + self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() + self.queue_ptr = 0 + if self.queue_ptr == self.queue_size: # the pool is full + # do dequeue and enqueue + # shuffle + idx = torch.randperm(self.queue_size) + self.queue_lr = self.queue_lr[idx] + self.queue_gt = self.queue_gt[idx] + # get first b samples + lq_dequeue = self.queue_lr[0:b, :, :, :].clone() + gt_dequeue = self.queue_gt[0:b, :, :, :].clone() + # update the queue + self.queue_lr[0:b, :, :, :] = self.lq.clone() + self.queue_gt[0:b, :, :, :] = self.gt.clone() + + self.lq = lq_dequeue + self.gt = gt_dequeue + else: + # only do enqueue + self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() + self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() + self.queue_ptr = self.queue_ptr + b + + @torch.no_grad() + def feed_data(self, data): + """Accept data from dataloader, and then add two-order degradations to obtain LQ images. + """ + if self.is_train and self.opt.get('high_order_degradation', True): + # training data synthesis + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) + + self.kernel1 = data['kernel1'].to(self.device) + self.kernel2 = data['kernel2'].to(self.device) + self.sinc_kernel = data['sinc_kernel'].to(self.device) + + ori_h, ori_w = self.gt.size()[2:4] + + # ----------------------- The first degradation process ----------------------- # + # blur + out = filter2D(self.gt_usm, self.kernel1) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, scale_factor=scale, mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob'] + if np.random.uniform() < self.opt['gaussian_noise_prob']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) + out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts + out = self.jpeger(out, quality=jpeg_p) + + # ----------------------- The second degradation process ----------------------- # + # blur + if np.random.uniform() < self.opt['second_blur_prob']: + out = filter2D(out, self.kernel2) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range2'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range2'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate( + out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob2'] + if np.random.uniform() < self.opt['gaussian_noise_prob2']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range2'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + + # JPEG compression + the final sinc filter + # We also need to resize images to desired sizes. We group [resize back + sinc filter] together + # as one operation. + # We consider two orders: + # 1. [resize back + sinc filter] + JPEG compression + # 2. JPEG compression + [resize back + sinc filter] + # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. + if np.random.uniform() < 0.5: + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + else: + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + + # clamp and round + self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. + + # random crop + gt_size = self.opt['gt_size'] + (self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size, + self.opt['scale']) + + # training pair pool + self._dequeue_and_enqueue() + # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue + self.gt_usm = self.usm_sharpener(self.gt) + self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract + else: + # for paired training or validation + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + # do not use the synthetic process during validation + self.is_train = False + super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) + self.is_train = True + + def optimize_parameters(self, current_iter): + # usm sharpening + l1_gt = self.gt_usm + percep_gt = self.gt_usm + gan_gt = self.gt_usm + if self.opt['l1_gt_usm'] is False: + l1_gt = self.gt + if self.opt['percep_gt_usm'] is False: + percep_gt = self.gt + if self.opt['gan_gt_usm'] is False: + gan_gt = self.gt + + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + if self.cri_ldl: + self.output_ema = self.net_g_ema(self.lq) + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, l1_gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + if self.cri_ldl: + pixel_weight = get_refined_artifact_map(self.gt, self.output, self.output_ema, 7) + l_g_ldl = self.cri_ldl(torch.mul(pixel_weight, self.output), torch.mul(pixel_weight, self.gt)) + l_g_total += l_g_ldl + loss_dict['l_g_ldl'] = l_g_ldl + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + # gan loss + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + real_d_pred = self.net_d(gan_gt) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9 + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + self.optimizer_d.step() + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + self.log_dict = self.reduce_loss_dict(loss_dict) diff --git a/r_basicsr/models/realesrnet_model.py b/r_basicsr/models/realesrnet_model.py new file mode 100644 index 0000000..2e8dc65 --- /dev/null +++ b/r_basicsr/models/realesrnet_model.py @@ -0,0 +1,189 @@ +import numpy as np +import random +import torch +from torch.nn import functional as F + +from r_basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt +from r_basicsr.data.transforms import paired_random_crop +from r_basicsr.models.sr_model import SRModel +from r_basicsr.utils import DiffJPEG, USMSharp +from r_basicsr.utils.img_process_util import filter2D +from r_basicsr.utils.registry import MODEL_REGISTRY + + +@MODEL_REGISTRY.register(suffix='basicsr') +class RealESRNetModel(SRModel): + """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It is trained without GAN losses. + It mainly performs: + 1. randomly synthesize LQ images in GPU tensors + 2. optimize the networks with GAN training. + """ + + def __init__(self, opt): + super(RealESRNetModel, self).__init__(opt) + self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts + self.usm_sharpener = USMSharp().cuda() # do usm sharpening + self.queue_size = opt.get('queue_size', 180) + + @torch.no_grad() + def _dequeue_and_enqueue(self): + """It is the training pair pool for increasing the diversity in a batch. + + Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a + batch could not have different resize scaling factors. Therefore, we employ this training pair pool + to increase the degradation diversity in a batch. + """ + # initialize + b, c, h, w = self.lq.size() + if not hasattr(self, 'queue_lr'): + assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' + self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() + _, c, h, w = self.gt.size() + self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() + self.queue_ptr = 0 + if self.queue_ptr == self.queue_size: # the pool is full + # do dequeue and enqueue + # shuffle + idx = torch.randperm(self.queue_size) + self.queue_lr = self.queue_lr[idx] + self.queue_gt = self.queue_gt[idx] + # get first b samples + lq_dequeue = self.queue_lr[0:b, :, :, :].clone() + gt_dequeue = self.queue_gt[0:b, :, :, :].clone() + # update the queue + self.queue_lr[0:b, :, :, :] = self.lq.clone() + self.queue_gt[0:b, :, :, :] = self.gt.clone() + + self.lq = lq_dequeue + self.gt = gt_dequeue + else: + # only do enqueue + self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() + self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() + self.queue_ptr = self.queue_ptr + b + + @torch.no_grad() + def feed_data(self, data): + """Accept data from dataloader, and then add two-order degradations to obtain LQ images. + """ + if self.is_train and self.opt.get('high_order_degradation', True): + # training data synthesis + self.gt = data['gt'].to(self.device) + # USM sharpen the GT images + if self.opt['gt_usm'] is True: + self.gt = self.usm_sharpener(self.gt) + + self.kernel1 = data['kernel1'].to(self.device) + self.kernel2 = data['kernel2'].to(self.device) + self.sinc_kernel = data['sinc_kernel'].to(self.device) + + ori_h, ori_w = self.gt.size()[2:4] + + # ----------------------- The first degradation process ----------------------- # + # blur + out = filter2D(self.gt, self.kernel1) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, scale_factor=scale, mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob'] + if np.random.uniform() < self.opt['gaussian_noise_prob']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) + out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts + out = self.jpeger(out, quality=jpeg_p) + + # ----------------------- The second degradation process ----------------------- # + # blur + if np.random.uniform() < self.opt['second_blur_prob']: + out = filter2D(out, self.kernel2) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range2'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range2'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate( + out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob2'] + if np.random.uniform() < self.opt['gaussian_noise_prob2']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range2'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + + # JPEG compression + the final sinc filter + # We also need to resize images to desired sizes. We group [resize back + sinc filter] together + # as one operation. + # We consider two orders: + # 1. [resize back + sinc filter] + JPEG compression + # 2. JPEG compression + [resize back + sinc filter] + # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. + if np.random.uniform() < 0.5: + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + else: + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + + # clamp and round + self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. + + # random crop + gt_size = self.opt['gt_size'] + self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale']) + + # training pair pool + self._dequeue_and_enqueue() + self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract + else: + # for paired training or validation + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + # do not use the synthetic process during validation + self.is_train = False + super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) + self.is_train = True diff --git a/r_basicsr/models/sr_model.py b/r_basicsr/models/sr_model.py new file mode 100644 index 0000000..f6e37e9 --- /dev/null +++ b/r_basicsr/models/sr_model.py @@ -0,0 +1,231 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + +from r_basicsr.archs import build_network +from r_basicsr.losses import build_loss +from r_basicsr.metrics import calculate_metric +from r_basicsr.utils import get_root_logger, imwrite, tensor2img +from r_basicsr.utils.registry import MODEL_REGISTRY +from .base_model import BaseModel + + +@MODEL_REGISTRY.register() +class SRModel(BaseModel): + """Base SR model for single image super-resolution.""" + + def __init__(self, opt): + super(SRModel, self).__init__(opt) + + # define network + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + self.net_g.train() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if self.cri_pix is None and self.cri_perceptual is None: + raise ValueError('Both pixel and perceptual losses are None.') + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + optim_params = [] + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def feed_data(self, data): + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + + def optimize_parameters(self, current_iter): + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_total = 0 + loss_dict = OrderedDict() + # pixel loss + if self.cri_pix: + l_pix = self.cri_pix(self.output, self.gt) + l_total += l_pix + loss_dict['l_pix'] = l_pix + # perceptual loss + if self.cri_perceptual: + l_percep, l_style = self.cri_perceptual(self.output, self.gt) + if l_percep is not None: + l_total += l_percep + loss_dict['l_percep'] = l_percep + if l_style is not None: + l_total += l_style + loss_dict['l_style'] = l_style + + l_total.backward() + self.optimizer_g.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def test(self): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(self.lq) + else: + self.net_g.eval() + with torch.no_grad(): + self.output = self.net_g(self.lq) + self.net_g.train() + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + use_pbar = self.opt['val'].get('pbar', False) + + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + # initialize the best metric results for each dataset_name (supporting multiple validation datasets) + self._initialize_best_metric_results(dataset_name) + # zero self.metric_results + if with_metrics: + self.metric_results = {metric: 0 for metric in self.metric_results} + + metric_data = dict() + if use_pbar: + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + metric_data['img'] = sr_img + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + metric_data['img2'] = gt_img + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + self.metric_results[name] += calculate_metric(metric_data, opt_) + if use_pbar: + pbar.update(1) + pbar.set_description(f'Test {img_name}') + if use_pbar: + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + # update the best metric result + self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}' + if hasattr(self, 'best_metric_results'): + log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ ' + f'{self.best_metric_results[dataset_name][metric]["iter"]} iter') + log_str += '\n' + + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter) + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['lq'] = self.lq.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + if hasattr(self, 'gt'): + out_dict['gt'] = self.gt.detach().cpu() + return out_dict + + def save(self, epoch, current_iter): + if hasattr(self, 'net_g_ema'): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/r_basicsr/models/srgan_model.py b/r_basicsr/models/srgan_model.py new file mode 100644 index 0000000..a562a7d --- /dev/null +++ b/r_basicsr/models/srgan_model.py @@ -0,0 +1,149 @@ +import torch +from collections import OrderedDict + +from r_basicsr.archs import build_network +from r_basicsr.losses import build_loss +from r_basicsr.utils import get_root_logger +from r_basicsr.utils.registry import MODEL_REGISTRY +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class SRGANModel(SRModel): + """SRGAN model for single image super-resolution.""" + + def init_training_settings(self): + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define network net_d + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_d', 'params') + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key) + + self.net_g.train() + self.net_d.train() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('ldl_opt'): + self.cri_ldl = build_loss(train_opt['ldl_opt']).to(self.device) + else: + self.cri_ldl = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('gan_opt'): + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, self.net_g.parameters(), **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + # optimizer d + optim_type = train_opt['optim_d'].pop('type') + self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) + self.optimizers.append(self.optimizer_d) + + def optimize_parameters(self, current_iter): + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + # gan loss + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + real_d_pred = self.net_d(self.gt) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach()) + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + self.optimizer_d.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def save(self, epoch, current_iter): + if hasattr(self, 'net_g_ema'): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/r_basicsr/models/stylegan2_model.py b/r_basicsr/models/stylegan2_model.py new file mode 100644 index 0000000..58a38ae --- /dev/null +++ b/r_basicsr/models/stylegan2_model.py @@ -0,0 +1,283 @@ +import cv2 +import math +import numpy as np +import random +import torch +from collections import OrderedDict +from os import path as osp + +from r_basicsr.archs import build_network +from r_basicsr.losses import build_loss +from r_basicsr.losses.gan_loss import g_path_regularize, r1_penalty +from r_basicsr.utils import imwrite, tensor2img +from r_basicsr.utils.registry import MODEL_REGISTRY +from .base_model import BaseModel + + +@MODEL_REGISTRY.register() +class StyleGAN2Model(BaseModel): + """StyleGAN2 model.""" + + def __init__(self, opt): + super(StyleGAN2Model, self).__init__(opt) + + # define network net_g + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + # latent dimension: self.num_style_feat + self.num_style_feat = opt['network_g']['num_style_feat'] + num_val_samples = self.opt['val'].get('num_val_samples', 16) + self.fixed_sample = torch.randn(num_val_samples, self.num_style_feat, device=self.device) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + train_opt = self.opt['train'] + + # define network net_d + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_d', 'params') + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key) + + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema only used for testing on one GPU and saving, do not need to + # wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + + self.net_g.train() + self.net_d.train() + self.net_g_ema.eval() + + # define losses + # gan loss (wgan) + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + # regularization weights + self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator + self.path_reg_weight = train_opt['path_reg_weight'] # for generator + + self.net_g_reg_every = train_opt['net_g_reg_every'] + self.net_d_reg_every = train_opt['net_d_reg_every'] + self.mixing_prob = train_opt['mixing_prob'] + + self.mean_path_length = 0 + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + net_g_reg_ratio = self.net_g_reg_every / (self.net_g_reg_every + 1) + if self.opt['network_g']['type'] == 'StyleGAN2GeneratorC': + normal_params = [] + style_mlp_params = [] + modulation_conv_params = [] + for name, param in self.net_g.named_parameters(): + if 'modulation' in name: + normal_params.append(param) + elif 'style_mlp' in name: + style_mlp_params.append(param) + elif 'modulated_conv' in name: + modulation_conv_params.append(param) + else: + normal_params.append(param) + optim_params_g = [ + { # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }, + { + 'params': style_mlp_params, + 'lr': train_opt['optim_g']['lr'] * 0.01 + }, + { + 'params': modulation_conv_params, + 'lr': train_opt['optim_g']['lr'] / 3 + } + ] + else: + normal_params = [] + for name, param in self.net_g.named_parameters(): + normal_params.append(param) + optim_params_g = [{ # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }] + + optim_type = train_opt['optim_g'].pop('type') + lr = train_opt['optim_g']['lr'] * net_g_reg_ratio + betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio) + self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas) + self.optimizers.append(self.optimizer_g) + + # optimizer d + net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1) + if self.opt['network_d']['type'] == 'StyleGAN2DiscriminatorC': + normal_params = [] + linear_params = [] + for name, param in self.net_d.named_parameters(): + if 'final_linear' in name: + linear_params.append(param) + else: + normal_params.append(param) + optim_params_d = [ + { # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_d']['lr'] + }, + { + 'params': linear_params, + 'lr': train_opt['optim_d']['lr'] * (1 / math.sqrt(512)) + } + ] + else: + normal_params = [] + for name, param in self.net_d.named_parameters(): + normal_params.append(param) + optim_params_d = [{ # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_d']['lr'] + }] + + optim_type = train_opt['optim_d'].pop('type') + lr = train_opt['optim_d']['lr'] * net_d_reg_ratio + betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio) + self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas) + self.optimizers.append(self.optimizer_d) + + def feed_data(self, data): + self.real_img = data['gt'].to(self.device) + + def make_noise(self, batch, num_noise): + if num_noise == 1: + noises = torch.randn(batch, self.num_style_feat, device=self.device) + else: + noises = torch.randn(num_noise, batch, self.num_style_feat, device=self.device).unbind(0) + return noises + + def mixing_noise(self, batch, prob): + if random.random() < prob: + return self.make_noise(batch, 2) + else: + return [self.make_noise(batch, 1)] + + def optimize_parameters(self, current_iter): + loss_dict = OrderedDict() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + self.optimizer_d.zero_grad() + + batch = self.real_img.size(0) + noise = self.mixing_noise(batch, self.mixing_prob) + fake_img, _ = self.net_g(noise) + fake_pred = self.net_d(fake_img.detach()) + + real_pred = self.net_d(self.real_img) + # wgan loss with softplus (logistic loss) for discriminator + l_d = self.cri_gan(real_pred, True, is_disc=True) + self.cri_gan(fake_pred, False, is_disc=True) + loss_dict['l_d'] = l_d + # In wgan, real_score should be positive and fake_score should be + # negative + loss_dict['real_score'] = real_pred.detach().mean() + loss_dict['fake_score'] = fake_pred.detach().mean() + l_d.backward() + + if current_iter % self.net_d_reg_every == 0: + self.real_img.requires_grad = True + real_pred = self.net_d(self.real_img) + l_d_r1 = r1_penalty(real_pred, self.real_img) + l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0]) + # TODO: why do we need to add 0 * real_pred, otherwise, a runtime + # error will arise: RuntimeError: Expected to have finished + # reduction in the prior iteration before starting a new one. + # This error indicates that your module has parameters that were + # not used in producing loss. + loss_dict['l_d_r1'] = l_d_r1.detach().mean() + l_d_r1.backward() + + self.optimizer_d.step() + + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + self.optimizer_g.zero_grad() + + noise = self.mixing_noise(batch, self.mixing_prob) + fake_img, _ = self.net_g(noise) + fake_pred = self.net_d(fake_img) + + # wgan loss with softplus (non-saturating loss) for generator + l_g = self.cri_gan(fake_pred, True, is_disc=False) + loss_dict['l_g'] = l_g + l_g.backward() + + if current_iter % self.net_g_reg_every == 0: + path_batch_size = max(1, batch // self.opt['train']['path_batch_shrink']) + noise = self.mixing_noise(path_batch_size, self.mixing_prob) + fake_img, latents = self.net_g(noise, return_latents=True) + l_g_path, path_lengths, self.mean_path_length = g_path_regularize(fake_img, latents, self.mean_path_length) + + l_g_path = (self.path_reg_weight * self.net_g_reg_every * l_g_path + 0 * fake_img[0, 0, 0, 0]) + # TODO: why do we need to add 0 * fake_img[0, 0, 0, 0] + l_g_path.backward() + loss_dict['l_g_path'] = l_g_path.detach().mean() + loss_dict['path_length'] = path_lengths + + self.optimizer_g.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + # EMA + self.model_ema(decay=0.5**(32 / (10 * 1000))) + + def test(self): + with torch.no_grad(): + self.net_g_ema.eval() + self.output, _ = self.net_g_ema([self.fixed_sample]) + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + assert dataloader is None, 'Validation dataloader should be None.' + self.test() + result = tensor2img(self.output, min_max=(-1, 1)) + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], 'train', f'train_{current_iter}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], 'test', f'test_{self.opt["name"]}.png') + imwrite(result, save_img_path) + # add sample images to tb_logger + result = (result / 255.).astype(np.float32) + result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) + if tb_logger is not None: + tb_logger.add_image('samples', result, global_step=current_iter, dataformats='HWC') + + def save(self, epoch, current_iter): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/r_basicsr/models/swinir_model.py b/r_basicsr/models/swinir_model.py new file mode 100644 index 0000000..18e5550 --- /dev/null +++ b/r_basicsr/models/swinir_model.py @@ -0,0 +1,33 @@ +import torch +from torch.nn import functional as F + +from r_basicsr.utils.registry import MODEL_REGISTRY +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class SwinIRModel(SRModel): + + def test(self): + # pad to multiplication of window_size + window_size = self.opt['network_g']['window_size'] + scale = self.opt.get('scale', 1) + mod_pad_h, mod_pad_w = 0, 0 + _, _, h, w = self.lq.size() + if h % window_size != 0: + mod_pad_h = window_size - h % window_size + if w % window_size != 0: + mod_pad_w = window_size - w % window_size + img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(img) + else: + self.net_g.eval() + with torch.no_grad(): + self.output = self.net_g(img) + self.net_g.train() + + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale] diff --git a/r_basicsr/models/video_base_model.py b/r_basicsr/models/video_base_model.py new file mode 100644 index 0000000..31ea37d --- /dev/null +++ b/r_basicsr/models/video_base_model.py @@ -0,0 +1,160 @@ +import torch +from collections import Counter +from os import path as osp +from torch import distributed as dist +from tqdm import tqdm + +from r_basicsr.metrics import calculate_metric +from r_basicsr.utils import get_root_logger, imwrite, tensor2img +from r_basicsr.utils.dist_util import get_dist_info +from r_basicsr.utils.registry import MODEL_REGISTRY +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class VideoBaseModel(SRModel): + """Base video SR model.""" + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset = dataloader.dataset + dataset_name = dataset.opt['name'] + with_metrics = self.opt['val']['metrics'] is not None + # initialize self.metric_results + # It is a dict: { + # 'folder1': tensor (num_frame x len(metrics)), + # 'folder2': tensor (num_frame x len(metrics)) + # } + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {} + num_frame_each_folder = Counter(dataset.data_info['folder']) + for folder, num_frame in num_frame_each_folder.items(): + self.metric_results[folder] = torch.zeros( + num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda') + # initialize the best metric results + self._initialize_best_metric_results(dataset_name) + # zero self.metric_results + rank, world_size = get_dist_info() + if with_metrics: + for _, tensor in self.metric_results.items(): + tensor.zero_() + + metric_data = dict() + # record all frames (border and center frames) + if rank == 0: + pbar = tqdm(total=len(dataset), unit='frame') + for idx in range(rank, len(dataset), world_size): + val_data = dataset[idx] + val_data['lq'].unsqueeze_(0) + val_data['gt'].unsqueeze_(0) + folder = val_data['folder'] + frame_idx, max_idx = val_data['idx'].split('/') + lq_path = val_data['lq_path'] + + self.feed_data(val_data) + self.test() + visuals = self.get_current_visuals() + result_img = tensor2img([visuals['result']]) + metric_data['img'] = result_img + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + metric_data['img2'] = gt_img + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + raise NotImplementedError('saving image is not supported during training.') + else: + if 'vimeo' in dataset_name.lower(): # vimeo90k dataset + split_result = lq_path.split('/') + img_name = f'{split_result[-3]}_{split_result[-2]}_{split_result[-1].split(".")[0]}' + else: # other datasets, e.g., REDS, Vid4 + img_name = osp.splitext(osp.basename(lq_path))[0] + + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder, + f'{img_name}_{self.opt["name"]}.png') + imwrite(result_img, save_img_path) + + if with_metrics: + # calculate metrics + for metric_idx, opt_ in enumerate(self.opt['val']['metrics'].values()): + result = calculate_metric(metric_data, opt_) + self.metric_results[folder][int(frame_idx), metric_idx] += result + + # progress bar + if rank == 0: + for _ in range(world_size): + pbar.update(1) + pbar.set_description(f'Test {folder}: {int(frame_idx) + world_size}/{max_idx}') + if rank == 0: + pbar.close() + + if with_metrics: + if self.opt['dist']: + # collect data among GPUs + for _, tensor in self.metric_results.items(): + dist.reduce(tensor, 0) + dist.barrier() + else: + pass # assume use one gpu in non-dist testing + + if rank == 0: + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + logger = get_root_logger() + logger.warning('nondist_validation is not implemented. Run dist_validation.') + self.dist_validation(dataloader, current_iter, tb_logger, save_img) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + # ----------------- calculate the average values for each folder, and for each metric ----------------- # + # average all frames for each sub-folder + # metric_results_avg is a dict:{ + # 'folder1': tensor (len(metrics)), + # 'folder2': tensor (len(metrics)) + # } + metric_results_avg = { + folder: torch.mean(tensor, dim=0).cpu() + for (folder, tensor) in self.metric_results.items() + } + # total_avg_results is a dict: { + # 'metric1': float, + # 'metric2': float + # } + total_avg_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + for folder, tensor in metric_results_avg.items(): + for idx, metric in enumerate(total_avg_results.keys()): + total_avg_results[metric] += metric_results_avg[folder][idx].item() + # average among folders + for metric in total_avg_results.keys(): + total_avg_results[metric] /= len(metric_results_avg) + # update the best metric result + self._update_best_metric_result(dataset_name, metric, total_avg_results[metric], current_iter) + + # ------------------------------------------ log the metric ------------------------------------------ # + log_str = f'Validation {dataset_name}\n' + for metric_idx, (metric, value) in enumerate(total_avg_results.items()): + log_str += f'\t # {metric}: {value:.4f}' + for folder, tensor in metric_results_avg.items(): + log_str += f'\t # {folder}: {tensor[metric_idx].item():.4f}' + if hasattr(self, 'best_metric_results'): + log_str += (f'\n\t Best: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ ' + f'{self.best_metric_results[dataset_name][metric]["iter"]} iter') + log_str += '\n' + + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric_idx, (metric, value) in enumerate(total_avg_results.items()): + tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + for folder, tensor in metric_results_avg.items(): + tb_logger.add_scalar(f'metrics/{metric}/{folder}', tensor[metric_idx].item(), current_iter) diff --git a/r_basicsr/models/video_gan_model.py b/r_basicsr/models/video_gan_model.py new file mode 100644 index 0000000..cc44476 --- /dev/null +++ b/r_basicsr/models/video_gan_model.py @@ -0,0 +1,17 @@ +from r_basicsr.utils.registry import MODEL_REGISTRY +from .srgan_model import SRGANModel +from .video_base_model import VideoBaseModel + + +@MODEL_REGISTRY.register() +class VideoGANModel(SRGANModel, VideoBaseModel): + """Video GAN model. + + Use multiple inheritance. + It will first use the functions of SRGANModel: + init_training_settings + setup_optimizers + optimize_parameters + save + Then find functions in VideoBaseModel. + """ diff --git a/r_basicsr/models/video_recurrent_gan_model.py b/r_basicsr/models/video_recurrent_gan_model.py new file mode 100644 index 0000000..2800e27 --- /dev/null +++ b/r_basicsr/models/video_recurrent_gan_model.py @@ -0,0 +1,180 @@ +import torch +from collections import OrderedDict + +from r_basicsr.archs import build_network +from r_basicsr.losses import build_loss +from r_basicsr.utils import get_root_logger +from r_basicsr.utils.registry import MODEL_REGISTRY +from .video_recurrent_model import VideoRecurrentModel + + +@MODEL_REGISTRY.register() +class VideoRecurrentGANModel(VideoRecurrentModel): + + def init_training_settings(self): + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # build network net_g with Exponential Moving Average (EMA) + # net_g_ema only used for testing on one GPU and saving. + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define network net_d + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_d', 'params') + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key) + + self.net_g.train() + self.net_d.train() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('gan_opt'): + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + if train_opt['fix_flow']: + normal_params = [] + flow_params = [] + for name, param in self.net_g.named_parameters(): + if 'spynet' in name: # The fix_flow now only works for spynet. + flow_params.append(param) + else: + normal_params.append(param) + + optim_params = [ + { # add flow params first + 'params': flow_params, + 'lr': train_opt['lr_flow'] + }, + { + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }, + ] + else: + optim_params = self.net_g.parameters() + + # optimizer g + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + # optimizer d + optim_type = train_opt['optim_d'].pop('type') + self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) + self.optimizers.append(self.optimizer_d) + + def optimize_parameters(self, current_iter): + logger = get_root_logger() + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + if self.fix_flow_iter: + if current_iter == 1: + logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.') + for name, param in self.net_g.named_parameters(): + if 'spynet' in name or 'edvr' in name: + param.requires_grad_(False) + elif current_iter == self.fix_flow_iter: + logger.warning('Train all the parameters.') + self.net_g.requires_grad_(True) + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + _, _, c, h, w = self.output.size() + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output.view(-1, c, h, w), self.gt.view(-1, c, h, w)) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + # gan loss + fake_g_pred = self.net_d(self.output.view(-1, c, h, w)) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + # reshape to (b*n, c, h, w) + real_d_pred = self.net_d(self.gt.view(-1, c, h, w)) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + # reshape to (b*n, c, h, w) + fake_d_pred = self.net_d(self.output.view(-1, c, h, w).detach()) + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + self.optimizer_d.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def save(self, epoch, current_iter): + if self.ema_decay > 0: + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/r_basicsr/models/video_recurrent_model.py b/r_basicsr/models/video_recurrent_model.py new file mode 100644 index 0000000..ea3a4c5 --- /dev/null +++ b/r_basicsr/models/video_recurrent_model.py @@ -0,0 +1,197 @@ +import torch +from collections import Counter +from os import path as osp +from torch import distributed as dist +from tqdm import tqdm + +from r_basicsr.metrics import calculate_metric +from r_basicsr.utils import get_root_logger, imwrite, tensor2img +from r_basicsr.utils.dist_util import get_dist_info +from r_basicsr.utils.registry import MODEL_REGISTRY +from .video_base_model import VideoBaseModel + + +@MODEL_REGISTRY.register() +class VideoRecurrentModel(VideoBaseModel): + + def __init__(self, opt): + super(VideoRecurrentModel, self).__init__(opt) + if self.is_train: + self.fix_flow_iter = opt['train'].get('fix_flow') + + def setup_optimizers(self): + train_opt = self.opt['train'] + flow_lr_mul = train_opt.get('flow_lr_mul', 1) + logger = get_root_logger() + logger.info(f'Multiple the learning rate for flow network with {flow_lr_mul}.') + if flow_lr_mul == 1: + optim_params = self.net_g.parameters() + else: # separate flow params and normal params for different lr + normal_params = [] + flow_params = [] + for name, param in self.net_g.named_parameters(): + if 'spynet' in name: + flow_params.append(param) + else: + normal_params.append(param) + optim_params = [ + { # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }, + { + 'params': flow_params, + 'lr': train_opt['optim_g']['lr'] * flow_lr_mul + }, + ] + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def optimize_parameters(self, current_iter): + if self.fix_flow_iter: + logger = get_root_logger() + if current_iter == 1: + logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.') + for name, param in self.net_g.named_parameters(): + if 'spynet' in name or 'edvr' in name: + param.requires_grad_(False) + elif current_iter == self.fix_flow_iter: + logger.warning('Train all the parameters.') + self.net_g.requires_grad_(True) + + super(VideoRecurrentModel, self).optimize_parameters(current_iter) + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset = dataloader.dataset + dataset_name = dataset.opt['name'] + with_metrics = self.opt['val']['metrics'] is not None + # initialize self.metric_results + # It is a dict: { + # 'folder1': tensor (num_frame x len(metrics)), + # 'folder2': tensor (num_frame x len(metrics)) + # } + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {} + num_frame_each_folder = Counter(dataset.data_info['folder']) + for folder, num_frame in num_frame_each_folder.items(): + self.metric_results[folder] = torch.zeros( + num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda') + # initialize the best metric results + self._initialize_best_metric_results(dataset_name) + # zero self.metric_results + rank, world_size = get_dist_info() + if with_metrics: + for _, tensor in self.metric_results.items(): + tensor.zero_() + + metric_data = dict() + num_folders = len(dataset) + num_pad = (world_size - (num_folders % world_size)) % world_size + if rank == 0: + pbar = tqdm(total=len(dataset), unit='folder') + # Will evaluate (num_folders + num_pad) times, but only the first num_folders results will be recorded. + # (To avoid wait-dead) + for i in range(rank, num_folders + num_pad, world_size): + idx = min(i, num_folders - 1) + val_data = dataset[idx] + folder = val_data['folder'] + + # compute outputs + val_data['lq'].unsqueeze_(0) + val_data['gt'].unsqueeze_(0) + self.feed_data(val_data) + val_data['lq'].squeeze_(0) + val_data['gt'].squeeze_(0) + + self.test() + visuals = self.get_current_visuals() + + # tentative for out of GPU memory + del self.lq + del self.output + if 'gt' in visuals: + del self.gt + torch.cuda.empty_cache() + + if self.center_frame_only: + visuals['result'] = visuals['result'].unsqueeze(1) + if 'gt' in visuals: + visuals['gt'] = visuals['gt'].unsqueeze(1) + + # evaluate + if i < num_folders: + for idx in range(visuals['result'].size(1)): + result = visuals['result'][0, idx, :, :, :] + result_img = tensor2img([result]) # uint8, bgr + metric_data['img'] = result_img + if 'gt' in visuals: + gt = visuals['gt'][0, idx, :, :, :] + gt_img = tensor2img([gt]) # uint8, bgr + metric_data['img2'] = gt_img + + if save_img: + if self.opt['is_train']: + raise NotImplementedError('saving image is not supported during training.') + else: + if self.center_frame_only: # vimeo-90k + clip_ = val_data['lq_path'].split('/')[-3] + seq_ = val_data['lq_path'].split('/')[-2] + name_ = f'{clip_}_{seq_}' + img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder, + f"{name_}_{self.opt['name']}.png") + else: # others + img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder, + f"{idx:08d}_{self.opt['name']}.png") + # image name only for REDS dataset + imwrite(result_img, img_path) + + # calculate metrics + if with_metrics: + for metric_idx, opt_ in enumerate(self.opt['val']['metrics'].values()): + result = calculate_metric(metric_data, opt_) + self.metric_results[folder][idx, metric_idx] += result + + # progress bar + if rank == 0: + for _ in range(world_size): + pbar.update(1) + pbar.set_description(f'Folder: {folder}') + + if rank == 0: + pbar.close() + + if with_metrics: + if self.opt['dist']: + # collect data among GPUs + for _, tensor in self.metric_results.items(): + dist.reduce(tensor, 0) + dist.barrier() + + if rank == 0: + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def test(self): + n = self.lq.size(1) + self.net_g.eval() + + flip_seq = self.opt['val'].get('flip_seq', False) + self.center_frame_only = self.opt['val'].get('center_frame_only', False) + + if flip_seq: + self.lq = torch.cat([self.lq, self.lq.flip(1)], dim=1) + + with torch.no_grad(): + self.output = self.net_g(self.lq) + + if flip_seq: + output_1 = self.output[:, :n, :, :, :] + output_2 = self.output[:, n:, :, :, :].flip(1) + self.output = 0.5 * (output_1 + output_2) + + if self.center_frame_only: + self.output = self.output[:, n // 2, :, :, :] + + self.net_g.train() -- cgit v1.2.3