mirror of https://github.com/open-mmlab/mmcv.git
28 lines
900 B
Python
28 lines
900 B
Python
|
import torch
|
||
|
|
||
|
from mmcv.cnn.bricks import Conv2dAdaptivePadding
|
||
|
|
||
|
|
||
|
def test_conv2d_samepadding():
|
||
|
# test Conv2dAdaptivePadding with stride=1
|
||
|
inputs = torch.rand((1, 3, 28, 28))
|
||
|
conv = Conv2dAdaptivePadding(3, 3, kernel_size=3, stride=1)
|
||
|
output = conv(inputs)
|
||
|
assert output.shape == inputs.shape
|
||
|
|
||
|
inputs = torch.rand((1, 3, 13, 13))
|
||
|
conv = Conv2dAdaptivePadding(3, 3, kernel_size=3, stride=1)
|
||
|
output = conv(inputs)
|
||
|
assert output.shape == inputs.shape
|
||
|
|
||
|
# test Conv2dAdaptivePadding with stride=2
|
||
|
inputs = torch.rand((1, 3, 28, 28))
|
||
|
conv = Conv2dAdaptivePadding(3, 3, kernel_size=3, stride=2)
|
||
|
output = conv(inputs)
|
||
|
assert output.shape == torch.Size([1, 3, 14, 14])
|
||
|
|
||
|
inputs = torch.rand((1, 3, 13, 13))
|
||
|
conv = Conv2dAdaptivePadding(3, 3, kernel_size=3, stride=2)
|
||
|
output = conv(inputs)
|
||
|
assert output.shape == torch.Size([1, 3, 7, 7])
|