mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
* update impad * fix docstring * add shape for impad * fix unit test * remove old version & fix doc * fix linting * fix doc * add linear decay learning rate scheduler * fix impad * fix setup.cfg * fix linting * add yapf * add swish * fix lr_updater * fix lr_updater.py * update swish * add swish * fix inplace * fix typo Co-authored-by: lixiaojie <lixiaojie@sensetime.com>
16 lines
378 B
Python
16 lines
378 B
Python
import torch
|
|
from torch.nn.functional import sigmoid
|
|
|
|
from mmcv.cnn.bricks import Swish
|
|
|
|
|
|
def test_swish():
|
|
act = Swish()
|
|
input = torch.randn(1, 3, 64, 64)
|
|
expected_output = input * sigmoid(input)
|
|
output = act(input)
|
|
# test output shape
|
|
assert output.shape == expected_output.shape
|
|
# test output value
|
|
assert torch.equal(output, expected_output)
|