149 lines
3.2 KiB
Python
149 lines
3.2 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
import torch
|
||
|
|
||
|
from mmrazor.models.builder import OPS
|
||
|
|
||
|
|
||
|
def test_shuffle_series():
|
||
|
|
||
|
tensor = torch.randn(16, 16, 32, 32)
|
||
|
|
||
|
# test ShuffleBlock_7x7
|
||
|
shuffle_block_7x7 = dict(
|
||
|
type='ShuffleBlock',
|
||
|
in_channels=16,
|
||
|
out_channels=16,
|
||
|
kernel_size=7,
|
||
|
stride=1)
|
||
|
|
||
|
op = OPS.build(shuffle_block_7x7)
|
||
|
|
||
|
# test forward
|
||
|
outputs = op(tensor)
|
||
|
assert outputs.size(1) == 16 and outputs.size(2) == 32
|
||
|
|
||
|
# test ShuffleBlock_5x5
|
||
|
shuffle_block_5x5 = dict(
|
||
|
type='ShuffleBlock',
|
||
|
in_channels=16,
|
||
|
out_channels=16,
|
||
|
kernel_size=5,
|
||
|
stride=1)
|
||
|
|
||
|
op = OPS.build(shuffle_block_5x5)
|
||
|
|
||
|
# test forward
|
||
|
outputs = op(tensor)
|
||
|
assert outputs.size(1) == 16 and outputs.size(2) == 32
|
||
|
|
||
|
# test ShuffleBlock_3x3
|
||
|
shuffle_block_3x3 = dict(
|
||
|
type='ShuffleBlock',
|
||
|
in_channels=16,
|
||
|
out_channels=16,
|
||
|
kernel_size=3,
|
||
|
stride=1)
|
||
|
|
||
|
op = OPS.build(shuffle_block_3x3)
|
||
|
|
||
|
# test forward
|
||
|
outputs = op(tensor)
|
||
|
assert outputs.size(1) == 16 and outputs.size(2) == 32
|
||
|
|
||
|
# test ShuffleXception
|
||
|
shuffle_xception = dict(
|
||
|
type='ShuffleXception', in_channels=16, out_channels=16, stride=1)
|
||
|
|
||
|
op = OPS.build(shuffle_xception)
|
||
|
|
||
|
# test forward
|
||
|
outputs = op(tensor)
|
||
|
assert outputs.size(1) == 16 and outputs.size(2) == 32
|
||
|
|
||
|
|
||
|
def test_darts_series():
|
||
|
|
||
|
tensor = torch.randn(16, 16, 32, 32)
|
||
|
|
||
|
# test avg pool bn
|
||
|
avg_pool_bn = dict(
|
||
|
type='DartsPoolBN',
|
||
|
in_channels=16,
|
||
|
out_channels=16,
|
||
|
kernel_size=3,
|
||
|
pool_type='avg',
|
||
|
stride=1)
|
||
|
|
||
|
op = OPS.build(avg_pool_bn)
|
||
|
|
||
|
# test forward
|
||
|
outputs = op(tensor)
|
||
|
assert outputs.size(1) == 16 and outputs.size(2) == 32
|
||
|
|
||
|
# test max pool bn
|
||
|
max_pool_bn = dict(
|
||
|
type='DartsPoolBN',
|
||
|
in_channels=16,
|
||
|
out_channels=16,
|
||
|
kernel_size=3,
|
||
|
pool_type='max',
|
||
|
stride=1)
|
||
|
|
||
|
op = OPS.build(max_pool_bn)
|
||
|
|
||
|
# test forward
|
||
|
outputs = op(tensor)
|
||
|
assert outputs.size(1) == 16 and outputs.size(2) == 32
|
||
|
|
||
|
# test DartsSepConv
|
||
|
sep_conv = dict(
|
||
|
type='DartsSepConv',
|
||
|
in_channels=16,
|
||
|
out_channels=16,
|
||
|
kernel_size=3,
|
||
|
stride=1)
|
||
|
|
||
|
op = OPS.build(sep_conv)
|
||
|
|
||
|
# test forward
|
||
|
outputs = op(tensor)
|
||
|
assert outputs.size(1) == 16 and outputs.size(2) == 32
|
||
|
|
||
|
# test DartsSepConv
|
||
|
sep_conv = dict(
|
||
|
type='DartsSepConv',
|
||
|
in_channels=16,
|
||
|
out_channels=16,
|
||
|
kernel_size=3,
|
||
|
stride=1)
|
||
|
|
||
|
op = OPS.build(sep_conv)
|
||
|
|
||
|
# test forward
|
||
|
outputs = op(tensor)
|
||
|
assert outputs.size(1) == 16 and outputs.size(2) == 32
|
||
|
|
||
|
# test DartsDilConv
|
||
|
dil_conv = dict(
|
||
|
type='DartsDilConv',
|
||
|
in_channels=16,
|
||
|
out_channels=16,
|
||
|
kernel_size=3,
|
||
|
stride=1)
|
||
|
|
||
|
op = OPS.build(dil_conv)
|
||
|
|
||
|
# test forward
|
||
|
outputs = op(tensor)
|
||
|
assert outputs.size(1) == 16 and outputs.size(2) == 32
|
||
|
|
||
|
# test DartsSkipConnect
|
||
|
skip_connect = dict(
|
||
|
type='DartsSkipConnect', in_channels=16, out_channels=16, stride=1)
|
||
|
|
||
|
op = OPS.build(skip_connect)
|
||
|
|
||
|
# test forward
|
||
|
outputs = op(tensor)
|
||
|
assert outputs.size(1) == 16 and outputs.size(2) == 32
|