90 lines
2.4 KiB
Python
90 lines
2.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
|
|
from mmseg.models.utils import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc,
|
|
nlc_to_nchw)
|
|
|
|
|
|
def test_nchw2nlc2nchw():
|
|
# Test nchw2nlc2nchw function
|
|
shape_nchw = (4, 2, 5, 5)
|
|
shape_nlc = (4, 25, 2)
|
|
|
|
def test_func(x):
|
|
assert x.shape == torch.Size(shape_nlc)
|
|
return x
|
|
|
|
x = torch.rand(*shape_nchw)
|
|
output = nchw2nlc2nchw(test_func, x)
|
|
assert output.shape == torch.Size(shape_nchw)
|
|
|
|
def test_func2(x, arg):
|
|
assert x.shape == torch.Size(shape_nlc)
|
|
assert arg == 100
|
|
return x
|
|
|
|
x = torch.rand(*shape_nchw)
|
|
output = nchw2nlc2nchw(test_func2, x, arg=100)
|
|
assert output.shape == torch.Size(shape_nchw)
|
|
|
|
def test_func3(x):
|
|
assert x.is_contiguous()
|
|
assert x.shape == torch.Size(shape_nlc)
|
|
return x
|
|
|
|
x = torch.rand(*shape_nchw)
|
|
output = nchw2nlc2nchw(test_func3, x, contiguous=True)
|
|
assert output.shape == torch.Size(shape_nchw)
|
|
assert output.is_contiguous()
|
|
|
|
|
|
def test_nlc2nchw2nlc():
|
|
# Test nlc2nchw2nlc function
|
|
shape_nchw = (4, 2, 5, 5)
|
|
shape_nlc = (4, 25, 2)
|
|
|
|
def test_func(x):
|
|
assert x.shape == torch.Size(shape_nchw)
|
|
return x
|
|
|
|
x = torch.rand(*shape_nlc)
|
|
output = nlc2nchw2nlc(test_func, x, shape_nchw[2:])
|
|
assert output.shape == torch.Size(shape_nlc)
|
|
|
|
def test_func2(x, arg):
|
|
assert x.shape == torch.Size(shape_nchw)
|
|
assert arg == 100
|
|
return x
|
|
|
|
x = torch.rand(*shape_nlc)
|
|
output = nlc2nchw2nlc(test_func2, x, shape_nchw[2:], arg=100)
|
|
assert output.shape == torch.Size(shape_nlc)
|
|
|
|
def test_func3(x):
|
|
assert x.is_contiguous()
|
|
assert x.shape == torch.Size(shape_nchw)
|
|
return x
|
|
|
|
x = torch.rand(*shape_nlc)
|
|
output = nlc2nchw2nlc(test_func3, x, shape_nchw[2:], contiguous=True)
|
|
assert output.shape == torch.Size(shape_nlc)
|
|
assert output.is_contiguous()
|
|
|
|
|
|
def test_nchw_to_nlc():
|
|
# Test nchw_to_nlc function
|
|
shape_nchw = (4, 2, 5, 5)
|
|
shape_nlc = (4, 25, 2)
|
|
x = torch.rand(*shape_nchw)
|
|
y = nchw_to_nlc(x)
|
|
assert y.shape == torch.Size(shape_nlc)
|
|
|
|
|
|
def test_nlc_to_nchw():
|
|
# Test nlc_to_nchw function
|
|
shape_nchw = (4, 2, 5, 5)
|
|
shape_nlc = (4, 25, 2)
|
|
x = torch.rand(*shape_nlc)
|
|
y = nlc_to_nchw(x, (5, 5))
|
|
assert y.shape == torch.Size(shape_nchw)
|