summaryrefslogtreecommitdiffstats
path: root/r_basicsr/utils/diffjpeg.py
diff options
context:
space:
mode:
Diffstat (limited to 'r_basicsr/utils/diffjpeg.py')
-rw-r--r--r_basicsr/utils/diffjpeg.py515
1 files changed, 515 insertions, 0 deletions
diff --git a/r_basicsr/utils/diffjpeg.py b/r_basicsr/utils/diffjpeg.py
new file mode 100644
index 0000000..c055c1b
--- /dev/null
+++ b/r_basicsr/utils/diffjpeg.py
@@ -0,0 +1,515 @@
+"""
+Modified from https://github.com/mlomnitz/DiffJPEG
+
+For images not divisible by 8
+https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343
+"""
+import itertools
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+# ------------------------ utils ------------------------#
+y_table = np.array(
+ [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56],
+ [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92],
+ [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
+ dtype=np.float32).T
+y_table = nn.Parameter(torch.from_numpy(y_table))
+c_table = np.empty((8, 8), dtype=np.float32)
+c_table.fill(99)
+c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T
+c_table = nn.Parameter(torch.from_numpy(c_table))
+
+
+def diff_round(x):
+ """ Differentiable rounding function
+ """
+ return torch.round(x) + (x - torch.round(x))**3
+
+
+def quality_to_factor(quality):
+ """ Calculate factor corresponding to quality
+
+ Args:
+ quality(float): Quality for jpeg compression.
+
+ Returns:
+ float: Compression factor.
+ """
+ if quality < 50:
+ quality = 5000. / quality
+ else:
+ quality = 200. - quality * 2
+ return quality / 100.
+
+
+# ------------------------ compression ------------------------#
+class RGB2YCbCrJpeg(nn.Module):
+ """ Converts RGB image to YCbCr
+ """
+
+ def __init__(self):
+ super(RGB2YCbCrJpeg, self).__init__()
+ matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]],
+ dtype=np.float32).T
+ self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
+ self.matrix = nn.Parameter(torch.from_numpy(matrix))
+
+ def forward(self, image):
+ """
+ Args:
+ image(Tensor): batch x 3 x height x width
+
+ Returns:
+ Tensor: batch x height x width x 3
+ """
+ image = image.permute(0, 2, 3, 1)
+ result = torch.tensordot(image, self.matrix, dims=1) + self.shift
+ return result.view(image.shape)
+
+
+class ChromaSubsampling(nn.Module):
+ """ Chroma subsampling on CbCr channels
+ """
+
+ def __init__(self):
+ super(ChromaSubsampling, self).__init__()
+
+ def forward(self, image):
+ """
+ Args:
+ image(tensor): batch x height x width x 3
+
+ Returns:
+ y(tensor): batch x height x width
+ cb(tensor): batch x height/2 x width/2
+ cr(tensor): batch x height/2 x width/2
+ """
+ image_2 = image.permute(0, 3, 1, 2).clone()
+ cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
+ cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
+ cb = cb.permute(0, 2, 3, 1)
+ cr = cr.permute(0, 2, 3, 1)
+ return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
+
+
+class BlockSplitting(nn.Module):
+ """ Splitting image into patches
+ """
+
+ def __init__(self):
+ super(BlockSplitting, self).__init__()
+ self.k = 8
+
+ def forward(self, image):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x h*w/64 x h x w
+ """
+ height, _ = image.shape[1:3]
+ batch_size = image.shape[0]
+ image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
+ image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
+ return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
+
+
+class DCT8x8(nn.Module):
+ """ Discrete Cosine Transformation
+ """
+
+ def __init__(self):
+ super(DCT8x8, self).__init__()
+ tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
+ for x, y, u, v in itertools.product(range(8), repeat=4):
+ tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)
+ alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
+ self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
+ self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())
+
+ def forward(self, image):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ image = image - 128
+ result = self.scale * torch.tensordot(image, self.tensor, dims=2)
+ result.view(image.shape)
+ return result
+
+
+class YQuantize(nn.Module):
+ """ JPEG Quantization for Y channel
+
+ Args:
+ rounding(function): rounding function to use
+ """
+
+ def __init__(self, rounding):
+ super(YQuantize, self).__init__()
+ self.rounding = rounding
+ self.y_table = y_table
+
+ def forward(self, image, factor=1):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ if isinstance(factor, (int, float)):
+ image = image.float() / (self.y_table * factor)
+ else:
+ b = factor.size(0)
+ table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
+ image = image.float() / table
+ image = self.rounding(image)
+ return image
+
+
+class CQuantize(nn.Module):
+ """ JPEG Quantization for CbCr channels
+
+ Args:
+ rounding(function): rounding function to use
+ """
+
+ def __init__(self, rounding):
+ super(CQuantize, self).__init__()
+ self.rounding = rounding
+ self.c_table = c_table
+
+ def forward(self, image, factor=1):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ if isinstance(factor, (int, float)):
+ image = image.float() / (self.c_table * factor)
+ else:
+ b = factor.size(0)
+ table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
+ image = image.float() / table
+ image = self.rounding(image)
+ return image
+
+
+class CompressJpeg(nn.Module):
+ """Full JPEG compression algorithm
+
+ Args:
+ rounding(function): rounding function to use
+ """
+
+ def __init__(self, rounding=torch.round):
+ super(CompressJpeg, self).__init__()
+ self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling())
+ self.l2 = nn.Sequential(BlockSplitting(), DCT8x8())
+ self.c_quantize = CQuantize(rounding=rounding)
+ self.y_quantize = YQuantize(rounding=rounding)
+
+ def forward(self, image, factor=1):
+ """
+ Args:
+ image(tensor): batch x 3 x height x width
+
+ Returns:
+ dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8.
+ """
+ y, cb, cr = self.l1(image * 255)
+ components = {'y': y, 'cb': cb, 'cr': cr}
+ for k in components.keys():
+ comp = self.l2(components[k])
+ if k in ('cb', 'cr'):
+ comp = self.c_quantize(comp, factor=factor)
+ else:
+ comp = self.y_quantize(comp, factor=factor)
+
+ components[k] = comp
+
+ return components['y'], components['cb'], components['cr']
+
+
+# ------------------------ decompression ------------------------#
+
+
+class YDequantize(nn.Module):
+ """Dequantize Y channel
+ """
+
+ def __init__(self):
+ super(YDequantize, self).__init__()
+ self.y_table = y_table
+
+ def forward(self, image, factor=1):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ if isinstance(factor, (int, float)):
+ out = image * (self.y_table * factor)
+ else:
+ b = factor.size(0)
+ table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
+ out = image * table
+ return out
+
+
+class CDequantize(nn.Module):
+ """Dequantize CbCr channel
+ """
+
+ def __init__(self):
+ super(CDequantize, self).__init__()
+ self.c_table = c_table
+
+ def forward(self, image, factor=1):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ if isinstance(factor, (int, float)):
+ out = image * (self.c_table * factor)
+ else:
+ b = factor.size(0)
+ table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
+ out = image * table
+ return out
+
+
+class iDCT8x8(nn.Module):
+ """Inverse discrete Cosine Transformation
+ """
+
+ def __init__(self):
+ super(iDCT8x8, self).__init__()
+ alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
+ self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
+ tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
+ for x, y, u, v in itertools.product(range(8), repeat=4):
+ tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16)
+ self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
+
+ def forward(self, image):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ image = image * self.alpha
+ result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
+ result.view(image.shape)
+ return result
+
+
+class BlockMerging(nn.Module):
+ """Merge patches into image
+ """
+
+ def __init__(self):
+ super(BlockMerging, self).__init__()
+
+ def forward(self, patches, height, width):
+ """
+ Args:
+ patches(tensor) batch x height*width/64, height x width
+ height(int)
+ width(int)
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ k = 8
+ batch_size = patches.shape[0]
+ image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
+ image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
+ return image_transposed.contiguous().view(batch_size, height, width)
+
+
+class ChromaUpsampling(nn.Module):
+ """Upsample chroma layers
+ """
+
+ def __init__(self):
+ super(ChromaUpsampling, self).__init__()
+
+ def forward(self, y, cb, cr):
+ """
+ Args:
+ y(tensor): y channel image
+ cb(tensor): cb channel
+ cr(tensor): cr channel
+
+ Returns:
+ Tensor: batch x height x width x 3
+ """
+
+ def repeat(x, k=2):
+ height, width = x.shape[1:3]
+ x = x.unsqueeze(-1)
+ x = x.repeat(1, 1, k, k)
+ x = x.view(-1, height * k, width * k)
+ return x
+
+ cb = repeat(cb)
+ cr = repeat(cr)
+ return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
+
+
+class YCbCr2RGBJpeg(nn.Module):
+ """Converts YCbCr image to RGB JPEG
+ """
+
+ def __init__(self):
+ super(YCbCr2RGBJpeg, self).__init__()
+
+ matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T
+ self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
+ self.matrix = nn.Parameter(torch.from_numpy(matrix))
+
+ def forward(self, image):
+ """
+ Args:
+ image(tensor): batch x height x width x 3
+
+ Returns:
+ Tensor: batch x 3 x height x width
+ """
+ result = torch.tensordot(image + self.shift, self.matrix, dims=1)
+ return result.view(image.shape).permute(0, 3, 1, 2)
+
+
+class DeCompressJpeg(nn.Module):
+ """Full JPEG decompression algorithm
+
+ Args:
+ rounding(function): rounding function to use
+ """
+
+ def __init__(self, rounding=torch.round):
+ super(DeCompressJpeg, self).__init__()
+ self.c_dequantize = CDequantize()
+ self.y_dequantize = YDequantize()
+ self.idct = iDCT8x8()
+ self.merging = BlockMerging()
+ self.chroma = ChromaUpsampling()
+ self.colors = YCbCr2RGBJpeg()
+
+ def forward(self, y, cb, cr, imgh, imgw, factor=1):
+ """
+ Args:
+ compressed(dict(tensor)): batch x h*w/64 x 8 x 8
+ imgh(int)
+ imgw(int)
+ factor(float)
+
+ Returns:
+ Tensor: batch x 3 x height x width
+ """
+ components = {'y': y, 'cb': cb, 'cr': cr}
+ for k in components.keys():
+ if k in ('cb', 'cr'):
+ comp = self.c_dequantize(components[k], factor=factor)
+ height, width = int(imgh / 2), int(imgw / 2)
+ else:
+ comp = self.y_dequantize(components[k], factor=factor)
+ height, width = imgh, imgw
+ comp = self.idct(comp)
+ components[k] = self.merging(comp, height, width)
+ #
+ image = self.chroma(components['y'], components['cb'], components['cr'])
+ image = self.colors(image)
+
+ image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image))
+ return image / 255
+
+
+# ------------------------ main DiffJPEG ------------------------ #
+
+
+class DiffJPEG(nn.Module):
+ """This JPEG algorithm result is slightly different from cv2.
+ DiffJPEG supports batch processing.
+
+ Args:
+ differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round
+ """
+
+ def __init__(self, differentiable=True):
+ super(DiffJPEG, self).__init__()
+ if differentiable:
+ rounding = diff_round
+ else:
+ rounding = torch.round
+
+ self.compress = CompressJpeg(rounding=rounding)
+ self.decompress = DeCompressJpeg(rounding=rounding)
+
+ def forward(self, x, quality):
+ """
+ Args:
+ x (Tensor): Input image, bchw, rgb, [0, 1]
+ quality(float): Quality factor for jpeg compression scheme.
+ """
+ factor = quality
+ if isinstance(factor, (int, float)):
+ factor = quality_to_factor(factor)
+ else:
+ for i in range(factor.size(0)):
+ factor[i] = quality_to_factor(factor[i])
+ h, w = x.size()[-2:]
+ h_pad, w_pad = 0, 0
+ # why should use 16
+ if h % 16 != 0:
+ h_pad = 16 - h % 16
+ if w % 16 != 0:
+ w_pad = 16 - w % 16
+ x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0)
+
+ y, cb, cr = self.compress(x, factor=factor)
+ recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor)
+ recovered = recovered[:, :, 0:h, 0:w]
+ return recovered
+
+
+if __name__ == '__main__':
+ import cv2
+
+ from r_basicsr.utils import img2tensor, tensor2img
+
+ img_gt = cv2.imread('test.png') / 255.
+
+ # -------------- cv2 -------------- #
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20]
+ _, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param)
+ img_lq = np.float32(cv2.imdecode(encimg, 1))
+ cv2.imwrite('cv2_JPEG_20.png', img_lq)
+
+ # -------------- DiffJPEG -------------- #
+ jpeger = DiffJPEG(differentiable=False).cuda()
+ img_gt = img2tensor(img_gt)
+ img_gt = torch.stack([img_gt, img_gt]).cuda()
+ quality = img_gt.new_tensor([20, 40])
+ out = jpeger(img_gt, quality=quality)
+
+ cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0]))
+ cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1]))