From 495ffc4777522e40941753e3b1b79c02f84b25b4 Mon Sep 17 00:00:00 2001 From: Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:00:30 +0000 Subject: Add files via upload --- r_basicsr/models/base_model.py | 380 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 380 insertions(+) create mode 100644 r_basicsr/models/base_model.py (limited to 'r_basicsr/models/base_model.py') diff --git a/r_basicsr/models/base_model.py b/r_basicsr/models/base_model.py new file mode 100644 index 0000000..bd2faad --- /dev/null +++ b/r_basicsr/models/base_model.py @@ -0,0 +1,380 @@ +import os +import time +import torch +from collections import OrderedDict +from copy import deepcopy +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from r_basicsr.models import lr_scheduler as lr_scheduler +from r_basicsr.utils import get_root_logger +from r_basicsr.utils.dist_util import master_only + + +class BaseModel(): + """Base model.""" + + def __init__(self, opt): + self.opt = opt + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.is_train = opt['is_train'] + self.schedulers = [] + self.optimizers = [] + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def save(self, epoch, current_iter): + """Save networks and training state.""" + pass + + def validation(self, dataloader, current_iter, tb_logger, save_img=False): + """Validation function. + + Args: + dataloader (torch.utils.data.DataLoader): Validation dataloader. + current_iter (int): Current iteration. + tb_logger (tensorboard logger): Tensorboard logger. + save_img (bool): Whether to save images. Default: False. + """ + if self.opt['dist']: + self.dist_validation(dataloader, current_iter, tb_logger, save_img) + else: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def _initialize_best_metric_results(self, dataset_name): + """Initialize the best metric results dict for recording the best metric value and iteration.""" + if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results: + return + elif not hasattr(self, 'best_metric_results'): + self.best_metric_results = dict() + + # add a dataset record + record = dict() + for metric, content in self.opt['val']['metrics'].items(): + better = content.get('better', 'higher') + init_val = float('-inf') if better == 'higher' else float('inf') + record[metric] = dict(better=better, val=init_val, iter=-1) + self.best_metric_results[dataset_name] = record + + def _update_best_metric_result(self, dataset_name, metric, val, current_iter): + if self.best_metric_results[dataset_name][metric]['better'] == 'higher': + if val >= self.best_metric_results[dataset_name][metric]['val']: + self.best_metric_results[dataset_name][metric]['val'] = val + self.best_metric_results[dataset_name][metric]['iter'] = current_iter + else: + if val <= self.best_metric_results[dataset_name][metric]['val']: + self.best_metric_results[dataset_name][metric]['val'] = val + self.best_metric_results[dataset_name][metric]['iter'] = current_iter + + def model_ema(self, decay=0.999): + net_g = self.get_bare_model(self.net_g) + + net_g_params = dict(net_g.named_parameters()) + net_g_ema_params = dict(self.net_g_ema.named_parameters()) + + for k in net_g_ema_params.keys(): + net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay) + + def get_current_log(self): + return self.log_dict + + def model_to_device(self, net): + """Model to device. It also warps models with DistributedDataParallel + or DataParallel. + + Args: + net (nn.Module) + """ + net = net.to(self.device) + if self.opt['dist']: + find_unused_parameters = self.opt.get('find_unused_parameters', False) + net = DistributedDataParallel( + net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters) + elif self.opt['num_gpu'] > 1: + net = DataParallel(net) + return net + + def get_optimizer(self, optim_type, params, lr, **kwargs): + if optim_type == 'Adam': + optimizer = torch.optim.Adam(params, lr, **kwargs) + else: + raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.') + return optimizer + + def setup_schedulers(self): + """Set up schedulers.""" + train_opt = self.opt['train'] + scheduler_type = train_opt['scheduler'].pop('type') + if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'CosineAnnealingRestartLR': + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler'])) + else: + raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.') + + def get_bare_model(self, net): + """Get bare model, especially under wrapping with + DistributedDataParallel or DataParallel. + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + return net + + @master_only + def print_network(self, net): + """Print the str and parameter number of a network. + + Args: + net (nn.Module) + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}' + else: + net_cls_str = f'{net.__class__.__name__}' + + net = self.get_bare_model(net) + net_str = str(net) + net_params = sum(map(lambda x: x.numel(), net.parameters())) + + logger = get_root_logger() + logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}') + logger.info(net_str) + + def _set_lr(self, lr_groups_l): + """Set learning rate for warm-up. + + Args: + lr_groups_l (list): List for lr_groups, each for an optimizer. + """ + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def _get_init_lr(self): + """Get the initial lr, which is set by the scheduler. + """ + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, current_iter, warmup_iter=-1): + """Update learning rate. + + Args: + current_iter (int): Current iteration. + warmup_iter (int): Warm-up iter numbers. -1 for no warm-up. + Default: -1. + """ + if current_iter > 1: + for scheduler in self.schedulers: + scheduler.step() + # set up warm-up learning rate + if current_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + # currently only support linearly warm up + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + return [param_group['lr'] for param_group in self.optimizers[0].param_groups] + + @master_only + def save_network(self, net, net_label, current_iter, param_key='params'): + """Save networks. + + Args: + net (nn.Module | list[nn.Module]): Network(s) to be saved. + net_label (str): Network label. + current_iter (int): Current iter number. + param_key (str | list[str]): The parameter key(s) to save network. + Default: 'params'. + """ + if current_iter == -1: + current_iter = 'latest' + save_filename = f'{net_label}_{current_iter}.pth' + save_path = os.path.join(self.opt['path']['models'], save_filename) + + net = net if isinstance(net, list) else [net] + param_key = param_key if isinstance(param_key, list) else [param_key] + assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.' + + save_dict = {} + for net_, param_key_ in zip(net, param_key): + net_ = self.get_bare_model(net_) + state_dict = net_.state_dict() + for key, param in state_dict.items(): + if key.startswith('module.'): # remove unnecessary 'module.' + key = key[7:] + state_dict[key] = param.cpu() + save_dict[param_key_] = state_dict + + # avoid occasional writing errors + retry = 3 + while retry > 0: + try: + torch.save(save_dict, save_path) + except Exception as e: + logger = get_root_logger() + logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}') + time.sleep(1) + else: + break + finally: + retry -= 1 + if retry == 0: + logger.warning(f'Still cannot save {save_path}. Just ignore it.') + # raise IOError(f'Cannot save {save_path}.') + + def _print_different_keys_loading(self, crt_net, load_net, strict=True): + """Print keys with different name or different size when loading models. + + 1. Print keys with different names. + 2. If strict=False, print the same key but with different tensor size. + It also ignore these keys with different sizes (not load). + + Args: + crt_net (torch model): Current network. + load_net (dict): Loaded network. + strict (bool): Whether strictly loaded. Default: True. + """ + crt_net = self.get_bare_model(crt_net) + crt_net = crt_net.state_dict() + crt_net_keys = set(crt_net.keys()) + load_net_keys = set(load_net.keys()) + + logger = get_root_logger() + if crt_net_keys != load_net_keys: + logger.warning('Current net - loaded net:') + for v in sorted(list(crt_net_keys - load_net_keys)): + logger.warning(f' {v}') + logger.warning('Loaded net - current net:') + for v in sorted(list(load_net_keys - crt_net_keys)): + logger.warning(f' {v}') + + # check the size for the same keys + if not strict: + common_keys = crt_net_keys & load_net_keys + for k in common_keys: + if crt_net[k].size() != load_net[k].size(): + logger.warning(f'Size different, ignore [{k}]: crt_net: ' + f'{crt_net[k].shape}; load_net: {load_net[k].shape}') + load_net[k + '.ignore'] = load_net.pop(k) + + def load_network(self, net, load_path, strict=True, param_key='params'): + """Load network. + + Args: + load_path (str): The path of networks to be loaded. + net (nn.Module): Network. + strict (bool): Whether strictly loaded. + param_key (str): The parameter key of loaded network. If set to + None, use the root 'path'. + Default: 'params'. + """ + logger = get_root_logger() + net = self.get_bare_model(net) + load_net = torch.load(load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info('Loading: params_ema does not exist, use params.') + load_net = load_net[param_key] + logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].') + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + self._print_different_keys_loading(net, load_net, strict) + net.load_state_dict(load_net, strict=strict) + + @master_only + def save_training_state(self, epoch, current_iter): + """Save training states during training, which will be used for + resuming. + + Args: + epoch (int): Current epoch. + current_iter (int): Current iteration. + """ + if current_iter != -1: + state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []} + for o in self.optimizers: + state['optimizers'].append(o.state_dict()) + for s in self.schedulers: + state['schedulers'].append(s.state_dict()) + save_filename = f'{current_iter}.state' + save_path = os.path.join(self.opt['path']['training_states'], save_filename) + + # avoid occasional writing errors + retry = 3 + while retry > 0: + try: + torch.save(state, save_path) + except Exception as e: + logger = get_root_logger() + logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}') + time.sleep(1) + else: + break + finally: + retry -= 1 + if retry == 0: + logger.warning(f'Still cannot save {save_path}. Just ignore it.') + # raise IOError(f'Cannot save {save_path}.') + + def resume_training(self, resume_state): + """Reload the optimizers and schedulers for resumed training. + + Args: + resume_state (dict): Resume state. + """ + resume_optimizers = resume_state['optimizers'] + resume_schedulers = resume_state['schedulers'] + assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' + assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + for i, s in enumerate(resume_schedulers): + self.schedulers[i].load_state_dict(s) + + def reduce_loss_dict(self, loss_dict): + """reduce loss dict. + + In distributed training, it averages the losses among different GPUs . + + Args: + loss_dict (OrderedDict): Loss dict. + """ + with torch.no_grad(): + if self.opt['dist']: + keys = [] + losses = [] + for name, value in loss_dict.items(): + keys.append(name) + losses.append(value) + losses = torch.stack(losses, 0) + torch.distributed.reduce(losses, dst=0) + if self.opt['rank'] == 0: + losses /= self.opt['world_size'] + loss_dict = {key: loss for key, loss in zip(keys, losses)} + + log_dict = OrderedDict() + for name, value in loss_dict.items(): + log_dict[name] = value.mean().item() + + return log_dict -- cgit v1.2.3