summaryrefslogtreecommitdiffstats
path: root/r_basicsr/models/esrgan_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'r_basicsr/models/esrgan_model.py')
-rw-r--r--r_basicsr/models/esrgan_model.py83
1 files changed, 83 insertions, 0 deletions
diff --git a/r_basicsr/models/esrgan_model.py b/r_basicsr/models/esrgan_model.py
new file mode 100644
index 0000000..8924920
--- /dev/null
+++ b/r_basicsr/models/esrgan_model.py
@@ -0,0 +1,83 @@
+import torch
+from collections import OrderedDict
+
+from r_basicsr.utils.registry import MODEL_REGISTRY
+from .srgan_model import SRGANModel
+
+
+@MODEL_REGISTRY.register()
+class ESRGANModel(SRGANModel):
+ """ESRGAN model for single image super-resolution."""
+
+ def optimize_parameters(self, current_iter):
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+
+ self.optimizer_g.zero_grad()
+ self.output = self.net_g(self.lq)
+
+ l_g_total = 0
+ loss_dict = OrderedDict()
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ # pixel loss
+ if self.cri_pix:
+ l_g_pix = self.cri_pix(self.output, self.gt)
+ l_g_total += l_g_pix
+ loss_dict['l_g_pix'] = l_g_pix
+ # perceptual loss
+ if self.cri_perceptual:
+ l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
+ if l_g_percep is not None:
+ l_g_total += l_g_percep
+ loss_dict['l_g_percep'] = l_g_percep
+ if l_g_style is not None:
+ l_g_total += l_g_style
+ loss_dict['l_g_style'] = l_g_style
+ # gan loss (relativistic gan)
+ real_d_pred = self.net_d(self.gt).detach()
+ fake_g_pred = self.net_d(self.output)
+ l_g_real = self.cri_gan(real_d_pred - torch.mean(fake_g_pred), False, is_disc=False)
+ l_g_fake = self.cri_gan(fake_g_pred - torch.mean(real_d_pred), True, is_disc=False)
+ l_g_gan = (l_g_real + l_g_fake) / 2
+
+ l_g_total += l_g_gan
+ loss_dict['l_g_gan'] = l_g_gan
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ # optimize net_d
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+
+ self.optimizer_d.zero_grad()
+ # gan loss (relativistic gan)
+
+ # In order to avoid the error in distributed training:
+ # "Error detected in CudnnBatchNormBackward: RuntimeError: one of
+ # the variables needed for gradient computation has been modified by
+ # an inplace operation",
+ # we separate the backwards for real and fake, and also detach the
+ # tensor for calculating mean.
+
+ # real
+ fake_d_pred = self.net_d(self.output).detach()
+ real_d_pred = self.net_d(self.gt)
+ l_d_real = self.cri_gan(real_d_pred - torch.mean(fake_d_pred), True, is_disc=True) * 0.5
+ l_d_real.backward()
+ # fake
+ fake_d_pred = self.net_d(self.output.detach())
+ l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred.detach()), False, is_disc=True) * 0.5
+ l_d_fake.backward()
+ self.optimizer_d.step()
+
+ loss_dict['l_d_real'] = l_d_real
+ loss_dict['l_d_fake'] = l_d_fake
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)