mirror of https://github.com/open-mmlab/mmcv.git
[Enhance] Use PyTorch HSwish implementation. (#1709)
* [Enhance] Use PyTorch HSwish implementation. * fix conv test * upgrade version * add version commentspull/1717/head
parent
6e2b1067ba
commit
62c1b7f68b
|
@ -1,10 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcv.utils import TORCH_VERSION, digit_version
|
||||
from .registry import ACTIVATION_LAYERS
|
||||
|
||||
|
||||
@ACTIVATION_LAYERS.register_module()
|
||||
class HSwish(nn.Module):
|
||||
"""Hard Swish Module.
|
||||
|
||||
|
@ -27,3 +27,12 @@ class HSwish(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
return x * self.act(x + 3) / 6
|
||||
|
||||
|
||||
if (TORCH_VERSION == 'parrots'
|
||||
or digit_version(TORCH_VERSION) < digit_version('1.7')):
|
||||
# Hardswish is not supported when PyTorch version < 1.6.
|
||||
# And Hardswish in PyTorch 1.6 does not support inplace.
|
||||
ACTIVATION_LAYERS.register_module(module=HSwish)
|
||||
else:
|
||||
ACTIVATION_LAYERS.register_module(module=nn.Hardswish, name='HSwish')
|
||||
|
|
|
@ -6,6 +6,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from mmcv.cnn.bricks import CONV_LAYERS, ConvModule, HSigmoid, HSwish
|
||||
from mmcv.utils import TORCH_VERSION, digit_version
|
||||
|
||||
|
||||
@CONV_LAYERS.register_module()
|
||||
|
@ -138,7 +139,12 @@ def test_conv_module():
|
|||
|
||||
# HSwish
|
||||
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='HSwish'))
|
||||
assert isinstance(conv.activate, HSwish)
|
||||
if (TORCH_VERSION == 'parrots'
|
||||
or digit_version(TORCH_VERSION) < digit_version('1.7')):
|
||||
assert isinstance(conv.activate, HSwish)
|
||||
else:
|
||||
assert isinstance(conv.activate, nn.Hardswish)
|
||||
|
||||
output = conv(x)
|
||||
assert output.shape == (1, 8, 256, 256)
|
||||
|
||||
|
|
Loading…
Reference in New Issue