From 495ffc4777522e40941753e3b1b79c02f84b25b4 Mon Sep 17 00:00:00 2001 From: Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:00:30 +0000 Subject: Add files via upload --- r_basicsr/archs/basicvsrpp_arch.py | 407 +++++++++++++++++++++++++++++++++++++ 1 file changed, 407 insertions(+) create mode 100644 r_basicsr/archs/basicvsrpp_arch.py (limited to 'r_basicsr/archs/basicvsrpp_arch.py') diff --git a/r_basicsr/archs/basicvsrpp_arch.py b/r_basicsr/archs/basicvsrpp_arch.py new file mode 100644 index 0000000..f53b434 --- /dev/null +++ b/r_basicsr/archs/basicvsrpp_arch.py @@ -0,0 +1,407 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import warnings + +from r_basicsr.archs.arch_util import flow_warp +from r_basicsr.archs.basicvsr_arch import ConvResidualBlocks +from r_basicsr.archs.spynet_arch import SpyNet +from r_basicsr.ops.dcn import ModulatedDeformConvPack +from r_basicsr.utils.registry import ARCH_REGISTRY + + +@ARCH_REGISTRY.register() +class BasicVSRPlusPlus(nn.Module): + """BasicVSR++ network structure. + Support either x4 upsampling or same size output. Since DCN is used in this + model, it can only be used with CUDA enabled. If CUDA is not enabled, + feature alignment will be skipped. Besides, we adopt the official DCN + implementation and the version of torch need to be higher than 1.9. + Paper: + BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation + and Alignment + Args: + mid_channels (int, optional): Channel number of the intermediate + features. Default: 64. + num_blocks (int, optional): The number of residual blocks in each + propagation branch. Default: 7. + max_residue_magnitude (int): The maximum magnitude of the offset + residue (Eq. 6 in paper). Default: 10. + is_low_res_input (bool, optional): Whether the input is low-resolution + or not. If False, the output resolution is equal to the input + resolution. Default: True. + spynet_path (str): Path to the pretrained weights of SPyNet. Default: None. + cpu_cache_length (int, optional): When the length of sequence is larger + than this value, the intermediate features are sent to CPU. This + saves GPU memory, but slows down the inference speed. You can + increase this number if you have a GPU with large memory. + Default: 100. + """ + + def __init__(self, + mid_channels=64, + num_blocks=7, + max_residue_magnitude=10, + is_low_res_input=True, + spynet_path=None, + cpu_cache_length=100): + + super().__init__() + self.mid_channels = mid_channels + self.is_low_res_input = is_low_res_input + self.cpu_cache_length = cpu_cache_length + + # optical flow + self.spynet = SpyNet(spynet_path) + + # feature extraction module + if is_low_res_input: + self.feat_extract = ConvResidualBlocks(3, mid_channels, 5) + else: + self.feat_extract = nn.Sequential( + nn.Conv2d(3, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(mid_channels, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True), + ConvResidualBlocks(mid_channels, mid_channels, 5)) + + # propagation branches + self.deform_align = nn.ModuleDict() + self.backbone = nn.ModuleDict() + modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2'] + for i, module in enumerate(modules): + if torch.cuda.is_available(): + self.deform_align[module] = SecondOrderDeformableAlignment( + 2 * mid_channels, + mid_channels, + 3, + padding=1, + deformable_groups=16, + max_residue_magnitude=max_residue_magnitude) + self.backbone[module] = ConvResidualBlocks((2 + i) * mid_channels, mid_channels, num_blocks) + + # upsampling module + self.reconstruction = ConvResidualBlocks(5 * mid_channels, mid_channels, 5) + + self.upconv1 = nn.Conv2d(mid_channels, mid_channels * 4, 3, 1, 1, bias=True) + self.upconv2 = nn.Conv2d(mid_channels, 64 * 4, 3, 1, 1, bias=True) + + self.pixel_shuffle = nn.PixelShuffle(2) + + self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1) + self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) + self.img_upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + # check if the sequence is augmented by flipping + self.is_mirror_extended = False + + if len(self.deform_align) > 0: + self.is_with_alignment = True + else: + self.is_with_alignment = False + warnings.warn('Deformable alignment module is not added. ' + 'Probably your CUDA is not configured correctly. DCN can only ' + 'be used with CUDA enabled. Alignment is skipped now.') + + def check_if_mirror_extended(self, lqs): + """Check whether the input is a mirror-extended sequence. + If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the + (t-1-i)-th frame. + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + """ + + if lqs.size(1) % 2 == 0: + lqs_1, lqs_2 = torch.chunk(lqs, 2, dim=1) + if torch.norm(lqs_1 - lqs_2.flip(1)) == 0: + self.is_mirror_extended = True + + def compute_flow(self, lqs): + """Compute optical flow using SPyNet for feature alignment. + Note that if the input is an mirror-extended sequence, 'flows_forward' + is not needed, since it is equal to 'flows_backward.flip(1)'. + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + Return: + tuple(Tensor): Optical flow. 'flows_forward' corresponds to the + flows used for forward-time propagation (current to previous). + 'flows_backward' corresponds to the flows used for + backward-time propagation (current to next). + """ + + n, t, c, h, w = lqs.size() + lqs_1 = lqs[:, :-1, :, :, :].reshape(-1, c, h, w) + lqs_2 = lqs[:, 1:, :, :, :].reshape(-1, c, h, w) + + flows_backward = self.spynet(lqs_1, lqs_2).view(n, t - 1, 2, h, w) + + if self.is_mirror_extended: # flows_forward = flows_backward.flip(1) + flows_forward = flows_backward.flip(1) + else: + flows_forward = self.spynet(lqs_2, lqs_1).view(n, t - 1, 2, h, w) + + if self.cpu_cache: + flows_backward = flows_backward.cpu() + flows_forward = flows_forward.cpu() + + return flows_forward, flows_backward + + def propagate(self, feats, flows, module_name): + """Propagate the latent features throughout the sequence. + Args: + feats dict(list[tensor]): Features from previous branches. Each + component is a list of tensors with shape (n, c, h, w). + flows (tensor): Optical flows with shape (n, t - 1, 2, h, w). + module_name (str): The name of the propgation branches. Can either + be 'backward_1', 'forward_1', 'backward_2', 'forward_2'. + Return: + dict(list[tensor]): A dictionary containing all the propagated + features. Each key in the dictionary corresponds to a + propagation branch, which is represented by a list of tensors. + """ + + n, t, _, h, w = flows.size() + + frame_idx = range(0, t + 1) + flow_idx = range(-1, t) + mapping_idx = list(range(0, len(feats['spatial']))) + mapping_idx += mapping_idx[::-1] + + if 'backward' in module_name: + frame_idx = frame_idx[::-1] + flow_idx = frame_idx + + feat_prop = flows.new_zeros(n, self.mid_channels, h, w) + for i, idx in enumerate(frame_idx): + feat_current = feats['spatial'][mapping_idx[idx]] + if self.cpu_cache: + feat_current = feat_current.cuda() + feat_prop = feat_prop.cuda() + # second-order deformable alignment + if i > 0 and self.is_with_alignment: + flow_n1 = flows[:, flow_idx[i], :, :, :] + if self.cpu_cache: + flow_n1 = flow_n1.cuda() + + cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1)) + + # initialize second-order features + feat_n2 = torch.zeros_like(feat_prop) + flow_n2 = torch.zeros_like(flow_n1) + cond_n2 = torch.zeros_like(cond_n1) + + if i > 1: # second-order features + feat_n2 = feats[module_name][-2] + if self.cpu_cache: + feat_n2 = feat_n2.cuda() + + flow_n2 = flows[:, flow_idx[i - 1], :, :, :] + if self.cpu_cache: + flow_n2 = flow_n2.cuda() + + flow_n2 = flow_n1 + flow_warp(flow_n2, flow_n1.permute(0, 2, 3, 1)) + cond_n2 = flow_warp(feat_n2, flow_n2.permute(0, 2, 3, 1)) + + # flow-guided deformable convolution + cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1) + feat_prop = torch.cat([feat_prop, feat_n2], dim=1) + feat_prop = self.deform_align[module_name](feat_prop, cond, flow_n1, flow_n2) + + # concatenate and residual blocks + feat = [feat_current] + [feats[k][idx] for k in feats if k not in ['spatial', module_name]] + [feat_prop] + if self.cpu_cache: + feat = [f.cuda() for f in feat] + + feat = torch.cat(feat, dim=1) + feat_prop = feat_prop + self.backbone[module_name](feat) + feats[module_name].append(feat_prop) + + if self.cpu_cache: + feats[module_name][-1] = feats[module_name][-1].cpu() + torch.cuda.empty_cache() + + if 'backward' in module_name: + feats[module_name] = feats[module_name][::-1] + + return feats + + def upsample(self, lqs, feats): + """Compute the output image given the features. + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + feats (dict): The features from the propgation branches. + Returns: + Tensor: Output HR sequence with shape (n, t, c, 4h, 4w). + """ + + outputs = [] + num_outputs = len(feats['spatial']) + + mapping_idx = list(range(0, num_outputs)) + mapping_idx += mapping_idx[::-1] + + for i in range(0, lqs.size(1)): + hr = [feats[k].pop(0) for k in feats if k != 'spatial'] + hr.insert(0, feats['spatial'][mapping_idx[i]]) + hr = torch.cat(hr, dim=1) + if self.cpu_cache: + hr = hr.cuda() + + hr = self.reconstruction(hr) + hr = self.lrelu(self.pixel_shuffle(self.upconv1(hr))) + hr = self.lrelu(self.pixel_shuffle(self.upconv2(hr))) + hr = self.lrelu(self.conv_hr(hr)) + hr = self.conv_last(hr) + if self.is_low_res_input: + hr += self.img_upsample(lqs[:, i, :, :, :]) + else: + hr += lqs[:, i, :, :, :] + + if self.cpu_cache: + hr = hr.cpu() + torch.cuda.empty_cache() + + outputs.append(hr) + + return torch.stack(outputs, dim=1) + + def forward(self, lqs): + """Forward function for BasicVSR++. + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + Returns: + Tensor: Output HR sequence with shape (n, t, c, 4h, 4w). + """ + + n, t, c, h, w = lqs.size() + + # whether to cache the features in CPU + self.cpu_cache = True if t > self.cpu_cache_length else False + + if self.is_low_res_input: + lqs_downsample = lqs.clone() + else: + lqs_downsample = F.interpolate( + lqs.view(-1, c, h, w), scale_factor=0.25, mode='bicubic').view(n, t, c, h // 4, w // 4) + + # check whether the input is an extended sequence + self.check_if_mirror_extended(lqs) + + feats = {} + # compute spatial features + if self.cpu_cache: + feats['spatial'] = [] + for i in range(0, t): + feat = self.feat_extract(lqs[:, i, :, :, :]).cpu() + feats['spatial'].append(feat) + torch.cuda.empty_cache() + else: + feats_ = self.feat_extract(lqs.view(-1, c, h, w)) + h, w = feats_.shape[2:] + feats_ = feats_.view(n, t, -1, h, w) + feats['spatial'] = [feats_[:, i, :, :, :] for i in range(0, t)] + + # compute optical flow using the low-res inputs + assert lqs_downsample.size(3) >= 64 and lqs_downsample.size(4) >= 64, ( + 'The height and width of low-res inputs must be at least 64, ' + f'but got {h} and {w}.') + flows_forward, flows_backward = self.compute_flow(lqs_downsample) + + # feature propgation + for iter_ in [1, 2]: + for direction in ['backward', 'forward']: + module = f'{direction}_{iter_}' + + feats[module] = [] + + if direction == 'backward': + flows = flows_backward + elif flows_forward is not None: + flows = flows_forward + else: + flows = flows_backward.flip(1) + + feats = self.propagate(feats, flows, module) + if self.cpu_cache: + del flows + torch.cuda.empty_cache() + + return self.upsample(lqs, feats) + + +class SecondOrderDeformableAlignment(ModulatedDeformConvPack): + """Second-order deformable alignment module. + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + max_residue_magnitude (int): The maximum magnitude of the offset + residue (Eq. 6 in paper). Default: 10. + """ + + def __init__(self, *args, **kwargs): + self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10) + + super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Sequential( + nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, 27 * self.deformable_groups, 3, 1, 1), + ) + + self.init_offset() + + def init_offset(self): + + def _constant_init(module, val, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.constant_(module.weight, val) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + _constant_init(self.conv_offset[-1], val=0, bias=0) + + def forward(self, x, extra_feat, flow_1, flow_2): + extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1) + out = self.conv_offset(extra_feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + + # offset + offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1)) + offset_1, offset_2 = torch.chunk(offset, 2, dim=1) + offset_1 = offset_1 + flow_1.flip(1).repeat(1, offset_1.size(1) // 2, 1, 1) + offset_2 = offset_2 + flow_2.flip(1).repeat(1, offset_2.size(1) // 2, 1, 1) + offset = torch.cat([offset_1, offset_2], dim=1) + + # mask + mask = torch.sigmoid(mask) + + return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, + self.dilation, mask) + + +# if __name__ == '__main__': +# spynet_path = 'experiments/pretrained_models/flownet/spynet_sintel_final-3d2a1287.pth' +# model = BasicVSRPlusPlus(spynet_path=spynet_path).cuda() +# input = torch.rand(1, 2, 3, 64, 64).cuda() +# output = model(input) +# print('===================') +# print(output.shape) -- cgit v1.2.3