summaryrefslogtreecommitdiffstats
path: root/r_basicsr/utils/img_process_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'r_basicsr/utils/img_process_util.py')
-rw-r--r--r_basicsr/utils/img_process_util.py83
1 files changed, 83 insertions, 0 deletions
diff --git a/r_basicsr/utils/img_process_util.py b/r_basicsr/utils/img_process_util.py
new file mode 100644
index 0000000..fb5fbc9
--- /dev/null
+++ b/r_basicsr/utils/img_process_util.py
@@ -0,0 +1,83 @@
+import cv2
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+
+def filter2D(img, kernel):
+ """PyTorch version of cv2.filter2D
+
+ Args:
+ img (Tensor): (b, c, h, w)
+ kernel (Tensor): (b, k, k)
+ """
+ k = kernel.size(-1)
+ b, c, h, w = img.size()
+ if k % 2 == 1:
+ img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
+ else:
+ raise ValueError('Wrong kernel size')
+
+ ph, pw = img.size()[-2:]
+
+ if kernel.size(0) == 1:
+ # apply the same kernel to all batch images
+ img = img.view(b * c, 1, ph, pw)
+ kernel = kernel.view(1, 1, k, k)
+ return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
+ else:
+ img = img.view(1, b * c, ph, pw)
+ kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
+ return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
+
+
+def usm_sharp(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening.
+
+ Input image: I; Blurry image: B.
+ 1. sharp = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * sharp + (1 - Mask) * I
+
+
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ sharp = img + weight * residual
+ sharp = np.clip(sharp, 0, 1)
+ return soft_mask * sharp + (1 - soft_mask) * img
+
+
+class USMSharp(torch.nn.Module):
+
+ def __init__(self, radius=50, sigma=0):
+ super(USMSharp, self).__init__()
+ if radius % 2 == 0:
+ radius += 1
+ self.radius = radius
+ kernel = cv2.getGaussianKernel(radius, sigma)
+ kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
+ self.register_buffer('kernel', kernel)
+
+ def forward(self, img, weight=0.5, threshold=10):
+ blur = filter2D(img, self.kernel)
+ residual = img - blur
+
+ mask = torch.abs(residual) * 255 > threshold
+ mask = mask.float()
+ soft_mask = filter2D(mask, self.kernel)
+ sharp = img + weight * residual
+ sharp = torch.clip(sharp, 0, 1)
+ return soft_mask * sharp + (1 - soft_mask) * img