mmrazor/tests/test_registry/test_registry.py

127 lines
4.0 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
import unittest
2022-07-14 22:38:35 +08:00
from typing import Dict, Optional, Union
from unittest import TestCase
import torch.nn as nn
from mmengine import fileio
from mmengine.config import Config
from mmengine.model import BaseModel
from mmrazor.models import * # noqa: F403, F401
from mmrazor.models.algorithms.base import BaseAlgorithm
2022-07-06 01:52:35 +08:00
from mmrazor.models.mutables import OneShotMutableOP
from mmrazor.registry import MODELS
from mmrazor.structures import load_fix_subnet
from mmrazor.utils import ValidFixMutable
@MODELS.register_module()
2022-07-14 22:38:35 +08:00
class MockModel(BaseModel):
def __init__(self):
super().__init__()
convs1 = nn.ModuleDict({
'conv1': nn.Conv2d(3, 8, 1),
'conv2': nn.Conv2d(3, 8, 1),
'conv3': nn.Conv2d(3, 8, 1),
})
convs2 = nn.ModuleDict({
'conv1': nn.Conv2d(8, 16, 1),
'conv2': nn.Conv2d(8, 16, 1),
'conv3': nn.Conv2d(8, 16, 1),
})
2022-07-06 01:52:35 +08:00
self.mutable1 = OneShotMutableOP(convs1)
self.mutable2 = OneShotMutableOP(convs2)
def forward(self, x):
x = self.mutable1(x)
x = self.mutable2(x)
return x
@MODELS.register_module()
2022-07-14 22:38:35 +08:00
class MockAlgorithm(BaseAlgorithm):
def __init__(self,
architecture: Union[BaseModel, Dict],
fix_subnet: Optional[ValidFixMutable] = None):
super().__init__(architecture)
2022-07-14 22:38:35 +08:00
if fix_subnet is not None:
# According to fix_subnet, delete the unchosen part of supernet
2022-07-14 22:38:35 +08:00
load_fix_subnet(self, fix_subnet, prefix='architecture.')
self.is_supernet = False
else:
self.is_supernet = True
class TestRegistry(TestCase):
def setUp(self) -> None:
self.arch_cfg_path = dict(
cfg_path='mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py',
pretrained=False)
return super().setUp()
def test_build_razor_from_cfg(self):
# test cfg_path
# TODO relay on mmengine:HAOCHENYE/config_new_feature
# model = MODELS.build(self.arch_cfg_path)
# self.assertIsNotNone(model)
# test fix subnet
cfg = Config.fromfile(
'tests/data/test_registry/registry_subnet_config.py')
model = MODELS.build(cfg.model)
# test return architecture
cfg = Config.fromfile(
'tests/data/test_registry/registry_architecture_config.py')
model = MODELS.build(cfg.model)
self.assertTrue(isinstance(model, BaseModel))
def test_build_subnet_prune_from_cfg_by_mutator(self):
mutator_cfg = fileio.load('tests/data/test_registry/subnet.json')
init_cfg = dict(
type='Pretrained',
checkpoint='tests/data/test_registry/subnet_weight.pth')
# test fix subnet
model_cfg = dict(
# use mmrazor's build_func
type='mmrazor.sub_model',
cfg=dict(
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py',
pretrained=False),
fix_subnet=mutator_cfg,
mode='mutator',
init_cfg=init_cfg)
model = MODELS.build(model_cfg)
self.assertTrue(isinstance(model, BaseModel))
[Feature] Add DMCP and fix the deploy pipeline of NAS algorithms (#406) * Copybook * Newly created copy PR * Newly created copy PR * update op_counters * update subnet/commit/FLOPsCounter * update docs/UT * update docs/UT * add setter for current_mask * replace current_mask with activated_tensor_channel * update subnet training * fix ci * fix ci * fix ci * fix readme.md * fix readme.md * update * fix expression * fix CI * fix UT * fix ci * fix arch YAMLs * fix yapf * revise mmcv version<=2.0.0rc3 * fix build.yaml * Rollback mmdet to v3.0.0rc5 * Rollback mmdet to v3.0.0rc5 * Rollback mmseg to v1.0.0rc4 * remove search_groups in mutator * revert env change * update usage of sub_model * fix UT * fix bignas config * fix UT for dcff & registry * update Ut&channel_mutator * fix test_channel_mutator * fix Ut * fix bug for load dcffnet * update nas config * update nas config * fix api in evolution_search_loop * update evolu_search_loop * fix metric_predictor * update url * fix a0 fine_grained * fix subnet export misskey * fix ofa yaml * fix lint * fix comments * add autoformer cfg * update readme * update supernet link * fix sub_model configs * update subnet inference readme * fix lint * fix lint * Update autoformer_subnet_8xb256_in1k.py * update test.py to support args.checkpoint as none * update DARTS readme * update readme --------- Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: aptsunny <aptsunny@tongji.edu.cn> Co-authored-by: sunyue1 <sunyue1@sensetime.com> Co-authored-by: aptsunny <36404164+aptsunny@users.noreply.github.com> Co-authored-by: wang shiguang <xiaohu_wyyx@163.com>
2023-03-02 18:22:20 +08:00
# make sure the model is pruned
assert model.backbone.layer1[0].conv1.weight.size()[0] == 41
def test_build_subnet_prune_from_cfg_by_mutable(self):
mutator_cfg = fileio.load('tests/data/test_registry/subnet.json')
init_cfg = dict(
type='Pretrained',
checkpoint='tests/data/test_registry/subnet_weight.pth')
# test fix subnet
model_cfg = dict(
# use mmrazor's build_func
type='mmrazor.sub_model',
cfg=dict(
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py',
pretrained=False),
fix_subnet=mutator_cfg,
mode='mutable',
init_cfg=init_cfg)
model = MODELS.build(model_cfg)
self.assertTrue(isinstance(model, BaseModel))
if __name__ == '__main__':
unittest.main()