mirror of https://github.com/open-mmlab/mmcv.git
[Enhance] Add ConvModule.turn_on_fast_conv_bn_eval to reduce repetitive code and dynamically bind conv during forward (#2835)
parent
8d87bf44ca
commit
3df6414efa
|
@ -209,15 +209,7 @@ class ConvModule(nn.Module):
|
|||
else:
|
||||
self.norm_name = None # type: ignore
|
||||
|
||||
# fast_conv_bn_eval works for conv + bn
|
||||
# with `track_running_stats` option
|
||||
if fast_conv_bn_eval and self.norm and isinstance(
|
||||
self.norm, _BatchNorm) and self.norm.track_running_stats:
|
||||
self.fast_conv_bn_eval_forward = partial(fast_conv_bn_eval_forward,
|
||||
self.norm, self.conv)
|
||||
else:
|
||||
self.fast_conv_bn_eval_forward = None # type: ignore
|
||||
self.original_conv_forward = self.conv.forward
|
||||
self.turn_on_fast_conv_bn_eval(fast_conv_bn_eval)
|
||||
|
||||
# build activation layer
|
||||
if self.with_activation:
|
||||
|
@ -278,11 +270,13 @@ class ConvModule(nn.Module):
|
|||
self.order[layer_index + 1] == 'norm' and norm and \
|
||||
self.with_norm and not self.norm.training and \
|
||||
self.fast_conv_bn_eval_forward is not None:
|
||||
self.conv.forward = self.fast_conv_bn_eval_forward
|
||||
self.conv.forward = partial(self.fast_conv_bn_eval_forward,
|
||||
self.norm, self.conv)
|
||||
layer_index += 1
|
||||
x = self.conv(x)
|
||||
del self.conv.forward
|
||||
else:
|
||||
self.conv.forward = self.original_conv_forward
|
||||
x = self.conv(x)
|
||||
x = self.conv(x)
|
||||
elif layer == 'norm' and norm and self.with_norm:
|
||||
x = self.norm(x)
|
||||
elif layer == 'act' and activate and self.with_activation:
|
||||
|
@ -290,6 +284,16 @@ class ConvModule(nn.Module):
|
|||
layer_index += 1
|
||||
return x
|
||||
|
||||
def turn_on_fast_conv_bn_eval(self, fast_conv_bn_eval=True):
|
||||
# fast_conv_bn_eval works for conv + bn
|
||||
# with `track_running_stats` option
|
||||
if fast_conv_bn_eval and self.norm \
|
||||
and isinstance(self.norm, _BatchNorm) \
|
||||
and self.norm.track_running_stats:
|
||||
self.fast_conv_bn_eval_forward = fast_conv_bn_eval_forward
|
||||
else:
|
||||
self.fast_conv_bn_eval_forward = None # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd,
|
||||
bn: torch.nn.modules.batchnorm._BatchNorm,
|
||||
|
@ -327,14 +331,6 @@ class ConvModule(nn.Module):
|
|||
self.norm_name, norm = 'bn', bn
|
||||
self.add_module(self.norm_name, norm)
|
||||
|
||||
# fast_conv_bn_eval works for conv + bn
|
||||
# with `track_running_stats` option
|
||||
if fast_conv_bn_eval and self.norm and isinstance(
|
||||
self.norm, _BatchNorm) and self.norm.track_running_stats:
|
||||
self.fast_conv_bn_eval_forward = partial(fast_conv_bn_eval_forward,
|
||||
self.norm, self.conv)
|
||||
else:
|
||||
self.fast_conv_bn_eval_forward = None # type: ignore
|
||||
self.original_conv_forward = self.conv.forward
|
||||
self.turn_on_fast_conv_bn_eval(fast_conv_bn_eval)
|
||||
|
||||
return self
|
||||
|
|
|
@ -76,14 +76,26 @@ def test_conv_module():
|
|||
assert output.shape == (1, 8, 255, 255)
|
||||
|
||||
# conv + norm with fast mode
|
||||
conv = ConvModule(
|
||||
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True)
|
||||
conv.norm.eval()
|
||||
x = torch.rand(1, 3, 256, 256)
|
||||
fast_mode_output = conv(x)
|
||||
conv.conv.forward = conv.original_conv_forward
|
||||
plain_implementation = conv.activate(conv.norm(conv.conv(x)))
|
||||
assert torch.allclose(fast_mode_output, plain_implementation, atol=1e-5)
|
||||
fast_conv = ConvModule(
|
||||
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True).eval()
|
||||
plain_conv = ConvModule(
|
||||
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=False).eval()
|
||||
for fast_param, plain_param in zip(fast_conv.state_dict().values(),
|
||||
plain_conv.state_dict().values()):
|
||||
plain_param.copy_(fast_param)
|
||||
|
||||
fast_mode_output = fast_conv(x)
|
||||
plain_mode_output = plain_conv(x)
|
||||
assert torch.allclose(fast_mode_output, plain_mode_output, atol=1e-5)
|
||||
|
||||
# `conv` attribute can be dynamically modified in fast mode
|
||||
fast_conv = ConvModule(
|
||||
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True).eval()
|
||||
new_conv = nn.Conv2d(3, 8, 2).eval()
|
||||
fast_conv.conv = new_conv
|
||||
fast_mode_output = fast_conv(x)
|
||||
plain_mode_output = fast_conv.activate(fast_conv.norm(new_conv(x)))
|
||||
assert torch.allclose(fast_mode_output, plain_mode_output, atol=1e-5)
|
||||
|
||||
# conv + act
|
||||
conv = ConvModule(3, 8, 2)
|
||||
|
|
Loading…
Reference in New Issue