38 lines
1.1 KiB
Python
38 lines
1.1 KiB
Python
|
import pytest
|
||
|
import torch
|
||
|
|
||
|
from mmcls.models.utils import channel_shuffle, make_divisible
|
||
|
|
||
|
|
||
|
def test_make_divisible():
|
||
|
# test min_value is None
|
||
|
result = make_divisible(34, 8, None)
|
||
|
assert result == 32
|
||
|
|
||
|
# test when new_value > min_ratio * value
|
||
|
result = make_divisible(10, 8, min_ratio=0.9)
|
||
|
assert result == 16
|
||
|
|
||
|
# test min_value = 0.8
|
||
|
result = make_divisible(33, 8, min_ratio=0.8)
|
||
|
assert result == 32
|
||
|
|
||
|
|
||
|
def test_channel_shuffle():
|
||
|
x = torch.randn(1, 24, 56, 56)
|
||
|
with pytest.raises(AssertionError):
|
||
|
# num_channels should be divisible by groups
|
||
|
channel_shuffle(x, 7)
|
||
|
|
||
|
groups = 3
|
||
|
batch_size, num_channels, height, width = x.size()
|
||
|
channels_per_group = num_channels // groups
|
||
|
out = channel_shuffle(x, groups)
|
||
|
# test the output value when groups = 3
|
||
|
for b in range(batch_size):
|
||
|
for c in range(num_channels):
|
||
|
c_out = c % channels_per_group * groups + c // channels_per_group
|
||
|
for i in range(height):
|
||
|
for j in range(width):
|
||
|
assert x[b, c, i, j] == out[b, c_out, i, j]
|