from collections import OrderedDict from itertools import product from unittest.mock import patch import torch import torch.nn as nn from mmcv.ops import Conv2d, ConvTranspose2d, Linear, MaxPool2d torch.__version__ = '1.1' # force test def test_conv2d(): """ CommandLine: xdoctest -m tests/test_wrappers.py test_conv2d """ test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]), ('in_channel', [1, 3]), ('out_channel', [1, 3]), ('kernel_size', [3, 5]), ('stride', [1, 2]), ('padding', [0, 1]), ('dilation', [1, 2])]) # train mode for in_h, in_w, in_cha, out_cha, k, s, p, d in product( *list(test_cases.values())): # wrapper op with 0-dim input x_empty = torch.randn(0, in_cha, in_h, in_w) torch.manual_seed(0) wrapper = Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d) wrapper_out = wrapper(x_empty) # torch op with 3-dim input as shape reference x_normal = torch.randn(3, in_cha, in_h, in_w).requires_grad_(True) torch.manual_seed(0) ref = nn.Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d) ref_out = ref(x_normal) assert wrapper_out.shape[0] == 0 assert wrapper_out.shape[1:] == ref_out.shape[1:] wrapper_out.sum().backward() assert wrapper.weight.grad is not None assert wrapper.weight.grad.shape == wrapper.weight.shape assert torch.equal(wrapper(x_normal), ref_out) # eval mode x_empty = torch.randn(0, in_cha, in_h, in_w) wrapper = Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d) wrapper.eval() wrapper(x_empty) def test_conv_transposed_2d(): test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]), ('in_channel', [1, 3]), ('out_channel', [1, 3]), ('kernel_size', [3, 5]), ('stride', [1, 2]), ('padding', [0, 1]), ('dilation', [1, 2])]) for in_h, in_w, in_cha, out_cha, k, s, p, d in product( *list(test_cases.values())): # wrapper op with 0-dim input x_empty = torch.randn(0, in_cha, in_h, in_w, requires_grad=True) # out padding must be smaller than either stride or dilation op = min(s, d) - 1 torch.manual_seed(0) wrapper = ConvTranspose2d( in_cha, out_cha, k, stride=s, padding=p, dilation=d, output_padding=op) wrapper_out = wrapper(x_empty) # torch op with 3-dim input as shape reference x_normal = torch.randn(3, in_cha, in_h, in_w) torch.manual_seed(0) ref = nn.ConvTranspose2d( in_cha, out_cha, k, stride=s, padding=p, dilation=d, output_padding=op) ref_out = ref(x_normal) assert wrapper_out.shape[0] == 0 assert wrapper_out.shape[1:] == ref_out.shape[1:] wrapper_out.sum().backward() assert wrapper.weight.grad is not None assert wrapper.weight.grad.shape == wrapper.weight.shape assert torch.equal(wrapper(x_normal), ref_out) # eval mode x_empty = torch.randn(0, in_cha, in_h, in_w) wrapper = ConvTranspose2d( in_cha, out_cha, k, stride=s, padding=p, dilation=d, output_padding=op) wrapper.eval() wrapper(x_empty) def test_max_pool_2d(): test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]), ('in_channel', [1, 3]), ('out_channel', [1, 3]), ('kernel_size', [3, 5]), ('stride', [1, 2]), ('padding', [0, 1]), ('dilation', [1, 2])]) for in_h, in_w, in_cha, out_cha, k, s, p, d in product( *list(test_cases.values())): # wrapper op with 0-dim input x_empty = torch.randn(0, in_cha, in_h, in_w, requires_grad=True) wrapper = MaxPool2d(k, stride=s, padding=p, dilation=d) wrapper_out = wrapper(x_empty) # torch op with 3-dim input as shape reference x_normal = torch.randn(3, in_cha, in_h, in_w) ref = nn.MaxPool2d(k, stride=s, padding=p, dilation=d) ref_out = ref(x_normal) assert wrapper_out.shape[0] == 0 assert wrapper_out.shape[1:] == ref_out.shape[1:] assert torch.equal(wrapper(x_normal), ref_out) def test_linear(): test_cases = OrderedDict([ ('in_w', [10, 20]), ('in_h', [10, 20]), ('in_feature', [1, 3]), ('out_feature', [1, 3]), ]) for in_h, in_w, in_feature, out_feature in product( *list(test_cases.values())): # wrapper op with 0-dim input x_empty = torch.randn(0, in_feature, requires_grad=True) torch.manual_seed(0) wrapper = Linear(in_feature, out_feature) wrapper_out = wrapper(x_empty) # torch op with 3-dim input as shape reference x_normal = torch.randn(3, in_feature) torch.manual_seed(0) ref = nn.Linear(in_feature, out_feature) ref_out = ref(x_normal) assert wrapper_out.shape[0] == 0 assert wrapper_out.shape[1:] == ref_out.shape[1:] wrapper_out.sum().backward() assert wrapper.weight.grad is not None assert wrapper.weight.grad.shape == wrapper.weight.shape assert torch.equal(wrapper(x_normal), ref_out) # eval mode x_empty = torch.randn(0, in_feature) wrapper = Linear(in_feature, out_feature) wrapper.eval() wrapper(x_empty) def test_nn_op_forward_called(): torch.__version__ = '1.4.1' for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']: with patch(f'torch.nn.{m}.forward') as nn_module_forward: # randn input x_empty = torch.randn(0, 3, 10, 10) wrapper = eval(m)(3, 2, 1) wrapper(x_empty) nn_module_forward.assert_called_with(x_empty) # non-randn input x_normal = torch.randn(1, 3, 10, 10) wrapper = eval(m)(3, 2, 1) wrapper(x_normal) nn_module_forward.assert_called_with(x_normal) with patch('torch.nn.Linear.forward') as nn_module_forward: # randn input x_empty = torch.randn(0, 3) wrapper = Linear(3, 3) wrapper(x_empty) nn_module_forward.assert_not_called() # non-randn input x_normal = torch.randn(1, 3) wrapper = Linear(3, 3) wrapper(x_normal) nn_module_forward.assert_called_with(x_normal)