mmcv/tests/test_cnn/test_swish.py
Xiaojie Li c3d8eb34ff
add Swish activation (#522)
* update impad

* fix docstring

* add shape for impad

* fix unit test

* remove old version & fix doc

* fix linting

* fix doc

* add linear decay learning rate scheduler

* fix impad

* fix setup.cfg

* fix linting

* add yapf

* add swish

* fix lr_updater

* fix lr_updater.py

* update swish

* add swish

* fix inplace

* fix typo

Co-authored-by: lixiaojie <lixiaojie@sensetime.com>
2020-08-27 00:39:17 +08:00

16 lines
378 B
Python

import torch
from torch.nn.functional import sigmoid
from mmcv.cnn.bricks import Swish
def test_swish():
act = Swish()
input = torch.randn(1, 3, 64, 64)
expected_output = input * sigmoid(input)
output = act(input)
# test output shape
assert output.shape == expected_output.shape
# test output value
assert torch.equal(output, expected_output)