53 lines
1.6 KiB
Python
53 lines
1.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
|
|
from mmrazor.models.builder import MUTABLES
|
|
|
|
|
|
def test_one_shot_op():
|
|
oneshot_choice_op = dict(
|
|
type='OneShotOP',
|
|
space_id='test',
|
|
num_chosen=1,
|
|
choices=dict(
|
|
shuffle_3x3=dict(type='ShuffleBlock', kernel_size=3),
|
|
shuffle_5x5=dict(type='ShuffleBlock', kernel_size=5),
|
|
shuffle_7x7=dict(type='ShuffleBlock', kernel_size=7),
|
|
shuffle_xception=dict(type='ShuffleXception'),
|
|
),
|
|
choice_args=dict(in_channels=16, out_channels=16, stride=1))
|
|
|
|
model = MUTABLES.build(oneshot_choice_op)
|
|
|
|
tensor = torch.randn(16, 16, 32, 32)
|
|
|
|
# test forward
|
|
outputs = model(tensor)
|
|
assert outputs.size(1) == 16 and outputs.size(2) == 32
|
|
|
|
|
|
def test_differentiable_op():
|
|
oneshot_choice_op = dict(
|
|
type='DifferentiableOP',
|
|
space_id='test',
|
|
num_chosen=1,
|
|
with_arch_param=True,
|
|
choices=dict(
|
|
zero=dict(type='DartsZero'),
|
|
skip_connect=dict(type='DartsSkipConnect'),
|
|
dil_conv_3x3=dict(type='DartsDilConv', kernel_size=3),
|
|
dil_conv_5x5=dict(type='DartsDilConv', kernel_size=5),
|
|
sep_conv_3x3=dict(type='DartsSepConv', kernel_size=3),
|
|
sep_conv_5x5=dict(type='DartsSepConv', kernel_size=5),
|
|
),
|
|
choice_args=dict(in_channels=16, out_channels=16, stride=2))
|
|
|
|
model = MUTABLES.build(oneshot_choice_op)
|
|
arch_param = model.build_arch_param()
|
|
|
|
tensor = torch.randn(16, 16, 32, 32)
|
|
|
|
# test forward
|
|
outputs = model(tensor, arch_param)
|
|
assert outputs.size(1) == 16 and outputs.size(2) == 16
|