From dcc20f3a4bccb1dda12649fc1faaaa61cca0a8a2 Mon Sep 17 00:00:00 2001 From: Jintao Lin Date: Thu, 18 Jun 2020 20:56:48 +0800 Subject: [PATCH] Add `_NonLocalNd` module (#331) * add non_local module * rewrite non local module comments * perfect docstring and adjust init function * not to init norm layer * Correct initialize when there is a norm * set normal method for both embedded_gaussian and dot_product --- mmcv/cnn/__init__.py | 12 +- mmcv/cnn/bricks/__init__.py | 5 +- mmcv/cnn/bricks/non_local.py | 240 +++++++++++++++++++++++++++++++ tests/test_cnn/test_non_local.py | 89 ++++++++++++ 4 files changed, 338 insertions(+), 8 deletions(-) create mode 100644 mmcv/cnn/bricks/non_local.py create mode 100644 tests/test_cnn/test_non_local.py diff --git a/mmcv/cnn/__init__.py b/mmcv/cnn/__init__.py index 71b1097ac..5ad1b4f84 100644 --- a/mmcv/cnn/__init__.py +++ b/mmcv/cnn/__init__.py @@ -1,9 +1,9 @@ # Copyright (c) Open-MMLab. All rights reserved. from .alexnet import AlexNet from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, - PADDING_LAYERS, UPSAMPLE_LAYERS, ConvModule, Scale, - build_activation_layer, build_conv_layer, - build_norm_layer, build_padding_layer, + PADDING_LAYERS, UPSAMPLE_LAYERS, ConvModule, NonLocal1d, + NonLocal2d, NonLocal3d, Scale, build_activation_layer, + build_conv_layer, build_norm_layer, build_padding_layer, build_upsample_layer, is_norm) from .resnet import ResNet, make_res_layer from .vgg import VGG, make_vgg_layer @@ -16,7 +16,7 @@ __all__ = [ 'constant_init', 'xavier_init', 'normal_init', 'uniform_init', 'kaiming_init', 'caffe2_xavier_init', 'bias_init_with_prob', 'ConvModule', 'build_activation_layer', 'build_conv_layer', 'build_norm_layer', - 'build_padding_layer', 'build_upsample_layer', 'is_norm', - 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', - 'UPSAMPLE_LAYERS', 'Scale' + 'build_padding_layer', 'build_upsample_layer', 'is_norm', 'NonLocal1d', + 'NonLocal2d', 'NonLocal3d', 'ACTIVATION_LAYERS', 'CONV_LAYERS', + 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'Scale' ] diff --git a/mmcv/cnn/bricks/__init__.py b/mmcv/cnn/bricks/__init__.py index 4155af1f9..9bbb95057 100644 --- a/mmcv/cnn/bricks/__init__.py +++ b/mmcv/cnn/bricks/__init__.py @@ -1,6 +1,7 @@ from .activation import build_activation_layer from .conv import build_conv_layer from .conv_module import ConvModule +from .non_local import NonLocal1d, NonLocal2d, NonLocal3d from .norm import build_norm_layer, is_norm from .padding import build_padding_layer from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, @@ -11,6 +12,6 @@ from .upsample import build_upsample_layer __all__ = [ 'ConvModule', 'build_activation_layer', 'build_conv_layer', 'build_norm_layer', 'build_padding_layer', 'build_upsample_layer', - 'is_norm', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', - 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'Scale' + 'is_norm', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d', 'ACTIVATION_LAYERS', + 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'Scale' ] diff --git a/mmcv/cnn/bricks/non_local.py b/mmcv/cnn/bricks/non_local.py new file mode 100644 index 000000000..474d98b05 --- /dev/null +++ b/mmcv/cnn/bricks/non_local.py @@ -0,0 +1,240 @@ +from abc import ABCMeta + +import torch +import torch.nn as nn + +from ..weight_init import constant_init, normal_init +from .conv_module import ConvModule + + +class _NonLocalNd(nn.Module, metaclass=ABCMeta): + """Basic Non-local module. + + This module is proposed in + "Non-local Neural Networks" + Paper reference: https://arxiv.org/abs/1711.07971 + + Args: + in_channels (int): Channels of the input feature map. + reduction (int): Channel reduction ratio. Default: 2. + use_scale (bool): Whether to scale pairwise_weight by + `1/sqrt(inter_channels)` when the mode is `embedded_gaussian`. + Default: True. + conv_cfg (None | dict): The config dict for convolution layers. + If not specified, it will use `nn.Conv2d` for convolution layers. + Default: None. + norm_cfg (None | dict): The config dict for normalization layers. + Default: None. (This parameter is only applicable to conv_out.) + mode (str): Options are `embedded_gaussian` and `dot_product`. + Default: embedded_gaussian. + """ + + def __init__(self, + in_channels, + reduction=2, + use_scale=True, + conv_cfg=None, + norm_cfg=None, + mode='embedded_gaussian', + **kwargs): + super(_NonLocalNd, self).__init__() + self.in_channels = in_channels + self.reduction = reduction + self.use_scale = use_scale + self.inter_channels = in_channels // reduction + self.mode = mode + + if mode not in ['embedded_gaussian', 'dot_product']: + raise ValueError( + "Mode should be in 'embedded_gaussian' or 'dot_product', " + f'but got {mode} instead.') + + # g, theta, phi are defaulted as `nn.ConvNd`. + # Here we use ConvModule for potential usage. + self.g = ConvModule( + self.in_channels, + self.inter_channels, + kernel_size=1, + conv_cfg=conv_cfg, + act_cfg=None) + self.theta = ConvModule( + self.in_channels, + self.inter_channels, + kernel_size=1, + conv_cfg=conv_cfg, + act_cfg=None) + self.phi = ConvModule( + self.in_channels, + self.inter_channels, + kernel_size=1, + conv_cfg=conv_cfg, + act_cfg=None) + self.conv_out = ConvModule( + self.inter_channels, + self.in_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + self.init_weights(**kwargs) + + def init_weights(self, std=0.01, zeros_init=True): + for m in [self.g, self.theta, self.phi]: + normal_init(m.conv, std=std) + if zeros_init: + if self.conv_out.norm_cfg is None: + constant_init(self.conv_out.conv, 0) + else: + constant_init(self.conv_out.norm, 0) + else: + if self.conv_out.norm_cfg is None: + normal_init(self.conv_out.conv, std=std) + else: + normal_init(self.conv_out.norm, std=std) + + def embedded_gaussian(self, theta_x, phi_x): + # NonLocal1d pairwise_weight: [N, H, H] + # NonLocal2d pairwise_weight: [N, HxW, HxW] + # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] + pairwise_weight = torch.matmul(theta_x, phi_x) + if self.use_scale: + # theta_x.shape[-1] is `self.inter_channels` + pairwise_weight /= theta_x.shape[-1]**0.5 + pairwise_weight = pairwise_weight.softmax(dim=-1) + return pairwise_weight + + def dot_product(self, theta_x, phi_x): + # NonLocal1d pairwise_weight: [N, H, H] + # NonLocal2d pairwise_weight: [N, HxW, HxW] + # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] + pairwise_weight = torch.matmul(theta_x, phi_x) + pairwise_weight /= pairwise_weight.shape[-1] + return pairwise_weight + + def forward(self, x): + # Assume `reduction = 1`, then `inter_channels = C` + # NonLocal1d x: [N, C, H] + # NonLocal2d x: [N, C, H, W] + # NonLocal3d x: [N, C, T, H, W] + n = x.size(0) + + # NonLocal1d g_x: [N, H, C] + # NonLocal2d g_x: [N, HxW, C] + # NonLocal3d g_x: [N, TxHxW, C] + g_x = self.g(x).view(n, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + # NonLocal1d theta_x: [N, H, C] + # NonLocal2d theta_x: [N, HxW, C] + # NonLocal3d theta_x: [N, TxHxW, C] + theta_x = self.theta(x).view(n, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + + # NonLocal1d phi_x: [N, C, H] + # NonLocal2d phi_x: [N, C, HxW] + # NonLocal3d phi_x: [N, C, TxHxW] + phi_x = self.phi(x).view(n, self.inter_channels, -1) + + pairwise_func = getattr(self, self.mode) + # NonLocal1d pairwise_weight: [N, H, H] + # NonLocal2d pairwise_weight: [N, HxW, HxW] + # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] + pairwise_weight = pairwise_func(theta_x, phi_x) + + # NonLocal1d y: [N, H, C] + # NonLocal2d y: [N, HxW, C] + # NonLocal3d y: [N, TxHxW, C] + y = torch.matmul(pairwise_weight, g_x) + # NonLocal1d y: [N, C, H] + # NonLocal2d y: [N, C, H, W] + # NonLocal3d y: [N, C, T, H, W] + y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels, + *x.size()[2:]) + + output = x + self.conv_out(y) + + return output + + +class NonLocal1d(_NonLocalNd): + """1D Non-local module. + + Args: + in_channels (int): Same as `NonLocalND`. + sub_sample (bool): Whether to apply max pooling after pairwise + function (Note that the `sub_sample` is applied on spatial only). + Default: False. + conv_cfg (None | dict): Same as `NonLocalND`. + Default: dict(type='Conv1d'). + """ + + def __init__(self, + in_channels, + sub_sample=False, + conv_cfg=dict(type='Conv1d'), + **kwargs): + super(NonLocal1d, self).__init__( + in_channels, conv_cfg=conv_cfg, **kwargs) + + self.sub_sample = sub_sample + + if sub_sample: + max_pool_layer = nn.MaxPool1d(kernel_size=2) + self.g = nn.Sequential(self.g, max_pool_layer) + self.phi = nn.Sequential(self.phi, max_pool_layer) + + +class NonLocal2d(_NonLocalNd): + """2D Non-local module. + + Args: + in_channels (int): Same as `NonLocalND`. + sub_sample (bool): Whether to apply max pooling after pairwise + function (Note that the `sub_sample` is applied on spatial only). + Default: False. + conv_cfg (None | dict): Same as `NonLocalND`. + Default: dict(type='Conv2d'). + """ + + def __init__(self, + in_channels, + sub_sample=False, + conv_cfg=dict(type='Conv2d'), + **kwargs): + super(NonLocal2d, self).__init__( + in_channels, conv_cfg=conv_cfg, **kwargs) + + self.sub_sample = sub_sample + + if sub_sample: + max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) + self.g = nn.Sequential(self.g, max_pool_layer) + self.phi = nn.Sequential(self.phi, max_pool_layer) + + +class NonLocal3d(_NonLocalNd): + """3D Non-local module. + + Args: + in_channels (int): Same as `NonLocalND`. + sub_sample (bool): Whether to apply max pooling after pairwise + function (Note that the `sub_sample` is applied on spatial only). + Default: False. + conv_cfg (None | dict): Same as `NonLocalND`. + Default: dict(type='Conv3d'). + """ + + def __init__(self, + in_channels, + sub_sample=False, + conv_cfg=dict(type='Conv3d'), + **kwargs): + super(NonLocal3d, self).__init__( + in_channels, conv_cfg=conv_cfg, **kwargs) + self.sub_sample = sub_sample + + if sub_sample: + max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) + self.g = nn.Sequential(self.g, max_pool_layer) + self.phi = nn.Sequential(self.phi, max_pool_layer) diff --git a/tests/test_cnn/test_non_local.py b/tests/test_cnn/test_non_local.py new file mode 100644 index 000000000..207b7b1fe --- /dev/null +++ b/tests/test_cnn/test_non_local.py @@ -0,0 +1,89 @@ +import pytest +import torch +import torch.nn as nn + +from mmcv.cnn import NonLocal1d, NonLocal2d, NonLocal3d +from mmcv.cnn.bricks.non_local import _NonLocalNd + + +def test_nonlocal(): + with pytest.raises(ValueError): + # mode should be in ['embedded_gaussian', 'dot_product'] + _NonLocalNd(3, mode='unsupport_mode') + + # _NonLocalNd + _NonLocalNd(3, norm_cfg=dict(type='BN')) + # Not Zero initialization + _NonLocalNd(3, norm_cfg=dict(type='BN'), zeros_init=True) + + # NonLocal3d + imgs = torch.randn(2, 3, 10, 20, 20) + nonlocal_3d = NonLocal3d(3) + if torch.__version__ == 'parrots': + if torch.cuda.is_available(): + # NonLocal is only implemented on gpu in parrots + imgs = imgs.cuda() + nonlocal_3d.cuda() + out = nonlocal_3d(imgs) + assert out.shape == imgs.shape + + nonlocal_3d = NonLocal3d(3, mode='dot_product') + assert nonlocal_3d.mode == 'dot_product' + if torch.__version__ == 'parrots': + if torch.cuda.is_available(): + nonlocal_3d.cuda() + out = nonlocal_3d(imgs) + assert out.shape == imgs.shape + + nonlocal_3d = NonLocal3d(3, mode='dot_product', sub_sample=True) + for m in [nonlocal_3d.g, nonlocal_3d.phi]: + assert isinstance(m, nn.Sequential) and len(m) == 2 + assert isinstance(m[1], nn.MaxPool3d) + assert m[1].kernel_size == (1, 2, 2) + if torch.__version__ == 'parrots': + if torch.cuda.is_available(): + nonlocal_3d.cuda() + out = nonlocal_3d(imgs) + assert out.shape == imgs.shape + + # NonLocal2d + imgs = torch.randn(2, 3, 20, 20) + nonlocal_2d = NonLocal2d(3) + if torch.__version__ == 'parrots': + if torch.cuda.is_available(): + imgs = imgs.cuda() + nonlocal_2d.cuda() + out = nonlocal_2d(imgs) + assert out.shape == imgs.shape + + nonlocal_2d = NonLocal2d(3, mode='dot_product', sub_sample=True) + for m in [nonlocal_2d.g, nonlocal_2d.phi]: + assert isinstance(m, nn.Sequential) and len(m) == 2 + assert isinstance(m[1], nn.MaxPool2d) + assert m[1].kernel_size == (2, 2) + if torch.__version__ == 'parrots': + if torch.cuda.is_available(): + nonlocal_2d.cuda() + out = nonlocal_2d(imgs) + assert out.shape == imgs.shape + + # NonLocal1d + imgs = torch.randn(2, 3, 20) + nonlocal_1d = NonLocal1d(3) + if torch.__version__ == 'parrots': + if torch.cuda.is_available(): + imgs = imgs.cuda() + nonlocal_1d.cuda() + out = nonlocal_1d(imgs) + assert out.shape == imgs.shape + + nonlocal_1d = NonLocal1d(3, mode='dot_product', sub_sample=True) + for m in [nonlocal_1d.g, nonlocal_1d.phi]: + assert isinstance(m, nn.Sequential) and len(m) == 2 + assert isinstance(m[1], nn.MaxPool1d) + assert m[1].kernel_size == 2 + if torch.__version__ == 'parrots': + if torch.cuda.is_available(): + nonlocal_1d.cuda() + out = nonlocal_1d(imgs) + assert out.shape == imgs.shape