summaryrefslogtreecommitdiffstats
path: root/r_basicsr/utils/options.py
diff options
context:
space:
mode:
Diffstat (limited to 'r_basicsr/utils/options.py')
-rw-r--r--r_basicsr/utils/options.py194
1 files changed, 194 insertions, 0 deletions
diff --git a/r_basicsr/utils/options.py b/r_basicsr/utils/options.py
new file mode 100644
index 0000000..cf90df4
--- /dev/null
+++ b/r_basicsr/utils/options.py
@@ -0,0 +1,194 @@
+import argparse
+import random
+import torch
+import yaml
+from collections import OrderedDict
+from os import path as osp
+
+from r_basicsr.utils import set_random_seed
+from r_basicsr.utils.dist_util import get_dist_info, init_dist, master_only
+
+
+def ordered_yaml():
+ """Support OrderedDict for yaml.
+
+ Returns:
+ yaml Loader and Dumper.
+ """
+ try:
+ from yaml import CDumper as Dumper
+ from yaml import CLoader as Loader
+ except ImportError:
+ from yaml import Dumper, Loader
+
+ _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
+
+ def dict_representer(dumper, data):
+ return dumper.represent_dict(data.items())
+
+ def dict_constructor(loader, node):
+ return OrderedDict(loader.construct_pairs(node))
+
+ Dumper.add_representer(OrderedDict, dict_representer)
+ Loader.add_constructor(_mapping_tag, dict_constructor)
+ return Loader, Dumper
+
+
+def dict2str(opt, indent_level=1):
+ """dict to string for printing options.
+
+ Args:
+ opt (dict): Option dict.
+ indent_level (int): Indent level. Default: 1.
+
+ Return:
+ (str): Option string for printing.
+ """
+ msg = '\n'
+ for k, v in opt.items():
+ if isinstance(v, dict):
+ msg += ' ' * (indent_level * 2) + k + ':['
+ msg += dict2str(v, indent_level + 1)
+ msg += ' ' * (indent_level * 2) + ']\n'
+ else:
+ msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
+ return msg
+
+
+def _postprocess_yml_value(value):
+ # None
+ if value == '~' or value.lower() == 'none':
+ return None
+ # bool
+ if value.lower() == 'true':
+ return True
+ elif value.lower() == 'false':
+ return False
+ # !!float number
+ if value.startswith('!!float'):
+ return float(value.replace('!!float', ''))
+ # number
+ if value.isdigit():
+ return int(value)
+ elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
+ return float(value)
+ # list
+ if value.startswith('['):
+ return eval(value)
+ # str
+ return value
+
+
+def parse_options(root_path, is_train=True):
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
+ parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
+ parser.add_argument('--auto_resume', action='store_true')
+ parser.add_argument('--debug', action='store_true')
+ parser.add_argument('--local_rank', type=int, default=0)
+ parser.add_argument(
+ '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
+ args = parser.parse_args()
+
+ # parse yml to dict
+ with open(args.opt, mode='r') as f:
+ opt = yaml.load(f, Loader=ordered_yaml()[0])
+
+ # distributed settings
+ if args.launcher == 'none':
+ opt['dist'] = False
+ print('Disable distributed.', flush=True)
+ else:
+ opt['dist'] = True
+ if args.launcher == 'slurm' and 'dist_params' in opt:
+ init_dist(args.launcher, **opt['dist_params'])
+ else:
+ init_dist(args.launcher)
+ opt['rank'], opt['world_size'] = get_dist_info()
+
+ # random seed
+ seed = opt.get('manual_seed')
+ if seed is None:
+ seed = random.randint(1, 10000)
+ opt['manual_seed'] = seed
+ set_random_seed(seed + opt['rank'])
+
+ # force to update yml options
+ if args.force_yml is not None:
+ for entry in args.force_yml:
+ # now do not support creating new keys
+ keys, value = entry.split('=')
+ keys, value = keys.strip(), value.strip()
+ value = _postprocess_yml_value(value)
+ eval_str = 'opt'
+ for key in keys.split(':'):
+ eval_str += f'["{key}"]'
+ eval_str += '=value'
+ # using exec function
+ exec(eval_str)
+
+ opt['auto_resume'] = args.auto_resume
+ opt['is_train'] = is_train
+
+ # debug setting
+ if args.debug and not opt['name'].startswith('debug'):
+ opt['name'] = 'debug_' + opt['name']
+
+ if opt['num_gpu'] == 'auto':
+ opt['num_gpu'] = torch.cuda.device_count()
+
+ # datasets
+ for phase, dataset in opt['datasets'].items():
+ # for multiple datasets, e.g., val_1, val_2; test_1, test_2
+ phase = phase.split('_')[0]
+ dataset['phase'] = phase
+ if 'scale' in opt:
+ dataset['scale'] = opt['scale']
+ if dataset.get('dataroot_gt') is not None:
+ dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
+ if dataset.get('dataroot_lq') is not None:
+ dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
+
+ # paths
+ for key, val in opt['path'].items():
+ if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
+ opt['path'][key] = osp.expanduser(val)
+
+ if is_train:
+ experiments_root = osp.join(root_path, 'experiments', opt['name'])
+ opt['path']['experiments_root'] = experiments_root
+ opt['path']['models'] = osp.join(experiments_root, 'models')
+ opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
+ opt['path']['log'] = experiments_root
+ opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
+
+ # change some options for debug mode
+ if 'debug' in opt['name']:
+ if 'val' in opt:
+ opt['val']['val_freq'] = 8
+ opt['logger']['print_freq'] = 1
+ opt['logger']['save_checkpoint_freq'] = 8
+ else: # test
+ results_root = osp.join(root_path, 'results', opt['name'])
+ opt['path']['results_root'] = results_root
+ opt['path']['log'] = results_root
+ opt['path']['visualization'] = osp.join(results_root, 'visualization')
+
+ return opt, args
+
+
+@master_only
+def copy_opt_file(opt_file, experiments_root):
+ # copy the yml file to the experiment root
+ import sys
+ import time
+ from shutil import copyfile
+ cmd = ' '.join(sys.argv)
+ filename = osp.join(experiments_root, osp.basename(opt_file))
+ copyfile(opt_file, filename)
+
+ with open(filename, 'r+') as f:
+ lines = f.readlines()
+ lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
+ f.seek(0)
+ f.writelines(lines)