diff options
Diffstat (limited to 'r_basicsr/losses/gan_loss.py')
-rw-r--r-- | r_basicsr/losses/gan_loss.py | 208 |
1 files changed, 208 insertions, 0 deletions
diff --git a/r_basicsr/losses/gan_loss.py b/r_basicsr/losses/gan_loss.py new file mode 100644 index 0000000..6c2a199 --- /dev/null +++ b/r_basicsr/losses/gan_loss.py @@ -0,0 +1,208 @@ +import math
+import torch
+from torch import autograd as autograd
+from torch import nn as nn
+from torch.nn import functional as F
+
+from r_basicsr.utils.registry import LOSS_REGISTRY
+
+
+@LOSS_REGISTRY.register()
+class GANLoss(nn.Module):
+ """Define GAN loss.
+
+ Args:
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
+ real_label_val (float): The value for real label. Default: 1.0.
+ fake_label_val (float): The value for fake label. Default: 0.0.
+ loss_weight (float): Loss weight. Default: 1.0.
+ Note that loss_weight is only for generators; and it is always 1.0
+ for discriminators.
+ """
+
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
+ super(GANLoss, self).__init__()
+ self.gan_type = gan_type
+ self.loss_weight = loss_weight
+ self.real_label_val = real_label_val
+ self.fake_label_val = fake_label_val
+
+ if self.gan_type == 'vanilla':
+ self.loss = nn.BCEWithLogitsLoss()
+ elif self.gan_type == 'lsgan':
+ self.loss = nn.MSELoss()
+ elif self.gan_type == 'wgan':
+ self.loss = self._wgan_loss
+ elif self.gan_type == 'wgan_softplus':
+ self.loss = self._wgan_softplus_loss
+ elif self.gan_type == 'hinge':
+ self.loss = nn.ReLU()
+ else:
+ raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
+
+ def _wgan_loss(self, input, target):
+ """wgan loss.
+
+ Args:
+ input (Tensor): Input tensor.
+ target (bool): Target label.
+
+ Returns:
+ Tensor: wgan loss.
+ """
+ return -input.mean() if target else input.mean()
+
+ def _wgan_softplus_loss(self, input, target):
+ """wgan loss with soft plus. softplus is a smooth approximation to the
+ ReLU function.
+
+ In StyleGAN2, it is called:
+ Logistic loss for discriminator;
+ Non-saturating loss for generator.
+
+ Args:
+ input (Tensor): Input tensor.
+ target (bool): Target label.
+
+ Returns:
+ Tensor: wgan loss.
+ """
+ return F.softplus(-input).mean() if target else F.softplus(input).mean()
+
+ def get_target_label(self, input, target_is_real):
+ """Get target label.
+
+ Args:
+ input (Tensor): Input tensor.
+ target_is_real (bool): Whether the target is real or fake.
+
+ Returns:
+ (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
+ return Tensor.
+ """
+
+ if self.gan_type in ['wgan', 'wgan_softplus']:
+ return target_is_real
+ target_val = (self.real_label_val if target_is_real else self.fake_label_val)
+ return input.new_ones(input.size()) * target_val
+
+ def forward(self, input, target_is_real, is_disc=False):
+ """
+ Args:
+ input (Tensor): The input for the loss module, i.e., the network
+ prediction.
+ target_is_real (bool): Whether the targe is real or fake.
+ is_disc (bool): Whether the loss for discriminators or not.
+ Default: False.
+
+ Returns:
+ Tensor: GAN loss value.
+ """
+ target_label = self.get_target_label(input, target_is_real)
+ if self.gan_type == 'hinge':
+ if is_disc: # for discriminators in hinge-gan
+ input = -input if target_is_real else input
+ loss = self.loss(1 + input).mean()
+ else: # for generators in hinge-gan
+ loss = -input.mean()
+ else: # other gan types
+ loss = self.loss(input, target_label)
+
+ # loss_weight is always 1.0 for discriminators
+ return loss if is_disc else loss * self.loss_weight
+
+
+@LOSS_REGISTRY.register()
+class MultiScaleGANLoss(GANLoss):
+ """
+ MultiScaleGANLoss accepts a list of predictions
+ """
+
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
+ super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight)
+
+ def forward(self, input, target_is_real, is_disc=False):
+ """
+ The input is a list of tensors, or a list of (a list of tensors)
+ """
+ if isinstance(input, list):
+ loss = 0
+ for pred_i in input:
+ if isinstance(pred_i, list):
+ # Only compute GAN loss for the last layer
+ # in case of multiscale feature matching
+ pred_i = pred_i[-1]
+ # Safe operation: 0-dim tensor calling self.mean() does nothing
+ loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean()
+ loss += loss_tensor
+ return loss / len(input)
+ else:
+ return super().forward(input, target_is_real, is_disc)
+
+
+def r1_penalty(real_pred, real_img):
+ """R1 regularization for discriminator. The core idea is to
+ penalize the gradient on real data alone: when the
+ generator distribution produces the true data distribution
+ and the discriminator is equal to 0 on the data manifold, the
+ gradient penalty ensures that the discriminator cannot create
+ a non-zero gradient orthogonal to the data manifold without
+ suffering a loss in the GAN game.
+
+ Ref:
+ Eq. 9 in Which training methods for GANs do actually converge.
+ """
+ grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
+ grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
+ return grad_penalty
+
+
+def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
+ noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
+ grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
+ path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
+
+ path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
+
+ path_penalty = (path_lengths - path_mean).pow(2).mean()
+
+ return path_penalty, path_lengths.detach().mean(), path_mean.detach()
+
+
+def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
+ """Calculate gradient penalty for wgan-gp.
+
+ Args:
+ discriminator (nn.Module): Network for the discriminator.
+ real_data (Tensor): Real input data.
+ fake_data (Tensor): Fake input data.
+ weight (Tensor): Weight tensor. Default: None.
+
+ Returns:
+ Tensor: A tensor for gradient penalty.
+ """
+
+ batch_size = real_data.size(0)
+ alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
+
+ # interpolate between real_data and fake_data
+ interpolates = alpha * real_data + (1. - alpha) * fake_data
+ interpolates = autograd.Variable(interpolates, requires_grad=True)
+
+ disc_interpolates = discriminator(interpolates)
+ gradients = autograd.grad(
+ outputs=disc_interpolates,
+ inputs=interpolates,
+ grad_outputs=torch.ones_like(disc_interpolates),
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True)[0]
+
+ if weight is not None:
+ gradients = gradients * weight
+
+ gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
+ if weight is not None:
+ gradients_penalty /= torch.mean(weight)
+
+ return gradients_penalty
|