CodeCamp Testing the Robustness of the MMRazor Channel Dependency Resolution Tool on MMPOSE ()

* support mmpose tracer test

* support mmpose tracer test 2

* support mmpose tracer test 2

* test trace on mmpose

* test trace on mmpose

* Delete run.sh

* fix lint

* restore models

* clean code

* note a bug for SimCCHead

Co-authored-by: liukai <your_email@abc.example>
pull/436/head
Xinxinxin Xu 2023-01-12 10:46:20 +08:00 committed by GitHub
parent 67da3ad240
commit 25796d5437
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 224 additions and 13 deletions

1
.gitignore vendored
View File

@ -113,6 +113,7 @@ venv.bak/
*.log.json
/work_dirs
/mmrazor/.mim
*.out
# Pytorch
*.pth

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
import torch.nn as nn
from mmengine.model import BaseModel
@ -6,8 +8,8 @@ from mmrazor.registry import TASK_UTILS
from mmrazor.utils import get_placeholder
from .demo_inputs import (BaseDemoInput, DefaultMMClsDemoInput,
DefaultMMDemoInput, DefaultMMDetDemoInput,
DefaultMMRotateDemoInput, DefaultMMSegDemoInput,
DefaultMMYoloDemoInput)
DefaultMMPoseDemoInput, DefaultMMRotateDemoInput,
DefaultMMSegDemoInput, DefaultMMYoloDemoInput)
try:
from mmdet.models import BaseDetector
@ -24,19 +26,28 @@ try:
except Exception:
BaseSegmentor = get_placeholder('mmseg')
default_demo_input_class = {
BaseDetector: DefaultMMDetDemoInput,
ImageClassifier: DefaultMMClsDemoInput,
BaseSegmentor: DefaultMMSegDemoInput,
BaseModel: DefaultMMDemoInput,
nn.Module: BaseDemoInput
}
# New
try:
from mmpose.models import TopdownPoseEstimator
except Exception:
TopdownPoseEstimator = get_placeholder('mmpose')
default_demo_input_class = OrderedDict([
(BaseDetector, DefaultMMDetDemoInput),
(ImageClassifier, DefaultMMClsDemoInput),
(BaseSegmentor, DefaultMMSegDemoInput),
(TopdownPoseEstimator, DefaultMMPoseDemoInput),
(BaseModel, DefaultMMDemoInput),
(nn.Module, BaseDemoInput),
])
default_demo_input_class_for_scope = {
'mmcls': DefaultMMClsDemoInput,
'mmdet': DefaultMMDetDemoInput,
'mmseg': DefaultMMSegDemoInput,
'mmrotate': DefaultMMRotateDemoInput,
'mmyolo': DefaultMMYoloDemoInput,
'mmpose': DefaultMMPoseDemoInput,
'torchvision': BaseDemoInput,
}

View File

@ -122,3 +122,17 @@ class DefaultMMYoloDemoInput(DefaultMMDetDemoInput):
"""Default demo input generator for mmyolo models."""
default_shape = (1, 3, 125, 320)
@TASK_UTILS.register_module()
class DefaultMMPoseDemoInput(DefaultMMDemoInput):
"""Default demo input generator for mmpose models."""
def _get_mm_data(self, model, input_shape, training=False):
from mmpose.models import TopdownPoseEstimator
from .mmpose_demo_input import demo_mmpose_inputs
assert isinstance(model, TopdownPoseEstimator)
data = demo_mmpose_inputs(model, input_shape)
return data

View File

@ -0,0 +1,122 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Include functions to generate mmpose demo inputs.
Modified from mmpose.
"""
import torch
from mmpose.models.heads import (CPMHead, DSNTHead, HeatmapHead,
IntegralRegressionHead, MSPNHead,
RegressionHead, RLEHead, SimCCHead,
ViPNASHead)
from mmpose.testing._utils import get_packed_inputs
from mmrazor.utils import get_placeholder
try:
from mmpose.models import PoseDataPreProcessor
from mmpose.structures import PoseDataSample
except ImportError:
PoseDataPreProcessor = get_placeholder('mmpose')
PoseDataSample = get_placeholder('mmpose')
def demo_mmpose_inputs(model, for_training=False, batch_size=1):
input_shape = (
1,
3,
) + model.head.decoder.input_size
imgs = torch.randn(*input_shape)
batch_data_samples = []
if isinstance(model.head, HeatmapHead):
batch_data_samples = [
inputs['data_sample'] for inputs in get_packed_inputs(
batch_size,
num_keypoints=model.head.out_channels,
heatmap_size=model.head.decoder.heatmap_size[::-1])
]
elif isinstance(model.head, MSPNHead):
batch_data_samples = [
inputs['data_sample'] for inputs in get_packed_inputs(
batch_size=batch_size,
num_instances=1,
num_keypoints=model.head.out_channels,
heatmap_size=model.head.decoder.heatmap_size,
with_heatmap=True,
with_reg_label=False,
num_levels=model.head.num_stages * model.head.num_units)
]
elif isinstance(model.head, CPMHead):
batch_data_samples = [
inputs['data_sample'] for inputs in get_packed_inputs(
batch_size=batch_size,
num_instances=1,
num_keypoints=model.head.out_channels,
heatmap_size=model.head.decoder.heatmap_size[::-1],
with_heatmap=True,
with_reg_label=False)
]
elif isinstance(model.head, SimCCHead):
# bug
batch_data_samples = [
inputs['data_sample'] for inputs in get_packed_inputs(
batch_size,
num_keypoints=model.head.out_channels,
simcc_split_ratio=model.head.decoder.simcc_split_ratio,
input_size=model.head.decoder.input_size,
with_simcc_label=True)
]
elif isinstance(model.head, ViPNASHead):
batch_data_samples = [
inputs['data_sample'] for inputs in get_packed_inputs(
batch_size,
num_keypoints=model.head.out_channels,
)
]
elif isinstance(model.head, DSNTHead):
batch_data_samples = [
inputs['data_sample'] for inputs in get_packed_inputs(
batch_size,
num_keypoints=model.head.num_joints,
with_reg_label=True)
]
elif isinstance(model.head, IntegralRegressionHead):
batch_data_samples = [
inputs['data_sample'] for inputs in get_packed_inputs(
batch_size,
num_keypoints=model.head.num_joints,
with_reg_label=True)
]
elif isinstance(model.head, RegressionHead):
batch_data_samples = [
inputs['data_sample'] for inputs in get_packed_inputs(
batch_size,
num_keypoints=model.head.num_joints,
with_reg_label=True)
]
elif isinstance(model.head, RLEHead):
batch_data_samples = [
inputs['data_sample'] for inputs in get_packed_inputs(
batch_size,
num_keypoints=model.head.num_joints,
with_reg_label=True)
]
else:
raise AssertionError('Head Type is Not Predefined')
mm_inputs = {
'inputs': torch.FloatTensor(imgs),
'data_samples': batch_data_samples
}
# check data preprocessor
if not hasattr(model,
'data_preprocessor') or model.data_preprocessor is None:
model.data_preprocessor = PoseDataPreProcessor()
mm_inputs = model.data_preprocessor(mm_inputs, for_training)
return mm_inputs

View File

@ -647,8 +647,27 @@ class MMSegModelLibrary(MMModelLibrary):
return config
# tools
class MMPoseModelLibrary(MMModelLibrary):
default_includes: List = [
'hand',
'face',
'wholebody',
'body',
'animal',
]
base_config_path = '/'
repo = 'mmpose'
def __init__(self, include=default_includes, exclude=[]) -> None:
super().__init__(include, exclude=exclude)
@classmethod
def _config_process(cls, config: Dict):
config['_scope_'] = 'mmpose'
return config
# tools
def revert_sync_batchnorm(module):
# this is very similar to the function that it is trying to revert:

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .model_library import (MMClsModelLibrary, MMDetModelLibrary,
DefaultModelLibrary, TorchModelLibrary,
MMSegModelLibrary)
MMPoseModelLibrary, MMSegModelLibrary)
class PassedModelManager:
@ -32,6 +32,7 @@ class FxPassedModelManager(PassedModelManager):
_mmcls_library = None
_mmseg_library = None
_mmdet_library = None
_mmpose_library = None
def libraries(self, full=False):
if full:
@ -41,6 +42,7 @@ class FxPassedModelManager(PassedModelManager):
self.__class__.mmcls_library(),
self.__class__.mmseg_library(),
self.__class__.mmdet_library(),
self.__class__.mmpose_library(),
]
else:
return [self.__class__.default_library()]
@ -304,6 +306,21 @@ class FxPassedModelManager(PassedModelManager):
cls._mmseg_library = MMSegModelLibrary(include=include)
return cls._mmseg_library
@classmethod
def mmpose_library(cls):
mmpose_include = [
'hand',
'face',
'wholebody',
'body',
'animal',
]
if cls._mmpose_library is None:
cls._mmpose_library = MMPoseModelLibrary(include=mmpose_include)
return cls._mmpose_library
# for backward tracer
@ -314,6 +331,8 @@ class BackwardPassedModelManager(PassedModelManager):
_mmcls_library = None
_mmseg_library = None
_mmdet_library = None
_mmpose_library = None
def libraries(self, full=False):
if full:
@ -323,6 +342,7 @@ class BackwardPassedModelManager(PassedModelManager):
self.__class__.mmcls_library(),
self.__class__.mmseg_library(),
self.__class__.mmdet_library(),
self.__class__.mmpose_library(),
]
else:
return [self.__class__.default_library()]
@ -481,6 +501,20 @@ class BackwardPassedModelManager(PassedModelManager):
cls._mmseg_library = MMSegModelLibrary(include=include)
return cls._mmseg_library
@classmethod
def mmpose_library(cls):
mmpose_include = [
'hand',
'face',
'wholebody',
'body',
'animal',
]
if cls._mmpose_library is None:
cls._mmpose_library = MMPoseModelLibrary(include=mmpose_include)
return cls._mmpose_library
fx_passed_library = FxPassedModelManager()
backward_passed_library = BackwardPassedModelManager()

View File

@ -6,8 +6,8 @@ import torch
from .data.model_library import (DefaultModelLibrary, MMClsModelLibrary,
MMDetModelLibrary, MMModelLibrary,
MMSegModelLibrary, ModelGenerator,
TorchModelLibrary)
MMPoseModelLibrary, MMSegModelLibrary,
ModelGenerator, TorchModelLibrary)
from .data.models import SingleLineModel
from .data.tracer_passed_models import (BackwardPassedModelManager,
FxPassedModelManager)
@ -45,6 +45,16 @@ class TestModelLibrary(unittest.TestCase):
if not TEST_DATA:
self.skipTest('not test data to save time.')
library = MMSegModelLibrary()
print(library.short_names())
self.assertTrue(library.is_default_includes_cover_all_models())
# New
def test_mmpose(self):
if not TEST_DATA:
self.skipTest('not test data to save time.')
library = MMPoseModelLibrary()
print(library.short_names())
self.assertTrue(library.is_default_includes_cover_all_models())
def test_get_model_by_config(self):