CodeCamp #129 Testing the Robustness of the MMRazor Channel Dependency Resolution Tool on MMPOSE (#415)
* support mmpose tracer test * support mmpose tracer test 2 * support mmpose tracer test 2 * test trace on mmpose * test trace on mmpose * Delete run.sh * fix lint * restore models * clean code * note a bug for SimCCHead Co-authored-by: liukai <your_email@abc.example>pull/436/head
parent
67da3ad240
commit
25796d5437
mmrazor/models/task_modules/demo_inputs
|
@ -113,6 +113,7 @@ venv.bak/
|
|||
*.log.json
|
||||
/work_dirs
|
||||
/mmrazor/.mim
|
||||
*.out
|
||||
|
||||
# Pytorch
|
||||
*.pth
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch.nn as nn
|
||||
from mmengine.model import BaseModel
|
||||
|
||||
|
@ -6,8 +8,8 @@ from mmrazor.registry import TASK_UTILS
|
|||
from mmrazor.utils import get_placeholder
|
||||
from .demo_inputs import (BaseDemoInput, DefaultMMClsDemoInput,
|
||||
DefaultMMDemoInput, DefaultMMDetDemoInput,
|
||||
DefaultMMRotateDemoInput, DefaultMMSegDemoInput,
|
||||
DefaultMMYoloDemoInput)
|
||||
DefaultMMPoseDemoInput, DefaultMMRotateDemoInput,
|
||||
DefaultMMSegDemoInput, DefaultMMYoloDemoInput)
|
||||
|
||||
try:
|
||||
from mmdet.models import BaseDetector
|
||||
|
@ -24,19 +26,28 @@ try:
|
|||
except Exception:
|
||||
BaseSegmentor = get_placeholder('mmseg')
|
||||
|
||||
default_demo_input_class = {
|
||||
BaseDetector: DefaultMMDetDemoInput,
|
||||
ImageClassifier: DefaultMMClsDemoInput,
|
||||
BaseSegmentor: DefaultMMSegDemoInput,
|
||||
BaseModel: DefaultMMDemoInput,
|
||||
nn.Module: BaseDemoInput
|
||||
}
|
||||
# New
|
||||
try:
|
||||
from mmpose.models import TopdownPoseEstimator
|
||||
except Exception:
|
||||
TopdownPoseEstimator = get_placeholder('mmpose')
|
||||
|
||||
default_demo_input_class = OrderedDict([
|
||||
(BaseDetector, DefaultMMDetDemoInput),
|
||||
(ImageClassifier, DefaultMMClsDemoInput),
|
||||
(BaseSegmentor, DefaultMMSegDemoInput),
|
||||
(TopdownPoseEstimator, DefaultMMPoseDemoInput),
|
||||
(BaseModel, DefaultMMDemoInput),
|
||||
(nn.Module, BaseDemoInput),
|
||||
])
|
||||
|
||||
default_demo_input_class_for_scope = {
|
||||
'mmcls': DefaultMMClsDemoInput,
|
||||
'mmdet': DefaultMMDetDemoInput,
|
||||
'mmseg': DefaultMMSegDemoInput,
|
||||
'mmrotate': DefaultMMRotateDemoInput,
|
||||
'mmyolo': DefaultMMYoloDemoInput,
|
||||
'mmpose': DefaultMMPoseDemoInput,
|
||||
'torchvision': BaseDemoInput,
|
||||
}
|
||||
|
||||
|
|
|
@ -122,3 +122,17 @@ class DefaultMMYoloDemoInput(DefaultMMDetDemoInput):
|
|||
"""Default demo input generator for mmyolo models."""
|
||||
|
||||
default_shape = (1, 3, 125, 320)
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class DefaultMMPoseDemoInput(DefaultMMDemoInput):
|
||||
"""Default demo input generator for mmpose models."""
|
||||
|
||||
def _get_mm_data(self, model, input_shape, training=False):
|
||||
from mmpose.models import TopdownPoseEstimator
|
||||
|
||||
from .mmpose_demo_input import demo_mmpose_inputs
|
||||
assert isinstance(model, TopdownPoseEstimator)
|
||||
|
||||
data = demo_mmpose_inputs(model, input_shape)
|
||||
return data
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Include functions to generate mmpose demo inputs.
|
||||
|
||||
Modified from mmpose.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from mmpose.models.heads import (CPMHead, DSNTHead, HeatmapHead,
|
||||
IntegralRegressionHead, MSPNHead,
|
||||
RegressionHead, RLEHead, SimCCHead,
|
||||
ViPNASHead)
|
||||
from mmpose.testing._utils import get_packed_inputs
|
||||
|
||||
from mmrazor.utils import get_placeholder
|
||||
|
||||
try:
|
||||
from mmpose.models import PoseDataPreProcessor
|
||||
from mmpose.structures import PoseDataSample
|
||||
except ImportError:
|
||||
PoseDataPreProcessor = get_placeholder('mmpose')
|
||||
PoseDataSample = get_placeholder('mmpose')
|
||||
|
||||
|
||||
def demo_mmpose_inputs(model, for_training=False, batch_size=1):
|
||||
|
||||
input_shape = (
|
||||
1,
|
||||
3,
|
||||
) + model.head.decoder.input_size
|
||||
imgs = torch.randn(*input_shape)
|
||||
|
||||
batch_data_samples = []
|
||||
|
||||
if isinstance(model.head, HeatmapHead):
|
||||
batch_data_samples = [
|
||||
inputs['data_sample'] for inputs in get_packed_inputs(
|
||||
batch_size,
|
||||
num_keypoints=model.head.out_channels,
|
||||
heatmap_size=model.head.decoder.heatmap_size[::-1])
|
||||
]
|
||||
elif isinstance(model.head, MSPNHead):
|
||||
batch_data_samples = [
|
||||
inputs['data_sample'] for inputs in get_packed_inputs(
|
||||
batch_size=batch_size,
|
||||
num_instances=1,
|
||||
num_keypoints=model.head.out_channels,
|
||||
heatmap_size=model.head.decoder.heatmap_size,
|
||||
with_heatmap=True,
|
||||
with_reg_label=False,
|
||||
num_levels=model.head.num_stages * model.head.num_units)
|
||||
]
|
||||
elif isinstance(model.head, CPMHead):
|
||||
batch_data_samples = [
|
||||
inputs['data_sample'] for inputs in get_packed_inputs(
|
||||
batch_size=batch_size,
|
||||
num_instances=1,
|
||||
num_keypoints=model.head.out_channels,
|
||||
heatmap_size=model.head.decoder.heatmap_size[::-1],
|
||||
with_heatmap=True,
|
||||
with_reg_label=False)
|
||||
]
|
||||
elif isinstance(model.head, SimCCHead):
|
||||
# bug
|
||||
batch_data_samples = [
|
||||
inputs['data_sample'] for inputs in get_packed_inputs(
|
||||
batch_size,
|
||||
num_keypoints=model.head.out_channels,
|
||||
simcc_split_ratio=model.head.decoder.simcc_split_ratio,
|
||||
input_size=model.head.decoder.input_size,
|
||||
with_simcc_label=True)
|
||||
]
|
||||
elif isinstance(model.head, ViPNASHead):
|
||||
batch_data_samples = [
|
||||
inputs['data_sample'] for inputs in get_packed_inputs(
|
||||
batch_size,
|
||||
num_keypoints=model.head.out_channels,
|
||||
)
|
||||
]
|
||||
elif isinstance(model.head, DSNTHead):
|
||||
batch_data_samples = [
|
||||
inputs['data_sample'] for inputs in get_packed_inputs(
|
||||
batch_size,
|
||||
num_keypoints=model.head.num_joints,
|
||||
with_reg_label=True)
|
||||
]
|
||||
elif isinstance(model.head, IntegralRegressionHead):
|
||||
batch_data_samples = [
|
||||
inputs['data_sample'] for inputs in get_packed_inputs(
|
||||
batch_size,
|
||||
num_keypoints=model.head.num_joints,
|
||||
with_reg_label=True)
|
||||
]
|
||||
elif isinstance(model.head, RegressionHead):
|
||||
batch_data_samples = [
|
||||
inputs['data_sample'] for inputs in get_packed_inputs(
|
||||
batch_size,
|
||||
num_keypoints=model.head.num_joints,
|
||||
with_reg_label=True)
|
||||
]
|
||||
elif isinstance(model.head, RLEHead):
|
||||
batch_data_samples = [
|
||||
inputs['data_sample'] for inputs in get_packed_inputs(
|
||||
batch_size,
|
||||
num_keypoints=model.head.num_joints,
|
||||
with_reg_label=True)
|
||||
]
|
||||
else:
|
||||
raise AssertionError('Head Type is Not Predefined')
|
||||
|
||||
mm_inputs = {
|
||||
'inputs': torch.FloatTensor(imgs),
|
||||
'data_samples': batch_data_samples
|
||||
}
|
||||
|
||||
# check data preprocessor
|
||||
if not hasattr(model,
|
||||
'data_preprocessor') or model.data_preprocessor is None:
|
||||
model.data_preprocessor = PoseDataPreProcessor()
|
||||
|
||||
mm_inputs = model.data_preprocessor(mm_inputs, for_training)
|
||||
|
||||
return mm_inputs
|
|
@ -647,8 +647,27 @@ class MMSegModelLibrary(MMModelLibrary):
|
|||
return config
|
||||
|
||||
|
||||
# tools
|
||||
class MMPoseModelLibrary(MMModelLibrary):
|
||||
default_includes: List = [
|
||||
'hand',
|
||||
'face',
|
||||
'wholebody',
|
||||
'body',
|
||||
'animal',
|
||||
]
|
||||
base_config_path = '/'
|
||||
repo = 'mmpose'
|
||||
|
||||
def __init__(self, include=default_includes, exclude=[]) -> None:
|
||||
super().__init__(include, exclude=exclude)
|
||||
|
||||
@classmethod
|
||||
def _config_process(cls, config: Dict):
|
||||
config['_scope_'] = 'mmpose'
|
||||
return config
|
||||
|
||||
|
||||
# tools
|
||||
|
||||
def revert_sync_batchnorm(module):
|
||||
# this is very similar to the function that it is trying to revert:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .model_library import (MMClsModelLibrary, MMDetModelLibrary,
|
||||
DefaultModelLibrary, TorchModelLibrary,
|
||||
MMSegModelLibrary)
|
||||
MMPoseModelLibrary, MMSegModelLibrary)
|
||||
|
||||
|
||||
class PassedModelManager:
|
||||
|
@ -32,6 +32,7 @@ class FxPassedModelManager(PassedModelManager):
|
|||
_mmcls_library = None
|
||||
_mmseg_library = None
|
||||
_mmdet_library = None
|
||||
_mmpose_library = None
|
||||
|
||||
def libraries(self, full=False):
|
||||
if full:
|
||||
|
@ -41,6 +42,7 @@ class FxPassedModelManager(PassedModelManager):
|
|||
self.__class__.mmcls_library(),
|
||||
self.__class__.mmseg_library(),
|
||||
self.__class__.mmdet_library(),
|
||||
self.__class__.mmpose_library(),
|
||||
]
|
||||
else:
|
||||
return [self.__class__.default_library()]
|
||||
|
@ -304,6 +306,21 @@ class FxPassedModelManager(PassedModelManager):
|
|||
cls._mmseg_library = MMSegModelLibrary(include=include)
|
||||
return cls._mmseg_library
|
||||
|
||||
|
||||
@classmethod
|
||||
def mmpose_library(cls):
|
||||
mmpose_include = [
|
||||
'hand',
|
||||
'face',
|
||||
'wholebody',
|
||||
'body',
|
||||
'animal',
|
||||
]
|
||||
if cls._mmpose_library is None:
|
||||
cls._mmpose_library = MMPoseModelLibrary(include=mmpose_include)
|
||||
|
||||
return cls._mmpose_library
|
||||
|
||||
# for backward tracer
|
||||
|
||||
|
||||
|
@ -314,6 +331,8 @@ class BackwardPassedModelManager(PassedModelManager):
|
|||
_mmcls_library = None
|
||||
_mmseg_library = None
|
||||
_mmdet_library = None
|
||||
_mmpose_library = None
|
||||
|
||||
|
||||
def libraries(self, full=False):
|
||||
if full:
|
||||
|
@ -323,6 +342,7 @@ class BackwardPassedModelManager(PassedModelManager):
|
|||
self.__class__.mmcls_library(),
|
||||
self.__class__.mmseg_library(),
|
||||
self.__class__.mmdet_library(),
|
||||
self.__class__.mmpose_library(),
|
||||
]
|
||||
else:
|
||||
return [self.__class__.default_library()]
|
||||
|
@ -481,6 +501,20 @@ class BackwardPassedModelManager(PassedModelManager):
|
|||
cls._mmseg_library = MMSegModelLibrary(include=include)
|
||||
return cls._mmseg_library
|
||||
|
||||
@classmethod
|
||||
def mmpose_library(cls):
|
||||
mmpose_include = [
|
||||
'hand',
|
||||
'face',
|
||||
'wholebody',
|
||||
'body',
|
||||
'animal',
|
||||
]
|
||||
|
||||
if cls._mmpose_library is None:
|
||||
cls._mmpose_library = MMPoseModelLibrary(include=mmpose_include)
|
||||
return cls._mmpose_library
|
||||
|
||||
|
||||
fx_passed_library = FxPassedModelManager()
|
||||
backward_passed_library = BackwardPassedModelManager()
|
||||
|
|
|
@ -6,8 +6,8 @@ import torch
|
|||
|
||||
from .data.model_library import (DefaultModelLibrary, MMClsModelLibrary,
|
||||
MMDetModelLibrary, MMModelLibrary,
|
||||
MMSegModelLibrary, ModelGenerator,
|
||||
TorchModelLibrary)
|
||||
MMPoseModelLibrary, MMSegModelLibrary,
|
||||
ModelGenerator, TorchModelLibrary)
|
||||
from .data.models import SingleLineModel
|
||||
from .data.tracer_passed_models import (BackwardPassedModelManager,
|
||||
FxPassedModelManager)
|
||||
|
@ -45,6 +45,16 @@ class TestModelLibrary(unittest.TestCase):
|
|||
if not TEST_DATA:
|
||||
self.skipTest('not test data to save time.')
|
||||
library = MMSegModelLibrary()
|
||||
print(library.short_names())
|
||||
|
||||
self.assertTrue(library.is_default_includes_cover_all_models())
|
||||
|
||||
# New
|
||||
def test_mmpose(self):
|
||||
if not TEST_DATA:
|
||||
self.skipTest('not test data to save time.')
|
||||
library = MMPoseModelLibrary()
|
||||
print(library.short_names())
|
||||
self.assertTrue(library.is_default_includes_cover_all_models())
|
||||
|
||||
def test_get_model_by_config(self):
|
||||
|
|
Loading…
Reference in New Issue