60 lines
1.7 KiB
Python
60 lines
1.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
from mmengine.utils import digit_version
|
|
|
|
from mmcls.models.utils import channel_shuffle, is_tracing, make_divisible
|
|
|
|
|
|
def test_make_divisible():
|
|
# test min_value is None
|
|
result = make_divisible(34, 8, None)
|
|
assert result == 32
|
|
|
|
# test when new_value > min_ratio * value
|
|
result = make_divisible(10, 8, min_ratio=0.9)
|
|
assert result == 16
|
|
|
|
# test min_value = 0.8
|
|
result = make_divisible(33, 8, min_ratio=0.8)
|
|
assert result == 32
|
|
|
|
|
|
def test_channel_shuffle():
|
|
x = torch.randn(1, 24, 56, 56)
|
|
with pytest.raises(AssertionError):
|
|
# num_channels should be divisible by groups
|
|
channel_shuffle(x, 7)
|
|
|
|
groups = 3
|
|
batch_size, num_channels, height, width = x.size()
|
|
channels_per_group = num_channels // groups
|
|
out = channel_shuffle(x, groups)
|
|
# test the output value when groups = 3
|
|
for b in range(batch_size):
|
|
for c in range(num_channels):
|
|
c_out = c % channels_per_group * groups + c // channels_per_group
|
|
for i in range(height):
|
|
for j in range(width):
|
|
assert x[b, c, i, j] == out[b, c_out, i, j]
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
digit_version(torch.__version__) < digit_version('1.6.0'),
|
|
reason='torch.jit.is_tracing is not available before 1.6.0')
|
|
def test_is_tracing():
|
|
|
|
def foo(x):
|
|
if is_tracing():
|
|
return x
|
|
else:
|
|
return x.tolist()
|
|
|
|
x = torch.rand(3)
|
|
# test without trace
|
|
assert isinstance(foo(x), list)
|
|
|
|
# test with trace
|
|
traced_foo = torch.jit.trace(foo, (torch.rand(1), ))
|
|
assert isinstance(traced_foo(x), torch.Tensor)
|