145 lines
4.7 KiB
Python
145 lines
4.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import unittest
|
|
from typing import Dict, Optional, Union
|
|
from unittest import TestCase
|
|
|
|
import torch.nn as nn
|
|
from mmengine import fileio
|
|
from mmengine.config import Config
|
|
from mmengine.model import BaseModel
|
|
|
|
from mmrazor.models import * # noqa: F403, F401
|
|
from mmrazor.models.algorithms.base import BaseAlgorithm
|
|
from mmrazor.models.mutables import OneShotMutableOP
|
|
from mmrazor.registry import MODELS
|
|
from mmrazor.structures import load_fix_subnet
|
|
from mmrazor.utils import ValidFixMutable
|
|
|
|
|
|
@MODELS.register_module()
|
|
class MockModel(BaseModel):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
convs1 = nn.ModuleDict({
|
|
'conv1': nn.Conv2d(3, 8, 1),
|
|
'conv2': nn.Conv2d(3, 8, 1),
|
|
'conv3': nn.Conv2d(3, 8, 1),
|
|
})
|
|
convs2 = nn.ModuleDict({
|
|
'conv1': nn.Conv2d(8, 16, 1),
|
|
'conv2': nn.Conv2d(8, 16, 1),
|
|
'conv3': nn.Conv2d(8, 16, 1),
|
|
})
|
|
|
|
self.mutable1 = OneShotMutableOP(convs1)
|
|
self.mutable2 = OneShotMutableOP(convs2)
|
|
|
|
def forward(self, x):
|
|
x = self.mutable1(x)
|
|
x = self.mutable2(x)
|
|
return x
|
|
|
|
|
|
@MODELS.register_module()
|
|
class MockAlgorithm(BaseAlgorithm):
|
|
|
|
def __init__(self,
|
|
architecture: Union[BaseModel, Dict],
|
|
fix_subnet: Optional[ValidFixMutable] = None):
|
|
super().__init__(architecture)
|
|
|
|
if fix_subnet is not None:
|
|
# According to fix_subnet, delete the unchosen part of supernet
|
|
load_fix_subnet(self, fix_subnet, prefix='architecture.')
|
|
self.is_supernet = False
|
|
else:
|
|
self.is_supernet = True
|
|
|
|
|
|
class TestRegistry(TestCase):
|
|
|
|
def setUp(self) -> None:
|
|
self.arch_cfg_path = dict(
|
|
cfg_path='mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py',
|
|
pretrained=False)
|
|
|
|
return super().setUp()
|
|
|
|
def test_build_razor_from_cfg(self):
|
|
# test cfg_path
|
|
# TODO relay on mmengine:HAOCHENYE/config_new_feature
|
|
# model = MODELS.build(self.arch_cfg_path)
|
|
# self.assertIsNotNone(model)
|
|
|
|
# test fix subnet
|
|
cfg = Config.fromfile(
|
|
'tests/data/test_registry/registry_subnet_config.py')
|
|
model = MODELS.build(cfg.model)
|
|
|
|
# test return architecture
|
|
cfg = Config.fromfile(
|
|
'tests/data/test_registry/registry_architecture_config.py')
|
|
model = MODELS.build(cfg.model)
|
|
self.assertTrue(isinstance(model, BaseModel))
|
|
|
|
def test_build_subnet_prune_from_cfg(self):
|
|
mutator_cfg = fileio.load('tests/data/test_registry/subnet.json')
|
|
init_cfg = dict(
|
|
type='Pretrained',
|
|
checkpoint='tests/data/test_registry/subnet_weight.pth')
|
|
# test fix subnet
|
|
model_cfg = dict(
|
|
# use mmrazor's build_func
|
|
type='mmrazor.sub_model',
|
|
cfg=dict(
|
|
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py',
|
|
pretrained=False),
|
|
fix_subnet=mutator_cfg,
|
|
mode='mutator',
|
|
init_cfg=init_cfg)
|
|
model = MODELS.build(model_cfg)
|
|
self.assertTrue(isinstance(model, BaseModel))
|
|
|
|
def test_build_subnet_prune_from_cfg_by_mutator(self):
|
|
mutator_cfg = fileio.load('tests/data/test_registry/subnet.json')
|
|
init_cfg = dict(
|
|
type='Pretrained',
|
|
checkpoint='tests/data/test_registry/subnet_weight.pth')
|
|
# test fix subnet
|
|
model_cfg = dict(
|
|
# use mmrazor's build_func
|
|
type='mmrazor.sub_model',
|
|
cfg=dict(
|
|
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py',
|
|
pretrained=False),
|
|
fix_subnet=mutator_cfg,
|
|
mode='mutator',
|
|
init_cfg=init_cfg)
|
|
model = MODELS.build(model_cfg)
|
|
self.assertTrue(isinstance(model, BaseModel))
|
|
# make sure the model is pruned
|
|
assert model.backbone.layer1[0].conv1.weight.size()[0] == 41
|
|
|
|
def test_build_subnet_prune_from_cfg_by_mutable(self):
|
|
mutator_cfg = fileio.load('tests/data/test_registry/subnet.json')
|
|
init_cfg = dict(
|
|
type='Pretrained',
|
|
checkpoint='tests/data/test_registry/subnet_weight.pth')
|
|
# test fix subnet
|
|
model_cfg = dict(
|
|
# use mmrazor's build_func
|
|
type='mmrazor.sub_model',
|
|
cfg=dict(
|
|
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py',
|
|
pretrained=False),
|
|
fix_subnet=mutator_cfg,
|
|
mode='mutable',
|
|
init_cfg=init_cfg)
|
|
model = MODELS.build(model_cfg)
|
|
self.assertTrue(isinstance(model, BaseModel))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|