From 495ffc4777522e40941753e3b1b79c02f84b25b4 Mon Sep 17 00:00:00 2001 From: Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:00:30 +0000 Subject: Add files via upload --- r_basicsr/metrics/psnr_ssim.py | 233 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 233 insertions(+) create mode 100644 r_basicsr/metrics/psnr_ssim.py (limited to 'r_basicsr/metrics/psnr_ssim.py') diff --git a/r_basicsr/metrics/psnr_ssim.py b/r_basicsr/metrics/psnr_ssim.py new file mode 100644 index 0000000..938de12 --- /dev/null +++ b/r_basicsr/metrics/psnr_ssim.py @@ -0,0 +1,233 @@ +import cv2 +import numpy as np +import torch +import torch.nn.functional as F + +from r_basicsr.metrics.metric_util import reorder_image, to_y_channel +from r_basicsr.utils.color_util import rgb2ycbcr_pt +from r_basicsr.utils.registry import METRIC_REGISTRY + + +@METRIC_REGISTRY.register() +def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: PSNR result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + mse = np.mean((img - img2)**2) + if mse == 0: + return float('inf') + return 10. * np.log10(255. * 255. / mse) + + +@METRIC_REGISTRY.register() +def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False, **kwargs): + """Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: PSNR result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + + if crop_border != 0: + img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] + img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] + + if test_y_channel: + img = rgb2ycbcr_pt(img, y_only=True) + img2 = rgb2ycbcr_pt(img2, y_only=True) + + img = img.to(torch.float64) + img2 = img2.to(torch.float64) + + mse = torch.mean((img - img2)**2, dim=[1, 2, 3]) + return 10. * torch.log10(1. / (mse + 1e-8)) + + +@METRIC_REGISTRY.register() +def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): + """Calculate SSIM (structural similarity). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: SSIM result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + ssims = [] + for i in range(img.shape[2]): + ssims.append(_ssim(img[..., i], img2[..., i])) + return np.array(ssims).mean() + + +@METRIC_REGISTRY.register() +def calculate_ssim_pt(img, img2, crop_border, test_y_channel=False, **kwargs): + """Calculate SSIM (structural similarity) (PyTorch version). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: SSIM result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + + if crop_border != 0: + img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] + img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] + + if test_y_channel: + img = rgb2ycbcr_pt(img, y_only=True) + img2 = rgb2ycbcr_pt(img2, y_only=True) + + img = img.to(torch.float64) + img2 = img2.to(torch.float64) + + ssim = _ssim_pth(img * 255., img2 * 255.) + return ssim + + +def _ssim(img, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: SSIM result. + """ + + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] # valid mode for window size 11 + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)) + return ssim_map.mean() + + +def _ssim_pth(img, img2): + """Calculate SSIM (structural similarity) (PyTorch version). + + It is called by func:`calculate_ssim_pt`. + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + + Returns: + float: SSIM result. + """ + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + window = torch.from_numpy(window).view(1, 1, 11, 11).expand(img.size(1), 1, 11, 11).to(img.dtype).to(img.device) + + mu1 = F.conv2d(img, window, stride=1, padding=0, groups=img.shape[1]) # valid mode + mu2 = F.conv2d(img2, window, stride=1, padding=0, groups=img2.shape[1]) # valid mode + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + sigma1_sq = F.conv2d(img * img, window, stride=1, padding=0, groups=img.shape[1]) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu2_sq + sigma12 = F.conv2d(img * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu1_mu2 + + cs_map = (2 * sigma12 + c2) / (sigma1_sq + sigma2_sq + c2) + ssim_map = ((2 * mu1_mu2 + c1) / (mu1_sq + mu2_sq + c1)) * cs_map + return ssim_map.mean([1, 2, 3]) -- cgit v1.2.3