# Copyright (c) OpenMMLab. All rights reserved. import sys from unittest import TestCase from unittest.mock import MagicMock import torch import torch.nn as nn from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS, DefaultOptimWrapperConstructor, OptimWrapper, build_optim_wrapper) from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS from mmengine.registry import build_from_cfg from mmengine.utils.dl_utils import mmcv_full_available MMCV_FULL_AVAILABLE = mmcv_full_available() if not MMCV_FULL_AVAILABLE: sys.modules['mmcv.ops'] = MagicMock( DeformConv2d=dict, ModulatedDeformConv2d=dict) class ExampleModel(nn.Module): def __init__(self): super().__init__() self.param1 = nn.Parameter(torch.ones(1)) self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False) self.conv2 = nn.Conv2d(4, 2, kernel_size=1) self.bn = nn.BatchNorm2d(2) self.sub = SubModel() if MMCV_FULL_AVAILABLE: from mmcv.ops import DeformConv2dPack self.dcn = DeformConv2dPack( 3, 4, kernel_size=3, deformable_groups=1) class ExampleDuplicateModel(nn.Module): def __init__(self): super().__init__() self.param1 = nn.Parameter(torch.ones(1)) self.conv1 = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=False)) self.conv2 = nn.Sequential(nn.Conv2d(4, 2, kernel_size=1)) self.bn = nn.BatchNorm2d(2) self.sub = SubModel() self.conv3 = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=False)) self.conv3[0] = self.conv1[0] if MMCV_FULL_AVAILABLE: from mmcv.ops import DeformConv2dPack self.dcn = DeformConv2dPack( 3, 4, kernel_size=3, deformable_groups=1) def forward(self, x): return x class SubModel(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(2, 2, kernel_size=1, groups=2) self.gn = nn.GroupNorm(2, 2) self.param1 = nn.Parameter(torch.ones(1)) def forward(self, x): return x class PseudoDataParallel(nn.Module): def __init__(self): super().__init__() self.module = ExampleModel() def forward(self, x): return x class TestBuilder(TestCase): def setUp(self): self.model = ExampleModel() self.base_lr = 0.01 self.momentum = 0.0001 self.base_wd = 0.9 def _check_default_optimizer(self, optimizer, model, prefix=''): assert isinstance(optimizer, torch.optim.SGD) assert optimizer.defaults['lr'] == self.base_lr assert optimizer.defaults['momentum'] == self.momentum assert optimizer.defaults['weight_decay'] == self.base_wd param_groups = optimizer.param_groups[0] if MMCV_FULL_AVAILABLE: param_names = [ 'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias', 'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight', 'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias', 'dcn.weight', 'dcn.conv_offset.weight', 'dcn.conv_offset.bias' ] else: param_names = [ 'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias', 'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight', 'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias' ] param_dict = dict(model.named_parameters()) assert len(param_groups['params']) == len(param_names) for i in range(len(param_groups['params'])): assert torch.equal(param_groups['params'][i], param_dict[prefix + param_names[i]]) def _check_sgd_optimizer(self, optimizer, model, prefix='', bias_lr_mult=1, bias_decay_mult=1, norm_decay_mult=1, dwconv_decay_mult=1, dcn_offset_lr_mult=1, bypass_duplicate=False): param_groups = optimizer.param_groups assert isinstance(optimizer, torch.optim.SGD) assert optimizer.defaults['lr'] == self.base_lr assert optimizer.defaults['momentum'] == self.momentum assert optimizer.defaults['weight_decay'] == self.base_wd model_parameters = list(model.parameters()) assert len(param_groups) == len(model_parameters) for i, param in enumerate(model_parameters): param_group = param_groups[i] assert torch.equal(param_group['params'][0], param) assert param_group['momentum'] == self.momentum # param1 param1 = param_groups[0] assert param1['lr'] == self.base_lr assert param1['weight_decay'] == self.base_wd # conv1.weight conv1_weight = param_groups[1] assert conv1_weight['lr'] == self.base_lr assert conv1_weight['weight_decay'] == self.base_wd # conv2.weight conv2_weight = param_groups[2] assert conv2_weight['lr'] == self.base_lr assert conv2_weight['weight_decay'] == self.base_wd # conv2.bias conv2_bias = param_groups[3] assert conv2_bias['lr'] == self.base_lr * bias_lr_mult assert conv2_bias['weight_decay'] == self.base_wd * bias_decay_mult # bn.weight bn_weight = param_groups[4] assert bn_weight['lr'] == self.base_lr assert bn_weight['weight_decay'] == self.base_wd * norm_decay_mult # bn.bias bn_bias = param_groups[5] assert bn_bias['lr'] == self.base_lr assert bn_bias['weight_decay'] == self.base_wd * norm_decay_mult # sub.param1 sub_param1 = param_groups[6] assert sub_param1['lr'] == self.base_lr assert sub_param1['weight_decay'] == self.base_wd # sub.conv1.weight sub_conv1_weight = param_groups[7] assert sub_conv1_weight['lr'] == self.base_lr assert sub_conv1_weight[ 'weight_decay'] == self.base_wd * dwconv_decay_mult # sub.conv1.bias sub_conv1_bias = param_groups[8] assert sub_conv1_bias['lr'] == self.base_lr * bias_lr_mult assert sub_conv1_bias[ 'weight_decay'] == self.base_wd * dwconv_decay_mult # sub.gn.weight sub_gn_weight = param_groups[9] assert sub_gn_weight['lr'] == self.base_lr assert sub_gn_weight['weight_decay'] == self.base_wd * norm_decay_mult # sub.gn.bias sub_gn_bias = param_groups[10] assert sub_gn_bias['lr'] == self.base_lr assert sub_gn_bias['weight_decay'] == self.base_wd * norm_decay_mult # test dcn which requires cuda is available and # mmcv-full has been installed if torch.cuda.is_available() and MMCV_FULL_AVAILABLE: dcn_conv_weight = param_groups[11] assert dcn_conv_weight['lr'] == self.base_lr assert dcn_conv_weight['weight_decay'] == self.base_wd dcn_offset_weight = param_groups[12] assert dcn_offset_weight['lr'] == self.base_lr * dcn_offset_lr_mult assert dcn_offset_weight['weight_decay'] == self.base_wd dcn_offset_bias = param_groups[13] assert dcn_offset_bias['lr'] == self.base_lr * dcn_offset_lr_mult assert dcn_offset_bias['weight_decay'] == self.base_wd def test_torch_optimizers(self): torch_optimizers = [ 'ASGD', 'Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'LBFGS', 'Optimizer', 'RMSprop', 'Rprop', 'SGD', 'SparseAdam' ] assert set(torch_optimizers).issubset(set(TORCH_OPTIMIZERS)) def test_build_optimizer(self): # test build function without ``constructor`` and ``paramwise_cfg`` optim_wrapper_cfg = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum)) optim_wrapper = build_optim_wrapper(self.model, optim_wrapper_cfg) self._check_default_optimizer(optim_wrapper.optimizer, self.model) # test build optimizer without type in optim_wrapper_cfg optim_wrapper_cfg = dict( optimizer=dict( type='SGD', lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum)) optim_wrapper = build_optim_wrapper(self.model, optim_wrapper_cfg) self.assertIsInstance(optim_wrapper, OptimWrapper) self._check_default_optimizer(optim_wrapper.optimizer, self.model) # test build function with invalid ``constructor`` with self.assertRaises(KeyError): optim_wrapper_cfg['constructor'] = 'INVALID_CONSTRUCTOR' build_optim_wrapper(self.model, optim_wrapper_cfg) # test build function with invalid ``paramwise_cfg`` with self.assertRaises(KeyError): optim_wrapper_cfg['paramwise_cfg'] = dict(invalid_mult=1) build_optim_wrapper(self.model, optim_wrapper_cfg) optim_wrapper_cfg.pop('optimizer') optim_wrapper_cfg.pop('constructor') optim_wrapper_cfg.pop('paramwise_cfg') self.assertRaisesRegex( AssertionError, '`optim_wrapper_cfg` must contain', lambda: build_optim_wrapper(self.model, optim_wrapper_cfg)) def test_build_default_optimizer_constructor(self): optim_wrapper = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum)) paramwise_cfg = dict( bias_lr_mult=2, bias_decay_mult=0.5, norm_decay_mult=0, dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1) optim_constructor_cfg = dict( type='DefaultOptimWrapperConstructor', optim_wrapper_cfg=optim_wrapper, paramwise_cfg=paramwise_cfg) optim_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( optim_constructor_cfg) optim_wrapper = optim_constructor(self.model) self._check_sgd_optimizer(optim_wrapper.optimizer, self.model, **paramwise_cfg) def test_build_custom_optimizer_constructor(self): optim_wrapper_cfg = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum)) @OPTIM_WRAPPER_CONSTRUCTORS.register_module() class MyOptimizerConstructor(DefaultOptimWrapperConstructor): def __call__(self, model): if hasattr(model, 'module'): model = model.module conv1_lr_mult = self.paramwise_cfg.get('conv1_lr_mult', 1.) params = [] for name, param in model.named_parameters(): param_group = {'params': [param]} if name.startswith('conv1') and param.requires_grad: param_group['lr'] = self.base_lr * conv1_lr_mult params.append(param_group) self.optimizer_cfg['params'] = params return build_from_cfg(self.optimizer_cfg, OPTIMIZERS) paramwise_cfg = dict(conv1_lr_mult=5) optim_constructor_cfg = dict( type='MyOptimizerConstructor', optim_wrapper_cfg=optim_wrapper_cfg, paramwise_cfg=paramwise_cfg) optim_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( optim_constructor_cfg) optimizer = optim_constructor(self.model) param_groups = optimizer.param_groups assert isinstance(optimizer, torch.optim.SGD) assert optimizer.defaults['lr'] == self.base_lr assert optimizer.defaults['momentum'] == self.momentum assert optimizer.defaults['weight_decay'] == self.base_wd for i, param in enumerate(self.model.parameters()): param_group = param_groups[i] assert torch.equal(param_group['params'][0], param) assert param_group['momentum'] == self.momentum # conv1.weight assert param_groups[1][ 'lr'] == self.base_lr * paramwise_cfg['conv1_lr_mult'] assert param_groups[1]['weight_decay'] == self.base_wd def test_default_optimizer_constructor(self): with self.assertRaises(TypeError): # optimizer_cfg must be a dict optimizer_cfg = [] optim_constructor = DefaultOptimWrapperConstructor(optimizer_cfg) optim_constructor(self.model) with self.assertRaises(TypeError): # paramwise_cfg must be a dict or None optim_wrapper_cfg = dict( type='OptimWrapper', optimizer=dict(lr=0.0001, weight_decay=None)) paramwise_cfg = ['error'] optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optim_constructor(self.model) with self.assertRaises(ValueError): # bias_decay_mult/norm_decay_mult is specified but weight_decay # is None optim_wrapper_cfg = dict( type='OptimWrapper', optimizer=dict(lr=0.0001, weight_decay=None)) paramwise_cfg = dict(bias_decay_mult=1, norm_decay_mult=1) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optim_constructor(self.model) # basic config with ExampleModel optimizer_cfg = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum)) optim_constructor = DefaultOptimWrapperConstructor(optimizer_cfg) optim_wrapper = optim_constructor(self.model) self._check_default_optimizer(optim_wrapper.optimizer, self.model) def test_default_optimizer_constructor_with_model_wrapper(self): # basic config with pseudo data parallel model = PseudoDataParallel() optim_wrapper_cfg = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum)) paramwise_cfg = None optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg) optim_wrapper = optim_constructor(model) self._check_default_optimizer( optim_wrapper.optimizer, model, prefix='module.') # paramwise_cfg with pseudo data parallel model = PseudoDataParallel() optim_wrapper_cfg = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum)) paramwise_cfg = dict( bias_lr_mult=2, bias_decay_mult=0.5, norm_decay_mult=0, dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(model) self._check_sgd_optimizer( optim_wrapper.optimizer, model, prefix='module.', **paramwise_cfg) # basic config with DataParallel if torch.cuda.is_available(): model = torch.nn.DataParallel(ExampleModel()) optim_wrapper_cfg = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum)) paramwise_cfg = None optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg) optim_wrapper = optim_constructor(model) self._check_default_optimizer( optim_wrapper.optimizer, model, prefix='module.') # paramwise_cfg with DataParallel if torch.cuda.is_available(): model = torch.nn.DataParallel(self.model) optim_wrapper_cfg = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum)) paramwise_cfg = dict( bias_lr_mult=2, bias_decay_mult=0.5, norm_decay_mult=0, dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(model) self._check_sgd_optimizer( optim_wrapper.optimizer, model, prefix='module.', **paramwise_cfg) def test_default_optimizer_constructor_with_empty_paramwise_cfg(self): # Empty paramwise_cfg with ExampleModel optim_wrapper_cfg = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum)) paramwise_cfg = dict() optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(self.model) self._check_default_optimizer(optim_wrapper.optimizer, self.model) # Empty paramwise_cfg with ExampleModel and no grad model = ExampleModel() for param in model.parameters(): param.requires_grad = False optim_wrapper_cfg = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum)) paramwise_cfg = dict() optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg) optim_wrapper = optim_constructor(model) self._check_default_optimizer(optim_wrapper.optimizer, model) def test_default_optimizer_constructor_with_paramwise_cfg(self): # paramwise_cfg with ExampleModel optim_wrapper_cfg = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum)) paramwise_cfg = dict( bias_lr_mult=2, bias_decay_mult=0.5, norm_decay_mult=0, dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(self.model) self._check_sgd_optimizer(optim_wrapper.optimizer, self.model, **paramwise_cfg) def test_default_optimizer_constructor_no_grad(self): # paramwise_cfg with ExampleModel and no grad optim_wrapper_cfg = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum)) paramwise_cfg = dict( bias_lr_mult=2, bias_decay_mult=0.5, norm_decay_mult=0, dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1) for param in self.model.parameters(): param.requires_grad = False optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(self.model) optimizer = optim_wrapper.optimizer param_groups = optimizer.param_groups assert isinstance(optim_wrapper.optimizer, torch.optim.SGD) assert optimizer.defaults['lr'] == self.base_lr assert optimizer.defaults['momentum'] == self.momentum assert optimizer.defaults['weight_decay'] == self.base_wd for i, (name, param) in enumerate(self.model.named_parameters()): param_group = param_groups[i] assert torch.equal(param_group['params'][0], param) assert param_group['momentum'] == self.momentum assert param_group['lr'] == self.base_lr assert param_group['weight_decay'] == self.base_wd def test_default_optimizer_constructor_bypass_duplicate(self): # paramwise_cfg with bypass_duplicate option model = ExampleDuplicateModel() optim_wrapper_cfg = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum)) paramwise_cfg = dict( bias_lr_mult=2, bias_decay_mult=0.5, norm_decay_mult=0, dwconv_decay_mult=0.1) with self.assertRaisesRegex( ValueError, 'some parameters appear in more than one parameter group'): optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optim_constructor(model) paramwise_cfg = dict( bias_lr_mult=2, bias_decay_mult=0.5, norm_decay_mult=0, dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1, bypass_duplicate=True) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) self.assertWarnsRegex( Warning, 'conv3.0 is duplicate. It is skipped since bypass_duplicate=True', lambda: optim_constructor(model)) optim_wrapper = optim_constructor(model) model_parameters = list(model.parameters()) num_params = 14 if MMCV_FULL_AVAILABLE else 11 assert len(optim_wrapper.optimizer.param_groups) == len( model_parameters) == num_params self._check_sgd_optimizer(optim_wrapper.optimizer, model, **paramwise_cfg) def test_default_optimizer_constructor_custom_key(self): # test DefaultOptimWrapperConstructor with custom_keys and # ExampleModel optim_wrapper_cfg = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum)) paramwise_cfg = dict( custom_keys={ 'param1': dict(lr_mult=10), 'sub': dict(lr_mult=0.1, decay_mult=0), 'sub.gn': dict(lr_mult=0.01), 'non_exist_key': dict(lr_mult=0.0) }, norm_decay_mult=0.5) with self.assertRaises(TypeError): # custom_keys should be a dict paramwise_cfg_ = dict(custom_keys=[0.1, 0.0001]) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg_) optimizer = optim_constructor(self.model) with self.assertRaises(ValueError): # if 'decay_mult' is specified in custom_keys, weight_decay # should be specified optim_wrapper_cfg_ = dict( type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)) paramwise_cfg_ = dict( custom_keys={'.backbone': dict(decay_mult=0.5)}) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg_, paramwise_cfg_) optim_constructor(self.model) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optimizer = optim_constructor(self.model).optimizer # check optimizer type and default config assert isinstance(optimizer, torch.optim.SGD) assert optimizer.defaults['lr'] == self.base_lr assert optimizer.defaults['momentum'] == self.momentum assert optimizer.defaults['weight_decay'] == self.base_wd # check params groups param_groups = optimizer.param_groups groups = [] group_settings = [] # group 1, matches of 'param1' # 'param1' is the longest match for 'sub.param1' groups.append(['param1', 'sub.param1']) group_settings.append({ 'lr': self.base_lr * 10, 'momentum': self.momentum, 'weight_decay': self.base_wd, }) # group 2, matches of 'sub.gn' groups.append(['sub.gn.weight', 'sub.gn.bias']) group_settings.append({ 'lr': self.base_lr * 0.01, 'momentum': self.momentum, 'weight_decay': self.base_wd, }) # group 3, matches of 'sub' groups.append(['sub.conv1.weight', 'sub.conv1.bias']) group_settings.append({ 'lr': self.base_lr * 0.1, 'momentum': self.momentum, 'weight_decay': 0, }) # group 4, bn is configured by 'norm_decay_mult' groups.append(['bn.weight', 'bn.bias']) group_settings.append({ 'lr': self.base_lr, 'momentum': self.momentum, 'weight_decay': self.base_wd * 0.5, }) # group 5, default group groups.append(['conv1.weight', 'conv2.weight', 'conv2.bias']) group_settings.append({ 'lr': self.base_lr, 'momentum': self.momentum, 'weight_decay': self.base_wd }) num_params = 14 if MMCV_FULL_AVAILABLE else 11 assert len(param_groups) == num_params for i, (name, param) in enumerate(self.model.named_parameters()): assert torch.equal(param_groups[i]['params'][0], param) for group, settings in zip(groups, group_settings): if name in group: for setting in settings: assert param_groups[i][setting] == settings[ setting], f'{name} {setting}' # test DefaultOptimWrapperConstructor with custom_keys and # ExampleModel 2 optim_wrapper_cfg = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=self.base_lr, momentum=self.momentum)) paramwise_cfg = dict(custom_keys={'param1': dict(lr_mult=10)}) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optimizer = optim_constructor(self.model).optimizer # check optimizer type and default config assert isinstance(optimizer, torch.optim.SGD) assert optimizer.defaults['lr'] == self.base_lr assert optimizer.defaults['momentum'] == self.momentum assert optimizer.defaults['weight_decay'] == 0 # check params groups param_groups = optimizer.param_groups groups = [] group_settings = [] # group 1, matches of 'param1' groups.append(['param1', 'sub.param1']) group_settings.append({ 'lr': self.base_lr * 10, 'momentum': self.momentum, 'weight_decay': 0, }) # group 2, default group groups.append([ 'sub.conv1.weight', 'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias', 'conv1.weight', 'conv2.weight', 'conv2.bias', 'bn.weight', 'bn.bias' ]) group_settings.append({ 'lr': self.base_lr, 'momentum': self.momentum, 'weight_decay': 0 }) num_params = 14 if MMCV_FULL_AVAILABLE else 11 assert len(param_groups) == num_params for i, (name, param) in enumerate(self.model.named_parameters()): assert torch.equal(param_groups[i]['params'][0], param) for group, settings in zip(groups, group_settings): if name in group: for setting in settings: assert param_groups[i][setting] == settings[ setting], f'{name} {setting}'