Migrate op (#392)

* migrate op

* migrate unittest

* update build no torch

* add back use_torch_vision for roi align

* fix type and unit test

* ignore test logging when no torch

* fix no torch ci test

* skip test registry

* remove coverage report when no torch

* fix mac ci order

* install latest pillow when no torch

* mv convws to brisk
pull/382/head^2
Cao Yuhang 2020-07-08 17:29:15 +08:00 committed by GitHub
parent 33ca908529
commit d5cbf7eed1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 639 additions and 116 deletions

View File

@ -133,10 +133,6 @@ jobs:
strategy:
matrix:
python-version: [3.7]
torch: [1.4.0]
include:
- torch: 1.4.0
torchvision: 0.4.2
steps:
- uses: actions/checkout@v2
@ -151,15 +147,10 @@ jobs:
- name: Build and install
run: rm -rf .eggs && pip install -e .
- name: Install Pillow
run: pip install Pillow==6.2.2
if: ${{matrix.torchvision == '0.4.2'}}
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
run: pip install Pillow
- name: Run unittests and generate coverage report
run: |
coverage run --branch --source=mmcv -m pytest tests/
coverage xml
coverage report -m
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_registry.py
build_macos:
runs-on: macos-latest
@ -181,15 +172,15 @@ jobs:
run: brew install ffmpeg jpeg-turbo
- name: Install unittest dependencies
run: pip install pytest coverage lmdb PyTurboJPEG
- name: Build and install
run: |
rm -rf .eggs
CC=clang CXX=clang++ CFLAGS='-stdlib=libc++' pip install -e .
- name: Install Pillow
run: pip install Pillow==6.2.2
if: ${{matrix.torchvision == '0.4.2'}}
- name: Install PyTorch
run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
- name: Build and install
run: |
rm -rf .eggs
CC=clang CXX=clang++ CFLAGS='-stdlib=libc++' pip install -e .
- name: Run unittests
run: |
# The timing on macos VMs is not precise, so we skip the progressbar tests

View File

@ -2,11 +2,12 @@
from .alexnet import AlexNet
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
ContextBlock, ConvModule, GeneralizedAttention, HSigmoid,
HSwish, NonLocal1d, NonLocal2d, NonLocal3d, Scale,
build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer, build_plugin_layer,
build_upsample_layer, is_norm)
ContextBlock, ConvAWS2d, ConvModule, ConvWS2d,
GeneralizedAttention, HSigmoid, HSwish, NonLocal1d,
NonLocal2d, NonLocal3d, Scale, build_activation_layer,
build_conv_layer, build_norm_layer, build_padding_layer,
build_plugin_layer, build_upsample_layer, conv_ws_2d,
is_norm)
from .resnet import ResNet, make_res_layer
from .utils import (bias_init_with_prob, caffe2_xavier_init, constant_init,
get_model_complexity_info, kaiming_init, normal_init,
@ -22,5 +23,6 @@ __all__ = [
'is_norm', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d', 'ContextBlock',
'HSigmoid', 'HSwish', 'GeneralizedAttention', 'ACTIVATION_LAYERS',
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info'
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
'ConvAWS2d', 'ConvWS2d'
]

View File

@ -2,6 +2,7 @@ from .activation import build_activation_layer
from .context_block import ContextBlock
from .conv import build_conv_layer
from .conv_module import ConvModule
from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
from .generalized_attention import GeneralizedAttention
from .hsigmoid import HSigmoid
from .hswish import HSwish
@ -20,5 +21,6 @@ __all__ = [
'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d',
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale'
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d'
]

View File

@ -0,0 +1,147 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .registry import CONV_LAYERS
def conv_ws_2d(input,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
eps=1e-5):
c_in = weight.size(0)
weight_flat = weight.view(c_in, -1)
mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
weight = (weight - mean) / (std + eps)
return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
@CONV_LAYERS.register_module('ConvWS')
class ConvWS2d(nn.Conv2d):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
eps=1e-5):
super(ConvWS2d, self).__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)
self.eps = eps
def forward(self, x):
return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups, self.eps)
@CONV_LAYERS.register_module(name='ConvAWS')
class ConvAWS2d(nn.Conv2d):
"""AWS (Adaptive Weight Standardization)
This is a variant of Weight Standardization
(https://arxiv.org/pdf/1903.10520.pdf)
It is used in DetectoRS to avoid NaN
(https://arxiv.org/pdf/2006.02334.pdf)
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the conv kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of
the input. Default: 0
dilation (int or tuple, optional): Spacing between kernel elements.
Default: 1
groups (int, optional): Number of blocked connections from input
channels to output channels. Default: 1
bias (bool, optional): If set True, adds a learnable bias to the
output. Default: True
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)
self.register_buffer('weight_gamma',
torch.ones(self.out_channels, 1, 1, 1))
self.register_buffer('weight_beta',
torch.zeros(self.out_channels, 1, 1, 1))
def _get_weight(self, weight):
weight_flat = weight.view(weight.size(0), -1)
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
weight = (weight - mean) / std
weight = self.weight_gamma * weight + self.weight_beta
return weight
def forward(self, x):
weight = self._get_weight(self.weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding,
self.dilation, self.groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""Override default load function.
AWS overrides the function _load_from_state_dict to recover
weight_gamma and weight_beta if they are missing. If weight_gamma and
weight_beta are found in the checkpoint, this function will return
after super()._load_from_state_dict. Otherwise, it will compute the
mean and std of the pretrained weights and store them in weight_beta
and weight_gamma.
"""
self.weight_gamma.data.fill_(-1)
local_missing_keys = []
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, local_missing_keys,
unexpected_keys, error_msgs)
if self.weight_gamma.data.mean() > 0:
for k in local_missing_keys:
missing_keys.append(k)
return
weight = self.weight.data
weight_flat = weight.view(weight.size(0), -1)
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
self.weight_beta.data.copy_(mean)
self.weight_gamma.data.copy_(std)
missing_gamma_beta = [
k for k in local_missing_keys
if k.endswith('weight_gamma') or k.endswith('weight_beta')
]
for k in missing_gamma_beta:
local_missing_keys.remove(k)
for k in local_missing_keys:
missing_keys.append(k)

View File

@ -1,7 +1,6 @@
from .bbox import bbox_overlaps
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention
from .conv_ws import ConvWS2d, conv_ws_2d
from .corner_pool import CornerPool
from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d
from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack,
@ -19,13 +18,14 @@ from .point_sample import (SimpleRoIAlign, point_sample,
from .psa_mask import PSAMask
from .roi_align import RoIAlign, roi_align
from .roi_pool import RoIPool, roi_pool
from .saconv import SAConv2d
from .sync_bn import SyncBatchNorm
from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
__all__ = [
'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe',
'carafe_naive', 'ConvWS2d', 'conv_ws_2d', 'CornerPool', 'DeformConv2d',
'DeformConv2dPack', 'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack',
'carafe_naive', 'CornerPool', 'DeformConv2d', 'DeformConv2dPack',
'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack',
'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss',
'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss',
'get_compiler_version', 'get_compiling_cuda_version', 'MaskedConv2d',
@ -33,5 +33,6 @@ __all__ = [
'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match',
'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign'
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d'
]

View File

@ -1,49 +0,0 @@
import torch.nn as nn
import torch.nn.functional as F
from ..cnn import CONV_LAYERS
def conv_ws_2d(input,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
eps=1e-5):
c_in = weight.size(0)
weight_flat = weight.view(c_in, -1)
mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
weight = (weight - mean) / (std + eps)
return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
@CONV_LAYERS.register_module('ConvWS')
class ConvWS2d(nn.Conv2d):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
eps=1e-5):
super(ConvWS2d, self).__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)
self.eps = eps
def forward(self, x):
return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups, self.eps)

View File

@ -162,11 +162,11 @@ def point_sample(input, points, align_corners=False, **kwargs):
class SimpleRoIAlign(nn.Module):
def __init__(self, out_size, spatial_scale, aligned=True):
def __init__(self, output_size, spatial_scale, aligned=True):
"""Simple RoI align in PointRend, faster than standard RoIAlign.
Args:
out_size (tuple[int]): h, w
output_size (tuple[int]): h, w
spatial_scale (float): scale the input boxes by this number
aligned (bool): if False, use the legacy implementation in
MMDetection, align_corners=True will be used in F.grid_sample.
@ -174,7 +174,7 @@ class SimpleRoIAlign(nn.Module):
"""
super(SimpleRoIAlign, self).__init__()
self.out_size = _pair(out_size)
self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale)
# to be consistent with other RoI ops
self.use_torchvision = False
@ -185,7 +185,7 @@ class SimpleRoIAlign(nn.Module):
num_imgs = features.size(0)
num_rois = rois.size(0)
rel_roi_points = generate_grid(
num_rois, self.out_size, device=rois.device)
num_rois, self.output_size, device=rois.device)
point_feats = []
for batch_ind in range(num_imgs):
@ -203,12 +203,12 @@ class SimpleRoIAlign(nn.Module):
channels = features.size(1)
roi_feats = torch.cat(point_feats, dim=0)
roi_feats = roi_feats.reshape(num_rois, channels, *self.out_size)
roi_feats = roi_feats.reshape(num_rois, channels, *self.output_size)
return roi_feats
def __repr__(self):
format_str = self.__class__.__name__
format_str += '(out_size={}, spatial_scale={}'.format(
self.out_size, self.spatial_scale)
format_str += '(output_size={}, spatial_scale={}'.format(
self.output_size, self.spatial_scale)
return format_str

View File

@ -105,6 +105,7 @@ class RoIAlign(nn.Module):
pool_mode (str, 'avg' or 'max'): pooling mode in each bin.
aligned (bool): if False, use the legacy implementation in
MMDetection. If True, align the results more perfectly.
use_torchvision (bool): whether to use roi_align from torchvision.
Note:
The implementation of RoIAlign when aligned=True is modified from
@ -135,7 +136,8 @@ class RoIAlign(nn.Module):
spatial_scale=1.0,
sampling_ratio=0,
pool_mode='avg',
aligned=True):
aligned=True,
use_torchvision=False):
super(RoIAlign, self).__init__()
self.output_size = _pair(output_size)
@ -143,6 +145,9 @@ class RoIAlign(nn.Module):
self.sampling_ratio = int(sampling_ratio)
self.pool_mode = pool_mode
self.aligned = aligned
self.use_torchvision = use_torchvision
assert not (use_torchvision and
aligned), 'Torchvision does not support aligned RoIAlgin'
def forward(self, input, rois):
"""
@ -151,8 +156,13 @@ class RoIAlign(nn.Module):
rois: Bx5 boxes. First column is the index into N.\
The other 4 columns are xyxy.
"""
return roi_align(input, rois, self.output_size, self.spatial_scale,
self.sampling_ratio, self.pool_mode, self.aligned)
if self.use_torchvision:
from torchvision.ops import roi_align as tv_roi_align
return tv_roi_align(input, rois, self.output_size,
self.spatial_scale, self.sampling_ratio)
else:
return roi_align(input, rois, self.output_size, self.spatial_scale,
self.sampling_ratio, self.pool_mode, self.aligned)
def __repr__(self):
s = self.__class__.__name__
@ -160,5 +170,6 @@ class RoIAlign(nn.Module):
s += f'spatial_scale={self.spatial_scale}, '
s += f'sampling_ratio={self.sampling_ratio}, '
s += f'pool_mode={self.pool_mode}, '
s += f'aligned={self.aligned})'
s += f'aligned={self.aligned}, '
s += f'use_torchvision={self.use_torchvision})'
return s

125
mmcv/ops/saconv.py 100644
View File

@ -0,0 +1,125 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init
from mmcv.ops.deform_conv import deform_conv2d
@CONV_LAYERS.register_module(name='SAC')
class SAConv2d(ConvAWS2d):
"""SAC (Switchable Atrous Convolution)
This is an implementation of SAC in DetectoRS
(https://arxiv.org/pdf/2006.02334.pdf).
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of
the input. Default: 0
padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
dilation (int or tuple, optional): Spacing between kernel elements.
Default: 1
groups (int, optional): Number of blocked connections from input
channels to output channels. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the
output. Default: ``True``
use_deform: If ``True``, replace convolution with deformable
convolution. Default: ``False``.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
use_deform=False):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)
self.use_deform = use_deform
self.switch = nn.Conv2d(
self.in_channels, 1, kernel_size=1, stride=stride, bias=True)
self.weight_diff = nn.Parameter(torch.Tensor(self.weight.size()))
self.pre_context = nn.Conv2d(
self.in_channels, self.in_channels, kernel_size=1, bias=True)
self.post_context = nn.Conv2d(
self.out_channels, self.out_channels, kernel_size=1, bias=True)
if self.use_deform:
self.offset_s = nn.Conv2d(
self.in_channels,
18,
kernel_size=3,
padding=1,
stride=stride,
bias=True)
self.offset_l = nn.Conv2d(
self.in_channels,
18,
kernel_size=3,
padding=1,
stride=stride,
bias=True)
self.init_weights()
def init_weights(self):
constant_init(self.switch, 0, bias=1)
self.weight_diff.data.zero_()
constant_init(self.pre_context, 0)
constant_init(self.post_context, 0)
if self.use_deform:
constant_init(self.offset_s, 0)
constant_init(self.offset_l, 0)
def forward(self, x):
# pre-context
avg_x = F.adaptive_avg_pool2d(x, output_size=1)
avg_x = self.pre_context(avg_x)
avg_x = avg_x.expand_as(x)
x = x + avg_x
# switch
avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect')
avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)
switch = self.switch(avg_x)
# sac
weight = self._get_weight(self.weight)
if self.use_deform:
offset = self.offset_s(avg_x)
out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
self.dilation, self.groups, 1)
else:
out_s = super().conv2d_forward(x, weight)
ori_p = self.padding
ori_d = self.dilation
self.padding = tuple(3 * p for p in self.padding)
self.dilation = tuple(3 * d for d in self.dilation)
weight = weight + self.weight_diff
if self.use_deform:
offset = self.offset_l(avg_x)
out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
self.dilation, self.groups, 1)
else:
out_l = super().conv2d_forward(x, weight)
out = switch * out_s + (1 - switch) * out_l
self.padding = ori_p
self.dilation = ori_d
# post-context
avg_x = F.adaptive_avg_pool2d(out, output_size=1)
avg_x = self.post_context(avg_x)
avg_x = avg_x.expand_as(out)
out = out + avg_x
return out

View File

@ -1,12 +1,17 @@
# Copyright (c) Open-MMLab. All rights reserved.
import numpy as np
import pytest
import torch
from numpy.testing import assert_array_equal
import mmcv
try:
import torch
except ImportError:
torch = None
@pytest.mark.skipif(torch is None, reason='requires torch library')
def test_tensor2imgs():
# test tensor obj

View File

@ -1,33 +1,58 @@
"""
CommandLine:
pytest tests/test_corner_pool.py
"""
import pytest
import torch
from torch.autograd import gradcheck
from mmcv.ops import CornerPool
class TestCornerPool(object):
def test_corner_pool_device_and_dtypes_cpu():
"""
CommandLine:
xdoctest -m tests/test_corner_pool.py \
test_corner_pool_device_and_dtypes_cpu
"""
with pytest.raises(AssertionError):
# pool mode must in ['bottom', 'left', 'right', 'top']
pool = CornerPool('corner')
def test_corner_pool_top_gradcheck(self):
if not torch.cuda.is_available():
return
from mmcv.ops import CornerPool
input = torch.randn(2, 4, 5, 5, requires_grad=True, device='cuda')
gradcheck(CornerPool('top'), (input, ), atol=1e-3, eps=1e-4)
def test_corner_pool_bottom_gradcheck(self):
if not torch.cuda.is_available():
return
from mmcv.ops import CornerPool
input = torch.randn(2, 4, 5, 5, requires_grad=True, device='cuda')
gradcheck(CornerPool('bottom'), (input, ), atol=1e-3, eps=1e-4)
def test_corner_pool_left_gradcheck(self):
if not torch.cuda.is_available():
return
from mmcv.ops import CornerPool
input = torch.randn(2, 4, 5, 5, requires_grad=True, device='cuda')
gradcheck(CornerPool('left'), (input, ), atol=1e-3, eps=1e-4)
def test_corner_pool_right_gradcheck(self):
if not torch.cuda.is_available():
return
from mmcv.ops import CornerPool
input = torch.randn(2, 4, 5, 5, requires_grad=True, device='cuda')
gradcheck(CornerPool('right'), (input, ), atol=1e-3, eps=1e-4)
lr_tensor = torch.tensor([[[[0, 0, 0, 0, 0], [2, 1, 3, 0, 2],
[5, 4, 1, 1, 6], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]]])
tb_tensor = torch.tensor([[[[0, 3, 1, 0, 0], [0, 1, 1, 0, 0],
[0, 3, 4, 0, 0], [0, 2, 2, 0, 0],
[0, 0, 2, 0, 0]]]])
# Left Pool
left_answer = torch.tensor([[[[0, 0, 0, 0, 0], [3, 3, 3, 2, 2],
[6, 6, 6, 6, 6], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]]])
pool = CornerPool('left')
left_tensor = pool(lr_tensor)
assert left_tensor.type() == lr_tensor.type()
assert torch.equal(left_tensor, left_answer)
# Right Pool
right_answer = torch.tensor([[[[0, 0, 0, 0, 0], [2, 2, 3, 3, 3],
[5, 5, 5, 5, 6], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]]])
pool = CornerPool('right')
right_tensor = pool(lr_tensor)
assert right_tensor.type() == lr_tensor.type()
assert torch.equal(right_tensor, right_answer)
# Top Pool
top_answer = torch.tensor([[[[0, 3, 4, 0, 0], [0, 3, 4, 0, 0],
[0, 3, 4, 0, 0], [0, 2, 2, 0, 0],
[0, 0, 2, 0, 0]]]])
pool = CornerPool('top')
top_tensor = pool(tb_tensor)
assert top_tensor.type() == tb_tensor.type()
assert torch.equal(top_tensor, top_answer)
# Bottom Pool
bottom_answer = torch.tensor([[[[0, 3, 1, 0, 0], [0, 3, 1, 0, 0],
[0, 3, 4, 0, 0], [0, 3, 4, 0, 0],
[0, 3, 4, 0, 0]]]])
pool = CornerPool('bottom')
bottom_tensor = pool(tb_tensor)
assert bottom_tensor.type() == tb_tensor.type()
assert torch.equal(bottom_tensor, bottom_answer)

View File

@ -0,0 +1,65 @@
"""
CommandLine:
pytest tests/test_merge_cells.py
"""
import torch
import torch.nn.functional as F
from mmcv.ops.merge_cells import (BaseMergeCell, ConcatCell, GlobalPoolingCell,
SumCell)
def test_sum_cell():
inputs_x = torch.randn([2, 256, 32, 32])
inputs_y = torch.randn([2, 256, 16, 16])
sum_cell = SumCell(256, 256)
output = sum_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
assert output.size() == inputs_x.size()
output = sum_cell(inputs_x, inputs_y, out_size=inputs_y.shape[-2:])
assert output.size() == inputs_y.size()
output = sum_cell(inputs_x, inputs_y)
assert output.size() == inputs_x.size()
def test_concat_cell():
inputs_x = torch.randn([2, 256, 32, 32])
inputs_y = torch.randn([2, 256, 16, 16])
concat_cell = ConcatCell(256, 256)
output = concat_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
assert output.size() == inputs_x.size()
output = concat_cell(inputs_x, inputs_y, out_size=inputs_y.shape[-2:])
assert output.size() == inputs_y.size()
output = concat_cell(inputs_x, inputs_y)
assert output.size() == inputs_x.size()
def test_global_pool_cell():
inputs_x = torch.randn([2, 256, 32, 32])
inputs_y = torch.randn([2, 256, 32, 32])
gp_cell = GlobalPoolingCell(with_out_conv=False)
gp_cell_out = gp_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
assert (gp_cell_out.size() == inputs_x.size())
gp_cell = GlobalPoolingCell(256, 256)
gp_cell_out = gp_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
assert (gp_cell_out.size() == inputs_x.size())
def test_resize_methods():
inputs_x = torch.randn([2, 256, 128, 128])
target_resize_sizes = [(128, 128), (256, 256)]
resize_methods_list = ['nearest', 'bilinear']
for method in resize_methods_list:
merge_cell = BaseMergeCell(upsample_mode=method)
for target_size in target_resize_sizes:
merge_cell_out = merge_cell._resize(inputs_x, target_size)
gt_out = F.interpolate(inputs_x, size=target_size, mode=method)
assert merge_cell_out.equal(gt_out)
target_size = (64, 64) # resize to a smaller size
merge_cell = BaseMergeCell()
merge_cell_out = merge_cell._resize(inputs_x, target_size)
kernel_size = inputs_x.shape[-1] // target_size[-1]
gt_out = F.max_pool2d(
inputs_x, kernel_size=kernel_size, stride=kernel_size)
assert (merge_cell_out == gt_out).all()

View File

@ -0,0 +1,198 @@
from collections import OrderedDict
from itertools import product
from unittest.mock import patch
import torch
import torch.nn as nn
from mmcv.ops import Conv2d, ConvTranspose2d, Linear, MaxPool2d
torch.__version__ = '1.1' # force test
def test_conv2d():
"""
CommandLine:
xdoctest -m tests/test_wrappers.py test_conv2d
"""
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
('in_channel', [1, 3]), ('out_channel', [1, 3]),
('kernel_size', [3, 5]), ('stride', [1, 2]),
('padding', [0, 1]), ('dilation', [1, 2])])
# train mode
for in_h, in_w, in_cha, out_cha, k, s, p, d in product(
*list(test_cases.values())):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_cha, in_h, in_w)
torch.manual_seed(0)
wrapper = Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_cha, in_h, in_w).requires_grad_(True)
torch.manual_seed(0)
ref = nn.Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_cha, in_h, in_w)
wrapper = Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d)
wrapper.eval()
wrapper(x_empty)
def test_conv_transposed_2d():
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
('in_channel', [1, 3]), ('out_channel', [1, 3]),
('kernel_size', [3, 5]), ('stride', [1, 2]),
('padding', [0, 1]), ('dilation', [1, 2])])
for in_h, in_w, in_cha, out_cha, k, s, p, d in product(
*list(test_cases.values())):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_cha, in_h, in_w, requires_grad=True)
# out padding must be smaller than either stride or dilation
op = min(s, d) - 1
torch.manual_seed(0)
wrapper = ConvTranspose2d(
in_cha,
out_cha,
k,
stride=s,
padding=p,
dilation=d,
output_padding=op)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_cha, in_h, in_w)
torch.manual_seed(0)
ref = nn.ConvTranspose2d(
in_cha,
out_cha,
k,
stride=s,
padding=p,
dilation=d,
output_padding=op)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_cha, in_h, in_w)
wrapper = ConvTranspose2d(
in_cha, out_cha, k, stride=s, padding=p, dilation=d, output_padding=op)
wrapper.eval()
wrapper(x_empty)
def test_max_pool_2d():
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
('in_channel', [1, 3]), ('out_channel', [1, 3]),
('kernel_size', [3, 5]), ('stride', [1, 2]),
('padding', [0, 1]), ('dilation', [1, 2])])
for in_h, in_w, in_cha, out_cha, k, s, p, d in product(
*list(test_cases.values())):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_cha, in_h, in_w, requires_grad=True)
wrapper = MaxPool2d(k, stride=s, padding=p, dilation=d)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_cha, in_h, in_w)
ref = nn.MaxPool2d(k, stride=s, padding=p, dilation=d)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
assert torch.equal(wrapper(x_normal), ref_out)
def test_linear():
test_cases = OrderedDict([
('in_w', [10, 20]),
('in_h', [10, 20]),
('in_feature', [1, 3]),
('out_feature', [1, 3]),
])
for in_h, in_w, in_feature, out_feature in product(
*list(test_cases.values())):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_feature, requires_grad=True)
torch.manual_seed(0)
wrapper = Linear(in_feature, out_feature)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_feature)
torch.manual_seed(0)
ref = nn.Linear(in_feature, out_feature)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_feature)
wrapper = Linear(in_feature, out_feature)
wrapper.eval()
wrapper(x_empty)
def test_nn_op_forward_called():
torch.__version__ = '1.4.1'
for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']:
with patch(f'torch.nn.{m}.forward') as nn_module_forward:
# randn input
x_empty = torch.randn(0, 3, 10, 10)
wrapper = eval(m)(3, 2, 1)
wrapper(x_empty)
nn_module_forward.assert_called_with(x_empty)
# non-randn input
x_normal = torch.randn(1, 3, 10, 10)
wrapper = eval(m)(3, 2, 1)
wrapper(x_normal)
nn_module_forward.assert_called_with(x_normal)
with patch('torch.nn.Linear.forward') as nn_module_forward:
# randn input
x_empty = torch.randn(0, 3)
wrapper = Linear(3, 3)
wrapper(x_empty)
nn_module_forward.assert_not_called()
# non-randn input
x_normal = torch.randn(1, 3)
wrapper = Linear(3, 3)
wrapper(x_normal)
nn_module_forward.assert_called_with(x_normal)