diff options
author | Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> | 2025-01-17 11:00:30 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-17 11:00:30 +0000 |
commit | 495ffc4777522e40941753e3b1b79c02f84b25b4 (patch) | |
tree | 5130fcb8676afdcb619a5e5eaef3ac28e135bc08 /r_basicsr/metrics | |
parent | febd45814cd41560c5247aacb111d8d013f3a303 (diff) | |
download | Comfyui-reactor-node-495ffc4777522e40941753e3b1b79c02f84b25b4.tar.gz |
Add files via upload
Diffstat (limited to 'r_basicsr/metrics')
-rw-r--r-- | r_basicsr/metrics/__init__.py | 20 | ||||
-rw-r--r-- | r_basicsr/metrics/fid.py | 93 | ||||
-rw-r--r-- | r_basicsr/metrics/metric_util.py | 45 | ||||
-rw-r--r-- | r_basicsr/metrics/niqe.py | 197 | ||||
-rw-r--r-- | r_basicsr/metrics/niqe_pris_params.npz | bin | 0 -> 11850 bytes | |||
-rw-r--r-- | r_basicsr/metrics/psnr_ssim.py | 233 |
6 files changed, 588 insertions, 0 deletions
diff --git a/r_basicsr/metrics/__init__.py b/r_basicsr/metrics/__init__.py new file mode 100644 index 0000000..46fcd61 --- /dev/null +++ b/r_basicsr/metrics/__init__.py @@ -0,0 +1,20 @@ +from copy import deepcopy
+
+from r_basicsr.utils.registry import METRIC_REGISTRY
+from .niqe import calculate_niqe
+from .psnr_ssim import calculate_psnr, calculate_ssim
+
+__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe']
+
+
+def calculate_metric(data, opt):
+ """Calculate metric from data and options.
+
+ Args:
+ opt (dict): Configuration. It must contain:
+ type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ metric_type = opt.pop('type')
+ metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
+ return metric
diff --git a/r_basicsr/metrics/fid.py b/r_basicsr/metrics/fid.py new file mode 100644 index 0000000..dd594d1 --- /dev/null +++ b/r_basicsr/metrics/fid.py @@ -0,0 +1,93 @@ +import numpy as np
+import torch
+import torch.nn as nn
+from scipy import linalg
+from tqdm import tqdm
+
+from r_basicsr.archs.inception import InceptionV3
+
+
+def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False):
+ # we may not resize the input, but in [rosinality/stylegan2-pytorch] it
+ # does resize the input.
+ inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input)
+ inception = nn.DataParallel(inception).eval().to(device)
+ return inception
+
+
+@torch.no_grad()
+def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'):
+ """Extract inception features.
+
+ Args:
+ data_generator (generator): A data generator.
+ inception (nn.Module): Inception model.
+ len_generator (int): Length of the data_generator to show the
+ progressbar. Default: None.
+ device (str): Device. Default: cuda.
+
+ Returns:
+ Tensor: Extracted features.
+ """
+ if len_generator is not None:
+ pbar = tqdm(total=len_generator, unit='batch', desc='Extract')
+ else:
+ pbar = None
+ features = []
+
+ for data in data_generator:
+ if pbar:
+ pbar.update(1)
+ data = data.to(device)
+ feature = inception(data)[0].view(data.shape[0], -1)
+ features.append(feature.to('cpu'))
+ if pbar:
+ pbar.close()
+ features = torch.cat(features, 0)
+ return features
+
+
+def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
+ """Numpy implementation of the Frechet Distance.
+
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
+ and X_2 ~ N(mu_2, C_2) is
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
+ Stable version by Dougal J. Sutherland.
+
+ Args:
+ mu1 (np.array): The sample mean over activations.
+ sigma1 (np.array): The covariance matrix over activations for
+ generated samples.
+ mu2 (np.array): The sample mean over activations, precalculated on an
+ representative data set.
+ sigma2 (np.array): The covariance matrix over activations,
+ precalculated on an representative data set.
+
+ Returns:
+ float: The Frechet Distance.
+ """
+ assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths'
+ assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions')
+
+ cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
+
+ # Product might be almost singular
+ if not np.isfinite(cov_sqrt).all():
+ print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates')
+ offset = np.eye(sigma1.shape[0]) * eps
+ cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset))
+
+ # Numerical error might give slight imaginary component
+ if np.iscomplexobj(cov_sqrt):
+ if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
+ m = np.max(np.abs(cov_sqrt.imag))
+ raise ValueError(f'Imaginary component {m}')
+ cov_sqrt = cov_sqrt.real
+
+ mean_diff = mu1 - mu2
+ mean_norm = mean_diff @ mean_diff
+ trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt)
+ fid = mean_norm + trace
+
+ return fid
diff --git a/r_basicsr/metrics/metric_util.py b/r_basicsr/metrics/metric_util.py new file mode 100644 index 0000000..0b45354 --- /dev/null +++ b/r_basicsr/metrics/metric_util.py @@ -0,0 +1,45 @@ +import numpy as np
+
+from r_basicsr.utils import bgr2ycbcr
+
+
+def reorder_image(img, input_order='HWC'):
+ """Reorder images to 'HWC' order.
+
+ If the input_order is (h, w), return (h, w, 1);
+ If the input_order is (c, h, w), return (h, w, c);
+ If the input_order is (h, w, c), return as it is.
+
+ Args:
+ img (ndarray): Input image.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ If the input image shape is (h, w), input_order will not have
+ effects. Default: 'HWC'.
+
+ Returns:
+ ndarray: reordered image.
+ """
+
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'")
+ if len(img.shape) == 2:
+ img = img[..., None]
+ if input_order == 'CHW':
+ img = img.transpose(1, 2, 0)
+ return img
+
+
+def to_y_channel(img):
+ """Change to Y channel of YCbCr.
+
+ Args:
+ img (ndarray): Images with range [0, 255].
+
+ Returns:
+ (ndarray): Images with range [0, 255] (float type) without round.
+ """
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 3 and img.shape[2] == 3:
+ img = bgr2ycbcr(img, y_only=True)
+ img = img[..., None]
+ return img * 255.
diff --git a/r_basicsr/metrics/niqe.py b/r_basicsr/metrics/niqe.py new file mode 100644 index 0000000..eb3f877 --- /dev/null +++ b/r_basicsr/metrics/niqe.py @@ -0,0 +1,197 @@ +import cv2
+import math
+import numpy as np
+import os
+from scipy.ndimage.filters import convolve
+from scipy.special import gamma
+
+from r_basicsr.metrics.metric_util import reorder_image, to_y_channel
+from r_basicsr.utils.matlab_functions import imresize
+from r_basicsr.utils.registry import METRIC_REGISTRY
+
+
+def estimate_aggd_param(block):
+ """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) parameters.
+
+ Args:
+ block (ndarray): 2D Image block.
+
+ Returns:
+ tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD
+ distribution (Estimating the parames in Equation 7 in the paper).
+ """
+ block = block.flatten()
+ gam = np.arange(0.2, 10.001, 0.001) # len = 9801
+ gam_reciprocal = np.reciprocal(gam)
+ r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
+
+ left_std = np.sqrt(np.mean(block[block < 0]**2))
+ right_std = np.sqrt(np.mean(block[block > 0]**2))
+ gammahat = left_std / right_std
+ rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
+ rhatnorm = (rhat * (gammahat**3 + 1) * (gammahat + 1)) / ((gammahat**2 + 1)**2)
+ array_position = np.argmin((r_gam - rhatnorm)**2)
+
+ alpha = gam[array_position]
+ beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
+ beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
+ return (alpha, beta_l, beta_r)
+
+
+def compute_feature(block):
+ """Compute features.
+
+ Args:
+ block (ndarray): 2D Image block.
+
+ Returns:
+ list: Features with length of 18.
+ """
+ feat = []
+ alpha, beta_l, beta_r = estimate_aggd_param(block)
+ feat.extend([alpha, (beta_l + beta_r) / 2])
+
+ # distortions disturb the fairly regular structure of natural images.
+ # This deviation can be captured by analyzing the sample distribution of
+ # the products of pairs of adjacent coefficients computed along
+ # horizontal, vertical and diagonal orientations.
+ shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
+ for i in range(len(shifts)):
+ shifted_block = np.roll(block, shifts[i], axis=(0, 1))
+ alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block)
+ # Eq. 8
+ mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha))
+ feat.extend([alpha, mean, beta_l, beta_r])
+ return feat
+
+
+def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size_h=96, block_size_w=96):
+ """Calculate NIQE (Natural Image Quality Evaluator) metric.
+
+ Ref: Making a "Completely Blind" Image Quality Analyzer.
+ This implementation could produce almost the same results as the official
+ MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
+
+ Note that we do not include block overlap height and width, since they are
+ always 0 in the official implementation.
+
+ For good performance, it is advisable by the official implementation to
+ divide the distorted image in to the same size patched as used for the
+ construction of multivariate Gaussian model.
+
+ Args:
+ img (ndarray): Input image whose quality needs to be computed. The
+ image must be a gray or Y (of YCbCr) image with shape (h, w).
+ Range [0, 255] with float type.
+ mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian
+ model calculated on the pristine dataset.
+ cov_pris_param (ndarray): Covariance of a pre-defined multivariate
+ Gaussian model calculated on the pristine dataset.
+ gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the
+ image.
+ block_size_h (int): Height of the blocks in to which image is divided.
+ Default: 96 (the official recommended value).
+ block_size_w (int): Width of the blocks in to which image is divided.
+ Default: 96 (the official recommended value).
+ """
+ assert img.ndim == 2, ('Input image must be a gray or Y (of YCbCr) image with shape (h, w).')
+ # crop image
+ h, w = img.shape
+ num_block_h = math.floor(h / block_size_h)
+ num_block_w = math.floor(w / block_size_w)
+ img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]
+
+ distparam = [] # dist param is actually the multiscale features
+ for scale in (1, 2): # perform on two scales (1, 2)
+ mu = convolve(img, gaussian_window, mode='nearest')
+ sigma = np.sqrt(np.abs(convolve(np.square(img), gaussian_window, mode='nearest') - np.square(mu)))
+ # normalize, as in Eq. 1 in the paper
+ img_nomalized = (img - mu) / (sigma + 1)
+
+ feat = []
+ for idx_w in range(num_block_w):
+ for idx_h in range(num_block_h):
+ # process ecah block
+ block = img_nomalized[idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale,
+ idx_w * block_size_w // scale:(idx_w + 1) * block_size_w // scale]
+ feat.append(compute_feature(block))
+
+ distparam.append(np.array(feat))
+
+ if scale == 1:
+ img = imresize(img / 255., scale=0.5, antialiasing=True)
+ img = img * 255.
+
+ distparam = np.concatenate(distparam, axis=1)
+
+ # fit a MVG (multivariate Gaussian) model to distorted patch features
+ mu_distparam = np.nanmean(distparam, axis=0)
+ # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html
+ distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)]
+ cov_distparam = np.cov(distparam_no_nan, rowvar=False)
+
+ # compute niqe quality, Eq. 10 in the paper
+ invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)
+ quality = np.matmul(
+ np.matmul((mu_pris_param - mu_distparam), invcov_param), np.transpose((mu_pris_param - mu_distparam)))
+
+ quality = np.sqrt(quality)
+ quality = float(np.squeeze(quality))
+ return quality
+
+
+@METRIC_REGISTRY.register()
+def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y', **kwargs):
+ """Calculate NIQE (Natural Image Quality Evaluator) metric.
+
+ Ref: Making a "Completely Blind" Image Quality Analyzer.
+ This implementation could produce almost the same results as the official
+ MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
+
+ > MATLAB R2021a result for tests/data/baboon.png: 5.72957338 (5.7296)
+ > Our re-implementation result for tests/data/baboon.png: 5.7295763 (5.7296)
+
+ We use the official params estimated from the pristine dataset.
+ We use the recommended block size (96, 96) without overlaps.
+
+ Args:
+ img (ndarray): Input image whose quality needs to be computed.
+ The input image must be in range [0, 255] with float/int type.
+ The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order)
+ If the input order is 'HWC' or 'CHW', it will be converted to gray
+ or Y (of YCbCr) image according to the ``convert_to`` argument.
+ crop_border (int): Cropped pixels in each edge of an image. These
+ pixels are not involved in the metric calculation.
+ input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'.
+ Default: 'HWC'.
+ convert_to (str): Whether converted to 'y' (of MATLAB YCbCr) or 'gray'.
+ Default: 'y'.
+
+ Returns:
+ float: NIQE result.
+ """
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+ # we use the official params estimated from the pristine dataset.
+ niqe_pris_params = np.load(os.path.join(ROOT_DIR, 'niqe_pris_params.npz'))
+ mu_pris_param = niqe_pris_params['mu_pris_param']
+ cov_pris_param = niqe_pris_params['cov_pris_param']
+ gaussian_window = niqe_pris_params['gaussian_window']
+
+ img = img.astype(np.float32)
+ if input_order != 'HW':
+ img = reorder_image(img, input_order=input_order)
+ if convert_to == 'y':
+ img = to_y_channel(img)
+ elif convert_to == 'gray':
+ img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255.
+ img = np.squeeze(img)
+
+ if crop_border != 0:
+ img = img[crop_border:-crop_border, crop_border:-crop_border]
+
+ # round is necessary for being consistent with MATLAB's result
+ img = img.round()
+
+ niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window)
+
+ return niqe_result
diff --git a/r_basicsr/metrics/niqe_pris_params.npz b/r_basicsr/metrics/niqe_pris_params.npz Binary files differnew file mode 100644 index 0000000..204ddce --- /dev/null +++ b/r_basicsr/metrics/niqe_pris_params.npz diff --git a/r_basicsr/metrics/psnr_ssim.py b/r_basicsr/metrics/psnr_ssim.py new file mode 100644 index 0000000..938de12 --- /dev/null +++ b/r_basicsr/metrics/psnr_ssim.py @@ -0,0 +1,233 @@ +import cv2
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from r_basicsr.metrics.metric_util import reorder_image, to_y_channel
+from r_basicsr.utils.color_util import rgb2ycbcr_pt
+from r_basicsr.utils.registry import METRIC_REGISTRY
+
+
+@METRIC_REGISTRY.register()
+def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
+
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
+
+ Args:
+ img (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: PSNR result.
+ """
+
+ assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
+ img = reorder_image(img, input_order=input_order)
+ img2 = reorder_image(img2, input_order=input_order)
+
+ if crop_border != 0:
+ img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+ if test_y_channel:
+ img = to_y_channel(img)
+ img2 = to_y_channel(img2)
+
+ img = img.astype(np.float64)
+ img2 = img2.astype(np.float64)
+
+ mse = np.mean((img - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 10. * np.log10(255. * 255. / mse)
+
+
+@METRIC_REGISTRY.register()
+def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False, **kwargs):
+ """Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version).
+
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
+
+ Args:
+ img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
+ img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
+ crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: PSNR result.
+ """
+
+ assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
+
+ if crop_border != 0:
+ img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
+ img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border]
+
+ if test_y_channel:
+ img = rgb2ycbcr_pt(img, y_only=True)
+ img2 = rgb2ycbcr_pt(img2, y_only=True)
+
+ img = img.to(torch.float64)
+ img2 = img2.to(torch.float64)
+
+ mse = torch.mean((img - img2)**2, dim=[1, 2, 3])
+ return 10. * torch.log10(1. / (mse + 1e-8))
+
+
+@METRIC_REGISTRY.register()
+def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
+ """Calculate SSIM (structural similarity).
+
+ Ref:
+ Image quality assessment: From error visibility to structural similarity
+
+ The results are the same as that of the official released MATLAB code in
+ https://ece.uwaterloo.ca/~z70wang/research/ssim/.
+
+ For three-channel images, SSIM is calculated for each channel and then
+ averaged.
+
+ Args:
+ img (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ Default: 'HWC'.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: SSIM result.
+ """
+
+ assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
+ img = reorder_image(img, input_order=input_order)
+ img2 = reorder_image(img2, input_order=input_order)
+
+ if crop_border != 0:
+ img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+ if test_y_channel:
+ img = to_y_channel(img)
+ img2 = to_y_channel(img2)
+
+ img = img.astype(np.float64)
+ img2 = img2.astype(np.float64)
+
+ ssims = []
+ for i in range(img.shape[2]):
+ ssims.append(_ssim(img[..., i], img2[..., i]))
+ return np.array(ssims).mean()
+
+
+@METRIC_REGISTRY.register()
+def calculate_ssim_pt(img, img2, crop_border, test_y_channel=False, **kwargs):
+ """Calculate SSIM (structural similarity) (PyTorch version).
+
+ Ref:
+ Image quality assessment: From error visibility to structural similarity
+
+ The results are the same as that of the official released MATLAB code in
+ https://ece.uwaterloo.ca/~z70wang/research/ssim/.
+
+ For three-channel images, SSIM is calculated for each channel and then
+ averaged.
+
+ Args:
+ img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
+ img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
+ crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: SSIM result.
+ """
+
+ assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
+
+ if crop_border != 0:
+ img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
+ img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border]
+
+ if test_y_channel:
+ img = rgb2ycbcr_pt(img, y_only=True)
+ img2 = rgb2ycbcr_pt(img2, y_only=True)
+
+ img = img.to(torch.float64)
+ img2 = img2.to(torch.float64)
+
+ ssim = _ssim_pth(img * 255., img2 * 255.)
+ return ssim
+
+
+def _ssim(img, img2):
+ """Calculate SSIM (structural similarity) for one channel images.
+
+ It is called by func:`calculate_ssim`.
+
+ Args:
+ img (ndarray): Images with range [0, 255] with order 'HWC'.
+ img2 (ndarray): Images with range [0, 255] with order 'HWC'.
+
+ Returns:
+ float: SSIM result.
+ """
+
+ c1 = (0.01 * 255)**2
+ c2 = (0.03 * 255)**2
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] # valid mode for window size 11
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
+ return ssim_map.mean()
+
+
+def _ssim_pth(img, img2):
+ """Calculate SSIM (structural similarity) (PyTorch version).
+
+ It is called by func:`calculate_ssim_pt`.
+
+ Args:
+ img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
+ img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
+
+ Returns:
+ float: SSIM result.
+ """
+ c1 = (0.01 * 255)**2
+ c2 = (0.03 * 255)**2
+
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+ window = torch.from_numpy(window).view(1, 1, 11, 11).expand(img.size(1), 1, 11, 11).to(img.dtype).to(img.device)
+
+ mu1 = F.conv2d(img, window, stride=1, padding=0, groups=img.shape[1]) # valid mode
+ mu2 = F.conv2d(img2, window, stride=1, padding=0, groups=img2.shape[1]) # valid mode
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = F.conv2d(img * img, window, stride=1, padding=0, groups=img.shape[1]) - mu1_sq
+ sigma2_sq = F.conv2d(img2 * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu2_sq
+ sigma12 = F.conv2d(img * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu1_mu2
+
+ cs_map = (2 * sigma12 + c2) / (sigma1_sq + sigma2_sq + c2)
+ ssim_map = ((2 * mu1_mu2 + c1) / (mu1_sq + mu2_sq + c1)) * cs_map
+ return ssim_map.mean([1, 2, 3])
|