From 495ffc4777522e40941753e3b1b79c02f84b25b4 Mon Sep 17 00:00:00 2001 From: Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:00:30 +0000 Subject: Add files via upload --- r_basicsr/archs/edvr_arch.py | 383 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 383 insertions(+) create mode 100644 r_basicsr/archs/edvr_arch.py (limited to 'r_basicsr/archs/edvr_arch.py') diff --git a/r_basicsr/archs/edvr_arch.py b/r_basicsr/archs/edvr_arch.py new file mode 100644 index 0000000..401b9b2 --- /dev/null +++ b/r_basicsr/archs/edvr_arch.py @@ -0,0 +1,383 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from r_basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import DCNv2Pack, ResidualBlockNoBN, make_layer + + +class PCDAlignment(nn.Module): + """Alignment module using Pyramid, Cascading and Deformable convolution + (PCD). It is used in EDVR. + + Ref: + EDVR: Video Restoration with Enhanced Deformable Convolutional Networks + + Args: + num_feat (int): Channel number of middle features. Default: 64. + deformable_groups (int): Deformable groups. Defaults: 8. + """ + + def __init__(self, num_feat=64, deformable_groups=8): + super(PCDAlignment, self).__init__() + + # Pyramid has three levels: + # L3: level 3, 1/4 spatial size + # L2: level 2, 1/2 spatial size + # L1: level 1, original spatial size + self.offset_conv1 = nn.ModuleDict() + self.offset_conv2 = nn.ModuleDict() + self.offset_conv3 = nn.ModuleDict() + self.dcn_pack = nn.ModuleDict() + self.feat_conv = nn.ModuleDict() + + # Pyramids + for i in range(3, 0, -1): + level = f'l{i}' + self.offset_conv1[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) + if i == 3: + self.offset_conv2[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + else: + self.offset_conv2[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) + self.offset_conv3[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.dcn_pack[level] = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups) + + if i < 3: + self.feat_conv[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) + + # Cascading dcn + self.cas_offset_conv1 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) + self.cas_offset_conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.cas_dcnpack = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups) + + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, nbr_feat_l, ref_feat_l): + """Align neighboring frame features to the reference frame features. + + Args: + nbr_feat_l (list[Tensor]): Neighboring feature list. It + contains three pyramid levels (L1, L2, L3), + each with shape (b, c, h, w). + ref_feat_l (list[Tensor]): Reference feature list. It + contains three pyramid levels (L1, L2, L3), + each with shape (b, c, h, w). + + Returns: + Tensor: Aligned features. + """ + # Pyramids + upsampled_offset, upsampled_feat = None, None + for i in range(3, 0, -1): + level = f'l{i}' + offset = torch.cat([nbr_feat_l[i - 1], ref_feat_l[i - 1]], dim=1) + offset = self.lrelu(self.offset_conv1[level](offset)) + if i == 3: + offset = self.lrelu(self.offset_conv2[level](offset)) + else: + offset = self.lrelu(self.offset_conv2[level](torch.cat([offset, upsampled_offset], dim=1))) + offset = self.lrelu(self.offset_conv3[level](offset)) + + feat = self.dcn_pack[level](nbr_feat_l[i - 1], offset) + if i < 3: + feat = self.feat_conv[level](torch.cat([feat, upsampled_feat], dim=1)) + if i > 1: + feat = self.lrelu(feat) + + if i > 1: # upsample offset and features + # x2: when we upsample the offset, we should also enlarge + # the magnitude. + upsampled_offset = self.upsample(offset) * 2 + upsampled_feat = self.upsample(feat) + + # Cascading + offset = torch.cat([feat, ref_feat_l[0]], dim=1) + offset = self.lrelu(self.cas_offset_conv2(self.lrelu(self.cas_offset_conv1(offset)))) + feat = self.lrelu(self.cas_dcnpack(feat, offset)) + return feat + + +class TSAFusion(nn.Module): + """Temporal Spatial Attention (TSA) fusion module. + + Temporal: Calculate the correlation between center frame and + neighboring frames; + Spatial: It has 3 pyramid levels, the attention is similar to SFT. + (SFT: Recovering realistic texture in image super-resolution by deep + spatial feature transform.) + + Args: + num_feat (int): Channel number of middle features. Default: 64. + num_frame (int): Number of frames. Default: 5. + center_frame_idx (int): The index of center frame. Default: 2. + """ + + def __init__(self, num_feat=64, num_frame=5, center_frame_idx=2): + super(TSAFusion, self).__init__() + self.center_frame_idx = center_frame_idx + # temporal attention (before fusion conv) + self.temporal_attn1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.temporal_attn2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.feat_fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1) + + # spatial attention (after fusion conv) + self.max_pool = nn.MaxPool2d(3, stride=2, padding=1) + self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1) + self.spatial_attn1 = nn.Conv2d(num_frame * num_feat, num_feat, 1) + self.spatial_attn2 = nn.Conv2d(num_feat * 2, num_feat, 1) + self.spatial_attn3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.spatial_attn4 = nn.Conv2d(num_feat, num_feat, 1) + self.spatial_attn5 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.spatial_attn_l1 = nn.Conv2d(num_feat, num_feat, 1) + self.spatial_attn_l2 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) + self.spatial_attn_l3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.spatial_attn_add1 = nn.Conv2d(num_feat, num_feat, 1) + self.spatial_attn_add2 = nn.Conv2d(num_feat, num_feat, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + + def forward(self, aligned_feat): + """ + Args: + aligned_feat (Tensor): Aligned features with shape (b, t, c, h, w). + + Returns: + Tensor: Features after TSA with the shape (b, c, h, w). + """ + b, t, c, h, w = aligned_feat.size() + # temporal attention + embedding_ref = self.temporal_attn1(aligned_feat[:, self.center_frame_idx, :, :, :].clone()) + embedding = self.temporal_attn2(aligned_feat.view(-1, c, h, w)) + embedding = embedding.view(b, t, -1, h, w) # (b, t, c, h, w) + + corr_l = [] # correlation list + for i in range(t): + emb_neighbor = embedding[:, i, :, :, :] + corr = torch.sum(emb_neighbor * embedding_ref, 1) # (b, h, w) + corr_l.append(corr.unsqueeze(1)) # (b, 1, h, w) + corr_prob = torch.sigmoid(torch.cat(corr_l, dim=1)) # (b, t, h, w) + corr_prob = corr_prob.unsqueeze(2).expand(b, t, c, h, w) + corr_prob = corr_prob.contiguous().view(b, -1, h, w) # (b, t*c, h, w) + aligned_feat = aligned_feat.view(b, -1, h, w) * corr_prob + + # fusion + feat = self.lrelu(self.feat_fusion(aligned_feat)) + + # spatial attention + attn = self.lrelu(self.spatial_attn1(aligned_feat)) + attn_max = self.max_pool(attn) + attn_avg = self.avg_pool(attn) + attn = self.lrelu(self.spatial_attn2(torch.cat([attn_max, attn_avg], dim=1))) + # pyramid levels + attn_level = self.lrelu(self.spatial_attn_l1(attn)) + attn_max = self.max_pool(attn_level) + attn_avg = self.avg_pool(attn_level) + attn_level = self.lrelu(self.spatial_attn_l2(torch.cat([attn_max, attn_avg], dim=1))) + attn_level = self.lrelu(self.spatial_attn_l3(attn_level)) + attn_level = self.upsample(attn_level) + + attn = self.lrelu(self.spatial_attn3(attn)) + attn_level + attn = self.lrelu(self.spatial_attn4(attn)) + attn = self.upsample(attn) + attn = self.spatial_attn5(attn) + attn_add = self.spatial_attn_add2(self.lrelu(self.spatial_attn_add1(attn))) + attn = torch.sigmoid(attn) + + # after initialization, * 2 makes (attn * 2) to be close to 1. + feat = feat * attn * 2 + attn_add + return feat + + +class PredeblurModule(nn.Module): + """Pre-dublur module. + + Args: + num_in_ch (int): Channel number of input image. Default: 3. + num_feat (int): Channel number of intermediate features. Default: 64. + hr_in (bool): Whether the input has high resolution. Default: False. + """ + + def __init__(self, num_in_ch=3, num_feat=64, hr_in=False): + super(PredeblurModule, self).__init__() + self.hr_in = hr_in + + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + if self.hr_in: + # downsample x4 by stride conv + self.stride_conv_hr1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + self.stride_conv_hr2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + + # generate feature pyramid + self.stride_conv_l2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + self.stride_conv_l3 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + + self.resblock_l3 = ResidualBlockNoBN(num_feat=num_feat) + self.resblock_l2_1 = ResidualBlockNoBN(num_feat=num_feat) + self.resblock_l2_2 = ResidualBlockNoBN(num_feat=num_feat) + self.resblock_l1 = nn.ModuleList([ResidualBlockNoBN(num_feat=num_feat) for i in range(5)]) + + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, x): + feat_l1 = self.lrelu(self.conv_first(x)) + if self.hr_in: + feat_l1 = self.lrelu(self.stride_conv_hr1(feat_l1)) + feat_l1 = self.lrelu(self.stride_conv_hr2(feat_l1)) + + # generate feature pyramid + feat_l2 = self.lrelu(self.stride_conv_l2(feat_l1)) + feat_l3 = self.lrelu(self.stride_conv_l3(feat_l2)) + + feat_l3 = self.upsample(self.resblock_l3(feat_l3)) + feat_l2 = self.resblock_l2_1(feat_l2) + feat_l3 + feat_l2 = self.upsample(self.resblock_l2_2(feat_l2)) + + for i in range(2): + feat_l1 = self.resblock_l1[i](feat_l1) + feat_l1 = feat_l1 + feat_l2 + for i in range(2, 5): + feat_l1 = self.resblock_l1[i](feat_l1) + return feat_l1 + + +@ARCH_REGISTRY.register() +class EDVR(nn.Module): + """EDVR network structure for video super-resolution. + + Now only support X4 upsampling factor. + Paper: + EDVR: Video Restoration with Enhanced Deformable Convolutional Networks + + Args: + num_in_ch (int): Channel number of input image. Default: 3. + num_out_ch (int): Channel number of output image. Default: 3. + num_feat (int): Channel number of intermediate features. Default: 64. + num_frame (int): Number of input frames. Default: 5. + deformable_groups (int): Deformable groups. Defaults: 8. + num_extract_block (int): Number of blocks for feature extraction. + Default: 5. + num_reconstruct_block (int): Number of blocks for reconstruction. + Default: 10. + center_frame_idx (int): The index of center frame. Frame counting from + 0. Default: Middle of input frames. + hr_in (bool): Whether the input has high resolution. Default: False. + with_predeblur (bool): Whether has predeblur module. + Default: False. + with_tsa (bool): Whether has TSA module. Default: True. + """ + + def __init__(self, + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_frame=5, + deformable_groups=8, + num_extract_block=5, + num_reconstruct_block=10, + center_frame_idx=None, + hr_in=False, + with_predeblur=False, + with_tsa=True): + super(EDVR, self).__init__() + if center_frame_idx is None: + self.center_frame_idx = num_frame // 2 + else: + self.center_frame_idx = center_frame_idx + self.hr_in = hr_in + self.with_predeblur = with_predeblur + self.with_tsa = with_tsa + + # extract features for each frame + if self.with_predeblur: + self.predeblur = PredeblurModule(num_feat=num_feat, hr_in=self.hr_in) + self.conv_1x1 = nn.Conv2d(num_feat, num_feat, 1, 1) + else: + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + + # extract pyramid features + self.feature_extraction = make_layer(ResidualBlockNoBN, num_extract_block, num_feat=num_feat) + self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + + # pcd and tsa module + self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=deformable_groups) + if self.with_tsa: + self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_frame, center_frame_idx=self.center_frame_idx) + else: + self.fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1) + + # reconstruction + self.reconstruction = make_layer(ResidualBlockNoBN, num_reconstruct_block, num_feat=num_feat) + # upsample + self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) + self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1) + self.pixel_shuffle = nn.PixelShuffle(2) + self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1) + self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, x): + b, t, c, h, w = x.size() + if self.hr_in: + assert h % 16 == 0 and w % 16 == 0, ('The height and width must be multiple of 16.') + else: + assert h % 4 == 0 and w % 4 == 0, ('The height and width must be multiple of 4.') + + x_center = x[:, self.center_frame_idx, :, :, :].contiguous() + + # extract features for each frame + # L1 + if self.with_predeblur: + feat_l1 = self.conv_1x1(self.predeblur(x.view(-1, c, h, w))) + if self.hr_in: + h, w = h // 4, w // 4 + else: + feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w))) + + feat_l1 = self.feature_extraction(feat_l1) + # L2 + feat_l2 = self.lrelu(self.conv_l2_1(feat_l1)) + feat_l2 = self.lrelu(self.conv_l2_2(feat_l2)) + # L3 + feat_l3 = self.lrelu(self.conv_l3_1(feat_l2)) + feat_l3 = self.lrelu(self.conv_l3_2(feat_l3)) + + feat_l1 = feat_l1.view(b, t, -1, h, w) + feat_l2 = feat_l2.view(b, t, -1, h // 2, w // 2) + feat_l3 = feat_l3.view(b, t, -1, h // 4, w // 4) + + # PCD alignment + ref_feat_l = [ # reference feature list + feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(), + feat_l3[:, self.center_frame_idx, :, :, :].clone() + ] + aligned_feat = [] + for i in range(t): + nbr_feat_l = [ # neighboring feature list + feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone() + ] + aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l)) + aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w) + + if not self.with_tsa: + aligned_feat = aligned_feat.view(b, -1, h, w) + feat = self.fusion(aligned_feat) + + out = self.reconstruction(feat) + out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) + out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) + out = self.lrelu(self.conv_hr(out)) + out = self.conv_last(out) + if self.hr_in: + base = x_center + else: + base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False) + out += base + return out -- cgit v1.2.3