From 495ffc4777522e40941753e3b1b79c02f84b25b4 Mon Sep 17 00:00:00 2001 From: Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:00:30 +0000 Subject: Add files via upload --- r_basicsr/archs/tof_arch.py | 172 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 172 insertions(+) create mode 100644 r_basicsr/archs/tof_arch.py (limited to 'r_basicsr/archs/tof_arch.py') diff --git a/r_basicsr/archs/tof_arch.py b/r_basicsr/archs/tof_arch.py new file mode 100644 index 0000000..6b0fefa --- /dev/null +++ b/r_basicsr/archs/tof_arch.py @@ -0,0 +1,172 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from r_basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import flow_warp + + +class BasicModule(nn.Module): + """Basic module of SPyNet. + + Note that unlike the architecture in spynet_arch.py, the basic module + here contains batch normalization. + """ + + def __init__(self): + super(BasicModule, self).__init__() + self.basic_module = nn.Sequential( + nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False), + nn.BatchNorm2d(32), nn.ReLU(inplace=True), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False), + nn.BatchNorm2d(64), nn.ReLU(inplace=True), + nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False), + nn.BatchNorm2d(32), nn.ReLU(inplace=True), + nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3, bias=False), + nn.BatchNorm2d(16), nn.ReLU(inplace=True), + nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) + + def forward(self, tensor_input): + """ + Args: + tensor_input (Tensor): Input tensor with shape (b, 8, h, w). + 8 channels contain: + [reference image (3), neighbor image (3), initial flow (2)]. + + Returns: + Tensor: Estimated flow with shape (b, 2, h, w) + """ + return self.basic_module(tensor_input) + + +class SPyNetTOF(nn.Module): + """SPyNet architecture for TOF. + + Note that this implementation is specifically for TOFlow. Please use + spynet_arch.py for general use. They differ in the following aspects: + 1. The basic modules here contain BatchNorm. + 2. Normalization and denormalization are not done here, as + they are done in TOFlow. + Paper: + Optical Flow Estimation using a Spatial Pyramid Network + Code reference: + https://github.com/Coldog2333/pytoflow + + Args: + load_path (str): Path for pretrained SPyNet. Default: None. + """ + + def __init__(self, load_path=None): + super(SPyNetTOF, self).__init__() + + self.basic_module = nn.ModuleList([BasicModule() for _ in range(4)]) + if load_path: + self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) + + def forward(self, ref, supp): + """ + Args: + ref (Tensor): Reference image with shape of (b, 3, h, w). + supp: The supporting image to be warped: (b, 3, h, w). + + Returns: + Tensor: Estimated optical flow: (b, 2, h, w). + """ + num_batches, _, h, w = ref.size() + ref = [ref] + supp = [supp] + + # generate downsampled frames + for _ in range(3): + ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) + supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) + + # flow computation + flow = ref[0].new_zeros(num_batches, 2, h // 16, w // 16) + for i in range(4): + flow_up = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 + flow = flow_up + self.basic_module[i]( + torch.cat([ref[i], flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1)) + return flow + + +@ARCH_REGISTRY.register() +class TOFlow(nn.Module): + """PyTorch implementation of TOFlow. + + In TOFlow, the LR frames are pre-upsampled and have the same size with + the GT frames. + Paper: + Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018 + Code reference: + 1. https://github.com/anchen1011/toflow + 2. https://github.com/Coldog2333/pytoflow + + Args: + adapt_official_weights (bool): Whether to adapt the weights translated + from the official implementation. Set to false if you want to + train from scratch. Default: False + """ + + def __init__(self, adapt_official_weights=False): + super(TOFlow, self).__init__() + self.adapt_official_weights = adapt_official_weights + self.ref_idx = 0 if adapt_official_weights else 3 + + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + # flow estimation module + self.spynet = SPyNetTOF() + + # reconstruction module + self.conv_1 = nn.Conv2d(3 * 7, 64, 9, 1, 4) + self.conv_2 = nn.Conv2d(64, 64, 9, 1, 4) + self.conv_3 = nn.Conv2d(64, 64, 1) + self.conv_4 = nn.Conv2d(64, 3, 1) + + # activation function + self.relu = nn.ReLU(inplace=True) + + def normalize(self, img): + return (img - self.mean) / self.std + + def denormalize(self, img): + return img * self.std + self.mean + + def forward(self, lrs): + """ + Args: + lrs: Input lr frames: (b, 7, 3, h, w). + + Returns: + Tensor: SR frame: (b, 3, h, w). + """ + # In the official implementation, the 0-th frame is the reference frame + if self.adapt_official_weights: + lrs = lrs[:, [3, 0, 1, 2, 4, 5, 6], :, :, :] + + num_batches, num_lrs, _, h, w = lrs.size() + + lrs = self.normalize(lrs.view(-1, 3, h, w)) + lrs = lrs.view(num_batches, num_lrs, 3, h, w) + + lr_ref = lrs[:, self.ref_idx, :, :, :] + lr_aligned = [] + for i in range(7): # 7 frames + if i == self.ref_idx: + lr_aligned.append(lr_ref) + else: + lr_supp = lrs[:, i, :, :, :] + flow = self.spynet(lr_ref, lr_supp) + lr_aligned.append(flow_warp(lr_supp, flow.permute(0, 2, 3, 1))) + + # reconstruction + hr = torch.stack(lr_aligned, dim=1) + hr = hr.view(num_batches, -1, h, w) + hr = self.relu(self.conv_1(hr)) + hr = self.relu(self.conv_2(hr)) + hr = self.relu(self.conv_3(hr)) + hr = self.conv_4(hr) + lr_ref + + return self.denormalize(hr) -- cgit v1.2.3