From 495ffc4777522e40941753e3b1b79c02f84b25b4 Mon Sep 17 00:00:00 2001 From: Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:00:30 +0000 Subject: Add files via upload --- r_basicsr/losses/gan_loss.py | 208 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 208 insertions(+) create mode 100644 r_basicsr/losses/gan_loss.py (limited to 'r_basicsr/losses/gan_loss.py') diff --git a/r_basicsr/losses/gan_loss.py b/r_basicsr/losses/gan_loss.py new file mode 100644 index 0000000..6c2a199 --- /dev/null +++ b/r_basicsr/losses/gan_loss.py @@ -0,0 +1,208 @@ +import math +import torch +from torch import autograd as autograd +from torch import nn as nn +from torch.nn import functional as F + +from r_basicsr.utils.registry import LOSS_REGISTRY + + +@LOSS_REGISTRY.register() +class GANLoss(nn.Module): + """Define GAN loss. + + Args: + gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. + real_label_val (float): The value for real label. Default: 1.0. + fake_label_val (float): The value for fake label. Default: 0.0. + loss_weight (float): Loss weight. Default: 1.0. + Note that loss_weight is only for generators; and it is always 1.0 + for discriminators. + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type + self.loss_weight = loss_weight + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan': + self.loss = self._wgan_loss + elif self.gan_type == 'wgan_softplus': + self.loss = self._wgan_softplus_loss + elif self.gan_type == 'hinge': + self.loss = nn.ReLU() + else: + raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') + + def _wgan_loss(self, input, target): + """wgan loss. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return -input.mean() if target else input.mean() + + def _wgan_softplus_loss(self, input, target): + """wgan loss with soft plus. softplus is a smooth approximation to the + ReLU function. + + In StyleGAN2, it is called: + Logistic loss for discriminator; + Non-saturating loss for generator. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return F.softplus(-input).mean() if target else F.softplus(input).mean() + + def get_target_label(self, input, target_is_real): + """Get target label. + + Args: + input (Tensor): Input tensor. + target_is_real (bool): Whether the target is real or fake. + + Returns: + (bool | Tensor): Target tensor. Return bool for wgan, otherwise, + return Tensor. + """ + + if self.gan_type in ['wgan', 'wgan_softplus']: + return target_is_real + target_val = (self.real_label_val if target_is_real else self.fake_label_val) + return input.new_ones(input.size()) * target_val + + def forward(self, input, target_is_real, is_disc=False): + """ + Args: + input (Tensor): The input for the loss module, i.e., the network + prediction. + target_is_real (bool): Whether the targe is real or fake. + is_disc (bool): Whether the loss for discriminators or not. + Default: False. + + Returns: + Tensor: GAN loss value. + """ + target_label = self.get_target_label(input, target_is_real) + if self.gan_type == 'hinge': + if is_disc: # for discriminators in hinge-gan + input = -input if target_is_real else input + loss = self.loss(1 + input).mean() + else: # for generators in hinge-gan + loss = -input.mean() + else: # other gan types + loss = self.loss(input, target_label) + + # loss_weight is always 1.0 for discriminators + return loss if is_disc else loss * self.loss_weight + + +@LOSS_REGISTRY.register() +class MultiScaleGANLoss(GANLoss): + """ + MultiScaleGANLoss accepts a list of predictions + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight) + + def forward(self, input, target_is_real, is_disc=False): + """ + The input is a list of tensors, or a list of (a list of tensors) + """ + if isinstance(input, list): + loss = 0 + for pred_i in input: + if isinstance(pred_i, list): + # Only compute GAN loss for the last layer + # in case of multiscale feature matching + pred_i = pred_i[-1] + # Safe operation: 0-dim tensor calling self.mean() does nothing + loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean() + loss += loss_tensor + return loss / len(input) + else: + return super().forward(input, target_is_real, is_disc) + + +def r1_penalty(real_pred, real_img): + """R1 regularization for discriminator. The core idea is to + penalize the gradient on real data alone: when the + generator distribution produces the true data distribution + and the discriminator is equal to 0 on the data manifold, the + gradient penalty ensures that the discriminator cannot create + a non-zero gradient orthogonal to the data manifold without + suffering a loss in the GAN game. + + Ref: + Eq. 9 in Which training methods for GANs do actually converge. + """ + grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] + grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() + return grad_penalty + + +def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): + noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) + grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] + path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) + + path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) + + path_penalty = (path_lengths - path_mean).pow(2).mean() + + return path_penalty, path_lengths.detach().mean(), path_mean.detach() + + +def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): + """Calculate gradient penalty for wgan-gp. + + Args: + discriminator (nn.Module): Network for the discriminator. + real_data (Tensor): Real input data. + fake_data (Tensor): Fake input data. + weight (Tensor): Weight tensor. Default: None. + + Returns: + Tensor: A tensor for gradient penalty. + """ + + batch_size = real_data.size(0) + alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) + + # interpolate between real_data and fake_data + interpolates = alpha * real_data + (1. - alpha) * fake_data + interpolates = autograd.Variable(interpolates, requires_grad=True) + + disc_interpolates = discriminator(interpolates) + gradients = autograd.grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones_like(disc_interpolates), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + + if weight is not None: + gradients = gradients * weight + + gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() + if weight is not None: + gradients_penalty /= torch.mean(weight) + + return gradients_penalty -- cgit v1.2.3