diff options
author | Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> | 2025-01-17 11:00:30 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-17 11:00:30 +0000 |
commit | 495ffc4777522e40941753e3b1b79c02f84b25b4 (patch) | |
tree | 5130fcb8676afdcb619a5e5eaef3ac28e135bc08 /r_basicsr/models | |
parent | febd45814cd41560c5247aacb111d8d013f3a303 (diff) | |
download | Comfyui-reactor-node-495ffc4777522e40941753e3b1b79c02f84b25b4.tar.gz |
Add files via upload
Diffstat (limited to 'r_basicsr/models')
-rw-r--r-- | r_basicsr/models/__init__.py | 29 | ||||
-rw-r--r-- | r_basicsr/models/base_model.py | 380 | ||||
-rw-r--r-- | r_basicsr/models/edvr_model.py | 62 | ||||
-rw-r--r-- | r_basicsr/models/esrgan_model.py | 83 | ||||
-rw-r--r-- | r_basicsr/models/hifacegan_model.py | 288 | ||||
-rw-r--r-- | r_basicsr/models/lr_scheduler.py | 96 | ||||
-rw-r--r-- | r_basicsr/models/realesrgan_model.py | 267 | ||||
-rw-r--r-- | r_basicsr/models/realesrnet_model.py | 189 | ||||
-rw-r--r-- | r_basicsr/models/sr_model.py | 231 | ||||
-rw-r--r-- | r_basicsr/models/srgan_model.py | 149 | ||||
-rw-r--r-- | r_basicsr/models/stylegan2_model.py | 283 | ||||
-rw-r--r-- | r_basicsr/models/swinir_model.py | 33 | ||||
-rw-r--r-- | r_basicsr/models/video_base_model.py | 160 | ||||
-rw-r--r-- | r_basicsr/models/video_gan_model.py | 17 | ||||
-rw-r--r-- | r_basicsr/models/video_recurrent_gan_model.py | 180 | ||||
-rw-r--r-- | r_basicsr/models/video_recurrent_model.py | 197 |
16 files changed, 2644 insertions, 0 deletions
diff --git a/r_basicsr/models/__init__.py b/r_basicsr/models/__init__.py new file mode 100644 index 0000000..b01cdba --- /dev/null +++ b/r_basicsr/models/__init__.py @@ -0,0 +1,29 @@ +import importlib
+from copy import deepcopy
+from os import path as osp
+
+from r_basicsr.utils import get_root_logger, scandir
+from r_basicsr.utils.registry import MODEL_REGISTRY
+
+__all__ = ['build_model']
+
+# automatically scan and import model modules for registry
+# scan all the files under the 'models' folder and collect files ending with '_model.py'
+model_folder = osp.dirname(osp.abspath(__file__))
+model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
+# import all the model modules
+_model_modules = [importlib.import_module(f'r_basicsr.models.{file_name}') for file_name in model_filenames]
+
+
+def build_model(opt):
+ """Build model from options.
+
+ Args:
+ opt (dict): Configuration. It must contain:
+ model_type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ model = MODEL_REGISTRY.get(opt['model_type'])(opt)
+ logger = get_root_logger()
+ logger.info(f'Model [{model.__class__.__name__}] is created.')
+ return model
diff --git a/r_basicsr/models/base_model.py b/r_basicsr/models/base_model.py new file mode 100644 index 0000000..bd2faad --- /dev/null +++ b/r_basicsr/models/base_model.py @@ -0,0 +1,380 @@ +import os
+import time
+import torch
+from collections import OrderedDict
+from copy import deepcopy
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+
+from r_basicsr.models import lr_scheduler as lr_scheduler
+from r_basicsr.utils import get_root_logger
+from r_basicsr.utils.dist_util import master_only
+
+
+class BaseModel():
+ """Base model."""
+
+ def __init__(self, opt):
+ self.opt = opt
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
+ self.is_train = opt['is_train']
+ self.schedulers = []
+ self.optimizers = []
+
+ def feed_data(self, data):
+ pass
+
+ def optimize_parameters(self):
+ pass
+
+ def get_current_visuals(self):
+ pass
+
+ def save(self, epoch, current_iter):
+ """Save networks and training state."""
+ pass
+
+ def validation(self, dataloader, current_iter, tb_logger, save_img=False):
+ """Validation function.
+
+ Args:
+ dataloader (torch.utils.data.DataLoader): Validation dataloader.
+ current_iter (int): Current iteration.
+ tb_logger (tensorboard logger): Tensorboard logger.
+ save_img (bool): Whether to save images. Default: False.
+ """
+ if self.opt['dist']:
+ self.dist_validation(dataloader, current_iter, tb_logger, save_img)
+ else:
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+ def _initialize_best_metric_results(self, dataset_name):
+ """Initialize the best metric results dict for recording the best metric value and iteration."""
+ if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results:
+ return
+ elif not hasattr(self, 'best_metric_results'):
+ self.best_metric_results = dict()
+
+ # add a dataset record
+ record = dict()
+ for metric, content in self.opt['val']['metrics'].items():
+ better = content.get('better', 'higher')
+ init_val = float('-inf') if better == 'higher' else float('inf')
+ record[metric] = dict(better=better, val=init_val, iter=-1)
+ self.best_metric_results[dataset_name] = record
+
+ def _update_best_metric_result(self, dataset_name, metric, val, current_iter):
+ if self.best_metric_results[dataset_name][metric]['better'] == 'higher':
+ if val >= self.best_metric_results[dataset_name][metric]['val']:
+ self.best_metric_results[dataset_name][metric]['val'] = val
+ self.best_metric_results[dataset_name][metric]['iter'] = current_iter
+ else:
+ if val <= self.best_metric_results[dataset_name][metric]['val']:
+ self.best_metric_results[dataset_name][metric]['val'] = val
+ self.best_metric_results[dataset_name][metric]['iter'] = current_iter
+
+ def model_ema(self, decay=0.999):
+ net_g = self.get_bare_model(self.net_g)
+
+ net_g_params = dict(net_g.named_parameters())
+ net_g_ema_params = dict(self.net_g_ema.named_parameters())
+
+ for k in net_g_ema_params.keys():
+ net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay)
+
+ def get_current_log(self):
+ return self.log_dict
+
+ def model_to_device(self, net):
+ """Model to device. It also warps models with DistributedDataParallel
+ or DataParallel.
+
+ Args:
+ net (nn.Module)
+ """
+ net = net.to(self.device)
+ if self.opt['dist']:
+ find_unused_parameters = self.opt.get('find_unused_parameters', False)
+ net = DistributedDataParallel(
+ net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
+ elif self.opt['num_gpu'] > 1:
+ net = DataParallel(net)
+ return net
+
+ def get_optimizer(self, optim_type, params, lr, **kwargs):
+ if optim_type == 'Adam':
+ optimizer = torch.optim.Adam(params, lr, **kwargs)
+ else:
+ raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.')
+ return optimizer
+
+ def setup_schedulers(self):
+ """Set up schedulers."""
+ train_opt = self.opt['train']
+ scheduler_type = train_opt['scheduler'].pop('type')
+ if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
+ for optimizer in self.optimizers:
+ self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler']))
+ elif scheduler_type == 'CosineAnnealingRestartLR':
+ for optimizer in self.optimizers:
+ self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler']))
+ else:
+ raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')
+
+ def get_bare_model(self, net):
+ """Get bare model, especially under wrapping with
+ DistributedDataParallel or DataParallel.
+ """
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
+ net = net.module
+ return net
+
+ @master_only
+ def print_network(self, net):
+ """Print the str and parameter number of a network.
+
+ Args:
+ net (nn.Module)
+ """
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
+ net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}'
+ else:
+ net_cls_str = f'{net.__class__.__name__}'
+
+ net = self.get_bare_model(net)
+ net_str = str(net)
+ net_params = sum(map(lambda x: x.numel(), net.parameters()))
+
+ logger = get_root_logger()
+ logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}')
+ logger.info(net_str)
+
+ def _set_lr(self, lr_groups_l):
+ """Set learning rate for warm-up.
+
+ Args:
+ lr_groups_l (list): List for lr_groups, each for an optimizer.
+ """
+ for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
+ for param_group, lr in zip(optimizer.param_groups, lr_groups):
+ param_group['lr'] = lr
+
+ def _get_init_lr(self):
+ """Get the initial lr, which is set by the scheduler.
+ """
+ init_lr_groups_l = []
+ for optimizer in self.optimizers:
+ init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
+ return init_lr_groups_l
+
+ def update_learning_rate(self, current_iter, warmup_iter=-1):
+ """Update learning rate.
+
+ Args:
+ current_iter (int): Current iteration.
+ warmup_iter (int): Warm-up iter numbers. -1 for no warm-up.
+ Default: -1.
+ """
+ if current_iter > 1:
+ for scheduler in self.schedulers:
+ scheduler.step()
+ # set up warm-up learning rate
+ if current_iter < warmup_iter:
+ # get initial lr for each group
+ init_lr_g_l = self._get_init_lr()
+ # modify warming-up learning rates
+ # currently only support linearly warm up
+ warm_up_lr_l = []
+ for init_lr_g in init_lr_g_l:
+ warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g])
+ # set learning rate
+ self._set_lr(warm_up_lr_l)
+
+ def get_current_learning_rate(self):
+ return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
+
+ @master_only
+ def save_network(self, net, net_label, current_iter, param_key='params'):
+ """Save networks.
+
+ Args:
+ net (nn.Module | list[nn.Module]): Network(s) to be saved.
+ net_label (str): Network label.
+ current_iter (int): Current iter number.
+ param_key (str | list[str]): The parameter key(s) to save network.
+ Default: 'params'.
+ """
+ if current_iter == -1:
+ current_iter = 'latest'
+ save_filename = f'{net_label}_{current_iter}.pth'
+ save_path = os.path.join(self.opt['path']['models'], save_filename)
+
+ net = net if isinstance(net, list) else [net]
+ param_key = param_key if isinstance(param_key, list) else [param_key]
+ assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.'
+
+ save_dict = {}
+ for net_, param_key_ in zip(net, param_key):
+ net_ = self.get_bare_model(net_)
+ state_dict = net_.state_dict()
+ for key, param in state_dict.items():
+ if key.startswith('module.'): # remove unnecessary 'module.'
+ key = key[7:]
+ state_dict[key] = param.cpu()
+ save_dict[param_key_] = state_dict
+
+ # avoid occasional writing errors
+ retry = 3
+ while retry > 0:
+ try:
+ torch.save(save_dict, save_path)
+ except Exception as e:
+ logger = get_root_logger()
+ logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}')
+ time.sleep(1)
+ else:
+ break
+ finally:
+ retry -= 1
+ if retry == 0:
+ logger.warning(f'Still cannot save {save_path}. Just ignore it.')
+ # raise IOError(f'Cannot save {save_path}.')
+
+ def _print_different_keys_loading(self, crt_net, load_net, strict=True):
+ """Print keys with different name or different size when loading models.
+
+ 1. Print keys with different names.
+ 2. If strict=False, print the same key but with different tensor size.
+ It also ignore these keys with different sizes (not load).
+
+ Args:
+ crt_net (torch model): Current network.
+ load_net (dict): Loaded network.
+ strict (bool): Whether strictly loaded. Default: True.
+ """
+ crt_net = self.get_bare_model(crt_net)
+ crt_net = crt_net.state_dict()
+ crt_net_keys = set(crt_net.keys())
+ load_net_keys = set(load_net.keys())
+
+ logger = get_root_logger()
+ if crt_net_keys != load_net_keys:
+ logger.warning('Current net - loaded net:')
+ for v in sorted(list(crt_net_keys - load_net_keys)):
+ logger.warning(f' {v}')
+ logger.warning('Loaded net - current net:')
+ for v in sorted(list(load_net_keys - crt_net_keys)):
+ logger.warning(f' {v}')
+
+ # check the size for the same keys
+ if not strict:
+ common_keys = crt_net_keys & load_net_keys
+ for k in common_keys:
+ if crt_net[k].size() != load_net[k].size():
+ logger.warning(f'Size different, ignore [{k}]: crt_net: '
+ f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
+ load_net[k + '.ignore'] = load_net.pop(k)
+
+ def load_network(self, net, load_path, strict=True, param_key='params'):
+ """Load network.
+
+ Args:
+ load_path (str): The path of networks to be loaded.
+ net (nn.Module): Network.
+ strict (bool): Whether strictly loaded.
+ param_key (str): The parameter key of loaded network. If set to
+ None, use the root 'path'.
+ Default: 'params'.
+ """
+ logger = get_root_logger()
+ net = self.get_bare_model(net)
+ load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
+ if param_key is not None:
+ if param_key not in load_net and 'params' in load_net:
+ param_key = 'params'
+ logger.info('Loading: params_ema does not exist, use params.')
+ load_net = load_net[param_key]
+ logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].')
+ # remove unnecessary 'module.'
+ for k, v in deepcopy(load_net).items():
+ if k.startswith('module.'):
+ load_net[k[7:]] = v
+ load_net.pop(k)
+ self._print_different_keys_loading(net, load_net, strict)
+ net.load_state_dict(load_net, strict=strict)
+
+ @master_only
+ def save_training_state(self, epoch, current_iter):
+ """Save training states during training, which will be used for
+ resuming.
+
+ Args:
+ epoch (int): Current epoch.
+ current_iter (int): Current iteration.
+ """
+ if current_iter != -1:
+ state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []}
+ for o in self.optimizers:
+ state['optimizers'].append(o.state_dict())
+ for s in self.schedulers:
+ state['schedulers'].append(s.state_dict())
+ save_filename = f'{current_iter}.state'
+ save_path = os.path.join(self.opt['path']['training_states'], save_filename)
+
+ # avoid occasional writing errors
+ retry = 3
+ while retry > 0:
+ try:
+ torch.save(state, save_path)
+ except Exception as e:
+ logger = get_root_logger()
+ logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}')
+ time.sleep(1)
+ else:
+ break
+ finally:
+ retry -= 1
+ if retry == 0:
+ logger.warning(f'Still cannot save {save_path}. Just ignore it.')
+ # raise IOError(f'Cannot save {save_path}.')
+
+ def resume_training(self, resume_state):
+ """Reload the optimizers and schedulers for resumed training.
+
+ Args:
+ resume_state (dict): Resume state.
+ """
+ resume_optimizers = resume_state['optimizers']
+ resume_schedulers = resume_state['schedulers']
+ assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
+ assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
+ for i, o in enumerate(resume_optimizers):
+ self.optimizers[i].load_state_dict(o)
+ for i, s in enumerate(resume_schedulers):
+ self.schedulers[i].load_state_dict(s)
+
+ def reduce_loss_dict(self, loss_dict):
+ """reduce loss dict.
+
+ In distributed training, it averages the losses among different GPUs .
+
+ Args:
+ loss_dict (OrderedDict): Loss dict.
+ """
+ with torch.no_grad():
+ if self.opt['dist']:
+ keys = []
+ losses = []
+ for name, value in loss_dict.items():
+ keys.append(name)
+ losses.append(value)
+ losses = torch.stack(losses, 0)
+ torch.distributed.reduce(losses, dst=0)
+ if self.opt['rank'] == 0:
+ losses /= self.opt['world_size']
+ loss_dict = {key: loss for key, loss in zip(keys, losses)}
+
+ log_dict = OrderedDict()
+ for name, value in loss_dict.items():
+ log_dict[name] = value.mean().item()
+
+ return log_dict
diff --git a/r_basicsr/models/edvr_model.py b/r_basicsr/models/edvr_model.py new file mode 100644 index 0000000..1475033 --- /dev/null +++ b/r_basicsr/models/edvr_model.py @@ -0,0 +1,62 @@ +from r_basicsr.utils import get_root_logger
+from r_basicsr.utils.registry import MODEL_REGISTRY
+from .video_base_model import VideoBaseModel
+
+
+@MODEL_REGISTRY.register()
+class EDVRModel(VideoBaseModel):
+ """EDVR Model.
+
+ Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501
+ """
+
+ def __init__(self, opt):
+ super(EDVRModel, self).__init__(opt)
+ if self.is_train:
+ self.train_tsa_iter = opt['train'].get('tsa_iter')
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ dcn_lr_mul = train_opt.get('dcn_lr_mul', 1)
+ logger = get_root_logger()
+ logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.')
+ if dcn_lr_mul == 1:
+ optim_params = self.net_g.parameters()
+ else: # separate dcn params and normal params for different lr
+ normal_params = []
+ dcn_params = []
+ for name, param in self.net_g.named_parameters():
+ if 'dcn' in name:
+ dcn_params.append(param)
+ else:
+ normal_params.append(param)
+ optim_params = [
+ { # add normal params first
+ 'params': normal_params,
+ 'lr': train_opt['optim_g']['lr']
+ },
+ {
+ 'params': dcn_params,
+ 'lr': train_opt['optim_g']['lr'] * dcn_lr_mul
+ },
+ ]
+
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+
+ def optimize_parameters(self, current_iter):
+ if self.train_tsa_iter:
+ if current_iter == 1:
+ logger = get_root_logger()
+ logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.')
+ for name, param in self.net_g.named_parameters():
+ if 'fusion' not in name:
+ param.requires_grad = False
+ elif current_iter == self.train_tsa_iter:
+ logger = get_root_logger()
+ logger.warning('Train all the parameters.')
+ for param in self.net_g.parameters():
+ param.requires_grad = True
+
+ super(EDVRModel, self).optimize_parameters(current_iter)
diff --git a/r_basicsr/models/esrgan_model.py b/r_basicsr/models/esrgan_model.py new file mode 100644 index 0000000..8924920 --- /dev/null +++ b/r_basicsr/models/esrgan_model.py @@ -0,0 +1,83 @@ +import torch
+from collections import OrderedDict
+
+from r_basicsr.utils.registry import MODEL_REGISTRY
+from .srgan_model import SRGANModel
+
+
+@MODEL_REGISTRY.register()
+class ESRGANModel(SRGANModel):
+ """ESRGAN model for single image super-resolution."""
+
+ def optimize_parameters(self, current_iter):
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+
+ self.optimizer_g.zero_grad()
+ self.output = self.net_g(self.lq)
+
+ l_g_total = 0
+ loss_dict = OrderedDict()
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ # pixel loss
+ if self.cri_pix:
+ l_g_pix = self.cri_pix(self.output, self.gt)
+ l_g_total += l_g_pix
+ loss_dict['l_g_pix'] = l_g_pix
+ # perceptual loss
+ if self.cri_perceptual:
+ l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
+ if l_g_percep is not None:
+ l_g_total += l_g_percep
+ loss_dict['l_g_percep'] = l_g_percep
+ if l_g_style is not None:
+ l_g_total += l_g_style
+ loss_dict['l_g_style'] = l_g_style
+ # gan loss (relativistic gan)
+ real_d_pred = self.net_d(self.gt).detach()
+ fake_g_pred = self.net_d(self.output)
+ l_g_real = self.cri_gan(real_d_pred - torch.mean(fake_g_pred), False, is_disc=False)
+ l_g_fake = self.cri_gan(fake_g_pred - torch.mean(real_d_pred), True, is_disc=False)
+ l_g_gan = (l_g_real + l_g_fake) / 2
+
+ l_g_total += l_g_gan
+ loss_dict['l_g_gan'] = l_g_gan
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ # optimize net_d
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+
+ self.optimizer_d.zero_grad()
+ # gan loss (relativistic gan)
+
+ # In order to avoid the error in distributed training:
+ # "Error detected in CudnnBatchNormBackward: RuntimeError: one of
+ # the variables needed for gradient computation has been modified by
+ # an inplace operation",
+ # we separate the backwards for real and fake, and also detach the
+ # tensor for calculating mean.
+
+ # real
+ fake_d_pred = self.net_d(self.output).detach()
+ real_d_pred = self.net_d(self.gt)
+ l_d_real = self.cri_gan(real_d_pred - torch.mean(fake_d_pred), True, is_disc=True) * 0.5
+ l_d_real.backward()
+ # fake
+ fake_d_pred = self.net_d(self.output.detach())
+ l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred.detach()), False, is_disc=True) * 0.5
+ l_d_fake.backward()
+ self.optimizer_d.step()
+
+ loss_dict['l_d_real'] = l_d_real
+ loss_dict['l_d_fake'] = l_d_fake
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
diff --git a/r_basicsr/models/hifacegan_model.py b/r_basicsr/models/hifacegan_model.py new file mode 100644 index 0000000..fd67d11 --- /dev/null +++ b/r_basicsr/models/hifacegan_model.py @@ -0,0 +1,288 @@ +import torch
+from collections import OrderedDict
+from os import path as osp
+from tqdm import tqdm
+
+from r_basicsr.archs import build_network
+from r_basicsr.losses import build_loss
+from r_basicsr.metrics import calculate_metric
+from r_basicsr.utils import imwrite, tensor2img
+from r_basicsr.utils.registry import MODEL_REGISTRY
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class HiFaceGANModel(SRModel):
+ """HiFaceGAN model for generic-purpose face restoration.
+ No prior modeling required, works for any degradations.
+ Currently doesn't support EMA for inference.
+ """
+
+ def init_training_settings(self):
+
+ train_opt = self.opt['train']
+ self.ema_decay = train_opt.get('ema_decay', 0)
+ if self.ema_decay > 0:
+ raise (NotImplementedError('HiFaceGAN does not support EMA now. Pass'))
+
+ self.net_g.train()
+
+ self.net_d = build_network(self.opt['network_d'])
+ self.net_d = self.model_to_device(self.net_d)
+ self.print_network(self.net_d)
+
+ # define losses
+ # HiFaceGAN does not use pixel loss by default
+ if train_opt.get('pixel_opt'):
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+ else:
+ self.cri_pix = None
+
+ if train_opt.get('perceptual_opt'):
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+ else:
+ self.cri_perceptual = None
+
+ if train_opt.get('feature_matching_opt'):
+ self.cri_feat = build_loss(train_opt['feature_matching_opt']).to(self.device)
+ else:
+ self.cri_feat = None
+
+ if self.cri_pix is None and self.cri_perceptual is None:
+ raise ValueError('Both pixel and perceptual losses are None.')
+
+ if train_opt.get('gan_opt'):
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
+
+ self.net_d_iters = train_opt.get('net_d_iters', 1)
+ self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ # optimizer g
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, self.net_g.parameters(), **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+ # optimizer d
+ optim_type = train_opt['optim_d'].pop('type')
+ self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
+ self.optimizers.append(self.optimizer_d)
+
+ def discriminate(self, input_lq, output, ground_truth):
+ """
+ This is a conditional (on the input) discriminator
+ In Batch Normalization, the fake and real images are
+ recommended to be in the same batch to avoid disparate
+ statistics in fake and real images.
+ So both fake and real images are fed to D all at once.
+ """
+ h, w = output.shape[-2:]
+ if output.shape[-2:] != input_lq.shape[-2:]:
+ lq = torch.nn.functional.interpolate(input_lq, (h, w))
+ real = torch.nn.functional.interpolate(ground_truth, (h, w))
+ fake_concat = torch.cat([lq, output], dim=1)
+ real_concat = torch.cat([lq, real], dim=1)
+ else:
+ fake_concat = torch.cat([input_lq, output], dim=1)
+ real_concat = torch.cat([input_lq, ground_truth], dim=1)
+
+ fake_and_real = torch.cat([fake_concat, real_concat], dim=0)
+ discriminator_out = self.net_d(fake_and_real)
+ pred_fake, pred_real = self._divide_pred(discriminator_out)
+ return pred_fake, pred_real
+
+ @staticmethod
+ def _divide_pred(pred):
+ """
+ Take the prediction of fake and real images from the combined batch.
+ The prediction contains the intermediate outputs of multiscale GAN,
+ so it's usually a list
+ """
+ if type(pred) == list:
+ fake = []
+ real = []
+ for p in pred:
+ fake.append([tensor[:tensor.size(0) // 2] for tensor in p])
+ real.append([tensor[tensor.size(0) // 2:] for tensor in p])
+ else:
+ fake = pred[:pred.size(0) // 2]
+ real = pred[pred.size(0) // 2:]
+
+ return fake, real
+
+ def optimize_parameters(self, current_iter):
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+
+ self.optimizer_g.zero_grad()
+ self.output = self.net_g(self.lq)
+
+ l_g_total = 0
+ loss_dict = OrderedDict()
+
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ # pixel loss
+ if self.cri_pix:
+ l_g_pix = self.cri_pix(self.output, self.gt)
+ l_g_total += l_g_pix
+ loss_dict['l_g_pix'] = l_g_pix
+
+ # perceptual loss
+ if self.cri_perceptual:
+ l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
+ if l_g_percep is not None:
+ l_g_total += l_g_percep
+ loss_dict['l_g_percep'] = l_g_percep
+ if l_g_style is not None:
+ l_g_total += l_g_style
+ loss_dict['l_g_style'] = l_g_style
+
+ # Requires real prediction for feature matching loss
+ pred_fake, pred_real = self.discriminate(self.lq, self.output, self.gt)
+ l_g_gan = self.cri_gan(pred_fake, True, is_disc=False)
+ l_g_total += l_g_gan
+ loss_dict['l_g_gan'] = l_g_gan
+
+ # feature matching loss
+ if self.cri_feat:
+ l_g_feat = self.cri_feat(pred_fake, pred_real)
+ l_g_total += l_g_feat
+ loss_dict['l_g_feat'] = l_g_feat
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ # optimize net_d
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+
+ self.optimizer_d.zero_grad()
+ # TODO: Benchmark test between HiFaceGAN and SRGAN implementation:
+ # SRGAN use the same fake output for discriminator update
+ # while HiFaceGAN regenerate a new output using updated net_g
+ # This should not make too much difference though. Stick to SRGAN now.
+ # -------------------------------------------------------------------
+ # ---------- Below are original HiFaceGAN code snippet --------------
+ # -------------------------------------------------------------------
+ # with torch.no_grad():
+ # fake_image = self.net_g(self.lq)
+ # fake_image = fake_image.detach()
+ # fake_image.requires_grad_()
+ # pred_fake, pred_real = self.discriminate(self.lq, fake_image, self.gt)
+
+ # real
+ pred_fake, pred_real = self.discriminate(self.lq, self.output.detach(), self.gt)
+ l_d_real = self.cri_gan(pred_real, True, is_disc=True)
+ loss_dict['l_d_real'] = l_d_real
+ # fake
+ l_d_fake = self.cri_gan(pred_fake, False, is_disc=True)
+ loss_dict['l_d_fake'] = l_d_fake
+
+ l_d_total = (l_d_real + l_d_fake) / 2
+ l_d_total.backward()
+ self.optimizer_d.step()
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ if self.ema_decay > 0:
+ print('HiFaceGAN does not support EMA now. pass')
+
+ def validation(self, dataloader, current_iter, tb_logger, save_img=False):
+ """
+ Warning: HiFaceGAN requires train() mode even for validation
+ For more info, see https://github.com/Lotayou/Face-Renovation/issues/31
+
+ Args:
+ dataloader (torch.utils.data.DataLoader): Validation dataloader.
+ current_iter (int): Current iteration.
+ tb_logger (tensorboard logger): Tensorboard logger.
+ save_img (bool): Whether to save images. Default: False.
+ """
+
+ if self.opt['network_g']['type'] in ('HiFaceGAN', 'SPADEGenerator'):
+ self.net_g.train()
+
+ if self.opt['dist']:
+ self.dist_validation(dataloader, current_iter, tb_logger, save_img)
+ else:
+ print('In HiFaceGANModel: The new metrics package is under development.' +
+ 'Using super method now (Only PSNR & SSIM are supported)')
+ super().nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ """
+ TODO: Validation using updated metric system
+ The metrics are now evaluated after all images have been tested
+ This allows batch processing, and also allows evaluation of
+ distributional metrics, such as:
+
+ @ Frechet Inception Distance: FID
+ @ Maximum Mean Discrepancy: MMD
+
+ Warning:
+ Need careful batch management for different inference settings.
+
+ """
+ dataset_name = dataloader.dataset.opt['name']
+ with_metrics = self.opt['val'].get('metrics') is not None
+ if with_metrics:
+ self.metric_results = dict() # {metric: 0 for metric in self.opt['val']['metrics'].keys()}
+ sr_tensors = []
+ gt_tensors = []
+
+ pbar = tqdm(total=len(dataloader), unit='image')
+ for val_data in dataloader:
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
+ self.feed_data(val_data)
+ self.test()
+
+ visuals = self.get_current_visuals() # detached cpu tensor, non-squeeze
+ sr_tensors.append(visuals['result'])
+ if 'gt' in visuals:
+ gt_tensors.append(visuals['gt'])
+ del self.gt
+
+ # tentative for out of GPU memory
+ del self.lq
+ del self.output
+ torch.cuda.empty_cache()
+
+ if save_img:
+ if self.opt['is_train']:
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
+ f'{img_name}_{current_iter}.png')
+ else:
+ if self.opt['val']['suffix']:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
+ else:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["name"]}.png')
+
+ imwrite(tensor2img(visuals['result']), save_img_path)
+
+ pbar.update(1)
+ pbar.set_description(f'Test {img_name}')
+ pbar.close()
+
+ if with_metrics:
+ sr_pack = torch.cat(sr_tensors, dim=0)
+ gt_pack = torch.cat(gt_tensors, dim=0)
+ # calculate metrics
+ for name, opt_ in self.opt['val']['metrics'].items():
+ # The new metric caller automatically returns mean value
+ # FIXME: ERROR: calculate_metric only supports two arguments. Now the codes cannot be successfully run
+ self.metric_results[name] = calculate_metric(dict(sr_pack=sr_pack, gt_pack=gt_pack), opt_)
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+ def save(self, epoch, current_iter):
+ if hasattr(self, 'net_g_ema'):
+ print('HiFaceGAN does not support EMA now. Fallback to normal mode.')
+
+ self.save_network(self.net_g, 'net_g', current_iter)
+ self.save_network(self.net_d, 'net_d', current_iter)
+ self.save_training_state(epoch, current_iter)
diff --git a/r_basicsr/models/lr_scheduler.py b/r_basicsr/models/lr_scheduler.py new file mode 100644 index 0000000..084122d --- /dev/null +++ b/r_basicsr/models/lr_scheduler.py @@ -0,0 +1,96 @@ +import math
+from collections import Counter
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class MultiStepRestartLR(_LRScheduler):
+ """ MultiStep with restarts learning rate scheme.
+
+ Args:
+ optimizer (torch.nn.optimizer): Torch optimizer.
+ milestones (list): Iterations that will decrease learning rate.
+ gamma (float): Decrease ratio. Default: 0.1.
+ restarts (list): Restart iterations. Default: [0].
+ restart_weights (list): Restart weights at each restart iteration.
+ Default: [1].
+ last_epoch (int): Used in _LRScheduler. Default: -1.
+ """
+
+ def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1):
+ self.milestones = Counter(milestones)
+ self.gamma = gamma
+ self.restarts = restarts
+ self.restart_weights = restart_weights
+ assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.'
+ super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if self.last_epoch in self.restarts:
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
+ return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
+ if self.last_epoch not in self.milestones:
+ return [group['lr'] for group in self.optimizer.param_groups]
+ return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
+
+
+def get_position_from_periods(iteration, cumulative_period):
+ """Get the position from a period list.
+
+ It will return the index of the right-closest number in the period list.
+ For example, the cumulative_period = [100, 200, 300, 400],
+ if iteration == 50, return 0;
+ if iteration == 210, return 2;
+ if iteration == 300, return 2.
+
+ Args:
+ iteration (int): Current iteration.
+ cumulative_period (list[int]): Cumulative period list.
+
+ Returns:
+ int: The position of the right-closest number in the period list.
+ """
+ for i, period in enumerate(cumulative_period):
+ if iteration <= period:
+ return i
+
+
+class CosineAnnealingRestartLR(_LRScheduler):
+ """ Cosine annealing with restarts learning rate scheme.
+
+ An example of config:
+ periods = [10, 10, 10, 10]
+ restart_weights = [1, 0.5, 0.5, 0.5]
+ eta_min=1e-7
+
+ It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
+ scheduler will restart with the weights in restart_weights.
+
+ Args:
+ optimizer (torch.nn.optimizer): Torch optimizer.
+ periods (list): Period for each cosine anneling cycle.
+ restart_weights (list): Restart weights at each restart iteration.
+ Default: [1].
+ eta_min (float): The minimum lr. Default: 0.
+ last_epoch (int): Used in _LRScheduler. Default: -1.
+ """
+
+ def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
+ self.periods = periods
+ self.restart_weights = restart_weights
+ self.eta_min = eta_min
+ assert (len(self.periods) == len(
+ self.restart_weights)), 'periods and restart_weights should have the same length.'
+ self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
+ super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
+ current_weight = self.restart_weights[idx]
+ nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
+ current_period = self.periods[idx]
+
+ return [
+ self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
+ (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
+ for base_lr in self.base_lrs
+ ]
diff --git a/r_basicsr/models/realesrgan_model.py b/r_basicsr/models/realesrgan_model.py new file mode 100644 index 0000000..b05dafd --- /dev/null +++ b/r_basicsr/models/realesrgan_model.py @@ -0,0 +1,267 @@ +import numpy as np
+import random
+import torch
+from collections import OrderedDict
+from torch.nn import functional as F
+
+from r_basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
+from r_basicsr.data.transforms import paired_random_crop
+from r_basicsr.losses.loss_util import get_refined_artifact_map
+from r_basicsr.models.srgan_model import SRGANModel
+from r_basicsr.utils import DiffJPEG, USMSharp
+from r_basicsr.utils.img_process_util import filter2D
+from r_basicsr.utils.registry import MODEL_REGISTRY
+
+
+@MODEL_REGISTRY.register(suffix='basicsr')
+class RealESRGANModel(SRGANModel):
+ """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+ It mainly performs:
+ 1. randomly synthesize LQ images in GPU tensors
+ 2. optimize the networks with GAN training.
+ """
+
+ def __init__(self, opt):
+ super(RealESRGANModel, self).__init__(opt)
+ self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
+ self.usm_sharpener = USMSharp().cuda() # do usm sharpening
+ self.queue_size = opt.get('queue_size', 180)
+
+ @torch.no_grad()
+ def _dequeue_and_enqueue(self):
+ """It is the training pair pool for increasing the diversity in a batch.
+
+ Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
+ batch could not have different resize scaling factors. Therefore, we employ this training pair pool
+ to increase the degradation diversity in a batch.
+ """
+ # initialize
+ b, c, h, w = self.lq.size()
+ if not hasattr(self, 'queue_lr'):
+ assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
+ self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
+ _, c, h, w = self.gt.size()
+ self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
+ self.queue_ptr = 0
+ if self.queue_ptr == self.queue_size: # the pool is full
+ # do dequeue and enqueue
+ # shuffle
+ idx = torch.randperm(self.queue_size)
+ self.queue_lr = self.queue_lr[idx]
+ self.queue_gt = self.queue_gt[idx]
+ # get first b samples
+ lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
+ gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
+ # update the queue
+ self.queue_lr[0:b, :, :, :] = self.lq.clone()
+ self.queue_gt[0:b, :, :, :] = self.gt.clone()
+
+ self.lq = lq_dequeue
+ self.gt = gt_dequeue
+ else:
+ # only do enqueue
+ self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
+ self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
+ self.queue_ptr = self.queue_ptr + b
+
+ @torch.no_grad()
+ def feed_data(self, data):
+ """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
+ """
+ if self.is_train and self.opt.get('high_order_degradation', True):
+ # training data synthesis
+ self.gt = data['gt'].to(self.device)
+ self.gt_usm = self.usm_sharpener(self.gt)
+
+ self.kernel1 = data['kernel1'].to(self.device)
+ self.kernel2 = data['kernel2'].to(self.device)
+ self.sinc_kernel = data['sinc_kernel'].to(self.device)
+
+ ori_h, ori_w = self.gt.size()[2:4]
+
+ # ----------------------- The first degradation process ----------------------- #
+ # blur
+ out = filter2D(self.gt_usm, self.kernel1)
+ # random resize
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
+ if updown_type == 'up':
+ scale = np.random.uniform(1, self.opt['resize_range'][1])
+ elif updown_type == 'down':
+ scale = np.random.uniform(self.opt['resize_range'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, scale_factor=scale, mode=mode)
+ # add noise
+ gray_noise_prob = self.opt['gray_noise_prob']
+ if np.random.uniform() < self.opt['gaussian_noise_prob']:
+ out = random_add_gaussian_noise_pt(
+ out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.opt['poisson_scale_range'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
+ out = self.jpeger(out, quality=jpeg_p)
+
+ # ----------------------- The second degradation process ----------------------- #
+ # blur
+ if np.random.uniform() < self.opt['second_blur_prob']:
+ out = filter2D(out, self.kernel2)
+ # random resize
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
+ if updown_type == 'up':
+ scale = np.random.uniform(1, self.opt['resize_range2'][1])
+ elif updown_type == 'down':
+ scale = np.random.uniform(self.opt['resize_range2'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(
+ out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
+ # add noise
+ gray_noise_prob = self.opt['gray_noise_prob2']
+ if np.random.uniform() < self.opt['gaussian_noise_prob2']:
+ out = random_add_gaussian_noise_pt(
+ out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.opt['poisson_scale_range2'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False)
+
+ # JPEG compression + the final sinc filter
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
+ # as one operation.
+ # We consider two orders:
+ # 1. [resize back + sinc filter] + JPEG compression
+ # 2. JPEG compression + [resize back + sinc filter]
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
+ if np.random.uniform() < 0.5:
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
+ out = filter2D(out, self.sinc_kernel)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = self.jpeger(out, quality=jpeg_p)
+ else:
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = self.jpeger(out, quality=jpeg_p)
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
+ out = filter2D(out, self.sinc_kernel)
+
+ # clamp and round
+ self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+
+ # random crop
+ gt_size = self.opt['gt_size']
+ (self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size,
+ self.opt['scale'])
+
+ # training pair pool
+ self._dequeue_and_enqueue()
+ # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
+ self.gt_usm = self.usm_sharpener(self.gt)
+ self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
+ else:
+ # for paired training or validation
+ self.lq = data['lq'].to(self.device)
+ if 'gt' in data:
+ self.gt = data['gt'].to(self.device)
+ self.gt_usm = self.usm_sharpener(self.gt)
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ # do not use the synthetic process during validation
+ self.is_train = False
+ super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
+ self.is_train = True
+
+ def optimize_parameters(self, current_iter):
+ # usm sharpening
+ l1_gt = self.gt_usm
+ percep_gt = self.gt_usm
+ gan_gt = self.gt_usm
+ if self.opt['l1_gt_usm'] is False:
+ l1_gt = self.gt
+ if self.opt['percep_gt_usm'] is False:
+ percep_gt = self.gt
+ if self.opt['gan_gt_usm'] is False:
+ gan_gt = self.gt
+
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+
+ self.optimizer_g.zero_grad()
+ self.output = self.net_g(self.lq)
+ if self.cri_ldl:
+ self.output_ema = self.net_g_ema(self.lq)
+
+ l_g_total = 0
+ loss_dict = OrderedDict()
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ # pixel loss
+ if self.cri_pix:
+ l_g_pix = self.cri_pix(self.output, l1_gt)
+ l_g_total += l_g_pix
+ loss_dict['l_g_pix'] = l_g_pix
+ if self.cri_ldl:
+ pixel_weight = get_refined_artifact_map(self.gt, self.output, self.output_ema, 7)
+ l_g_ldl = self.cri_ldl(torch.mul(pixel_weight, self.output), torch.mul(pixel_weight, self.gt))
+ l_g_total += l_g_ldl
+ loss_dict['l_g_ldl'] = l_g_ldl
+ # perceptual loss
+ if self.cri_perceptual:
+ l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
+ if l_g_percep is not None:
+ l_g_total += l_g_percep
+ loss_dict['l_g_percep'] = l_g_percep
+ if l_g_style is not None:
+ l_g_total += l_g_style
+ loss_dict['l_g_style'] = l_g_style
+ # gan loss
+ fake_g_pred = self.net_d(self.output)
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
+ l_g_total += l_g_gan
+ loss_dict['l_g_gan'] = l_g_gan
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ # optimize net_d
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+
+ self.optimizer_d.zero_grad()
+ # real
+ real_d_pred = self.net_d(gan_gt)
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
+ loss_dict['l_d_real'] = l_d_real
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+ l_d_real.backward()
+ # fake
+ fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
+ loss_dict['l_d_fake'] = l_d_fake
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+ l_d_fake.backward()
+ self.optimizer_d.step()
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
diff --git a/r_basicsr/models/realesrnet_model.py b/r_basicsr/models/realesrnet_model.py new file mode 100644 index 0000000..2e8dc65 --- /dev/null +++ b/r_basicsr/models/realesrnet_model.py @@ -0,0 +1,189 @@ +import numpy as np
+import random
+import torch
+from torch.nn import functional as F
+
+from r_basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
+from r_basicsr.data.transforms import paired_random_crop
+from r_basicsr.models.sr_model import SRModel
+from r_basicsr.utils import DiffJPEG, USMSharp
+from r_basicsr.utils.img_process_util import filter2D
+from r_basicsr.utils.registry import MODEL_REGISTRY
+
+
+@MODEL_REGISTRY.register(suffix='basicsr')
+class RealESRNetModel(SRModel):
+ """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+ It is trained without GAN losses.
+ It mainly performs:
+ 1. randomly synthesize LQ images in GPU tensors
+ 2. optimize the networks with GAN training.
+ """
+
+ def __init__(self, opt):
+ super(RealESRNetModel, self).__init__(opt)
+ self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
+ self.usm_sharpener = USMSharp().cuda() # do usm sharpening
+ self.queue_size = opt.get('queue_size', 180)
+
+ @torch.no_grad()
+ def _dequeue_and_enqueue(self):
+ """It is the training pair pool for increasing the diversity in a batch.
+
+ Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
+ batch could not have different resize scaling factors. Therefore, we employ this training pair pool
+ to increase the degradation diversity in a batch.
+ """
+ # initialize
+ b, c, h, w = self.lq.size()
+ if not hasattr(self, 'queue_lr'):
+ assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
+ self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
+ _, c, h, w = self.gt.size()
+ self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
+ self.queue_ptr = 0
+ if self.queue_ptr == self.queue_size: # the pool is full
+ # do dequeue and enqueue
+ # shuffle
+ idx = torch.randperm(self.queue_size)
+ self.queue_lr = self.queue_lr[idx]
+ self.queue_gt = self.queue_gt[idx]
+ # get first b samples
+ lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
+ gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
+ # update the queue
+ self.queue_lr[0:b, :, :, :] = self.lq.clone()
+ self.queue_gt[0:b, :, :, :] = self.gt.clone()
+
+ self.lq = lq_dequeue
+ self.gt = gt_dequeue
+ else:
+ # only do enqueue
+ self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
+ self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
+ self.queue_ptr = self.queue_ptr + b
+
+ @torch.no_grad()
+ def feed_data(self, data):
+ """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
+ """
+ if self.is_train and self.opt.get('high_order_degradation', True):
+ # training data synthesis
+ self.gt = data['gt'].to(self.device)
+ # USM sharpen the GT images
+ if self.opt['gt_usm'] is True:
+ self.gt = self.usm_sharpener(self.gt)
+
+ self.kernel1 = data['kernel1'].to(self.device)
+ self.kernel2 = data['kernel2'].to(self.device)
+ self.sinc_kernel = data['sinc_kernel'].to(self.device)
+
+ ori_h, ori_w = self.gt.size()[2:4]
+
+ # ----------------------- The first degradation process ----------------------- #
+ # blur
+ out = filter2D(self.gt, self.kernel1)
+ # random resize
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
+ if updown_type == 'up':
+ scale = np.random.uniform(1, self.opt['resize_range'][1])
+ elif updown_type == 'down':
+ scale = np.random.uniform(self.opt['resize_range'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, scale_factor=scale, mode=mode)
+ # add noise
+ gray_noise_prob = self.opt['gray_noise_prob']
+ if np.random.uniform() < self.opt['gaussian_noise_prob']:
+ out = random_add_gaussian_noise_pt(
+ out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.opt['poisson_scale_range'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
+ out = self.jpeger(out, quality=jpeg_p)
+
+ # ----------------------- The second degradation process ----------------------- #
+ # blur
+ if np.random.uniform() < self.opt['second_blur_prob']:
+ out = filter2D(out, self.kernel2)
+ # random resize
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
+ if updown_type == 'up':
+ scale = np.random.uniform(1, self.opt['resize_range2'][1])
+ elif updown_type == 'down':
+ scale = np.random.uniform(self.opt['resize_range2'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(
+ out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
+ # add noise
+ gray_noise_prob = self.opt['gray_noise_prob2']
+ if np.random.uniform() < self.opt['gaussian_noise_prob2']:
+ out = random_add_gaussian_noise_pt(
+ out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.opt['poisson_scale_range2'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False)
+
+ # JPEG compression + the final sinc filter
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
+ # as one operation.
+ # We consider two orders:
+ # 1. [resize back + sinc filter] + JPEG compression
+ # 2. JPEG compression + [resize back + sinc filter]
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
+ if np.random.uniform() < 0.5:
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
+ out = filter2D(out, self.sinc_kernel)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = self.jpeger(out, quality=jpeg_p)
+ else:
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = self.jpeger(out, quality=jpeg_p)
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
+ out = filter2D(out, self.sinc_kernel)
+
+ # clamp and round
+ self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+
+ # random crop
+ gt_size = self.opt['gt_size']
+ self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
+
+ # training pair pool
+ self._dequeue_and_enqueue()
+ self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
+ else:
+ # for paired training or validation
+ self.lq = data['lq'].to(self.device)
+ if 'gt' in data:
+ self.gt = data['gt'].to(self.device)
+ self.gt_usm = self.usm_sharpener(self.gt)
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ # do not use the synthetic process during validation
+ self.is_train = False
+ super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
+ self.is_train = True
diff --git a/r_basicsr/models/sr_model.py b/r_basicsr/models/sr_model.py new file mode 100644 index 0000000..f6e37e9 --- /dev/null +++ b/r_basicsr/models/sr_model.py @@ -0,0 +1,231 @@ +import torch
+from collections import OrderedDict
+from os import path as osp
+from tqdm import tqdm
+
+from r_basicsr.archs import build_network
+from r_basicsr.losses import build_loss
+from r_basicsr.metrics import calculate_metric
+from r_basicsr.utils import get_root_logger, imwrite, tensor2img
+from r_basicsr.utils.registry import MODEL_REGISTRY
+from .base_model import BaseModel
+
+
+@MODEL_REGISTRY.register()
+class SRModel(BaseModel):
+ """Base SR model for single image super-resolution."""
+
+ def __init__(self, opt):
+ super(SRModel, self).__init__(opt)
+
+ # define network
+ self.net_g = build_network(opt['network_g'])
+ self.net_g = self.model_to_device(self.net_g)
+ self.print_network(self.net_g)
+
+ # load pretrained models
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ param_key = self.opt['path'].get('param_key_g', 'params')
+ self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
+
+ if self.is_train:
+ self.init_training_settings()
+
+ def init_training_settings(self):
+ self.net_g.train()
+ train_opt = self.opt['train']
+
+ self.ema_decay = train_opt.get('ema_decay', 0)
+ if self.ema_decay > 0:
+ logger = get_root_logger()
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
+ # define network net_g with Exponential Moving Average (EMA)
+ # net_g_ema is used only for testing on one GPU and saving
+ # There is no need to wrap with DistributedDataParallel
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+ else:
+ self.model_ema(0) # copy net_g weight
+ self.net_g_ema.eval()
+
+ # define losses
+ if train_opt.get('pixel_opt'):
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+ else:
+ self.cri_pix = None
+
+ if train_opt.get('perceptual_opt'):
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+ else:
+ self.cri_perceptual = None
+
+ if self.cri_pix is None and self.cri_perceptual is None:
+ raise ValueError('Both pixel and perceptual losses are None.')
+
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ optim_params = []
+ for k, v in self.net_g.named_parameters():
+ if v.requires_grad:
+ optim_params.append(v)
+ else:
+ logger = get_root_logger()
+ logger.warning(f'Params {k} will not be optimized.')
+
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+
+ def feed_data(self, data):
+ self.lq = data['lq'].to(self.device)
+ if 'gt' in data:
+ self.gt = data['gt'].to(self.device)
+
+ def optimize_parameters(self, current_iter):
+ self.optimizer_g.zero_grad()
+ self.output = self.net_g(self.lq)
+
+ l_total = 0
+ loss_dict = OrderedDict()
+ # pixel loss
+ if self.cri_pix:
+ l_pix = self.cri_pix(self.output, self.gt)
+ l_total += l_pix
+ loss_dict['l_pix'] = l_pix
+ # perceptual loss
+ if self.cri_perceptual:
+ l_percep, l_style = self.cri_perceptual(self.output, self.gt)
+ if l_percep is not None:
+ l_total += l_percep
+ loss_dict['l_percep'] = l_percep
+ if l_style is not None:
+ l_total += l_style
+ loss_dict['l_style'] = l_style
+
+ l_total.backward()
+ self.optimizer_g.step()
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+ def test(self):
+ if hasattr(self, 'net_g_ema'):
+ self.net_g_ema.eval()
+ with torch.no_grad():
+ self.output = self.net_g_ema(self.lq)
+ else:
+ self.net_g.eval()
+ with torch.no_grad():
+ self.output = self.net_g(self.lq)
+ self.net_g.train()
+
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ if self.opt['rank'] == 0:
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ dataset_name = dataloader.dataset.opt['name']
+ with_metrics = self.opt['val'].get('metrics') is not None
+ use_pbar = self.opt['val'].get('pbar', False)
+
+ if with_metrics:
+ if not hasattr(self, 'metric_results'): # only execute in the first run
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
+ # initialize the best metric results for each dataset_name (supporting multiple validation datasets)
+ self._initialize_best_metric_results(dataset_name)
+ # zero self.metric_results
+ if with_metrics:
+ self.metric_results = {metric: 0 for metric in self.metric_results}
+
+ metric_data = dict()
+ if use_pbar:
+ pbar = tqdm(total=len(dataloader), unit='image')
+
+ for idx, val_data in enumerate(dataloader):
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
+ self.feed_data(val_data)
+ self.test()
+
+ visuals = self.get_current_visuals()
+ sr_img = tensor2img([visuals['result']])
+ metric_data['img'] = sr_img
+ if 'gt' in visuals:
+ gt_img = tensor2img([visuals['gt']])
+ metric_data['img2'] = gt_img
+ del self.gt
+
+ # tentative for out of GPU memory
+ del self.lq
+ del self.output
+ torch.cuda.empty_cache()
+
+ if save_img:
+ if self.opt['is_train']:
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
+ f'{img_name}_{current_iter}.png')
+ else:
+ if self.opt['val']['suffix']:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
+ else:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["name"]}.png')
+ imwrite(sr_img, save_img_path)
+
+ if with_metrics:
+ # calculate metrics
+ for name, opt_ in self.opt['val']['metrics'].items():
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
+ if use_pbar:
+ pbar.update(1)
+ pbar.set_description(f'Test {img_name}')
+ if use_pbar:
+ pbar.close()
+
+ if with_metrics:
+ for metric in self.metric_results.keys():
+ self.metric_results[metric] /= (idx + 1)
+ # update the best metric result
+ self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
+
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
+ log_str = f'Validation {dataset_name}\n'
+ for metric, value in self.metric_results.items():
+ log_str += f'\t # {metric}: {value:.4f}'
+ if hasattr(self, 'best_metric_results'):
+ log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
+ f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
+ log_str += '\n'
+
+ logger = get_root_logger()
+ logger.info(log_str)
+ if tb_logger:
+ for metric, value in self.metric_results.items():
+ tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
+
+ def get_current_visuals(self):
+ out_dict = OrderedDict()
+ out_dict['lq'] = self.lq.detach().cpu()
+ out_dict['result'] = self.output.detach().cpu()
+ if hasattr(self, 'gt'):
+ out_dict['gt'] = self.gt.detach().cpu()
+ return out_dict
+
+ def save(self, epoch, current_iter):
+ if hasattr(self, 'net_g_ema'):
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+ else:
+ self.save_network(self.net_g, 'net_g', current_iter)
+ self.save_training_state(epoch, current_iter)
diff --git a/r_basicsr/models/srgan_model.py b/r_basicsr/models/srgan_model.py new file mode 100644 index 0000000..a562a7d --- /dev/null +++ b/r_basicsr/models/srgan_model.py @@ -0,0 +1,149 @@ +import torch
+from collections import OrderedDict
+
+from r_basicsr.archs import build_network
+from r_basicsr.losses import build_loss
+from r_basicsr.utils import get_root_logger
+from r_basicsr.utils.registry import MODEL_REGISTRY
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class SRGANModel(SRModel):
+ """SRGAN model for single image super-resolution."""
+
+ def init_training_settings(self):
+ train_opt = self.opt['train']
+
+ self.ema_decay = train_opt.get('ema_decay', 0)
+ if self.ema_decay > 0:
+ logger = get_root_logger()
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
+ # define network net_g with Exponential Moving Average (EMA)
+ # net_g_ema is used only for testing on one GPU and saving
+ # There is no need to wrap with DistributedDataParallel
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+ else:
+ self.model_ema(0) # copy net_g weight
+ self.net_g_ema.eval()
+
+ # define network net_d
+ self.net_d = build_network(self.opt['network_d'])
+ self.net_d = self.model_to_device(self.net_d)
+ self.print_network(self.net_d)
+
+ # load pretrained models
+ load_path = self.opt['path'].get('pretrain_network_d', None)
+ if load_path is not None:
+ param_key = self.opt['path'].get('param_key_d', 'params')
+ self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
+
+ self.net_g.train()
+ self.net_d.train()
+
+ # define losses
+ if train_opt.get('pixel_opt'):
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+ else:
+ self.cri_pix = None
+
+ if train_opt.get('ldl_opt'):
+ self.cri_ldl = build_loss(train_opt['ldl_opt']).to(self.device)
+ else:
+ self.cri_ldl = None
+
+ if train_opt.get('perceptual_opt'):
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+ else:
+ self.cri_perceptual = None
+
+ if train_opt.get('gan_opt'):
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
+
+ self.net_d_iters = train_opt.get('net_d_iters', 1)
+ self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
+
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ # optimizer g
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, self.net_g.parameters(), **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+ # optimizer d
+ optim_type = train_opt['optim_d'].pop('type')
+ self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
+ self.optimizers.append(self.optimizer_d)
+
+ def optimize_parameters(self, current_iter):
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+
+ self.optimizer_g.zero_grad()
+ self.output = self.net_g(self.lq)
+
+ l_g_total = 0
+ loss_dict = OrderedDict()
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ # pixel loss
+ if self.cri_pix:
+ l_g_pix = self.cri_pix(self.output, self.gt)
+ l_g_total += l_g_pix
+ loss_dict['l_g_pix'] = l_g_pix
+ # perceptual loss
+ if self.cri_perceptual:
+ l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
+ if l_g_percep is not None:
+ l_g_total += l_g_percep
+ loss_dict['l_g_percep'] = l_g_percep
+ if l_g_style is not None:
+ l_g_total += l_g_style
+ loss_dict['l_g_style'] = l_g_style
+ # gan loss
+ fake_g_pred = self.net_d(self.output)
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
+ l_g_total += l_g_gan
+ loss_dict['l_g_gan'] = l_g_gan
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ # optimize net_d
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+
+ self.optimizer_d.zero_grad()
+ # real
+ real_d_pred = self.net_d(self.gt)
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
+ loss_dict['l_d_real'] = l_d_real
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+ l_d_real.backward()
+ # fake
+ fake_d_pred = self.net_d(self.output.detach())
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
+ loss_dict['l_d_fake'] = l_d_fake
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+ l_d_fake.backward()
+ self.optimizer_d.step()
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+ def save(self, epoch, current_iter):
+ if hasattr(self, 'net_g_ema'):
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+ else:
+ self.save_network(self.net_g, 'net_g', current_iter)
+ self.save_network(self.net_d, 'net_d', current_iter)
+ self.save_training_state(epoch, current_iter)
diff --git a/r_basicsr/models/stylegan2_model.py b/r_basicsr/models/stylegan2_model.py new file mode 100644 index 0000000..58a38ae --- /dev/null +++ b/r_basicsr/models/stylegan2_model.py @@ -0,0 +1,283 @@ +import cv2
+import math
+import numpy as np
+import random
+import torch
+from collections import OrderedDict
+from os import path as osp
+
+from r_basicsr.archs import build_network
+from r_basicsr.losses import build_loss
+from r_basicsr.losses.gan_loss import g_path_regularize, r1_penalty
+from r_basicsr.utils import imwrite, tensor2img
+from r_basicsr.utils.registry import MODEL_REGISTRY
+from .base_model import BaseModel
+
+
+@MODEL_REGISTRY.register()
+class StyleGAN2Model(BaseModel):
+ """StyleGAN2 model."""
+
+ def __init__(self, opt):
+ super(StyleGAN2Model, self).__init__(opt)
+
+ # define network net_g
+ self.net_g = build_network(opt['network_g'])
+ self.net_g = self.model_to_device(self.net_g)
+ self.print_network(self.net_g)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ param_key = self.opt['path'].get('param_key_g', 'params')
+ self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
+
+ # latent dimension: self.num_style_feat
+ self.num_style_feat = opt['network_g']['num_style_feat']
+ num_val_samples = self.opt['val'].get('num_val_samples', 16)
+ self.fixed_sample = torch.randn(num_val_samples, self.num_style_feat, device=self.device)
+
+ if self.is_train:
+ self.init_training_settings()
+
+ def init_training_settings(self):
+ train_opt = self.opt['train']
+
+ # define network net_d
+ self.net_d = build_network(self.opt['network_d'])
+ self.net_d = self.model_to_device(self.net_d)
+ self.print_network(self.net_d)
+
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_d', None)
+ if load_path is not None:
+ param_key = self.opt['path'].get('param_key_d', 'params')
+ self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
+
+ # define network net_g with Exponential Moving Average (EMA)
+ # net_g_ema only used for testing on one GPU and saving, do not need to
+ # wrap with DistributedDataParallel
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+ else:
+ self.model_ema(0) # copy net_g weight
+
+ self.net_g.train()
+ self.net_d.train()
+ self.net_g_ema.eval()
+
+ # define losses
+ # gan loss (wgan)
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
+ # regularization weights
+ self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator
+ self.path_reg_weight = train_opt['path_reg_weight'] # for generator
+
+ self.net_g_reg_every = train_opt['net_g_reg_every']
+ self.net_d_reg_every = train_opt['net_d_reg_every']
+ self.mixing_prob = train_opt['mixing_prob']
+
+ self.mean_path_length = 0
+
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ # optimizer g
+ net_g_reg_ratio = self.net_g_reg_every / (self.net_g_reg_every + 1)
+ if self.opt['network_g']['type'] == 'StyleGAN2GeneratorC':
+ normal_params = []
+ style_mlp_params = []
+ modulation_conv_params = []
+ for name, param in self.net_g.named_parameters():
+ if 'modulation' in name:
+ normal_params.append(param)
+ elif 'style_mlp' in name:
+ style_mlp_params.append(param)
+ elif 'modulated_conv' in name:
+ modulation_conv_params.append(param)
+ else:
+ normal_params.append(param)
+ optim_params_g = [
+ { # add normal params first
+ 'params': normal_params,
+ 'lr': train_opt['optim_g']['lr']
+ },
+ {
+ 'params': style_mlp_params,
+ 'lr': train_opt['optim_g']['lr'] * 0.01
+ },
+ {
+ 'params': modulation_conv_params,
+ 'lr': train_opt['optim_g']['lr'] / 3
+ }
+ ]
+ else:
+ normal_params = []
+ for name, param in self.net_g.named_parameters():
+ normal_params.append(param)
+ optim_params_g = [{ # add normal params first
+ 'params': normal_params,
+ 'lr': train_opt['optim_g']['lr']
+ }]
+
+ optim_type = train_opt['optim_g'].pop('type')
+ lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
+ betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio)
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas)
+ self.optimizers.append(self.optimizer_g)
+
+ # optimizer d
+ net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1)
+ if self.opt['network_d']['type'] == 'StyleGAN2DiscriminatorC':
+ normal_params = []
+ linear_params = []
+ for name, param in self.net_d.named_parameters():
+ if 'final_linear' in name:
+ linear_params.append(param)
+ else:
+ normal_params.append(param)
+ optim_params_d = [
+ { # add normal params first
+ 'params': normal_params,
+ 'lr': train_opt['optim_d']['lr']
+ },
+ {
+ 'params': linear_params,
+ 'lr': train_opt['optim_d']['lr'] * (1 / math.sqrt(512))
+ }
+ ]
+ else:
+ normal_params = []
+ for name, param in self.net_d.named_parameters():
+ normal_params.append(param)
+ optim_params_d = [{ # add normal params first
+ 'params': normal_params,
+ 'lr': train_opt['optim_d']['lr']
+ }]
+
+ optim_type = train_opt['optim_d'].pop('type')
+ lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
+ betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio)
+ self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
+ self.optimizers.append(self.optimizer_d)
+
+ def feed_data(self, data):
+ self.real_img = data['gt'].to(self.device)
+
+ def make_noise(self, batch, num_noise):
+ if num_noise == 1:
+ noises = torch.randn(batch, self.num_style_feat, device=self.device)
+ else:
+ noises = torch.randn(num_noise, batch, self.num_style_feat, device=self.device).unbind(0)
+ return noises
+
+ def mixing_noise(self, batch, prob):
+ if random.random() < prob:
+ return self.make_noise(batch, 2)
+ else:
+ return [self.make_noise(batch, 1)]
+
+ def optimize_parameters(self, current_iter):
+ loss_dict = OrderedDict()
+
+ # optimize net_d
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+ self.optimizer_d.zero_grad()
+
+ batch = self.real_img.size(0)
+ noise = self.mixing_noise(batch, self.mixing_prob)
+ fake_img, _ = self.net_g(noise)
+ fake_pred = self.net_d(fake_img.detach())
+
+ real_pred = self.net_d(self.real_img)
+ # wgan loss with softplus (logistic loss) for discriminator
+ l_d = self.cri_gan(real_pred, True, is_disc=True) + self.cri_gan(fake_pred, False, is_disc=True)
+ loss_dict['l_d'] = l_d
+ # In wgan, real_score should be positive and fake_score should be
+ # negative
+ loss_dict['real_score'] = real_pred.detach().mean()
+ loss_dict['fake_score'] = fake_pred.detach().mean()
+ l_d.backward()
+
+ if current_iter % self.net_d_reg_every == 0:
+ self.real_img.requires_grad = True
+ real_pred = self.net_d(self.real_img)
+ l_d_r1 = r1_penalty(real_pred, self.real_img)
+ l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0])
+ # TODO: why do we need to add 0 * real_pred, otherwise, a runtime
+ # error will arise: RuntimeError: Expected to have finished
+ # reduction in the prior iteration before starting a new one.
+ # This error indicates that your module has parameters that were
+ # not used in producing loss.
+ loss_dict['l_d_r1'] = l_d_r1.detach().mean()
+ l_d_r1.backward()
+
+ self.optimizer_d.step()
+
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+ self.optimizer_g.zero_grad()
+
+ noise = self.mixing_noise(batch, self.mixing_prob)
+ fake_img, _ = self.net_g(noise)
+ fake_pred = self.net_d(fake_img)
+
+ # wgan loss with softplus (non-saturating loss) for generator
+ l_g = self.cri_gan(fake_pred, True, is_disc=False)
+ loss_dict['l_g'] = l_g
+ l_g.backward()
+
+ if current_iter % self.net_g_reg_every == 0:
+ path_batch_size = max(1, batch // self.opt['train']['path_batch_shrink'])
+ noise = self.mixing_noise(path_batch_size, self.mixing_prob)
+ fake_img, latents = self.net_g(noise, return_latents=True)
+ l_g_path, path_lengths, self.mean_path_length = g_path_regularize(fake_img, latents, self.mean_path_length)
+
+ l_g_path = (self.path_reg_weight * self.net_g_reg_every * l_g_path + 0 * fake_img[0, 0, 0, 0])
+ # TODO: why do we need to add 0 * fake_img[0, 0, 0, 0]
+ l_g_path.backward()
+ loss_dict['l_g_path'] = l_g_path.detach().mean()
+ loss_dict['path_length'] = path_lengths
+
+ self.optimizer_g.step()
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ # EMA
+ self.model_ema(decay=0.5**(32 / (10 * 1000)))
+
+ def test(self):
+ with torch.no_grad():
+ self.net_g_ema.eval()
+ self.output, _ = self.net_g_ema([self.fixed_sample])
+
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ if self.opt['rank'] == 0:
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ assert dataloader is None, 'Validation dataloader should be None.'
+ self.test()
+ result = tensor2img(self.output, min_max=(-1, 1))
+ if self.opt['is_train']:
+ save_img_path = osp.join(self.opt['path']['visualization'], 'train', f'train_{current_iter}.png')
+ else:
+ save_img_path = osp.join(self.opt['path']['visualization'], 'test', f'test_{self.opt["name"]}.png')
+ imwrite(result, save_img_path)
+ # add sample images to tb_logger
+ result = (result / 255.).astype(np.float32)
+ result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
+ if tb_logger is not None:
+ tb_logger.add_image('samples', result, global_step=current_iter, dataformats='HWC')
+
+ def save(self, epoch, current_iter):
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+ self.save_network(self.net_d, 'net_d', current_iter)
+ self.save_training_state(epoch, current_iter)
diff --git a/r_basicsr/models/swinir_model.py b/r_basicsr/models/swinir_model.py new file mode 100644 index 0000000..18e5550 --- /dev/null +++ b/r_basicsr/models/swinir_model.py @@ -0,0 +1,33 @@ +import torch
+from torch.nn import functional as F
+
+from r_basicsr.utils.registry import MODEL_REGISTRY
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class SwinIRModel(SRModel):
+
+ def test(self):
+ # pad to multiplication of window_size
+ window_size = self.opt['network_g']['window_size']
+ scale = self.opt.get('scale', 1)
+ mod_pad_h, mod_pad_w = 0, 0
+ _, _, h, w = self.lq.size()
+ if h % window_size != 0:
+ mod_pad_h = window_size - h % window_size
+ if w % window_size != 0:
+ mod_pad_w = window_size - w % window_size
+ img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
+ if hasattr(self, 'net_g_ema'):
+ self.net_g_ema.eval()
+ with torch.no_grad():
+ self.output = self.net_g_ema(img)
+ else:
+ self.net_g.eval()
+ with torch.no_grad():
+ self.output = self.net_g(img)
+ self.net_g.train()
+
+ _, _, h, w = self.output.size()
+ self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
diff --git a/r_basicsr/models/video_base_model.py b/r_basicsr/models/video_base_model.py new file mode 100644 index 0000000..31ea37d --- /dev/null +++ b/r_basicsr/models/video_base_model.py @@ -0,0 +1,160 @@ +import torch
+from collections import Counter
+from os import path as osp
+from torch import distributed as dist
+from tqdm import tqdm
+
+from r_basicsr.metrics import calculate_metric
+from r_basicsr.utils import get_root_logger, imwrite, tensor2img
+from r_basicsr.utils.dist_util import get_dist_info
+from r_basicsr.utils.registry import MODEL_REGISTRY
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class VideoBaseModel(SRModel):
+ """Base video SR model."""
+
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ dataset = dataloader.dataset
+ dataset_name = dataset.opt['name']
+ with_metrics = self.opt['val']['metrics'] is not None
+ # initialize self.metric_results
+ # It is a dict: {
+ # 'folder1': tensor (num_frame x len(metrics)),
+ # 'folder2': tensor (num_frame x len(metrics))
+ # }
+ if with_metrics:
+ if not hasattr(self, 'metric_results'): # only execute in the first run
+ self.metric_results = {}
+ num_frame_each_folder = Counter(dataset.data_info['folder'])
+ for folder, num_frame in num_frame_each_folder.items():
+ self.metric_results[folder] = torch.zeros(
+ num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
+ # initialize the best metric results
+ self._initialize_best_metric_results(dataset_name)
+ # zero self.metric_results
+ rank, world_size = get_dist_info()
+ if with_metrics:
+ for _, tensor in self.metric_results.items():
+ tensor.zero_()
+
+ metric_data = dict()
+ # record all frames (border and center frames)
+ if rank == 0:
+ pbar = tqdm(total=len(dataset), unit='frame')
+ for idx in range(rank, len(dataset), world_size):
+ val_data = dataset[idx]
+ val_data['lq'].unsqueeze_(0)
+ val_data['gt'].unsqueeze_(0)
+ folder = val_data['folder']
+ frame_idx, max_idx = val_data['idx'].split('/')
+ lq_path = val_data['lq_path']
+
+ self.feed_data(val_data)
+ self.test()
+ visuals = self.get_current_visuals()
+ result_img = tensor2img([visuals['result']])
+ metric_data['img'] = result_img
+ if 'gt' in visuals:
+ gt_img = tensor2img([visuals['gt']])
+ metric_data['img2'] = gt_img
+ del self.gt
+
+ # tentative for out of GPU memory
+ del self.lq
+ del self.output
+ torch.cuda.empty_cache()
+
+ if save_img:
+ if self.opt['is_train']:
+ raise NotImplementedError('saving image is not supported during training.')
+ else:
+ if 'vimeo' in dataset_name.lower(): # vimeo90k dataset
+ split_result = lq_path.split('/')
+ img_name = f'{split_result[-3]}_{split_result[-2]}_{split_result[-1].split(".")[0]}'
+ else: # other datasets, e.g., REDS, Vid4
+ img_name = osp.splitext(osp.basename(lq_path))[0]
+
+ if self.opt['val']['suffix']:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
+ else:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
+ f'{img_name}_{self.opt["name"]}.png')
+ imwrite(result_img, save_img_path)
+
+ if with_metrics:
+ # calculate metrics
+ for metric_idx, opt_ in enumerate(self.opt['val']['metrics'].values()):
+ result = calculate_metric(metric_data, opt_)
+ self.metric_results[folder][int(frame_idx), metric_idx] += result
+
+ # progress bar
+ if rank == 0:
+ for _ in range(world_size):
+ pbar.update(1)
+ pbar.set_description(f'Test {folder}: {int(frame_idx) + world_size}/{max_idx}')
+ if rank == 0:
+ pbar.close()
+
+ if with_metrics:
+ if self.opt['dist']:
+ # collect data among GPUs
+ for _, tensor in self.metric_results.items():
+ dist.reduce(tensor, 0)
+ dist.barrier()
+ else:
+ pass # assume use one gpu in non-dist testing
+
+ if rank == 0:
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ logger = get_root_logger()
+ logger.warning('nondist_validation is not implemented. Run dist_validation.')
+ self.dist_validation(dataloader, current_iter, tb_logger, save_img)
+
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
+ # ----------------- calculate the average values for each folder, and for each metric ----------------- #
+ # average all frames for each sub-folder
+ # metric_results_avg is a dict:{
+ # 'folder1': tensor (len(metrics)),
+ # 'folder2': tensor (len(metrics))
+ # }
+ metric_results_avg = {
+ folder: torch.mean(tensor, dim=0).cpu()
+ for (folder, tensor) in self.metric_results.items()
+ }
+ # total_avg_results is a dict: {
+ # 'metric1': float,
+ # 'metric2': float
+ # }
+ total_avg_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
+ for folder, tensor in metric_results_avg.items():
+ for idx, metric in enumerate(total_avg_results.keys()):
+ total_avg_results[metric] += metric_results_avg[folder][idx].item()
+ # average among folders
+ for metric in total_avg_results.keys():
+ total_avg_results[metric] /= len(metric_results_avg)
+ # update the best metric result
+ self._update_best_metric_result(dataset_name, metric, total_avg_results[metric], current_iter)
+
+ # ------------------------------------------ log the metric ------------------------------------------ #
+ log_str = f'Validation {dataset_name}\n'
+ for metric_idx, (metric, value) in enumerate(total_avg_results.items()):
+ log_str += f'\t # {metric}: {value:.4f}'
+ for folder, tensor in metric_results_avg.items():
+ log_str += f'\t # {folder}: {tensor[metric_idx].item():.4f}'
+ if hasattr(self, 'best_metric_results'):
+ log_str += (f'\n\t Best: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
+ f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
+ log_str += '\n'
+
+ logger = get_root_logger()
+ logger.info(log_str)
+ if tb_logger:
+ for metric_idx, (metric, value) in enumerate(total_avg_results.items()):
+ tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
+ for folder, tensor in metric_results_avg.items():
+ tb_logger.add_scalar(f'metrics/{metric}/{folder}', tensor[metric_idx].item(), current_iter)
diff --git a/r_basicsr/models/video_gan_model.py b/r_basicsr/models/video_gan_model.py new file mode 100644 index 0000000..cc44476 --- /dev/null +++ b/r_basicsr/models/video_gan_model.py @@ -0,0 +1,17 @@ +from r_basicsr.utils.registry import MODEL_REGISTRY
+from .srgan_model import SRGANModel
+from .video_base_model import VideoBaseModel
+
+
+@MODEL_REGISTRY.register()
+class VideoGANModel(SRGANModel, VideoBaseModel):
+ """Video GAN model.
+
+ Use multiple inheritance.
+ It will first use the functions of SRGANModel:
+ init_training_settings
+ setup_optimizers
+ optimize_parameters
+ save
+ Then find functions in VideoBaseModel.
+ """
diff --git a/r_basicsr/models/video_recurrent_gan_model.py b/r_basicsr/models/video_recurrent_gan_model.py new file mode 100644 index 0000000..2800e27 --- /dev/null +++ b/r_basicsr/models/video_recurrent_gan_model.py @@ -0,0 +1,180 @@ +import torch
+from collections import OrderedDict
+
+from r_basicsr.archs import build_network
+from r_basicsr.losses import build_loss
+from r_basicsr.utils import get_root_logger
+from r_basicsr.utils.registry import MODEL_REGISTRY
+from .video_recurrent_model import VideoRecurrentModel
+
+
+@MODEL_REGISTRY.register()
+class VideoRecurrentGANModel(VideoRecurrentModel):
+
+ def init_training_settings(self):
+ train_opt = self.opt['train']
+
+ self.ema_decay = train_opt.get('ema_decay', 0)
+ if self.ema_decay > 0:
+ logger = get_root_logger()
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
+ # build network net_g with Exponential Moving Average (EMA)
+ # net_g_ema only used for testing on one GPU and saving.
+ # There is no need to wrap with DistributedDataParallel
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+ else:
+ self.model_ema(0) # copy net_g weight
+ self.net_g_ema.eval()
+
+ # define network net_d
+ self.net_d = build_network(self.opt['network_d'])
+ self.net_d = self.model_to_device(self.net_d)
+ self.print_network(self.net_d)
+
+ # load pretrained models
+ load_path = self.opt['path'].get('pretrain_network_d', None)
+ if load_path is not None:
+ param_key = self.opt['path'].get('param_key_d', 'params')
+ self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
+
+ self.net_g.train()
+ self.net_d.train()
+
+ # define losses
+ if train_opt.get('pixel_opt'):
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+ else:
+ self.cri_pix = None
+
+ if train_opt.get('perceptual_opt'):
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+ else:
+ self.cri_perceptual = None
+
+ if train_opt.get('gan_opt'):
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
+
+ self.net_d_iters = train_opt.get('net_d_iters', 1)
+ self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
+
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ if train_opt['fix_flow']:
+ normal_params = []
+ flow_params = []
+ for name, param in self.net_g.named_parameters():
+ if 'spynet' in name: # The fix_flow now only works for spynet.
+ flow_params.append(param)
+ else:
+ normal_params.append(param)
+
+ optim_params = [
+ { # add flow params first
+ 'params': flow_params,
+ 'lr': train_opt['lr_flow']
+ },
+ {
+ 'params': normal_params,
+ 'lr': train_opt['optim_g']['lr']
+ },
+ ]
+ else:
+ optim_params = self.net_g.parameters()
+
+ # optimizer g
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+ # optimizer d
+ optim_type = train_opt['optim_d'].pop('type')
+ self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
+ self.optimizers.append(self.optimizer_d)
+
+ def optimize_parameters(self, current_iter):
+ logger = get_root_logger()
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+
+ if self.fix_flow_iter:
+ if current_iter == 1:
+ logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.')
+ for name, param in self.net_g.named_parameters():
+ if 'spynet' in name or 'edvr' in name:
+ param.requires_grad_(False)
+ elif current_iter == self.fix_flow_iter:
+ logger.warning('Train all the parameters.')
+ self.net_g.requires_grad_(True)
+
+ self.optimizer_g.zero_grad()
+ self.output = self.net_g(self.lq)
+
+ _, _, c, h, w = self.output.size()
+
+ l_g_total = 0
+ loss_dict = OrderedDict()
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ # pixel loss
+ if self.cri_pix:
+ l_g_pix = self.cri_pix(self.output, self.gt)
+ l_g_total += l_g_pix
+ loss_dict['l_g_pix'] = l_g_pix
+ # perceptual loss
+ if self.cri_perceptual:
+ l_g_percep, l_g_style = self.cri_perceptual(self.output.view(-1, c, h, w), self.gt.view(-1, c, h, w))
+ if l_g_percep is not None:
+ l_g_total += l_g_percep
+ loss_dict['l_g_percep'] = l_g_percep
+ if l_g_style is not None:
+ l_g_total += l_g_style
+ loss_dict['l_g_style'] = l_g_style
+ # gan loss
+ fake_g_pred = self.net_d(self.output.view(-1, c, h, w))
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
+ l_g_total += l_g_gan
+ loss_dict['l_g_gan'] = l_g_gan
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ # optimize net_d
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+
+ self.optimizer_d.zero_grad()
+ # real
+ # reshape to (b*n, c, h, w)
+ real_d_pred = self.net_d(self.gt.view(-1, c, h, w))
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
+ loss_dict['l_d_real'] = l_d_real
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+ l_d_real.backward()
+ # fake
+ # reshape to (b*n, c, h, w)
+ fake_d_pred = self.net_d(self.output.view(-1, c, h, w).detach())
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
+ loss_dict['l_d_fake'] = l_d_fake
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+ l_d_fake.backward()
+ self.optimizer_d.step()
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+ def save(self, epoch, current_iter):
+ if self.ema_decay > 0:
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+ else:
+ self.save_network(self.net_g, 'net_g', current_iter)
+ self.save_network(self.net_d, 'net_d', current_iter)
+ self.save_training_state(epoch, current_iter)
diff --git a/r_basicsr/models/video_recurrent_model.py b/r_basicsr/models/video_recurrent_model.py new file mode 100644 index 0000000..ea3a4c5 --- /dev/null +++ b/r_basicsr/models/video_recurrent_model.py @@ -0,0 +1,197 @@ +import torch
+from collections import Counter
+from os import path as osp
+from torch import distributed as dist
+from tqdm import tqdm
+
+from r_basicsr.metrics import calculate_metric
+from r_basicsr.utils import get_root_logger, imwrite, tensor2img
+from r_basicsr.utils.dist_util import get_dist_info
+from r_basicsr.utils.registry import MODEL_REGISTRY
+from .video_base_model import VideoBaseModel
+
+
+@MODEL_REGISTRY.register()
+class VideoRecurrentModel(VideoBaseModel):
+
+ def __init__(self, opt):
+ super(VideoRecurrentModel, self).__init__(opt)
+ if self.is_train:
+ self.fix_flow_iter = opt['train'].get('fix_flow')
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ flow_lr_mul = train_opt.get('flow_lr_mul', 1)
+ logger = get_root_logger()
+ logger.info(f'Multiple the learning rate for flow network with {flow_lr_mul}.')
+ if flow_lr_mul == 1:
+ optim_params = self.net_g.parameters()
+ else: # separate flow params and normal params for different lr
+ normal_params = []
+ flow_params = []
+ for name, param in self.net_g.named_parameters():
+ if 'spynet' in name:
+ flow_params.append(param)
+ else:
+ normal_params.append(param)
+ optim_params = [
+ { # add normal params first
+ 'params': normal_params,
+ 'lr': train_opt['optim_g']['lr']
+ },
+ {
+ 'params': flow_params,
+ 'lr': train_opt['optim_g']['lr'] * flow_lr_mul
+ },
+ ]
+
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+
+ def optimize_parameters(self, current_iter):
+ if self.fix_flow_iter:
+ logger = get_root_logger()
+ if current_iter == 1:
+ logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.')
+ for name, param in self.net_g.named_parameters():
+ if 'spynet' in name or 'edvr' in name:
+ param.requires_grad_(False)
+ elif current_iter == self.fix_flow_iter:
+ logger.warning('Train all the parameters.')
+ self.net_g.requires_grad_(True)
+
+ super(VideoRecurrentModel, self).optimize_parameters(current_iter)
+
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ dataset = dataloader.dataset
+ dataset_name = dataset.opt['name']
+ with_metrics = self.opt['val']['metrics'] is not None
+ # initialize self.metric_results
+ # It is a dict: {
+ # 'folder1': tensor (num_frame x len(metrics)),
+ # 'folder2': tensor (num_frame x len(metrics))
+ # }
+ if with_metrics:
+ if not hasattr(self, 'metric_results'): # only execute in the first run
+ self.metric_results = {}
+ num_frame_each_folder = Counter(dataset.data_info['folder'])
+ for folder, num_frame in num_frame_each_folder.items():
+ self.metric_results[folder] = torch.zeros(
+ num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
+ # initialize the best metric results
+ self._initialize_best_metric_results(dataset_name)
+ # zero self.metric_results
+ rank, world_size = get_dist_info()
+ if with_metrics:
+ for _, tensor in self.metric_results.items():
+ tensor.zero_()
+
+ metric_data = dict()
+ num_folders = len(dataset)
+ num_pad = (world_size - (num_folders % world_size)) % world_size
+ if rank == 0:
+ pbar = tqdm(total=len(dataset), unit='folder')
+ # Will evaluate (num_folders + num_pad) times, but only the first num_folders results will be recorded.
+ # (To avoid wait-dead)
+ for i in range(rank, num_folders + num_pad, world_size):
+ idx = min(i, num_folders - 1)
+ val_data = dataset[idx]
+ folder = val_data['folder']
+
+ # compute outputs
+ val_data['lq'].unsqueeze_(0)
+ val_data['gt'].unsqueeze_(0)
+ self.feed_data(val_data)
+ val_data['lq'].squeeze_(0)
+ val_data['gt'].squeeze_(0)
+
+ self.test()
+ visuals = self.get_current_visuals()
+
+ # tentative for out of GPU memory
+ del self.lq
+ del self.output
+ if 'gt' in visuals:
+ del self.gt
+ torch.cuda.empty_cache()
+
+ if self.center_frame_only:
+ visuals['result'] = visuals['result'].unsqueeze(1)
+ if 'gt' in visuals:
+ visuals['gt'] = visuals['gt'].unsqueeze(1)
+
+ # evaluate
+ if i < num_folders:
+ for idx in range(visuals['result'].size(1)):
+ result = visuals['result'][0, idx, :, :, :]
+ result_img = tensor2img([result]) # uint8, bgr
+ metric_data['img'] = result_img
+ if 'gt' in visuals:
+ gt = visuals['gt'][0, idx, :, :, :]
+ gt_img = tensor2img([gt]) # uint8, bgr
+ metric_data['img2'] = gt_img
+
+ if save_img:
+ if self.opt['is_train']:
+ raise NotImplementedError('saving image is not supported during training.')
+ else:
+ if self.center_frame_only: # vimeo-90k
+ clip_ = val_data['lq_path'].split('/')[-3]
+ seq_ = val_data['lq_path'].split('/')[-2]
+ name_ = f'{clip_}_{seq_}'
+ img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
+ f"{name_}_{self.opt['name']}.png")
+ else: # others
+ img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
+ f"{idx:08d}_{self.opt['name']}.png")
+ # image name only for REDS dataset
+ imwrite(result_img, img_path)
+
+ # calculate metrics
+ if with_metrics:
+ for metric_idx, opt_ in enumerate(self.opt['val']['metrics'].values()):
+ result = calculate_metric(metric_data, opt_)
+ self.metric_results[folder][idx, metric_idx] += result
+
+ # progress bar
+ if rank == 0:
+ for _ in range(world_size):
+ pbar.update(1)
+ pbar.set_description(f'Folder: {folder}')
+
+ if rank == 0:
+ pbar.close()
+
+ if with_metrics:
+ if self.opt['dist']:
+ # collect data among GPUs
+ for _, tensor in self.metric_results.items():
+ dist.reduce(tensor, 0)
+ dist.barrier()
+
+ if rank == 0:
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+ def test(self):
+ n = self.lq.size(1)
+ self.net_g.eval()
+
+ flip_seq = self.opt['val'].get('flip_seq', False)
+ self.center_frame_only = self.opt['val'].get('center_frame_only', False)
+
+ if flip_seq:
+ self.lq = torch.cat([self.lq, self.lq.flip(1)], dim=1)
+
+ with torch.no_grad():
+ self.output = self.net_g(self.lq)
+
+ if flip_seq:
+ output_1 = self.output[:, :n, :, :, :]
+ output_2 = self.output[:, n:, :, :, :].flip(1)
+ self.output = 0.5 * (output_1 + output_2)
+
+ if self.center_frame_only:
+ self.output = self.output[:, n // 2, :, :, :]
+
+ self.net_g.train()
|