mirror of
https://github.com/open-mmlab/mmrazor.git
synced 2025-06-03 15:02:54 +08:00
* move build_arch_param from mutable to mutator * fix UT of diff mutable and mutator * modify based on shiguang's comments * remove mutator from the unittest of mutable
88 lines
3.0 KiB
Python
88 lines
3.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmrazor.models import * # noqa:F403,F401
|
|
from mmrazor.registry import MODELS
|
|
|
|
MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True)
|
|
|
|
|
|
class TestDiffChoiceRoute(TestCase):
|
|
|
|
def test_forward_arch_param(self):
|
|
edges_dict = nn.ModuleDict()
|
|
edges_dict.add_module('first_edge', nn.Conv2d(32, 32, 3, 1, 1))
|
|
edges_dict.add_module('second_edge', nn.Conv2d(32, 32, 5, 1, 2))
|
|
edges_dict.add_module('third_edge', nn.MaxPool2d(3, 1, 1))
|
|
edges_dict.add_module('fourth_edge', nn.MaxPool2d(5, 1, 2))
|
|
edges_dict.add_module('fifth_edge', nn.MaxPool2d(7, 1, 3))
|
|
|
|
diff_choice_route_cfg = dict(
|
|
type='DiffChoiceRoute',
|
|
edges=edges_dict,
|
|
with_arch_param=True,
|
|
)
|
|
|
|
# test with_arch_param = True
|
|
diffchoiceroute = MODELS.build(diff_choice_route_cfg)
|
|
arch_param = nn.Parameter(torch.randn(len(edges_dict)))
|
|
|
|
x = [torch.randn(4, 32, 64, 64) for _ in range(5)]
|
|
output = diffchoiceroute.forward_arch_param(x=x, arch_param=arch_param)
|
|
assert output is not None
|
|
|
|
# test with_arch_param = False
|
|
new_diff_choice_route_cfg = diff_choice_route_cfg.copy()
|
|
new_diff_choice_route_cfg['with_arch_param'] = False
|
|
|
|
new_diff_choice_route = MODELS.build(new_diff_choice_route_cfg)
|
|
arch_param = nn.Parameter(torch.randn(len(edges_dict)))
|
|
output = new_diff_choice_route.forward_arch_param(
|
|
x=x, arch_param=arch_param)
|
|
assert output is not None
|
|
|
|
new_diff_choice_route.fix_chosen(chosen=['first_edge'])
|
|
|
|
# test sample choice
|
|
arch_param = nn.Parameter(torch.randn(len(edges_dict)))
|
|
new_diff_choice_route.sample_choice(arch_param)
|
|
|
|
# test dump_chosen
|
|
with pytest.raises(AssertionError):
|
|
new_diff_choice_route.dump_chosen()
|
|
|
|
def test_forward_fixed(self):
|
|
edges_dict = nn.ModuleDict({
|
|
'first_edge': nn.Conv2d(32, 32, 3, 1, 1),
|
|
'second_edge': nn.Conv2d(32, 32, 5, 1, 2),
|
|
'third_edge': nn.Conv2d(32, 32, 7, 1, 3),
|
|
'fourth_edge': nn.MaxPool2d(3, 1, 1),
|
|
'fifth_edge': nn.AvgPool2d(3, 1, 1),
|
|
})
|
|
|
|
diff_choice_route_cfg = dict(
|
|
type='DiffChoiceRoute',
|
|
edges=edges_dict,
|
|
with_arch_param=True,
|
|
)
|
|
|
|
# test with_arch_param = True
|
|
diffchoiceroute = MODELS.build(diff_choice_route_cfg)
|
|
|
|
diffchoiceroute.fix_chosen(
|
|
chosen=['first_edge', 'second_edge', 'fifth_edge'])
|
|
assert diffchoiceroute.is_fixed is True
|
|
|
|
x = [torch.randn(4, 32, 64, 64) for _ in range(5)]
|
|
output = diffchoiceroute.forward_fixed(x)
|
|
assert output is not None
|
|
assert diffchoiceroute.num_choices == 3
|
|
|
|
# after is_fixed = True, call fix_chosen
|
|
with pytest.raises(AttributeError):
|
|
diffchoiceroute.fix_chosen(chosen=['first_edge'])
|