summaryrefslogtreecommitdiffstats
path: root/r_basicsr/archs/spynet_arch.py
diff options
context:
space:
mode:
Diffstat (limited to 'r_basicsr/archs/spynet_arch.py')
-rw-r--r--r_basicsr/archs/spynet_arch.py96
1 files changed, 96 insertions, 0 deletions
diff --git a/r_basicsr/archs/spynet_arch.py b/r_basicsr/archs/spynet_arch.py
new file mode 100644
index 0000000..2bd143c
--- /dev/null
+++ b/r_basicsr/archs/spynet_arch.py
@@ -0,0 +1,96 @@
+import math
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from r_basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import flow_warp
+
+
+class BasicModule(nn.Module):
+ """Basic Module for SpyNet.
+ """
+
+ def __init__(self):
+ super(BasicModule, self).__init__()
+
+ self.basic_module = nn.Sequential(
+ nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
+
+ def forward(self, tensor_input):
+ return self.basic_module(tensor_input)
+
+
+@ARCH_REGISTRY.register()
+class SpyNet(nn.Module):
+ """SpyNet architecture.
+
+ Args:
+ load_path (str): path for pretrained SpyNet. Default: None.
+ """
+
+ def __init__(self, load_path=None):
+ super(SpyNet, self).__init__()
+ self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)])
+ if load_path:
+ self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
+
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def preprocess(self, tensor_input):
+ tensor_output = (tensor_input - self.mean) / self.std
+ return tensor_output
+
+ def process(self, ref, supp):
+ flow = []
+
+ ref = [self.preprocess(ref)]
+ supp = [self.preprocess(supp)]
+
+ for level in range(5):
+ ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
+ supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
+
+ flow = ref[0].new_zeros(
+ [ref[0].size(0), 2,
+ int(math.floor(ref[0].size(2) / 2.0)),
+ int(math.floor(ref[0].size(3) / 2.0))])
+
+ for level in range(len(ref)):
+ upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
+
+ if upsampled_flow.size(2) != ref[level].size(2):
+ upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate')
+ if upsampled_flow.size(3) != ref[level].size(3):
+ upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate')
+
+ flow = self.basic_module[level](torch.cat([
+ ref[level],
+ flow_warp(
+ supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'),
+ upsampled_flow
+ ], 1)) + upsampled_flow
+
+ return flow
+
+ def forward(self, ref, supp):
+ assert ref.size() == supp.size()
+
+ h, w = ref.size(2), ref.size(3)
+ w_floor = math.floor(math.ceil(w / 32.0) * 32.0)
+ h_floor = math.floor(math.ceil(h / 32.0) * 32.0)
+
+ ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
+ supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
+
+ flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False)
+
+ flow[:, 0, :, :] *= float(w) / float(w_floor)
+ flow[:, 1, :, :] *= float(h) / float(h_floor)
+
+ return flow