diff --git a/mmcv/cnn/bricks/conv_module.py b/mmcv/cnn/bricks/conv_module.py index bb0886470..365e6c73d 100644 --- a/mmcv/cnn/bricks/conv_module.py +++ b/mmcv/cnn/bricks/conv_module.py @@ -140,7 +140,9 @@ class ConvModule(nn.Module): if self.with_activation: act_cfg_ = act_cfg.copy() # nn.Tanh has no 'inplace' argument - if act_cfg_['type'] not in ['Tanh', 'PReLU', 'Sigmoid']: + if act_cfg_['type'] not in [ + 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid' + ]: act_cfg_.setdefault('inplace', inplace) self.activate = build_activation_layer(act_cfg_) diff --git a/tests/test_cnn/test_conv_module.py b/tests/test_cnn/test_conv_module.py index b29193426..3f058bab3 100644 --- a/tests/test_cnn/test_conv_module.py +++ b/tests/test_cnn/test_conv_module.py @@ -4,7 +4,7 @@ import pytest import torch import torch.nn as nn -from mmcv.cnn.bricks import CONV_LAYERS, ConvModule +from mmcv.cnn.bricks import CONV_LAYERS, ConvModule, HSigmoid, HSwish @CONV_LAYERS.register_module() @@ -135,6 +135,18 @@ def test_conv_module(): output = conv(x) assert output.shape == (1, 8, 256, 256) + # HSwish + conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='HSwish')) + assert isinstance(conv.activate, HSwish) + output = conv(x) + assert output.shape == (1, 8, 256, 256) + + # HSigmoid + conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='HSigmoid')) + assert isinstance(conv.activate, HSigmoid) + output = conv(x) + assert output.shape == (1, 8, 256, 256) + def test_bias(): # bias: auto, without norm