From 2c6fc5fd9bdb4c71367a1bf14e0dd55bc7327d2f Mon Sep 17 00:00:00 2001 From: louzana <67114543+louzana@users.noreply.github.com> Date: Sun, 28 Jun 2020 23:27:14 +0800 Subject: [PATCH] =?UTF-8?q?fix=20bug=20of=20building=20ConvModule=20with?= =?UTF-8?q?=20HSigmoid=20using=20inplace=3DTrue=20and=20a=E2=80=A6=20(#369?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix bug of building ConvModule with HSigmoid using inplace=True and add corresponding unittest * fix linting --- mmcv/cnn/bricks/conv_module.py | 4 +++- tests/test_cnn/test_conv_module.py | 14 +++++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/mmcv/cnn/bricks/conv_module.py b/mmcv/cnn/bricks/conv_module.py index bb0886470..365e6c73d 100644 --- a/mmcv/cnn/bricks/conv_module.py +++ b/mmcv/cnn/bricks/conv_module.py @@ -140,7 +140,9 @@ class ConvModule(nn.Module): if self.with_activation: act_cfg_ = act_cfg.copy() # nn.Tanh has no 'inplace' argument - if act_cfg_['type'] not in ['Tanh', 'PReLU', 'Sigmoid']: + if act_cfg_['type'] not in [ + 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid' + ]: act_cfg_.setdefault('inplace', inplace) self.activate = build_activation_layer(act_cfg_) diff --git a/tests/test_cnn/test_conv_module.py b/tests/test_cnn/test_conv_module.py index b29193426..3f058bab3 100644 --- a/tests/test_cnn/test_conv_module.py +++ b/tests/test_cnn/test_conv_module.py @@ -4,7 +4,7 @@ import pytest import torch import torch.nn as nn -from mmcv.cnn.bricks import CONV_LAYERS, ConvModule +from mmcv.cnn.bricks import CONV_LAYERS, ConvModule, HSigmoid, HSwish @CONV_LAYERS.register_module() @@ -135,6 +135,18 @@ def test_conv_module(): output = conv(x) assert output.shape == (1, 8, 256, 256) + # HSwish + conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='HSwish')) + assert isinstance(conv.activate, HSwish) + output = conv(x) + assert output.shape == (1, 8, 256, 256) + + # HSigmoid + conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='HSigmoid')) + assert isinstance(conv.activate, HSigmoid) + output = conv(x) + assert output.shape == (1, 8, 256, 256) + def test_bias(): # bias: auto, without norm