[Fix] Fix skip_layer for RF-Next (#2489)

* judge skip_layer by fullname

* lint

* skip_layer first

* update unit test
pull/2487/head
Zhongyu Li 2022-12-28 15:27:34 +08:00 committed by GitHub
parent 30d975a5f9
commit 935ba78b39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 41 deletions

View File

@ -143,7 +143,10 @@ class RFSearchHook(Hook):
module.estimate_rates()
module.expand_rates()
def wrap_model(self, model: nn.Module, search_op: str = 'Conv2d'):
def wrap_model(self,
model: nn.Module,
search_op: str = 'Conv2d',
prefix: str = ''):
"""wrap model to support searchable conv op.
Args:
@ -152,9 +155,18 @@ class RFSearchHook(Hook):
Defaults to 'Conv2d'.
init_rates (int, optional): Set to other initial dilation rates.
Defaults to None.
prefix (str): Prefix for function recursion. Defaults to ''.
"""
op = 'torch.nn.' + search_op
for name, module in model.named_children():
if prefix == '':
fullname = 'module.' + name
else:
fullname = prefix + '.' + name
if self.config['search']['skip_layer'] is not None:
if any(layer in fullname
for layer in self.config['search']['skip_layer']):
continue
if isinstance(module, eval(op)):
if 1 < module.kernel_size[0] and \
0 != module.kernel_size[0] % 2 or \
@ -167,14 +179,8 @@ class RFSearchHook(Hook):
logger.info('Wrap model %s to %s.' %
(str(module), str(moduleWrap)))
setattr(model, name, moduleWrap)
elif isinstance(module, BaseConvRFSearchOp):
pass
else:
if self.config['search']['skip_layer'] is not None:
if any(layer in name
for layer in self.config['search']['skip_layer']):
continue
self.wrap_model(module, search_op)
elif not isinstance(module, BaseConvRFSearchOp):
self.wrap_model(module, search_op, fullname)
def set_model(self,
model: nn.Module,
@ -198,6 +204,10 @@ class RFSearchHook(Hook):
fullname = 'module.' + name
else:
fullname = prefix + '.' + name
if self.config['search']['skip_layer'] is not None:
if any(layer in fullname
for layer in self.config['search']['skip_layer']):
continue
if isinstance(module, eval(op)):
if 1 < module.kernel_size[0] and \
0 != module.kernel_size[0] % 2 or \
@ -224,11 +234,5 @@ class RFSearchHook(Hook):
logger.info(
'Set module %s dilation as: [%d %d]' %
(fullname, module.dilation[0], module.dilation[1]))
elif isinstance(module, BaseConvRFSearchOp):
pass
else:
if self.config['search']['skip_layer'] is not None:
if any(layer in fullname
for layer in self.config['search']['skip_layer']):
continue
elif not isinstance(module, BaseConvRFSearchOp):
self.set_model(module, search_op, init_rates, fullname)

View File

@ -16,36 +16,36 @@ from tests.test_runner.test_hooks import _build_demo_runner
def test_rfsearchhook():
def conv(in_channels, out_channels, kernel_size, stride, padding,
dilation):
return nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=1,
out_channels=2,
kernel_size=1,
stride=1,
padding=0,
dilation=1)
self.conv2 = nn.Conv2d(
in_channels=2,
out_channels=2,
kernel_size=3,
stride=1,
padding=1,
dilation=1)
self.conv3 = nn.Conv2d(
in_channels=1,
out_channels=2,
kernel_size=(1, 3),
stride=1,
padding=(0, 1),
dilation=1)
self.stem = conv(1, 2, 3, 1, 1, 1)
self.conv0 = conv(2, 2, 3, 1, 1, 1)
self.layer0 = nn.Sequential(
conv(2, 2, 3, 1, 1, 1), conv(2, 2, 3, 1, 1, 1))
self.conv1 = conv(2, 2, 1, 1, 0, 1)
self.conv2 = conv(2, 2, 3, 1, 1, 1)
self.conv3 = conv(2, 2, (1, 3), 1, (0, 1), 1)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
return x2
x1 = self.stem(x)
x2 = self.layer0(x1)
x3 = self.conv0(x2)
x4 = self.conv1(x3)
x5 = self.conv2(x4)
x6 = self.conv3(x5)
return x6
def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x).mean(), num_samples=x.shape[0])
@ -63,13 +63,14 @@ def test_rfsearchhook():
mmin=1,
mmax=24,
num_branches=2,
skip_layer=['stem', 'layer1'])),
skip_layer=['stem', 'conv0', 'layer0.1'])),
)
# hook for search
rfsearchhook_search = RFSearchHook(
'search', rfsearch_cfg['config'], by_epoch=True, verbose=True)
rfsearchhook_search.config['structure'] = {
'module.layer0.0': [1, 1],
'module.conv2': [2, 2],
'module.conv3': [1, 1]
}
@ -80,6 +81,7 @@ def test_rfsearchhook():
by_epoch=True,
verbose=True)
rfsearchhook_fixed_single_branch.config['structure'] = {
'module.layer0.0': [1, 1],
'module.conv2': [2, 2],
'module.conv3': [1, 1]
}
@ -90,14 +92,22 @@ def test_rfsearchhook():
by_epoch=True,
verbose=True)
rfsearchhook_fixed_multi_branch.config['structure'] = {
'module.layer0.0': [1, 1],
'module.conv2': [2, 2],
'module.conv3': [1, 1]
}
def test_skip_layer():
assert not isinstance(model.stem, Conv2dRFSearchOp)
assert not isinstance(model.conv0, Conv2dRFSearchOp)
assert isinstance(model.layer0[0], Conv2dRFSearchOp)
assert not isinstance(model.layer0[1], Conv2dRFSearchOp)
# 1. test init_model() with mode of search
model = Model()
rfsearchhook_search.init_model(model)
test_skip_layer()
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp)
@ -111,6 +121,7 @@ def test_rfsearchhook():
runner.register_hook(rfsearchhook_search)
runner.run([loader], [('train', 1)])
test_skip_layer()
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp)
@ -145,6 +156,7 @@ def test_rfsearchhook():
model = Model()
rfsearchhook_fixed_multi_branch.init_model(model)
test_skip_layer()
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp)
@ -157,6 +169,7 @@ def test_rfsearchhook():
runner.register_hook(rfsearchhook_fixed_multi_branch)
runner.run([loader], [('train', 1)])
test_skip_layer()
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp)