mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Fix skip_layer for RF-Next (#2489)
* judge skip_layer by fullname * lint * skip_layer first * update unit testpull/2487/head
parent
30d975a5f9
commit
935ba78b39
|
@ -143,7 +143,10 @@ class RFSearchHook(Hook):
|
|||
module.estimate_rates()
|
||||
module.expand_rates()
|
||||
|
||||
def wrap_model(self, model: nn.Module, search_op: str = 'Conv2d'):
|
||||
def wrap_model(self,
|
||||
model: nn.Module,
|
||||
search_op: str = 'Conv2d',
|
||||
prefix: str = ''):
|
||||
"""wrap model to support searchable conv op.
|
||||
|
||||
Args:
|
||||
|
@ -152,9 +155,18 @@ class RFSearchHook(Hook):
|
|||
Defaults to 'Conv2d'.
|
||||
init_rates (int, optional): Set to other initial dilation rates.
|
||||
Defaults to None.
|
||||
prefix (str): Prefix for function recursion. Defaults to ''.
|
||||
"""
|
||||
op = 'torch.nn.' + search_op
|
||||
for name, module in model.named_children():
|
||||
if prefix == '':
|
||||
fullname = 'module.' + name
|
||||
else:
|
||||
fullname = prefix + '.' + name
|
||||
if self.config['search']['skip_layer'] is not None:
|
||||
if any(layer in fullname
|
||||
for layer in self.config['search']['skip_layer']):
|
||||
continue
|
||||
if isinstance(module, eval(op)):
|
||||
if 1 < module.kernel_size[0] and \
|
||||
0 != module.kernel_size[0] % 2 or \
|
||||
|
@ -167,14 +179,8 @@ class RFSearchHook(Hook):
|
|||
logger.info('Wrap model %s to %s.' %
|
||||
(str(module), str(moduleWrap)))
|
||||
setattr(model, name, moduleWrap)
|
||||
elif isinstance(module, BaseConvRFSearchOp):
|
||||
pass
|
||||
else:
|
||||
if self.config['search']['skip_layer'] is not None:
|
||||
if any(layer in name
|
||||
for layer in self.config['search']['skip_layer']):
|
||||
continue
|
||||
self.wrap_model(module, search_op)
|
||||
elif not isinstance(module, BaseConvRFSearchOp):
|
||||
self.wrap_model(module, search_op, fullname)
|
||||
|
||||
def set_model(self,
|
||||
model: nn.Module,
|
||||
|
@ -198,6 +204,10 @@ class RFSearchHook(Hook):
|
|||
fullname = 'module.' + name
|
||||
else:
|
||||
fullname = prefix + '.' + name
|
||||
if self.config['search']['skip_layer'] is not None:
|
||||
if any(layer in fullname
|
||||
for layer in self.config['search']['skip_layer']):
|
||||
continue
|
||||
if isinstance(module, eval(op)):
|
||||
if 1 < module.kernel_size[0] and \
|
||||
0 != module.kernel_size[0] % 2 or \
|
||||
|
@ -224,11 +234,5 @@ class RFSearchHook(Hook):
|
|||
logger.info(
|
||||
'Set module %s dilation as: [%d %d]' %
|
||||
(fullname, module.dilation[0], module.dilation[1]))
|
||||
elif isinstance(module, BaseConvRFSearchOp):
|
||||
pass
|
||||
else:
|
||||
if self.config['search']['skip_layer'] is not None:
|
||||
if any(layer in fullname
|
||||
for layer in self.config['search']['skip_layer']):
|
||||
continue
|
||||
elif not isinstance(module, BaseConvRFSearchOp):
|
||||
self.set_model(module, search_op, init_rates, fullname)
|
||||
|
|
|
@ -16,36 +16,36 @@ from tests.test_runner.test_hooks import _build_demo_runner
|
|||
|
||||
def test_rfsearchhook():
|
||||
|
||||
def conv(in_channels, out_channels, kernel_size, stride, padding,
|
||||
dilation):
|
||||
return nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation)
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels=1,
|
||||
out_channels=2,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1)
|
||||
self.conv2 = nn.Conv2d(
|
||||
in_channels=2,
|
||||
out_channels=2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dilation=1)
|
||||
self.conv3 = nn.Conv2d(
|
||||
in_channels=1,
|
||||
out_channels=2,
|
||||
kernel_size=(1, 3),
|
||||
stride=1,
|
||||
padding=(0, 1),
|
||||
dilation=1)
|
||||
self.stem = conv(1, 2, 3, 1, 1, 1)
|
||||
self.conv0 = conv(2, 2, 3, 1, 1, 1)
|
||||
self.layer0 = nn.Sequential(
|
||||
conv(2, 2, 3, 1, 1, 1), conv(2, 2, 3, 1, 1, 1))
|
||||
self.conv1 = conv(2, 2, 1, 1, 0, 1)
|
||||
self.conv2 = conv(2, 2, 3, 1, 1, 1)
|
||||
self.conv3 = conv(2, 2, (1, 3), 1, (0, 1), 1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.conv1(x)
|
||||
x2 = self.conv2(x1)
|
||||
return x2
|
||||
x1 = self.stem(x)
|
||||
x2 = self.layer0(x1)
|
||||
x3 = self.conv0(x2)
|
||||
x4 = self.conv1(x3)
|
||||
x5 = self.conv2(x4)
|
||||
x6 = self.conv3(x5)
|
||||
return x6
|
||||
|
||||
def train_step(self, x, optimizer, **kwargs):
|
||||
return dict(loss=self(x).mean(), num_samples=x.shape[0])
|
||||
|
@ -63,13 +63,14 @@ def test_rfsearchhook():
|
|||
mmin=1,
|
||||
mmax=24,
|
||||
num_branches=2,
|
||||
skip_layer=['stem', 'layer1'])),
|
||||
skip_layer=['stem', 'conv0', 'layer0.1'])),
|
||||
)
|
||||
|
||||
# hook for search
|
||||
rfsearchhook_search = RFSearchHook(
|
||||
'search', rfsearch_cfg['config'], by_epoch=True, verbose=True)
|
||||
rfsearchhook_search.config['structure'] = {
|
||||
'module.layer0.0': [1, 1],
|
||||
'module.conv2': [2, 2],
|
||||
'module.conv3': [1, 1]
|
||||
}
|
||||
|
@ -80,6 +81,7 @@ def test_rfsearchhook():
|
|||
by_epoch=True,
|
||||
verbose=True)
|
||||
rfsearchhook_fixed_single_branch.config['structure'] = {
|
||||
'module.layer0.0': [1, 1],
|
||||
'module.conv2': [2, 2],
|
||||
'module.conv3': [1, 1]
|
||||
}
|
||||
|
@ -90,14 +92,22 @@ def test_rfsearchhook():
|
|||
by_epoch=True,
|
||||
verbose=True)
|
||||
rfsearchhook_fixed_multi_branch.config['structure'] = {
|
||||
'module.layer0.0': [1, 1],
|
||||
'module.conv2': [2, 2],
|
||||
'module.conv3': [1, 1]
|
||||
}
|
||||
|
||||
def test_skip_layer():
|
||||
assert not isinstance(model.stem, Conv2dRFSearchOp)
|
||||
assert not isinstance(model.conv0, Conv2dRFSearchOp)
|
||||
assert isinstance(model.layer0[0], Conv2dRFSearchOp)
|
||||
assert not isinstance(model.layer0[1], Conv2dRFSearchOp)
|
||||
|
||||
# 1. test init_model() with mode of search
|
||||
model = Model()
|
||||
rfsearchhook_search.init_model(model)
|
||||
|
||||
test_skip_layer()
|
||||
assert not isinstance(model.conv1, Conv2dRFSearchOp)
|
||||
assert isinstance(model.conv2, Conv2dRFSearchOp)
|
||||
assert isinstance(model.conv3, Conv2dRFSearchOp)
|
||||
|
@ -111,6 +121,7 @@ def test_rfsearchhook():
|
|||
runner.register_hook(rfsearchhook_search)
|
||||
runner.run([loader], [('train', 1)])
|
||||
|
||||
test_skip_layer()
|
||||
assert not isinstance(model.conv1, Conv2dRFSearchOp)
|
||||
assert isinstance(model.conv2, Conv2dRFSearchOp)
|
||||
assert isinstance(model.conv3, Conv2dRFSearchOp)
|
||||
|
@ -145,6 +156,7 @@ def test_rfsearchhook():
|
|||
model = Model()
|
||||
rfsearchhook_fixed_multi_branch.init_model(model)
|
||||
|
||||
test_skip_layer()
|
||||
assert not isinstance(model.conv1, Conv2dRFSearchOp)
|
||||
assert isinstance(model.conv2, Conv2dRFSearchOp)
|
||||
assert isinstance(model.conv3, Conv2dRFSearchOp)
|
||||
|
@ -157,6 +169,7 @@ def test_rfsearchhook():
|
|||
runner.register_hook(rfsearchhook_fixed_multi_branch)
|
||||
runner.run([loader], [('train', 1)])
|
||||
|
||||
test_skip_layer()
|
||||
assert not isinstance(model.conv1, Conv2dRFSearchOp)
|
||||
assert isinstance(model.conv2, Conv2dRFSearchOp)
|
||||
assert isinstance(model.conv3, Conv2dRFSearchOp)
|
||||
|
|
Loading…
Reference in New Issue