mirror of https://github.com/open-mmlab/mmcv.git
Add pairwise function for 'gaussian' and 'concatenation' mode in NonLocal. (#424)
* add pairwise function for 'gaussian' and 'concatenation' mode * rename test function * decrease the complexity of nonlocal unittest * fix typo and make unittest more complete * add unittest when zero_init is False * minor fix * pack theta and phi Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>pull/434/head
parent
49fdf3cfa0
commit
6ece0e5d19
|
@ -14,6 +14,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
||||||
This module is proposed in
|
This module is proposed in
|
||||||
"Non-local Neural Networks"
|
"Non-local Neural Networks"
|
||||||
Paper reference: https://arxiv.org/abs/1711.07971
|
Paper reference: https://arxiv.org/abs/1711.07971
|
||||||
|
Code reference: https://github.com/AlexHex7/Non-local_pytorch
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_channels (int): Channels of the input feature map.
|
in_channels (int): Channels of the input feature map.
|
||||||
|
@ -26,8 +27,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
||||||
Default: None.
|
Default: None.
|
||||||
norm_cfg (None | dict): The config dict for normalization layers.
|
norm_cfg (None | dict): The config dict for normalization layers.
|
||||||
Default: None. (This parameter is only applicable to conv_out.)
|
Default: None. (This parameter is only applicable to conv_out.)
|
||||||
mode (str): Options are `embedded_gaussian` and `dot_product`.
|
mode (str): Options are `gaussian`, `concatenation`,
|
||||||
Default: embedded_gaussian.
|
`embedded_gaussian` and `dot_product`. Default: embedded_gaussian.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -42,13 +43,15 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
self.use_scale = use_scale
|
self.use_scale = use_scale
|
||||||
self.inter_channels = in_channels // reduction
|
self.inter_channels = max(in_channels // reduction, 1)
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
|
||||||
if mode not in ['embedded_gaussian', 'dot_product']:
|
if mode not in [
|
||||||
raise ValueError(
|
'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation'
|
||||||
"Mode should be in 'embedded_gaussian' or 'dot_product', "
|
]:
|
||||||
f'but got {mode} instead.')
|
raise ValueError("Mode should be in 'gaussian', 'concatenation', "
|
||||||
|
f"'embedded_gaussian' or 'dot_product', but got "
|
||||||
|
f'{mode} instead.')
|
||||||
|
|
||||||
# g, theta, phi are defaulted as `nn.ConvNd`.
|
# g, theta, phi are defaulted as `nn.ConvNd`.
|
||||||
# Here we use ConvModule for potential usage.
|
# Here we use ConvModule for potential usage.
|
||||||
|
@ -58,18 +61,6 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
conv_cfg=conv_cfg,
|
conv_cfg=conv_cfg,
|
||||||
act_cfg=None)
|
act_cfg=None)
|
||||||
self.theta = ConvModule(
|
|
||||||
self.in_channels,
|
|
||||||
self.inter_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
conv_cfg=conv_cfg,
|
|
||||||
act_cfg=None)
|
|
||||||
self.phi = ConvModule(
|
|
||||||
self.in_channels,
|
|
||||||
self.inter_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
conv_cfg=conv_cfg,
|
|
||||||
act_cfg=None)
|
|
||||||
self.conv_out = ConvModule(
|
self.conv_out = ConvModule(
|
||||||
self.inter_channels,
|
self.inter_channels,
|
||||||
self.in_channels,
|
self.in_channels,
|
||||||
|
@ -78,11 +69,38 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
||||||
norm_cfg=norm_cfg,
|
norm_cfg=norm_cfg,
|
||||||
act_cfg=None)
|
act_cfg=None)
|
||||||
|
|
||||||
|
if self.mode != 'gaussian':
|
||||||
|
self.theta = ConvModule(
|
||||||
|
self.in_channels,
|
||||||
|
self.inter_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
conv_cfg=conv_cfg,
|
||||||
|
act_cfg=None)
|
||||||
|
self.phi = ConvModule(
|
||||||
|
self.in_channels,
|
||||||
|
self.inter_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
conv_cfg=conv_cfg,
|
||||||
|
act_cfg=None)
|
||||||
|
|
||||||
|
if self.mode == 'concatenation':
|
||||||
|
self.concat_project = ConvModule(
|
||||||
|
self.inter_channels * 2,
|
||||||
|
1,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=False,
|
||||||
|
act_cfg=dict(type='ReLU'))
|
||||||
|
|
||||||
self.init_weights(**kwargs)
|
self.init_weights(**kwargs)
|
||||||
|
|
||||||
def init_weights(self, std=0.01, zeros_init=True):
|
def init_weights(self, std=0.01, zeros_init=True):
|
||||||
for m in [self.g, self.theta, self.phi]:
|
if self.mode != 'gaussian':
|
||||||
normal_init(m.conv, std=std)
|
for m in [self.g, self.theta, self.phi]:
|
||||||
|
normal_init(m.conv, std=std)
|
||||||
|
else:
|
||||||
|
normal_init(self.g.conv, std=std)
|
||||||
if zeros_init:
|
if zeros_init:
|
||||||
if self.conv_out.norm_cfg is None:
|
if self.conv_out.norm_cfg is None:
|
||||||
constant_init(self.conv_out.conv, 0)
|
constant_init(self.conv_out.conv, 0)
|
||||||
|
@ -94,6 +112,14 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
||||||
else:
|
else:
|
||||||
normal_init(self.conv_out.norm, std=std)
|
normal_init(self.conv_out.norm, std=std)
|
||||||
|
|
||||||
|
def gaussian(self, theta_x, phi_x):
|
||||||
|
# NonLocal1d pairwise_weight: [N, H, H]
|
||||||
|
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
||||||
|
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
||||||
|
pairwise_weight = torch.matmul(theta_x, phi_x)
|
||||||
|
pairwise_weight = pairwise_weight.softmax(dim=-1)
|
||||||
|
return pairwise_weight
|
||||||
|
|
||||||
def embedded_gaussian(self, theta_x, phi_x):
|
def embedded_gaussian(self, theta_x, phi_x):
|
||||||
# NonLocal1d pairwise_weight: [N, H, H]
|
# NonLocal1d pairwise_weight: [N, H, H]
|
||||||
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
||||||
|
@ -113,8 +139,27 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
||||||
pairwise_weight /= pairwise_weight.shape[-1]
|
pairwise_weight /= pairwise_weight.shape[-1]
|
||||||
return pairwise_weight
|
return pairwise_weight
|
||||||
|
|
||||||
|
def concatenation(self, theta_x, phi_x):
|
||||||
|
# NonLocal1d pairwise_weight: [N, H, H]
|
||||||
|
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
||||||
|
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
||||||
|
h = theta_x.size(2)
|
||||||
|
w = phi_x.size(3)
|
||||||
|
theta_x = theta_x.repeat(1, 1, 1, w)
|
||||||
|
phi_x = phi_x.repeat(1, 1, h, 1)
|
||||||
|
|
||||||
|
concat_feature = torch.cat([theta_x, phi_x], dim=1)
|
||||||
|
pairwise_weight = self.concat_project(concat_feature)
|
||||||
|
n, _, h, w = pairwise_weight.size()
|
||||||
|
pairwise_weight = pairwise_weight.view(n, h, w)
|
||||||
|
pairwise_weight /= pairwise_weight.shape[-1]
|
||||||
|
|
||||||
|
return pairwise_weight
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Assume `reduction = 1`, then `inter_channels = C`
|
# Assume `reduction = 1`, then `inter_channels = C`
|
||||||
|
# or `inter_channels = C` when `mode="gaussian"`
|
||||||
|
|
||||||
# NonLocal1d x: [N, C, H]
|
# NonLocal1d x: [N, C, H]
|
||||||
# NonLocal2d x: [N, C, H, W]
|
# NonLocal2d x: [N, C, H, W]
|
||||||
# NonLocal3d x: [N, C, T, H, W]
|
# NonLocal3d x: [N, C, T, H, W]
|
||||||
|
@ -126,16 +171,23 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
||||||
g_x = self.g(x).view(n, self.inter_channels, -1)
|
g_x = self.g(x).view(n, self.inter_channels, -1)
|
||||||
g_x = g_x.permute(0, 2, 1)
|
g_x = g_x.permute(0, 2, 1)
|
||||||
|
|
||||||
# NonLocal1d theta_x: [N, H, C]
|
# NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H]
|
||||||
# NonLocal2d theta_x: [N, HxW, C]
|
# NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW]
|
||||||
# NonLocal3d theta_x: [N, TxHxW, C]
|
# NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW]
|
||||||
theta_x = self.theta(x).view(n, self.inter_channels, -1)
|
if self.mode == 'gaussian':
|
||||||
theta_x = theta_x.permute(0, 2, 1)
|
theta_x = x.view(n, self.in_channels, -1)
|
||||||
|
theta_x = theta_x.permute(0, 2, 1)
|
||||||
# NonLocal1d phi_x: [N, C, H]
|
if self.sub_sample:
|
||||||
# NonLocal2d phi_x: [N, C, HxW]
|
phi_x = self.phi(x).view(n, self.in_channels, -1)
|
||||||
# NonLocal3d phi_x: [N, C, TxHxW]
|
else:
|
||||||
phi_x = self.phi(x).view(n, self.inter_channels, -1)
|
phi_x = x.view(n, self.in_channels, -1)
|
||||||
|
elif self.mode == 'concatenation':
|
||||||
|
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
|
||||||
|
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
|
||||||
|
else:
|
||||||
|
theta_x = self.theta(x).view(n, self.inter_channels, -1)
|
||||||
|
theta_x = theta_x.permute(0, 2, 1)
|
||||||
|
phi_x = self.phi(x).view(n, self.inter_channels, -1)
|
||||||
|
|
||||||
pairwise_func = getattr(self, self.mode)
|
pairwise_func = getattr(self, self.mode)
|
||||||
# NonLocal1d pairwise_weight: [N, H, H]
|
# NonLocal1d pairwise_weight: [N, H, H]
|
||||||
|
@ -183,7 +235,10 @@ class NonLocal1d(_NonLocalNd):
|
||||||
if sub_sample:
|
if sub_sample:
|
||||||
max_pool_layer = nn.MaxPool1d(kernel_size=2)
|
max_pool_layer = nn.MaxPool1d(kernel_size=2)
|
||||||
self.g = nn.Sequential(self.g, max_pool_layer)
|
self.g = nn.Sequential(self.g, max_pool_layer)
|
||||||
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
if self.mode != 'gaussian':
|
||||||
|
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
||||||
|
else:
|
||||||
|
self.phi = max_pool_layer
|
||||||
|
|
||||||
|
|
||||||
@PLUGIN_LAYERS.register_module()
|
@PLUGIN_LAYERS.register_module()
|
||||||
|
@ -214,7 +269,10 @@ class NonLocal2d(_NonLocalNd):
|
||||||
if sub_sample:
|
if sub_sample:
|
||||||
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
|
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
|
||||||
self.g = nn.Sequential(self.g, max_pool_layer)
|
self.g = nn.Sequential(self.g, max_pool_layer)
|
||||||
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
if self.mode != 'gaussian':
|
||||||
|
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
||||||
|
else:
|
||||||
|
self.phi = max_pool_layer
|
||||||
|
|
||||||
|
|
||||||
class NonLocal3d(_NonLocalNd):
|
class NonLocal3d(_NonLocalNd):
|
||||||
|
@ -241,4 +299,7 @@ class NonLocal3d(_NonLocalNd):
|
||||||
if sub_sample:
|
if sub_sample:
|
||||||
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
|
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
|
||||||
self.g = nn.Sequential(self.g, max_pool_layer)
|
self.g = nn.Sequential(self.g, max_pool_layer)
|
||||||
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
if self.mode != 'gaussian':
|
||||||
|
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
||||||
|
else:
|
||||||
|
self.phi = max_pool_layer
|
||||||
|
|
|
@ -11,12 +11,17 @@ def test_nonlocal():
|
||||||
# mode should be in ['embedded_gaussian', 'dot_product']
|
# mode should be in ['embedded_gaussian', 'dot_product']
|
||||||
_NonLocalNd(3, mode='unsupport_mode')
|
_NonLocalNd(3, mode='unsupport_mode')
|
||||||
|
|
||||||
# _NonLocalNd
|
# _NonLocalNd with zero initialization
|
||||||
|
_NonLocalNd(3)
|
||||||
_NonLocalNd(3, norm_cfg=dict(type='BN'))
|
_NonLocalNd(3, norm_cfg=dict(type='BN'))
|
||||||
# Not Zero initialization
|
|
||||||
_NonLocalNd(3, norm_cfg=dict(type='BN'), zeros_init=True)
|
|
||||||
|
|
||||||
# NonLocal3d
|
# _NonLocalNd without zero initialization
|
||||||
|
_NonLocalNd(3, zeros_init=False)
|
||||||
|
_NonLocalNd(3, norm_cfg=dict(type='BN'), zeros_init=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nonlocal3d():
|
||||||
|
# NonLocal3d with 'embedded_gaussian' mode
|
||||||
imgs = torch.randn(2, 3, 10, 20, 20)
|
imgs = torch.randn(2, 3, 10, 20, 20)
|
||||||
nonlocal_3d = NonLocal3d(3)
|
nonlocal_3d = NonLocal3d(3)
|
||||||
if torch.__version__ == 'parrots':
|
if torch.__version__ == 'parrots':
|
||||||
|
@ -27,6 +32,7 @@ def test_nonlocal():
|
||||||
out = nonlocal_3d(imgs)
|
out = nonlocal_3d(imgs)
|
||||||
assert out.shape == imgs.shape
|
assert out.shape == imgs.shape
|
||||||
|
|
||||||
|
# NonLocal3d with 'dot_product' mode
|
||||||
nonlocal_3d = NonLocal3d(3, mode='dot_product')
|
nonlocal_3d = NonLocal3d(3, mode='dot_product')
|
||||||
assert nonlocal_3d.mode == 'dot_product'
|
assert nonlocal_3d.mode == 'dot_product'
|
||||||
if torch.__version__ == 'parrots':
|
if torch.__version__ == 'parrots':
|
||||||
|
@ -35,6 +41,38 @@ def test_nonlocal():
|
||||||
out = nonlocal_3d(imgs)
|
out = nonlocal_3d(imgs)
|
||||||
assert out.shape == imgs.shape
|
assert out.shape == imgs.shape
|
||||||
|
|
||||||
|
# NonLocal3d with 'concatenation' mode
|
||||||
|
nonlocal_3d = NonLocal3d(3, mode='concatenation')
|
||||||
|
assert nonlocal_3d.mode == 'concatenation'
|
||||||
|
if torch.__version__ == 'parrots':
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
nonlocal_3d.cuda()
|
||||||
|
out = nonlocal_3d(imgs)
|
||||||
|
assert out.shape == imgs.shape
|
||||||
|
|
||||||
|
# NonLocal3d with 'gaussian' mode
|
||||||
|
nonlocal_3d = NonLocal3d(3, mode='gaussian')
|
||||||
|
assert not hasattr(nonlocal_3d, 'phi')
|
||||||
|
assert nonlocal_3d.mode == 'gaussian'
|
||||||
|
if torch.__version__ == 'parrots':
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
nonlocal_3d.cuda()
|
||||||
|
out = nonlocal_3d(imgs)
|
||||||
|
assert out.shape == imgs.shape
|
||||||
|
|
||||||
|
# NonLocal3d with 'gaussian' mode and sub_sample
|
||||||
|
nonlocal_3d = NonLocal3d(3, mode='gaussian', sub_sample=True)
|
||||||
|
assert isinstance(nonlocal_3d.g, nn.Sequential) and len(nonlocal_3d.g) == 2
|
||||||
|
assert isinstance(nonlocal_3d.g[1], nn.MaxPool3d)
|
||||||
|
assert nonlocal_3d.g[1].kernel_size == (1, 2, 2)
|
||||||
|
assert isinstance(nonlocal_3d.phi, nn.MaxPool3d)
|
||||||
|
if torch.__version__ == 'parrots':
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
nonlocal_3d.cuda()
|
||||||
|
out = nonlocal_3d(imgs)
|
||||||
|
assert out.shape == imgs.shape
|
||||||
|
|
||||||
|
# NonLocal3d with 'dot_product' mode and sub_sample
|
||||||
nonlocal_3d = NonLocal3d(3, mode='dot_product', sub_sample=True)
|
nonlocal_3d = NonLocal3d(3, mode='dot_product', sub_sample=True)
|
||||||
for m in [nonlocal_3d.g, nonlocal_3d.phi]:
|
for m in [nonlocal_3d.g, nonlocal_3d.phi]:
|
||||||
assert isinstance(m, nn.Sequential) and len(m) == 2
|
assert isinstance(m, nn.Sequential) and len(m) == 2
|
||||||
|
@ -46,7 +84,9 @@ def test_nonlocal():
|
||||||
out = nonlocal_3d(imgs)
|
out = nonlocal_3d(imgs)
|
||||||
assert out.shape == imgs.shape
|
assert out.shape == imgs.shape
|
||||||
|
|
||||||
# NonLocal2d
|
|
||||||
|
def test_nonlocal2d():
|
||||||
|
# NonLocal2d with 'embedded_gaussian' mode
|
||||||
imgs = torch.randn(2, 3, 20, 20)
|
imgs = torch.randn(2, 3, 20, 20)
|
||||||
nonlocal_2d = NonLocal2d(3)
|
nonlocal_2d = NonLocal2d(3)
|
||||||
if torch.__version__ == 'parrots':
|
if torch.__version__ == 'parrots':
|
||||||
|
@ -56,6 +96,50 @@ def test_nonlocal():
|
||||||
out = nonlocal_2d(imgs)
|
out = nonlocal_2d(imgs)
|
||||||
assert out.shape == imgs.shape
|
assert out.shape == imgs.shape
|
||||||
|
|
||||||
|
# NonLocal2d with 'dot_product' mode
|
||||||
|
imgs = torch.randn(2, 3, 20, 20)
|
||||||
|
nonlocal_2d = NonLocal2d(3, mode='dot_product')
|
||||||
|
if torch.__version__ == 'parrots':
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
imgs = imgs.cuda()
|
||||||
|
nonlocal_2d.cuda()
|
||||||
|
out = nonlocal_2d(imgs)
|
||||||
|
assert out.shape == imgs.shape
|
||||||
|
|
||||||
|
# NonLocal2d with 'concatenation' mode
|
||||||
|
imgs = torch.randn(2, 3, 20, 20)
|
||||||
|
nonlocal_2d = NonLocal2d(3, mode='concatenation')
|
||||||
|
if torch.__version__ == 'parrots':
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
imgs = imgs.cuda()
|
||||||
|
nonlocal_2d.cuda()
|
||||||
|
out = nonlocal_2d(imgs)
|
||||||
|
assert out.shape == imgs.shape
|
||||||
|
|
||||||
|
# NonLocal2d with 'gaussian' mode
|
||||||
|
imgs = torch.randn(2, 3, 20, 20)
|
||||||
|
nonlocal_2d = NonLocal2d(3, mode='gaussian')
|
||||||
|
assert not hasattr(nonlocal_2d, 'phi')
|
||||||
|
if torch.__version__ == 'parrots':
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
imgs = imgs.cuda()
|
||||||
|
nonlocal_2d.cuda()
|
||||||
|
out = nonlocal_2d(imgs)
|
||||||
|
assert out.shape == imgs.shape
|
||||||
|
|
||||||
|
# NonLocal2d with 'gaussian' mode and sub_sample
|
||||||
|
nonlocal_2d = NonLocal2d(3, mode='gaussian', sub_sample=True)
|
||||||
|
assert isinstance(nonlocal_2d.g, nn.Sequential) and len(nonlocal_2d.g) == 2
|
||||||
|
assert isinstance(nonlocal_2d.g[1], nn.MaxPool2d)
|
||||||
|
assert nonlocal_2d.g[1].kernel_size == (2, 2)
|
||||||
|
assert isinstance(nonlocal_2d.phi, nn.MaxPool2d)
|
||||||
|
if torch.__version__ == 'parrots':
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
nonlocal_2d.cuda()
|
||||||
|
out = nonlocal_2d(imgs)
|
||||||
|
assert out.shape == imgs.shape
|
||||||
|
|
||||||
|
# NonLocal2d with 'dot_product' mode and sub_sample
|
||||||
nonlocal_2d = NonLocal2d(3, mode='dot_product', sub_sample=True)
|
nonlocal_2d = NonLocal2d(3, mode='dot_product', sub_sample=True)
|
||||||
for m in [nonlocal_2d.g, nonlocal_2d.phi]:
|
for m in [nonlocal_2d.g, nonlocal_2d.phi]:
|
||||||
assert isinstance(m, nn.Sequential) and len(m) == 2
|
assert isinstance(m, nn.Sequential) and len(m) == 2
|
||||||
|
@ -67,7 +151,9 @@ def test_nonlocal():
|
||||||
out = nonlocal_2d(imgs)
|
out = nonlocal_2d(imgs)
|
||||||
assert out.shape == imgs.shape
|
assert out.shape == imgs.shape
|
||||||
|
|
||||||
# NonLocal1d
|
|
||||||
|
def test_nonlocal1d():
|
||||||
|
# NonLocal1d with 'embedded_gaussian' mode
|
||||||
imgs = torch.randn(2, 3, 20)
|
imgs = torch.randn(2, 3, 20)
|
||||||
nonlocal_1d = NonLocal1d(3)
|
nonlocal_1d = NonLocal1d(3)
|
||||||
if torch.__version__ == 'parrots':
|
if torch.__version__ == 'parrots':
|
||||||
|
@ -77,6 +163,50 @@ def test_nonlocal():
|
||||||
out = nonlocal_1d(imgs)
|
out = nonlocal_1d(imgs)
|
||||||
assert out.shape == imgs.shape
|
assert out.shape == imgs.shape
|
||||||
|
|
||||||
|
# NonLocal1d with 'dot_product' mode
|
||||||
|
imgs = torch.randn(2, 3, 20)
|
||||||
|
nonlocal_1d = NonLocal1d(3, mode='dot_product')
|
||||||
|
if torch.__version__ == 'parrots':
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
imgs = imgs.cuda()
|
||||||
|
nonlocal_1d.cuda()
|
||||||
|
out = nonlocal_1d(imgs)
|
||||||
|
assert out.shape == imgs.shape
|
||||||
|
|
||||||
|
# NonLocal1d with 'concatenation' mode
|
||||||
|
imgs = torch.randn(2, 3, 20)
|
||||||
|
nonlocal_1d = NonLocal1d(3, mode='concatenation')
|
||||||
|
if torch.__version__ == 'parrots':
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
imgs = imgs.cuda()
|
||||||
|
nonlocal_1d.cuda()
|
||||||
|
out = nonlocal_1d(imgs)
|
||||||
|
assert out.shape == imgs.shape
|
||||||
|
|
||||||
|
# NonLocal1d with 'gaussian' mode
|
||||||
|
imgs = torch.randn(2, 3, 20)
|
||||||
|
nonlocal_1d = NonLocal1d(3, mode='gaussian')
|
||||||
|
assert not hasattr(nonlocal_1d, 'phi')
|
||||||
|
if torch.__version__ == 'parrots':
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
imgs = imgs.cuda()
|
||||||
|
nonlocal_1d.cuda()
|
||||||
|
out = nonlocal_1d(imgs)
|
||||||
|
assert out.shape == imgs.shape
|
||||||
|
|
||||||
|
# NonLocal1d with 'gaussian' mode and sub_sample
|
||||||
|
nonlocal_1d = NonLocal1d(3, mode='gaussian', sub_sample=True)
|
||||||
|
assert isinstance(nonlocal_1d.g, nn.Sequential) and len(nonlocal_1d.g) == 2
|
||||||
|
assert isinstance(nonlocal_1d.g[1], nn.MaxPool1d)
|
||||||
|
assert nonlocal_1d.g[1].kernel_size == 2
|
||||||
|
assert isinstance(nonlocal_1d.phi, nn.MaxPool1d)
|
||||||
|
if torch.__version__ == 'parrots':
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
nonlocal_1d.cuda()
|
||||||
|
out = nonlocal_1d(imgs)
|
||||||
|
assert out.shape == imgs.shape
|
||||||
|
|
||||||
|
# NonLocal1d with 'dot_product' mode and sub_sample
|
||||||
nonlocal_1d = NonLocal1d(3, mode='dot_product', sub_sample=True)
|
nonlocal_1d = NonLocal1d(3, mode='dot_product', sub_sample=True)
|
||||||
for m in [nonlocal_1d.g, nonlocal_1d.phi]:
|
for m in [nonlocal_1d.g, nonlocal_1d.phi]:
|
||||||
assert isinstance(m, nn.Sequential) and len(m) == 2
|
assert isinstance(m, nn.Sequential) and len(m) == 2
|
||||||
|
|
Loading…
Reference in New Issue