mmcv/tests/test_ops/test_wrappers.py

199 lines
6.5 KiB
Python

from collections import OrderedDict
from itertools import product
from unittest.mock import patch
import torch
import torch.nn as nn
from mmcv.ops import Conv2d, ConvTranspose2d, Linear, MaxPool2d
torch.__version__ = '1.1' # force test
def test_conv2d():
"""
CommandLine:
xdoctest -m tests/test_wrappers.py test_conv2d
"""
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
('in_channel', [1, 3]), ('out_channel', [1, 3]),
('kernel_size', [3, 5]), ('stride', [1, 2]),
('padding', [0, 1]), ('dilation', [1, 2])])
# train mode
for in_h, in_w, in_cha, out_cha, k, s, p, d in product(
*list(test_cases.values())):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_cha, in_h, in_w)
torch.manual_seed(0)
wrapper = Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_cha, in_h, in_w).requires_grad_(True)
torch.manual_seed(0)
ref = nn.Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_cha, in_h, in_w)
wrapper = Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d)
wrapper.eval()
wrapper(x_empty)
def test_conv_transposed_2d():
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
('in_channel', [1, 3]), ('out_channel', [1, 3]),
('kernel_size', [3, 5]), ('stride', [1, 2]),
('padding', [0, 1]), ('dilation', [1, 2])])
for in_h, in_w, in_cha, out_cha, k, s, p, d in product(
*list(test_cases.values())):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_cha, in_h, in_w, requires_grad=True)
# out padding must be smaller than either stride or dilation
op = min(s, d) - 1
torch.manual_seed(0)
wrapper = ConvTranspose2d(
in_cha,
out_cha,
k,
stride=s,
padding=p,
dilation=d,
output_padding=op)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_cha, in_h, in_w)
torch.manual_seed(0)
ref = nn.ConvTranspose2d(
in_cha,
out_cha,
k,
stride=s,
padding=p,
dilation=d,
output_padding=op)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_cha, in_h, in_w)
wrapper = ConvTranspose2d(
in_cha, out_cha, k, stride=s, padding=p, dilation=d, output_padding=op)
wrapper.eval()
wrapper(x_empty)
def test_max_pool_2d():
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
('in_channel', [1, 3]), ('out_channel', [1, 3]),
('kernel_size', [3, 5]), ('stride', [1, 2]),
('padding', [0, 1]), ('dilation', [1, 2])])
for in_h, in_w, in_cha, out_cha, k, s, p, d in product(
*list(test_cases.values())):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_cha, in_h, in_w, requires_grad=True)
wrapper = MaxPool2d(k, stride=s, padding=p, dilation=d)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_cha, in_h, in_w)
ref = nn.MaxPool2d(k, stride=s, padding=p, dilation=d)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
assert torch.equal(wrapper(x_normal), ref_out)
def test_linear():
test_cases = OrderedDict([
('in_w', [10, 20]),
('in_h', [10, 20]),
('in_feature', [1, 3]),
('out_feature', [1, 3]),
])
for in_h, in_w, in_feature, out_feature in product(
*list(test_cases.values())):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_feature, requires_grad=True)
torch.manual_seed(0)
wrapper = Linear(in_feature, out_feature)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_feature)
torch.manual_seed(0)
ref = nn.Linear(in_feature, out_feature)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_feature)
wrapper = Linear(in_feature, out_feature)
wrapper.eval()
wrapper(x_empty)
def test_nn_op_forward_called():
torch.__version__ = '1.4.1'
for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']:
with patch(f'torch.nn.{m}.forward') as nn_module_forward:
# randn input
x_empty = torch.randn(0, 3, 10, 10)
wrapper = eval(m)(3, 2, 1)
wrapper(x_empty)
nn_module_forward.assert_called_with(x_empty)
# non-randn input
x_normal = torch.randn(1, 3, 10, 10)
wrapper = eval(m)(3, 2, 1)
wrapper(x_normal)
nn_module_forward.assert_called_with(x_normal)
with patch('torch.nn.Linear.forward') as nn_module_forward:
# randn input
x_empty = torch.randn(0, 3)
wrapper = Linear(3, 3)
wrapper(x_empty)
nn_module_forward.assert_not_called()
# non-randn input
x_normal = torch.randn(1, 3)
wrapper = Linear(3, 3)
wrapper(x_normal)
nn_module_forward.assert_called_with(x_normal)