add ConvTranspose3d

pull/652/head
dreamerlin 2020-11-15 18:56:02 +08:00
parent c390e327fa
commit 8ccea20234
4 changed files with 92 additions and 8 deletions

View File

@ -1,14 +1,17 @@
# Copyright (c) Open-MMLab. All rights reserved.
from .alexnet import AlexNet
# yapf: disable
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
ContextBlock, Conv2d, ConvAWS2d, ConvModule,
ConvTranspose2d, ConvWS2d, DepthwiseSeparableConvModule,
GeneralizedAttention, HSigmoid, HSwish, Linear, MaxPool2d,
NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
ConvTranspose2d, ConvTranspose3d, ConvWS2d,
DepthwiseSeparableConvModule, GeneralizedAttention,
HSigmoid, HSwish, Linear, MaxPool2d, NonLocal1d,
NonLocal2d, NonLocal3d, Scale, Swish,
build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer, build_plugin_layer,
build_upsample_layer, conv_ws_2d, is_norm)
# yapf: enable
from .resnet import ResNet, make_res_layer
from .utils import (bias_init_with_prob, caffe2_xavier_init, constant_init,
fuse_conv_bn, get_model_complexity_info, kaiming_init,
@ -26,5 +29,5 @@ __all__ = [
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d'
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d'
]

View File

@ -17,7 +17,8 @@ from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
from .scale import Scale
from .swish import Swish
from .upsample import build_upsample_layer
from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
from .wrappers import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear,
MaxPool2d)
__all__ = [
'ConvModule', 'build_activation_layer', 'build_conv_layer',
@ -27,5 +28,6 @@ __all__ = [
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d'
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
'ConvTranspose3d'
]

View File

@ -78,7 +78,30 @@ class ConvTranspose2d(nn.ConvTranspose2d):
else:
return empty
return super(ConvTranspose2d, self).forward(x)
return super().forward(x)
@CONV_LAYERS.register_module()
@CONV_LAYERS.register_module('deconv3d')
@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
class ConvTranspose3d(nn.ConvTranspose3d):
def forward(self, x):
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
self.padding, self.stride,
self.dilation, self.output_padding):
out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
empty = NewEmptyTensorOp.apply(x, out_shape)
if self.training:
# produce dummy gradient to avoid DDP warning.
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
return empty + dummy
else:
return empty
return super().forward(x)
class MaxPool2d(nn.MaxPool2d):

View File

@ -5,7 +5,8 @@ from unittest.mock import patch
import torch
import torch.nn as nn
from mmcv.cnn.bricks import Conv2d, ConvTranspose2d, Linear, MaxPool2d
from mmcv.cnn.bricks import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear,
MaxPool2d)
@patch('torch.__version__', '1.1')
@ -105,6 +106,61 @@ def test_conv_transposed_2d():
wrapper(x_empty)
@patch('torch.__version__', '1.1')
def test_conv_transposed_3d():
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
('in_t', [10, 20]), ('in_channel', [1, 3]),
('out_channel', [1, 3]), ('kernel_size', [3, 5]),
('stride', [1, 2]), ('padding', [0, 1]),
('dilation', [1, 2])])
for in_h, in_w, in_t, in_cha, out_cha, k, s, p, d in product(
*list(test_cases.values())):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_cha, in_t, in_h, in_w, requires_grad=True)
# out padding must be smaller than either stride or dilation
op = min(s, d) - 1
torch.manual_seed(0)
wrapper = ConvTranspose3d(
in_cha,
out_cha,
k,
stride=s,
padding=p,
dilation=d,
output_padding=op)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_cha, in_t, in_h, in_w)
torch.manual_seed(0)
ref = nn.ConvTranspose3d(
in_cha,
out_cha,
k,
stride=s,
padding=p,
dilation=d,
output_padding=op)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_cha, in_t, in_h, in_w)
wrapper = ConvTranspose3d(
in_cha, out_cha, k, stride=s, padding=p, dilation=d, output_padding=op)
wrapper.eval()
wrapper(x_empty)
@patch('torch.__version__', '1.1')
def test_max_pool_2d():
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),