diff --git a/mmcv/runner/optimizer/default_constructor.py b/mmcv/runner/optimizer/default_constructor.py index d51621929..7b64ec464 100644 --- a/mmcv/runner/optimizer/default_constructor.py +++ b/mmcv/runner/optimizer/default_constructor.py @@ -14,6 +14,17 @@ class DefaultOptimizerConstructor: By default each parameter share the same optimizer settings, and we provide an argument ``paramwise_cfg`` to specify parameter-wise settings. It is a dict and may contain the following fields: + + - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If + one of the keys in ``custom_keys`` is a substring of the name of one + parameter, then the setting of the parameter will be specified by + ``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will + be ignored. It should be noted that the aforementioned ``key`` is the + longest key that is a substring of the name of the parameter. If there + are multiple matched keys with the same length, then the key with lower + alphabet order will be chosen. + ``custom_keys[key]`` should be a dict and may contain fields ``lr_mult`` + and ``decay_mult``. See Example 2 below. - ``bias_lr_mult`` (float): It will be multiplied to the learning rate for all bias parameters (except for those in normalization layers). @@ -39,7 +50,7 @@ class DefaultOptimizerConstructor: lr, weight_decay, momentum, etc. paramwise_cfg (dict, optional): Parameter-wise options. - Example: + Example 1: >>> model = torch.nn.modules.Conv1d(1, 1, 1) >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9, >>> weight_decay=0.0001) @@ -47,6 +58,18 @@ class DefaultOptimizerConstructor: >>> optim_builder = DefaultOptimizerConstructor( >>> optimizer_cfg, paramwise_cfg) >>> optimizer = optim_builder(model) + + Example 2: + >>> # assume model have attribute model.backbone and model.cls_head + >>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95) + >>> paramwise_cfg = dict(custom_keys={ + '.backbone': dict(lr_mult=0.1, decay_mult=0.9)}) + >>> optim_builder = DefaultOptimizerConstructor( + >>> optimizer_cfg, paramwise_cfg) + >>> optimizer = optim_builder(model) + >>> # Then the `lr` and `weight_decay` for model.backbone is + >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for + >>> # model.cls_head is (0.01, 0.95). """ def __init__(self, optimizer_cfg, paramwise_cfg=None): @@ -63,6 +86,13 @@ class DefaultOptimizerConstructor: if not isinstance(self.paramwise_cfg, dict): raise TypeError('paramwise_cfg should be None or a dict, ' f'but got {type(self.paramwise_cfg)}') + + if ('custom_keys' in self.paramwise_cfg + and not isinstance(self.paramwise_cfg['custom_keys'], dict)): + raise TypeError( + 'If specified, custom_keys must be a dict, ' + f'but got {type(self.paramwise_cfg["custom_keys"])}') + # get base lr and weight decay # weight_decay must be explicitly specified if mult is specified if ('bias_decay_mult' in self.paramwise_cfg @@ -93,6 +123,10 @@ class DefaultOptimizerConstructor: prefix (str): The prefix of the module """ # get param-wise options + custom_keys = self.paramwise_cfg.get('custom_keys', {}) + # first sort with alphabet order and then sort with reversed len of str + sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) + bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.) bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.) norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.) @@ -115,23 +149,35 @@ class DefaultOptimizerConstructor: warnings.warn(f'{prefix} is duplicate. It is skipped since ' f'bypass_duplicate={bypass_duplicate}') continue - # bias_lr_mult affects all bias parameters except for norm.bias - if name == 'bias' and not is_norm: - param_group['lr'] = self.base_lr * bias_lr_mult - # apply weight decay policies - if self.base_wd is not None: - # norm decay - if is_norm: + # if the parameter match one of the custom keys, ignore other rules + is_custom = False + for key in sorted_keys: + if key in f'{prefix}.{name}': + is_custom = True + param_group['lr'] = self.base_lr * custom_keys[key].get( + 'lr_mult', 1.) param_group[ - 'weight_decay'] = self.base_wd * norm_decay_mult - # depth-wise conv - elif is_dwconv: - param_group[ - 'weight_decay'] = self.base_wd * dwconv_decay_mult - # bias lr and decay - elif name == 'bias': - param_group[ - 'weight_decay'] = self.base_wd * bias_decay_mult + 'weight_decay'] = self.base_wd * custom_keys[key].get( + 'decay_mult', 1.) + break + if not is_custom: + # bias_lr_mult affects all bias parameters except for norm.bias + if name == 'bias' and not is_norm: + param_group['lr'] = self.base_lr * bias_lr_mult + # apply weight decay policies + if self.base_wd is not None: + # norm decay + if is_norm: + param_group[ + 'weight_decay'] = self.base_wd * norm_decay_mult + # depth-wise conv + elif is_dwconv: + param_group[ + 'weight_decay'] = self.base_wd * dwconv_decay_mult + # bias lr and decay + elif name == 'bias': + param_group[ + 'weight_decay'] = self.base_wd * bias_decay_mult params.append(param_group) for child_name, child_mod in module.named_children(): diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 12918a47c..4df7e768e 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -355,6 +355,86 @@ def test_default_optimizer_constructor(): assert len(optimizer.param_groups) == len(model_parameters) == 11 check_optimizer(optimizer, model, **paramwise_cfg) + # test DefaultOptimizerConstructor with custom_groups and ExampleModel + model = ExampleModel() + optimizer_cfg = dict( + type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum) + paramwise_cfg = dict( + custom_keys={ + 'param1': dict(lr_mult=10), + 'sub': dict(lr_mult=0.1, decay_mult=0), + 'sub.gn': dict(lr_mult=0.01), + 'non_exist_key': dict(lr_mult=0.0) + }, + norm_decay_mult=0.5) + + with pytest.raises(TypeError): + # custom_keys should be a dict + paramwise_cfg_ = dict(custom_keys=[0.1, 0.0001]) + optim_constructor = DefaultOptimizerConstructor( + optim_constructor, paramwise_cfg_) + optimizer = optim_constructor(model) + + optim_constructor = DefaultOptimizerConstructor(optimizer_cfg, + paramwise_cfg) + optimizer = optim_constructor(model) + # check optimizer type and default config + assert isinstance(optimizer, torch.optim.SGD) + assert optimizer.defaults['lr'] == base_lr + assert optimizer.defaults['momentum'] == momentum + assert optimizer.defaults['weight_decay'] == base_wd + + # check params groups + param_groups = optimizer.param_groups + + groups = [] + group_settings = [] + # group 1, matches of 'param1' + # 'param1' is the longest match for 'sub.param1' + groups.append(['param1', 'sub.param1']) + group_settings.append({ + 'lr': base_lr * 10, + 'momentum': momentum, + 'weight_decay': base_wd, + }) + # group 2, matches of 'sub.gn' + groups.append(['sub.gn.weight', 'sub.gn.bias']) + group_settings.append({ + 'lr': base_lr * 0.01, + 'momentum': momentum, + 'weight_decay': base_wd, + }) + # group 3, matches of 'sub' + groups.append(['sub.conv1.weight', 'sub.conv1.bias']) + group_settings.append({ + 'lr': base_lr * 0.1, + 'momentum': momentum, + 'weight_decay': 0, + }) + # group 4, bn is configured by 'norm_decay_mult' + groups.append(['bn.weight', 'bn.bias']) + group_settings.append({ + 'lr': base_lr, + 'momentum': momentum, + 'weight_decay': base_wd * 0.5, + }) + # group 5, default group + groups.append(['conv1.weight', 'conv2.weight', 'conv2.bias']) + group_settings.append({ + 'lr': base_lr, + 'momentum': momentum, + 'weight_decay': base_wd + }) + + assert len(param_groups) == 11 + for i, (name, param) in enumerate(model.named_parameters()): + assert torch.equal(param_groups[i]['params'][0], param) + for group, settings in zip(groups, group_settings): + if name in group: + for setting in settings: + assert param_groups[i][setting] == settings[ + setting], f'{name} {setting}' + def test_torch_optimizers(): torch_optimizers = [