From 495ffc4777522e40941753e3b1b79c02f84b25b4 Mon Sep 17 00:00:00 2001 From: Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:00:30 +0000 Subject: Add files via upload --- r_basicsr/metrics/__init__.py | 20 +++ r_basicsr/metrics/fid.py | 93 +++++++++++++ r_basicsr/metrics/metric_util.py | 45 +++++++ r_basicsr/metrics/niqe.py | 197 ++++++++++++++++++++++++++++ r_basicsr/metrics/niqe_pris_params.npz | Bin 0 -> 11850 bytes r_basicsr/metrics/psnr_ssim.py | 233 +++++++++++++++++++++++++++++++++ 6 files changed, 588 insertions(+) create mode 100644 r_basicsr/metrics/__init__.py create mode 100644 r_basicsr/metrics/fid.py create mode 100644 r_basicsr/metrics/metric_util.py create mode 100644 r_basicsr/metrics/niqe.py create mode 100644 r_basicsr/metrics/niqe_pris_params.npz create mode 100644 r_basicsr/metrics/psnr_ssim.py (limited to 'r_basicsr/metrics') diff --git a/r_basicsr/metrics/__init__.py b/r_basicsr/metrics/__init__.py new file mode 100644 index 0000000..46fcd61 --- /dev/null +++ b/r_basicsr/metrics/__init__.py @@ -0,0 +1,20 @@ +from copy import deepcopy + +from r_basicsr.utils.registry import METRIC_REGISTRY +from .niqe import calculate_niqe +from .psnr_ssim import calculate_psnr, calculate_ssim + +__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] + + +def calculate_metric(data, opt): + """Calculate metric from data and options. + + Args: + opt (dict): Configuration. It must contain: + type (str): Model type. + """ + opt = deepcopy(opt) + metric_type = opt.pop('type') + metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) + return metric diff --git a/r_basicsr/metrics/fid.py b/r_basicsr/metrics/fid.py new file mode 100644 index 0000000..dd594d1 --- /dev/null +++ b/r_basicsr/metrics/fid.py @@ -0,0 +1,93 @@ +import numpy as np +import torch +import torch.nn as nn +from scipy import linalg +from tqdm import tqdm + +from r_basicsr.archs.inception import InceptionV3 + + +def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False): + # we may not resize the input, but in [rosinality/stylegan2-pytorch] it + # does resize the input. + inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input) + inception = nn.DataParallel(inception).eval().to(device) + return inception + + +@torch.no_grad() +def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'): + """Extract inception features. + + Args: + data_generator (generator): A data generator. + inception (nn.Module): Inception model. + len_generator (int): Length of the data_generator to show the + progressbar. Default: None. + device (str): Device. Default: cuda. + + Returns: + Tensor: Extracted features. + """ + if len_generator is not None: + pbar = tqdm(total=len_generator, unit='batch', desc='Extract') + else: + pbar = None + features = [] + + for data in data_generator: + if pbar: + pbar.update(1) + data = data.to(device) + feature = inception(data)[0].view(data.shape[0], -1) + features.append(feature.to('cpu')) + if pbar: + pbar.close() + features = torch.cat(features, 0) + return features + + +def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + + Args: + mu1 (np.array): The sample mean over activations. + sigma1 (np.array): The covariance matrix over activations for + generated samples. + mu2 (np.array): The sample mean over activations, precalculated on an + representative data set. + sigma2 (np.array): The covariance matrix over activations, + precalculated on an representative data set. + + Returns: + float: The Frechet Distance. + """ + assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions') + + cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) + + # Product might be almost singular + if not np.isfinite(cov_sqrt).all(): + print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates') + offset = np.eye(sigma1.shape[0]) * eps + cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(cov_sqrt): + if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): + m = np.max(np.abs(cov_sqrt.imag)) + raise ValueError(f'Imaginary component {m}') + cov_sqrt = cov_sqrt.real + + mean_diff = mu1 - mu2 + mean_norm = mean_diff @ mean_diff + trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) + fid = mean_norm + trace + + return fid diff --git a/r_basicsr/metrics/metric_util.py b/r_basicsr/metrics/metric_util.py new file mode 100644 index 0000000..0b45354 --- /dev/null +++ b/r_basicsr/metrics/metric_util.py @@ -0,0 +1,45 @@ +import numpy as np + +from r_basicsr.utils import bgr2ycbcr + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'") + if len(img.shape) == 2: + img = img[..., None] + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def to_y_channel(img): + """Change to Y channel of YCbCr. + + Args: + img (ndarray): Images with range [0, 255]. + + Returns: + (ndarray): Images with range [0, 255] (float type) without round. + """ + img = img.astype(np.float32) / 255. + if img.ndim == 3 and img.shape[2] == 3: + img = bgr2ycbcr(img, y_only=True) + img = img[..., None] + return img * 255. diff --git a/r_basicsr/metrics/niqe.py b/r_basicsr/metrics/niqe.py new file mode 100644 index 0000000..eb3f877 --- /dev/null +++ b/r_basicsr/metrics/niqe.py @@ -0,0 +1,197 @@ +import cv2 +import math +import numpy as np +import os +from scipy.ndimage.filters import convolve +from scipy.special import gamma + +from r_basicsr.metrics.metric_util import reorder_image, to_y_channel +from r_basicsr.utils.matlab_functions import imresize +from r_basicsr.utils.registry import METRIC_REGISTRY + + +def estimate_aggd_param(block): + """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) parameters. + + Args: + block (ndarray): 2D Image block. + + Returns: + tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD + distribution (Estimating the parames in Equation 7 in the paper). + """ + block = block.flatten() + gam = np.arange(0.2, 10.001, 0.001) # len = 9801 + gam_reciprocal = np.reciprocal(gam) + r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3)) + + left_std = np.sqrt(np.mean(block[block < 0]**2)) + right_std = np.sqrt(np.mean(block[block > 0]**2)) + gammahat = left_std / right_std + rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2) + rhatnorm = (rhat * (gammahat**3 + 1) * (gammahat + 1)) / ((gammahat**2 + 1)**2) + array_position = np.argmin((r_gam - rhatnorm)**2) + + alpha = gam[array_position] + beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha)) + beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha)) + return (alpha, beta_l, beta_r) + + +def compute_feature(block): + """Compute features. + + Args: + block (ndarray): 2D Image block. + + Returns: + list: Features with length of 18. + """ + feat = [] + alpha, beta_l, beta_r = estimate_aggd_param(block) + feat.extend([alpha, (beta_l + beta_r) / 2]) + + # distortions disturb the fairly regular structure of natural images. + # This deviation can be captured by analyzing the sample distribution of + # the products of pairs of adjacent coefficients computed along + # horizontal, vertical and diagonal orientations. + shifts = [[0, 1], [1, 0], [1, 1], [1, -1]] + for i in range(len(shifts)): + shifted_block = np.roll(block, shifts[i], axis=(0, 1)) + alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block) + # Eq. 8 + mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha)) + feat.extend([alpha, mean, beta_l, beta_r]) + return feat + + +def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size_h=96, block_size_w=96): + """Calculate NIQE (Natural Image Quality Evaluator) metric. + + Ref: Making a "Completely Blind" Image Quality Analyzer. + This implementation could produce almost the same results as the official + MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip + + Note that we do not include block overlap height and width, since they are + always 0 in the official implementation. + + For good performance, it is advisable by the official implementation to + divide the distorted image in to the same size patched as used for the + construction of multivariate Gaussian model. + + Args: + img (ndarray): Input image whose quality needs to be computed. The + image must be a gray or Y (of YCbCr) image with shape (h, w). + Range [0, 255] with float type. + mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian + model calculated on the pristine dataset. + cov_pris_param (ndarray): Covariance of a pre-defined multivariate + Gaussian model calculated on the pristine dataset. + gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the + image. + block_size_h (int): Height of the blocks in to which image is divided. + Default: 96 (the official recommended value). + block_size_w (int): Width of the blocks in to which image is divided. + Default: 96 (the official recommended value). + """ + assert img.ndim == 2, ('Input image must be a gray or Y (of YCbCr) image with shape (h, w).') + # crop image + h, w = img.shape + num_block_h = math.floor(h / block_size_h) + num_block_w = math.floor(w / block_size_w) + img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w] + + distparam = [] # dist param is actually the multiscale features + for scale in (1, 2): # perform on two scales (1, 2) + mu = convolve(img, gaussian_window, mode='nearest') + sigma = np.sqrt(np.abs(convolve(np.square(img), gaussian_window, mode='nearest') - np.square(mu))) + # normalize, as in Eq. 1 in the paper + img_nomalized = (img - mu) / (sigma + 1) + + feat = [] + for idx_w in range(num_block_w): + for idx_h in range(num_block_h): + # process ecah block + block = img_nomalized[idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale, + idx_w * block_size_w // scale:(idx_w + 1) * block_size_w // scale] + feat.append(compute_feature(block)) + + distparam.append(np.array(feat)) + + if scale == 1: + img = imresize(img / 255., scale=0.5, antialiasing=True) + img = img * 255. + + distparam = np.concatenate(distparam, axis=1) + + # fit a MVG (multivariate Gaussian) model to distorted patch features + mu_distparam = np.nanmean(distparam, axis=0) + # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html + distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)] + cov_distparam = np.cov(distparam_no_nan, rowvar=False) + + # compute niqe quality, Eq. 10 in the paper + invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2) + quality = np.matmul( + np.matmul((mu_pris_param - mu_distparam), invcov_param), np.transpose((mu_pris_param - mu_distparam))) + + quality = np.sqrt(quality) + quality = float(np.squeeze(quality)) + return quality + + +@METRIC_REGISTRY.register() +def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y', **kwargs): + """Calculate NIQE (Natural Image Quality Evaluator) metric. + + Ref: Making a "Completely Blind" Image Quality Analyzer. + This implementation could produce almost the same results as the official + MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip + + > MATLAB R2021a result for tests/data/baboon.png: 5.72957338 (5.7296) + > Our re-implementation result for tests/data/baboon.png: 5.7295763 (5.7296) + + We use the official params estimated from the pristine dataset. + We use the recommended block size (96, 96) without overlaps. + + Args: + img (ndarray): Input image whose quality needs to be computed. + The input image must be in range [0, 255] with float/int type. + The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order) + If the input order is 'HWC' or 'CHW', it will be converted to gray + or Y (of YCbCr) image according to the ``convert_to`` argument. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the metric calculation. + input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'. + Default: 'HWC'. + convert_to (str): Whether converted to 'y' (of MATLAB YCbCr) or 'gray'. + Default: 'y'. + + Returns: + float: NIQE result. + """ + ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + # we use the official params estimated from the pristine dataset. + niqe_pris_params = np.load(os.path.join(ROOT_DIR, 'niqe_pris_params.npz')) + mu_pris_param = niqe_pris_params['mu_pris_param'] + cov_pris_param = niqe_pris_params['cov_pris_param'] + gaussian_window = niqe_pris_params['gaussian_window'] + + img = img.astype(np.float32) + if input_order != 'HW': + img = reorder_image(img, input_order=input_order) + if convert_to == 'y': + img = to_y_channel(img) + elif convert_to == 'gray': + img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255. + img = np.squeeze(img) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border] + + # round is necessary for being consistent with MATLAB's result + img = img.round() + + niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window) + + return niqe_result diff --git a/r_basicsr/metrics/niqe_pris_params.npz b/r_basicsr/metrics/niqe_pris_params.npz new file mode 100644 index 0000000..204ddce Binary files /dev/null and b/r_basicsr/metrics/niqe_pris_params.npz differ diff --git a/r_basicsr/metrics/psnr_ssim.py b/r_basicsr/metrics/psnr_ssim.py new file mode 100644 index 0000000..938de12 --- /dev/null +++ b/r_basicsr/metrics/psnr_ssim.py @@ -0,0 +1,233 @@ +import cv2 +import numpy as np +import torch +import torch.nn.functional as F + +from r_basicsr.metrics.metric_util import reorder_image, to_y_channel +from r_basicsr.utils.color_util import rgb2ycbcr_pt +from r_basicsr.utils.registry import METRIC_REGISTRY + + +@METRIC_REGISTRY.register() +def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: PSNR result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + mse = np.mean((img - img2)**2) + if mse == 0: + return float('inf') + return 10. * np.log10(255. * 255. / mse) + + +@METRIC_REGISTRY.register() +def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False, **kwargs): + """Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: PSNR result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + + if crop_border != 0: + img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] + img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] + + if test_y_channel: + img = rgb2ycbcr_pt(img, y_only=True) + img2 = rgb2ycbcr_pt(img2, y_only=True) + + img = img.to(torch.float64) + img2 = img2.to(torch.float64) + + mse = torch.mean((img - img2)**2, dim=[1, 2, 3]) + return 10. * torch.log10(1. / (mse + 1e-8)) + + +@METRIC_REGISTRY.register() +def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): + """Calculate SSIM (structural similarity). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: SSIM result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + ssims = [] + for i in range(img.shape[2]): + ssims.append(_ssim(img[..., i], img2[..., i])) + return np.array(ssims).mean() + + +@METRIC_REGISTRY.register() +def calculate_ssim_pt(img, img2, crop_border, test_y_channel=False, **kwargs): + """Calculate SSIM (structural similarity) (PyTorch version). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: SSIM result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + + if crop_border != 0: + img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] + img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] + + if test_y_channel: + img = rgb2ycbcr_pt(img, y_only=True) + img2 = rgb2ycbcr_pt(img2, y_only=True) + + img = img.to(torch.float64) + img2 = img2.to(torch.float64) + + ssim = _ssim_pth(img * 255., img2 * 255.) + return ssim + + +def _ssim(img, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: SSIM result. + """ + + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] # valid mode for window size 11 + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)) + return ssim_map.mean() + + +def _ssim_pth(img, img2): + """Calculate SSIM (structural similarity) (PyTorch version). + + It is called by func:`calculate_ssim_pt`. + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + + Returns: + float: SSIM result. + """ + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + window = torch.from_numpy(window).view(1, 1, 11, 11).expand(img.size(1), 1, 11, 11).to(img.dtype).to(img.device) + + mu1 = F.conv2d(img, window, stride=1, padding=0, groups=img.shape[1]) # valid mode + mu2 = F.conv2d(img2, window, stride=1, padding=0, groups=img2.shape[1]) # valid mode + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + sigma1_sq = F.conv2d(img * img, window, stride=1, padding=0, groups=img.shape[1]) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu2_sq + sigma12 = F.conv2d(img * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu1_mu2 + + cs_map = (2 * sigma12 + c2) / (sigma1_sq + sigma2_sq + c2) + ssim_map = ((2 * mu1_mu2 + c1) / (mu1_sq + mu2_sq + c1)) * cs_map + return ssim_map.mean([1, 2, 3]) -- cgit v1.2.3