[Enhance] Add ConvModule.turn_on_fast_conv_bn_eval to reduce repetitive code and dynamically bind conv during forward (#2835)

pull/2831/head^2
youkaichao 2023-06-16 12:43:53 +08:00 committed by GitHub
parent 8d87bf44ca
commit 3df6414efa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 29 deletions

View File

@ -209,15 +209,7 @@ class ConvModule(nn.Module):
else:
self.norm_name = None # type: ignore
# fast_conv_bn_eval works for conv + bn
# with `track_running_stats` option
if fast_conv_bn_eval and self.norm and isinstance(
self.norm, _BatchNorm) and self.norm.track_running_stats:
self.fast_conv_bn_eval_forward = partial(fast_conv_bn_eval_forward,
self.norm, self.conv)
else:
self.fast_conv_bn_eval_forward = None # type: ignore
self.original_conv_forward = self.conv.forward
self.turn_on_fast_conv_bn_eval(fast_conv_bn_eval)
# build activation layer
if self.with_activation:
@ -278,11 +270,13 @@ class ConvModule(nn.Module):
self.order[layer_index + 1] == 'norm' and norm and \
self.with_norm and not self.norm.training and \
self.fast_conv_bn_eval_forward is not None:
self.conv.forward = self.fast_conv_bn_eval_forward
self.conv.forward = partial(self.fast_conv_bn_eval_forward,
self.norm, self.conv)
layer_index += 1
x = self.conv(x)
del self.conv.forward
else:
self.conv.forward = self.original_conv_forward
x = self.conv(x)
x = self.conv(x)
elif layer == 'norm' and norm and self.with_norm:
x = self.norm(x)
elif layer == 'act' and activate and self.with_activation:
@ -290,6 +284,16 @@ class ConvModule(nn.Module):
layer_index += 1
return x
def turn_on_fast_conv_bn_eval(self, fast_conv_bn_eval=True):
# fast_conv_bn_eval works for conv + bn
# with `track_running_stats` option
if fast_conv_bn_eval and self.norm \
and isinstance(self.norm, _BatchNorm) \
and self.norm.track_running_stats:
self.fast_conv_bn_eval_forward = fast_conv_bn_eval_forward
else:
self.fast_conv_bn_eval_forward = None # type: ignore
@staticmethod
def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd,
bn: torch.nn.modules.batchnorm._BatchNorm,
@ -327,14 +331,6 @@ class ConvModule(nn.Module):
self.norm_name, norm = 'bn', bn
self.add_module(self.norm_name, norm)
# fast_conv_bn_eval works for conv + bn
# with `track_running_stats` option
if fast_conv_bn_eval and self.norm and isinstance(
self.norm, _BatchNorm) and self.norm.track_running_stats:
self.fast_conv_bn_eval_forward = partial(fast_conv_bn_eval_forward,
self.norm, self.conv)
else:
self.fast_conv_bn_eval_forward = None # type: ignore
self.original_conv_forward = self.conv.forward
self.turn_on_fast_conv_bn_eval(fast_conv_bn_eval)
return self

View File

@ -76,14 +76,26 @@ def test_conv_module():
assert output.shape == (1, 8, 255, 255)
# conv + norm with fast mode
conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True)
conv.norm.eval()
x = torch.rand(1, 3, 256, 256)
fast_mode_output = conv(x)
conv.conv.forward = conv.original_conv_forward
plain_implementation = conv.activate(conv.norm(conv.conv(x)))
assert torch.allclose(fast_mode_output, plain_implementation, atol=1e-5)
fast_conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True).eval()
plain_conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=False).eval()
for fast_param, plain_param in zip(fast_conv.state_dict().values(),
plain_conv.state_dict().values()):
plain_param.copy_(fast_param)
fast_mode_output = fast_conv(x)
plain_mode_output = plain_conv(x)
assert torch.allclose(fast_mode_output, plain_mode_output, atol=1e-5)
# `conv` attribute can be dynamically modified in fast mode
fast_conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True).eval()
new_conv = nn.Conv2d(3, 8, 2).eval()
fast_conv.conv = new_conv
fast_mode_output = fast_conv(x)
plain_mode_output = fast_conv.activate(fast_conv.norm(new_conv(x)))
assert torch.allclose(fast_mode_output, plain_mode_output, atol=1e-5)
# conv + act
conv = ConvModule(3, 8, 2)