diff --git a/.gitignore b/.gitignore index e1496625..32163484 100644 --- a/.gitignore +++ b/.gitignore @@ -113,6 +113,7 @@ venv.bak/ *.log.json /work_dirs /mmrazor/.mim +*.out # Pytorch *.pth diff --git a/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py b/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py index e28588ec..75a1db29 100644 --- a/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py +++ b/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict + import torch.nn as nn from mmengine.model import BaseModel @@ -6,8 +8,8 @@ from mmrazor.registry import TASK_UTILS from mmrazor.utils import get_placeholder from .demo_inputs import (BaseDemoInput, DefaultMMClsDemoInput, DefaultMMDemoInput, DefaultMMDetDemoInput, - DefaultMMRotateDemoInput, DefaultMMSegDemoInput, - DefaultMMYoloDemoInput) + DefaultMMPoseDemoInput, DefaultMMRotateDemoInput, + DefaultMMSegDemoInput, DefaultMMYoloDemoInput) try: from mmdet.models import BaseDetector @@ -24,19 +26,28 @@ try: except Exception: BaseSegmentor = get_placeholder('mmseg') -default_demo_input_class = { - BaseDetector: DefaultMMDetDemoInput, - ImageClassifier: DefaultMMClsDemoInput, - BaseSegmentor: DefaultMMSegDemoInput, - BaseModel: DefaultMMDemoInput, - nn.Module: BaseDemoInput -} +# New +try: + from mmpose.models import TopdownPoseEstimator +except Exception: + TopdownPoseEstimator = get_placeholder('mmpose') + +default_demo_input_class = OrderedDict([ + (BaseDetector, DefaultMMDetDemoInput), + (ImageClassifier, DefaultMMClsDemoInput), + (BaseSegmentor, DefaultMMSegDemoInput), + (TopdownPoseEstimator, DefaultMMPoseDemoInput), + (BaseModel, DefaultMMDemoInput), + (nn.Module, BaseDemoInput), +]) + default_demo_input_class_for_scope = { 'mmcls': DefaultMMClsDemoInput, 'mmdet': DefaultMMDetDemoInput, 'mmseg': DefaultMMSegDemoInput, 'mmrotate': DefaultMMRotateDemoInput, 'mmyolo': DefaultMMYoloDemoInput, + 'mmpose': DefaultMMPoseDemoInput, 'torchvision': BaseDemoInput, } diff --git a/mmrazor/models/task_modules/demo_inputs/demo_inputs.py b/mmrazor/models/task_modules/demo_inputs/demo_inputs.py index d8fa8bcf..8664f3a2 100644 --- a/mmrazor/models/task_modules/demo_inputs/demo_inputs.py +++ b/mmrazor/models/task_modules/demo_inputs/demo_inputs.py @@ -122,3 +122,17 @@ class DefaultMMYoloDemoInput(DefaultMMDetDemoInput): """Default demo input generator for mmyolo models.""" default_shape = (1, 3, 125, 320) + + +@TASK_UTILS.register_module() +class DefaultMMPoseDemoInput(DefaultMMDemoInput): + """Default demo input generator for mmpose models.""" + + def _get_mm_data(self, model, input_shape, training=False): + from mmpose.models import TopdownPoseEstimator + + from .mmpose_demo_input import demo_mmpose_inputs + assert isinstance(model, TopdownPoseEstimator) + + data = demo_mmpose_inputs(model, input_shape) + return data diff --git a/mmrazor/models/task_modules/demo_inputs/mmpose_demo_input.py b/mmrazor/models/task_modules/demo_inputs/mmpose_demo_input.py new file mode 100644 index 00000000..98ebb961 --- /dev/null +++ b/mmrazor/models/task_modules/demo_inputs/mmpose_demo_input.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Include functions to generate mmpose demo inputs. + +Modified from mmpose. +""" + +import torch +from mmpose.models.heads import (CPMHead, DSNTHead, HeatmapHead, + IntegralRegressionHead, MSPNHead, + RegressionHead, RLEHead, SimCCHead, + ViPNASHead) +from mmpose.testing._utils import get_packed_inputs + +from mmrazor.utils import get_placeholder + +try: + from mmpose.models import PoseDataPreProcessor + from mmpose.structures import PoseDataSample +except ImportError: + PoseDataPreProcessor = get_placeholder('mmpose') + PoseDataSample = get_placeholder('mmpose') + + +def demo_mmpose_inputs(model, for_training=False, batch_size=1): + + input_shape = ( + 1, + 3, + ) + model.head.decoder.input_size + imgs = torch.randn(*input_shape) + + batch_data_samples = [] + + if isinstance(model.head, HeatmapHead): + batch_data_samples = [ + inputs['data_sample'] for inputs in get_packed_inputs( + batch_size, + num_keypoints=model.head.out_channels, + heatmap_size=model.head.decoder.heatmap_size[::-1]) + ] + elif isinstance(model.head, MSPNHead): + batch_data_samples = [ + inputs['data_sample'] for inputs in get_packed_inputs( + batch_size=batch_size, + num_instances=1, + num_keypoints=model.head.out_channels, + heatmap_size=model.head.decoder.heatmap_size, + with_heatmap=True, + with_reg_label=False, + num_levels=model.head.num_stages * model.head.num_units) + ] + elif isinstance(model.head, CPMHead): + batch_data_samples = [ + inputs['data_sample'] for inputs in get_packed_inputs( + batch_size=batch_size, + num_instances=1, + num_keypoints=model.head.out_channels, + heatmap_size=model.head.decoder.heatmap_size[::-1], + with_heatmap=True, + with_reg_label=False) + ] + elif isinstance(model.head, SimCCHead): + # bug + batch_data_samples = [ + inputs['data_sample'] for inputs in get_packed_inputs( + batch_size, + num_keypoints=model.head.out_channels, + simcc_split_ratio=model.head.decoder.simcc_split_ratio, + input_size=model.head.decoder.input_size, + with_simcc_label=True) + ] + elif isinstance(model.head, ViPNASHead): + batch_data_samples = [ + inputs['data_sample'] for inputs in get_packed_inputs( + batch_size, + num_keypoints=model.head.out_channels, + ) + ] + elif isinstance(model.head, DSNTHead): + batch_data_samples = [ + inputs['data_sample'] for inputs in get_packed_inputs( + batch_size, + num_keypoints=model.head.num_joints, + with_reg_label=True) + ] + elif isinstance(model.head, IntegralRegressionHead): + batch_data_samples = [ + inputs['data_sample'] for inputs in get_packed_inputs( + batch_size, + num_keypoints=model.head.num_joints, + with_reg_label=True) + ] + elif isinstance(model.head, RegressionHead): + batch_data_samples = [ + inputs['data_sample'] for inputs in get_packed_inputs( + batch_size, + num_keypoints=model.head.num_joints, + with_reg_label=True) + ] + elif isinstance(model.head, RLEHead): + batch_data_samples = [ + inputs['data_sample'] for inputs in get_packed_inputs( + batch_size, + num_keypoints=model.head.num_joints, + with_reg_label=True) + ] + else: + raise AssertionError('Head Type is Not Predefined') + + mm_inputs = { + 'inputs': torch.FloatTensor(imgs), + 'data_samples': batch_data_samples + } + + # check data preprocessor + if not hasattr(model, + 'data_preprocessor') or model.data_preprocessor is None: + model.data_preprocessor = PoseDataPreProcessor() + + mm_inputs = model.data_preprocessor(mm_inputs, for_training) + + return mm_inputs diff --git a/tests/data/model_library.py b/tests/data/model_library.py index d2783a63..d917dcc3 100644 --- a/tests/data/model_library.py +++ b/tests/data/model_library.py @@ -647,8 +647,27 @@ class MMSegModelLibrary(MMModelLibrary): return config -# tools +class MMPoseModelLibrary(MMModelLibrary): + default_includes: List = [ + 'hand', + 'face', + 'wholebody', + 'body', + 'animal', + ] + base_config_path = '/' + repo = 'mmpose' + def __init__(self, include=default_includes, exclude=[]) -> None: + super().__init__(include, exclude=exclude) + + @classmethod + def _config_process(cls, config: Dict): + config['_scope_'] = 'mmpose' + return config + + +# tools def revert_sync_batchnorm(module): # this is very similar to the function that it is trying to revert: diff --git a/tests/data/tracer_passed_models.py b/tests/data/tracer_passed_models.py index 7e316357..ade28214 100644 --- a/tests/data/tracer_passed_models.py +++ b/tests/data/tracer_passed_models.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .model_library import (MMClsModelLibrary, MMDetModelLibrary, DefaultModelLibrary, TorchModelLibrary, - MMSegModelLibrary) + MMPoseModelLibrary, MMSegModelLibrary) class PassedModelManager: @@ -32,6 +32,7 @@ class FxPassedModelManager(PassedModelManager): _mmcls_library = None _mmseg_library = None _mmdet_library = None + _mmpose_library = None def libraries(self, full=False): if full: @@ -41,6 +42,7 @@ class FxPassedModelManager(PassedModelManager): self.__class__.mmcls_library(), self.__class__.mmseg_library(), self.__class__.mmdet_library(), + self.__class__.mmpose_library(), ] else: return [self.__class__.default_library()] @@ -304,6 +306,21 @@ class FxPassedModelManager(PassedModelManager): cls._mmseg_library = MMSegModelLibrary(include=include) return cls._mmseg_library + + @classmethod + def mmpose_library(cls): + mmpose_include = [ + 'hand', + 'face', + 'wholebody', + 'body', + 'animal', + ] + if cls._mmpose_library is None: + cls._mmpose_library = MMPoseModelLibrary(include=mmpose_include) + + return cls._mmpose_library + # for backward tracer @@ -314,6 +331,8 @@ class BackwardPassedModelManager(PassedModelManager): _mmcls_library = None _mmseg_library = None _mmdet_library = None + _mmpose_library = None + def libraries(self, full=False): if full: @@ -323,6 +342,7 @@ class BackwardPassedModelManager(PassedModelManager): self.__class__.mmcls_library(), self.__class__.mmseg_library(), self.__class__.mmdet_library(), + self.__class__.mmpose_library(), ] else: return [self.__class__.default_library()] @@ -481,6 +501,20 @@ class BackwardPassedModelManager(PassedModelManager): cls._mmseg_library = MMSegModelLibrary(include=include) return cls._mmseg_library + @classmethod + def mmpose_library(cls): + mmpose_include = [ + 'hand', + 'face', + 'wholebody', + 'body', + 'animal', + ] + + if cls._mmpose_library is None: + cls._mmpose_library = MMPoseModelLibrary(include=mmpose_include) + return cls._mmpose_library + fx_passed_library = FxPassedModelManager() backward_passed_library = BackwardPassedModelManager() diff --git a/tests/test_data.py b/tests/test_data.py index b0dc0770..df3e07f6 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -6,8 +6,8 @@ import torch from .data.model_library import (DefaultModelLibrary, MMClsModelLibrary, MMDetModelLibrary, MMModelLibrary, - MMSegModelLibrary, ModelGenerator, - TorchModelLibrary) + MMPoseModelLibrary, MMSegModelLibrary, + ModelGenerator, TorchModelLibrary) from .data.models import SingleLineModel from .data.tracer_passed_models import (BackwardPassedModelManager, FxPassedModelManager) @@ -45,6 +45,16 @@ class TestModelLibrary(unittest.TestCase): if not TEST_DATA: self.skipTest('not test data to save time.') library = MMSegModelLibrary() + print(library.short_names()) + + self.assertTrue(library.is_default_includes_cover_all_models()) + + # New + def test_mmpose(self): + if not TEST_DATA: + self.skipTest('not test data to save time.') + library = MMPoseModelLibrary() + print(library.short_names()) self.assertTrue(library.is_default_includes_cover_all_models()) def test_get_model_by_config(self):