mmcv/tests/test_cnn/test_non_local.py

90 lines
2.9 KiB
Python
Raw Normal View History

import pytest
import torch
import torch.nn as nn
from mmcv.cnn import NonLocal1d, NonLocal2d, NonLocal3d
from mmcv.cnn.bricks.non_local import _NonLocalNd
def test_nonlocal():
with pytest.raises(ValueError):
# mode should be in ['embedded_gaussian', 'dot_product']
_NonLocalNd(3, mode='unsupport_mode')
# _NonLocalNd
_NonLocalNd(3, norm_cfg=dict(type='BN'))
# Not Zero initialization
_NonLocalNd(3, norm_cfg=dict(type='BN'), zeros_init=True)
# NonLocal3d
imgs = torch.randn(2, 3, 10, 20, 20)
nonlocal_3d = NonLocal3d(3)
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
# NonLocal is only implemented on gpu in parrots
imgs = imgs.cuda()
nonlocal_3d.cuda()
out = nonlocal_3d(imgs)
assert out.shape == imgs.shape
nonlocal_3d = NonLocal3d(3, mode='dot_product')
assert nonlocal_3d.mode == 'dot_product'
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
nonlocal_3d.cuda()
out = nonlocal_3d(imgs)
assert out.shape == imgs.shape
nonlocal_3d = NonLocal3d(3, mode='dot_product', sub_sample=True)
for m in [nonlocal_3d.g, nonlocal_3d.phi]:
assert isinstance(m, nn.Sequential) and len(m) == 2
assert isinstance(m[1], nn.MaxPool3d)
assert m[1].kernel_size == (1, 2, 2)
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
nonlocal_3d.cuda()
out = nonlocal_3d(imgs)
assert out.shape == imgs.shape
# NonLocal2d
imgs = torch.randn(2, 3, 20, 20)
nonlocal_2d = NonLocal2d(3)
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
imgs = imgs.cuda()
nonlocal_2d.cuda()
out = nonlocal_2d(imgs)
assert out.shape == imgs.shape
nonlocal_2d = NonLocal2d(3, mode='dot_product', sub_sample=True)
for m in [nonlocal_2d.g, nonlocal_2d.phi]:
assert isinstance(m, nn.Sequential) and len(m) == 2
assert isinstance(m[1], nn.MaxPool2d)
assert m[1].kernel_size == (2, 2)
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
nonlocal_2d.cuda()
out = nonlocal_2d(imgs)
assert out.shape == imgs.shape
# NonLocal1d
imgs = torch.randn(2, 3, 20)
nonlocal_1d = NonLocal1d(3)
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
imgs = imgs.cuda()
nonlocal_1d.cuda()
out = nonlocal_1d(imgs)
assert out.shape == imgs.shape
nonlocal_1d = NonLocal1d(3, mode='dot_product', sub_sample=True)
for m in [nonlocal_1d.g, nonlocal_1d.phi]:
assert isinstance(m, nn.Sequential) and len(m) == 2
assert isinstance(m[1], nn.MaxPool1d)
assert m[1].kernel_size == 2
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
nonlocal_1d.cuda()
out = nonlocal_1d(imgs)
assert out.shape == imgs.shape