mirror of https://github.com/open-mmlab/mmcv.git
add ConvTranspose3d
parent
c390e327fa
commit
8ccea20234
|
@ -1,14 +1,17 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
from .alexnet import AlexNet
|
||||
# yapf: disable
|
||||
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
|
||||
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
|
||||
ContextBlock, Conv2d, ConvAWS2d, ConvModule,
|
||||
ConvTranspose2d, ConvWS2d, DepthwiseSeparableConvModule,
|
||||
GeneralizedAttention, HSigmoid, HSwish, Linear, MaxPool2d,
|
||||
NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
|
||||
ConvTranspose2d, ConvTranspose3d, ConvWS2d,
|
||||
DepthwiseSeparableConvModule, GeneralizedAttention,
|
||||
HSigmoid, HSwish, Linear, MaxPool2d, NonLocal1d,
|
||||
NonLocal2d, NonLocal3d, Scale, Swish,
|
||||
build_activation_layer, build_conv_layer,
|
||||
build_norm_layer, build_padding_layer, build_plugin_layer,
|
||||
build_upsample_layer, conv_ws_2d, is_norm)
|
||||
# yapf: enable
|
||||
from .resnet import ResNet, make_res_layer
|
||||
from .utils import (bias_init_with_prob, caffe2_xavier_init, constant_init,
|
||||
fuse_conv_bn, get_model_complexity_info, kaiming_init,
|
||||
|
@ -26,5 +29,5 @@ __all__ = [
|
|||
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
|
||||
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
|
||||
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
|
||||
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d'
|
||||
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d'
|
||||
]
|
||||
|
|
|
@ -17,7 +17,8 @@ from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
|
|||
from .scale import Scale
|
||||
from .swish import Swish
|
||||
from .upsample import build_upsample_layer
|
||||
from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
|
||||
from .wrappers import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear,
|
||||
MaxPool2d)
|
||||
|
||||
__all__ = [
|
||||
'ConvModule', 'build_activation_layer', 'build_conv_layer',
|
||||
|
@ -27,5 +28,6 @@ __all__ = [
|
|||
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
|
||||
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
|
||||
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
|
||||
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d'
|
||||
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
|
||||
'ConvTranspose3d'
|
||||
]
|
||||
|
|
|
@ -78,7 +78,30 @@ class ConvTranspose2d(nn.ConvTranspose2d):
|
|||
else:
|
||||
return empty
|
||||
|
||||
return super(ConvTranspose2d, self).forward(x)
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
@CONV_LAYERS.register_module()
|
||||
@CONV_LAYERS.register_module('deconv3d')
|
||||
@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
|
||||
class ConvTranspose3d(nn.ConvTranspose3d):
|
||||
|
||||
def forward(self, x):
|
||||
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
|
||||
out_shape = [x.shape[0], self.out_channels]
|
||||
for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
|
||||
self.padding, self.stride,
|
||||
self.dilation, self.output_padding):
|
||||
out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
|
||||
empty = NewEmptyTensorOp.apply(x, out_shape)
|
||||
if self.training:
|
||||
# produce dummy gradient to avoid DDP warning.
|
||||
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
|
||||
return empty + dummy
|
||||
else:
|
||||
return empty
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class MaxPool2d(nn.MaxPool2d):
|
||||
|
|
|
@ -5,7 +5,8 @@ from unittest.mock import patch
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcv.cnn.bricks import Conv2d, ConvTranspose2d, Linear, MaxPool2d
|
||||
from mmcv.cnn.bricks import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear,
|
||||
MaxPool2d)
|
||||
|
||||
|
||||
@patch('torch.__version__', '1.1')
|
||||
|
@ -105,6 +106,61 @@ def test_conv_transposed_2d():
|
|||
wrapper(x_empty)
|
||||
|
||||
|
||||
@patch('torch.__version__', '1.1')
|
||||
def test_conv_transposed_3d():
|
||||
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
|
||||
('in_t', [10, 20]), ('in_channel', [1, 3]),
|
||||
('out_channel', [1, 3]), ('kernel_size', [3, 5]),
|
||||
('stride', [1, 2]), ('padding', [0, 1]),
|
||||
('dilation', [1, 2])])
|
||||
|
||||
for in_h, in_w, in_t, in_cha, out_cha, k, s, p, d in product(
|
||||
*list(test_cases.values())):
|
||||
# wrapper op with 0-dim input
|
||||
x_empty = torch.randn(0, in_cha, in_t, in_h, in_w, requires_grad=True)
|
||||
# out padding must be smaller than either stride or dilation
|
||||
op = min(s, d) - 1
|
||||
torch.manual_seed(0)
|
||||
wrapper = ConvTranspose3d(
|
||||
in_cha,
|
||||
out_cha,
|
||||
k,
|
||||
stride=s,
|
||||
padding=p,
|
||||
dilation=d,
|
||||
output_padding=op)
|
||||
wrapper_out = wrapper(x_empty)
|
||||
|
||||
# torch op with 3-dim input as shape reference
|
||||
x_normal = torch.randn(3, in_cha, in_t, in_h, in_w)
|
||||
torch.manual_seed(0)
|
||||
ref = nn.ConvTranspose3d(
|
||||
in_cha,
|
||||
out_cha,
|
||||
k,
|
||||
stride=s,
|
||||
padding=p,
|
||||
dilation=d,
|
||||
output_padding=op)
|
||||
ref_out = ref(x_normal)
|
||||
|
||||
assert wrapper_out.shape[0] == 0
|
||||
assert wrapper_out.shape[1:] == ref_out.shape[1:]
|
||||
|
||||
wrapper_out.sum().backward()
|
||||
assert wrapper.weight.grad is not None
|
||||
assert wrapper.weight.grad.shape == wrapper.weight.shape
|
||||
|
||||
assert torch.equal(wrapper(x_normal), ref_out)
|
||||
|
||||
# eval mode
|
||||
x_empty = torch.randn(0, in_cha, in_t, in_h, in_w)
|
||||
wrapper = ConvTranspose3d(
|
||||
in_cha, out_cha, k, stride=s, padding=p, dilation=d, output_padding=op)
|
||||
wrapper.eval()
|
||||
wrapper(x_empty)
|
||||
|
||||
|
||||
@patch('torch.__version__', '1.1')
|
||||
def test_max_pool_2d():
|
||||
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
|
||||
|
|
Loading…
Reference in New Issue