diff options
author | Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> | 2025-01-17 11:00:30 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-17 11:00:30 +0000 |
commit | 495ffc4777522e40941753e3b1b79c02f84b25b4 (patch) | |
tree | 5130fcb8676afdcb619a5e5eaef3ac28e135bc08 /r_basicsr/archs | |
parent | febd45814cd41560c5247aacb111d8d013f3a303 (diff) | |
download | Comfyui-reactor-node-495ffc4777522e40941753e3b1b79c02f84b25b4.tar.gz |
Add files via upload
Diffstat (limited to 'r_basicsr/archs')
24 files changed, 6144 insertions, 0 deletions
diff --git a/r_basicsr/archs/__init__.py b/r_basicsr/archs/__init__.py new file mode 100644 index 0000000..4a3f3c4 --- /dev/null +++ b/r_basicsr/archs/__init__.py @@ -0,0 +1,25 @@ +import importlib
+from copy import deepcopy
+from os import path as osp
+
+from r_basicsr.utils import get_root_logger, scandir
+from r_basicsr.utils.registry import ARCH_REGISTRY
+
+__all__ = ['build_network']
+
+# automatically scan and import arch modules for registry
+# scan all the files under the 'archs' folder and collect files ending with
+# '_arch.py'
+arch_folder = osp.dirname(osp.abspath(__file__))
+arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
+# import all the arch modules
+_arch_modules = [importlib.import_module(f'r_basicsr.archs.{file_name}') for file_name in arch_filenames]
+
+
+def build_network(opt):
+ opt = deepcopy(opt)
+ network_type = opt.pop('type')
+ net = ARCH_REGISTRY.get(network_type)(**opt)
+ logger = get_root_logger()
+ logger.info(f'Network [{net.__class__.__name__}] is created.')
+ return net
diff --git a/r_basicsr/archs/arch_util.py b/r_basicsr/archs/arch_util.py new file mode 100644 index 0000000..2b27eac --- /dev/null +++ b/r_basicsr/archs/arch_util.py @@ -0,0 +1,322 @@ +import collections.abc
+import math
+import torch
+import torchvision
+import warnings
+try:
+ from distutils.version import LooseVersion
+except:
+ from packaging.version import Version
+ LooseVersion = Version
+from itertools import repeat
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.nn import init as init
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from r_basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
+from r_basicsr.utils import get_root_logger
+
+
+@torch.no_grad()
+def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
+ """Initialize network weights.
+
+ Args:
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
+ scale (float): Scale initialized weights, especially for residual
+ blocks. Default: 1.
+ bias_fill (float): The value to fill bias. Default: 0
+ kwargs (dict): Other arguments for initialization function.
+ """
+ if not isinstance(module_list, list):
+ module_list = [module_list]
+ for module in module_list:
+ for m in module.modules():
+ if isinstance(m, nn.Conv2d):
+ init.kaiming_normal_(m.weight, **kwargs)
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+ elif isinstance(m, nn.Linear):
+ init.kaiming_normal_(m.weight, **kwargs)
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+ elif isinstance(m, _BatchNorm):
+ init.constant_(m.weight, 1)
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+
+
+def make_layer(basic_block, num_basic_block, **kwarg):
+ """Make layers by stacking the same blocks.
+
+ Args:
+ basic_block (nn.module): nn.module class for basic block.
+ num_basic_block (int): number of blocks.
+
+ Returns:
+ nn.Sequential: Stacked blocks in nn.Sequential.
+ """
+ layers = []
+ for _ in range(num_basic_block):
+ layers.append(basic_block(**kwarg))
+ return nn.Sequential(*layers)
+
+
+class ResidualBlockNoBN(nn.Module):
+ """Residual block without BN.
+
+ It has a style of:
+ ---Conv-ReLU-Conv-+-
+ |________________|
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ Default: 64.
+ res_scale (float): Residual scale. Default: 1.
+ pytorch_init (bool): If set to True, use pytorch default init,
+ otherwise, use default_init_weights. Default: False.
+ """
+
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
+ super(ResidualBlockNoBN, self).__init__()
+ self.res_scale = res_scale
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
+ self.relu = nn.ReLU(inplace=True)
+
+ if not pytorch_init:
+ default_init_weights([self.conv1, self.conv2], 0.1)
+
+ def forward(self, x):
+ identity = x
+ out = self.conv2(self.relu(self.conv1(x)))
+ return identity + out * self.res_scale
+
+
+class Upsample(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
+ super(Upsample, self).__init__(*m)
+
+
+def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
+ """Warp an image or feature map with optical flow.
+
+ Args:
+ x (Tensor): Tensor with size (n, c, h, w).
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
+ Default: 'zeros'.
+ align_corners (bool): Before pytorch 1.3, the default value is
+ align_corners=True. After pytorch 1.3, the default value is
+ align_corners=False. Here, we use the True as default.
+
+ Returns:
+ Tensor: Warped image or feature map.
+ """
+ assert x.size()[-2:] == flow.size()[1:3]
+ _, _, h, w = x.size()
+ # create mesh grid
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
+ grid.requires_grad = False
+
+ vgrid = grid + flow
+ # scale grid to [-1,1]
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
+
+ # TODO, what if align_corners=False
+ return output
+
+
+def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
+ """Resize a flow according to ratio or shape.
+
+ Args:
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
+ size_type (str): 'ratio' or 'shape'.
+ sizes (list[int | float]): the ratio for resizing or the final output
+ shape.
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
+ ratio > 1.0).
+ 2) The order of output_size should be [out_h, out_w].
+ interp_mode (str): The mode of interpolation for resizing.
+ Default: 'bilinear'.
+ align_corners (bool): Whether align corners. Default: False.
+
+ Returns:
+ Tensor: Resized flow.
+ """
+ _, _, flow_h, flow_w = flow.size()
+ if size_type == 'ratio':
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
+ elif size_type == 'shape':
+ output_h, output_w = sizes[0], sizes[1]
+ else:
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
+
+ input_flow = flow.clone()
+ ratio_h = output_h / flow_h
+ ratio_w = output_w / flow_w
+ input_flow[:, 0, :, :] *= ratio_w
+ input_flow[:, 1, :, :] *= ratio_h
+ resized_flow = F.interpolate(
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
+ return resized_flow
+
+
+# TODO: may write a cpp file
+def pixel_unshuffle(x, scale):
+ """ Pixel unshuffle.
+
+ Args:
+ x (Tensor): Input feature with shape (b, c, hh, hw).
+ scale (int): Downsample ratio.
+
+ Returns:
+ Tensor: the pixel unshuffled feature.
+ """
+ b, c, hh, hw = x.size()
+ out_channel = c * (scale**2)
+ assert hh % scale == 0 and hw % scale == 0
+ h = hh // scale
+ w = hw // scale
+ x_view = x.view(b, c, h, scale, w, scale)
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
+
+
+class DCNv2Pack(ModulatedDeformConvPack):
+ """Modulated deformable conv for deformable alignment.
+
+ Different from the official DCNv2Pack, which generates offsets and masks
+ from the preceding features, this DCNv2Pack takes another different
+ features to generate offsets and masks.
+
+ Ref:
+ Delving Deep into Deformable Alignment in Video Super-Resolution.
+ """
+
+ def forward(self, x, feat):
+ out = self.conv_offset(feat)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+
+ offset_absmean = torch.mean(torch.abs(offset))
+ if offset_absmean > 50:
+ logger = get_root_logger()
+ logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
+
+ if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, mask)
+ else:
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups, self.deformable_groups)
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ low = norm_cdf((a - mean) / std)
+ up = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [low, up], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution.
+
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
+
+ The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+# From PyTorch
+def _ntuple(n):
+
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
diff --git a/r_basicsr/archs/basicvsr_arch.py b/r_basicsr/archs/basicvsr_arch.py new file mode 100644 index 0000000..b812c7f --- /dev/null +++ b/r_basicsr/archs/basicvsr_arch.py @@ -0,0 +1,336 @@ +import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import ResidualBlockNoBN, flow_warp, make_layer
+from .edvr_arch import PCDAlignment, TSAFusion
+from .spynet_arch import SpyNet
+
+
+@ARCH_REGISTRY.register()
+class BasicVSR(nn.Module):
+ """A recurrent network for video SR. Now only x4 is supported.
+
+ Args:
+ num_feat (int): Number of channels. Default: 64.
+ num_block (int): Number of residual blocks for each branch. Default: 15
+ spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
+ """
+
+ def __init__(self, num_feat=64, num_block=15, spynet_path=None):
+ super().__init__()
+ self.num_feat = num_feat
+
+ # alignment
+ self.spynet = SpyNet(spynet_path)
+
+ # propagation
+ self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
+ self.forward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
+
+ # reconstruction
+ self.fusion = nn.Conv2d(num_feat * 2, num_feat, 1, 1, 0, bias=True)
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
+ self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
+ self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
+ self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
+
+ self.pixel_shuffle = nn.PixelShuffle(2)
+
+ # activation functions
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ def get_flow(self, x):
+ b, n, c, h, w = x.size()
+
+ x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
+ x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
+
+ flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
+ flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
+
+ return flows_forward, flows_backward
+
+ def forward(self, x):
+ """Forward function of BasicVSR.
+
+ Args:
+ x: Input frames with shape (b, n, c, h, w). n is the temporal dimension / number of frames.
+ """
+ flows_forward, flows_backward = self.get_flow(x)
+ b, n, _, h, w = x.size()
+
+ # backward branch
+ out_l = []
+ feat_prop = x.new_zeros(b, self.num_feat, h, w)
+ for i in range(n - 1, -1, -1):
+ x_i = x[:, i, :, :, :]
+ if i < n - 1:
+ flow = flows_backward[:, i, :, :, :]
+ feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
+ feat_prop = torch.cat([x_i, feat_prop], dim=1)
+ feat_prop = self.backward_trunk(feat_prop)
+ out_l.insert(0, feat_prop)
+
+ # forward branch
+ feat_prop = torch.zeros_like(feat_prop)
+ for i in range(0, n):
+ x_i = x[:, i, :, :, :]
+ if i > 0:
+ flow = flows_forward[:, i - 1, :, :, :]
+ feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
+
+ feat_prop = torch.cat([x_i, feat_prop], dim=1)
+ feat_prop = self.forward_trunk(feat_prop)
+
+ # upsample
+ out = torch.cat([out_l[i], feat_prop], dim=1)
+ out = self.lrelu(self.fusion(out))
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
+ out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
+ out = self.lrelu(self.conv_hr(out))
+ out = self.conv_last(out)
+ base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
+ out += base
+ out_l[i] = out
+
+ return torch.stack(out_l, dim=1)
+
+
+class ConvResidualBlocks(nn.Module):
+ """Conv and residual block used in BasicVSR.
+
+ Args:
+ num_in_ch (int): Number of input channels. Default: 3.
+ num_out_ch (int): Number of output channels. Default: 64.
+ num_block (int): Number of residual blocks. Default: 15.
+ """
+
+ def __init__(self, num_in_ch=3, num_out_ch=64, num_block=15):
+ super().__init__()
+ self.main = nn.Sequential(
+ nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ make_layer(ResidualBlockNoBN, num_block, num_feat=num_out_ch))
+
+ def forward(self, fea):
+ return self.main(fea)
+
+
+@ARCH_REGISTRY.register()
+class IconVSR(nn.Module):
+ """IconVSR, proposed also in the BasicVSR paper.
+
+ Args:
+ num_feat (int): Number of channels. Default: 64.
+ num_block (int): Number of residual blocks for each branch. Default: 15.
+ keyframe_stride (int): Keyframe stride. Default: 5.
+ temporal_padding (int): Temporal padding. Default: 2.
+ spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
+ edvr_path (str): Path to the pretrained EDVR model. Default: None.
+ """
+
+ def __init__(self,
+ num_feat=64,
+ num_block=15,
+ keyframe_stride=5,
+ temporal_padding=2,
+ spynet_path=None,
+ edvr_path=None):
+ super().__init__()
+
+ self.num_feat = num_feat
+ self.temporal_padding = temporal_padding
+ self.keyframe_stride = keyframe_stride
+
+ # keyframe_branch
+ self.edvr = EDVRFeatureExtractor(temporal_padding * 2 + 1, num_feat, edvr_path)
+ # alignment
+ self.spynet = SpyNet(spynet_path)
+
+ # propagation
+ self.backward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
+ self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
+
+ self.forward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
+ self.forward_trunk = ConvResidualBlocks(2 * num_feat + 3, num_feat, num_block)
+
+ # reconstruction
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
+ self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
+ self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
+ self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
+
+ self.pixel_shuffle = nn.PixelShuffle(2)
+
+ # activation functions
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ def pad_spatial(self, x):
+ """Apply padding spatially.
+
+ Since the PCD module in EDVR requires that the resolution is a multiple
+ of 4, we apply padding to the input LR images if their resolution is
+ not divisible by 4.
+
+ Args:
+ x (Tensor): Input LR sequence with shape (n, t, c, h, w).
+ Returns:
+ Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad).
+ """
+ n, t, c, h, w = x.size()
+
+ pad_h = (4 - h % 4) % 4
+ pad_w = (4 - w % 4) % 4
+
+ # padding
+ x = x.view(-1, c, h, w)
+ x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
+
+ return x.view(n, t, c, h + pad_h, w + pad_w)
+
+ def get_flow(self, x):
+ b, n, c, h, w = x.size()
+
+ x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
+ x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
+
+ flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
+ flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
+
+ return flows_forward, flows_backward
+
+ def get_keyframe_feature(self, x, keyframe_idx):
+ if self.temporal_padding == 2:
+ x = [x[:, [4, 3]], x, x[:, [-4, -5]]]
+ elif self.temporal_padding == 3:
+ x = [x[:, [6, 5, 4]], x, x[:, [-5, -6, -7]]]
+ x = torch.cat(x, dim=1)
+
+ num_frames = 2 * self.temporal_padding + 1
+ feats_keyframe = {}
+ for i in keyframe_idx:
+ feats_keyframe[i] = self.edvr(x[:, i:i + num_frames].contiguous())
+ return feats_keyframe
+
+ def forward(self, x):
+ b, n, _, h_input, w_input = x.size()
+
+ x = self.pad_spatial(x)
+ h, w = x.shape[3:]
+
+ keyframe_idx = list(range(0, n, self.keyframe_stride))
+ if keyframe_idx[-1] != n - 1:
+ keyframe_idx.append(n - 1) # last frame is a keyframe
+
+ # compute flow and keyframe features
+ flows_forward, flows_backward = self.get_flow(x)
+ feats_keyframe = self.get_keyframe_feature(x, keyframe_idx)
+
+ # backward branch
+ out_l = []
+ feat_prop = x.new_zeros(b, self.num_feat, h, w)
+ for i in range(n - 1, -1, -1):
+ x_i = x[:, i, :, :, :]
+ if i < n - 1:
+ flow = flows_backward[:, i, :, :, :]
+ feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
+ if i in keyframe_idx:
+ feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
+ feat_prop = self.backward_fusion(feat_prop)
+ feat_prop = torch.cat([x_i, feat_prop], dim=1)
+ feat_prop = self.backward_trunk(feat_prop)
+ out_l.insert(0, feat_prop)
+
+ # forward branch
+ feat_prop = torch.zeros_like(feat_prop)
+ for i in range(0, n):
+ x_i = x[:, i, :, :, :]
+ if i > 0:
+ flow = flows_forward[:, i - 1, :, :, :]
+ feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
+ if i in keyframe_idx:
+ feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
+ feat_prop = self.forward_fusion(feat_prop)
+
+ feat_prop = torch.cat([x_i, out_l[i], feat_prop], dim=1)
+ feat_prop = self.forward_trunk(feat_prop)
+
+ # upsample
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(feat_prop)))
+ out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
+ out = self.lrelu(self.conv_hr(out))
+ out = self.conv_last(out)
+ base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
+ out += base
+ out_l[i] = out
+
+ return torch.stack(out_l, dim=1)[..., :4 * h_input, :4 * w_input]
+
+
+class EDVRFeatureExtractor(nn.Module):
+ """EDVR feature extractor used in IconVSR.
+
+ Args:
+ num_input_frame (int): Number of input frames.
+ num_feat (int): Number of feature channels
+ load_path (str): Path to the pretrained weights of EDVR. Default: None.
+ """
+
+ def __init__(self, num_input_frame, num_feat, load_path):
+
+ super(EDVRFeatureExtractor, self).__init__()
+
+ self.center_frame_idx = num_input_frame // 2
+
+ # extract pyramid features
+ self.conv_first = nn.Conv2d(3, num_feat, 3, 1, 1)
+ self.feature_extraction = make_layer(ResidualBlockNoBN, 5, num_feat=num_feat)
+ self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+ self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+ self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+
+ # pcd and tsa module
+ self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=8)
+ self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_input_frame, center_frame_idx=self.center_frame_idx)
+
+ # activation function
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ if load_path:
+ self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
+
+ def forward(self, x):
+ b, n, c, h, w = x.size()
+
+ # extract features for each frame
+ # L1
+ feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
+ feat_l1 = self.feature_extraction(feat_l1)
+ # L2
+ feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
+ feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
+ # L3
+ feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
+ feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
+
+ feat_l1 = feat_l1.view(b, n, -1, h, w)
+ feat_l2 = feat_l2.view(b, n, -1, h // 2, w // 2)
+ feat_l3 = feat_l3.view(b, n, -1, h // 4, w // 4)
+
+ # PCD alignment
+ ref_feat_l = [ # reference feature list
+ feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
+ feat_l3[:, self.center_frame_idx, :, :, :].clone()
+ ]
+ aligned_feat = []
+ for i in range(n):
+ nbr_feat_l = [ # neighboring feature list
+ feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
+ ]
+ aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
+ aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
+
+ # TSA fusion
+ return self.fusion(aligned_feat)
diff --git a/r_basicsr/archs/basicvsrpp_arch.py b/r_basicsr/archs/basicvsrpp_arch.py new file mode 100644 index 0000000..f53b434 --- /dev/null +++ b/r_basicsr/archs/basicvsrpp_arch.py @@ -0,0 +1,407 @@ +import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+import warnings
+
+from r_basicsr.archs.arch_util import flow_warp
+from r_basicsr.archs.basicvsr_arch import ConvResidualBlocks
+from r_basicsr.archs.spynet_arch import SpyNet
+from r_basicsr.ops.dcn import ModulatedDeformConvPack
+from r_basicsr.utils.registry import ARCH_REGISTRY
+
+
+@ARCH_REGISTRY.register()
+class BasicVSRPlusPlus(nn.Module):
+ """BasicVSR++ network structure.
+ Support either x4 upsampling or same size output. Since DCN is used in this
+ model, it can only be used with CUDA enabled. If CUDA is not enabled,
+ feature alignment will be skipped. Besides, we adopt the official DCN
+ implementation and the version of torch need to be higher than 1.9.
+ Paper:
+ BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation
+ and Alignment
+ Args:
+ mid_channels (int, optional): Channel number of the intermediate
+ features. Default: 64.
+ num_blocks (int, optional): The number of residual blocks in each
+ propagation branch. Default: 7.
+ max_residue_magnitude (int): The maximum magnitude of the offset
+ residue (Eq. 6 in paper). Default: 10.
+ is_low_res_input (bool, optional): Whether the input is low-resolution
+ or not. If False, the output resolution is equal to the input
+ resolution. Default: True.
+ spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
+ cpu_cache_length (int, optional): When the length of sequence is larger
+ than this value, the intermediate features are sent to CPU. This
+ saves GPU memory, but slows down the inference speed. You can
+ increase this number if you have a GPU with large memory.
+ Default: 100.
+ """
+
+ def __init__(self,
+ mid_channels=64,
+ num_blocks=7,
+ max_residue_magnitude=10,
+ is_low_res_input=True,
+ spynet_path=None,
+ cpu_cache_length=100):
+
+ super().__init__()
+ self.mid_channels = mid_channels
+ self.is_low_res_input = is_low_res_input
+ self.cpu_cache_length = cpu_cache_length
+
+ # optical flow
+ self.spynet = SpyNet(spynet_path)
+
+ # feature extraction module
+ if is_low_res_input:
+ self.feat_extract = ConvResidualBlocks(3, mid_channels, 5)
+ else:
+ self.feat_extract = nn.Sequential(
+ nn.Conv2d(3, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(mid_channels, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ ConvResidualBlocks(mid_channels, mid_channels, 5))
+
+ # propagation branches
+ self.deform_align = nn.ModuleDict()
+ self.backbone = nn.ModuleDict()
+ modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2']
+ for i, module in enumerate(modules):
+ if torch.cuda.is_available():
+ self.deform_align[module] = SecondOrderDeformableAlignment(
+ 2 * mid_channels,
+ mid_channels,
+ 3,
+ padding=1,
+ deformable_groups=16,
+ max_residue_magnitude=max_residue_magnitude)
+ self.backbone[module] = ConvResidualBlocks((2 + i) * mid_channels, mid_channels, num_blocks)
+
+ # upsampling module
+ self.reconstruction = ConvResidualBlocks(5 * mid_channels, mid_channels, 5)
+
+ self.upconv1 = nn.Conv2d(mid_channels, mid_channels * 4, 3, 1, 1, bias=True)
+ self.upconv2 = nn.Conv2d(mid_channels, 64 * 4, 3, 1, 1, bias=True)
+
+ self.pixel_shuffle = nn.PixelShuffle(2)
+
+ self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
+ self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
+ self.img_upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
+
+ # activation function
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ # check if the sequence is augmented by flipping
+ self.is_mirror_extended = False
+
+ if len(self.deform_align) > 0:
+ self.is_with_alignment = True
+ else:
+ self.is_with_alignment = False
+ warnings.warn('Deformable alignment module is not added. '
+ 'Probably your CUDA is not configured correctly. DCN can only '
+ 'be used with CUDA enabled. Alignment is skipped now.')
+
+ def check_if_mirror_extended(self, lqs):
+ """Check whether the input is a mirror-extended sequence.
+ If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the
+ (t-1-i)-th frame.
+ Args:
+ lqs (tensor): Input low quality (LQ) sequence with
+ shape (n, t, c, h, w).
+ """
+
+ if lqs.size(1) % 2 == 0:
+ lqs_1, lqs_2 = torch.chunk(lqs, 2, dim=1)
+ if torch.norm(lqs_1 - lqs_2.flip(1)) == 0:
+ self.is_mirror_extended = True
+
+ def compute_flow(self, lqs):
+ """Compute optical flow using SPyNet for feature alignment.
+ Note that if the input is an mirror-extended sequence, 'flows_forward'
+ is not needed, since it is equal to 'flows_backward.flip(1)'.
+ Args:
+ lqs (tensor): Input low quality (LQ) sequence with
+ shape (n, t, c, h, w).
+ Return:
+ tuple(Tensor): Optical flow. 'flows_forward' corresponds to the
+ flows used for forward-time propagation (current to previous).
+ 'flows_backward' corresponds to the flows used for
+ backward-time propagation (current to next).
+ """
+
+ n, t, c, h, w = lqs.size()
+ lqs_1 = lqs[:, :-1, :, :, :].reshape(-1, c, h, w)
+ lqs_2 = lqs[:, 1:, :, :, :].reshape(-1, c, h, w)
+
+ flows_backward = self.spynet(lqs_1, lqs_2).view(n, t - 1, 2, h, w)
+
+ if self.is_mirror_extended: # flows_forward = flows_backward.flip(1)
+ flows_forward = flows_backward.flip(1)
+ else:
+ flows_forward = self.spynet(lqs_2, lqs_1).view(n, t - 1, 2, h, w)
+
+ if self.cpu_cache:
+ flows_backward = flows_backward.cpu()
+ flows_forward = flows_forward.cpu()
+
+ return flows_forward, flows_backward
+
+ def propagate(self, feats, flows, module_name):
+ """Propagate the latent features throughout the sequence.
+ Args:
+ feats dict(list[tensor]): Features from previous branches. Each
+ component is a list of tensors with shape (n, c, h, w).
+ flows (tensor): Optical flows with shape (n, t - 1, 2, h, w).
+ module_name (str): The name of the propgation branches. Can either
+ be 'backward_1', 'forward_1', 'backward_2', 'forward_2'.
+ Return:
+ dict(list[tensor]): A dictionary containing all the propagated
+ features. Each key in the dictionary corresponds to a
+ propagation branch, which is represented by a list of tensors.
+ """
+
+ n, t, _, h, w = flows.size()
+
+ frame_idx = range(0, t + 1)
+ flow_idx = range(-1, t)
+ mapping_idx = list(range(0, len(feats['spatial'])))
+ mapping_idx += mapping_idx[::-1]
+
+ if 'backward' in module_name:
+ frame_idx = frame_idx[::-1]
+ flow_idx = frame_idx
+
+ feat_prop = flows.new_zeros(n, self.mid_channels, h, w)
+ for i, idx in enumerate(frame_idx):
+ feat_current = feats['spatial'][mapping_idx[idx]]
+ if self.cpu_cache:
+ feat_current = feat_current.cuda()
+ feat_prop = feat_prop.cuda()
+ # second-order deformable alignment
+ if i > 0 and self.is_with_alignment:
+ flow_n1 = flows[:, flow_idx[i], :, :, :]
+ if self.cpu_cache:
+ flow_n1 = flow_n1.cuda()
+
+ cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
+
+ # initialize second-order features
+ feat_n2 = torch.zeros_like(feat_prop)
+ flow_n2 = torch.zeros_like(flow_n1)
+ cond_n2 = torch.zeros_like(cond_n1)
+
+ if i > 1: # second-order features
+ feat_n2 = feats[module_name][-2]
+ if self.cpu_cache:
+ feat_n2 = feat_n2.cuda()
+
+ flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
+ if self.cpu_cache:
+ flow_n2 = flow_n2.cuda()
+
+ flow_n2 = flow_n1 + flow_warp(flow_n2, flow_n1.permute(0, 2, 3, 1))
+ cond_n2 = flow_warp(feat_n2, flow_n2.permute(0, 2, 3, 1))
+
+ # flow-guided deformable convolution
+ cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
+ feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
+ feat_prop = self.deform_align[module_name](feat_prop, cond, flow_n1, flow_n2)
+
+ # concatenate and residual blocks
+ feat = [feat_current] + [feats[k][idx] for k in feats if k not in ['spatial', module_name]] + [feat_prop]
+ if self.cpu_cache:
+ feat = [f.cuda() for f in feat]
+
+ feat = torch.cat(feat, dim=1)
+ feat_prop = feat_prop + self.backbone[module_name](feat)
+ feats[module_name].append(feat_prop)
+
+ if self.cpu_cache:
+ feats[module_name][-1] = feats[module_name][-1].cpu()
+ torch.cuda.empty_cache()
+
+ if 'backward' in module_name:
+ feats[module_name] = feats[module_name][::-1]
+
+ return feats
+
+ def upsample(self, lqs, feats):
+ """Compute the output image given the features.
+ Args:
+ lqs (tensor): Input low quality (LQ) sequence with
+ shape (n, t, c, h, w).
+ feats (dict): The features from the propgation branches.
+ Returns:
+ Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
+ """
+
+ outputs = []
+ num_outputs = len(feats['spatial'])
+
+ mapping_idx = list(range(0, num_outputs))
+ mapping_idx += mapping_idx[::-1]
+
+ for i in range(0, lqs.size(1)):
+ hr = [feats[k].pop(0) for k in feats if k != 'spatial']
+ hr.insert(0, feats['spatial'][mapping_idx[i]])
+ hr = torch.cat(hr, dim=1)
+ if self.cpu_cache:
+ hr = hr.cuda()
+
+ hr = self.reconstruction(hr)
+ hr = self.lrelu(self.pixel_shuffle(self.upconv1(hr)))
+ hr = self.lrelu(self.pixel_shuffle(self.upconv2(hr)))
+ hr = self.lrelu(self.conv_hr(hr))
+ hr = self.conv_last(hr)
+ if self.is_low_res_input:
+ hr += self.img_upsample(lqs[:, i, :, :, :])
+ else:
+ hr += lqs[:, i, :, :, :]
+
+ if self.cpu_cache:
+ hr = hr.cpu()
+ torch.cuda.empty_cache()
+
+ outputs.append(hr)
+
+ return torch.stack(outputs, dim=1)
+
+ def forward(self, lqs):
+ """Forward function for BasicVSR++.
+ Args:
+ lqs (tensor): Input low quality (LQ) sequence with
+ shape (n, t, c, h, w).
+ Returns:
+ Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
+ """
+
+ n, t, c, h, w = lqs.size()
+
+ # whether to cache the features in CPU
+ self.cpu_cache = True if t > self.cpu_cache_length else False
+
+ if self.is_low_res_input:
+ lqs_downsample = lqs.clone()
+ else:
+ lqs_downsample = F.interpolate(
+ lqs.view(-1, c, h, w), scale_factor=0.25, mode='bicubic').view(n, t, c, h // 4, w // 4)
+
+ # check whether the input is an extended sequence
+ self.check_if_mirror_extended(lqs)
+
+ feats = {}
+ # compute spatial features
+ if self.cpu_cache:
+ feats['spatial'] = []
+ for i in range(0, t):
+ feat = self.feat_extract(lqs[:, i, :, :, :]).cpu()
+ feats['spatial'].append(feat)
+ torch.cuda.empty_cache()
+ else:
+ feats_ = self.feat_extract(lqs.view(-1, c, h, w))
+ h, w = feats_.shape[2:]
+ feats_ = feats_.view(n, t, -1, h, w)
+ feats['spatial'] = [feats_[:, i, :, :, :] for i in range(0, t)]
+
+ # compute optical flow using the low-res inputs
+ assert lqs_downsample.size(3) >= 64 and lqs_downsample.size(4) >= 64, (
+ 'The height and width of low-res inputs must be at least 64, '
+ f'but got {h} and {w}.')
+ flows_forward, flows_backward = self.compute_flow(lqs_downsample)
+
+ # feature propgation
+ for iter_ in [1, 2]:
+ for direction in ['backward', 'forward']:
+ module = f'{direction}_{iter_}'
+
+ feats[module] = []
+
+ if direction == 'backward':
+ flows = flows_backward
+ elif flows_forward is not None:
+ flows = flows_forward
+ else:
+ flows = flows_backward.flip(1)
+
+ feats = self.propagate(feats, flows, module)
+ if self.cpu_cache:
+ del flows
+ torch.cuda.empty_cache()
+
+ return self.upsample(lqs, feats)
+
+
+class SecondOrderDeformableAlignment(ModulatedDeformConvPack):
+ """Second-order deformable alignment module.
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ max_residue_magnitude (int): The maximum magnitude of the offset
+ residue (Eq. 6 in paper). Default: 10.
+ """
+
+ def __init__(self, *args, **kwargs):
+ self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
+
+ super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Sequential(
+ nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(self.out_channels, 27 * self.deformable_groups, 3, 1, 1),
+ )
+
+ self.init_offset()
+
+ def init_offset(self):
+
+ def _constant_init(module, val, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.constant_(module.weight, val)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+ _constant_init(self.conv_offset[-1], val=0, bias=0)
+
+ def forward(self, x, extra_feat, flow_1, flow_2):
+ extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
+ out = self.conv_offset(extra_feat)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+
+ # offset
+ offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
+ offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
+ offset_1 = offset_1 + flow_1.flip(1).repeat(1, offset_1.size(1) // 2, 1, 1)
+ offset_2 = offset_2 + flow_2.flip(1).repeat(1, offset_2.size(1) // 2, 1, 1)
+ offset = torch.cat([offset_1, offset_2], dim=1)
+
+ # mask
+ mask = torch.sigmoid(mask)
+
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, mask)
+
+
+# if __name__ == '__main__':
+# spynet_path = 'experiments/pretrained_models/flownet/spynet_sintel_final-3d2a1287.pth'
+# model = BasicVSRPlusPlus(spynet_path=spynet_path).cuda()
+# input = torch.rand(1, 2, 3, 64, 64).cuda()
+# output = model(input)
+# print('===================')
+# print(output.shape)
diff --git a/r_basicsr/archs/dfdnet_arch.py b/r_basicsr/archs/dfdnet_arch.py new file mode 100644 index 0000000..04093f0 --- /dev/null +++ b/r_basicsr/archs/dfdnet_arch.py @@ -0,0 +1,169 @@ +import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils.spectral_norm import spectral_norm
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+from .dfdnet_util import AttentionBlock, Blur, MSDilationBlock, UpResBlock, adaptive_instance_normalization
+from .vgg_arch import VGGFeatureExtractor
+
+
+class SFTUpBlock(nn.Module):
+ """Spatial feature transform (SFT) with upsampling block.
+
+ Args:
+ in_channel (int): Number of input channels.
+ out_channel (int): Number of output channels.
+ kernel_size (int): Kernel size in convolutions. Default: 3.
+ padding (int): Padding in convolutions. Default: 1.
+ """
+
+ def __init__(self, in_channel, out_channel, kernel_size=3, padding=1):
+ super(SFTUpBlock, self).__init__()
+ self.conv1 = nn.Sequential(
+ Blur(in_channel),
+ spectral_norm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
+ nn.LeakyReLU(0.04, True),
+ # The official codes use two LeakyReLU here, so 0.04 for equivalent
+ )
+ self.convup = nn.Sequential(
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
+ spectral_norm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
+ nn.LeakyReLU(0.2, True),
+ )
+
+ # for SFT scale and shift
+ self.scale_block = nn.Sequential(
+ spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
+ spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)))
+ self.shift_block = nn.Sequential(
+ spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
+ spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), nn.Sigmoid())
+ # The official codes use sigmoid for shift block, do not know why
+
+ def forward(self, x, updated_feat):
+ out = self.conv1(x)
+ # SFT
+ scale = self.scale_block(updated_feat)
+ shift = self.shift_block(updated_feat)
+ out = out * scale + shift
+ # upsample
+ out = self.convup(out)
+ return out
+
+
+@ARCH_REGISTRY.register()
+class DFDNet(nn.Module):
+ """DFDNet: Deep Face Dictionary Network.
+
+ It only processes faces with 512x512 size.
+
+ Args:
+ num_feat (int): Number of feature channels.
+ dict_path (str): Path to the facial component dictionary.
+ """
+
+ def __init__(self, num_feat, dict_path):
+ super().__init__()
+ self.parts = ['left_eye', 'right_eye', 'nose', 'mouth']
+ # part_sizes: [80, 80, 50, 110]
+ channel_sizes = [128, 256, 512, 512]
+ self.feature_sizes = np.array([256, 128, 64, 32])
+ self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4']
+ self.flag_dict_device = False
+
+ # dict
+ self.dict = torch.load(dict_path)
+
+ # vgg face extractor
+ self.vgg_extractor = VGGFeatureExtractor(
+ layer_name_list=self.vgg_layers,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=True,
+ requires_grad=False)
+
+ # attention block for fusing dictionary features and input features
+ self.attn_blocks = nn.ModuleDict()
+ for idx, feat_size in enumerate(self.feature_sizes):
+ for name in self.parts:
+ self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock(channel_sizes[idx])
+
+ # multi scale dilation block
+ self.multi_scale_dilation = MSDilationBlock(num_feat * 8, dilation=[4, 3, 2, 1])
+
+ # upsampling and reconstruction
+ self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8)
+ self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4)
+ self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2)
+ self.upsample3 = SFTUpBlock(num_feat * 2, num_feat)
+ self.upsample4 = nn.Sequential(
+ spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), nn.LeakyReLU(0.2, True), UpResBlock(num_feat),
+ UpResBlock(num_feat), nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), nn.Tanh())
+
+ def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_size):
+ """swap the features from the dictionary."""
+ # get the original vgg features
+ part_feat = vgg_feat[:, :, location[1]:location[3], location[0]:location[2]].clone()
+ # resize original vgg features
+ part_resize_feat = F.interpolate(part_feat, dict_feat.size()[2:4], mode='bilinear', align_corners=False)
+ # use adaptive instance normalization to adjust color and illuminations
+ dict_feat = adaptive_instance_normalization(dict_feat, part_resize_feat)
+ # get similarity scores
+ similarity_score = F.conv2d(part_resize_feat, dict_feat)
+ similarity_score = F.softmax(similarity_score.view(-1), dim=0)
+ # select the most similar features in the dict (after norm)
+ select_idx = torch.argmax(similarity_score)
+ swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], part_feat.size()[2:4])
+ # attention
+ attn = self.attn_blocks[f'{part_name}_' + str(f_size)](swap_feat - part_feat)
+ attn_feat = attn * swap_feat
+ # update features
+ updated_feat[:, :, location[1]:location[3], location[0]:location[2]] = attn_feat + part_feat
+ return updated_feat
+
+ def put_dict_to_device(self, x):
+ if self.flag_dict_device is False:
+ for k, v in self.dict.items():
+ for kk, vv in v.items():
+ self.dict[k][kk] = vv.to(x)
+ self.flag_dict_device = True
+
+ def forward(self, x, part_locations):
+ """
+ Now only support testing with batch size = 0.
+
+ Args:
+ x (Tensor): Input faces with shape (b, c, 512, 512).
+ part_locations (list[Tensor]): Part locations.
+ """
+ self.put_dict_to_device(x)
+ # extract vggface features
+ vgg_features = self.vgg_extractor(x)
+ # update vggface features using the dictionary for each part
+ updated_vgg_features = []
+ batch = 0 # only supports testing with batch size = 0
+ for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes):
+ dict_features = self.dict[f'{f_size}']
+ vgg_feat = vgg_features[vgg_layer]
+ updated_feat = vgg_feat.clone()
+
+ # swap features from dictionary
+ for part_idx, part_name in enumerate(self.parts):
+ location = (part_locations[part_idx][batch] // (512 / f_size)).int()
+ updated_feat = self.swap_feat(vgg_feat, updated_feat, dict_features[part_name], location, part_name,
+ f_size)
+
+ updated_vgg_features.append(updated_feat)
+
+ vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4'])
+ # use updated vgg features to modulate the upsampled features with
+ # SFT (Spatial Feature Transform) scaling and shifting manner.
+ upsampled_feat = self.upsample0(vgg_feat_dilation, updated_vgg_features[3])
+ upsampled_feat = self.upsample1(upsampled_feat, updated_vgg_features[2])
+ upsampled_feat = self.upsample2(upsampled_feat, updated_vgg_features[1])
+ upsampled_feat = self.upsample3(upsampled_feat, updated_vgg_features[0])
+ out = self.upsample4(upsampled_feat)
+
+ return out
diff --git a/r_basicsr/archs/dfdnet_util.py b/r_basicsr/archs/dfdnet_util.py new file mode 100644 index 0000000..411e683 --- /dev/null +++ b/r_basicsr/archs/dfdnet_util.py @@ -0,0 +1,162 @@ +import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.nn.utils.spectral_norm import spectral_norm
+
+
+class BlurFunctionBackward(Function):
+
+ @staticmethod
+ def forward(ctx, grad_output, kernel, kernel_flip):
+ ctx.save_for_backward(kernel, kernel_flip)
+ grad_input = F.conv2d(grad_output, kernel_flip, padding=1, groups=grad_output.shape[1])
+ return grad_input
+
+ @staticmethod
+ def backward(ctx, gradgrad_output):
+ kernel, _ = ctx.saved_tensors
+ grad_input = F.conv2d(gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1])
+ return grad_input, None, None
+
+
+class BlurFunction(Function):
+
+ @staticmethod
+ def forward(ctx, x, kernel, kernel_flip):
+ ctx.save_for_backward(kernel, kernel_flip)
+ output = F.conv2d(x, kernel, padding=1, groups=x.shape[1])
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ kernel, kernel_flip = ctx.saved_tensors
+ grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)
+ return grad_input, None, None
+
+
+blur = BlurFunction.apply
+
+
+class Blur(nn.Module):
+
+ def __init__(self, channel):
+ super().__init__()
+ kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
+ kernel = kernel.view(1, 1, 3, 3)
+ kernel = kernel / kernel.sum()
+ kernel_flip = torch.flip(kernel, [2, 3])
+
+ self.kernel = kernel.repeat(channel, 1, 1, 1)
+ self.kernel_flip = kernel_flip.repeat(channel, 1, 1, 1)
+
+ def forward(self, x):
+ return blur(x, self.kernel.type_as(x), self.kernel_flip.type_as(x))
+
+
+def calc_mean_std(feat, eps=1e-5):
+ """Calculate mean and std for adaptive_instance_normalization.
+
+ Args:
+ feat (Tensor): 4D tensor.
+ eps (float): A small value added to the variance to avoid
+ divide-by-zero. Default: 1e-5.
+ """
+ size = feat.size()
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
+ n, c = size[:2]
+ feat_var = feat.view(n, c, -1).var(dim=2) + eps
+ feat_std = feat_var.sqrt().view(n, c, 1, 1)
+ feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1)
+ return feat_mean, feat_std
+
+
+def adaptive_instance_normalization(content_feat, style_feat):
+ """Adaptive instance normalization.
+
+ Adjust the reference features to have the similar color and illuminations
+ as those in the degradate features.
+
+ Args:
+ content_feat (Tensor): The reference feature.
+ style_feat (Tensor): The degradate features.
+ """
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std(style_feat)
+ content_mean, content_std = calc_mean_std(content_feat)
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+
+def AttentionBlock(in_channel):
+ return nn.Sequential(
+ spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
+ spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)))
+
+
+def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True):
+ """Conv block used in MSDilationBlock."""
+
+ return nn.Sequential(
+ spectral_norm(
+ nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ padding=((kernel_size - 1) // 2) * dilation,
+ bias=bias)),
+ nn.LeakyReLU(0.2),
+ spectral_norm(
+ nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ padding=((kernel_size - 1) // 2) * dilation,
+ bias=bias)),
+ )
+
+
+class MSDilationBlock(nn.Module):
+ """Multi-scale dilation block."""
+
+ def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), bias=True):
+ super(MSDilationBlock, self).__init__()
+
+ self.conv_blocks = nn.ModuleList()
+ for i in range(4):
+ self.conv_blocks.append(conv_block(in_channels, in_channels, kernel_size, dilation=dilation[i], bias=bias))
+ self.conv_fusion = spectral_norm(
+ nn.Conv2d(
+ in_channels * 4,
+ in_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ bias=bias))
+
+ def forward(self, x):
+ out = []
+ for i in range(4):
+ out.append(self.conv_blocks[i](x))
+ out = torch.cat(out, 1)
+ out = self.conv_fusion(out) + x
+ return out
+
+
+class UpResBlock(nn.Module):
+
+ def __init__(self, in_channel):
+ super(UpResBlock, self).__init__()
+ self.body = nn.Sequential(
+ nn.Conv2d(in_channel, in_channel, 3, 1, 1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(in_channel, in_channel, 3, 1, 1),
+ )
+
+ def forward(self, x):
+ out = x + self.body(x)
+ return out
diff --git a/r_basicsr/archs/discriminator_arch.py b/r_basicsr/archs/discriminator_arch.py new file mode 100644 index 0000000..2229748 --- /dev/null +++ b/r_basicsr/archs/discriminator_arch.py @@ -0,0 +1,150 @@ +from torch import nn as nn
+from torch.nn import functional as F
+from torch.nn.utils import spectral_norm
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+
+
+@ARCH_REGISTRY.register()
+class VGGStyleDiscriminator(nn.Module):
+ """VGG style discriminator with input size 128 x 128 or 256 x 256.
+
+ It is used to train SRGAN, ESRGAN, and VideoGAN.
+
+ Args:
+ num_in_ch (int): Channel number of inputs. Default: 3.
+ num_feat (int): Channel number of base intermediate features.Default: 64.
+ """
+
+ def __init__(self, num_in_ch, num_feat, input_size=128):
+ super(VGGStyleDiscriminator, self).__init__()
+ self.input_size = input_size
+ assert self.input_size == 128 or self.input_size == 256, (
+ f'input size must be 128 or 256, but received {input_size}')
+
+ self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
+ self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False)
+ self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True)
+
+ self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False)
+ self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True)
+ self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False)
+ self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True)
+
+ self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False)
+ self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True)
+ self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False)
+ self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True)
+
+ self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False)
+ self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
+ self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
+ self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
+
+ self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
+ self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
+ self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
+ self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
+
+ if self.input_size == 256:
+ self.conv5_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
+ self.bn5_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
+ self.conv5_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
+ self.bn5_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
+
+ self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100)
+ self.linear2 = nn.Linear(100, 1)
+
+ # activation function
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x):
+ assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.')
+
+ feat = self.lrelu(self.conv0_0(x))
+ feat = self.lrelu(self.bn0_1(self.conv0_1(feat))) # output spatial size: /2
+
+ feat = self.lrelu(self.bn1_0(self.conv1_0(feat)))
+ feat = self.lrelu(self.bn1_1(self.conv1_1(feat))) # output spatial size: /4
+
+ feat = self.lrelu(self.bn2_0(self.conv2_0(feat)))
+ feat = self.lrelu(self.bn2_1(self.conv2_1(feat))) # output spatial size: /8
+
+ feat = self.lrelu(self.bn3_0(self.conv3_0(feat)))
+ feat = self.lrelu(self.bn3_1(self.conv3_1(feat))) # output spatial size: /16
+
+ feat = self.lrelu(self.bn4_0(self.conv4_0(feat)))
+ feat = self.lrelu(self.bn4_1(self.conv4_1(feat))) # output spatial size: /32
+
+ if self.input_size == 256:
+ feat = self.lrelu(self.bn5_0(self.conv5_0(feat)))
+ feat = self.lrelu(self.bn5_1(self.conv5_1(feat))) # output spatial size: / 64
+
+ # spatial size: (4, 4)
+ feat = feat.view(feat.size(0), -1)
+ feat = self.lrelu(self.linear1(feat))
+ out = self.linear2(feat)
+ return out
+
+
+@ARCH_REGISTRY.register(suffix='basicsr')
+class UNetDiscriminatorSN(nn.Module):
+ """Defines a U-Net discriminator with spectral normalization (SN)
+
+ It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+ Arg:
+ num_in_ch (int): Channel number of inputs. Default: 3.
+ num_feat (int): Channel number of base intermediate features. Default: 64.
+ skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
+ """
+
+ def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
+ super(UNetDiscriminatorSN, self).__init__()
+ self.skip_connection = skip_connection
+ norm = spectral_norm
+ # the first convolution
+ self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
+ # downsample
+ self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
+ self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
+ self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
+ # upsample
+ self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
+ self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
+ self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
+ # extra convolutions
+ self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
+ self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
+ self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
+
+ def forward(self, x):
+ # downsample
+ x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
+ x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
+ x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
+ x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
+
+ # upsample
+ x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
+ x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
+
+ if self.skip_connection:
+ x4 = x4 + x2
+ x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
+ x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
+
+ if self.skip_connection:
+ x5 = x5 + x1
+ x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
+ x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
+
+ if self.skip_connection:
+ x6 = x6 + x0
+
+ # extra convolutions
+ out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
+ out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
+ out = self.conv9(out)
+
+ return out
diff --git a/r_basicsr/archs/duf_arch.py b/r_basicsr/archs/duf_arch.py new file mode 100644 index 0000000..9b963a5 --- /dev/null +++ b/r_basicsr/archs/duf_arch.py @@ -0,0 +1,277 @@ +import numpy as np
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+
+
+class DenseBlocksTemporalReduce(nn.Module):
+ """A concatenation of 3 dense blocks with reduction in temporal dimension.
+
+ Note that the output temporal dimension is 6 fewer the input temporal dimension, since there are 3 blocks.
+
+ Args:
+ num_feat (int): Number of channels in the blocks. Default: 64.
+ num_grow_ch (int): Growing factor of the dense blocks. Default: 32
+ adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
+ Set to false if you want to train from scratch. Default: False.
+ """
+
+ def __init__(self, num_feat=64, num_grow_ch=32, adapt_official_weights=False):
+ super(DenseBlocksTemporalReduce, self).__init__()
+ if adapt_official_weights:
+ eps = 1e-3
+ momentum = 1e-3
+ else: # pytorch default values
+ eps = 1e-05
+ momentum = 0.1
+
+ self.temporal_reduce1 = nn.Sequential(
+ nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+ nn.Conv3d(num_feat, num_feat, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True),
+ nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+ nn.Conv3d(num_feat, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
+
+ self.temporal_reduce2 = nn.Sequential(
+ nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+ nn.Conv3d(
+ num_feat + num_grow_ch,
+ num_feat + num_grow_ch, (1, 1, 1),
+ stride=(1, 1, 1),
+ padding=(0, 0, 0),
+ bias=True), nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+ nn.Conv3d(num_feat + num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
+
+ self.temporal_reduce3 = nn.Sequential(
+ nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+ nn.Conv3d(
+ num_feat + 2 * num_grow_ch,
+ num_feat + 2 * num_grow_ch, (1, 1, 1),
+ stride=(1, 1, 1),
+ padding=(0, 0, 0),
+ bias=True), nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum),
+ nn.ReLU(inplace=True),
+ nn.Conv3d(
+ num_feat + 2 * num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
+
+ Returns:
+ Tensor: Output with shape (b, num_feat + num_grow_ch * 3, 1, h, w).
+ """
+ x1 = self.temporal_reduce1(x)
+ x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1)
+
+ x2 = self.temporal_reduce2(x1)
+ x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1)
+
+ x3 = self.temporal_reduce3(x2)
+ x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1)
+
+ return x3
+
+
+class DenseBlocks(nn.Module):
+ """ A concatenation of N dense blocks.
+
+ Args:
+ num_feat (int): Number of channels in the blocks. Default: 64.
+ num_grow_ch (int): Growing factor of the dense blocks. Default: 32.
+ num_block (int): Number of dense blocks. The values are:
+ DUF-S (16 layers): 3
+ DUF-M (18 layers): 9
+ DUF-L (52 layers): 21
+ adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
+ Set to false if you want to train from scratch. Default: False.
+ """
+
+ def __init__(self, num_block, num_feat=64, num_grow_ch=16, adapt_official_weights=False):
+ super(DenseBlocks, self).__init__()
+ if adapt_official_weights:
+ eps = 1e-3
+ momentum = 1e-3
+ else: # pytorch default values
+ eps = 1e-05
+ momentum = 0.1
+
+ self.dense_blocks = nn.ModuleList()
+ for i in range(0, num_block):
+ self.dense_blocks.append(
+ nn.Sequential(
+ nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+ nn.Conv3d(
+ num_feat + i * num_grow_ch,
+ num_feat + i * num_grow_ch, (1, 1, 1),
+ stride=(1, 1, 1),
+ padding=(0, 0, 0),
+ bias=True), nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum),
+ nn.ReLU(inplace=True),
+ nn.Conv3d(
+ num_feat + i * num_grow_ch,
+ num_grow_ch, (3, 3, 3),
+ stride=(1, 1, 1),
+ padding=(1, 1, 1),
+ bias=True)))
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
+
+ Returns:
+ Tensor: Output with shape (b, num_feat + num_block * num_grow_ch, t, h, w).
+ """
+ for i in range(0, len(self.dense_blocks)):
+ y = self.dense_blocks[i](x)
+ x = torch.cat((x, y), 1)
+ return x
+
+
+class DynamicUpsamplingFilter(nn.Module):
+ """Dynamic upsampling filter used in DUF.
+
+ Ref: https://github.com/yhjo09/VSR-DUF.
+ It only supports input with 3 channels. And it applies the same filters to 3 channels.
+
+ Args:
+ filter_size (tuple): Filter size of generated filters. The shape is (kh, kw). Default: (5, 5).
+ """
+
+ def __init__(self, filter_size=(5, 5)):
+ super(DynamicUpsamplingFilter, self).__init__()
+ if not isinstance(filter_size, tuple):
+ raise TypeError(f'The type of filter_size must be tuple, but got type{filter_size}')
+ if len(filter_size) != 2:
+ raise ValueError(f'The length of filter size must be 2, but got {len(filter_size)}.')
+ # generate a local expansion filter, similar to im2col
+ self.filter_size = filter_size
+ filter_prod = np.prod(filter_size)
+ expansion_filter = torch.eye(int(filter_prod)).view(filter_prod, 1, *filter_size) # (kh*kw, 1, kh, kw)
+ self.expansion_filter = expansion_filter.repeat(3, 1, 1, 1) # repeat for all the 3 channels
+
+ def forward(self, x, filters):
+ """Forward function for DynamicUpsamplingFilter.
+
+ Args:
+ x (Tensor): Input image with 3 channels. The shape is (n, 3, h, w).
+ filters (Tensor): Generated dynamic filters.
+ The shape is (n, filter_prod, upsampling_square, h, w).
+ filter_prod: prod of filter kernel size, e.g., 1*5*5=25.
+ upsampling_square: similar to pixel shuffle,
+ upsampling_square = upsampling * upsampling
+ e.g., for x 4 upsampling, upsampling_square= 4*4 = 16
+
+ Returns:
+ Tensor: Filtered image with shape (n, 3*upsampling_square, h, w)
+ """
+ n, filter_prod, upsampling_square, h, w = filters.size()
+ kh, kw = self.filter_size
+ expanded_input = F.conv2d(
+ x, self.expansion_filter.to(x), padding=(kh // 2, kw // 2), groups=3) # (n, 3*filter_prod, h, w)
+ expanded_input = expanded_input.view(n, 3, filter_prod, h, w).permute(0, 3, 4, 1,
+ 2) # (n, h, w, 3, filter_prod)
+ filters = filters.permute(0, 3, 4, 1, 2) # (n, h, w, filter_prod, upsampling_square]
+ out = torch.matmul(expanded_input, filters) # (n, h, w, 3, upsampling_square)
+ return out.permute(0, 3, 4, 1, 2).view(n, 3 * upsampling_square, h, w)
+
+
+@ARCH_REGISTRY.register()
+class DUF(nn.Module):
+ """Network architecture for DUF
+
+ Paper: Jo et.al. Deep Video Super-Resolution Network Using Dynamic
+ Upsampling Filters Without Explicit Motion Compensation, CVPR, 2018
+ Code reference:
+ https://github.com/yhjo09/VSR-DUF
+ For all the models below, 'adapt_official_weights' is only necessary when
+ loading the weights converted from the official TensorFlow weights.
+ Please set it to False if you are training the model from scratch.
+
+ There are three models with different model size: DUF16Layers, DUF28Layers,
+ and DUF52Layers. This class is the base class for these models.
+
+ Args:
+ scale (int): The upsampling factor. Default: 4.
+ num_layer (int): The number of layers. Default: 52.
+ adapt_official_weights_weights (bool): Whether to adapt the weights
+ translated from the official implementation. Set to false if you
+ want to train from scratch. Default: False.
+ """
+
+ def __init__(self, scale=4, num_layer=52, adapt_official_weights=False):
+ super(DUF, self).__init__()
+ self.scale = scale
+ if adapt_official_weights:
+ eps = 1e-3
+ momentum = 1e-3
+ else: # pytorch default values
+ eps = 1e-05
+ momentum = 0.1
+
+ self.conv3d1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
+ self.dynamic_filter = DynamicUpsamplingFilter((5, 5))
+
+ if num_layer == 16:
+ num_block = 3
+ num_grow_ch = 32
+ elif num_layer == 28:
+ num_block = 9
+ num_grow_ch = 16
+ elif num_layer == 52:
+ num_block = 21
+ num_grow_ch = 16
+ else:
+ raise ValueError(f'Only supported (16, 28, 52) layers, but got {num_layer}.')
+
+ self.dense_block1 = DenseBlocks(
+ num_block=num_block, num_feat=64, num_grow_ch=num_grow_ch,
+ adapt_official_weights=adapt_official_weights) # T = 7
+ self.dense_block2 = DenseBlocksTemporalReduce(
+ 64 + num_grow_ch * num_block, num_grow_ch, adapt_official_weights=adapt_official_weights) # T = 1
+ channels = 64 + num_grow_ch * num_block + num_grow_ch * 3
+ self.bn3d2 = nn.BatchNorm3d(channels, eps=eps, momentum=momentum)
+ self.conv3d2 = nn.Conv3d(channels, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
+
+ self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
+ self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
+
+ self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
+ self.conv3d_f2 = nn.Conv3d(
+ 512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): Input with shape (b, 7, c, h, w)
+
+ Returns:
+ Tensor: Output with shape (b, c, h * scale, w * scale)
+ """
+ num_batches, num_imgs, _, h, w = x.size()
+
+ x = x.permute(0, 2, 1, 3, 4) # (b, c, 7, h, w) for Conv3D
+ x_center = x[:, :, num_imgs // 2, :, :]
+
+ x = self.conv3d1(x)
+ x = self.dense_block1(x)
+ x = self.dense_block2(x)
+ x = F.relu(self.bn3d2(x), inplace=True)
+ x = F.relu(self.conv3d2(x), inplace=True)
+
+ # residual image
+ res = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True))
+
+ # filter
+ filter_ = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True))
+ filter_ = F.softmax(filter_.view(num_batches, 25, self.scale**2, h, w), dim=1)
+
+ # dynamic filter
+ out = self.dynamic_filter(x_center, filter_)
+ out += res.squeeze_(2)
+ out = F.pixel_shuffle(out, self.scale)
+
+ return out
diff --git a/r_basicsr/archs/ecbsr_arch.py b/r_basicsr/archs/ecbsr_arch.py new file mode 100644 index 0000000..9eb1d75 --- /dev/null +++ b/r_basicsr/archs/ecbsr_arch.py @@ -0,0 +1,274 @@ +import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+
+
+class SeqConv3x3(nn.Module):
+ """The re-parameterizable block used in the ECBSR architecture.
+
+ Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
+ Ref git repo: https://github.com/xindongzhang/ECBSR
+
+ Args:
+ seq_type (str): Sequence type, option: conv1x1-conv3x3 | conv1x1-sobelx | conv1x1-sobely | conv1x1-laplacian.
+ in_channels (int): Channel number of input.
+ out_channels (int): Channel number of output.
+ depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
+ """
+
+ def __init__(self, seq_type, in_channels, out_channels, depth_multiplier=1):
+ super(SeqConv3x3, self).__init__()
+ self.seq_type = seq_type
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ if self.seq_type == 'conv1x1-conv3x3':
+ self.mid_planes = int(out_channels * depth_multiplier)
+ conv0 = torch.nn.Conv2d(self.in_channels, self.mid_planes, kernel_size=1, padding=0)
+ self.k0 = conv0.weight
+ self.b0 = conv0.bias
+
+ conv1 = torch.nn.Conv2d(self.mid_planes, self.out_channels, kernel_size=3)
+ self.k1 = conv1.weight
+ self.b1 = conv1.bias
+
+ elif self.seq_type == 'conv1x1-sobelx':
+ conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
+ self.k0 = conv0.weight
+ self.b0 = conv0.bias
+
+ # init scale and bias
+ scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
+ self.scale = nn.Parameter(scale)
+ bias = torch.randn(self.out_channels) * 1e-3
+ bias = torch.reshape(bias, (self.out_channels, ))
+ self.bias = nn.Parameter(bias)
+ # init mask
+ self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
+ for i in range(self.out_channels):
+ self.mask[i, 0, 0, 0] = 1.0
+ self.mask[i, 0, 1, 0] = 2.0
+ self.mask[i, 0, 2, 0] = 1.0
+ self.mask[i, 0, 0, 2] = -1.0
+ self.mask[i, 0, 1, 2] = -2.0
+ self.mask[i, 0, 2, 2] = -1.0
+ self.mask = nn.Parameter(data=self.mask, requires_grad=False)
+
+ elif self.seq_type == 'conv1x1-sobely':
+ conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
+ self.k0 = conv0.weight
+ self.b0 = conv0.bias
+
+ # init scale and bias
+ scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
+ self.scale = nn.Parameter(torch.FloatTensor(scale))
+ bias = torch.randn(self.out_channels) * 1e-3
+ bias = torch.reshape(bias, (self.out_channels, ))
+ self.bias = nn.Parameter(torch.FloatTensor(bias))
+ # init mask
+ self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
+ for i in range(self.out_channels):
+ self.mask[i, 0, 0, 0] = 1.0
+ self.mask[i, 0, 0, 1] = 2.0
+ self.mask[i, 0, 0, 2] = 1.0
+ self.mask[i, 0, 2, 0] = -1.0
+ self.mask[i, 0, 2, 1] = -2.0
+ self.mask[i, 0, 2, 2] = -1.0
+ self.mask = nn.Parameter(data=self.mask, requires_grad=False)
+
+ elif self.seq_type == 'conv1x1-laplacian':
+ conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
+ self.k0 = conv0.weight
+ self.b0 = conv0.bias
+
+ # init scale and bias
+ scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
+ self.scale = nn.Parameter(torch.FloatTensor(scale))
+ bias = torch.randn(self.out_channels) * 1e-3
+ bias = torch.reshape(bias, (self.out_channels, ))
+ self.bias = nn.Parameter(torch.FloatTensor(bias))
+ # init mask
+ self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
+ for i in range(self.out_channels):
+ self.mask[i, 0, 0, 1] = 1.0
+ self.mask[i, 0, 1, 0] = 1.0
+ self.mask[i, 0, 1, 2] = 1.0
+ self.mask[i, 0, 2, 1] = 1.0
+ self.mask[i, 0, 1, 1] = -4.0
+ self.mask = nn.Parameter(data=self.mask, requires_grad=False)
+ else:
+ raise ValueError('The type of seqconv is not supported!')
+
+ def forward(self, x):
+ if self.seq_type == 'conv1x1-conv3x3':
+ # conv-1x1
+ y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
+ # explicitly padding with bias
+ y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
+ b0_pad = self.b0.view(1, -1, 1, 1)
+ y0[:, :, 0:1, :] = b0_pad
+ y0[:, :, -1:, :] = b0_pad
+ y0[:, :, :, 0:1] = b0_pad
+ y0[:, :, :, -1:] = b0_pad
+ # conv-3x3
+ y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
+ else:
+ y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
+ # explicitly padding with bias
+ y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
+ b0_pad = self.b0.view(1, -1, 1, 1)
+ y0[:, :, 0:1, :] = b0_pad
+ y0[:, :, -1:, :] = b0_pad
+ y0[:, :, :, 0:1] = b0_pad
+ y0[:, :, :, -1:] = b0_pad
+ # conv-3x3
+ y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_channels)
+ return y1
+
+ def rep_params(self):
+ device = self.k0.get_device()
+ if device < 0:
+ device = None
+
+ if self.seq_type == 'conv1x1-conv3x3':
+ # re-param conv kernel
+ rep_weight = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
+ # re-param conv bias
+ rep_bias = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
+ rep_bias = F.conv2d(input=rep_bias, weight=self.k1).view(-1, ) + self.b1
+ else:
+ tmp = self.scale * self.mask
+ k1 = torch.zeros((self.out_channels, self.out_channels, 3, 3), device=device)
+ for i in range(self.out_channels):
+ k1[i, i, :, :] = tmp[i, 0, :, :]
+ b1 = self.bias
+ # re-param conv kernel
+ rep_weight = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
+ # re-param conv bias
+ rep_bias = torch.ones(1, self.out_channels, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
+ rep_bias = F.conv2d(input=rep_bias, weight=k1).view(-1, ) + b1
+ return rep_weight, rep_bias
+
+
+class ECB(nn.Module):
+ """The ECB block used in the ECBSR architecture.
+
+ Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
+ Ref git repo: https://github.com/xindongzhang/ECBSR
+
+ Args:
+ in_channels (int): Channel number of input.
+ out_channels (int): Channel number of output.
+ depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
+ act_type (str): Activation type. Option: prelu | relu | rrelu | softplus | linear. Default: prelu.
+ with_idt (bool): Whether to use identity connection. Default: False.
+ """
+
+ def __init__(self, in_channels, out_channels, depth_multiplier, act_type='prelu', with_idt=False):
+ super(ECB, self).__init__()
+
+ self.depth_multiplier = depth_multiplier
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.act_type = act_type
+
+ if with_idt and (self.in_channels == self.out_channels):
+ self.with_idt = True
+ else:
+ self.with_idt = False
+
+ self.conv3x3 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1)
+ self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.in_channels, self.out_channels, self.depth_multiplier)
+ self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.in_channels, self.out_channels)
+ self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.in_channels, self.out_channels)
+ self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.in_channels, self.out_channels)
+
+ if self.act_type == 'prelu':
+ self.act = nn.PReLU(num_parameters=self.out_channels)
+ elif self.act_type == 'relu':
+ self.act = nn.ReLU(inplace=True)
+ elif self.act_type == 'rrelu':
+ self.act = nn.RReLU(lower=-0.05, upper=0.05)
+ elif self.act_type == 'softplus':
+ self.act = nn.Softplus()
+ elif self.act_type == 'linear':
+ pass
+ else:
+ raise ValueError('The type of activation if not support!')
+
+ def forward(self, x):
+ if self.training:
+ y = self.conv3x3(x) + self.conv1x1_3x3(x) + self.conv1x1_sbx(x) + self.conv1x1_sby(x) + self.conv1x1_lpl(x)
+ if self.with_idt:
+ y += x
+ else:
+ rep_weight, rep_bias = self.rep_params()
+ y = F.conv2d(input=x, weight=rep_weight, bias=rep_bias, stride=1, padding=1)
+ if self.act_type != 'linear':
+ y = self.act(y)
+ return y
+
+ def rep_params(self):
+ weight0, bias0 = self.conv3x3.weight, self.conv3x3.bias
+ weight1, bias1 = self.conv1x1_3x3.rep_params()
+ weight2, bias2 = self.conv1x1_sbx.rep_params()
+ weight3, bias3 = self.conv1x1_sby.rep_params()
+ weight4, bias4 = self.conv1x1_lpl.rep_params()
+ rep_weight, rep_bias = (weight0 + weight1 + weight2 + weight3 + weight4), (
+ bias0 + bias1 + bias2 + bias3 + bias4)
+
+ if self.with_idt:
+ device = rep_weight.get_device()
+ if device < 0:
+ device = None
+ weight_idt = torch.zeros(self.out_channels, self.out_channels, 3, 3, device=device)
+ for i in range(self.out_channels):
+ weight_idt[i, i, 1, 1] = 1.0
+ bias_idt = 0.0
+ rep_weight, rep_bias = rep_weight + weight_idt, rep_bias + bias_idt
+ return rep_weight, rep_bias
+
+
+@ARCH_REGISTRY.register()
+class ECBSR(nn.Module):
+ """ECBSR architecture.
+
+ Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
+ Ref git repo: https://github.com/xindongzhang/ECBSR
+
+ Args:
+ num_in_ch (int): Channel number of inputs.
+ num_out_ch (int): Channel number of outputs.
+ num_block (int): Block number in the trunk network.
+ num_channel (int): Channel number.
+ with_idt (bool): Whether use identity in convolution layers.
+ act_type (str): Activation type.
+ scale (int): Upsampling factor.
+ """
+
+ def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale):
+ super(ECBSR, self).__init__()
+ self.num_in_ch = num_in_ch
+ self.scale = scale
+
+ backbone = []
+ backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
+ for _ in range(num_block):
+ backbone += [ECB(num_channel, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
+ backbone += [
+ ECB(num_channel, num_out_ch * scale * scale, depth_multiplier=2.0, act_type='linear', with_idt=with_idt)
+ ]
+
+ self.backbone = nn.Sequential(*backbone)
+ self.upsampler = nn.PixelShuffle(scale)
+
+ def forward(self, x):
+ if self.num_in_ch > 1:
+ shortcut = torch.repeat_interleave(x, self.scale * self.scale, dim=1)
+ else:
+ shortcut = x # will repeat the input in the channel dimension (repeat scale * scale times)
+ y = self.backbone(x) + shortcut
+ y = self.upsampler(y)
+ return y
diff --git a/r_basicsr/archs/edsr_arch.py b/r_basicsr/archs/edsr_arch.py new file mode 100644 index 0000000..4990b2c --- /dev/null +++ b/r_basicsr/archs/edsr_arch.py @@ -0,0 +1,61 @@ +import torch
+from torch import nn as nn
+
+from r_basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer
+from r_basicsr.utils.registry import ARCH_REGISTRY
+
+
+@ARCH_REGISTRY.register()
+class EDSR(nn.Module):
+ """EDSR network structure.
+
+ Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution.
+ Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch
+
+ Args:
+ num_in_ch (int): Channel number of inputs.
+ num_out_ch (int): Channel number of outputs.
+ num_feat (int): Channel number of intermediate features.
+ Default: 64.
+ num_block (int): Block number in the trunk network. Default: 16.
+ upscale (int): Upsampling factor. Support 2^n and 3.
+ Default: 4.
+ res_scale (float): Used to scale the residual in residual block.
+ Default: 1.
+ img_range (float): Image range. Default: 255.
+ rgb_mean (tuple[float]): Image mean in RGB orders.
+ Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
+ """
+
+ def __init__(self,
+ num_in_ch,
+ num_out_ch,
+ num_feat=64,
+ num_block=16,
+ upscale=4,
+ res_scale=1,
+ img_range=255.,
+ rgb_mean=(0.4488, 0.4371, 0.4040)):
+ super(EDSR, self).__init__()
+
+ self.img_range = img_range
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
+
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True)
+ self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ def forward(self, x):
+ self.mean = self.mean.type_as(x)
+
+ x = (x - self.mean) * self.img_range
+ x = self.conv_first(x)
+ res = self.conv_after_body(self.body(x))
+ res += x
+
+ x = self.conv_last(self.upsample(res))
+ x = x / self.img_range + self.mean
+
+ return x
diff --git a/r_basicsr/archs/edvr_arch.py b/r_basicsr/archs/edvr_arch.py new file mode 100644 index 0000000..401b9b2 --- /dev/null +++ b/r_basicsr/archs/edvr_arch.py @@ -0,0 +1,383 @@ +import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import DCNv2Pack, ResidualBlockNoBN, make_layer
+
+
+class PCDAlignment(nn.Module):
+ """Alignment module using Pyramid, Cascading and Deformable convolution
+ (PCD). It is used in EDVR.
+
+ Ref:
+ EDVR: Video Restoration with Enhanced Deformable Convolutional Networks
+
+ Args:
+ num_feat (int): Channel number of middle features. Default: 64.
+ deformable_groups (int): Deformable groups. Defaults: 8.
+ """
+
+ def __init__(self, num_feat=64, deformable_groups=8):
+ super(PCDAlignment, self).__init__()
+
+ # Pyramid has three levels:
+ # L3: level 3, 1/4 spatial size
+ # L2: level 2, 1/2 spatial size
+ # L1: level 1, original spatial size
+ self.offset_conv1 = nn.ModuleDict()
+ self.offset_conv2 = nn.ModuleDict()
+ self.offset_conv3 = nn.ModuleDict()
+ self.dcn_pack = nn.ModuleDict()
+ self.feat_conv = nn.ModuleDict()
+
+ # Pyramids
+ for i in range(3, 0, -1):
+ level = f'l{i}'
+ self.offset_conv1[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
+ if i == 3:
+ self.offset_conv2[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ else:
+ self.offset_conv2[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
+ self.offset_conv3[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.dcn_pack[level] = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
+
+ if i < 3:
+ self.feat_conv[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
+
+ # Cascading dcn
+ self.cas_offset_conv1 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
+ self.cas_offset_conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.cas_dcnpack = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
+
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ def forward(self, nbr_feat_l, ref_feat_l):
+ """Align neighboring frame features to the reference frame features.
+
+ Args:
+ nbr_feat_l (list[Tensor]): Neighboring feature list. It
+ contains three pyramid levels (L1, L2, L3),
+ each with shape (b, c, h, w).
+ ref_feat_l (list[Tensor]): Reference feature list. It
+ contains three pyramid levels (L1, L2, L3),
+ each with shape (b, c, h, w).
+
+ Returns:
+ Tensor: Aligned features.
+ """
+ # Pyramids
+ upsampled_offset, upsampled_feat = None, None
+ for i in range(3, 0, -1):
+ level = f'l{i}'
+ offset = torch.cat([nbr_feat_l[i - 1], ref_feat_l[i - 1]], dim=1)
+ offset = self.lrelu(self.offset_conv1[level](offset))
+ if i == 3:
+ offset = self.lrelu(self.offset_conv2[level](offset))
+ else:
+ offset = self.lrelu(self.offset_conv2[level](torch.cat([offset, upsampled_offset], dim=1)))
+ offset = self.lrelu(self.offset_conv3[level](offset))
+
+ feat = self.dcn_pack[level](nbr_feat_l[i - 1], offset)
+ if i < 3:
+ feat = self.feat_conv[level](torch.cat([feat, upsampled_feat], dim=1))
+ if i > 1:
+ feat = self.lrelu(feat)
+
+ if i > 1: # upsample offset and features
+ # x2: when we upsample the offset, we should also enlarge
+ # the magnitude.
+ upsampled_offset = self.upsample(offset) * 2
+ upsampled_feat = self.upsample(feat)
+
+ # Cascading
+ offset = torch.cat([feat, ref_feat_l[0]], dim=1)
+ offset = self.lrelu(self.cas_offset_conv2(self.lrelu(self.cas_offset_conv1(offset))))
+ feat = self.lrelu(self.cas_dcnpack(feat, offset))
+ return feat
+
+
+class TSAFusion(nn.Module):
+ """Temporal Spatial Attention (TSA) fusion module.
+
+ Temporal: Calculate the correlation between center frame and
+ neighboring frames;
+ Spatial: It has 3 pyramid levels, the attention is similar to SFT.
+ (SFT: Recovering realistic texture in image super-resolution by deep
+ spatial feature transform.)
+
+ Args:
+ num_feat (int): Channel number of middle features. Default: 64.
+ num_frame (int): Number of frames. Default: 5.
+ center_frame_idx (int): The index of center frame. Default: 2.
+ """
+
+ def __init__(self, num_feat=64, num_frame=5, center_frame_idx=2):
+ super(TSAFusion, self).__init__()
+ self.center_frame_idx = center_frame_idx
+ # temporal attention (before fusion conv)
+ self.temporal_attn1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.temporal_attn2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.feat_fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
+
+ # spatial attention (after fusion conv)
+ self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
+ self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)
+ self.spatial_attn1 = nn.Conv2d(num_frame * num_feat, num_feat, 1)
+ self.spatial_attn2 = nn.Conv2d(num_feat * 2, num_feat, 1)
+ self.spatial_attn3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.spatial_attn4 = nn.Conv2d(num_feat, num_feat, 1)
+ self.spatial_attn5 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.spatial_attn_l1 = nn.Conv2d(num_feat, num_feat, 1)
+ self.spatial_attn_l2 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
+ self.spatial_attn_l3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.spatial_attn_add1 = nn.Conv2d(num_feat, num_feat, 1)
+ self.spatial_attn_add2 = nn.Conv2d(num_feat, num_feat, 1)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
+
+ def forward(self, aligned_feat):
+ """
+ Args:
+ aligned_feat (Tensor): Aligned features with shape (b, t, c, h, w).
+
+ Returns:
+ Tensor: Features after TSA with the shape (b, c, h, w).
+ """
+ b, t, c, h, w = aligned_feat.size()
+ # temporal attention
+ embedding_ref = self.temporal_attn1(aligned_feat[:, self.center_frame_idx, :, :, :].clone())
+ embedding = self.temporal_attn2(aligned_feat.view(-1, c, h, w))
+ embedding = embedding.view(b, t, -1, h, w) # (b, t, c, h, w)
+
+ corr_l = [] # correlation list
+ for i in range(t):
+ emb_neighbor = embedding[:, i, :, :, :]
+ corr = torch.sum(emb_neighbor * embedding_ref, 1) # (b, h, w)
+ corr_l.append(corr.unsqueeze(1)) # (b, 1, h, w)
+ corr_prob = torch.sigmoid(torch.cat(corr_l, dim=1)) # (b, t, h, w)
+ corr_prob = corr_prob.unsqueeze(2).expand(b, t, c, h, w)
+ corr_prob = corr_prob.contiguous().view(b, -1, h, w) # (b, t*c, h, w)
+ aligned_feat = aligned_feat.view(b, -1, h, w) * corr_prob
+
+ # fusion
+ feat = self.lrelu(self.feat_fusion(aligned_feat))
+
+ # spatial attention
+ attn = self.lrelu(self.spatial_attn1(aligned_feat))
+ attn_max = self.max_pool(attn)
+ attn_avg = self.avg_pool(attn)
+ attn = self.lrelu(self.spatial_attn2(torch.cat([attn_max, attn_avg], dim=1)))
+ # pyramid levels
+ attn_level = self.lrelu(self.spatial_attn_l1(attn))
+ attn_max = self.max_pool(attn_level)
+ attn_avg = self.avg_pool(attn_level)
+ attn_level = self.lrelu(self.spatial_attn_l2(torch.cat([attn_max, attn_avg], dim=1)))
+ attn_level = self.lrelu(self.spatial_attn_l3(attn_level))
+ attn_level = self.upsample(attn_level)
+
+ attn = self.lrelu(self.spatial_attn3(attn)) + attn_level
+ attn = self.lrelu(self.spatial_attn4(attn))
+ attn = self.upsample(attn)
+ attn = self.spatial_attn5(attn)
+ attn_add = self.spatial_attn_add2(self.lrelu(self.spatial_attn_add1(attn)))
+ attn = torch.sigmoid(attn)
+
+ # after initialization, * 2 makes (attn * 2) to be close to 1.
+ feat = feat * attn * 2 + attn_add
+ return feat
+
+
+class PredeblurModule(nn.Module):
+ """Pre-dublur module.
+
+ Args:
+ num_in_ch (int): Channel number of input image. Default: 3.
+ num_feat (int): Channel number of intermediate features. Default: 64.
+ hr_in (bool): Whether the input has high resolution. Default: False.
+ """
+
+ def __init__(self, num_in_ch=3, num_feat=64, hr_in=False):
+ super(PredeblurModule, self).__init__()
+ self.hr_in = hr_in
+
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ if self.hr_in:
+ # downsample x4 by stride conv
+ self.stride_conv_hr1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+ self.stride_conv_hr2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+
+ # generate feature pyramid
+ self.stride_conv_l2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+ self.stride_conv_l3 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+
+ self.resblock_l3 = ResidualBlockNoBN(num_feat=num_feat)
+ self.resblock_l2_1 = ResidualBlockNoBN(num_feat=num_feat)
+ self.resblock_l2_2 = ResidualBlockNoBN(num_feat=num_feat)
+ self.resblock_l1 = nn.ModuleList([ResidualBlockNoBN(num_feat=num_feat) for i in range(5)])
+
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ def forward(self, x):
+ feat_l1 = self.lrelu(self.conv_first(x))
+ if self.hr_in:
+ feat_l1 = self.lrelu(self.stride_conv_hr1(feat_l1))
+ feat_l1 = self.lrelu(self.stride_conv_hr2(feat_l1))
+
+ # generate feature pyramid
+ feat_l2 = self.lrelu(self.stride_conv_l2(feat_l1))
+ feat_l3 = self.lrelu(self.stride_conv_l3(feat_l2))
+
+ feat_l3 = self.upsample(self.resblock_l3(feat_l3))
+ feat_l2 = self.resblock_l2_1(feat_l2) + feat_l3
+ feat_l2 = self.upsample(self.resblock_l2_2(feat_l2))
+
+ for i in range(2):
+ feat_l1 = self.resblock_l1[i](feat_l1)
+ feat_l1 = feat_l1 + feat_l2
+ for i in range(2, 5):
+ feat_l1 = self.resblock_l1[i](feat_l1)
+ return feat_l1
+
+
+@ARCH_REGISTRY.register()
+class EDVR(nn.Module):
+ """EDVR network structure for video super-resolution.
+
+ Now only support X4 upsampling factor.
+ Paper:
+ EDVR: Video Restoration with Enhanced Deformable Convolutional Networks
+
+ Args:
+ num_in_ch (int): Channel number of input image. Default: 3.
+ num_out_ch (int): Channel number of output image. Default: 3.
+ num_feat (int): Channel number of intermediate features. Default: 64.
+ num_frame (int): Number of input frames. Default: 5.
+ deformable_groups (int): Deformable groups. Defaults: 8.
+ num_extract_block (int): Number of blocks for feature extraction.
+ Default: 5.
+ num_reconstruct_block (int): Number of blocks for reconstruction.
+ Default: 10.
+ center_frame_idx (int): The index of center frame. Frame counting from
+ 0. Default: Middle of input frames.
+ hr_in (bool): Whether the input has high resolution. Default: False.
+ with_predeblur (bool): Whether has predeblur module.
+ Default: False.
+ with_tsa (bool): Whether has TSA module. Default: True.
+ """
+
+ def __init__(self,
+ num_in_ch=3,
+ num_out_ch=3,
+ num_feat=64,
+ num_frame=5,
+ deformable_groups=8,
+ num_extract_block=5,
+ num_reconstruct_block=10,
+ center_frame_idx=None,
+ hr_in=False,
+ with_predeblur=False,
+ with_tsa=True):
+ super(EDVR, self).__init__()
+ if center_frame_idx is None:
+ self.center_frame_idx = num_frame // 2
+ else:
+ self.center_frame_idx = center_frame_idx
+ self.hr_in = hr_in
+ self.with_predeblur = with_predeblur
+ self.with_tsa = with_tsa
+
+ # extract features for each frame
+ if self.with_predeblur:
+ self.predeblur = PredeblurModule(num_feat=num_feat, hr_in=self.hr_in)
+ self.conv_1x1 = nn.Conv2d(num_feat, num_feat, 1, 1)
+ else:
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+
+ # extract pyramid features
+ self.feature_extraction = make_layer(ResidualBlockNoBN, num_extract_block, num_feat=num_feat)
+ self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+ self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+ self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+
+ # pcd and tsa module
+ self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=deformable_groups)
+ if self.with_tsa:
+ self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_frame, center_frame_idx=self.center_frame_idx)
+ else:
+ self.fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
+
+ # reconstruction
+ self.reconstruction = make_layer(ResidualBlockNoBN, num_reconstruct_block, num_feat=num_feat)
+ # upsample
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
+ self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1)
+ self.pixel_shuffle = nn.PixelShuffle(2)
+ self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
+ self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
+
+ # activation function
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ def forward(self, x):
+ b, t, c, h, w = x.size()
+ if self.hr_in:
+ assert h % 16 == 0 and w % 16 == 0, ('The height and width must be multiple of 16.')
+ else:
+ assert h % 4 == 0 and w % 4 == 0, ('The height and width must be multiple of 4.')
+
+ x_center = x[:, self.center_frame_idx, :, :, :].contiguous()
+
+ # extract features for each frame
+ # L1
+ if self.with_predeblur:
+ feat_l1 = self.conv_1x1(self.predeblur(x.view(-1, c, h, w)))
+ if self.hr_in:
+ h, w = h // 4, w // 4
+ else:
+ feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
+
+ feat_l1 = self.feature_extraction(feat_l1)
+ # L2
+ feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
+ feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
+ # L3
+ feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
+ feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
+
+ feat_l1 = feat_l1.view(b, t, -1, h, w)
+ feat_l2 = feat_l2.view(b, t, -1, h // 2, w // 2)
+ feat_l3 = feat_l3.view(b, t, -1, h // 4, w // 4)
+
+ # PCD alignment
+ ref_feat_l = [ # reference feature list
+ feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
+ feat_l3[:, self.center_frame_idx, :, :, :].clone()
+ ]
+ aligned_feat = []
+ for i in range(t):
+ nbr_feat_l = [ # neighboring feature list
+ feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
+ ]
+ aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
+ aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
+
+ if not self.with_tsa:
+ aligned_feat = aligned_feat.view(b, -1, h, w)
+ feat = self.fusion(aligned_feat)
+
+ out = self.reconstruction(feat)
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
+ out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
+ out = self.lrelu(self.conv_hr(out))
+ out = self.conv_last(out)
+ if self.hr_in:
+ base = x_center
+ else:
+ base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
+ out += base
+ return out
diff --git a/r_basicsr/archs/hifacegan_arch.py b/r_basicsr/archs/hifacegan_arch.py new file mode 100644 index 0000000..58df7e7 --- /dev/null +++ b/r_basicsr/archs/hifacegan_arch.py @@ -0,0 +1,259 @@ +import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+from .hifacegan_util import BaseNetwork, LIPEncoder, SPADEResnetBlock, get_nonspade_norm_layer
+
+
+class SPADEGenerator(BaseNetwork):
+ """Generator with SPADEResBlock"""
+
+ def __init__(self,
+ num_in_ch=3,
+ num_feat=64,
+ use_vae=False,
+ z_dim=256,
+ crop_size=512,
+ norm_g='spectralspadesyncbatch3x3',
+ is_train=True,
+ init_train_phase=3): # progressive training disabled
+ super().__init__()
+ self.nf = num_feat
+ self.input_nc = num_in_ch
+ self.is_train = is_train
+ self.train_phase = init_train_phase
+
+ self.scale_ratio = 5 # hardcoded now
+ self.sw = crop_size // (2**self.scale_ratio)
+ self.sh = self.sw # 20210519: By default use square image, aspect_ratio = 1.0
+
+ if use_vae:
+ # In case of VAE, we will sample from random z vector
+ self.fc = nn.Linear(z_dim, 16 * self.nf * self.sw * self.sh)
+ else:
+ # Otherwise, we make the network deterministic by starting with
+ # downsampled segmentation map instead of random z
+ self.fc = nn.Conv2d(num_in_ch, 16 * self.nf, 3, padding=1)
+
+ self.head_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
+
+ self.g_middle_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
+ self.g_middle_1 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
+
+ self.ups = nn.ModuleList([
+ SPADEResnetBlock(16 * self.nf, 8 * self.nf, norm_g),
+ SPADEResnetBlock(8 * self.nf, 4 * self.nf, norm_g),
+ SPADEResnetBlock(4 * self.nf, 2 * self.nf, norm_g),
+ SPADEResnetBlock(2 * self.nf, 1 * self.nf, norm_g)
+ ])
+
+ self.to_rgbs = nn.ModuleList([
+ nn.Conv2d(8 * self.nf, 3, 3, padding=1),
+ nn.Conv2d(4 * self.nf, 3, 3, padding=1),
+ nn.Conv2d(2 * self.nf, 3, 3, padding=1),
+ nn.Conv2d(1 * self.nf, 3, 3, padding=1)
+ ])
+
+ self.up = nn.Upsample(scale_factor=2)
+
+ def encode(self, input_tensor):
+ """
+ Encode input_tensor into feature maps, can be overridden in derived classes
+ Default: nearest downsampling of 2**5 = 32 times
+ """
+ h, w = input_tensor.size()[-2:]
+ sh, sw = h // 2**self.scale_ratio, w // 2**self.scale_ratio
+ x = F.interpolate(input_tensor, size=(sh, sw))
+ return self.fc(x)
+
+ def forward(self, x):
+ # In oroginal SPADE, seg means a segmentation map, but here we use x instead.
+ seg = x
+
+ x = self.encode(x)
+ x = self.head_0(x, seg)
+
+ x = self.up(x)
+ x = self.g_middle_0(x, seg)
+ x = self.g_middle_1(x, seg)
+
+ if self.is_train:
+ phase = self.train_phase + 1
+ else:
+ phase = len(self.to_rgbs)
+
+ for i in range(phase):
+ x = self.up(x)
+ x = self.ups[i](x, seg)
+
+ x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
+ x = torch.tanh(x)
+
+ return x
+
+ def mixed_guidance_forward(self, input_x, seg=None, n=0, mode='progressive'):
+ """
+ A helper class for subspace visualization. Input and seg are different images.
+ For the first n levels (including encoder) we use input, for the rest we use seg.
+
+ If mode = 'progressive', the output's like: AAABBB
+ If mode = 'one_plug', the output's like: AAABAA
+ If mode = 'one_ablate', the output's like: BBBABB
+ """
+
+ if seg is None:
+ return self.forward(input_x)
+
+ if self.is_train:
+ phase = self.train_phase + 1
+ else:
+ phase = len(self.to_rgbs)
+
+ if mode == 'progressive':
+ n = max(min(n, 4 + phase), 0)
+ guide_list = [input_x] * n + [seg] * (4 + phase - n)
+ elif mode == 'one_plug':
+ n = max(min(n, 4 + phase - 1), 0)
+ guide_list = [seg] * (4 + phase)
+ guide_list[n] = input_x
+ elif mode == 'one_ablate':
+ if n > 3 + phase:
+ return self.forward(input_x)
+ guide_list = [input_x] * (4 + phase)
+ guide_list[n] = seg
+
+ x = self.encode(guide_list[0])
+ x = self.head_0(x, guide_list[1])
+
+ x = self.up(x)
+ x = self.g_middle_0(x, guide_list[2])
+ x = self.g_middle_1(x, guide_list[3])
+
+ for i in range(phase):
+ x = self.up(x)
+ x = self.ups[i](x, guide_list[4 + i])
+
+ x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
+ x = torch.tanh(x)
+
+ return x
+
+
+@ARCH_REGISTRY.register()
+class HiFaceGAN(SPADEGenerator):
+ """
+ HiFaceGAN: SPADEGenerator with a learnable feature encoder
+ Current encoder design: LIPEncoder
+ """
+
+ def __init__(self,
+ num_in_ch=3,
+ num_feat=64,
+ use_vae=False,
+ z_dim=256,
+ crop_size=512,
+ norm_g='spectralspadesyncbatch3x3',
+ is_train=True,
+ init_train_phase=3):
+ super().__init__(num_in_ch, num_feat, use_vae, z_dim, crop_size, norm_g, is_train, init_train_phase)
+ self.lip_encoder = LIPEncoder(num_in_ch, num_feat, self.sw, self.sh, self.scale_ratio)
+
+ def encode(self, input_tensor):
+ return self.lip_encoder(input_tensor)
+
+
+@ARCH_REGISTRY.register()
+class HiFaceGANDiscriminator(BaseNetwork):
+ """
+ Inspired by pix2pixHD multiscale discriminator.
+ Args:
+ num_in_ch (int): Channel number of inputs. Default: 3.
+ num_out_ch (int): Channel number of outputs. Default: 3.
+ conditional_d (bool): Whether use conditional discriminator.
+ Default: True.
+ num_d (int): Number of Multiscale discriminators. Default: 3.
+ n_layers_d (int): Number of downsample layers in each D. Default: 4.
+ num_feat (int): Channel number of base intermediate features.
+ Default: 64.
+ norm_d (str): String to determine normalization layers in D.
+ Choices: [spectral][instance/batch/syncbatch]
+ Default: 'spectralinstance'.
+ keep_features (bool): Keep intermediate features for matching loss, etc.
+ Default: True.
+ """
+
+ def __init__(self,
+ num_in_ch=3,
+ num_out_ch=3,
+ conditional_d=True,
+ num_d=2,
+ n_layers_d=4,
+ num_feat=64,
+ norm_d='spectralinstance',
+ keep_features=True):
+ super().__init__()
+ self.num_d = num_d
+
+ input_nc = num_in_ch
+ if conditional_d:
+ input_nc += num_out_ch
+
+ for i in range(num_d):
+ subnet_d = NLayerDiscriminator(input_nc, n_layers_d, num_feat, norm_d, keep_features)
+ self.add_module(f'discriminator_{i}', subnet_d)
+
+ def downsample(self, x):
+ return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
+
+ # Returns list of lists of discriminator outputs.
+ # The final result is of size opt.num_d x opt.n_layers_D
+ def forward(self, x):
+ result = []
+ for _, _net_d in self.named_children():
+ out = _net_d(x)
+ result.append(out)
+ x = self.downsample(x)
+
+ return result
+
+
+class NLayerDiscriminator(BaseNetwork):
+ """Defines the PatchGAN discriminator with the specified arguments."""
+
+ def __init__(self, input_nc, n_layers_d, num_feat, norm_d, keep_features):
+ super().__init__()
+ kw = 4
+ padw = int(np.ceil((kw - 1.0) / 2))
+ nf = num_feat
+ self.keep_features = keep_features
+
+ norm_layer = get_nonspade_norm_layer(norm_d)
+ sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]]
+
+ for n in range(1, n_layers_d):
+ nf_prev = nf
+ nf = min(nf * 2, 512)
+ stride = 1 if n == n_layers_d - 1 else 2
+ sequence += [[
+ norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)),
+ nn.LeakyReLU(0.2, False)
+ ]]
+
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
+
+ # We divide the layers into groups to extract intermediate layer outputs
+ for n in range(len(sequence)):
+ self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
+
+ def forward(self, x):
+ results = [x]
+ for submodel in self.children():
+ intermediate_output = submodel(results[-1])
+ results.append(intermediate_output)
+
+ if self.keep_features:
+ return results[1:]
+ else:
+ return results[-1]
diff --git a/r_basicsr/archs/hifacegan_util.py b/r_basicsr/archs/hifacegan_util.py new file mode 100644 index 0000000..b63b928 --- /dev/null +++ b/r_basicsr/archs/hifacegan_util.py @@ -0,0 +1,255 @@ +import re
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import init
+# Warning: spectral norm could be buggy
+# under eval mode and multi-GPU inference
+# A workaround is sticking to single-GPU inference and train mode
+from torch.nn.utils import spectral_norm
+
+
+class SPADE(nn.Module):
+
+ def __init__(self, config_text, norm_nc, label_nc):
+ super().__init__()
+
+ assert config_text.startswith('spade')
+ parsed = re.search('spade(\\D+)(\\d)x\\d', config_text)
+ param_free_norm_type = str(parsed.group(1))
+ ks = int(parsed.group(2))
+
+ if param_free_norm_type == 'instance':
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc)
+ elif param_free_norm_type == 'syncbatch':
+ print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc)
+ elif param_free_norm_type == 'batch':
+ self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
+ else:
+ raise ValueError(f'{param_free_norm_type} is not a recognized param-free norm type in SPADE')
+
+ # The dimension of the intermediate embedding space. Yes, hardcoded.
+ nhidden = 128 if norm_nc > 128 else norm_nc
+
+ pw = ks // 2
+ self.mlp_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
+
+ def forward(self, x, segmap):
+
+ # Part 1. generate parameter-free normalized activations
+ normalized = self.param_free_norm(x)
+
+ # Part 2. produce scaling and bias conditioned on semantic map
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
+ actv = self.mlp_shared(segmap)
+ gamma = self.mlp_gamma(actv)
+ beta = self.mlp_beta(actv)
+
+ # apply scale and bias
+ out = normalized * gamma + beta
+
+ return out
+
+
+class SPADEResnetBlock(nn.Module):
+ """
+ ResNet block that uses SPADE. It differs from the ResNet block of pix2pixHD in that
+ it takes in the segmentation map as input, learns the skip connection if necessary,
+ and applies normalization first and then convolution.
+ This architecture seemed like a standard architecture for unconditional or
+ class-conditional GAN architecture using residual block.
+ The code was inspired from https://github.com/LMescheder/GAN_stability.
+ """
+
+ def __init__(self, fin, fout, norm_g='spectralspadesyncbatch3x3', semantic_nc=3):
+ super().__init__()
+ # Attributes
+ self.learned_shortcut = (fin != fout)
+ fmiddle = min(fin, fout)
+
+ # create conv layers
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
+ if self.learned_shortcut:
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
+
+ # apply spectral norm if specified
+ if 'spectral' in norm_g:
+ self.conv_0 = spectral_norm(self.conv_0)
+ self.conv_1 = spectral_norm(self.conv_1)
+ if self.learned_shortcut:
+ self.conv_s = spectral_norm(self.conv_s)
+
+ # define normalization layers
+ spade_config_str = norm_g.replace('spectral', '')
+ self.norm_0 = SPADE(spade_config_str, fin, semantic_nc)
+ self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc)
+ if self.learned_shortcut:
+ self.norm_s = SPADE(spade_config_str, fin, semantic_nc)
+
+ # note the resnet block with SPADE also takes in |seg|,
+ # the semantic segmentation map as input
+ def forward(self, x, seg):
+ x_s = self.shortcut(x, seg)
+ dx = self.conv_0(self.act(self.norm_0(x, seg)))
+ dx = self.conv_1(self.act(self.norm_1(dx, seg)))
+ out = x_s + dx
+ return out
+
+ def shortcut(self, x, seg):
+ if self.learned_shortcut:
+ x_s = self.conv_s(self.norm_s(x, seg))
+ else:
+ x_s = x
+ return x_s
+
+ def act(self, x):
+ return F.leaky_relu(x, 2e-1)
+
+
+class BaseNetwork(nn.Module):
+ """ A basis for hifacegan archs with custom initialization """
+
+ def init_weights(self, init_type='normal', gain=0.02):
+
+ def init_func(m):
+ classname = m.__class__.__name__
+ if classname.find('BatchNorm2d') != -1:
+ if hasattr(m, 'weight') and m.weight is not None:
+ init.normal_(m.weight.data, 1.0, gain)
+ if hasattr(m, 'bias') and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
+ if init_type == 'normal':
+ init.normal_(m.weight.data, 0.0, gain)
+ elif init_type == 'xavier':
+ init.xavier_normal_(m.weight.data, gain=gain)
+ elif init_type == 'xavier_uniform':
+ init.xavier_uniform_(m.weight.data, gain=1.0)
+ elif init_type == 'kaiming':
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+ elif init_type == 'orthogonal':
+ init.orthogonal_(m.weight.data, gain=gain)
+ elif init_type == 'none': # uses pytorch's default init method
+ m.reset_parameters()
+ else:
+ raise NotImplementedError(f'initialization method [{init_type}] is not implemented')
+ if hasattr(m, 'bias') and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+
+ self.apply(init_func)
+
+ # propagate to children
+ for m in self.children():
+ if hasattr(m, 'init_weights'):
+ m.init_weights(init_type, gain)
+
+ def forward(self, x):
+ pass
+
+
+def lip2d(x, logit, kernel=3, stride=2, padding=1):
+ weight = logit.exp()
+ return F.avg_pool2d(x * weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding)
+
+
+class SoftGate(nn.Module):
+ COEFF = 12.0
+
+ def forward(self, x):
+ return torch.sigmoid(x).mul(self.COEFF)
+
+
+class SimplifiedLIP(nn.Module):
+
+ def __init__(self, channels):
+ super(SimplifiedLIP, self).__init__()
+ self.logit = nn.Sequential(
+ nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.InstanceNorm2d(channels, affine=True),
+ SoftGate())
+
+ def init_layer(self):
+ self.logit[0].weight.data.fill_(0.0)
+
+ def forward(self, x):
+ frac = lip2d(x, self.logit(x))
+ return frac
+
+
+class LIPEncoder(BaseNetwork):
+ """Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)"""
+
+ def __init__(self, input_nc, ngf, sw, sh, n_2xdown, norm_layer=nn.InstanceNorm2d):
+ super().__init__()
+ self.sw = sw
+ self.sh = sh
+ self.max_ratio = 16
+ # 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold
+ kw = 3
+ pw = (kw - 1) // 2
+
+ model = [
+ nn.Conv2d(input_nc, ngf, kw, stride=1, padding=pw, bias=False),
+ norm_layer(ngf),
+ nn.ReLU(),
+ ]
+ cur_ratio = 1
+ for i in range(n_2xdown):
+ next_ratio = min(cur_ratio * 2, self.max_ratio)
+ model += [
+ SimplifiedLIP(ngf * cur_ratio),
+ nn.Conv2d(ngf * cur_ratio, ngf * next_ratio, kw, stride=1, padding=pw),
+ norm_layer(ngf * next_ratio),
+ ]
+ cur_ratio = next_ratio
+ if i < n_2xdown - 1:
+ model += [nn.ReLU(inplace=True)]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+def get_nonspade_norm_layer(norm_type='instance'):
+ # helper function to get # output channels of the previous layer
+ def get_out_channel(layer):
+ if hasattr(layer, 'out_channels'):
+ return getattr(layer, 'out_channels')
+ return layer.weight.size(0)
+
+ # this function will be returned
+ def add_norm_layer(layer):
+ nonlocal norm_type
+ if norm_type.startswith('spectral'):
+ layer = spectral_norm(layer)
+ subnorm_type = norm_type[len('spectral'):]
+
+ if subnorm_type == 'none' or len(subnorm_type) == 0:
+ return layer
+
+ # remove bias in the previous layer, which is meaningless
+ # since it has no effect after normalization
+ if getattr(layer, 'bias', None) is not None:
+ delattr(layer, 'bias')
+ layer.register_parameter('bias', None)
+
+ if subnorm_type == 'batch':
+ norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
+ elif subnorm_type == 'sync_batch':
+ print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
+ # norm_layer = SynchronizedBatchNorm2d(
+ # get_out_channel(layer), affine=True)
+ norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
+ elif subnorm_type == 'instance':
+ norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
+ else:
+ raise ValueError(f'normalization layer {subnorm_type} is not recognized')
+
+ return nn.Sequential(layer, norm_layer)
+
+ print('This is a legacy from nvlabs/SPADE, and will be removed in future versions.')
+ return add_norm_layer
diff --git a/r_basicsr/archs/inception.py b/r_basicsr/archs/inception.py new file mode 100644 index 0000000..7db2b42 --- /dev/null +++ b/r_basicsr/archs/inception.py @@ -0,0 +1,307 @@ +# Modified from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py # noqa: E501
+# For FID metric
+
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.model_zoo import load_url
+from torchvision import models
+
+# Inception weights ported to Pytorch from
+# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
+LOCAL_FID_WEIGHTS = 'experiments/pretrained_models/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
+
+
+class InceptionV3(nn.Module):
+ """Pretrained InceptionV3 network returning feature maps"""
+
+ # Index of default block of inception to return,
+ # corresponds to output of final average pooling
+ DEFAULT_BLOCK_INDEX = 3
+
+ # Maps feature dimensionality to their output blocks indices
+ BLOCK_INDEX_BY_DIM = {
+ 64: 0, # First max pooling features
+ 192: 1, # Second max pooling features
+ 768: 2, # Pre-aux classifier features
+ 2048: 3 # Final average pooling features
+ }
+
+ def __init__(self,
+ output_blocks=(DEFAULT_BLOCK_INDEX),
+ resize_input=True,
+ normalize_input=True,
+ requires_grad=False,
+ use_fid_inception=True):
+ """Build pretrained InceptionV3.
+
+ Args:
+ output_blocks (list[int]): Indices of blocks to return features of.
+ Possible values are:
+ - 0: corresponds to output of first max pooling
+ - 1: corresponds to output of second max pooling
+ - 2: corresponds to output which is fed to aux classifier
+ - 3: corresponds to output of final average pooling
+ resize_input (bool): If true, bilinearly resizes input to width and
+ height 299 before feeding input to model. As the network
+ without fully connected layers is fully convolutional, it
+ should be able to handle inputs of arbitrary size, so resizing
+ might not be strictly needed. Default: True.
+ normalize_input (bool): If true, scales the input from range (0, 1)
+ to the range the pretrained Inception network expects,
+ namely (-1, 1). Default: True.
+ requires_grad (bool): If true, parameters of the model require
+ gradients. Possibly useful for finetuning the network.
+ Default: False.
+ use_fid_inception (bool): If true, uses the pretrained Inception
+ model used in Tensorflow's FID implementation.
+ If false, uses the pretrained Inception model available in
+ torchvision. The FID Inception model has different weights
+ and a slightly different structure from torchvision's
+ Inception model. If you want to compute FID scores, you are
+ strongly advised to set this parameter to true to get
+ comparable results. Default: True.
+ """
+ super(InceptionV3, self).__init__()
+
+ self.resize_input = resize_input
+ self.normalize_input = normalize_input
+ self.output_blocks = sorted(output_blocks)
+ self.last_needed_block = max(output_blocks)
+
+ assert self.last_needed_block <= 3, ('Last possible output block index is 3')
+
+ self.blocks = nn.ModuleList()
+
+ if use_fid_inception:
+ inception = fid_inception_v3()
+ else:
+ try:
+ inception = models.inception_v3(pretrained=True, init_weights=False)
+ except TypeError:
+ # pytorch < 1.5 does not have init_weights for inception_v3
+ inception = models.inception_v3(pretrained=True)
+
+ # Block 0: input to maxpool1
+ block0 = [
+ inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3,
+ nn.MaxPool2d(kernel_size=3, stride=2)
+ ]
+ self.blocks.append(nn.Sequential(*block0))
+
+ # Block 1: maxpool1 to maxpool2
+ if self.last_needed_block >= 1:
+ block1 = [inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2)]
+ self.blocks.append(nn.Sequential(*block1))
+
+ # Block 2: maxpool2 to aux classifier
+ if self.last_needed_block >= 2:
+ block2 = [
+ inception.Mixed_5b,
+ inception.Mixed_5c,
+ inception.Mixed_5d,
+ inception.Mixed_6a,
+ inception.Mixed_6b,
+ inception.Mixed_6c,
+ inception.Mixed_6d,
+ inception.Mixed_6e,
+ ]
+ self.blocks.append(nn.Sequential(*block2))
+
+ # Block 3: aux classifier to final avgpool
+ if self.last_needed_block >= 3:
+ block3 = [
+ inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c,
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
+ ]
+ self.blocks.append(nn.Sequential(*block3))
+
+ for param in self.parameters():
+ param.requires_grad = requires_grad
+
+ def forward(self, x):
+ """Get Inception feature maps.
+
+ Args:
+ x (Tensor): Input tensor of shape (b, 3, h, w).
+ Values are expected to be in range (-1, 1). You can also input
+ (0, 1) with setting normalize_input = True.
+
+ Returns:
+ list[Tensor]: Corresponding to the selected output block, sorted
+ ascending by index.
+ """
+ output = []
+
+ if self.resize_input:
+ x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
+
+ if self.normalize_input:
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
+
+ for idx, block in enumerate(self.blocks):
+ x = block(x)
+ if idx in self.output_blocks:
+ output.append(x)
+
+ if idx == self.last_needed_block:
+ break
+
+ return output
+
+
+def fid_inception_v3():
+ """Build pretrained Inception model for FID computation.
+
+ The Inception model for FID computation uses a different set of weights
+ and has a slightly different structure than torchvision's Inception.
+
+ This method first constructs torchvision's Inception and then patches the
+ necessary parts that are different in the FID Inception model.
+ """
+ try:
+ inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False, init_weights=False)
+ except TypeError:
+ # pytorch < 1.5 does not have init_weights for inception_v3
+ inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False)
+
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
+ inception.Mixed_7b = FIDInceptionE_1(1280)
+ inception.Mixed_7c = FIDInceptionE_2(2048)
+
+ if os.path.exists(LOCAL_FID_WEIGHTS):
+ state_dict = torch.load(LOCAL_FID_WEIGHTS, map_location=lambda storage, loc: storage)
+ else:
+ state_dict = load_url(FID_WEIGHTS_URL, progress=True)
+
+ inception.load_state_dict(state_dict)
+ return inception
+
+
+class FIDInceptionA(models.inception.InceptionA):
+ """InceptionA block patched for FID computation"""
+
+ def __init__(self, in_channels, pool_features):
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch5x5 = self.branch5x5_1(x)
+ branch5x5 = self.branch5x5_2(branch5x5)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+ # Patch: Tensorflow's average pool does not use the padded zero's in
+ # its average calculation
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class FIDInceptionC(models.inception.InceptionC):
+ """InceptionC block patched for FID computation"""
+
+ def __init__(self, in_channels, channels_7x7):
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch7x7 = self.branch7x7_1(x)
+ branch7x7 = self.branch7x7_2(branch7x7)
+ branch7x7 = self.branch7x7_3(branch7x7)
+
+ branch7x7dbl = self.branch7x7dbl_1(x)
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
+
+ # Patch: Tensorflow's average pool does not use the padded zero's in
+ # its average calculation
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_1(models.inception.InceptionE):
+ """First InceptionE block patched for FID computation"""
+
+ def __init__(self, in_channels):
+ super(FIDInceptionE_1, self).__init__(in_channels)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = [
+ self.branch3x3_2a(branch3x3),
+ self.branch3x3_2b(branch3x3),
+ ]
+ branch3x3 = torch.cat(branch3x3, 1)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = [
+ self.branch3x3dbl_3a(branch3x3dbl),
+ self.branch3x3dbl_3b(branch3x3dbl),
+ ]
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+ # Patch: Tensorflow's average pool does not use the padded zero's in
+ # its average calculation
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_2(models.inception.InceptionE):
+ """Second InceptionE block patched for FID computation"""
+
+ def __init__(self, in_channels):
+ super(FIDInceptionE_2, self).__init__(in_channels)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = [
+ self.branch3x3_2a(branch3x3),
+ self.branch3x3_2b(branch3x3),
+ ]
+ branch3x3 = torch.cat(branch3x3, 1)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = [
+ self.branch3x3dbl_3a(branch3x3dbl),
+ self.branch3x3dbl_3b(branch3x3dbl),
+ ]
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+ # Patch: The FID Inception model uses max pooling instead of average
+ # pooling. This is likely an error in this specific Inception
+ # implementation, as other Inception models use average pooling here
+ # (which matches the description in the paper).
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
diff --git a/r_basicsr/archs/rcan_arch.py b/r_basicsr/archs/rcan_arch.py new file mode 100644 index 0000000..78f917e --- /dev/null +++ b/r_basicsr/archs/rcan_arch.py @@ -0,0 +1,135 @@ +import torch
+from torch import nn as nn
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import Upsample, make_layer
+
+
+class ChannelAttention(nn.Module):
+ """Channel attention used in RCAN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
+ """
+
+ def __init__(self, num_feat, squeeze_factor=16):
+ super(ChannelAttention, self).__init__()
+ self.attention = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
+ nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid())
+
+ def forward(self, x):
+ y = self.attention(x)
+ return x * y
+
+
+class RCAB(nn.Module):
+ """Residual Channel Attention Block (RCAB) used in RCAN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
+ res_scale (float): Scale the residual. Default: 1.
+ """
+
+ def __init__(self, num_feat, squeeze_factor=16, res_scale=1):
+ super(RCAB, self).__init__()
+ self.res_scale = res_scale
+
+ self.rcab = nn.Sequential(
+ nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1),
+ ChannelAttention(num_feat, squeeze_factor))
+
+ def forward(self, x):
+ res = self.rcab(x) * self.res_scale
+ return res + x
+
+
+class ResidualGroup(nn.Module):
+ """Residual Group of RCAB.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ num_block (int): Block number in the body network.
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
+ res_scale (float): Scale the residual. Default: 1.
+ """
+
+ def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1):
+ super(ResidualGroup, self).__init__()
+
+ self.residual_group = make_layer(
+ RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale)
+ self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+
+ def forward(self, x):
+ res = self.conv(self.residual_group(x))
+ return res + x
+
+
+@ARCH_REGISTRY.register()
+class RCAN(nn.Module):
+ """Residual Channel Attention Networks.
+
+ Paper: Image Super-Resolution Using Very Deep Residual Channel Attention
+ Networks
+ Ref git repo: https://github.com/yulunzhang/RCAN.
+
+ Args:
+ num_in_ch (int): Channel number of inputs.
+ num_out_ch (int): Channel number of outputs.
+ num_feat (int): Channel number of intermediate features.
+ Default: 64.
+ num_group (int): Number of ResidualGroup. Default: 10.
+ num_block (int): Number of RCAB in ResidualGroup. Default: 16.
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
+ upscale (int): Upsampling factor. Support 2^n and 3.
+ Default: 4.
+ res_scale (float): Used to scale the residual in residual block.
+ Default: 1.
+ img_range (float): Image range. Default: 255.
+ rgb_mean (tuple[float]): Image mean in RGB orders.
+ Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
+ """
+
+ def __init__(self,
+ num_in_ch,
+ num_out_ch,
+ num_feat=64,
+ num_group=10,
+ num_block=16,
+ squeeze_factor=16,
+ upscale=4,
+ res_scale=1,
+ img_range=255.,
+ rgb_mean=(0.4488, 0.4371, 0.4040)):
+ super(RCAN, self).__init__()
+
+ self.img_range = img_range
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
+
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.body = make_layer(
+ ResidualGroup,
+ num_group,
+ num_feat=num_feat,
+ num_block=num_block,
+ squeeze_factor=squeeze_factor,
+ res_scale=res_scale)
+ self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ def forward(self, x):
+ self.mean = self.mean.type_as(x)
+
+ x = (x - self.mean) * self.img_range
+ x = self.conv_first(x)
+ res = self.conv_after_body(self.body(x))
+ res += x
+
+ x = self.conv_last(self.upsample(res))
+ x = x / self.img_range + self.mean
+
+ return x
diff --git a/r_basicsr/archs/ridnet_arch.py b/r_basicsr/archs/ridnet_arch.py new file mode 100644 index 0000000..5a9349f --- /dev/null +++ b/r_basicsr/archs/ridnet_arch.py @@ -0,0 +1,184 @@ +import torch
+import torch.nn as nn
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import ResidualBlockNoBN, make_layer
+
+
+class MeanShift(nn.Conv2d):
+ """ Data normalization with mean and std.
+
+ Args:
+ rgb_range (int): Maximum value of RGB.
+ rgb_mean (list[float]): Mean for RGB channels.
+ rgb_std (list[float]): Std for RGB channels.
+ sign (int): For subtraction, sign is -1, for addition, sign is 1.
+ Default: -1.
+ requires_grad (bool): Whether to update the self.weight and self.bias.
+ Default: True.
+ """
+
+ def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1, requires_grad=True):
+ super(MeanShift, self).__init__(3, 3, kernel_size=1)
+ std = torch.Tensor(rgb_std)
+ self.weight.data = torch.eye(3).view(3, 3, 1, 1)
+ self.weight.data.div_(std.view(3, 1, 1, 1))
+ self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
+ self.bias.data.div_(std)
+ self.requires_grad = requires_grad
+
+
+class EResidualBlockNoBN(nn.Module):
+ """Enhanced Residual block without BN.
+
+ There are three convolution layers in residual branch.
+
+ It has a style of:
+ ---Conv-ReLU-Conv-ReLU-Conv-+-ReLU-
+ |__________________________|
+ """
+
+ def __init__(self, in_channels, out_channels):
+ super(EResidualBlockNoBN, self).__init__()
+
+ self.body = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, 3, 1, 1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels, 1, 1, 0),
+ )
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ out = self.body(x)
+ out = self.relu(out + x)
+ return out
+
+
+class MergeRun(nn.Module):
+ """ Merge-and-run unit.
+
+ This unit contains two branches with different dilated convolutions,
+ followed by a convolution to process the concatenated features.
+
+ Paper: Real Image Denoising with Feature Attention
+ Ref git repo: https://github.com/saeed-anwar/RIDNet
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
+ super(MergeRun, self).__init__()
+
+ self.dilation1 = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels, kernel_size, stride, 2, 2), nn.ReLU(inplace=True))
+ self.dilation2 = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, 3, 3), nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels, kernel_size, stride, 4, 4), nn.ReLU(inplace=True))
+
+ self.aggregation = nn.Sequential(
+ nn.Conv2d(out_channels * 2, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True))
+
+ def forward(self, x):
+ dilation1 = self.dilation1(x)
+ dilation2 = self.dilation2(x)
+ out = torch.cat([dilation1, dilation2], dim=1)
+ out = self.aggregation(out)
+ out = out + x
+ return out
+
+
+class ChannelAttention(nn.Module):
+ """Channel attention.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ squeeze_factor (int): Channel squeeze factor. Default:
+ """
+
+ def __init__(self, mid_channels, squeeze_factor=16):
+ super(ChannelAttention, self).__init__()
+ self.attention = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1), nn.Conv2d(mid_channels, mid_channels // squeeze_factor, 1, padding=0),
+ nn.ReLU(inplace=True), nn.Conv2d(mid_channels // squeeze_factor, mid_channels, 1, padding=0), nn.Sigmoid())
+
+ def forward(self, x):
+ y = self.attention(x)
+ return x * y
+
+
+class EAM(nn.Module):
+ """Enhancement attention modules (EAM) in RIDNet.
+
+ This module contains a merge-and-run unit, a residual block,
+ an enhanced residual block and a feature attention unit.
+
+ Attributes:
+ merge: The merge-and-run unit.
+ block1: The residual block.
+ block2: The enhanced residual block.
+ ca: The feature/channel attention unit.
+ """
+
+ def __init__(self, in_channels, mid_channels, out_channels):
+ super(EAM, self).__init__()
+
+ self.merge = MergeRun(in_channels, mid_channels)
+ self.block1 = ResidualBlockNoBN(mid_channels)
+ self.block2 = EResidualBlockNoBN(mid_channels, out_channels)
+ self.ca = ChannelAttention(out_channels)
+ # The residual block in the paper contains a relu after addition.
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ out = self.merge(x)
+ out = self.relu(self.block1(out))
+ out = self.block2(out)
+ out = self.ca(out)
+ return out
+
+
+@ARCH_REGISTRY.register()
+class RIDNet(nn.Module):
+ """RIDNet: Real Image Denoising with Feature Attention.
+
+ Ref git repo: https://github.com/saeed-anwar/RIDNet
+
+ Args:
+ in_channels (int): Channel number of inputs.
+ mid_channels (int): Channel number of EAM modules.
+ Default: 64.
+ out_channels (int): Channel number of outputs.
+ num_block (int): Number of EAM. Default: 4.
+ img_range (float): Image range. Default: 255.
+ rgb_mean (tuple[float]): Image mean in RGB orders.
+ Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
+ """
+
+ def __init__(self,
+ in_channels,
+ mid_channels,
+ out_channels,
+ num_block=4,
+ img_range=255.,
+ rgb_mean=(0.4488, 0.4371, 0.4040),
+ rgb_std=(1.0, 1.0, 1.0)):
+ super(RIDNet, self).__init__()
+
+ self.sub_mean = MeanShift(img_range, rgb_mean, rgb_std)
+ self.add_mean = MeanShift(img_range, rgb_mean, rgb_std, 1)
+
+ self.head = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
+ self.body = make_layer(
+ EAM, num_block, in_channels=mid_channels, mid_channels=mid_channels, out_channels=mid_channels)
+ self.tail = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ res = self.sub_mean(x)
+ res = self.tail(self.body(self.relu(self.head(res))))
+ res = self.add_mean(res)
+
+ out = x + res
+ return out
diff --git a/r_basicsr/archs/rrdbnet_arch.py b/r_basicsr/archs/rrdbnet_arch.py new file mode 100644 index 0000000..305696b --- /dev/null +++ b/r_basicsr/archs/rrdbnet_arch.py @@ -0,0 +1,119 @@ +import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import default_init_weights, make_layer, pixel_unshuffle
+
+
+class ResidualDenseBlock(nn.Module):
+ """Residual Dense Block.
+
+ Used in RRDB block in ESRGAN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ num_grow_ch (int): Channels for each growth.
+ """
+
+ def __init__(self, num_feat=64, num_grow_ch=32):
+ super(ResidualDenseBlock, self).__init__()
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ # initialization
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
+
+ def forward(self, x):
+ x1 = self.lrelu(self.conv1(x))
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ # Empirically, we use 0.2 to scale the residual for better performance
+ return x5 * 0.2 + x
+
+
+class RRDB(nn.Module):
+ """Residual in Residual Dense Block.
+
+ Used in RRDB-Net in ESRGAN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ num_grow_ch (int): Channels for each growth.
+ """
+
+ def __init__(self, num_feat, num_grow_ch=32):
+ super(RRDB, self).__init__()
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
+
+ def forward(self, x):
+ out = self.rdb1(x)
+ out = self.rdb2(out)
+ out = self.rdb3(out)
+ # Empirically, we use 0.2 to scale the residual for better performance
+ return out * 0.2 + x
+
+
+@ARCH_REGISTRY.register()
+class RRDBNet(nn.Module):
+ """Networks consisting of Residual in Residual Dense Block, which is used
+ in ESRGAN.
+
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
+
+ We extend ESRGAN for scale x2 and scale x1.
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
+
+ Args:
+ num_in_ch (int): Channel number of inputs.
+ num_out_ch (int): Channel number of outputs.
+ num_feat (int): Channel number of intermediate features.
+ Default: 64
+ num_block (int): Block number in the trunk network. Defaults: 23
+ num_grow_ch (int): Channels for each growth. Default: 32.
+ """
+
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
+ super(RRDBNet, self).__init__()
+ self.scale = scale
+ if scale == 2:
+ num_in_ch = num_in_ch * 4
+ elif scale == 1:
+ num_in_ch = num_in_ch * 16
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ # upsample
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x):
+ if self.scale == 2:
+ feat = pixel_unshuffle(x, scale=2)
+ elif self.scale == 1:
+ feat = pixel_unshuffle(x, scale=4)
+ else:
+ feat = x
+ feat = self.conv_first(feat)
+ body_feat = self.conv_body(self.body(feat))
+ feat = feat + body_feat
+ # upsample
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
+ return out
diff --git a/r_basicsr/archs/spynet_arch.py b/r_basicsr/archs/spynet_arch.py new file mode 100644 index 0000000..2bd143c --- /dev/null +++ b/r_basicsr/archs/spynet_arch.py @@ -0,0 +1,96 @@ +import math
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import flow_warp
+
+
+class BasicModule(nn.Module):
+ """Basic Module for SpyNet.
+ """
+
+ def __init__(self):
+ super(BasicModule, self).__init__()
+
+ self.basic_module = nn.Sequential(
+ nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
+
+ def forward(self, tensor_input):
+ return self.basic_module(tensor_input)
+
+
+@ARCH_REGISTRY.register()
+class SpyNet(nn.Module):
+ """SpyNet architecture.
+
+ Args:
+ load_path (str): path for pretrained SpyNet. Default: None.
+ """
+
+ def __init__(self, load_path=None):
+ super(SpyNet, self).__init__()
+ self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)])
+ if load_path:
+ self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
+
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def preprocess(self, tensor_input):
+ tensor_output = (tensor_input - self.mean) / self.std
+ return tensor_output
+
+ def process(self, ref, supp):
+ flow = []
+
+ ref = [self.preprocess(ref)]
+ supp = [self.preprocess(supp)]
+
+ for level in range(5):
+ ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
+ supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
+
+ flow = ref[0].new_zeros(
+ [ref[0].size(0), 2,
+ int(math.floor(ref[0].size(2) / 2.0)),
+ int(math.floor(ref[0].size(3) / 2.0))])
+
+ for level in range(len(ref)):
+ upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
+
+ if upsampled_flow.size(2) != ref[level].size(2):
+ upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate')
+ if upsampled_flow.size(3) != ref[level].size(3):
+ upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate')
+
+ flow = self.basic_module[level](torch.cat([
+ ref[level],
+ flow_warp(
+ supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'),
+ upsampled_flow
+ ], 1)) + upsampled_flow
+
+ return flow
+
+ def forward(self, ref, supp):
+ assert ref.size() == supp.size()
+
+ h, w = ref.size(2), ref.size(3)
+ w_floor = math.floor(math.ceil(w / 32.0) * 32.0)
+ h_floor = math.floor(math.ceil(h / 32.0) * 32.0)
+
+ ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
+ supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
+
+ flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False)
+
+ flow[:, 0, :, :] *= float(w) / float(w_floor)
+ flow[:, 1, :, :] *= float(h) / float(h_floor)
+
+ return flow
diff --git a/r_basicsr/archs/srresnet_arch.py b/r_basicsr/archs/srresnet_arch.py new file mode 100644 index 0000000..99b56a4 --- /dev/null +++ b/r_basicsr/archs/srresnet_arch.py @@ -0,0 +1,65 @@ +from torch import nn as nn
+from torch.nn import functional as F
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer
+
+
+@ARCH_REGISTRY.register()
+class MSRResNet(nn.Module):
+ """Modified SRResNet.
+
+ A compacted version modified from SRResNet in
+ "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"
+ It uses residual blocks without BN, similar to EDSR.
+ Currently, it supports x2, x3 and x4 upsampling scale factor.
+
+ Args:
+ num_in_ch (int): Channel number of inputs. Default: 3.
+ num_out_ch (int): Channel number of outputs. Default: 3.
+ num_feat (int): Channel number of intermediate features. Default: 64.
+ num_block (int): Block number in the body network. Default: 16.
+ upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4.
+ """
+
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4):
+ super(MSRResNet, self).__init__()
+ self.upscale = upscale
+
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat)
+
+ # upsampling
+ if self.upscale in [2, 3]:
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, 1, 1)
+ self.pixel_shuffle = nn.PixelShuffle(self.upscale)
+ elif self.upscale == 4:
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
+ self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
+ self.pixel_shuffle = nn.PixelShuffle(2)
+
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ # activation function
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ # initialization
+ default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1)
+ if self.upscale == 4:
+ default_init_weights(self.upconv2, 0.1)
+
+ def forward(self, x):
+ feat = self.lrelu(self.conv_first(x))
+ out = self.body(feat)
+
+ if self.upscale == 4:
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
+ out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
+ elif self.upscale in [2, 3]:
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
+
+ out = self.conv_last(self.lrelu(self.conv_hr(out)))
+ base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
+ out += base
+ return out
diff --git a/r_basicsr/archs/srvgg_arch.py b/r_basicsr/archs/srvgg_arch.py new file mode 100644 index 0000000..f0e51e4 --- /dev/null +++ b/r_basicsr/archs/srvgg_arch.py @@ -0,0 +1,70 @@ +from torch import nn as nn
+from torch.nn import functional as F
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+
+
+@ARCH_REGISTRY.register(suffix='basicsr')
+class SRVGGNetCompact(nn.Module):
+ """A compact VGG-style network structure for super-resolution.
+
+ It is a compact network structure, which performs upsampling in the last layer and no convolution is
+ conducted on the HR feature space.
+
+ Args:
+ num_in_ch (int): Channel number of inputs. Default: 3.
+ num_out_ch (int): Channel number of outputs. Default: 3.
+ num_feat (int): Channel number of intermediate features. Default: 64.
+ num_conv (int): Number of convolution layers in the body network. Default: 16.
+ upscale (int): Upsampling factor. Default: 4.
+ act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
+ """
+
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
+ super(SRVGGNetCompact, self).__init__()
+ self.num_in_ch = num_in_ch
+ self.num_out_ch = num_out_ch
+ self.num_feat = num_feat
+ self.num_conv = num_conv
+ self.upscale = upscale
+ self.act_type = act_type
+
+ self.body = nn.ModuleList()
+ # the first conv
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
+ # the first activation
+ if act_type == 'relu':
+ activation = nn.ReLU(inplace=True)
+ elif act_type == 'prelu':
+ activation = nn.PReLU(num_parameters=num_feat)
+ elif act_type == 'leakyrelu':
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.body.append(activation)
+
+ # the body structure
+ for _ in range(num_conv):
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
+ # activation
+ if act_type == 'relu':
+ activation = nn.ReLU(inplace=True)
+ elif act_type == 'prelu':
+ activation = nn.PReLU(num_parameters=num_feat)
+ elif act_type == 'leakyrelu':
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.body.append(activation)
+
+ # the last conv
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
+ # upsample
+ self.upsampler = nn.PixelShuffle(upscale)
+
+ def forward(self, x):
+ out = x
+ for i in range(0, len(self.body)):
+ out = self.body[i](out)
+
+ out = self.upsampler(out)
+ # add the nearest upsampled image, so that the network learns the residual
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
+ out += base
+ return out
diff --git a/r_basicsr/archs/stylegan2_arch.py b/r_basicsr/archs/stylegan2_arch.py new file mode 100644 index 0000000..e8d571e --- /dev/null +++ b/r_basicsr/archs/stylegan2_arch.py @@ -0,0 +1,799 @@ +import math
+import random
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from r_basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
+from r_basicsr.ops.upfirdn2d import upfirdn2d
+from r_basicsr.utils.registry import ARCH_REGISTRY
+
+
+class NormStyleCode(nn.Module):
+
+ def forward(self, x):
+ """Normalize the style codes.
+
+ Args:
+ x (Tensor): Style codes with shape (b, c).
+
+ Returns:
+ Tensor: Normalized tensor.
+ """
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
+
+
+def make_resample_kernel(k):
+ """Make resampling kernel for UpFirDn.
+
+ Args:
+ k (list[int]): A list indicating the 1D resample kernel magnitude.
+
+ Returns:
+ Tensor: 2D resampled kernel.
+ """
+ k = torch.tensor(k, dtype=torch.float32)
+ if k.ndim == 1:
+ k = k[None, :] * k[:, None] # to 2D kernel, outer product
+ # normalize
+ k /= k.sum()
+ return k
+
+
+class UpFirDnUpsample(nn.Module):
+ """Upsample, FIR filter, and downsample (upsampole version).
+
+ References:
+ 1. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.upfirdn.html # noqa: E501
+ 2. http://www.ece.northwestern.edu/local-apps/matlabhelp/toolbox/signal/upfirdn.html # noqa: E501
+
+ Args:
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
+ magnitude.
+ factor (int): Upsampling scale factor. Default: 2.
+ """
+
+ def __init__(self, resample_kernel, factor=2):
+ super(UpFirDnUpsample, self).__init__()
+ self.kernel = make_resample_kernel(resample_kernel) * (factor**2)
+ self.factor = factor
+
+ pad = self.kernel.shape[0] - factor
+ self.pad = ((pad + 1) // 2 + factor - 1, pad // 2)
+
+ def forward(self, x):
+ out = upfirdn2d(x, self.kernel.type_as(x), up=self.factor, down=1, pad=self.pad)
+ return out
+
+ def __repr__(self):
+ return (f'{self.__class__.__name__}(factor={self.factor})')
+
+
+class UpFirDnDownsample(nn.Module):
+ """Upsample, FIR filter, and downsample (downsampole version).
+
+ Args:
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
+ magnitude.
+ factor (int): Downsampling scale factor. Default: 2.
+ """
+
+ def __init__(self, resample_kernel, factor=2):
+ super(UpFirDnDownsample, self).__init__()
+ self.kernel = make_resample_kernel(resample_kernel)
+ self.factor = factor
+
+ pad = self.kernel.shape[0] - factor
+ self.pad = ((pad + 1) // 2, pad // 2)
+
+ def forward(self, x):
+ out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=self.factor, pad=self.pad)
+ return out
+
+ def __repr__(self):
+ return (f'{self.__class__.__name__}(factor={self.factor})')
+
+
+class UpFirDnSmooth(nn.Module):
+ """Upsample, FIR filter, and downsample (smooth version).
+
+ Args:
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
+ magnitude.
+ upsample_factor (int): Upsampling scale factor. Default: 1.
+ downsample_factor (int): Downsampling scale factor. Default: 1.
+ kernel_size (int): Kernel size: Default: 1.
+ """
+
+ def __init__(self, resample_kernel, upsample_factor=1, downsample_factor=1, kernel_size=1):
+ super(UpFirDnSmooth, self).__init__()
+ self.upsample_factor = upsample_factor
+ self.downsample_factor = downsample_factor
+ self.kernel = make_resample_kernel(resample_kernel)
+ if upsample_factor > 1:
+ self.kernel = self.kernel * (upsample_factor**2)
+
+ if upsample_factor > 1:
+ pad = (self.kernel.shape[0] - upsample_factor) - (kernel_size - 1)
+ self.pad = ((pad + 1) // 2 + upsample_factor - 1, pad // 2 + 1)
+ elif downsample_factor > 1:
+ pad = (self.kernel.shape[0] - downsample_factor) + (kernel_size - 1)
+ self.pad = ((pad + 1) // 2, pad // 2)
+ else:
+ raise NotImplementedError
+
+ def forward(self, x):
+ out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=1, pad=self.pad)
+ return out
+
+ def __repr__(self):
+ return (f'{self.__class__.__name__}(upsample_factor={self.upsample_factor}'
+ f', downsample_factor={self.downsample_factor})')
+
+
+class EqualLinear(nn.Module):
+ """Equalized Linear as StyleGAN2.
+
+ Args:
+ in_channels (int): Size of each sample.
+ out_channels (int): Size of each output sample.
+ bias (bool): If set to ``False``, the layer will not learn an additive
+ bias. Default: ``True``.
+ bias_init_val (float): Bias initialized value. Default: 0.
+ lr_mul (float): Learning rate multiplier. Default: 1.
+ activation (None | str): The activation after ``linear`` operation.
+ Supported: 'fused_lrelu', None. Default: None.
+ """
+
+ def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
+ super(EqualLinear, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.lr_mul = lr_mul
+ self.activation = activation
+ if self.activation not in ['fused_lrelu', None]:
+ raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
+ "Supported ones are: ['fused_lrelu', None].")
+ self.scale = (1 / math.sqrt(in_channels)) * lr_mul
+
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
+ else:
+ self.register_parameter('bias', None)
+
+ def forward(self, x):
+ if self.bias is None:
+ bias = None
+ else:
+ bias = self.bias * self.lr_mul
+ if self.activation == 'fused_lrelu':
+ out = F.linear(x, self.weight * self.scale)
+ out = fused_leaky_relu(out, bias)
+ else:
+ out = F.linear(x, self.weight * self.scale, bias=bias)
+ return out
+
+ def __repr__(self):
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+ f'out_channels={self.out_channels}, bias={self.bias is not None})')
+
+
+class ModulatedConv2d(nn.Module):
+ """Modulated Conv2d used in StyleGAN2.
+
+ There is no bias in ModulatedConv2d.
+
+ Args:
+ in_channels (int): Channel number of the input.
+ out_channels (int): Channel number of the output.
+ kernel_size (int): Size of the convolving kernel.
+ num_style_feat (int): Channel number of style features.
+ demodulate (bool): Whether to demodulate in the conv layer.
+ Default: True.
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
+ Default: None.
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
+ magnitude. Default: (1, 3, 3, 1).
+ eps (float): A value added to the denominator for numerical stability.
+ Default: 1e-8.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ resample_kernel=(1, 3, 3, 1),
+ eps=1e-8):
+ super(ModulatedConv2d, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.demodulate = demodulate
+ self.sample_mode = sample_mode
+ self.eps = eps
+
+ if self.sample_mode == 'upsample':
+ self.smooth = UpFirDnSmooth(
+ resample_kernel, upsample_factor=2, downsample_factor=1, kernel_size=kernel_size)
+ elif self.sample_mode == 'downsample':
+ self.smooth = UpFirDnSmooth(
+ resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size)
+ elif self.sample_mode is None:
+ pass
+ else:
+ raise ValueError(f'Wrong sample mode {self.sample_mode}, '
+ "supported ones are ['upsample', 'downsample', None].")
+
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
+ # modulation inside each modulated conv
+ self.modulation = EqualLinear(
+ num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
+
+ self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
+ self.padding = kernel_size // 2
+
+ def forward(self, x, style):
+ """Forward function.
+
+ Args:
+ x (Tensor): Tensor with shape (b, c, h, w).
+ style (Tensor): Tensor with shape (b, num_style_feat).
+
+ Returns:
+ Tensor: Modulated tensor after convolution.
+ """
+ b, c, h, w = x.shape # c = c_in
+ # weight modulation
+ style = self.modulation(style).view(b, 1, c, 1, 1)
+ # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
+ weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
+
+ if self.demodulate:
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
+
+ weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
+
+ if self.sample_mode == 'upsample':
+ x = x.view(1, b * c, h, w)
+ weight = weight.view(b, self.out_channels, c, self.kernel_size, self.kernel_size)
+ weight = weight.transpose(1, 2).reshape(b * c, self.out_channels, self.kernel_size, self.kernel_size)
+ out = F.conv_transpose2d(x, weight, padding=0, stride=2, groups=b)
+ out = out.view(b, self.out_channels, *out.shape[2:4])
+ out = self.smooth(out)
+ elif self.sample_mode == 'downsample':
+ x = self.smooth(x)
+ x = x.view(1, b * c, *x.shape[2:4])
+ out = F.conv2d(x, weight, padding=0, stride=2, groups=b)
+ out = out.view(b, self.out_channels, *out.shape[2:4])
+ else:
+ x = x.view(1, b * c, h, w)
+ # weight: (b*c_out, c_in, k, k), groups=b
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
+ out = out.view(b, self.out_channels, *out.shape[2:4])
+
+ return out
+
+ def __repr__(self):
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+ f'out_channels={self.out_channels}, '
+ f'kernel_size={self.kernel_size}, '
+ f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
+
+
+class StyleConv(nn.Module):
+ """Style conv.
+
+ Args:
+ in_channels (int): Channel number of the input.
+ out_channels (int): Channel number of the output.
+ kernel_size (int): Size of the convolving kernel.
+ num_style_feat (int): Channel number of style features.
+ demodulate (bool): Whether demodulate in the conv layer. Default: True.
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
+ Default: None.
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
+ magnitude. Default: (1, 3, 3, 1).
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ resample_kernel=(1, 3, 3, 1)):
+ super(StyleConv, self).__init__()
+ self.modulated_conv = ModulatedConv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ num_style_feat,
+ demodulate=demodulate,
+ sample_mode=sample_mode,
+ resample_kernel=resample_kernel)
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
+ self.activate = FusedLeakyReLU(out_channels)
+
+ def forward(self, x, style, noise=None):
+ # modulate
+ out = self.modulated_conv(x, style)
+ # noise injection
+ if noise is None:
+ b, _, h, w = out.shape
+ noise = out.new_empty(b, 1, h, w).normal_()
+ out = out + self.weight * noise
+ # activation (with bias)
+ out = self.activate(out)
+ return out
+
+
+class ToRGB(nn.Module):
+ """To RGB from features.
+
+ Args:
+ in_channels (int): Channel number of input.
+ num_style_feat (int): Channel number of style features.
+ upsample (bool): Whether to upsample. Default: True.
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
+ magnitude. Default: (1, 3, 3, 1).
+ """
+
+ def __init__(self, in_channels, num_style_feat, upsample=True, resample_kernel=(1, 3, 3, 1)):
+ super(ToRGB, self).__init__()
+ if upsample:
+ self.upsample = UpFirDnUpsample(resample_kernel, factor=2)
+ else:
+ self.upsample = None
+ self.modulated_conv = ModulatedConv2d(
+ in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
+
+ def forward(self, x, style, skip=None):
+ """Forward function.
+
+ Args:
+ x (Tensor): Feature tensor with shape (b, c, h, w).
+ style (Tensor): Tensor with shape (b, num_style_feat).
+ skip (Tensor): Base/skip tensor. Default: None.
+
+ Returns:
+ Tensor: RGB images.
+ """
+ out = self.modulated_conv(x, style)
+ out = out + self.bias
+ if skip is not None:
+ if self.upsample:
+ skip = self.upsample(skip)
+ out = out + skip
+ return out
+
+
+class ConstantInput(nn.Module):
+ """Constant input.
+
+ Args:
+ num_channel (int): Channel number of constant input.
+ size (int): Spatial size of constant input.
+ """
+
+ def __init__(self, num_channel, size):
+ super(ConstantInput, self).__init__()
+ self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
+
+ def forward(self, batch):
+ out = self.weight.repeat(batch, 1, 1, 1)
+ return out
+
+
+@ARCH_REGISTRY.register()
+class StyleGAN2Generator(nn.Module):
+ """StyleGAN2 Generator.
+
+ Args:
+ out_size (int): The spatial size of outputs.
+ num_style_feat (int): Channel number of style features. Default: 512.
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
+ channel_multiplier (int): Channel multiplier for large networks of
+ StyleGAN2. Default: 2.
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
+ magnitude. A cross production will be applied to extent 1D resample
+ kernel to 2D resample kernel. Default: (1, 3, 3, 1).
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
+ narrow (float): Narrow ratio for channels. Default: 1.0.
+ """
+
+ def __init__(self,
+ out_size,
+ num_style_feat=512,
+ num_mlp=8,
+ channel_multiplier=2,
+ resample_kernel=(1, 3, 3, 1),
+ lr_mlp=0.01,
+ narrow=1):
+ super(StyleGAN2Generator, self).__init__()
+ # Style MLP layers
+ self.num_style_feat = num_style_feat
+ style_mlp_layers = [NormStyleCode()]
+ for i in range(num_mlp):
+ style_mlp_layers.append(
+ EqualLinear(
+ num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
+ activation='fused_lrelu'))
+ self.style_mlp = nn.Sequential(*style_mlp_layers)
+
+ channels = {
+ '4': int(512 * narrow),
+ '8': int(512 * narrow),
+ '16': int(512 * narrow),
+ '32': int(512 * narrow),
+ '64': int(256 * channel_multiplier * narrow),
+ '128': int(128 * channel_multiplier * narrow),
+ '256': int(64 * channel_multiplier * narrow),
+ '512': int(32 * channel_multiplier * narrow),
+ '1024': int(16 * channel_multiplier * narrow)
+ }
+ self.channels = channels
+
+ self.constant_input = ConstantInput(channels['4'], size=4)
+ self.style_conv1 = StyleConv(
+ channels['4'],
+ channels['4'],
+ kernel_size=3,
+ num_style_feat=num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ resample_kernel=resample_kernel)
+ self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, resample_kernel=resample_kernel)
+
+ self.log_size = int(math.log(out_size, 2))
+ self.num_layers = (self.log_size - 2) * 2 + 1
+ self.num_latent = self.log_size * 2 - 2
+
+ self.style_convs = nn.ModuleList()
+ self.to_rgbs = nn.ModuleList()
+ self.noises = nn.Module()
+
+ in_channels = channels['4']
+ # noise
+ for layer_idx in range(self.num_layers):
+ resolution = 2**((layer_idx + 5) // 2)
+ shape = [1, 1, resolution, resolution]
+ self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
+ # style convs and to_rgbs
+ for i in range(3, self.log_size + 1):
+ out_channels = channels[f'{2**i}']
+ self.style_convs.append(
+ StyleConv(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ num_style_feat=num_style_feat,
+ demodulate=True,
+ sample_mode='upsample',
+ resample_kernel=resample_kernel,
+ ))
+ self.style_convs.append(
+ StyleConv(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ num_style_feat=num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ resample_kernel=resample_kernel))
+ self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True, resample_kernel=resample_kernel))
+ in_channels = out_channels
+
+ def make_noise(self):
+ """Make noise for noise injection."""
+ device = self.constant_input.weight.device
+ noises = [torch.randn(1, 1, 4, 4, device=device)]
+
+ for i in range(3, self.log_size + 1):
+ for _ in range(2):
+ noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
+
+ return noises
+
+ def get_latent(self, x):
+ return self.style_mlp(x)
+
+ def mean_latent(self, num_latent):
+ latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
+ latent = self.style_mlp(latent_in).mean(0, keepdim=True)
+ return latent
+
+ def forward(self,
+ styles,
+ input_is_latent=False,
+ noise=None,
+ randomize_noise=True,
+ truncation=1,
+ truncation_latent=None,
+ inject_index=None,
+ return_latents=False):
+ """Forward function for StyleGAN2Generator.
+
+ Args:
+ styles (list[Tensor]): Sample codes of styles.
+ input_is_latent (bool): Whether input is latent style.
+ Default: False.
+ noise (Tensor | None): Input noise or None. Default: None.
+ randomize_noise (bool): Randomize noise, used when 'noise' is
+ False. Default: True.
+ truncation (float): TODO. Default: 1.
+ truncation_latent (Tensor | None): TODO. Default: None.
+ inject_index (int | None): The injection index for mixing noise.
+ Default: None.
+ return_latents (bool): Whether to return style latents.
+ Default: False.
+ """
+ # style codes -> latents with Style MLP layer
+ if not input_is_latent:
+ styles = [self.style_mlp(s) for s in styles]
+ # noises
+ if noise is None:
+ if randomize_noise:
+ noise = [None] * self.num_layers # for each style conv layer
+ else: # use the stored noise
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
+ # style truncation
+ if truncation < 1:
+ style_truncation = []
+ for style in styles:
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
+ styles = style_truncation
+ # get style latent with injection
+ if len(styles) == 1:
+ inject_index = self.num_latent
+
+ if styles[0].ndim < 3:
+ # repeat latent code for all the layers
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ else: # used for encoder with different latent code for each layer
+ latent = styles[0]
+ elif len(styles) == 2: # mixing noises
+ if inject_index is None:
+ inject_index = random.randint(1, self.num_latent - 1)
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
+ latent = torch.cat([latent1, latent2], 1)
+
+ # main generation
+ out = self.constant_input(latent.shape[0])
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
+ skip = self.to_rgb1(out, latent[:, 1])
+
+ i = 1
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
+ noise[2::2], self.to_rgbs):
+ out = conv1(out, latent[:, i], noise=noise1)
+ out = conv2(out, latent[:, i + 1], noise=noise2)
+ skip = to_rgb(out, latent[:, i + 2], skip)
+ i += 2
+
+ image = skip
+
+ if return_latents:
+ return image, latent
+ else:
+ return image, None
+
+
+class ScaledLeakyReLU(nn.Module):
+ """Scaled LeakyReLU.
+
+ Args:
+ negative_slope (float): Negative slope. Default: 0.2.
+ """
+
+ def __init__(self, negative_slope=0.2):
+ super(ScaledLeakyReLU, self).__init__()
+ self.negative_slope = negative_slope
+
+ def forward(self, x):
+ out = F.leaky_relu(x, negative_slope=self.negative_slope)
+ return out * math.sqrt(2)
+
+
+class EqualConv2d(nn.Module):
+ """Equalized Linear as StyleGAN2.
+
+ Args:
+ in_channels (int): Channel number of the input.
+ out_channels (int): Channel number of the output.
+ kernel_size (int): Size of the convolving kernel.
+ stride (int): Stride of the convolution. Default: 1
+ padding (int): Zero-padding added to both sides of the input.
+ Default: 0.
+ bias (bool): If ``True``, adds a learnable bias to the output.
+ Default: ``True``.
+ bias_init_val (float): Bias initialized value. Default: 0.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
+ super(EqualConv2d, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
+
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
+ else:
+ self.register_parameter('bias', None)
+
+ def forward(self, x):
+ out = F.conv2d(
+ x,
+ self.weight * self.scale,
+ bias=self.bias,
+ stride=self.stride,
+ padding=self.padding,
+ )
+
+ return out
+
+ def __repr__(self):
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+ f'out_channels={self.out_channels}, '
+ f'kernel_size={self.kernel_size},'
+ f' stride={self.stride}, padding={self.padding}, '
+ f'bias={self.bias is not None})')
+
+
+class ConvLayer(nn.Sequential):
+ """Conv Layer used in StyleGAN2 Discriminator.
+
+ Args:
+ in_channels (int): Channel number of the input.
+ out_channels (int): Channel number of the output.
+ kernel_size (int): Kernel size.
+ downsample (bool): Whether downsample by a factor of 2.
+ Default: False.
+ resample_kernel (list[int]): A list indicating the 1D resample
+ kernel magnitude. A cross production will be applied to
+ extent 1D resample kernel to 2D resample kernel.
+ Default: (1, 3, 3, 1).
+ bias (bool): Whether with bias. Default: True.
+ activate (bool): Whether use activateion. Default: True.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ downsample=False,
+ resample_kernel=(1, 3, 3, 1),
+ bias=True,
+ activate=True):
+ layers = []
+ # downsample
+ if downsample:
+ layers.append(
+ UpFirDnSmooth(resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size))
+ stride = 2
+ self.padding = 0
+ else:
+ stride = 1
+ self.padding = kernel_size // 2
+ # conv
+ layers.append(
+ EqualConv2d(
+ in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
+ and not activate))
+ # activation
+ if activate:
+ if bias:
+ layers.append(FusedLeakyReLU(out_channels))
+ else:
+ layers.append(ScaledLeakyReLU(0.2))
+
+ super(ConvLayer, self).__init__(*layers)
+
+
+class ResBlock(nn.Module):
+ """Residual block used in StyleGAN2 Discriminator.
+
+ Args:
+ in_channels (int): Channel number of the input.
+ out_channels (int): Channel number of the output.
+ resample_kernel (list[int]): A list indicating the 1D resample
+ kernel magnitude. A cross production will be applied to
+ extent 1D resample kernel to 2D resample kernel.
+ Default: (1, 3, 3, 1).
+ """
+
+ def __init__(self, in_channels, out_channels, resample_kernel=(1, 3, 3, 1)):
+ super(ResBlock, self).__init__()
+
+ self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
+ self.conv2 = ConvLayer(
+ in_channels, out_channels, 3, downsample=True, resample_kernel=resample_kernel, bias=True, activate=True)
+ self.skip = ConvLayer(
+ in_channels, out_channels, 1, downsample=True, resample_kernel=resample_kernel, bias=False, activate=False)
+
+ def forward(self, x):
+ out = self.conv1(x)
+ out = self.conv2(out)
+ skip = self.skip(x)
+ out = (out + skip) / math.sqrt(2)
+ return out
+
+
+@ARCH_REGISTRY.register()
+class StyleGAN2Discriminator(nn.Module):
+ """StyleGAN2 Discriminator.
+
+ Args:
+ out_size (int): The spatial size of outputs.
+ channel_multiplier (int): Channel multiplier for large networks of
+ StyleGAN2. Default: 2.
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
+ magnitude. A cross production will be applied to extent 1D resample
+ kernel to 2D resample kernel. Default: (1, 3, 3, 1).
+ stddev_group (int): For group stddev statistics. Default: 4.
+ narrow (float): Narrow ratio for channels. Default: 1.0.
+ """
+
+ def __init__(self, out_size, channel_multiplier=2, resample_kernel=(1, 3, 3, 1), stddev_group=4, narrow=1):
+ super(StyleGAN2Discriminator, self).__init__()
+
+ channels = {
+ '4': int(512 * narrow),
+ '8': int(512 * narrow),
+ '16': int(512 * narrow),
+ '32': int(512 * narrow),
+ '64': int(256 * channel_multiplier * narrow),
+ '128': int(128 * channel_multiplier * narrow),
+ '256': int(64 * channel_multiplier * narrow),
+ '512': int(32 * channel_multiplier * narrow),
+ '1024': int(16 * channel_multiplier * narrow)
+ }
+
+ log_size = int(math.log(out_size, 2))
+
+ conv_body = [ConvLayer(3, channels[f'{out_size}'], 1, bias=True, activate=True)]
+
+ in_channels = channels[f'{out_size}']
+ for i in range(log_size, 2, -1):
+ out_channels = channels[f'{2**(i - 1)}']
+ conv_body.append(ResBlock(in_channels, out_channels, resample_kernel))
+ in_channels = out_channels
+ self.conv_body = nn.Sequential(*conv_body)
+
+ self.final_conv = ConvLayer(in_channels + 1, channels['4'], 3, bias=True, activate=True)
+ self.final_linear = nn.Sequential(
+ EqualLinear(
+ channels['4'] * 4 * 4, channels['4'], bias=True, bias_init_val=0, lr_mul=1, activation='fused_lrelu'),
+ EqualLinear(channels['4'], 1, bias=True, bias_init_val=0, lr_mul=1, activation=None),
+ )
+ self.stddev_group = stddev_group
+ self.stddev_feat = 1
+
+ def forward(self, x):
+ out = self.conv_body(x)
+
+ b, c, h, w = out.shape
+ # concatenate a group stddev statistics to out
+ group = min(b, self.stddev_group) # Minibatch must be divisible by (or smaller than) group_size
+ stddev = out.view(group, -1, self.stddev_feat, c // self.stddev_feat, h, w)
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
+ stddev = stddev.repeat(group, 1, h, w)
+ out = torch.cat([out, stddev], 1)
+
+ out = self.final_conv(out)
+ out = out.view(b, -1)
+ out = self.final_linear(out)
+
+ return out
diff --git a/r_basicsr/archs/swinir_arch.py b/r_basicsr/archs/swinir_arch.py new file mode 100644 index 0000000..6be34df --- /dev/null +++ b/r_basicsr/archs/swinir_arch.py @@ -0,0 +1,956 @@ +# Modified from https://github.com/JingyunLiang/SwinIR
+# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
+# Originally Written by Ze Liu, Modified by Jingyun Liang.
+
+import math
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import to_2tuple, trunc_normal_
+
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
+ """
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class Mlp(nn.Module):
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (b, h, w, c)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*b, window_size, window_size, c)
+ """
+ b, h, w, c = x.shape
+ x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
+ return windows
+
+
+def window_reverse(windows, window_size, h, w):
+ """
+ Args:
+ windows: (num_windows*b, window_size, window_size, c)
+ window_size (int): Window size
+ h (int): Height of image
+ w (int): Width of image
+
+ Returns:
+ x: (b, h, w, c)
+ """
+ b = int(windows.shape[0] / (h * w / window_size / window_size))
+ x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer('relative_position_index', relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*b, n, c)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ b_, n, c = x.shape
+ qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nw = mask.shape[0]
+ attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, n, n)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
+
+ def flops(self, n):
+ # calculate flops for 1 window with token length of n
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += n * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * n * (self.dim // self.num_heads) * n
+ # x = (attn @ v)
+ flops += self.num_heads * n * n * (self.dim // self.num_heads)
+ # x = self.proj(x)
+ flops += n * self.dim * self.dim
+ return flops
+
+
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self,
+ dim,
+ input_resolution,
+ num_heads,
+ window_size=7,
+ shift_size=0,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim,
+ window_size=to_2tuple(self.window_size),
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if self.shift_size > 0:
+ attn_mask = self.calculate_mask(self.input_resolution)
+ else:
+ attn_mask = None
+
+ self.register_buffer('attn_mask', attn_mask)
+
+ def calculate_mask(self, x_size):
+ # calculate attention mask for SW-MSA
+ h, w = x_size
+ img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
+ h_slices = (slice(0, -self.window_size), slice(-self.window_size,
+ -self.shift_size), slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size), slice(-self.window_size,
+ -self.shift_size), slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ return attn_mask
+
+ def forward(self, x, x_size):
+ h, w = x_size
+ b, _, c = x.shape
+ # assert seq_len == h * w, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(b, h, w, c)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
+
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
+ if self.input_resolution == x_size:
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nw*b, window_size*window_size, c
+ else:
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
+ shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(b, h * w, c)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+ def extra_repr(self) -> str:
+ return (f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, '
+ f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}')
+
+ def flops(self):
+ flops = 0
+ h, w = self.input_resolution
+ # norm1
+ flops += self.dim * h * w
+ # W-MSA/SW-MSA
+ nw = h * w / self.window_size / self.window_size
+ flops += nw * self.attn.flops(self.window_size * self.window_size)
+ # mlp
+ flops += 2 * h * w * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * h * w
+ return flops
+
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x):
+ """
+ x: b, h*w, c
+ """
+ h, w = self.input_resolution
+ b, seq_len, c = x.shape
+ assert seq_len == h * w, 'input feature has wrong size'
+ assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.'
+
+ x = x.view(b, h, w, c)
+
+ x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
+ x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
+ x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
+ x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
+ x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
+ x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f'input_resolution={self.input_resolution}, dim={self.dim}'
+
+ def flops(self):
+ h, w = self.input_resolution
+ flops = h * w * self.dim
+ flops += (h // 2) * (w // 2) * 4 * self.dim * 2 * self.dim
+ return flops
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self,
+ dim,
+ input_resolution,
+ depth,
+ num_heads,
+ window_size,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(
+ dim=dim,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer) for i in range(depth)
+ ])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, x_size):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x, x_size)
+ if self.downsample is not None:
+ x = self.downsample(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+
+class RSTB(nn.Module):
+ """Residual Swin Transformer Block (RSTB).
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ img_size: Input image size.
+ patch_size: Patch size.
+ resi_connection: The convolutional block before residual connection.
+ """
+
+ def __init__(self,
+ dim,
+ input_resolution,
+ depth,
+ num_heads,
+ window_size,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False,
+ img_size=224,
+ patch_size=4,
+ resi_connection='1conv'):
+ super(RSTB, self).__init__()
+
+ self.dim = dim
+ self.input_resolution = input_resolution
+
+ self.residual_group = BasicLayer(
+ dim=dim,
+ input_resolution=input_resolution,
+ depth=depth,
+ num_heads=num_heads,
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path,
+ norm_layer=norm_layer,
+ downsample=downsample,
+ use_checkpoint=use_checkpoint)
+
+ if resi_connection == '1conv':
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv = nn.Sequential(
+ nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
+
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
+
+ def forward(self, x, x_size):
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
+
+ def flops(self):
+ flops = 0
+ flops += self.residual_group.flops()
+ h, w = self.input_resolution
+ flops += h * w * self.dim * self.dim * 9
+ flops += self.patch_embed.flops()
+ flops += self.patch_unembed.flops()
+
+ return flops
+
+
+class PatchEmbed(nn.Module):
+ r""" Image to Patch Embedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+ def flops(self):
+ flops = 0
+ h, w = self.img_size
+ if self.norm is not None:
+ flops += h * w * self.embed_dim
+ return flops
+
+
+class PatchUnEmbed(nn.Module):
+ r""" Image to Patch Unembedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ def forward(self, x, x_size):
+ x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
+ return x
+
+ def flops(self):
+ flops = 0
+ return flops
+
+
+class Upsample(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
+ super(Upsample, self).__init__(*m)
+
+
+class UpsampleOneStep(nn.Sequential):
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
+ Used in lightweight SR to save parameters.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+
+ """
+
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
+ self.num_feat = num_feat
+ self.input_resolution = input_resolution
+ m = []
+ m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
+ m.append(nn.PixelShuffle(scale))
+ super(UpsampleOneStep, self).__init__(*m)
+
+ def flops(self):
+ h, w = self.input_resolution
+ flops = h * w * self.num_feat * 3 * 9
+ return flops
+
+
+@ARCH_REGISTRY.register()
+class SwinIR(nn.Module):
+ r""" SwinIR
+ A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 64
+ patch_size (int | tuple(int)): Patch size. Default: 1
+ in_chans (int): Number of input image channels. Default: 3
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
+ img_range: Image range. 1. or 255.
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
+ """
+
+ def __init__(self,
+ img_size=64,
+ patch_size=1,
+ in_chans=3,
+ embed_dim=96,
+ depths=(6, 6, 6, 6),
+ num_heads=(6, 6, 6, 6),
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ ape=False,
+ patch_norm=True,
+ use_checkpoint=False,
+ upscale=2,
+ img_range=1.,
+ upsampler='',
+ resi_connection='1conv',
+ **kwargs):
+ super(SwinIR, self).__init__()
+ num_in_ch = in_chans
+ num_out_ch = in_chans
+ num_feat = 64
+ self.img_range = img_range
+ if in_chans == 3:
+ rgb_mean = (0.4488, 0.4371, 0.4040)
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
+ else:
+ self.mean = torch.zeros(1, 1, 1, 1)
+ self.upscale = upscale
+ self.upsampler = upsampler
+
+ # ------------------------- 1, shallow feature extraction ------------------------- #
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
+
+ # ------------------------- 2, deep feature extraction ------------------------- #
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = embed_dim
+ self.mlp_ratio = mlp_ratio
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=embed_dim,
+ embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.patches_resolution
+ self.patches_resolution = patches_resolution
+
+ # merge non-overlapping patches into image
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=embed_dim,
+ embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build Residual Swin Transformer blocks (RSTB)
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = RSTB(
+ dim=embed_dim,
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
+ norm_layer=norm_layer,
+ downsample=None,
+ use_checkpoint=use_checkpoint,
+ img_size=img_size,
+ patch_size=patch_size,
+ resi_connection=resi_connection)
+ self.layers.append(layer)
+ self.norm = norm_layer(self.num_features)
+
+ # build the last conv layer in deep feature extraction
+ if resi_connection == '1conv':
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv_after_body = nn.Sequential(
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
+
+ # ------------------------- 3, high quality image reconstruction ------------------------- #
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ self.conv_before_upsample = nn.Sequential(
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR (to save parameters)
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
+ (patches_resolution[0], patches_resolution[1]))
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR (less artifacts)
+ assert self.upscale == 4, 'only support x4 now.'
+ self.conv_before_upsample = nn.Sequential(
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+ def forward_features(self, x):
+ x_size = (x.shape[2], x.shape[3])
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers:
+ x = layer(x, x_size)
+
+ x = self.norm(x) # b seq_len c
+ x = self.patch_unembed(x, x_size)
+
+ return x
+
+ def forward(self, x):
+ self.mean = self.mean.type_as(x)
+ x = (x - self.mean) * self.img_range
+
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.conv_last(self.upsample(x))
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.upsample(x)
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ x_first = self.conv_first(x)
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
+ x = x + self.conv_last(res)
+
+ x = x / self.img_range + self.mean
+
+ return x
+
+ def flops(self):
+ flops = 0
+ h, w = self.patches_resolution
+ flops += h * w * 3 * self.embed_dim * 9
+ flops += self.patch_embed.flops()
+ for layer in self.layers:
+ flops += layer.flops()
+ flops += h * w * 3 * self.embed_dim * self.embed_dim
+ flops += self.upsample.flops()
+ return flops
+
+
+if __name__ == '__main__':
+ upscale = 4
+ window_size = 8
+ height = (1024 // upscale // window_size + 1) * window_size
+ width = (720 // upscale // window_size + 1) * window_size
+ model = SwinIR(
+ upscale=2,
+ img_size=(height, width),
+ window_size=window_size,
+ img_range=1.,
+ depths=[6, 6, 6, 6],
+ embed_dim=60,
+ num_heads=[6, 6, 6, 6],
+ mlp_ratio=2,
+ upsampler='pixelshuffledirect')
+ print(model)
+ print(height, width, model.flops() / 1e9)
+
+ x = torch.randn((1, 3, height, width))
+ x = model(x)
+ print(x.shape)
diff --git a/r_basicsr/archs/tof_arch.py b/r_basicsr/archs/tof_arch.py new file mode 100644 index 0000000..6b0fefa --- /dev/null +++ b/r_basicsr/archs/tof_arch.py @@ -0,0 +1,172 @@ +import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import flow_warp
+
+
+class BasicModule(nn.Module):
+ """Basic module of SPyNet.
+
+ Note that unlike the architecture in spynet_arch.py, the basic module
+ here contains batch normalization.
+ """
+
+ def __init__(self):
+ super(BasicModule, self).__init__()
+ self.basic_module = nn.Sequential(
+ nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
+ nn.BatchNorm2d(32), nn.ReLU(inplace=True),
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False),
+ nn.BatchNorm2d(64), nn.ReLU(inplace=True),
+ nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
+ nn.BatchNorm2d(32), nn.ReLU(inplace=True),
+ nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3, bias=False),
+ nn.BatchNorm2d(16), nn.ReLU(inplace=True),
+ nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
+
+ def forward(self, tensor_input):
+ """
+ Args:
+ tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
+ 8 channels contain:
+ [reference image (3), neighbor image (3), initial flow (2)].
+
+ Returns:
+ Tensor: Estimated flow with shape (b, 2, h, w)
+ """
+ return self.basic_module(tensor_input)
+
+
+class SPyNetTOF(nn.Module):
+ """SPyNet architecture for TOF.
+
+ Note that this implementation is specifically for TOFlow. Please use
+ spynet_arch.py for general use. They differ in the following aspects:
+ 1. The basic modules here contain BatchNorm.
+ 2. Normalization and denormalization are not done here, as
+ they are done in TOFlow.
+ Paper:
+ Optical Flow Estimation using a Spatial Pyramid Network
+ Code reference:
+ https://github.com/Coldog2333/pytoflow
+
+ Args:
+ load_path (str): Path for pretrained SPyNet. Default: None.
+ """
+
+ def __init__(self, load_path=None):
+ super(SPyNetTOF, self).__init__()
+
+ self.basic_module = nn.ModuleList([BasicModule() for _ in range(4)])
+ if load_path:
+ self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
+
+ def forward(self, ref, supp):
+ """
+ Args:
+ ref (Tensor): Reference image with shape of (b, 3, h, w).
+ supp: The supporting image to be warped: (b, 3, h, w).
+
+ Returns:
+ Tensor: Estimated optical flow: (b, 2, h, w).
+ """
+ num_batches, _, h, w = ref.size()
+ ref = [ref]
+ supp = [supp]
+
+ # generate downsampled frames
+ for _ in range(3):
+ ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
+ supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
+
+ # flow computation
+ flow = ref[0].new_zeros(num_batches, 2, h // 16, w // 16)
+ for i in range(4):
+ flow_up = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
+ flow = flow_up + self.basic_module[i](
+ torch.cat([ref[i], flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1))
+ return flow
+
+
+@ARCH_REGISTRY.register()
+class TOFlow(nn.Module):
+ """PyTorch implementation of TOFlow.
+
+ In TOFlow, the LR frames are pre-upsampled and have the same size with
+ the GT frames.
+ Paper:
+ Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018
+ Code reference:
+ 1. https://github.com/anchen1011/toflow
+ 2. https://github.com/Coldog2333/pytoflow
+
+ Args:
+ adapt_official_weights (bool): Whether to adapt the weights translated
+ from the official implementation. Set to false if you want to
+ train from scratch. Default: False
+ """
+
+ def __init__(self, adapt_official_weights=False):
+ super(TOFlow, self).__init__()
+ self.adapt_official_weights = adapt_official_weights
+ self.ref_idx = 0 if adapt_official_weights else 3
+
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ # flow estimation module
+ self.spynet = SPyNetTOF()
+
+ # reconstruction module
+ self.conv_1 = nn.Conv2d(3 * 7, 64, 9, 1, 4)
+ self.conv_2 = nn.Conv2d(64, 64, 9, 1, 4)
+ self.conv_3 = nn.Conv2d(64, 64, 1)
+ self.conv_4 = nn.Conv2d(64, 3, 1)
+
+ # activation function
+ self.relu = nn.ReLU(inplace=True)
+
+ def normalize(self, img):
+ return (img - self.mean) / self.std
+
+ def denormalize(self, img):
+ return img * self.std + self.mean
+
+ def forward(self, lrs):
+ """
+ Args:
+ lrs: Input lr frames: (b, 7, 3, h, w).
+
+ Returns:
+ Tensor: SR frame: (b, 3, h, w).
+ """
+ # In the official implementation, the 0-th frame is the reference frame
+ if self.adapt_official_weights:
+ lrs = lrs[:, [3, 0, 1, 2, 4, 5, 6], :, :, :]
+
+ num_batches, num_lrs, _, h, w = lrs.size()
+
+ lrs = self.normalize(lrs.view(-1, 3, h, w))
+ lrs = lrs.view(num_batches, num_lrs, 3, h, w)
+
+ lr_ref = lrs[:, self.ref_idx, :, :, :]
+ lr_aligned = []
+ for i in range(7): # 7 frames
+ if i == self.ref_idx:
+ lr_aligned.append(lr_ref)
+ else:
+ lr_supp = lrs[:, i, :, :, :]
+ flow = self.spynet(lr_ref, lr_supp)
+ lr_aligned.append(flow_warp(lr_supp, flow.permute(0, 2, 3, 1)))
+
+ # reconstruction
+ hr = torch.stack(lr_aligned, dim=1)
+ hr = hr.view(num_batches, -1, h, w)
+ hr = self.relu(self.conv_1(hr))
+ hr = self.relu(self.conv_2(hr))
+ hr = self.relu(self.conv_3(hr))
+ hr = self.conv_4(hr) + lr_ref
+
+ return self.denormalize(hr)
diff --git a/r_basicsr/archs/vgg_arch.py b/r_basicsr/archs/vgg_arch.py new file mode 100644 index 0000000..e6d9351 --- /dev/null +++ b/r_basicsr/archs/vgg_arch.py @@ -0,0 +1,161 @@ +import os
+import torch
+from collections import OrderedDict
+from torch import nn as nn
+from torchvision.models import vgg as vgg
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+
+VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
+NAMES = {
+ 'vgg11': [
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
+ 'pool5'
+ ],
+ 'vgg13': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
+ ],
+ 'vgg16': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
+ 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
+ 'pool5'
+ ],
+ 'vgg19': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
+ ]
+}
+
+
+def insert_bn(names):
+ """Insert bn layer after each conv.
+
+ Args:
+ names (list): The list of layer names.
+
+ Returns:
+ list: The list of layer names with bn layers.
+ """
+ names_bn = []
+ for name in names:
+ names_bn.append(name)
+ if 'conv' in name:
+ position = name.replace('conv', '')
+ names_bn.append('bn' + position)
+ return names_bn
+
+
+@ARCH_REGISTRY.register()
+class VGGFeatureExtractor(nn.Module):
+ """VGG network for feature extraction.
+
+ In this implementation, we allow users to choose whether use normalization
+ in the input feature and the type of vgg network. Note that the pretrained
+ path must fit the vgg type.
+
+ Args:
+ layer_name_list (list[str]): Forward function returns the corresponding
+ features according to the layer_name_list.
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image. Importantly,
+ the input feature must in the range [0, 1]. Default: True.
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+ Default: False.
+ requires_grad (bool): If true, the parameters of VGG network will be
+ optimized. Default: False.
+ remove_pooling (bool): If true, the max pooling operations in VGG net
+ will be removed. Default: False.
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
+ """
+
+ def __init__(self,
+ layer_name_list,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=False,
+ requires_grad=False,
+ remove_pooling=False,
+ pooling_stride=2):
+ super(VGGFeatureExtractor, self).__init__()
+
+ self.layer_name_list = layer_name_list
+ self.use_input_norm = use_input_norm
+ self.range_norm = range_norm
+
+ self.names = NAMES[vgg_type.replace('_bn', '')]
+ if 'bn' in vgg_type:
+ self.names = insert_bn(self.names)
+
+ # only borrow layers that will be used to avoid unused params
+ max_idx = 0
+ for v in layer_name_list:
+ idx = self.names.index(v)
+ if idx > max_idx:
+ max_idx = idx
+
+ if os.path.exists(VGG_PRETRAIN_PATH):
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
+ state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
+ vgg_net.load_state_dict(state_dict)
+ else:
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
+
+ features = vgg_net.features[:max_idx + 1]
+
+ modified_net = OrderedDict()
+ for k, v in zip(self.names, features):
+ if 'pool' in k:
+ # if remove_pooling is true, pooling operation will be removed
+ if remove_pooling:
+ continue
+ else:
+ # in some cases, we may want to change the default stride
+ modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
+ else:
+ modified_net[k] = v
+
+ self.vgg_net = nn.Sequential(modified_net)
+
+ if not requires_grad:
+ self.vgg_net.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+ else:
+ self.vgg_net.train()
+ for param in self.parameters():
+ param.requires_grad = True
+
+ if self.use_input_norm:
+ # the mean is for image with range [0, 1]
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ # the std is for image with range [0, 1]
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def forward(self, x):
+ """Forward function.
+
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+
+ Returns:
+ Tensor: Forward results.
+ """
+ if self.range_norm:
+ x = (x + 1) / 2
+ if self.use_input_norm:
+ x = (x - self.mean) / self.std
+
+ output = {}
+ for key, layer in self.vgg_net._modules.items():
+ x = layer(x)
+ if key in self.layer_name_list:
+ output[key] = x.clone()
+
+ return output
|