diff options
Diffstat (limited to 'r_basicsr/archs/srresnet_arch.py')
-rw-r--r-- | r_basicsr/archs/srresnet_arch.py | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/r_basicsr/archs/srresnet_arch.py b/r_basicsr/archs/srresnet_arch.py new file mode 100644 index 0000000..99b56a4 --- /dev/null +++ b/r_basicsr/archs/srresnet_arch.py @@ -0,0 +1,65 @@ +from torch import nn as nn
+from torch.nn import functional as F
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer
+
+
+@ARCH_REGISTRY.register()
+class MSRResNet(nn.Module):
+ """Modified SRResNet.
+
+ A compacted version modified from SRResNet in
+ "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"
+ It uses residual blocks without BN, similar to EDSR.
+ Currently, it supports x2, x3 and x4 upsampling scale factor.
+
+ Args:
+ num_in_ch (int): Channel number of inputs. Default: 3.
+ num_out_ch (int): Channel number of outputs. Default: 3.
+ num_feat (int): Channel number of intermediate features. Default: 64.
+ num_block (int): Block number in the body network. Default: 16.
+ upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4.
+ """
+
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4):
+ super(MSRResNet, self).__init__()
+ self.upscale = upscale
+
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat)
+
+ # upsampling
+ if self.upscale in [2, 3]:
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, 1, 1)
+ self.pixel_shuffle = nn.PixelShuffle(self.upscale)
+ elif self.upscale == 4:
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
+ self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
+ self.pixel_shuffle = nn.PixelShuffle(2)
+
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ # activation function
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ # initialization
+ default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1)
+ if self.upscale == 4:
+ default_init_weights(self.upconv2, 0.1)
+
+ def forward(self, x):
+ feat = self.lrelu(self.conv_first(x))
+ out = self.body(feat)
+
+ if self.upscale == 4:
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
+ out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
+ elif self.upscale in [2, 3]:
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
+
+ out = self.conv_last(self.lrelu(self.conv_hr(out)))
+ base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
+ out += base
+ return out
|