mirror of https://github.com/open-mmlab/mmcv.git
Migrate op (#392)
* migrate op * migrate unittest * update build no torch * add back use_torch_vision for roi align * fix type and unit test * ignore test logging when no torch * fix no torch ci test * skip test registry * remove coverage report when no torch * fix mac ci order * install latest pillow when no torch * mv convws to briskpull/382/head^2
parent
33ca908529
commit
d5cbf7eed1
|
@ -133,10 +133,6 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7]
|
||||
torch: [1.4.0]
|
||||
include:
|
||||
- torch: 1.4.0
|
||||
torchvision: 0.4.2
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
@ -151,15 +147,10 @@ jobs:
|
|||
- name: Build and install
|
||||
run: rm -rf .eggs && pip install -e .
|
||||
- name: Install Pillow
|
||||
run: pip install Pillow==6.2.2
|
||||
if: ${{matrix.torchvision == '0.4.2'}}
|
||||
- name: Install PyTorch
|
||||
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
run: pip install Pillow
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
coverage run --branch --source=mmcv -m pytest tests/
|
||||
coverage xml
|
||||
coverage report -m
|
||||
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_registry.py
|
||||
|
||||
build_macos:
|
||||
runs-on: macos-latest
|
||||
|
@ -181,15 +172,15 @@ jobs:
|
|||
run: brew install ffmpeg jpeg-turbo
|
||||
- name: Install unittest dependencies
|
||||
run: pip install pytest coverage lmdb PyTurboJPEG
|
||||
- name: Build and install
|
||||
run: |
|
||||
rm -rf .eggs
|
||||
CC=clang CXX=clang++ CFLAGS='-stdlib=libc++' pip install -e .
|
||||
- name: Install Pillow
|
||||
run: pip install Pillow==6.2.2
|
||||
if: ${{matrix.torchvision == '0.4.2'}}
|
||||
- name: Install PyTorch
|
||||
run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Build and install
|
||||
run: |
|
||||
rm -rf .eggs
|
||||
CC=clang CXX=clang++ CFLAGS='-stdlib=libc++' pip install -e .
|
||||
- name: Run unittests
|
||||
run: |
|
||||
# The timing on macos VMs is not precise, so we skip the progressbar tests
|
||||
|
|
|
@ -2,11 +2,12 @@
|
|||
from .alexnet import AlexNet
|
||||
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
|
||||
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
|
||||
ContextBlock, ConvModule, GeneralizedAttention, HSigmoid,
|
||||
HSwish, NonLocal1d, NonLocal2d, NonLocal3d, Scale,
|
||||
build_activation_layer, build_conv_layer,
|
||||
build_norm_layer, build_padding_layer, build_plugin_layer,
|
||||
build_upsample_layer, is_norm)
|
||||
ContextBlock, ConvAWS2d, ConvModule, ConvWS2d,
|
||||
GeneralizedAttention, HSigmoid, HSwish, NonLocal1d,
|
||||
NonLocal2d, NonLocal3d, Scale, build_activation_layer,
|
||||
build_conv_layer, build_norm_layer, build_padding_layer,
|
||||
build_plugin_layer, build_upsample_layer, conv_ws_2d,
|
||||
is_norm)
|
||||
from .resnet import ResNet, make_res_layer
|
||||
from .utils import (bias_init_with_prob, caffe2_xavier_init, constant_init,
|
||||
get_model_complexity_info, kaiming_init, normal_init,
|
||||
|
@ -22,5 +23,6 @@ __all__ = [
|
|||
'is_norm', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d', 'ContextBlock',
|
||||
'HSigmoid', 'HSwish', 'GeneralizedAttention', 'ACTIVATION_LAYERS',
|
||||
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
|
||||
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info'
|
||||
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
|
||||
'ConvAWS2d', 'ConvWS2d'
|
||||
]
|
||||
|
|
|
@ -2,6 +2,7 @@ from .activation import build_activation_layer
|
|||
from .context_block import ContextBlock
|
||||
from .conv import build_conv_layer
|
||||
from .conv_module import ConvModule
|
||||
from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
|
||||
from .generalized_attention import GeneralizedAttention
|
||||
from .hsigmoid import HSigmoid
|
||||
from .hswish import HSwish
|
||||
|
@ -20,5 +21,6 @@ __all__ = [
|
|||
'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d',
|
||||
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
|
||||
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
|
||||
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale'
|
||||
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
|
||||
'conv_ws_2d'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import CONV_LAYERS
|
||||
|
||||
|
||||
def conv_ws_2d(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
eps=1e-5):
|
||||
c_in = weight.size(0)
|
||||
weight_flat = weight.view(c_in, -1)
|
||||
mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
|
||||
std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
|
||||
weight = (weight - mean) / (std + eps)
|
||||
return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
|
||||
|
||||
|
||||
@CONV_LAYERS.register_module('ConvWS')
|
||||
class ConvWS2d(nn.Conv2d):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
eps=1e-5):
|
||||
super(ConvWS2d, self).__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
|
||||
self.dilation, self.groups, self.eps)
|
||||
|
||||
|
||||
@CONV_LAYERS.register_module(name='ConvAWS')
|
||||
class ConvAWS2d(nn.Conv2d):
|
||||
"""AWS (Adaptive Weight Standardization)
|
||||
|
||||
This is a variant of Weight Standardization
|
||||
(https://arxiv.org/pdf/1903.10520.pdf)
|
||||
It is used in DetectoRS to avoid NaN
|
||||
(https://arxiv.org/pdf/2006.02334.pdf)
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input image
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int or tuple): Size of the conv kernel
|
||||
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
||||
padding (int or tuple, optional): Zero-padding added to both sides of
|
||||
the input. Default: 0
|
||||
dilation (int or tuple, optional): Spacing between kernel elements.
|
||||
Default: 1
|
||||
groups (int, optional): Number of blocked connections from input
|
||||
channels to output channels. Default: 1
|
||||
bias (bool, optional): If set True, adds a learnable bias to the
|
||||
output. Default: True
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias)
|
||||
self.register_buffer('weight_gamma',
|
||||
torch.ones(self.out_channels, 1, 1, 1))
|
||||
self.register_buffer('weight_beta',
|
||||
torch.zeros(self.out_channels, 1, 1, 1))
|
||||
|
||||
def _get_weight(self, weight):
|
||||
weight_flat = weight.view(weight.size(0), -1)
|
||||
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
|
||||
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
|
||||
weight = (weight - mean) / std
|
||||
weight = self.weight_gamma * weight + self.weight_beta
|
||||
return weight
|
||||
|
||||
def forward(self, x):
|
||||
weight = self._get_weight(self.weight)
|
||||
return F.conv2d(x, weight, self.bias, self.stride, self.padding,
|
||||
self.dilation, self.groups)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
"""Override default load function.
|
||||
|
||||
AWS overrides the function _load_from_state_dict to recover
|
||||
weight_gamma and weight_beta if they are missing. If weight_gamma and
|
||||
weight_beta are found in the checkpoint, this function will return
|
||||
after super()._load_from_state_dict. Otherwise, it will compute the
|
||||
mean and std of the pretrained weights and store them in weight_beta
|
||||
and weight_gamma.
|
||||
"""
|
||||
|
||||
self.weight_gamma.data.fill_(-1)
|
||||
local_missing_keys = []
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata,
|
||||
strict, local_missing_keys,
|
||||
unexpected_keys, error_msgs)
|
||||
if self.weight_gamma.data.mean() > 0:
|
||||
for k in local_missing_keys:
|
||||
missing_keys.append(k)
|
||||
return
|
||||
weight = self.weight.data
|
||||
weight_flat = weight.view(weight.size(0), -1)
|
||||
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
|
||||
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
|
||||
self.weight_beta.data.copy_(mean)
|
||||
self.weight_gamma.data.copy_(std)
|
||||
missing_gamma_beta = [
|
||||
k for k in local_missing_keys
|
||||
if k.endswith('weight_gamma') or k.endswith('weight_beta')
|
||||
]
|
||||
for k in missing_gamma_beta:
|
||||
local_missing_keys.remove(k)
|
||||
for k in local_missing_keys:
|
||||
missing_keys.append(k)
|
|
@ -1,7 +1,6 @@
|
|||
from .bbox import bbox_overlaps
|
||||
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
|
||||
from .cc_attention import CrissCrossAttention
|
||||
from .conv_ws import ConvWS2d, conv_ws_2d
|
||||
from .corner_pool import CornerPool
|
||||
from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d
|
||||
from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack,
|
||||
|
@ -19,13 +18,14 @@ from .point_sample import (SimpleRoIAlign, point_sample,
|
|||
from .psa_mask import PSAMask
|
||||
from .roi_align import RoIAlign, roi_align
|
||||
from .roi_pool import RoIPool, roi_pool
|
||||
from .saconv import SAConv2d
|
||||
from .sync_bn import SyncBatchNorm
|
||||
from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
|
||||
|
||||
__all__ = [
|
||||
'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe',
|
||||
'carafe_naive', 'ConvWS2d', 'conv_ws_2d', 'CornerPool', 'DeformConv2d',
|
||||
'DeformConv2dPack', 'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack',
|
||||
'carafe_naive', 'CornerPool', 'DeformConv2d', 'DeformConv2dPack',
|
||||
'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack',
|
||||
'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss',
|
||||
'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss',
|
||||
'get_compiler_version', 'get_compiling_cuda_version', 'MaskedConv2d',
|
||||
|
@ -33,5 +33,6 @@ __all__ = [
|
|||
'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match',
|
||||
'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
|
||||
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
|
||||
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign'
|
||||
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
|
||||
'SAConv2d'
|
||||
]
|
||||
|
|
|
@ -1,49 +0,0 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..cnn import CONV_LAYERS
|
||||
|
||||
|
||||
def conv_ws_2d(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
eps=1e-5):
|
||||
c_in = weight.size(0)
|
||||
weight_flat = weight.view(c_in, -1)
|
||||
mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
|
||||
std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
|
||||
weight = (weight - mean) / (std + eps)
|
||||
return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
|
||||
|
||||
|
||||
@CONV_LAYERS.register_module('ConvWS')
|
||||
class ConvWS2d(nn.Conv2d):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
eps=1e-5):
|
||||
super(ConvWS2d, self).__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
|
||||
self.dilation, self.groups, self.eps)
|
|
@ -162,11 +162,11 @@ def point_sample(input, points, align_corners=False, **kwargs):
|
|||
|
||||
class SimpleRoIAlign(nn.Module):
|
||||
|
||||
def __init__(self, out_size, spatial_scale, aligned=True):
|
||||
def __init__(self, output_size, spatial_scale, aligned=True):
|
||||
"""Simple RoI align in PointRend, faster than standard RoIAlign.
|
||||
|
||||
Args:
|
||||
out_size (tuple[int]): h, w
|
||||
output_size (tuple[int]): h, w
|
||||
spatial_scale (float): scale the input boxes by this number
|
||||
aligned (bool): if False, use the legacy implementation in
|
||||
MMDetection, align_corners=True will be used in F.grid_sample.
|
||||
|
@ -174,7 +174,7 @@ class SimpleRoIAlign(nn.Module):
|
|||
"""
|
||||
|
||||
super(SimpleRoIAlign, self).__init__()
|
||||
self.out_size = _pair(out_size)
|
||||
self.output_size = _pair(output_size)
|
||||
self.spatial_scale = float(spatial_scale)
|
||||
# to be consistent with other RoI ops
|
||||
self.use_torchvision = False
|
||||
|
@ -185,7 +185,7 @@ class SimpleRoIAlign(nn.Module):
|
|||
num_imgs = features.size(0)
|
||||
num_rois = rois.size(0)
|
||||
rel_roi_points = generate_grid(
|
||||
num_rois, self.out_size, device=rois.device)
|
||||
num_rois, self.output_size, device=rois.device)
|
||||
|
||||
point_feats = []
|
||||
for batch_ind in range(num_imgs):
|
||||
|
@ -203,12 +203,12 @@ class SimpleRoIAlign(nn.Module):
|
|||
|
||||
channels = features.size(1)
|
||||
roi_feats = torch.cat(point_feats, dim=0)
|
||||
roi_feats = roi_feats.reshape(num_rois, channels, *self.out_size)
|
||||
roi_feats = roi_feats.reshape(num_rois, channels, *self.output_size)
|
||||
|
||||
return roi_feats
|
||||
|
||||
def __repr__(self):
|
||||
format_str = self.__class__.__name__
|
||||
format_str += '(out_size={}, spatial_scale={}'.format(
|
||||
self.out_size, self.spatial_scale)
|
||||
format_str += '(output_size={}, spatial_scale={}'.format(
|
||||
self.output_size, self.spatial_scale)
|
||||
return format_str
|
||||
|
|
|
@ -105,6 +105,7 @@ class RoIAlign(nn.Module):
|
|||
pool_mode (str, 'avg' or 'max'): pooling mode in each bin.
|
||||
aligned (bool): if False, use the legacy implementation in
|
||||
MMDetection. If True, align the results more perfectly.
|
||||
use_torchvision (bool): whether to use roi_align from torchvision.
|
||||
|
||||
Note:
|
||||
The implementation of RoIAlign when aligned=True is modified from
|
||||
|
@ -135,7 +136,8 @@ class RoIAlign(nn.Module):
|
|||
spatial_scale=1.0,
|
||||
sampling_ratio=0,
|
||||
pool_mode='avg',
|
||||
aligned=True):
|
||||
aligned=True,
|
||||
use_torchvision=False):
|
||||
super(RoIAlign, self).__init__()
|
||||
|
||||
self.output_size = _pair(output_size)
|
||||
|
@ -143,6 +145,9 @@ class RoIAlign(nn.Module):
|
|||
self.sampling_ratio = int(sampling_ratio)
|
||||
self.pool_mode = pool_mode
|
||||
self.aligned = aligned
|
||||
self.use_torchvision = use_torchvision
|
||||
assert not (use_torchvision and
|
||||
aligned), 'Torchvision does not support aligned RoIAlgin'
|
||||
|
||||
def forward(self, input, rois):
|
||||
"""
|
||||
|
@ -151,8 +156,13 @@ class RoIAlign(nn.Module):
|
|||
rois: Bx5 boxes. First column is the index into N.\
|
||||
The other 4 columns are xyxy.
|
||||
"""
|
||||
return roi_align(input, rois, self.output_size, self.spatial_scale,
|
||||
self.sampling_ratio, self.pool_mode, self.aligned)
|
||||
if self.use_torchvision:
|
||||
from torchvision.ops import roi_align as tv_roi_align
|
||||
return tv_roi_align(input, rois, self.output_size,
|
||||
self.spatial_scale, self.sampling_ratio)
|
||||
else:
|
||||
return roi_align(input, rois, self.output_size, self.spatial_scale,
|
||||
self.sampling_ratio, self.pool_mode, self.aligned)
|
||||
|
||||
def __repr__(self):
|
||||
s = self.__class__.__name__
|
||||
|
@ -160,5 +170,6 @@ class RoIAlign(nn.Module):
|
|||
s += f'spatial_scale={self.spatial_scale}, '
|
||||
s += f'sampling_ratio={self.sampling_ratio}, '
|
||||
s += f'pool_mode={self.pool_mode}, '
|
||||
s += f'aligned={self.aligned})'
|
||||
s += f'aligned={self.aligned}, '
|
||||
s += f'use_torchvision={self.use_torchvision})'
|
||||
return s
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init
|
||||
from mmcv.ops.deform_conv import deform_conv2d
|
||||
|
||||
|
||||
@CONV_LAYERS.register_module(name='SAC')
|
||||
class SAConv2d(ConvAWS2d):
|
||||
"""SAC (Switchable Atrous Convolution)
|
||||
|
||||
This is an implementation of SAC in DetectoRS
|
||||
(https://arxiv.org/pdf/2006.02334.pdf).
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input image
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int or tuple): Size of the convolving kernel
|
||||
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
||||
padding (int or tuple, optional): Zero-padding added to both sides of
|
||||
the input. Default: 0
|
||||
padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
|
||||
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
|
||||
dilation (int or tuple, optional): Spacing between kernel elements.
|
||||
Default: 1
|
||||
groups (int, optional): Number of blocked connections from input
|
||||
channels to output channels. Default: 1
|
||||
bias (bool, optional): If ``True``, adds a learnable bias to the
|
||||
output. Default: ``True``
|
||||
use_deform: If ``True``, replace convolution with deformable
|
||||
convolution. Default: ``False``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
use_deform=False):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias)
|
||||
self.use_deform = use_deform
|
||||
self.switch = nn.Conv2d(
|
||||
self.in_channels, 1, kernel_size=1, stride=stride, bias=True)
|
||||
self.weight_diff = nn.Parameter(torch.Tensor(self.weight.size()))
|
||||
self.pre_context = nn.Conv2d(
|
||||
self.in_channels, self.in_channels, kernel_size=1, bias=True)
|
||||
self.post_context = nn.Conv2d(
|
||||
self.out_channels, self.out_channels, kernel_size=1, bias=True)
|
||||
if self.use_deform:
|
||||
self.offset_s = nn.Conv2d(
|
||||
self.in_channels,
|
||||
18,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=stride,
|
||||
bias=True)
|
||||
self.offset_l = nn.Conv2d(
|
||||
self.in_channels,
|
||||
18,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=stride,
|
||||
bias=True)
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
constant_init(self.switch, 0, bias=1)
|
||||
self.weight_diff.data.zero_()
|
||||
constant_init(self.pre_context, 0)
|
||||
constant_init(self.post_context, 0)
|
||||
if self.use_deform:
|
||||
constant_init(self.offset_s, 0)
|
||||
constant_init(self.offset_l, 0)
|
||||
|
||||
def forward(self, x):
|
||||
# pre-context
|
||||
avg_x = F.adaptive_avg_pool2d(x, output_size=1)
|
||||
avg_x = self.pre_context(avg_x)
|
||||
avg_x = avg_x.expand_as(x)
|
||||
x = x + avg_x
|
||||
# switch
|
||||
avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect')
|
||||
avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)
|
||||
switch = self.switch(avg_x)
|
||||
# sac
|
||||
weight = self._get_weight(self.weight)
|
||||
if self.use_deform:
|
||||
offset = self.offset_s(avg_x)
|
||||
out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
|
||||
self.dilation, self.groups, 1)
|
||||
else:
|
||||
out_s = super().conv2d_forward(x, weight)
|
||||
ori_p = self.padding
|
||||
ori_d = self.dilation
|
||||
self.padding = tuple(3 * p for p in self.padding)
|
||||
self.dilation = tuple(3 * d for d in self.dilation)
|
||||
weight = weight + self.weight_diff
|
||||
if self.use_deform:
|
||||
offset = self.offset_l(avg_x)
|
||||
out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
|
||||
self.dilation, self.groups, 1)
|
||||
else:
|
||||
out_l = super().conv2d_forward(x, weight)
|
||||
out = switch * out_s + (1 - switch) * out_l
|
||||
self.padding = ori_p
|
||||
self.dilation = ori_d
|
||||
# post-context
|
||||
avg_x = F.adaptive_avg_pool2d(out, output_size=1)
|
||||
avg_x = self.post_context(avg_x)
|
||||
avg_x = avg_x.expand_as(out)
|
||||
out = out + avg_x
|
||||
return out
|
|
@ -1,12 +1,17 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from numpy.testing import assert_array_equal
|
||||
|
||||
import mmcv
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
torch = None
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch is None, reason='requires torch library')
|
||||
def test_tensor2imgs():
|
||||
|
||||
# test tensor obj
|
||||
|
|
|
@ -1,33 +1,58 @@
|
|||
"""
|
||||
CommandLine:
|
||||
pytest tests/test_corner_pool.py
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
from torch.autograd import gradcheck
|
||||
|
||||
from mmcv.ops import CornerPool
|
||||
|
||||
|
||||
class TestCornerPool(object):
|
||||
def test_corner_pool_device_and_dtypes_cpu():
|
||||
"""
|
||||
CommandLine:
|
||||
xdoctest -m tests/test_corner_pool.py \
|
||||
test_corner_pool_device_and_dtypes_cpu
|
||||
"""
|
||||
with pytest.raises(AssertionError):
|
||||
# pool mode must in ['bottom', 'left', 'right', 'top']
|
||||
pool = CornerPool('corner')
|
||||
|
||||
def test_corner_pool_top_gradcheck(self):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
from mmcv.ops import CornerPool
|
||||
input = torch.randn(2, 4, 5, 5, requires_grad=True, device='cuda')
|
||||
gradcheck(CornerPool('top'), (input, ), atol=1e-3, eps=1e-4)
|
||||
|
||||
def test_corner_pool_bottom_gradcheck(self):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
from mmcv.ops import CornerPool
|
||||
input = torch.randn(2, 4, 5, 5, requires_grad=True, device='cuda')
|
||||
gradcheck(CornerPool('bottom'), (input, ), atol=1e-3, eps=1e-4)
|
||||
|
||||
def test_corner_pool_left_gradcheck(self):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
from mmcv.ops import CornerPool
|
||||
input = torch.randn(2, 4, 5, 5, requires_grad=True, device='cuda')
|
||||
gradcheck(CornerPool('left'), (input, ), atol=1e-3, eps=1e-4)
|
||||
|
||||
def test_corner_pool_right_gradcheck(self):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
from mmcv.ops import CornerPool
|
||||
input = torch.randn(2, 4, 5, 5, requires_grad=True, device='cuda')
|
||||
gradcheck(CornerPool('right'), (input, ), atol=1e-3, eps=1e-4)
|
||||
lr_tensor = torch.tensor([[[[0, 0, 0, 0, 0], [2, 1, 3, 0, 2],
|
||||
[5, 4, 1, 1, 6], [0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]]]])
|
||||
tb_tensor = torch.tensor([[[[0, 3, 1, 0, 0], [0, 1, 1, 0, 0],
|
||||
[0, 3, 4, 0, 0], [0, 2, 2, 0, 0],
|
||||
[0, 0, 2, 0, 0]]]])
|
||||
# Left Pool
|
||||
left_answer = torch.tensor([[[[0, 0, 0, 0, 0], [3, 3, 3, 2, 2],
|
||||
[6, 6, 6, 6, 6], [0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]]]])
|
||||
pool = CornerPool('left')
|
||||
left_tensor = pool(lr_tensor)
|
||||
assert left_tensor.type() == lr_tensor.type()
|
||||
assert torch.equal(left_tensor, left_answer)
|
||||
# Right Pool
|
||||
right_answer = torch.tensor([[[[0, 0, 0, 0, 0], [2, 2, 3, 3, 3],
|
||||
[5, 5, 5, 5, 6], [0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]]]])
|
||||
pool = CornerPool('right')
|
||||
right_tensor = pool(lr_tensor)
|
||||
assert right_tensor.type() == lr_tensor.type()
|
||||
assert torch.equal(right_tensor, right_answer)
|
||||
# Top Pool
|
||||
top_answer = torch.tensor([[[[0, 3, 4, 0, 0], [0, 3, 4, 0, 0],
|
||||
[0, 3, 4, 0, 0], [0, 2, 2, 0, 0],
|
||||
[0, 0, 2, 0, 0]]]])
|
||||
pool = CornerPool('top')
|
||||
top_tensor = pool(tb_tensor)
|
||||
assert top_tensor.type() == tb_tensor.type()
|
||||
assert torch.equal(top_tensor, top_answer)
|
||||
# Bottom Pool
|
||||
bottom_answer = torch.tensor([[[[0, 3, 1, 0, 0], [0, 3, 1, 0, 0],
|
||||
[0, 3, 4, 0, 0], [0, 3, 4, 0, 0],
|
||||
[0, 3, 4, 0, 0]]]])
|
||||
pool = CornerPool('bottom')
|
||||
bottom_tensor = pool(tb_tensor)
|
||||
assert bottom_tensor.type() == tb_tensor.type()
|
||||
assert torch.equal(bottom_tensor, bottom_answer)
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
"""
|
||||
CommandLine:
|
||||
pytest tests/test_merge_cells.py
|
||||
"""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmcv.ops.merge_cells import (BaseMergeCell, ConcatCell, GlobalPoolingCell,
|
||||
SumCell)
|
||||
|
||||
|
||||
def test_sum_cell():
|
||||
inputs_x = torch.randn([2, 256, 32, 32])
|
||||
inputs_y = torch.randn([2, 256, 16, 16])
|
||||
sum_cell = SumCell(256, 256)
|
||||
output = sum_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
|
||||
assert output.size() == inputs_x.size()
|
||||
output = sum_cell(inputs_x, inputs_y, out_size=inputs_y.shape[-2:])
|
||||
assert output.size() == inputs_y.size()
|
||||
output = sum_cell(inputs_x, inputs_y)
|
||||
assert output.size() == inputs_x.size()
|
||||
|
||||
|
||||
def test_concat_cell():
|
||||
inputs_x = torch.randn([2, 256, 32, 32])
|
||||
inputs_y = torch.randn([2, 256, 16, 16])
|
||||
concat_cell = ConcatCell(256, 256)
|
||||
output = concat_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
|
||||
assert output.size() == inputs_x.size()
|
||||
output = concat_cell(inputs_x, inputs_y, out_size=inputs_y.shape[-2:])
|
||||
assert output.size() == inputs_y.size()
|
||||
output = concat_cell(inputs_x, inputs_y)
|
||||
assert output.size() == inputs_x.size()
|
||||
|
||||
|
||||
def test_global_pool_cell():
|
||||
inputs_x = torch.randn([2, 256, 32, 32])
|
||||
inputs_y = torch.randn([2, 256, 32, 32])
|
||||
gp_cell = GlobalPoolingCell(with_out_conv=False)
|
||||
gp_cell_out = gp_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
|
||||
assert (gp_cell_out.size() == inputs_x.size())
|
||||
gp_cell = GlobalPoolingCell(256, 256)
|
||||
gp_cell_out = gp_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
|
||||
assert (gp_cell_out.size() == inputs_x.size())
|
||||
|
||||
|
||||
def test_resize_methods():
|
||||
inputs_x = torch.randn([2, 256, 128, 128])
|
||||
target_resize_sizes = [(128, 128), (256, 256)]
|
||||
resize_methods_list = ['nearest', 'bilinear']
|
||||
|
||||
for method in resize_methods_list:
|
||||
merge_cell = BaseMergeCell(upsample_mode=method)
|
||||
for target_size in target_resize_sizes:
|
||||
merge_cell_out = merge_cell._resize(inputs_x, target_size)
|
||||
gt_out = F.interpolate(inputs_x, size=target_size, mode=method)
|
||||
assert merge_cell_out.equal(gt_out)
|
||||
|
||||
target_size = (64, 64) # resize to a smaller size
|
||||
merge_cell = BaseMergeCell()
|
||||
merge_cell_out = merge_cell._resize(inputs_x, target_size)
|
||||
kernel_size = inputs_x.shape[-1] // target_size[-1]
|
||||
gt_out = F.max_pool2d(
|
||||
inputs_x, kernel_size=kernel_size, stride=kernel_size)
|
||||
assert (merge_cell_out == gt_out).all()
|
|
@ -0,0 +1,198 @@
|
|||
from collections import OrderedDict
|
||||
from itertools import product
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcv.ops import Conv2d, ConvTranspose2d, Linear, MaxPool2d
|
||||
|
||||
torch.__version__ = '1.1' # force test
|
||||
|
||||
|
||||
def test_conv2d():
|
||||
"""
|
||||
CommandLine:
|
||||
xdoctest -m tests/test_wrappers.py test_conv2d
|
||||
"""
|
||||
|
||||
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
|
||||
('in_channel', [1, 3]), ('out_channel', [1, 3]),
|
||||
('kernel_size', [3, 5]), ('stride', [1, 2]),
|
||||
('padding', [0, 1]), ('dilation', [1, 2])])
|
||||
|
||||
# train mode
|
||||
for in_h, in_w, in_cha, out_cha, k, s, p, d in product(
|
||||
*list(test_cases.values())):
|
||||
# wrapper op with 0-dim input
|
||||
x_empty = torch.randn(0, in_cha, in_h, in_w)
|
||||
torch.manual_seed(0)
|
||||
wrapper = Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d)
|
||||
wrapper_out = wrapper(x_empty)
|
||||
|
||||
# torch op with 3-dim input as shape reference
|
||||
x_normal = torch.randn(3, in_cha, in_h, in_w).requires_grad_(True)
|
||||
torch.manual_seed(0)
|
||||
ref = nn.Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d)
|
||||
ref_out = ref(x_normal)
|
||||
|
||||
assert wrapper_out.shape[0] == 0
|
||||
assert wrapper_out.shape[1:] == ref_out.shape[1:]
|
||||
|
||||
wrapper_out.sum().backward()
|
||||
assert wrapper.weight.grad is not None
|
||||
assert wrapper.weight.grad.shape == wrapper.weight.shape
|
||||
|
||||
assert torch.equal(wrapper(x_normal), ref_out)
|
||||
|
||||
# eval mode
|
||||
x_empty = torch.randn(0, in_cha, in_h, in_w)
|
||||
wrapper = Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d)
|
||||
wrapper.eval()
|
||||
wrapper(x_empty)
|
||||
|
||||
|
||||
def test_conv_transposed_2d():
|
||||
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
|
||||
('in_channel', [1, 3]), ('out_channel', [1, 3]),
|
||||
('kernel_size', [3, 5]), ('stride', [1, 2]),
|
||||
('padding', [0, 1]), ('dilation', [1, 2])])
|
||||
|
||||
for in_h, in_w, in_cha, out_cha, k, s, p, d in product(
|
||||
*list(test_cases.values())):
|
||||
# wrapper op with 0-dim input
|
||||
x_empty = torch.randn(0, in_cha, in_h, in_w, requires_grad=True)
|
||||
# out padding must be smaller than either stride or dilation
|
||||
op = min(s, d) - 1
|
||||
torch.manual_seed(0)
|
||||
wrapper = ConvTranspose2d(
|
||||
in_cha,
|
||||
out_cha,
|
||||
k,
|
||||
stride=s,
|
||||
padding=p,
|
||||
dilation=d,
|
||||
output_padding=op)
|
||||
wrapper_out = wrapper(x_empty)
|
||||
|
||||
# torch op with 3-dim input as shape reference
|
||||
x_normal = torch.randn(3, in_cha, in_h, in_w)
|
||||
torch.manual_seed(0)
|
||||
ref = nn.ConvTranspose2d(
|
||||
in_cha,
|
||||
out_cha,
|
||||
k,
|
||||
stride=s,
|
||||
padding=p,
|
||||
dilation=d,
|
||||
output_padding=op)
|
||||
ref_out = ref(x_normal)
|
||||
|
||||
assert wrapper_out.shape[0] == 0
|
||||
assert wrapper_out.shape[1:] == ref_out.shape[1:]
|
||||
|
||||
wrapper_out.sum().backward()
|
||||
assert wrapper.weight.grad is not None
|
||||
assert wrapper.weight.grad.shape == wrapper.weight.shape
|
||||
|
||||
assert torch.equal(wrapper(x_normal), ref_out)
|
||||
|
||||
# eval mode
|
||||
x_empty = torch.randn(0, in_cha, in_h, in_w)
|
||||
wrapper = ConvTranspose2d(
|
||||
in_cha, out_cha, k, stride=s, padding=p, dilation=d, output_padding=op)
|
||||
wrapper.eval()
|
||||
wrapper(x_empty)
|
||||
|
||||
|
||||
def test_max_pool_2d():
|
||||
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
|
||||
('in_channel', [1, 3]), ('out_channel', [1, 3]),
|
||||
('kernel_size', [3, 5]), ('stride', [1, 2]),
|
||||
('padding', [0, 1]), ('dilation', [1, 2])])
|
||||
|
||||
for in_h, in_w, in_cha, out_cha, k, s, p, d in product(
|
||||
*list(test_cases.values())):
|
||||
# wrapper op with 0-dim input
|
||||
x_empty = torch.randn(0, in_cha, in_h, in_w, requires_grad=True)
|
||||
wrapper = MaxPool2d(k, stride=s, padding=p, dilation=d)
|
||||
wrapper_out = wrapper(x_empty)
|
||||
|
||||
# torch op with 3-dim input as shape reference
|
||||
x_normal = torch.randn(3, in_cha, in_h, in_w)
|
||||
ref = nn.MaxPool2d(k, stride=s, padding=p, dilation=d)
|
||||
ref_out = ref(x_normal)
|
||||
|
||||
assert wrapper_out.shape[0] == 0
|
||||
assert wrapper_out.shape[1:] == ref_out.shape[1:]
|
||||
|
||||
assert torch.equal(wrapper(x_normal), ref_out)
|
||||
|
||||
|
||||
def test_linear():
|
||||
test_cases = OrderedDict([
|
||||
('in_w', [10, 20]),
|
||||
('in_h', [10, 20]),
|
||||
('in_feature', [1, 3]),
|
||||
('out_feature', [1, 3]),
|
||||
])
|
||||
|
||||
for in_h, in_w, in_feature, out_feature in product(
|
||||
*list(test_cases.values())):
|
||||
# wrapper op with 0-dim input
|
||||
x_empty = torch.randn(0, in_feature, requires_grad=True)
|
||||
torch.manual_seed(0)
|
||||
wrapper = Linear(in_feature, out_feature)
|
||||
wrapper_out = wrapper(x_empty)
|
||||
|
||||
# torch op with 3-dim input as shape reference
|
||||
x_normal = torch.randn(3, in_feature)
|
||||
torch.manual_seed(0)
|
||||
ref = nn.Linear(in_feature, out_feature)
|
||||
ref_out = ref(x_normal)
|
||||
|
||||
assert wrapper_out.shape[0] == 0
|
||||
assert wrapper_out.shape[1:] == ref_out.shape[1:]
|
||||
|
||||
wrapper_out.sum().backward()
|
||||
assert wrapper.weight.grad is not None
|
||||
assert wrapper.weight.grad.shape == wrapper.weight.shape
|
||||
|
||||
assert torch.equal(wrapper(x_normal), ref_out)
|
||||
|
||||
# eval mode
|
||||
x_empty = torch.randn(0, in_feature)
|
||||
wrapper = Linear(in_feature, out_feature)
|
||||
wrapper.eval()
|
||||
wrapper(x_empty)
|
||||
|
||||
|
||||
def test_nn_op_forward_called():
|
||||
torch.__version__ = '1.4.1'
|
||||
|
||||
for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']:
|
||||
with patch(f'torch.nn.{m}.forward') as nn_module_forward:
|
||||
# randn input
|
||||
x_empty = torch.randn(0, 3, 10, 10)
|
||||
wrapper = eval(m)(3, 2, 1)
|
||||
wrapper(x_empty)
|
||||
nn_module_forward.assert_called_with(x_empty)
|
||||
|
||||
# non-randn input
|
||||
x_normal = torch.randn(1, 3, 10, 10)
|
||||
wrapper = eval(m)(3, 2, 1)
|
||||
wrapper(x_normal)
|
||||
nn_module_forward.assert_called_with(x_normal)
|
||||
|
||||
with patch('torch.nn.Linear.forward') as nn_module_forward:
|
||||
# randn input
|
||||
x_empty = torch.randn(0, 3)
|
||||
wrapper = Linear(3, 3)
|
||||
wrapper(x_empty)
|
||||
nn_module_forward.assert_not_called()
|
||||
|
||||
# non-randn input
|
||||
x_normal = torch.randn(1, 3)
|
||||
wrapper = Linear(3, 3)
|
||||
wrapper(x_normal)
|
||||
nn_module_forward.assert_called_with(x_normal)
|
Loading…
Reference in New Issue