From 495ffc4777522e40941753e3b1b79c02f84b25b4 Mon Sep 17 00:00:00 2001 From: Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:00:30 +0000 Subject: Add files via upload --- r_basicsr/models/edvr_model.py | 62 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 r_basicsr/models/edvr_model.py (limited to 'r_basicsr/models/edvr_model.py') diff --git a/r_basicsr/models/edvr_model.py b/r_basicsr/models/edvr_model.py new file mode 100644 index 0000000..1475033 --- /dev/null +++ b/r_basicsr/models/edvr_model.py @@ -0,0 +1,62 @@ +from r_basicsr.utils import get_root_logger +from r_basicsr.utils.registry import MODEL_REGISTRY +from .video_base_model import VideoBaseModel + + +@MODEL_REGISTRY.register() +class EDVRModel(VideoBaseModel): + """EDVR Model. + + Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501 + """ + + def __init__(self, opt): + super(EDVRModel, self).__init__(opt) + if self.is_train: + self.train_tsa_iter = opt['train'].get('tsa_iter') + + def setup_optimizers(self): + train_opt = self.opt['train'] + dcn_lr_mul = train_opt.get('dcn_lr_mul', 1) + logger = get_root_logger() + logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.') + if dcn_lr_mul == 1: + optim_params = self.net_g.parameters() + else: # separate dcn params and normal params for different lr + normal_params = [] + dcn_params = [] + for name, param in self.net_g.named_parameters(): + if 'dcn' in name: + dcn_params.append(param) + else: + normal_params.append(param) + optim_params = [ + { # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }, + { + 'params': dcn_params, + 'lr': train_opt['optim_g']['lr'] * dcn_lr_mul + }, + ] + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def optimize_parameters(self, current_iter): + if self.train_tsa_iter: + if current_iter == 1: + logger = get_root_logger() + logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.') + for name, param in self.net_g.named_parameters(): + if 'fusion' not in name: + param.requires_grad = False + elif current_iter == self.train_tsa_iter: + logger = get_root_logger() + logger.warning('Train all the parameters.') + for param in self.net_g.parameters(): + param.requires_grad = True + + super(EDVRModel, self).optimize_parameters(current_iter) -- cgit v1.2.3