mirror of https://github.com/open-mmlab/mmcv.git
221 lines
7.6 KiB
Python
221 lines
7.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmcv.cnn import NonLocal1d, NonLocal2d, NonLocal3d
|
|
from mmcv.cnn.bricks.non_local import _NonLocalNd
|
|
|
|
|
|
def test_nonlocal():
|
|
with pytest.raises(ValueError):
|
|
# mode should be in ['embedded_gaussian', 'dot_product']
|
|
_NonLocalNd(3, mode='unsupport_mode')
|
|
|
|
# _NonLocalNd with zero initialization
|
|
_NonLocalNd(3)
|
|
_NonLocalNd(3, norm_cfg=dict(type='BN'))
|
|
|
|
# _NonLocalNd without zero initialization
|
|
_NonLocalNd(3, zeros_init=False)
|
|
_NonLocalNd(3, norm_cfg=dict(type='BN'), zeros_init=False)
|
|
|
|
|
|
def test_nonlocal3d():
|
|
# NonLocal3d with 'embedded_gaussian' mode
|
|
imgs = torch.randn(2, 3, 10, 20, 20)
|
|
nonlocal_3d = NonLocal3d(3)
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
# NonLocal is only implemented on gpu in parrots
|
|
imgs = imgs.cuda()
|
|
nonlocal_3d.cuda()
|
|
out = nonlocal_3d(imgs)
|
|
assert out.shape == imgs.shape
|
|
|
|
# NonLocal3d with 'dot_product' mode
|
|
nonlocal_3d = NonLocal3d(3, mode='dot_product')
|
|
assert nonlocal_3d.mode == 'dot_product'
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
nonlocal_3d.cuda()
|
|
out = nonlocal_3d(imgs)
|
|
assert out.shape == imgs.shape
|
|
|
|
# NonLocal3d with 'concatenation' mode
|
|
nonlocal_3d = NonLocal3d(3, mode='concatenation')
|
|
assert nonlocal_3d.mode == 'concatenation'
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
nonlocal_3d.cuda()
|
|
out = nonlocal_3d(imgs)
|
|
assert out.shape == imgs.shape
|
|
|
|
# NonLocal3d with 'gaussian' mode
|
|
nonlocal_3d = NonLocal3d(3, mode='gaussian')
|
|
assert not hasattr(nonlocal_3d, 'phi')
|
|
assert nonlocal_3d.mode == 'gaussian'
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
nonlocal_3d.cuda()
|
|
out = nonlocal_3d(imgs)
|
|
assert out.shape == imgs.shape
|
|
|
|
# NonLocal3d with 'gaussian' mode and sub_sample
|
|
nonlocal_3d = NonLocal3d(3, mode='gaussian', sub_sample=True)
|
|
assert isinstance(nonlocal_3d.g, nn.Sequential) and len(nonlocal_3d.g) == 2
|
|
assert isinstance(nonlocal_3d.g[1], nn.MaxPool3d)
|
|
assert nonlocal_3d.g[1].kernel_size == (1, 2, 2)
|
|
assert isinstance(nonlocal_3d.phi, nn.MaxPool3d)
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
nonlocal_3d.cuda()
|
|
out = nonlocal_3d(imgs)
|
|
assert out.shape == imgs.shape
|
|
|
|
# NonLocal3d with 'dot_product' mode and sub_sample
|
|
nonlocal_3d = NonLocal3d(3, mode='dot_product', sub_sample=True)
|
|
for m in [nonlocal_3d.g, nonlocal_3d.phi]:
|
|
assert isinstance(m, nn.Sequential) and len(m) == 2
|
|
assert isinstance(m[1], nn.MaxPool3d)
|
|
assert m[1].kernel_size == (1, 2, 2)
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
nonlocal_3d.cuda()
|
|
out = nonlocal_3d(imgs)
|
|
assert out.shape == imgs.shape
|
|
|
|
|
|
def test_nonlocal2d():
|
|
# NonLocal2d with 'embedded_gaussian' mode
|
|
imgs = torch.randn(2, 3, 20, 20)
|
|
nonlocal_2d = NonLocal2d(3)
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
imgs = imgs.cuda()
|
|
nonlocal_2d.cuda()
|
|
out = nonlocal_2d(imgs)
|
|
assert out.shape == imgs.shape
|
|
|
|
# NonLocal2d with 'dot_product' mode
|
|
imgs = torch.randn(2, 3, 20, 20)
|
|
nonlocal_2d = NonLocal2d(3, mode='dot_product')
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
imgs = imgs.cuda()
|
|
nonlocal_2d.cuda()
|
|
out = nonlocal_2d(imgs)
|
|
assert out.shape == imgs.shape
|
|
|
|
# NonLocal2d with 'concatenation' mode
|
|
imgs = torch.randn(2, 3, 20, 20)
|
|
nonlocal_2d = NonLocal2d(3, mode='concatenation')
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
imgs = imgs.cuda()
|
|
nonlocal_2d.cuda()
|
|
out = nonlocal_2d(imgs)
|
|
assert out.shape == imgs.shape
|
|
|
|
# NonLocal2d with 'gaussian' mode
|
|
imgs = torch.randn(2, 3, 20, 20)
|
|
nonlocal_2d = NonLocal2d(3, mode='gaussian')
|
|
assert not hasattr(nonlocal_2d, 'phi')
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
imgs = imgs.cuda()
|
|
nonlocal_2d.cuda()
|
|
out = nonlocal_2d(imgs)
|
|
assert out.shape == imgs.shape
|
|
|
|
# NonLocal2d with 'gaussian' mode and sub_sample
|
|
nonlocal_2d = NonLocal2d(3, mode='gaussian', sub_sample=True)
|
|
assert isinstance(nonlocal_2d.g, nn.Sequential) and len(nonlocal_2d.g) == 2
|
|
assert isinstance(nonlocal_2d.g[1], nn.MaxPool2d)
|
|
assert nonlocal_2d.g[1].kernel_size == (2, 2)
|
|
assert isinstance(nonlocal_2d.phi, nn.MaxPool2d)
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
nonlocal_2d.cuda()
|
|
out = nonlocal_2d(imgs)
|
|
assert out.shape == imgs.shape
|
|
|
|
# NonLocal2d with 'dot_product' mode and sub_sample
|
|
nonlocal_2d = NonLocal2d(3, mode='dot_product', sub_sample=True)
|
|
for m in [nonlocal_2d.g, nonlocal_2d.phi]:
|
|
assert isinstance(m, nn.Sequential) and len(m) == 2
|
|
assert isinstance(m[1], nn.MaxPool2d)
|
|
assert m[1].kernel_size == (2, 2)
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
nonlocal_2d.cuda()
|
|
out = nonlocal_2d(imgs)
|
|
assert out.shape == imgs.shape
|
|
|
|
|
|
def test_nonlocal1d():
|
|
# NonLocal1d with 'embedded_gaussian' mode
|
|
imgs = torch.randn(2, 3, 20)
|
|
nonlocal_1d = NonLocal1d(3)
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
imgs = imgs.cuda()
|
|
nonlocal_1d.cuda()
|
|
out = nonlocal_1d(imgs)
|
|
assert out.shape == imgs.shape
|
|
|
|
# NonLocal1d with 'dot_product' mode
|
|
imgs = torch.randn(2, 3, 20)
|
|
nonlocal_1d = NonLocal1d(3, mode='dot_product')
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
imgs = imgs.cuda()
|
|
nonlocal_1d.cuda()
|
|
out = nonlocal_1d(imgs)
|
|
assert out.shape == imgs.shape
|
|
|
|
# NonLocal1d with 'concatenation' mode
|
|
imgs = torch.randn(2, 3, 20)
|
|
nonlocal_1d = NonLocal1d(3, mode='concatenation')
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
imgs = imgs.cuda()
|
|
nonlocal_1d.cuda()
|
|
out = nonlocal_1d(imgs)
|
|
assert out.shape == imgs.shape
|
|
|
|
# NonLocal1d with 'gaussian' mode
|
|
imgs = torch.randn(2, 3, 20)
|
|
nonlocal_1d = NonLocal1d(3, mode='gaussian')
|
|
assert not hasattr(nonlocal_1d, 'phi')
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
imgs = imgs.cuda()
|
|
nonlocal_1d.cuda()
|
|
out = nonlocal_1d(imgs)
|
|
assert out.shape == imgs.shape
|
|
|
|
# NonLocal1d with 'gaussian' mode and sub_sample
|
|
nonlocal_1d = NonLocal1d(3, mode='gaussian', sub_sample=True)
|
|
assert isinstance(nonlocal_1d.g, nn.Sequential) and len(nonlocal_1d.g) == 2
|
|
assert isinstance(nonlocal_1d.g[1], nn.MaxPool1d)
|
|
assert nonlocal_1d.g[1].kernel_size == 2
|
|
assert isinstance(nonlocal_1d.phi, nn.MaxPool1d)
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
nonlocal_1d.cuda()
|
|
out = nonlocal_1d(imgs)
|
|
assert out.shape == imgs.shape
|
|
|
|
# NonLocal1d with 'dot_product' mode and sub_sample
|
|
nonlocal_1d = NonLocal1d(3, mode='dot_product', sub_sample=True)
|
|
for m in [nonlocal_1d.g, nonlocal_1d.phi]:
|
|
assert isinstance(m, nn.Sequential) and len(m) == 2
|
|
assert isinstance(m[1], nn.MaxPool1d)
|
|
assert m[1].kernel_size == 2
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
nonlocal_1d.cuda()
|
|
out = nonlocal_1d(imgs)
|
|
assert out.shape == imgs.shape
|