mirror of https://github.com/open-mmlab/mmcv.git
Add `_NonLocalNd` module (#331)
* add non_local module * rewrite non local module comments * perfect docstring and adjust init function * not to init norm layer * Correct initialize when there is a norm * set normal method for both embedded_gaussian and dot_productpull/356/head
parent
6bb244f255
commit
dcc20f3a4b
|
@ -1,9 +1,9 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
from .alexnet import AlexNet
|
||||
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
|
||||
PADDING_LAYERS, UPSAMPLE_LAYERS, ConvModule, Scale,
|
||||
build_activation_layer, build_conv_layer,
|
||||
build_norm_layer, build_padding_layer,
|
||||
PADDING_LAYERS, UPSAMPLE_LAYERS, ConvModule, NonLocal1d,
|
||||
NonLocal2d, NonLocal3d, Scale, build_activation_layer,
|
||||
build_conv_layer, build_norm_layer, build_padding_layer,
|
||||
build_upsample_layer, is_norm)
|
||||
from .resnet import ResNet, make_res_layer
|
||||
from .vgg import VGG, make_vgg_layer
|
||||
|
@ -16,7 +16,7 @@ __all__ = [
|
|||
'constant_init', 'xavier_init', 'normal_init', 'uniform_init',
|
||||
'kaiming_init', 'caffe2_xavier_init', 'bias_init_with_prob', 'ConvModule',
|
||||
'build_activation_layer', 'build_conv_layer', 'build_norm_layer',
|
||||
'build_padding_layer', 'build_upsample_layer', 'is_norm',
|
||||
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
|
||||
'UPSAMPLE_LAYERS', 'Scale'
|
||||
'build_padding_layer', 'build_upsample_layer', 'is_norm', 'NonLocal1d',
|
||||
'NonLocal2d', 'NonLocal3d', 'ACTIVATION_LAYERS', 'CONV_LAYERS',
|
||||
'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'Scale'
|
||||
]
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from .activation import build_activation_layer
|
||||
from .conv import build_conv_layer
|
||||
from .conv_module import ConvModule
|
||||
from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
|
||||
from .norm import build_norm_layer, is_norm
|
||||
from .padding import build_padding_layer
|
||||
from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
|
||||
|
@ -11,6 +12,6 @@ from .upsample import build_upsample_layer
|
|||
__all__ = [
|
||||
'ConvModule', 'build_activation_layer', 'build_conv_layer',
|
||||
'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
|
||||
'is_norm', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
|
||||
'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'Scale'
|
||||
'is_norm', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d', 'ACTIVATION_LAYERS',
|
||||
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'Scale'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,240 @@
|
|||
from abc import ABCMeta
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..weight_init import constant_init, normal_init
|
||||
from .conv_module import ConvModule
|
||||
|
||||
|
||||
class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
||||
"""Basic Non-local module.
|
||||
|
||||
This module is proposed in
|
||||
"Non-local Neural Networks"
|
||||
Paper reference: https://arxiv.org/abs/1711.07971
|
||||
|
||||
Args:
|
||||
in_channels (int): Channels of the input feature map.
|
||||
reduction (int): Channel reduction ratio. Default: 2.
|
||||
use_scale (bool): Whether to scale pairwise_weight by
|
||||
`1/sqrt(inter_channels)` when the mode is `embedded_gaussian`.
|
||||
Default: True.
|
||||
conv_cfg (None | dict): The config dict for convolution layers.
|
||||
If not specified, it will use `nn.Conv2d` for convolution layers.
|
||||
Default: None.
|
||||
norm_cfg (None | dict): The config dict for normalization layers.
|
||||
Default: None. (This parameter is only applicable to conv_out.)
|
||||
mode (str): Options are `embedded_gaussian` and `dot_product`.
|
||||
Default: embedded_gaussian.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
reduction=2,
|
||||
use_scale=True,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
mode='embedded_gaussian',
|
||||
**kwargs):
|
||||
super(_NonLocalNd, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.reduction = reduction
|
||||
self.use_scale = use_scale
|
||||
self.inter_channels = in_channels // reduction
|
||||
self.mode = mode
|
||||
|
||||
if mode not in ['embedded_gaussian', 'dot_product']:
|
||||
raise ValueError(
|
||||
"Mode should be in 'embedded_gaussian' or 'dot_product', "
|
||||
f'but got {mode} instead.')
|
||||
|
||||
# g, theta, phi are defaulted as `nn.ConvNd`.
|
||||
# Here we use ConvModule for potential usage.
|
||||
self.g = ConvModule(
|
||||
self.in_channels,
|
||||
self.inter_channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=conv_cfg,
|
||||
act_cfg=None)
|
||||
self.theta = ConvModule(
|
||||
self.in_channels,
|
||||
self.inter_channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=conv_cfg,
|
||||
act_cfg=None)
|
||||
self.phi = ConvModule(
|
||||
self.in_channels,
|
||||
self.inter_channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=conv_cfg,
|
||||
act_cfg=None)
|
||||
self.conv_out = ConvModule(
|
||||
self.inter_channels,
|
||||
self.in_channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
self.init_weights(**kwargs)
|
||||
|
||||
def init_weights(self, std=0.01, zeros_init=True):
|
||||
for m in [self.g, self.theta, self.phi]:
|
||||
normal_init(m.conv, std=std)
|
||||
if zeros_init:
|
||||
if self.conv_out.norm_cfg is None:
|
||||
constant_init(self.conv_out.conv, 0)
|
||||
else:
|
||||
constant_init(self.conv_out.norm, 0)
|
||||
else:
|
||||
if self.conv_out.norm_cfg is None:
|
||||
normal_init(self.conv_out.conv, std=std)
|
||||
else:
|
||||
normal_init(self.conv_out.norm, std=std)
|
||||
|
||||
def embedded_gaussian(self, theta_x, phi_x):
|
||||
# NonLocal1d pairwise_weight: [N, H, H]
|
||||
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
||||
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
||||
pairwise_weight = torch.matmul(theta_x, phi_x)
|
||||
if self.use_scale:
|
||||
# theta_x.shape[-1] is `self.inter_channels`
|
||||
pairwise_weight /= theta_x.shape[-1]**0.5
|
||||
pairwise_weight = pairwise_weight.softmax(dim=-1)
|
||||
return pairwise_weight
|
||||
|
||||
def dot_product(self, theta_x, phi_x):
|
||||
# NonLocal1d pairwise_weight: [N, H, H]
|
||||
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
||||
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
||||
pairwise_weight = torch.matmul(theta_x, phi_x)
|
||||
pairwise_weight /= pairwise_weight.shape[-1]
|
||||
return pairwise_weight
|
||||
|
||||
def forward(self, x):
|
||||
# Assume `reduction = 1`, then `inter_channels = C`
|
||||
# NonLocal1d x: [N, C, H]
|
||||
# NonLocal2d x: [N, C, H, W]
|
||||
# NonLocal3d x: [N, C, T, H, W]
|
||||
n = x.size(0)
|
||||
|
||||
# NonLocal1d g_x: [N, H, C]
|
||||
# NonLocal2d g_x: [N, HxW, C]
|
||||
# NonLocal3d g_x: [N, TxHxW, C]
|
||||
g_x = self.g(x).view(n, self.inter_channels, -1)
|
||||
g_x = g_x.permute(0, 2, 1)
|
||||
|
||||
# NonLocal1d theta_x: [N, H, C]
|
||||
# NonLocal2d theta_x: [N, HxW, C]
|
||||
# NonLocal3d theta_x: [N, TxHxW, C]
|
||||
theta_x = self.theta(x).view(n, self.inter_channels, -1)
|
||||
theta_x = theta_x.permute(0, 2, 1)
|
||||
|
||||
# NonLocal1d phi_x: [N, C, H]
|
||||
# NonLocal2d phi_x: [N, C, HxW]
|
||||
# NonLocal3d phi_x: [N, C, TxHxW]
|
||||
phi_x = self.phi(x).view(n, self.inter_channels, -1)
|
||||
|
||||
pairwise_func = getattr(self, self.mode)
|
||||
# NonLocal1d pairwise_weight: [N, H, H]
|
||||
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
||||
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
||||
pairwise_weight = pairwise_func(theta_x, phi_x)
|
||||
|
||||
# NonLocal1d y: [N, H, C]
|
||||
# NonLocal2d y: [N, HxW, C]
|
||||
# NonLocal3d y: [N, TxHxW, C]
|
||||
y = torch.matmul(pairwise_weight, g_x)
|
||||
# NonLocal1d y: [N, C, H]
|
||||
# NonLocal2d y: [N, C, H, W]
|
||||
# NonLocal3d y: [N, C, T, H, W]
|
||||
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
|
||||
*x.size()[2:])
|
||||
|
||||
output = x + self.conv_out(y)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class NonLocal1d(_NonLocalNd):
|
||||
"""1D Non-local module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Same as `NonLocalND`.
|
||||
sub_sample (bool): Whether to apply max pooling after pairwise
|
||||
function (Note that the `sub_sample` is applied on spatial only).
|
||||
Default: False.
|
||||
conv_cfg (None | dict): Same as `NonLocalND`.
|
||||
Default: dict(type='Conv1d').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
sub_sample=False,
|
||||
conv_cfg=dict(type='Conv1d'),
|
||||
**kwargs):
|
||||
super(NonLocal1d, self).__init__(
|
||||
in_channels, conv_cfg=conv_cfg, **kwargs)
|
||||
|
||||
self.sub_sample = sub_sample
|
||||
|
||||
if sub_sample:
|
||||
max_pool_layer = nn.MaxPool1d(kernel_size=2)
|
||||
self.g = nn.Sequential(self.g, max_pool_layer)
|
||||
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
||||
|
||||
|
||||
class NonLocal2d(_NonLocalNd):
|
||||
"""2D Non-local module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Same as `NonLocalND`.
|
||||
sub_sample (bool): Whether to apply max pooling after pairwise
|
||||
function (Note that the `sub_sample` is applied on spatial only).
|
||||
Default: False.
|
||||
conv_cfg (None | dict): Same as `NonLocalND`.
|
||||
Default: dict(type='Conv2d').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
sub_sample=False,
|
||||
conv_cfg=dict(type='Conv2d'),
|
||||
**kwargs):
|
||||
super(NonLocal2d, self).__init__(
|
||||
in_channels, conv_cfg=conv_cfg, **kwargs)
|
||||
|
||||
self.sub_sample = sub_sample
|
||||
|
||||
if sub_sample:
|
||||
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
|
||||
self.g = nn.Sequential(self.g, max_pool_layer)
|
||||
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
||||
|
||||
|
||||
class NonLocal3d(_NonLocalNd):
|
||||
"""3D Non-local module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Same as `NonLocalND`.
|
||||
sub_sample (bool): Whether to apply max pooling after pairwise
|
||||
function (Note that the `sub_sample` is applied on spatial only).
|
||||
Default: False.
|
||||
conv_cfg (None | dict): Same as `NonLocalND`.
|
||||
Default: dict(type='Conv3d').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
sub_sample=False,
|
||||
conv_cfg=dict(type='Conv3d'),
|
||||
**kwargs):
|
||||
super(NonLocal3d, self).__init__(
|
||||
in_channels, conv_cfg=conv_cfg, **kwargs)
|
||||
self.sub_sample = sub_sample
|
||||
|
||||
if sub_sample:
|
||||
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
|
||||
self.g = nn.Sequential(self.g, max_pool_layer)
|
||||
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
|
@ -0,0 +1,89 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcv.cnn import NonLocal1d, NonLocal2d, NonLocal3d
|
||||
from mmcv.cnn.bricks.non_local import _NonLocalNd
|
||||
|
||||
|
||||
def test_nonlocal():
|
||||
with pytest.raises(ValueError):
|
||||
# mode should be in ['embedded_gaussian', 'dot_product']
|
||||
_NonLocalNd(3, mode='unsupport_mode')
|
||||
|
||||
# _NonLocalNd
|
||||
_NonLocalNd(3, norm_cfg=dict(type='BN'))
|
||||
# Not Zero initialization
|
||||
_NonLocalNd(3, norm_cfg=dict(type='BN'), zeros_init=True)
|
||||
|
||||
# NonLocal3d
|
||||
imgs = torch.randn(2, 3, 10, 20, 20)
|
||||
nonlocal_3d = NonLocal3d(3)
|
||||
if torch.__version__ == 'parrots':
|
||||
if torch.cuda.is_available():
|
||||
# NonLocal is only implemented on gpu in parrots
|
||||
imgs = imgs.cuda()
|
||||
nonlocal_3d.cuda()
|
||||
out = nonlocal_3d(imgs)
|
||||
assert out.shape == imgs.shape
|
||||
|
||||
nonlocal_3d = NonLocal3d(3, mode='dot_product')
|
||||
assert nonlocal_3d.mode == 'dot_product'
|
||||
if torch.__version__ == 'parrots':
|
||||
if torch.cuda.is_available():
|
||||
nonlocal_3d.cuda()
|
||||
out = nonlocal_3d(imgs)
|
||||
assert out.shape == imgs.shape
|
||||
|
||||
nonlocal_3d = NonLocal3d(3, mode='dot_product', sub_sample=True)
|
||||
for m in [nonlocal_3d.g, nonlocal_3d.phi]:
|
||||
assert isinstance(m, nn.Sequential) and len(m) == 2
|
||||
assert isinstance(m[1], nn.MaxPool3d)
|
||||
assert m[1].kernel_size == (1, 2, 2)
|
||||
if torch.__version__ == 'parrots':
|
||||
if torch.cuda.is_available():
|
||||
nonlocal_3d.cuda()
|
||||
out = nonlocal_3d(imgs)
|
||||
assert out.shape == imgs.shape
|
||||
|
||||
# NonLocal2d
|
||||
imgs = torch.randn(2, 3, 20, 20)
|
||||
nonlocal_2d = NonLocal2d(3)
|
||||
if torch.__version__ == 'parrots':
|
||||
if torch.cuda.is_available():
|
||||
imgs = imgs.cuda()
|
||||
nonlocal_2d.cuda()
|
||||
out = nonlocal_2d(imgs)
|
||||
assert out.shape == imgs.shape
|
||||
|
||||
nonlocal_2d = NonLocal2d(3, mode='dot_product', sub_sample=True)
|
||||
for m in [nonlocal_2d.g, nonlocal_2d.phi]:
|
||||
assert isinstance(m, nn.Sequential) and len(m) == 2
|
||||
assert isinstance(m[1], nn.MaxPool2d)
|
||||
assert m[1].kernel_size == (2, 2)
|
||||
if torch.__version__ == 'parrots':
|
||||
if torch.cuda.is_available():
|
||||
nonlocal_2d.cuda()
|
||||
out = nonlocal_2d(imgs)
|
||||
assert out.shape == imgs.shape
|
||||
|
||||
# NonLocal1d
|
||||
imgs = torch.randn(2, 3, 20)
|
||||
nonlocal_1d = NonLocal1d(3)
|
||||
if torch.__version__ == 'parrots':
|
||||
if torch.cuda.is_available():
|
||||
imgs = imgs.cuda()
|
||||
nonlocal_1d.cuda()
|
||||
out = nonlocal_1d(imgs)
|
||||
assert out.shape == imgs.shape
|
||||
|
||||
nonlocal_1d = NonLocal1d(3, mode='dot_product', sub_sample=True)
|
||||
for m in [nonlocal_1d.g, nonlocal_1d.phi]:
|
||||
assert isinstance(m, nn.Sequential) and len(m) == 2
|
||||
assert isinstance(m[1], nn.MaxPool1d)
|
||||
assert m[1].kernel_size == 2
|
||||
if torch.__version__ == 'parrots':
|
||||
if torch.cuda.is_available():
|
||||
nonlocal_1d.cuda()
|
||||
out = nonlocal_1d(imgs)
|
||||
assert out.shape == imgs.shape
|
Loading…
Reference in New Issue