mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Add Pytorch HardSwish assertion in unit test (#1294)
* assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test
This commit is contained in:
parent
58c02dd169
commit
304df56c78
@ -2,6 +2,7 @@
|
||||
import mmcv
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv.utils import TORCH_VERSION, digit_version
|
||||
|
||||
from mmseg.models.utils import (InvertedResidual, InvertedResidualV3, SELayer,
|
||||
make_divisible)
|
||||
@ -108,7 +109,6 @@ def test_inv_residualv3():
|
||||
assert inv_module.expand_conv.conv.kernel_size == (1, 1)
|
||||
assert inv_module.expand_conv.conv.stride == (1, 1)
|
||||
assert inv_module.expand_conv.conv.padding == (0, 0)
|
||||
assert isinstance(inv_module.expand_conv.activate, mmcv.cnn.HSwish)
|
||||
|
||||
assert isinstance(inv_module.depthwise_conv.conv,
|
||||
mmcv.cnn.bricks.Conv2dAdaptivePadding)
|
||||
@ -116,11 +116,27 @@ def test_inv_residualv3():
|
||||
assert inv_module.depthwise_conv.conv.stride == (2, 2)
|
||||
assert inv_module.depthwise_conv.conv.padding == (0, 0)
|
||||
assert isinstance(inv_module.depthwise_conv.bn, torch.nn.BatchNorm2d)
|
||||
assert isinstance(inv_module.depthwise_conv.activate, mmcv.cnn.HSwish)
|
||||
|
||||
assert inv_module.linear_conv.conv.kernel_size == (1, 1)
|
||||
assert inv_module.linear_conv.conv.stride == (1, 1)
|
||||
assert inv_module.linear_conv.conv.padding == (0, 0)
|
||||
assert isinstance(inv_module.linear_conv.bn, torch.nn.BatchNorm2d)
|
||||
|
||||
if (TORCH_VERSION == 'parrots'
|
||||
or digit_version(TORCH_VERSION) < digit_version('1.7')):
|
||||
# Note: Use PyTorch official HSwish
|
||||
# when torch>=1.7 after MMCV >= 1.4.5.
|
||||
# Hardswish is not supported when PyTorch version < 1.6.
|
||||
# And Hardswish in PyTorch 1.6 does not support inplace.
|
||||
# More details could be found from:
|
||||
# https://github.com/open-mmlab/mmcv/pull/1709
|
||||
assert isinstance(inv_module.expand_conv.activate, mmcv.cnn.HSwish)
|
||||
assert isinstance(inv_module.depthwise_conv.activate, mmcv.cnn.HSwish)
|
||||
else:
|
||||
assert isinstance(inv_module.expand_conv.activate, torch.nn.Hardswish)
|
||||
assert isinstance(inv_module.depthwise_conv.activate,
|
||||
torch.nn.Hardswish)
|
||||
|
||||
x = torch.rand(1, 32, 64, 64)
|
||||
output = inv_module(x)
|
||||
assert output.shape == (1, 40, 32, 32)
|
||||
|
Loading…
x
Reference in New Issue
Block a user