From 304df56c78dd353aa8660cf3c8154ce1a39dfdd1 Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Wed, 16 Feb 2022 19:47:58 +0800 Subject: [PATCH] [Fix] Add Pytorch HardSwish assertion in unit test (#1294) * assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test --- .../test_models/test_backbones/test_blocks.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/test_models/test_backbones/test_blocks.py b/tests/test_models/test_backbones/test_blocks.py index ad3ad2d8c..77c8564a4 100644 --- a/tests/test_models/test_backbones/test_blocks.py +++ b/tests/test_models/test_backbones/test_blocks.py @@ -2,6 +2,7 @@ import mmcv import pytest import torch +from mmcv.utils import TORCH_VERSION, digit_version from mmseg.models.utils import (InvertedResidual, InvertedResidualV3, SELayer, make_divisible) @@ -108,7 +109,6 @@ def test_inv_residualv3(): assert inv_module.expand_conv.conv.kernel_size == (1, 1) assert inv_module.expand_conv.conv.stride == (1, 1) assert inv_module.expand_conv.conv.padding == (0, 0) - assert isinstance(inv_module.expand_conv.activate, mmcv.cnn.HSwish) assert isinstance(inv_module.depthwise_conv.conv, mmcv.cnn.bricks.Conv2dAdaptivePadding) @@ -116,11 +116,27 @@ def test_inv_residualv3(): assert inv_module.depthwise_conv.conv.stride == (2, 2) assert inv_module.depthwise_conv.conv.padding == (0, 0) assert isinstance(inv_module.depthwise_conv.bn, torch.nn.BatchNorm2d) - assert isinstance(inv_module.depthwise_conv.activate, mmcv.cnn.HSwish) + assert inv_module.linear_conv.conv.kernel_size == (1, 1) assert inv_module.linear_conv.conv.stride == (1, 1) assert inv_module.linear_conv.conv.padding == (0, 0) assert isinstance(inv_module.linear_conv.bn, torch.nn.BatchNorm2d) + + if (TORCH_VERSION == 'parrots' + or digit_version(TORCH_VERSION) < digit_version('1.7')): + # Note: Use PyTorch official HSwish + # when torch>=1.7 after MMCV >= 1.4.5. + # Hardswish is not supported when PyTorch version < 1.6. + # And Hardswish in PyTorch 1.6 does not support inplace. + # More details could be found from: + # https://github.com/open-mmlab/mmcv/pull/1709 + assert isinstance(inv_module.expand_conv.activate, mmcv.cnn.HSwish) + assert isinstance(inv_module.depthwise_conv.activate, mmcv.cnn.HSwish) + else: + assert isinstance(inv_module.expand_conv.activate, torch.nn.Hardswish) + assert isinstance(inv_module.depthwise_conv.activate, + torch.nn.Hardswish) + x = torch.rand(1, 32, 64, 64) output = inv_module(x) assert output.shape == (1, 40, 32, 32)