Add `_NonLocalNd` module (#331)

* add non_local module

* rewrite non local module comments

* perfect docstring and adjust init function

* not to init norm layer

* Correct initialize when there is a norm

* set normal method for both embedded_gaussian and dot_product
pull/356/head
Jintao Lin 2020-06-18 20:56:48 +08:00 committed by GitHub
parent 6bb244f255
commit dcc20f3a4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 338 additions and 8 deletions

View File

@ -1,9 +1,9 @@
# Copyright (c) Open-MMLab. All rights reserved.
from .alexnet import AlexNet
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, UPSAMPLE_LAYERS, ConvModule, Scale,
build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer,
PADDING_LAYERS, UPSAMPLE_LAYERS, ConvModule, NonLocal1d,
NonLocal2d, NonLocal3d, Scale, build_activation_layer,
build_conv_layer, build_norm_layer, build_padding_layer,
build_upsample_layer, is_norm)
from .resnet import ResNet, make_res_layer
from .vgg import VGG, make_vgg_layer
@ -16,7 +16,7 @@ __all__ = [
'constant_init', 'xavier_init', 'normal_init', 'uniform_init',
'kaiming_init', 'caffe2_xavier_init', 'bias_init_with_prob', 'ConvModule',
'build_activation_layer', 'build_conv_layer', 'build_norm_layer',
'build_padding_layer', 'build_upsample_layer', 'is_norm',
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
'UPSAMPLE_LAYERS', 'Scale'
'build_padding_layer', 'build_upsample_layer', 'is_norm', 'NonLocal1d',
'NonLocal2d', 'NonLocal3d', 'ACTIVATION_LAYERS', 'CONV_LAYERS',
'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'Scale'
]

View File

@ -1,6 +1,7 @@
from .activation import build_activation_layer
from .conv import build_conv_layer
from .conv_module import ConvModule
from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
from .norm import build_norm_layer, is_norm
from .padding import build_padding_layer
from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
@ -11,6 +12,6 @@ from .upsample import build_upsample_layer
__all__ = [
'ConvModule', 'build_activation_layer', 'build_conv_layer',
'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
'is_norm', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'Scale'
'is_norm', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d', 'ACTIVATION_LAYERS',
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'Scale'
]

View File

@ -0,0 +1,240 @@
from abc import ABCMeta
import torch
import torch.nn as nn
from ..weight_init import constant_init, normal_init
from .conv_module import ConvModule
class _NonLocalNd(nn.Module, metaclass=ABCMeta):
"""Basic Non-local module.
This module is proposed in
"Non-local Neural Networks"
Paper reference: https://arxiv.org/abs/1711.07971
Args:
in_channels (int): Channels of the input feature map.
reduction (int): Channel reduction ratio. Default: 2.
use_scale (bool): Whether to scale pairwise_weight by
`1/sqrt(inter_channels)` when the mode is `embedded_gaussian`.
Default: True.
conv_cfg (None | dict): The config dict for convolution layers.
If not specified, it will use `nn.Conv2d` for convolution layers.
Default: None.
norm_cfg (None | dict): The config dict for normalization layers.
Default: None. (This parameter is only applicable to conv_out.)
mode (str): Options are `embedded_gaussian` and `dot_product`.
Default: embedded_gaussian.
"""
def __init__(self,
in_channels,
reduction=2,
use_scale=True,
conv_cfg=None,
norm_cfg=None,
mode='embedded_gaussian',
**kwargs):
super(_NonLocalNd, self).__init__()
self.in_channels = in_channels
self.reduction = reduction
self.use_scale = use_scale
self.inter_channels = in_channels // reduction
self.mode = mode
if mode not in ['embedded_gaussian', 'dot_product']:
raise ValueError(
"Mode should be in 'embedded_gaussian' or 'dot_product', "
f'but got {mode} instead.')
# g, theta, phi are defaulted as `nn.ConvNd`.
# Here we use ConvModule for potential usage.
self.g = ConvModule(
self.in_channels,
self.inter_channels,
kernel_size=1,
conv_cfg=conv_cfg,
act_cfg=None)
self.theta = ConvModule(
self.in_channels,
self.inter_channels,
kernel_size=1,
conv_cfg=conv_cfg,
act_cfg=None)
self.phi = ConvModule(
self.in_channels,
self.inter_channels,
kernel_size=1,
conv_cfg=conv_cfg,
act_cfg=None)
self.conv_out = ConvModule(
self.inter_channels,
self.in_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
self.init_weights(**kwargs)
def init_weights(self, std=0.01, zeros_init=True):
for m in [self.g, self.theta, self.phi]:
normal_init(m.conv, std=std)
if zeros_init:
if self.conv_out.norm_cfg is None:
constant_init(self.conv_out.conv, 0)
else:
constant_init(self.conv_out.norm, 0)
else:
if self.conv_out.norm_cfg is None:
normal_init(self.conv_out.conv, std=std)
else:
normal_init(self.conv_out.norm, std=std)
def embedded_gaussian(self, theta_x, phi_x):
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
pairwise_weight = torch.matmul(theta_x, phi_x)
if self.use_scale:
# theta_x.shape[-1] is `self.inter_channels`
pairwise_weight /= theta_x.shape[-1]**0.5
pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight
def dot_product(self, theta_x, phi_x):
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
pairwise_weight = torch.matmul(theta_x, phi_x)
pairwise_weight /= pairwise_weight.shape[-1]
return pairwise_weight
def forward(self, x):
# Assume `reduction = 1`, then `inter_channels = C`
# NonLocal1d x: [N, C, H]
# NonLocal2d x: [N, C, H, W]
# NonLocal3d x: [N, C, T, H, W]
n = x.size(0)
# NonLocal1d g_x: [N, H, C]
# NonLocal2d g_x: [N, HxW, C]
# NonLocal3d g_x: [N, TxHxW, C]
g_x = self.g(x).view(n, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
# NonLocal1d theta_x: [N, H, C]
# NonLocal2d theta_x: [N, HxW, C]
# NonLocal3d theta_x: [N, TxHxW, C]
theta_x = self.theta(x).view(n, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
# NonLocal1d phi_x: [N, C, H]
# NonLocal2d phi_x: [N, C, HxW]
# NonLocal3d phi_x: [N, C, TxHxW]
phi_x = self.phi(x).view(n, self.inter_channels, -1)
pairwise_func = getattr(self, self.mode)
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
pairwise_weight = pairwise_func(theta_x, phi_x)
# NonLocal1d y: [N, H, C]
# NonLocal2d y: [N, HxW, C]
# NonLocal3d y: [N, TxHxW, C]
y = torch.matmul(pairwise_weight, g_x)
# NonLocal1d y: [N, C, H]
# NonLocal2d y: [N, C, H, W]
# NonLocal3d y: [N, C, T, H, W]
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
*x.size()[2:])
output = x + self.conv_out(y)
return output
class NonLocal1d(_NonLocalNd):
"""1D Non-local module.
Args:
in_channels (int): Same as `NonLocalND`.
sub_sample (bool): Whether to apply max pooling after pairwise
function (Note that the `sub_sample` is applied on spatial only).
Default: False.
conv_cfg (None | dict): Same as `NonLocalND`.
Default: dict(type='Conv1d').
"""
def __init__(self,
in_channels,
sub_sample=False,
conv_cfg=dict(type='Conv1d'),
**kwargs):
super(NonLocal1d, self).__init__(
in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample
if sub_sample:
max_pool_layer = nn.MaxPool1d(kernel_size=2)
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
class NonLocal2d(_NonLocalNd):
"""2D Non-local module.
Args:
in_channels (int): Same as `NonLocalND`.
sub_sample (bool): Whether to apply max pooling after pairwise
function (Note that the `sub_sample` is applied on spatial only).
Default: False.
conv_cfg (None | dict): Same as `NonLocalND`.
Default: dict(type='Conv2d').
"""
def __init__(self,
in_channels,
sub_sample=False,
conv_cfg=dict(type='Conv2d'),
**kwargs):
super(NonLocal2d, self).__init__(
in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample
if sub_sample:
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
class NonLocal3d(_NonLocalNd):
"""3D Non-local module.
Args:
in_channels (int): Same as `NonLocalND`.
sub_sample (bool): Whether to apply max pooling after pairwise
function (Note that the `sub_sample` is applied on spatial only).
Default: False.
conv_cfg (None | dict): Same as `NonLocalND`.
Default: dict(type='Conv3d').
"""
def __init__(self,
in_channels,
sub_sample=False,
conv_cfg=dict(type='Conv3d'),
**kwargs):
super(NonLocal3d, self).__init__(
in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample
if sub_sample:
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)

View File

@ -0,0 +1,89 @@
import pytest
import torch
import torch.nn as nn
from mmcv.cnn import NonLocal1d, NonLocal2d, NonLocal3d
from mmcv.cnn.bricks.non_local import _NonLocalNd
def test_nonlocal():
with pytest.raises(ValueError):
# mode should be in ['embedded_gaussian', 'dot_product']
_NonLocalNd(3, mode='unsupport_mode')
# _NonLocalNd
_NonLocalNd(3, norm_cfg=dict(type='BN'))
# Not Zero initialization
_NonLocalNd(3, norm_cfg=dict(type='BN'), zeros_init=True)
# NonLocal3d
imgs = torch.randn(2, 3, 10, 20, 20)
nonlocal_3d = NonLocal3d(3)
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
# NonLocal is only implemented on gpu in parrots
imgs = imgs.cuda()
nonlocal_3d.cuda()
out = nonlocal_3d(imgs)
assert out.shape == imgs.shape
nonlocal_3d = NonLocal3d(3, mode='dot_product')
assert nonlocal_3d.mode == 'dot_product'
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
nonlocal_3d.cuda()
out = nonlocal_3d(imgs)
assert out.shape == imgs.shape
nonlocal_3d = NonLocal3d(3, mode='dot_product', sub_sample=True)
for m in [nonlocal_3d.g, nonlocal_3d.phi]:
assert isinstance(m, nn.Sequential) and len(m) == 2
assert isinstance(m[1], nn.MaxPool3d)
assert m[1].kernel_size == (1, 2, 2)
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
nonlocal_3d.cuda()
out = nonlocal_3d(imgs)
assert out.shape == imgs.shape
# NonLocal2d
imgs = torch.randn(2, 3, 20, 20)
nonlocal_2d = NonLocal2d(3)
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
imgs = imgs.cuda()
nonlocal_2d.cuda()
out = nonlocal_2d(imgs)
assert out.shape == imgs.shape
nonlocal_2d = NonLocal2d(3, mode='dot_product', sub_sample=True)
for m in [nonlocal_2d.g, nonlocal_2d.phi]:
assert isinstance(m, nn.Sequential) and len(m) == 2
assert isinstance(m[1], nn.MaxPool2d)
assert m[1].kernel_size == (2, 2)
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
nonlocal_2d.cuda()
out = nonlocal_2d(imgs)
assert out.shape == imgs.shape
# NonLocal1d
imgs = torch.randn(2, 3, 20)
nonlocal_1d = NonLocal1d(3)
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
imgs = imgs.cuda()
nonlocal_1d.cuda()
out = nonlocal_1d(imgs)
assert out.shape == imgs.shape
nonlocal_1d = NonLocal1d(3, mode='dot_product', sub_sample=True)
for m in [nonlocal_1d.g, nonlocal_1d.phi]:
assert isinstance(m, nn.Sequential) and len(m) == 2
assert isinstance(m[1], nn.MaxPool1d)
assert m[1].kernel_size == 2
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
nonlocal_1d.cuda()
out = nonlocal_1d(imgs)
assert out.shape == imgs.shape