mmcv/tests/test_cnn/test_silu.py
takuoko afff388692
[Enhancement] Support SiLU with torch < 1.7.0 (#2278)
* support silu torch<1.7.0

* fix test

* fix test

* fix inplace

* fix inplace

* Update tests/test_cnn/test_silu.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
2022-09-29 15:19:41 +08:00

29 lines
892 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.cnn.bricks import build_activation_layer
def test_silu():
act = build_activation_layer(dict(type='SiLU'))
input = torch.randn(1, 3, 64, 64)
expected_output = input * torch.sigmoid(input)
output = act(input)
# test output shape
assert output.shape == expected_output.shape
# test output value
assert torch.allclose(output, expected_output)
# test inplace
act = build_activation_layer(dict(type='SiLU', inplace=True))
assert act.inplace
input = torch.randn(1, 3, 64, 64)
expected_output = input * torch.sigmoid(input)
output = act(input)
# test output shape
assert output.shape == expected_output.shape
# test output value
assert torch.allclose(output, expected_output)
assert torch.allclose(input, expected_output)
assert input is output