mirror of
https://github.com/open-mmlab/mmrazor.git
synced 2025-06-03 15:02:54 +08:00
* tmp * add new mmdet models * add docstring * pass test and pre-commit * rm razor tracer * update fx tracer, now it can automatically wrap methods and functions. * update tracer passed models * add warning for torch <1.12.0 fix bug for python3.6 update placeholder to support placeholder.XXX * fix bug * update docs * fix lint * fix parse_cfg in configs * restore mutablechannel * test ite prune algorithm when using dist * add get_model_from_path to MMModelLibrrary * add mm models to DefaultModelLibrary * add uts * fix bug * fix bug * add uts * add uts * add uts * add uts * fix bug * restore ite_prune_algorithm * update doc * PruneTracer -> ChannelAnalyzer * prune_tracer -> channel_analyzer * add test for fxtracer * fix bug * fix bug * PruneTracer -> ChannelAnalyzer refine * CustomFxTracer -> MMFxTracer * fix bug when test with torch<1.12 * update print log * fix lint * rm unuseful code Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: jacky <jacky@xx.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: liukai <your_email@abc.example>
84 lines
2.8 KiB
Python
84 lines
2.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from .data.model_library import (DefaultModelLibrary, MMClsModelLibrary,
|
|
MMDetModelLibrary, MMModelLibrary,
|
|
MMSegModelLibrary, ModelGenerator,
|
|
TorchModelLibrary)
|
|
from .data.models import SingleLineModel
|
|
from .data.tracer_passed_models import (BackwardPassedModelManager,
|
|
FxPassedModelManager)
|
|
|
|
TEST_DATA = os.getenv('TEST_DATA') == 'true'
|
|
|
|
|
|
class TestModelLibrary(unittest.TestCase):
|
|
|
|
def test_mmcls(self):
|
|
if not TEST_DATA:
|
|
self.skipTest('not test data to save time.')
|
|
library = MMClsModelLibrary(exclude=['cutmax', 'cifar'])
|
|
self.assertTrue(library.is_default_includes_cover_all_models())
|
|
|
|
def test_defaul_library(self):
|
|
if not TEST_DATA:
|
|
self.skipTest('not test data to save time.')
|
|
library = DefaultModelLibrary()
|
|
self.assertTrue(library.is_default_includes_cover_all_models())
|
|
|
|
def test_torchlibrary(self):
|
|
if not TEST_DATA:
|
|
self.skipTest('not test data to save time.')
|
|
library = TorchModelLibrary()
|
|
self.assertTrue(library.is_default_includes_cover_all_models())
|
|
|
|
def test_mmdet(self):
|
|
if not TEST_DATA:
|
|
self.skipTest('not test data to save time.')
|
|
library = MMDetModelLibrary()
|
|
self.assertTrue(library.is_default_includes_cover_all_models())
|
|
|
|
def test_mmseg(self):
|
|
if not TEST_DATA:
|
|
self.skipTest('not test data to save time.')
|
|
library = MMSegModelLibrary()
|
|
self.assertTrue(library.is_default_includes_cover_all_models())
|
|
|
|
def test_get_model_by_config(self):
|
|
config = 'mmcls::resnet/resnet34_8xb32_in1k.py'
|
|
Model = MMModelLibrary.get_model_from_path(config)
|
|
_ = Model()
|
|
|
|
def test_passed_models(self):
|
|
try:
|
|
print(FxPassedModelManager().include_models())
|
|
print(BackwardPassedModelManager().include_models())
|
|
except Exception:
|
|
self.fail()
|
|
|
|
|
|
class TestModels(unittest.TestCase):
|
|
|
|
def _test_a_model(self, Model):
|
|
model = Model()
|
|
x = torch.rand(2, 3, 224, 224)
|
|
y = model(x)
|
|
self.assertSequenceEqual(y.shape, [2, 1000])
|
|
|
|
def test_models(self):
|
|
library = DefaultModelLibrary()
|
|
for Model in library.include_models():
|
|
with self.subTest(model=Model):
|
|
self._test_a_model(Model)
|
|
|
|
def test_generator(self):
|
|
Model = ModelGenerator('model', SingleLineModel)
|
|
model = Model()
|
|
model.eval()
|
|
self.assertEqual(model.training, False)
|
|
model.train()
|
|
self.assertEqual(model.training, True)
|