diff options
Diffstat (limited to 'r_basicsr/archs/hifacegan_util.py')
-rw-r--r-- | r_basicsr/archs/hifacegan_util.py | 255 |
1 files changed, 255 insertions, 0 deletions
diff --git a/r_basicsr/archs/hifacegan_util.py b/r_basicsr/archs/hifacegan_util.py new file mode 100644 index 0000000..b63b928 --- /dev/null +++ b/r_basicsr/archs/hifacegan_util.py @@ -0,0 +1,255 @@ +import re
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import init
+# Warning: spectral norm could be buggy
+# under eval mode and multi-GPU inference
+# A workaround is sticking to single-GPU inference and train mode
+from torch.nn.utils import spectral_norm
+
+
+class SPADE(nn.Module):
+
+ def __init__(self, config_text, norm_nc, label_nc):
+ super().__init__()
+
+ assert config_text.startswith('spade')
+ parsed = re.search('spade(\\D+)(\\d)x\\d', config_text)
+ param_free_norm_type = str(parsed.group(1))
+ ks = int(parsed.group(2))
+
+ if param_free_norm_type == 'instance':
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc)
+ elif param_free_norm_type == 'syncbatch':
+ print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc)
+ elif param_free_norm_type == 'batch':
+ self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
+ else:
+ raise ValueError(f'{param_free_norm_type} is not a recognized param-free norm type in SPADE')
+
+ # The dimension of the intermediate embedding space. Yes, hardcoded.
+ nhidden = 128 if norm_nc > 128 else norm_nc
+
+ pw = ks // 2
+ self.mlp_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
+
+ def forward(self, x, segmap):
+
+ # Part 1. generate parameter-free normalized activations
+ normalized = self.param_free_norm(x)
+
+ # Part 2. produce scaling and bias conditioned on semantic map
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
+ actv = self.mlp_shared(segmap)
+ gamma = self.mlp_gamma(actv)
+ beta = self.mlp_beta(actv)
+
+ # apply scale and bias
+ out = normalized * gamma + beta
+
+ return out
+
+
+class SPADEResnetBlock(nn.Module):
+ """
+ ResNet block that uses SPADE. It differs from the ResNet block of pix2pixHD in that
+ it takes in the segmentation map as input, learns the skip connection if necessary,
+ and applies normalization first and then convolution.
+ This architecture seemed like a standard architecture for unconditional or
+ class-conditional GAN architecture using residual block.
+ The code was inspired from https://github.com/LMescheder/GAN_stability.
+ """
+
+ def __init__(self, fin, fout, norm_g='spectralspadesyncbatch3x3', semantic_nc=3):
+ super().__init__()
+ # Attributes
+ self.learned_shortcut = (fin != fout)
+ fmiddle = min(fin, fout)
+
+ # create conv layers
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
+ if self.learned_shortcut:
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
+
+ # apply spectral norm if specified
+ if 'spectral' in norm_g:
+ self.conv_0 = spectral_norm(self.conv_0)
+ self.conv_1 = spectral_norm(self.conv_1)
+ if self.learned_shortcut:
+ self.conv_s = spectral_norm(self.conv_s)
+
+ # define normalization layers
+ spade_config_str = norm_g.replace('spectral', '')
+ self.norm_0 = SPADE(spade_config_str, fin, semantic_nc)
+ self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc)
+ if self.learned_shortcut:
+ self.norm_s = SPADE(spade_config_str, fin, semantic_nc)
+
+ # note the resnet block with SPADE also takes in |seg|,
+ # the semantic segmentation map as input
+ def forward(self, x, seg):
+ x_s = self.shortcut(x, seg)
+ dx = self.conv_0(self.act(self.norm_0(x, seg)))
+ dx = self.conv_1(self.act(self.norm_1(dx, seg)))
+ out = x_s + dx
+ return out
+
+ def shortcut(self, x, seg):
+ if self.learned_shortcut:
+ x_s = self.conv_s(self.norm_s(x, seg))
+ else:
+ x_s = x
+ return x_s
+
+ def act(self, x):
+ return F.leaky_relu(x, 2e-1)
+
+
+class BaseNetwork(nn.Module):
+ """ A basis for hifacegan archs with custom initialization """
+
+ def init_weights(self, init_type='normal', gain=0.02):
+
+ def init_func(m):
+ classname = m.__class__.__name__
+ if classname.find('BatchNorm2d') != -1:
+ if hasattr(m, 'weight') and m.weight is not None:
+ init.normal_(m.weight.data, 1.0, gain)
+ if hasattr(m, 'bias') and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
+ if init_type == 'normal':
+ init.normal_(m.weight.data, 0.0, gain)
+ elif init_type == 'xavier':
+ init.xavier_normal_(m.weight.data, gain=gain)
+ elif init_type == 'xavier_uniform':
+ init.xavier_uniform_(m.weight.data, gain=1.0)
+ elif init_type == 'kaiming':
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+ elif init_type == 'orthogonal':
+ init.orthogonal_(m.weight.data, gain=gain)
+ elif init_type == 'none': # uses pytorch's default init method
+ m.reset_parameters()
+ else:
+ raise NotImplementedError(f'initialization method [{init_type}] is not implemented')
+ if hasattr(m, 'bias') and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+
+ self.apply(init_func)
+
+ # propagate to children
+ for m in self.children():
+ if hasattr(m, 'init_weights'):
+ m.init_weights(init_type, gain)
+
+ def forward(self, x):
+ pass
+
+
+def lip2d(x, logit, kernel=3, stride=2, padding=1):
+ weight = logit.exp()
+ return F.avg_pool2d(x * weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding)
+
+
+class SoftGate(nn.Module):
+ COEFF = 12.0
+
+ def forward(self, x):
+ return torch.sigmoid(x).mul(self.COEFF)
+
+
+class SimplifiedLIP(nn.Module):
+
+ def __init__(self, channels):
+ super(SimplifiedLIP, self).__init__()
+ self.logit = nn.Sequential(
+ nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.InstanceNorm2d(channels, affine=True),
+ SoftGate())
+
+ def init_layer(self):
+ self.logit[0].weight.data.fill_(0.0)
+
+ def forward(self, x):
+ frac = lip2d(x, self.logit(x))
+ return frac
+
+
+class LIPEncoder(BaseNetwork):
+ """Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)"""
+
+ def __init__(self, input_nc, ngf, sw, sh, n_2xdown, norm_layer=nn.InstanceNorm2d):
+ super().__init__()
+ self.sw = sw
+ self.sh = sh
+ self.max_ratio = 16
+ # 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold
+ kw = 3
+ pw = (kw - 1) // 2
+
+ model = [
+ nn.Conv2d(input_nc, ngf, kw, stride=1, padding=pw, bias=False),
+ norm_layer(ngf),
+ nn.ReLU(),
+ ]
+ cur_ratio = 1
+ for i in range(n_2xdown):
+ next_ratio = min(cur_ratio * 2, self.max_ratio)
+ model += [
+ SimplifiedLIP(ngf * cur_ratio),
+ nn.Conv2d(ngf * cur_ratio, ngf * next_ratio, kw, stride=1, padding=pw),
+ norm_layer(ngf * next_ratio),
+ ]
+ cur_ratio = next_ratio
+ if i < n_2xdown - 1:
+ model += [nn.ReLU(inplace=True)]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+def get_nonspade_norm_layer(norm_type='instance'):
+ # helper function to get # output channels of the previous layer
+ def get_out_channel(layer):
+ if hasattr(layer, 'out_channels'):
+ return getattr(layer, 'out_channels')
+ return layer.weight.size(0)
+
+ # this function will be returned
+ def add_norm_layer(layer):
+ nonlocal norm_type
+ if norm_type.startswith('spectral'):
+ layer = spectral_norm(layer)
+ subnorm_type = norm_type[len('spectral'):]
+
+ if subnorm_type == 'none' or len(subnorm_type) == 0:
+ return layer
+
+ # remove bias in the previous layer, which is meaningless
+ # since it has no effect after normalization
+ if getattr(layer, 'bias', None) is not None:
+ delattr(layer, 'bias')
+ layer.register_parameter('bias', None)
+
+ if subnorm_type == 'batch':
+ norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
+ elif subnorm_type == 'sync_batch':
+ print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
+ # norm_layer = SynchronizedBatchNorm2d(
+ # get_out_channel(layer), affine=True)
+ norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
+ elif subnorm_type == 'instance':
+ norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
+ else:
+ raise ValueError(f'normalization layer {subnorm_type} is not recognized')
+
+ return nn.Sequential(layer, norm_layer)
+
+ print('This is a legacy from nvlabs/SPADE, and will be removed in future versions.')
+ return add_norm_layer
|