summaryrefslogtreecommitdiffstats
path: root/r_basicsr/losses/loss_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'r_basicsr/losses/loss_util.py')
-rw-r--r--r_basicsr/losses/loss_util.py145
1 files changed, 145 insertions, 0 deletions
diff --git a/r_basicsr/losses/loss_util.py b/r_basicsr/losses/loss_util.py
new file mode 100644
index 0000000..bcb80f3
--- /dev/null
+++ b/r_basicsr/losses/loss_util.py
@@ -0,0 +1,145 @@
+import functools
+import torch
+from torch.nn import functional as F
+
+
+def reduce_loss(loss, reduction):
+ """Reduce loss as specified.
+
+ Args:
+ loss (Tensor): Elementwise loss tensor.
+ reduction (str): Options are 'none', 'mean' and 'sum'.
+
+ Returns:
+ Tensor: Reduced loss tensor.
+ """
+ reduction_enum = F._Reduction.get_enum(reduction)
+ # none: 0, elementwise_mean:1, sum: 2
+ if reduction_enum == 0:
+ return loss
+ elif reduction_enum == 1:
+ return loss.mean()
+ else:
+ return loss.sum()
+
+
+def weight_reduce_loss(loss, weight=None, reduction='mean'):
+ """Apply element-wise weight and reduce loss.
+
+ Args:
+ loss (Tensor): Element-wise loss.
+ weight (Tensor): Element-wise weights. Default: None.
+ reduction (str): Same as built-in losses of PyTorch. Options are
+ 'none', 'mean' and 'sum'. Default: 'mean'.
+
+ Returns:
+ Tensor: Loss values.
+ """
+ # if weight is specified, apply element-wise weight
+ if weight is not None:
+ assert weight.dim() == loss.dim()
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
+ loss = loss * weight
+
+ # if weight is not specified or reduction is sum, just reduce the loss
+ if weight is None or reduction == 'sum':
+ loss = reduce_loss(loss, reduction)
+ # if reduction is mean, then compute mean over weight region
+ elif reduction == 'mean':
+ if weight.size(1) > 1:
+ weight = weight.sum()
+ else:
+ weight = weight.sum() * loss.size(1)
+ loss = loss.sum() / weight
+
+ return loss
+
+
+def weighted_loss(loss_func):
+ """Create a weighted version of a given loss function.
+
+ To use this decorator, the loss function must have the signature like
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
+ element-wise loss without any reduction. This decorator will add weight
+ and reduction arguments to the function. The decorated function will have
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
+ **kwargs)`.
+
+ :Example:
+
+ >>> import torch
+ >>> @weighted_loss
+ >>> def l1_loss(pred, target):
+ >>> return (pred - target).abs()
+
+ >>> pred = torch.Tensor([0, 2, 3])
+ >>> target = torch.Tensor([1, 1, 1])
+ >>> weight = torch.Tensor([1, 0, 1])
+
+ >>> l1_loss(pred, target)
+ tensor(1.3333)
+ >>> l1_loss(pred, target, weight)
+ tensor(1.5000)
+ >>> l1_loss(pred, target, reduction='none')
+ tensor([1., 1., 2.])
+ >>> l1_loss(pred, target, weight, reduction='sum')
+ tensor(3.)
+ """
+
+ @functools.wraps(loss_func)
+ def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
+ # get element-wise loss
+ loss = loss_func(pred, target, **kwargs)
+ loss = weight_reduce_loss(loss, weight, reduction)
+ return loss
+
+ return wrapper
+
+
+def get_local_weights(residual, ksize):
+ """Get local weights for generating the artifact map of LDL.
+
+ It is only called by the `get_refined_artifact_map` function.
+
+ Args:
+ residual (Tensor): Residual between predicted and ground truth images.
+ ksize (Int): size of the local window.
+
+ Returns:
+ Tensor: weight for each pixel to be discriminated as an artifact pixel
+ """
+
+ pad = (ksize - 1) // 2
+ residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect')
+
+ unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1)
+ pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1)
+
+ return pixel_level_weight
+
+
+def get_refined_artifact_map(img_gt, img_output, img_ema, ksize):
+ """Calculate the artifact map of LDL
+ (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022)
+
+ Args:
+ img_gt (Tensor): ground truth images.
+ img_output (Tensor): output images given by the optimizing model.
+ img_ema (Tensor): output images given by the ema model.
+ ksize (Int): size of the local window.
+
+ Returns:
+ overall_weight: weight for each pixel to be discriminated as an artifact pixel
+ (calculated based on both local and global observations).
+ """
+
+ residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True)
+ residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True)
+
+ patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5)
+ pixel_level_weight = get_local_weights(residual_sr.clone(), ksize)
+ overall_weight = patch_level_weight * pixel_level_weight
+
+ overall_weight[residual_sr < residual_ema] = 0
+
+ return overall_weight