mmcv/tests/test_cnn/test_swish.py
Zaida Zhou 6e9ce18323
Add copyright pre-commit-hook (#1742)
* first commit

* Add copyright pre-commit-hook
2022-02-24 09:24:25 +08:00

17 lines
420 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmcv.cnn.bricks import Swish
def test_swish():
act = Swish()
input = torch.randn(1, 3, 64, 64)
expected_output = input * F.sigmoid(input)
output = act(input)
# test output shape
assert output.shape == expected_output.shape
# test output value
assert torch.equal(output, expected_output)