mmsegmentation/tests/test_models/test_utils/test_shape_convert.py

90 lines
2.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.models.utils import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc,
nlc_to_nchw)
def test_nchw2nlc2nchw():
# Test nchw2nlc2nchw function
shape_nchw = (4, 2, 5, 5)
shape_nlc = (4, 25, 2)
def test_func(x):
assert x.shape == torch.Size(shape_nlc)
return x
x = torch.rand(*shape_nchw)
output = nchw2nlc2nchw(test_func, x)
assert output.shape == torch.Size(shape_nchw)
def test_func2(x, arg):
assert x.shape == torch.Size(shape_nlc)
assert arg == 100
return x
x = torch.rand(*shape_nchw)
output = nchw2nlc2nchw(test_func2, x, arg=100)
assert output.shape == torch.Size(shape_nchw)
def test_func3(x):
assert x.is_contiguous()
assert x.shape == torch.Size(shape_nlc)
return x
x = torch.rand(*shape_nchw)
output = nchw2nlc2nchw(test_func3, x, contiguous=True)
assert output.shape == torch.Size(shape_nchw)
assert output.is_contiguous()
def test_nlc2nchw2nlc():
# Test nlc2nchw2nlc function
shape_nchw = (4, 2, 5, 5)
shape_nlc = (4, 25, 2)
def test_func(x):
assert x.shape == torch.Size(shape_nchw)
return x
x = torch.rand(*shape_nlc)
output = nlc2nchw2nlc(test_func, x, shape_nchw[2:])
assert output.shape == torch.Size(shape_nlc)
def test_func2(x, arg):
assert x.shape == torch.Size(shape_nchw)
assert arg == 100
return x
x = torch.rand(*shape_nlc)
output = nlc2nchw2nlc(test_func2, x, shape_nchw[2:], arg=100)
assert output.shape == torch.Size(shape_nlc)
def test_func3(x):
assert x.is_contiguous()
assert x.shape == torch.Size(shape_nchw)
return x
x = torch.rand(*shape_nlc)
output = nlc2nchw2nlc(test_func3, x, shape_nchw[2:], contiguous=True)
assert output.shape == torch.Size(shape_nlc)
assert output.is_contiguous()
def test_nchw_to_nlc():
# Test nchw_to_nlc function
shape_nchw = (4, 2, 5, 5)
shape_nlc = (4, 25, 2)
x = torch.rand(*shape_nchw)
y = nchw_to_nlc(x)
assert y.shape == torch.Size(shape_nlc)
def test_nlc_to_nchw():
# Test nlc_to_nchw function
shape_nchw = (4, 2, 5, 5)
shape_nlc = (4, 25, 2)
x = torch.rand(*shape_nlc)
y = nlc_to_nchw(x, (5, 5))
assert y.shape == torch.Size(shape_nchw)