[Enhance] Use PyTorch HSwish implementation. (#1709)

* [Enhance] Use PyTorch HSwish implementation.

* fix conv test

* upgrade version

* add version comments
pull/1717/head
RangiLyu 2022-02-12 14:34:35 +08:00 committed by GitHub
parent 6e2b1067ba
commit 62c1b7f68b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 2 deletions

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.utils import TORCH_VERSION, digit_version
from .registry import ACTIVATION_LAYERS
@ACTIVATION_LAYERS.register_module()
class HSwish(nn.Module):
"""Hard Swish Module.
@ -27,3 +27,12 @@ class HSwish(nn.Module):
def forward(self, x):
return x * self.act(x + 3) / 6
if (TORCH_VERSION == 'parrots'
or digit_version(TORCH_VERSION) < digit_version('1.7')):
# Hardswish is not supported when PyTorch version < 1.6.
# And Hardswish in PyTorch 1.6 does not support inplace.
ACTIVATION_LAYERS.register_module(module=HSwish)
else:
ACTIVATION_LAYERS.register_module(module=nn.Hardswish, name='HSwish')

View File

@ -6,6 +6,7 @@ import torch
import torch.nn as nn
from mmcv.cnn.bricks import CONV_LAYERS, ConvModule, HSigmoid, HSwish
from mmcv.utils import TORCH_VERSION, digit_version
@CONV_LAYERS.register_module()
@ -138,7 +139,12 @@ def test_conv_module():
# HSwish
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='HSwish'))
assert isinstance(conv.activate, HSwish)
if (TORCH_VERSION == 'parrots'
or digit_version(TORCH_VERSION) < digit_version('1.7')):
assert isinstance(conv.activate, HSwish)
else:
assert isinstance(conv.activate, nn.Hardswish)
output = conv(x)
assert output.shape == (1, 8, 256, 256)