Add pairwise function for 'gaussian' and 'concatenation' mode in NonLocal. (#424)

* add pairwise function for 'gaussian' and 'concatenation' mode

* rename test function

* decrease the complexity of nonlocal unittest

* fix typo and make unittest more complete

* add unittest when zero_init is False

* minor fix

* pack theta and phi

Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
pull/434/head
Jintao Lin 2020-07-17 23:44:50 +08:00 committed by GitHub
parent 49fdf3cfa0
commit 6ece0e5d19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 231 additions and 40 deletions

View File

@ -14,6 +14,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
This module is proposed in
"Non-local Neural Networks"
Paper reference: https://arxiv.org/abs/1711.07971
Code reference: https://github.com/AlexHex7/Non-local_pytorch
Args:
in_channels (int): Channels of the input feature map.
@ -26,8 +27,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
Default: None.
norm_cfg (None | dict): The config dict for normalization layers.
Default: None. (This parameter is only applicable to conv_out.)
mode (str): Options are `embedded_gaussian` and `dot_product`.
Default: embedded_gaussian.
mode (str): Options are `gaussian`, `concatenation`,
`embedded_gaussian` and `dot_product`. Default: embedded_gaussian.
"""
def __init__(self,
@ -42,13 +43,15 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
self.in_channels = in_channels
self.reduction = reduction
self.use_scale = use_scale
self.inter_channels = in_channels // reduction
self.inter_channels = max(in_channels // reduction, 1)
self.mode = mode
if mode not in ['embedded_gaussian', 'dot_product']:
raise ValueError(
"Mode should be in 'embedded_gaussian' or 'dot_product', "
f'but got {mode} instead.')
if mode not in [
'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation'
]:
raise ValueError("Mode should be in 'gaussian', 'concatenation', "
f"'embedded_gaussian' or 'dot_product', but got "
f'{mode} instead.')
# g, theta, phi are defaulted as `nn.ConvNd`.
# Here we use ConvModule for potential usage.
@ -58,18 +61,6 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
kernel_size=1,
conv_cfg=conv_cfg,
act_cfg=None)
self.theta = ConvModule(
self.in_channels,
self.inter_channels,
kernel_size=1,
conv_cfg=conv_cfg,
act_cfg=None)
self.phi = ConvModule(
self.in_channels,
self.inter_channels,
kernel_size=1,
conv_cfg=conv_cfg,
act_cfg=None)
self.conv_out = ConvModule(
self.inter_channels,
self.in_channels,
@ -78,11 +69,38 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
norm_cfg=norm_cfg,
act_cfg=None)
if self.mode != 'gaussian':
self.theta = ConvModule(
self.in_channels,
self.inter_channels,
kernel_size=1,
conv_cfg=conv_cfg,
act_cfg=None)
self.phi = ConvModule(
self.in_channels,
self.inter_channels,
kernel_size=1,
conv_cfg=conv_cfg,
act_cfg=None)
if self.mode == 'concatenation':
self.concat_project = ConvModule(
self.inter_channels * 2,
1,
kernel_size=1,
stride=1,
padding=0,
bias=False,
act_cfg=dict(type='ReLU'))
self.init_weights(**kwargs)
def init_weights(self, std=0.01, zeros_init=True):
for m in [self.g, self.theta, self.phi]:
normal_init(m.conv, std=std)
if self.mode != 'gaussian':
for m in [self.g, self.theta, self.phi]:
normal_init(m.conv, std=std)
else:
normal_init(self.g.conv, std=std)
if zeros_init:
if self.conv_out.norm_cfg is None:
constant_init(self.conv_out.conv, 0)
@ -94,6 +112,14 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
else:
normal_init(self.conv_out.norm, std=std)
def gaussian(self, theta_x, phi_x):
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
pairwise_weight = torch.matmul(theta_x, phi_x)
pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight
def embedded_gaussian(self, theta_x, phi_x):
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
@ -113,8 +139,27 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
pairwise_weight /= pairwise_weight.shape[-1]
return pairwise_weight
def concatenation(self, theta_x, phi_x):
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
h = theta_x.size(2)
w = phi_x.size(3)
theta_x = theta_x.repeat(1, 1, 1, w)
phi_x = phi_x.repeat(1, 1, h, 1)
concat_feature = torch.cat([theta_x, phi_x], dim=1)
pairwise_weight = self.concat_project(concat_feature)
n, _, h, w = pairwise_weight.size()
pairwise_weight = pairwise_weight.view(n, h, w)
pairwise_weight /= pairwise_weight.shape[-1]
return pairwise_weight
def forward(self, x):
# Assume `reduction = 1`, then `inter_channels = C`
# or `inter_channels = C` when `mode="gaussian"`
# NonLocal1d x: [N, C, H]
# NonLocal2d x: [N, C, H, W]
# NonLocal3d x: [N, C, T, H, W]
@ -126,16 +171,23 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
g_x = self.g(x).view(n, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
# NonLocal1d theta_x: [N, H, C]
# NonLocal2d theta_x: [N, HxW, C]
# NonLocal3d theta_x: [N, TxHxW, C]
theta_x = self.theta(x).view(n, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
# NonLocal1d phi_x: [N, C, H]
# NonLocal2d phi_x: [N, C, HxW]
# NonLocal3d phi_x: [N, C, TxHxW]
phi_x = self.phi(x).view(n, self.inter_channels, -1)
# NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H]
# NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW]
# NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW]
if self.mode == 'gaussian':
theta_x = x.view(n, self.in_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
if self.sub_sample:
phi_x = self.phi(x).view(n, self.in_channels, -1)
else:
phi_x = x.view(n, self.in_channels, -1)
elif self.mode == 'concatenation':
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
else:
theta_x = self.theta(x).view(n, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(n, self.inter_channels, -1)
pairwise_func = getattr(self, self.mode)
# NonLocal1d pairwise_weight: [N, H, H]
@ -183,7 +235,10 @@ class NonLocal1d(_NonLocalNd):
if sub_sample:
max_pool_layer = nn.MaxPool1d(kernel_size=2)
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
if self.mode != 'gaussian':
self.phi = nn.Sequential(self.phi, max_pool_layer)
else:
self.phi = max_pool_layer
@PLUGIN_LAYERS.register_module()
@ -214,7 +269,10 @@ class NonLocal2d(_NonLocalNd):
if sub_sample:
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
if self.mode != 'gaussian':
self.phi = nn.Sequential(self.phi, max_pool_layer)
else:
self.phi = max_pool_layer
class NonLocal3d(_NonLocalNd):
@ -241,4 +299,7 @@ class NonLocal3d(_NonLocalNd):
if sub_sample:
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
if self.mode != 'gaussian':
self.phi = nn.Sequential(self.phi, max_pool_layer)
else:
self.phi = max_pool_layer

View File

@ -11,12 +11,17 @@ def test_nonlocal():
# mode should be in ['embedded_gaussian', 'dot_product']
_NonLocalNd(3, mode='unsupport_mode')
# _NonLocalNd
# _NonLocalNd with zero initialization
_NonLocalNd(3)
_NonLocalNd(3, norm_cfg=dict(type='BN'))
# Not Zero initialization
_NonLocalNd(3, norm_cfg=dict(type='BN'), zeros_init=True)
# NonLocal3d
# _NonLocalNd without zero initialization
_NonLocalNd(3, zeros_init=False)
_NonLocalNd(3, norm_cfg=dict(type='BN'), zeros_init=False)
def test_nonlocal3d():
# NonLocal3d with 'embedded_gaussian' mode
imgs = torch.randn(2, 3, 10, 20, 20)
nonlocal_3d = NonLocal3d(3)
if torch.__version__ == 'parrots':
@ -27,6 +32,7 @@ def test_nonlocal():
out = nonlocal_3d(imgs)
assert out.shape == imgs.shape
# NonLocal3d with 'dot_product' mode
nonlocal_3d = NonLocal3d(3, mode='dot_product')
assert nonlocal_3d.mode == 'dot_product'
if torch.__version__ == 'parrots':
@ -35,6 +41,38 @@ def test_nonlocal():
out = nonlocal_3d(imgs)
assert out.shape == imgs.shape
# NonLocal3d with 'concatenation' mode
nonlocal_3d = NonLocal3d(3, mode='concatenation')
assert nonlocal_3d.mode == 'concatenation'
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
nonlocal_3d.cuda()
out = nonlocal_3d(imgs)
assert out.shape == imgs.shape
# NonLocal3d with 'gaussian' mode
nonlocal_3d = NonLocal3d(3, mode='gaussian')
assert not hasattr(nonlocal_3d, 'phi')
assert nonlocal_3d.mode == 'gaussian'
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
nonlocal_3d.cuda()
out = nonlocal_3d(imgs)
assert out.shape == imgs.shape
# NonLocal3d with 'gaussian' mode and sub_sample
nonlocal_3d = NonLocal3d(3, mode='gaussian', sub_sample=True)
assert isinstance(nonlocal_3d.g, nn.Sequential) and len(nonlocal_3d.g) == 2
assert isinstance(nonlocal_3d.g[1], nn.MaxPool3d)
assert nonlocal_3d.g[1].kernel_size == (1, 2, 2)
assert isinstance(nonlocal_3d.phi, nn.MaxPool3d)
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
nonlocal_3d.cuda()
out = nonlocal_3d(imgs)
assert out.shape == imgs.shape
# NonLocal3d with 'dot_product' mode and sub_sample
nonlocal_3d = NonLocal3d(3, mode='dot_product', sub_sample=True)
for m in [nonlocal_3d.g, nonlocal_3d.phi]:
assert isinstance(m, nn.Sequential) and len(m) == 2
@ -46,7 +84,9 @@ def test_nonlocal():
out = nonlocal_3d(imgs)
assert out.shape == imgs.shape
# NonLocal2d
def test_nonlocal2d():
# NonLocal2d with 'embedded_gaussian' mode
imgs = torch.randn(2, 3, 20, 20)
nonlocal_2d = NonLocal2d(3)
if torch.__version__ == 'parrots':
@ -56,6 +96,50 @@ def test_nonlocal():
out = nonlocal_2d(imgs)
assert out.shape == imgs.shape
# NonLocal2d with 'dot_product' mode
imgs = torch.randn(2, 3, 20, 20)
nonlocal_2d = NonLocal2d(3, mode='dot_product')
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
imgs = imgs.cuda()
nonlocal_2d.cuda()
out = nonlocal_2d(imgs)
assert out.shape == imgs.shape
# NonLocal2d with 'concatenation' mode
imgs = torch.randn(2, 3, 20, 20)
nonlocal_2d = NonLocal2d(3, mode='concatenation')
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
imgs = imgs.cuda()
nonlocal_2d.cuda()
out = nonlocal_2d(imgs)
assert out.shape == imgs.shape
# NonLocal2d with 'gaussian' mode
imgs = torch.randn(2, 3, 20, 20)
nonlocal_2d = NonLocal2d(3, mode='gaussian')
assert not hasattr(nonlocal_2d, 'phi')
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
imgs = imgs.cuda()
nonlocal_2d.cuda()
out = nonlocal_2d(imgs)
assert out.shape == imgs.shape
# NonLocal2d with 'gaussian' mode and sub_sample
nonlocal_2d = NonLocal2d(3, mode='gaussian', sub_sample=True)
assert isinstance(nonlocal_2d.g, nn.Sequential) and len(nonlocal_2d.g) == 2
assert isinstance(nonlocal_2d.g[1], nn.MaxPool2d)
assert nonlocal_2d.g[1].kernel_size == (2, 2)
assert isinstance(nonlocal_2d.phi, nn.MaxPool2d)
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
nonlocal_2d.cuda()
out = nonlocal_2d(imgs)
assert out.shape == imgs.shape
# NonLocal2d with 'dot_product' mode and sub_sample
nonlocal_2d = NonLocal2d(3, mode='dot_product', sub_sample=True)
for m in [nonlocal_2d.g, nonlocal_2d.phi]:
assert isinstance(m, nn.Sequential) and len(m) == 2
@ -67,7 +151,9 @@ def test_nonlocal():
out = nonlocal_2d(imgs)
assert out.shape == imgs.shape
# NonLocal1d
def test_nonlocal1d():
# NonLocal1d with 'embedded_gaussian' mode
imgs = torch.randn(2, 3, 20)
nonlocal_1d = NonLocal1d(3)
if torch.__version__ == 'parrots':
@ -77,6 +163,50 @@ def test_nonlocal():
out = nonlocal_1d(imgs)
assert out.shape == imgs.shape
# NonLocal1d with 'dot_product' mode
imgs = torch.randn(2, 3, 20)
nonlocal_1d = NonLocal1d(3, mode='dot_product')
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
imgs = imgs.cuda()
nonlocal_1d.cuda()
out = nonlocal_1d(imgs)
assert out.shape == imgs.shape
# NonLocal1d with 'concatenation' mode
imgs = torch.randn(2, 3, 20)
nonlocal_1d = NonLocal1d(3, mode='concatenation')
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
imgs = imgs.cuda()
nonlocal_1d.cuda()
out = nonlocal_1d(imgs)
assert out.shape == imgs.shape
# NonLocal1d with 'gaussian' mode
imgs = torch.randn(2, 3, 20)
nonlocal_1d = NonLocal1d(3, mode='gaussian')
assert not hasattr(nonlocal_1d, 'phi')
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
imgs = imgs.cuda()
nonlocal_1d.cuda()
out = nonlocal_1d(imgs)
assert out.shape == imgs.shape
# NonLocal1d with 'gaussian' mode and sub_sample
nonlocal_1d = NonLocal1d(3, mode='gaussian', sub_sample=True)
assert isinstance(nonlocal_1d.g, nn.Sequential) and len(nonlocal_1d.g) == 2
assert isinstance(nonlocal_1d.g[1], nn.MaxPool1d)
assert nonlocal_1d.g[1].kernel_size == 2
assert isinstance(nonlocal_1d.phi, nn.MaxPool1d)
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
nonlocal_1d.cuda()
out = nonlocal_1d(imgs)
assert out.shape == imgs.shape
# NonLocal1d with 'dot_product' mode and sub_sample
nonlocal_1d = NonLocal1d(3, mode='dot_product', sub_sample=True)
for m in [nonlocal_1d.g, nonlocal_1d.phi]:
assert isinstance(m, nn.Sequential) and len(m) == 2