mmcv/tests/test_cnn/test_scale.py
Kai Chen 45111e193d
Add building bricks of cnn (#247)
* add building bricks of cnn

* add unit tests

* use registry for building bricks

* minor updates

* add scale layer

* add test for scale

* add doc string

Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
2020-05-01 00:32:25 +08:00

22 lines
511 B
Python

import torch
from mmcv.cnn.bricks import Scale
def test_scale():
# test default scale
scale = Scale()
assert scale.scale.data == 1.
assert scale.scale.dtype == torch.float
x = torch.rand(1, 3, 64, 64)
output = scale(x)
assert output.shape == (1, 3, 64, 64)
# test given scale
scale = Scale(10.)
assert scale.scale.data == 10.
assert scale.scale.dtype == torch.float
x = torch.rand(1, 3, 64, 64)
output = scale(x)
assert output.shape == (1, 3, 64, 64)