mmrazor/tests/test_data.py
LKJacky 1c03a07350
Enhance the Abilities of the Tracer for Pruning. (#371)
* tmp

* add new mmdet models

* add docstring

* pass test and pre-commit

* rm razor tracer

* update fx tracer, now it can automatically wrap methods and functions.

* update tracer passed models

* add warning for torch <1.12.0

fix bug for python3.6

update placeholder to support placeholder.XXX

* fix bug

* update docs

* fix lint

* fix parse_cfg in configs

* restore mutablechannel

* test ite prune algorithm when using dist

* add get_model_from_path to MMModelLibrrary

* add mm models to DefaultModelLibrary

* add uts

* fix bug

* fix bug

* add uts

* add uts

* add uts

* add uts

* fix bug

* restore ite_prune_algorithm

* update doc

* PruneTracer -> ChannelAnalyzer

* prune_tracer -> channel_analyzer

* add test for fxtracer

* fix bug

* fix bug

* PruneTracer -> ChannelAnalyzer

refine

* CustomFxTracer -> MMFxTracer

* fix bug when test with torch<1.12

* update print log

* fix lint

* rm unuseful code

Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: jacky <jacky@xx.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: liukai <your_email@abc.example>
2022-12-08 15:59:27 +08:00

84 lines
2.8 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os
import unittest
import torch
from .data.model_library import (DefaultModelLibrary, MMClsModelLibrary,
MMDetModelLibrary, MMModelLibrary,
MMSegModelLibrary, ModelGenerator,
TorchModelLibrary)
from .data.models import SingleLineModel
from .data.tracer_passed_models import (BackwardPassedModelManager,
FxPassedModelManager)
TEST_DATA = os.getenv('TEST_DATA') == 'true'
class TestModelLibrary(unittest.TestCase):
def test_mmcls(self):
if not TEST_DATA:
self.skipTest('not test data to save time.')
library = MMClsModelLibrary(exclude=['cutmax', 'cifar'])
self.assertTrue(library.is_default_includes_cover_all_models())
def test_defaul_library(self):
if not TEST_DATA:
self.skipTest('not test data to save time.')
library = DefaultModelLibrary()
self.assertTrue(library.is_default_includes_cover_all_models())
def test_torchlibrary(self):
if not TEST_DATA:
self.skipTest('not test data to save time.')
library = TorchModelLibrary()
self.assertTrue(library.is_default_includes_cover_all_models())
def test_mmdet(self):
if not TEST_DATA:
self.skipTest('not test data to save time.')
library = MMDetModelLibrary()
self.assertTrue(library.is_default_includes_cover_all_models())
def test_mmseg(self):
if not TEST_DATA:
self.skipTest('not test data to save time.')
library = MMSegModelLibrary()
self.assertTrue(library.is_default_includes_cover_all_models())
def test_get_model_by_config(self):
config = 'mmcls::resnet/resnet34_8xb32_in1k.py'
Model = MMModelLibrary.get_model_from_path(config)
_ = Model()
def test_passed_models(self):
try:
print(FxPassedModelManager().include_models())
print(BackwardPassedModelManager().include_models())
except Exception:
self.fail()
class TestModels(unittest.TestCase):
def _test_a_model(self, Model):
model = Model()
x = torch.rand(2, 3, 224, 224)
y = model(x)
self.assertSequenceEqual(y.shape, [2, 1000])
def test_models(self):
library = DefaultModelLibrary()
for Model in library.include_models():
with self.subTest(model=Model):
self._test_a_model(Model)
def test_generator(self):
Model = ModelGenerator('model', SingleLineModel)
model = Model()
model.eval()
self.assertEqual(model.training, False)
model.train()
self.assertEqual(model.training, True)