From 495ffc4777522e40941753e3b1b79c02f84b25b4 Mon Sep 17 00:00:00 2001 From: Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:00:30 +0000 Subject: Add files via upload --- r_basicsr/utils/img_process_util.py | 83 +++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 r_basicsr/utils/img_process_util.py (limited to 'r_basicsr/utils/img_process_util.py') diff --git a/r_basicsr/utils/img_process_util.py b/r_basicsr/utils/img_process_util.py new file mode 100644 index 0000000..fb5fbc9 --- /dev/null +++ b/r_basicsr/utils/img_process_util.py @@ -0,0 +1,83 @@ +import cv2 +import numpy as np +import torch +from torch.nn import functional as F + + +def filter2D(img, kernel): + """PyTorch version of cv2.filter2D + + Args: + img (Tensor): (b, c, h, w) + kernel (Tensor): (b, k, k) + """ + k = kernel.size(-1) + b, c, h, w = img.size() + if k % 2 == 1: + img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') + else: + raise ValueError('Wrong kernel size') + + ph, pw = img.size()[-2:] + + if kernel.size(0) == 1: + # apply the same kernel to all batch images + img = img.view(b * c, 1, ph, pw) + kernel = kernel.view(1, 1, k, k) + return F.conv2d(img, kernel, padding=0).view(b, c, h, w) + else: + img = img.view(1, b * c, ph, pw) + kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) + return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) + + +def usm_sharp(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. + + Input image: I; Blurry image: B. + 1. sharp = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * sharp + (1 - Mask) * I + + + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + sharp = img + weight * residual + sharp = np.clip(sharp, 0, 1) + return soft_mask * sharp + (1 - soft_mask) * img + + +class USMSharp(torch.nn.Module): + + def __init__(self, radius=50, sigma=0): + super(USMSharp, self).__init__() + if radius % 2 == 0: + radius += 1 + self.radius = radius + kernel = cv2.getGaussianKernel(radius, sigma) + kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) + self.register_buffer('kernel', kernel) + + def forward(self, img, weight=0.5, threshold=10): + blur = filter2D(img, self.kernel) + residual = img - blur + + mask = torch.abs(residual) * 255 > threshold + mask = mask.float() + soft_mask = filter2D(mask, self.kernel) + sharp = img + weight * residual + sharp = torch.clip(sharp, 0, 1) + return soft_mask * sharp + (1 - soft_mask) * img -- cgit v1.2.3