mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Add Pytorch HardSwish assertion in unit test (#1294)
* assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test
This commit is contained in:
parent
58c02dd169
commit
304df56c78
@ -2,6 +2,7 @@
|
|||||||
import mmcv
|
import mmcv
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from mmcv.utils import TORCH_VERSION, digit_version
|
||||||
|
|
||||||
from mmseg.models.utils import (InvertedResidual, InvertedResidualV3, SELayer,
|
from mmseg.models.utils import (InvertedResidual, InvertedResidualV3, SELayer,
|
||||||
make_divisible)
|
make_divisible)
|
||||||
@ -108,7 +109,6 @@ def test_inv_residualv3():
|
|||||||
assert inv_module.expand_conv.conv.kernel_size == (1, 1)
|
assert inv_module.expand_conv.conv.kernel_size == (1, 1)
|
||||||
assert inv_module.expand_conv.conv.stride == (1, 1)
|
assert inv_module.expand_conv.conv.stride == (1, 1)
|
||||||
assert inv_module.expand_conv.conv.padding == (0, 0)
|
assert inv_module.expand_conv.conv.padding == (0, 0)
|
||||||
assert isinstance(inv_module.expand_conv.activate, mmcv.cnn.HSwish)
|
|
||||||
|
|
||||||
assert isinstance(inv_module.depthwise_conv.conv,
|
assert isinstance(inv_module.depthwise_conv.conv,
|
||||||
mmcv.cnn.bricks.Conv2dAdaptivePadding)
|
mmcv.cnn.bricks.Conv2dAdaptivePadding)
|
||||||
@ -116,11 +116,27 @@ def test_inv_residualv3():
|
|||||||
assert inv_module.depthwise_conv.conv.stride == (2, 2)
|
assert inv_module.depthwise_conv.conv.stride == (2, 2)
|
||||||
assert inv_module.depthwise_conv.conv.padding == (0, 0)
|
assert inv_module.depthwise_conv.conv.padding == (0, 0)
|
||||||
assert isinstance(inv_module.depthwise_conv.bn, torch.nn.BatchNorm2d)
|
assert isinstance(inv_module.depthwise_conv.bn, torch.nn.BatchNorm2d)
|
||||||
assert isinstance(inv_module.depthwise_conv.activate, mmcv.cnn.HSwish)
|
|
||||||
assert inv_module.linear_conv.conv.kernel_size == (1, 1)
|
assert inv_module.linear_conv.conv.kernel_size == (1, 1)
|
||||||
assert inv_module.linear_conv.conv.stride == (1, 1)
|
assert inv_module.linear_conv.conv.stride == (1, 1)
|
||||||
assert inv_module.linear_conv.conv.padding == (0, 0)
|
assert inv_module.linear_conv.conv.padding == (0, 0)
|
||||||
assert isinstance(inv_module.linear_conv.bn, torch.nn.BatchNorm2d)
|
assert isinstance(inv_module.linear_conv.bn, torch.nn.BatchNorm2d)
|
||||||
|
|
||||||
|
if (TORCH_VERSION == 'parrots'
|
||||||
|
or digit_version(TORCH_VERSION) < digit_version('1.7')):
|
||||||
|
# Note: Use PyTorch official HSwish
|
||||||
|
# when torch>=1.7 after MMCV >= 1.4.5.
|
||||||
|
# Hardswish is not supported when PyTorch version < 1.6.
|
||||||
|
# And Hardswish in PyTorch 1.6 does not support inplace.
|
||||||
|
# More details could be found from:
|
||||||
|
# https://github.com/open-mmlab/mmcv/pull/1709
|
||||||
|
assert isinstance(inv_module.expand_conv.activate, mmcv.cnn.HSwish)
|
||||||
|
assert isinstance(inv_module.depthwise_conv.activate, mmcv.cnn.HSwish)
|
||||||
|
else:
|
||||||
|
assert isinstance(inv_module.expand_conv.activate, torch.nn.Hardswish)
|
||||||
|
assert isinstance(inv_module.depthwise_conv.activate,
|
||||||
|
torch.nn.Hardswish)
|
||||||
|
|
||||||
x = torch.rand(1, 32, 64, 64)
|
x = torch.rand(1, 32, 64, 64)
|
||||||
output = inv_module(x)
|
output = inv_module(x)
|
||||||
assert output.shape == (1, 40, 32, 32)
|
assert output.shape == (1, 40, 32, 32)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user