summaryrefslogtreecommitdiffstats
path: root/r_basicsr/models/edvr_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'r_basicsr/models/edvr_model.py')
-rw-r--r--r_basicsr/models/edvr_model.py62
1 files changed, 62 insertions, 0 deletions
diff --git a/r_basicsr/models/edvr_model.py b/r_basicsr/models/edvr_model.py
new file mode 100644
index 0000000..1475033
--- /dev/null
+++ b/r_basicsr/models/edvr_model.py
@@ -0,0 +1,62 @@
+from r_basicsr.utils import get_root_logger
+from r_basicsr.utils.registry import MODEL_REGISTRY
+from .video_base_model import VideoBaseModel
+
+
+@MODEL_REGISTRY.register()
+class EDVRModel(VideoBaseModel):
+ """EDVR Model.
+
+ Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501
+ """
+
+ def __init__(self, opt):
+ super(EDVRModel, self).__init__(opt)
+ if self.is_train:
+ self.train_tsa_iter = opt['train'].get('tsa_iter')
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ dcn_lr_mul = train_opt.get('dcn_lr_mul', 1)
+ logger = get_root_logger()
+ logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.')
+ if dcn_lr_mul == 1:
+ optim_params = self.net_g.parameters()
+ else: # separate dcn params and normal params for different lr
+ normal_params = []
+ dcn_params = []
+ for name, param in self.net_g.named_parameters():
+ if 'dcn' in name:
+ dcn_params.append(param)
+ else:
+ normal_params.append(param)
+ optim_params = [
+ { # add normal params first
+ 'params': normal_params,
+ 'lr': train_opt['optim_g']['lr']
+ },
+ {
+ 'params': dcn_params,
+ 'lr': train_opt['optim_g']['lr'] * dcn_lr_mul
+ },
+ ]
+
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+
+ def optimize_parameters(self, current_iter):
+ if self.train_tsa_iter:
+ if current_iter == 1:
+ logger = get_root_logger()
+ logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.')
+ for name, param in self.net_g.named_parameters():
+ if 'fusion' not in name:
+ param.requires_grad = False
+ elif current_iter == self.train_tsa_iter:
+ logger = get_root_logger()
+ logger.warning('Train all the parameters.')
+ for param in self.net_g.parameters():
+ param.requires_grad = True
+
+ super(EDVRModel, self).optimize_parameters(current_iter)