[Fix] Add Pytorch HardSwish assertion in unit test (#1294)

* assert original HardSwish when PyTorch > 1.6 in unit test

* assert original HardSwish when PyTorch > 1.6 in unit test

* assert original HardSwish when PyTorch > 1.6 in unit test

* assert original HardSwish when PyTorch > 1.6 in unit test

* assert original HardSwish when PyTorch > 1.6 in unit test

* assert original HardSwish when PyTorch > 1.6 in unit test
This commit is contained in:
MengzhangLI 2022-02-16 19:47:58 +08:00 committed by GitHub
parent 58c02dd169
commit 304df56c78

View File

@ -2,6 +2,7 @@
import mmcv import mmcv
import pytest import pytest
import torch import torch
from mmcv.utils import TORCH_VERSION, digit_version
from mmseg.models.utils import (InvertedResidual, InvertedResidualV3, SELayer, from mmseg.models.utils import (InvertedResidual, InvertedResidualV3, SELayer,
make_divisible) make_divisible)
@ -108,7 +109,6 @@ def test_inv_residualv3():
assert inv_module.expand_conv.conv.kernel_size == (1, 1) assert inv_module.expand_conv.conv.kernel_size == (1, 1)
assert inv_module.expand_conv.conv.stride == (1, 1) assert inv_module.expand_conv.conv.stride == (1, 1)
assert inv_module.expand_conv.conv.padding == (0, 0) assert inv_module.expand_conv.conv.padding == (0, 0)
assert isinstance(inv_module.expand_conv.activate, mmcv.cnn.HSwish)
assert isinstance(inv_module.depthwise_conv.conv, assert isinstance(inv_module.depthwise_conv.conv,
mmcv.cnn.bricks.Conv2dAdaptivePadding) mmcv.cnn.bricks.Conv2dAdaptivePadding)
@ -116,11 +116,27 @@ def test_inv_residualv3():
assert inv_module.depthwise_conv.conv.stride == (2, 2) assert inv_module.depthwise_conv.conv.stride == (2, 2)
assert inv_module.depthwise_conv.conv.padding == (0, 0) assert inv_module.depthwise_conv.conv.padding == (0, 0)
assert isinstance(inv_module.depthwise_conv.bn, torch.nn.BatchNorm2d) assert isinstance(inv_module.depthwise_conv.bn, torch.nn.BatchNorm2d)
assert isinstance(inv_module.depthwise_conv.activate, mmcv.cnn.HSwish)
assert inv_module.linear_conv.conv.kernel_size == (1, 1) assert inv_module.linear_conv.conv.kernel_size == (1, 1)
assert inv_module.linear_conv.conv.stride == (1, 1) assert inv_module.linear_conv.conv.stride == (1, 1)
assert inv_module.linear_conv.conv.padding == (0, 0) assert inv_module.linear_conv.conv.padding == (0, 0)
assert isinstance(inv_module.linear_conv.bn, torch.nn.BatchNorm2d) assert isinstance(inv_module.linear_conv.bn, torch.nn.BatchNorm2d)
if (TORCH_VERSION == 'parrots'
or digit_version(TORCH_VERSION) < digit_version('1.7')):
# Note: Use PyTorch official HSwish
# when torch>=1.7 after MMCV >= 1.4.5.
# Hardswish is not supported when PyTorch version < 1.6.
# And Hardswish in PyTorch 1.6 does not support inplace.
# More details could be found from:
# https://github.com/open-mmlab/mmcv/pull/1709
assert isinstance(inv_module.expand_conv.activate, mmcv.cnn.HSwish)
assert isinstance(inv_module.depthwise_conv.activate, mmcv.cnn.HSwish)
else:
assert isinstance(inv_module.expand_conv.activate, torch.nn.Hardswish)
assert isinstance(inv_module.depthwise_conv.activate,
torch.nn.Hardswish)
x = torch.rand(1, 32, 64, 64) x = torch.rand(1, 32, 64, 64)
output = inv_module(x) output = inv_module(x)
assert output.shape == (1, 40, 32, 32) assert output.shape == (1, 40, 32, 32)