diff options
Diffstat (limited to 'r_basicsr/archs/dfdnet_arch.py')
-rw-r--r-- | r_basicsr/archs/dfdnet_arch.py | 169 |
1 files changed, 169 insertions, 0 deletions
diff --git a/r_basicsr/archs/dfdnet_arch.py b/r_basicsr/archs/dfdnet_arch.py new file mode 100644 index 0000000..04093f0 --- /dev/null +++ b/r_basicsr/archs/dfdnet_arch.py @@ -0,0 +1,169 @@ +import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils.spectral_norm import spectral_norm
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+from .dfdnet_util import AttentionBlock, Blur, MSDilationBlock, UpResBlock, adaptive_instance_normalization
+from .vgg_arch import VGGFeatureExtractor
+
+
+class SFTUpBlock(nn.Module):
+ """Spatial feature transform (SFT) with upsampling block.
+
+ Args:
+ in_channel (int): Number of input channels.
+ out_channel (int): Number of output channels.
+ kernel_size (int): Kernel size in convolutions. Default: 3.
+ padding (int): Padding in convolutions. Default: 1.
+ """
+
+ def __init__(self, in_channel, out_channel, kernel_size=3, padding=1):
+ super(SFTUpBlock, self).__init__()
+ self.conv1 = nn.Sequential(
+ Blur(in_channel),
+ spectral_norm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
+ nn.LeakyReLU(0.04, True),
+ # The official codes use two LeakyReLU here, so 0.04 for equivalent
+ )
+ self.convup = nn.Sequential(
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
+ spectral_norm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
+ nn.LeakyReLU(0.2, True),
+ )
+
+ # for SFT scale and shift
+ self.scale_block = nn.Sequential(
+ spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
+ spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)))
+ self.shift_block = nn.Sequential(
+ spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
+ spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), nn.Sigmoid())
+ # The official codes use sigmoid for shift block, do not know why
+
+ def forward(self, x, updated_feat):
+ out = self.conv1(x)
+ # SFT
+ scale = self.scale_block(updated_feat)
+ shift = self.shift_block(updated_feat)
+ out = out * scale + shift
+ # upsample
+ out = self.convup(out)
+ return out
+
+
+@ARCH_REGISTRY.register()
+class DFDNet(nn.Module):
+ """DFDNet: Deep Face Dictionary Network.
+
+ It only processes faces with 512x512 size.
+
+ Args:
+ num_feat (int): Number of feature channels.
+ dict_path (str): Path to the facial component dictionary.
+ """
+
+ def __init__(self, num_feat, dict_path):
+ super().__init__()
+ self.parts = ['left_eye', 'right_eye', 'nose', 'mouth']
+ # part_sizes: [80, 80, 50, 110]
+ channel_sizes = [128, 256, 512, 512]
+ self.feature_sizes = np.array([256, 128, 64, 32])
+ self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4']
+ self.flag_dict_device = False
+
+ # dict
+ self.dict = torch.load(dict_path)
+
+ # vgg face extractor
+ self.vgg_extractor = VGGFeatureExtractor(
+ layer_name_list=self.vgg_layers,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=True,
+ requires_grad=False)
+
+ # attention block for fusing dictionary features and input features
+ self.attn_blocks = nn.ModuleDict()
+ for idx, feat_size in enumerate(self.feature_sizes):
+ for name in self.parts:
+ self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock(channel_sizes[idx])
+
+ # multi scale dilation block
+ self.multi_scale_dilation = MSDilationBlock(num_feat * 8, dilation=[4, 3, 2, 1])
+
+ # upsampling and reconstruction
+ self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8)
+ self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4)
+ self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2)
+ self.upsample3 = SFTUpBlock(num_feat * 2, num_feat)
+ self.upsample4 = nn.Sequential(
+ spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), nn.LeakyReLU(0.2, True), UpResBlock(num_feat),
+ UpResBlock(num_feat), nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), nn.Tanh())
+
+ def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_size):
+ """swap the features from the dictionary."""
+ # get the original vgg features
+ part_feat = vgg_feat[:, :, location[1]:location[3], location[0]:location[2]].clone()
+ # resize original vgg features
+ part_resize_feat = F.interpolate(part_feat, dict_feat.size()[2:4], mode='bilinear', align_corners=False)
+ # use adaptive instance normalization to adjust color and illuminations
+ dict_feat = adaptive_instance_normalization(dict_feat, part_resize_feat)
+ # get similarity scores
+ similarity_score = F.conv2d(part_resize_feat, dict_feat)
+ similarity_score = F.softmax(similarity_score.view(-1), dim=0)
+ # select the most similar features in the dict (after norm)
+ select_idx = torch.argmax(similarity_score)
+ swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], part_feat.size()[2:4])
+ # attention
+ attn = self.attn_blocks[f'{part_name}_' + str(f_size)](swap_feat - part_feat)
+ attn_feat = attn * swap_feat
+ # update features
+ updated_feat[:, :, location[1]:location[3], location[0]:location[2]] = attn_feat + part_feat
+ return updated_feat
+
+ def put_dict_to_device(self, x):
+ if self.flag_dict_device is False:
+ for k, v in self.dict.items():
+ for kk, vv in v.items():
+ self.dict[k][kk] = vv.to(x)
+ self.flag_dict_device = True
+
+ def forward(self, x, part_locations):
+ """
+ Now only support testing with batch size = 0.
+
+ Args:
+ x (Tensor): Input faces with shape (b, c, 512, 512).
+ part_locations (list[Tensor]): Part locations.
+ """
+ self.put_dict_to_device(x)
+ # extract vggface features
+ vgg_features = self.vgg_extractor(x)
+ # update vggface features using the dictionary for each part
+ updated_vgg_features = []
+ batch = 0 # only supports testing with batch size = 0
+ for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes):
+ dict_features = self.dict[f'{f_size}']
+ vgg_feat = vgg_features[vgg_layer]
+ updated_feat = vgg_feat.clone()
+
+ # swap features from dictionary
+ for part_idx, part_name in enumerate(self.parts):
+ location = (part_locations[part_idx][batch] // (512 / f_size)).int()
+ updated_feat = self.swap_feat(vgg_feat, updated_feat, dict_features[part_name], location, part_name,
+ f_size)
+
+ updated_vgg_features.append(updated_feat)
+
+ vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4'])
+ # use updated vgg features to modulate the upsampled features with
+ # SFT (Spatial Feature Transform) scaling and shifting manner.
+ upsampled_feat = self.upsample0(vgg_feat_dilation, updated_vgg_features[3])
+ upsampled_feat = self.upsample1(upsampled_feat, updated_vgg_features[2])
+ upsampled_feat = self.upsample2(upsampled_feat, updated_vgg_features[1])
+ upsampled_feat = self.upsample3(upsampled_feat, updated_vgg_features[0])
+ out = self.upsample4(upsampled_feat)
+
+ return out
|