mmcv/tests/test_cnn/test_swish.py
Jiazhen Wang fb486b96fd
[Fix] Fix some warnings in unittest (#1522)
* [Fix] fix some warnings in unittest

* [Impl] standardize some warnings

* [Fix] fix warning type in test_deprecation

* [Fix] fix warning type

* [Fix] continue fixing

* [Fix] fix some details

* [Fix] fix docstring

* [Fix] del useless statement

* [Fix] keep compatibility for torch < 1.5.0
2021-12-22 10:57:10 +08:00

16 lines
372 B
Python

import torch
import torch.nn.functional as F
from mmcv.cnn.bricks import Swish
def test_swish():
act = Swish()
input = torch.randn(1, 3, 64, 64)
expected_output = input * F.sigmoid(input)
output = act(input)
# test output shape
assert output.shape == expected_output.shape
# test output value
assert torch.equal(output, expected_output)