fix bug of building ConvModule with HSigmoid using inplace=True and a… ()

* fix bug of building ConvModule with HSigmoid using inplace=True and add corresponding unittest

* fix linting
pull/372/head
louzana 2020-06-28 23:27:14 +08:00 committed by GitHub
parent c0f5492ee9
commit 2c6fc5fd9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 2 deletions
mmcv/cnn/bricks
tests/test_cnn

View File

@ -140,7 +140,9 @@ class ConvModule(nn.Module):
if self.with_activation:
act_cfg_ = act_cfg.copy()
# nn.Tanh has no 'inplace' argument
if act_cfg_['type'] not in ['Tanh', 'PReLU', 'Sigmoid']:
if act_cfg_['type'] not in [
'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid'
]:
act_cfg_.setdefault('inplace', inplace)
self.activate = build_activation_layer(act_cfg_)

View File

@ -4,7 +4,7 @@ import pytest
import torch
import torch.nn as nn
from mmcv.cnn.bricks import CONV_LAYERS, ConvModule
from mmcv.cnn.bricks import CONV_LAYERS, ConvModule, HSigmoid, HSwish
@CONV_LAYERS.register_module()
@ -135,6 +135,18 @@ def test_conv_module():
output = conv(x)
assert output.shape == (1, 8, 256, 256)
# HSwish
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='HSwish'))
assert isinstance(conv.activate, HSwish)
output = conv(x)
assert output.shape == (1, 8, 256, 256)
# HSigmoid
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='HSigmoid'))
assert isinstance(conv.activate, HSigmoid)
output = conv(x)
assert output.shape == (1, 8, 256, 256)
def test_bias():
# bias: auto, without norm