mirror of https://github.com/open-mmlab/mmcv.git
199 lines
6.5 KiB
Python
199 lines
6.5 KiB
Python
from collections import OrderedDict
|
|
from itertools import product
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmcv.ops import Conv2d, ConvTranspose2d, Linear, MaxPool2d
|
|
|
|
torch.__version__ = '1.1' # force test
|
|
|
|
|
|
def test_conv2d():
|
|
"""
|
|
CommandLine:
|
|
xdoctest -m tests/test_wrappers.py test_conv2d
|
|
"""
|
|
|
|
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
|
|
('in_channel', [1, 3]), ('out_channel', [1, 3]),
|
|
('kernel_size', [3, 5]), ('stride', [1, 2]),
|
|
('padding', [0, 1]), ('dilation', [1, 2])])
|
|
|
|
# train mode
|
|
for in_h, in_w, in_cha, out_cha, k, s, p, d in product(
|
|
*list(test_cases.values())):
|
|
# wrapper op with 0-dim input
|
|
x_empty = torch.randn(0, in_cha, in_h, in_w)
|
|
torch.manual_seed(0)
|
|
wrapper = Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d)
|
|
wrapper_out = wrapper(x_empty)
|
|
|
|
# torch op with 3-dim input as shape reference
|
|
x_normal = torch.randn(3, in_cha, in_h, in_w).requires_grad_(True)
|
|
torch.manual_seed(0)
|
|
ref = nn.Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d)
|
|
ref_out = ref(x_normal)
|
|
|
|
assert wrapper_out.shape[0] == 0
|
|
assert wrapper_out.shape[1:] == ref_out.shape[1:]
|
|
|
|
wrapper_out.sum().backward()
|
|
assert wrapper.weight.grad is not None
|
|
assert wrapper.weight.grad.shape == wrapper.weight.shape
|
|
|
|
assert torch.equal(wrapper(x_normal), ref_out)
|
|
|
|
# eval mode
|
|
x_empty = torch.randn(0, in_cha, in_h, in_w)
|
|
wrapper = Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d)
|
|
wrapper.eval()
|
|
wrapper(x_empty)
|
|
|
|
|
|
def test_conv_transposed_2d():
|
|
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
|
|
('in_channel', [1, 3]), ('out_channel', [1, 3]),
|
|
('kernel_size', [3, 5]), ('stride', [1, 2]),
|
|
('padding', [0, 1]), ('dilation', [1, 2])])
|
|
|
|
for in_h, in_w, in_cha, out_cha, k, s, p, d in product(
|
|
*list(test_cases.values())):
|
|
# wrapper op with 0-dim input
|
|
x_empty = torch.randn(0, in_cha, in_h, in_w, requires_grad=True)
|
|
# out padding must be smaller than either stride or dilation
|
|
op = min(s, d) - 1
|
|
torch.manual_seed(0)
|
|
wrapper = ConvTranspose2d(
|
|
in_cha,
|
|
out_cha,
|
|
k,
|
|
stride=s,
|
|
padding=p,
|
|
dilation=d,
|
|
output_padding=op)
|
|
wrapper_out = wrapper(x_empty)
|
|
|
|
# torch op with 3-dim input as shape reference
|
|
x_normal = torch.randn(3, in_cha, in_h, in_w)
|
|
torch.manual_seed(0)
|
|
ref = nn.ConvTranspose2d(
|
|
in_cha,
|
|
out_cha,
|
|
k,
|
|
stride=s,
|
|
padding=p,
|
|
dilation=d,
|
|
output_padding=op)
|
|
ref_out = ref(x_normal)
|
|
|
|
assert wrapper_out.shape[0] == 0
|
|
assert wrapper_out.shape[1:] == ref_out.shape[1:]
|
|
|
|
wrapper_out.sum().backward()
|
|
assert wrapper.weight.grad is not None
|
|
assert wrapper.weight.grad.shape == wrapper.weight.shape
|
|
|
|
assert torch.equal(wrapper(x_normal), ref_out)
|
|
|
|
# eval mode
|
|
x_empty = torch.randn(0, in_cha, in_h, in_w)
|
|
wrapper = ConvTranspose2d(
|
|
in_cha, out_cha, k, stride=s, padding=p, dilation=d, output_padding=op)
|
|
wrapper.eval()
|
|
wrapper(x_empty)
|
|
|
|
|
|
def test_max_pool_2d():
|
|
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
|
|
('in_channel', [1, 3]), ('out_channel', [1, 3]),
|
|
('kernel_size', [3, 5]), ('stride', [1, 2]),
|
|
('padding', [0, 1]), ('dilation', [1, 2])])
|
|
|
|
for in_h, in_w, in_cha, out_cha, k, s, p, d in product(
|
|
*list(test_cases.values())):
|
|
# wrapper op with 0-dim input
|
|
x_empty = torch.randn(0, in_cha, in_h, in_w, requires_grad=True)
|
|
wrapper = MaxPool2d(k, stride=s, padding=p, dilation=d)
|
|
wrapper_out = wrapper(x_empty)
|
|
|
|
# torch op with 3-dim input as shape reference
|
|
x_normal = torch.randn(3, in_cha, in_h, in_w)
|
|
ref = nn.MaxPool2d(k, stride=s, padding=p, dilation=d)
|
|
ref_out = ref(x_normal)
|
|
|
|
assert wrapper_out.shape[0] == 0
|
|
assert wrapper_out.shape[1:] == ref_out.shape[1:]
|
|
|
|
assert torch.equal(wrapper(x_normal), ref_out)
|
|
|
|
|
|
def test_linear():
|
|
test_cases = OrderedDict([
|
|
('in_w', [10, 20]),
|
|
('in_h', [10, 20]),
|
|
('in_feature', [1, 3]),
|
|
('out_feature', [1, 3]),
|
|
])
|
|
|
|
for in_h, in_w, in_feature, out_feature in product(
|
|
*list(test_cases.values())):
|
|
# wrapper op with 0-dim input
|
|
x_empty = torch.randn(0, in_feature, requires_grad=True)
|
|
torch.manual_seed(0)
|
|
wrapper = Linear(in_feature, out_feature)
|
|
wrapper_out = wrapper(x_empty)
|
|
|
|
# torch op with 3-dim input as shape reference
|
|
x_normal = torch.randn(3, in_feature)
|
|
torch.manual_seed(0)
|
|
ref = nn.Linear(in_feature, out_feature)
|
|
ref_out = ref(x_normal)
|
|
|
|
assert wrapper_out.shape[0] == 0
|
|
assert wrapper_out.shape[1:] == ref_out.shape[1:]
|
|
|
|
wrapper_out.sum().backward()
|
|
assert wrapper.weight.grad is not None
|
|
assert wrapper.weight.grad.shape == wrapper.weight.shape
|
|
|
|
assert torch.equal(wrapper(x_normal), ref_out)
|
|
|
|
# eval mode
|
|
x_empty = torch.randn(0, in_feature)
|
|
wrapper = Linear(in_feature, out_feature)
|
|
wrapper.eval()
|
|
wrapper(x_empty)
|
|
|
|
|
|
def test_nn_op_forward_called():
|
|
torch.__version__ = '1.4.1'
|
|
|
|
for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']:
|
|
with patch(f'torch.nn.{m}.forward') as nn_module_forward:
|
|
# randn input
|
|
x_empty = torch.randn(0, 3, 10, 10)
|
|
wrapper = eval(m)(3, 2, 1)
|
|
wrapper(x_empty)
|
|
nn_module_forward.assert_called_with(x_empty)
|
|
|
|
# non-randn input
|
|
x_normal = torch.randn(1, 3, 10, 10)
|
|
wrapper = eval(m)(3, 2, 1)
|
|
wrapper(x_normal)
|
|
nn_module_forward.assert_called_with(x_normal)
|
|
|
|
with patch('torch.nn.Linear.forward') as nn_module_forward:
|
|
# randn input
|
|
x_empty = torch.randn(0, 3)
|
|
wrapper = Linear(3, 3)
|
|
wrapper(x_empty)
|
|
nn_module_forward.assert_not_called()
|
|
|
|
# non-randn input
|
|
x_normal = torch.randn(1, 3)
|
|
wrapper = Linear(3, 3)
|
|
wrapper(x_normal)
|
|
nn_module_forward.assert_called_with(x_normal)
|