summaryrefslogtreecommitdiffstats
path: root/r_basicsr/models/video_recurrent_gan_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'r_basicsr/models/video_recurrent_gan_model.py')
-rw-r--r--r_basicsr/models/video_recurrent_gan_model.py180
1 files changed, 180 insertions, 0 deletions
diff --git a/r_basicsr/models/video_recurrent_gan_model.py b/r_basicsr/models/video_recurrent_gan_model.py
new file mode 100644
index 0000000..2800e27
--- /dev/null
+++ b/r_basicsr/models/video_recurrent_gan_model.py
@@ -0,0 +1,180 @@
+import torch
+from collections import OrderedDict
+
+from r_basicsr.archs import build_network
+from r_basicsr.losses import build_loss
+from r_basicsr.utils import get_root_logger
+from r_basicsr.utils.registry import MODEL_REGISTRY
+from .video_recurrent_model import VideoRecurrentModel
+
+
+@MODEL_REGISTRY.register()
+class VideoRecurrentGANModel(VideoRecurrentModel):
+
+ def init_training_settings(self):
+ train_opt = self.opt['train']
+
+ self.ema_decay = train_opt.get('ema_decay', 0)
+ if self.ema_decay > 0:
+ logger = get_root_logger()
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
+ # build network net_g with Exponential Moving Average (EMA)
+ # net_g_ema only used for testing on one GPU and saving.
+ # There is no need to wrap with DistributedDataParallel
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+ else:
+ self.model_ema(0) # copy net_g weight
+ self.net_g_ema.eval()
+
+ # define network net_d
+ self.net_d = build_network(self.opt['network_d'])
+ self.net_d = self.model_to_device(self.net_d)
+ self.print_network(self.net_d)
+
+ # load pretrained models
+ load_path = self.opt['path'].get('pretrain_network_d', None)
+ if load_path is not None:
+ param_key = self.opt['path'].get('param_key_d', 'params')
+ self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
+
+ self.net_g.train()
+ self.net_d.train()
+
+ # define losses
+ if train_opt.get('pixel_opt'):
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+ else:
+ self.cri_pix = None
+
+ if train_opt.get('perceptual_opt'):
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+ else:
+ self.cri_perceptual = None
+
+ if train_opt.get('gan_opt'):
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
+
+ self.net_d_iters = train_opt.get('net_d_iters', 1)
+ self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
+
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ if train_opt['fix_flow']:
+ normal_params = []
+ flow_params = []
+ for name, param in self.net_g.named_parameters():
+ if 'spynet' in name: # The fix_flow now only works for spynet.
+ flow_params.append(param)
+ else:
+ normal_params.append(param)
+
+ optim_params = [
+ { # add flow params first
+ 'params': flow_params,
+ 'lr': train_opt['lr_flow']
+ },
+ {
+ 'params': normal_params,
+ 'lr': train_opt['optim_g']['lr']
+ },
+ ]
+ else:
+ optim_params = self.net_g.parameters()
+
+ # optimizer g
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+ # optimizer d
+ optim_type = train_opt['optim_d'].pop('type')
+ self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
+ self.optimizers.append(self.optimizer_d)
+
+ def optimize_parameters(self, current_iter):
+ logger = get_root_logger()
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+
+ if self.fix_flow_iter:
+ if current_iter == 1:
+ logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.')
+ for name, param in self.net_g.named_parameters():
+ if 'spynet' in name or 'edvr' in name:
+ param.requires_grad_(False)
+ elif current_iter == self.fix_flow_iter:
+ logger.warning('Train all the parameters.')
+ self.net_g.requires_grad_(True)
+
+ self.optimizer_g.zero_grad()
+ self.output = self.net_g(self.lq)
+
+ _, _, c, h, w = self.output.size()
+
+ l_g_total = 0
+ loss_dict = OrderedDict()
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ # pixel loss
+ if self.cri_pix:
+ l_g_pix = self.cri_pix(self.output, self.gt)
+ l_g_total += l_g_pix
+ loss_dict['l_g_pix'] = l_g_pix
+ # perceptual loss
+ if self.cri_perceptual:
+ l_g_percep, l_g_style = self.cri_perceptual(self.output.view(-1, c, h, w), self.gt.view(-1, c, h, w))
+ if l_g_percep is not None:
+ l_g_total += l_g_percep
+ loss_dict['l_g_percep'] = l_g_percep
+ if l_g_style is not None:
+ l_g_total += l_g_style
+ loss_dict['l_g_style'] = l_g_style
+ # gan loss
+ fake_g_pred = self.net_d(self.output.view(-1, c, h, w))
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
+ l_g_total += l_g_gan
+ loss_dict['l_g_gan'] = l_g_gan
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ # optimize net_d
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+
+ self.optimizer_d.zero_grad()
+ # real
+ # reshape to (b*n, c, h, w)
+ real_d_pred = self.net_d(self.gt.view(-1, c, h, w))
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
+ loss_dict['l_d_real'] = l_d_real
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+ l_d_real.backward()
+ # fake
+ # reshape to (b*n, c, h, w)
+ fake_d_pred = self.net_d(self.output.view(-1, c, h, w).detach())
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
+ loss_dict['l_d_fake'] = l_d_fake
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+ l_d_fake.backward()
+ self.optimizer_d.step()
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+ def save(self, epoch, current_iter):
+ if self.ema_decay > 0:
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+ else:
+ self.save_network(self.net_g, 'net_g', current_iter)
+ self.save_network(self.net_d, 'net_d', current_iter)
+ self.save_training_state(epoch, current_iter)