diff --git a/configs/mmaction/video-recognition/video-recognition_2d_dynamic.py b/configs/mmaction/video-recognition/video-recognition_2d_dynamic.py new file mode 100644 index 000000000..548799a3f --- /dev/null +++ b/configs/mmaction/video-recognition/video-recognition_2d_dynamic.py @@ -0,0 +1,15 @@ +_base_ = ['./video-recognition_static.py'] + +onnx_config = dict( + dynamic_axes={ + 'input': { + 0: 'batch', + 1: 'num_crops * num_segs', + 3: 'height', + 4: 'width' + }, + 'output': { + 0: 'batch', + } + }, + input_shape=None) diff --git a/configs/mmaction/video-recognition/video-recognition_2d_tensorrt_static-224x224.py b/configs/mmaction/video-recognition/video-recognition_2d_tensorrt_static-224x224.py new file mode 100644 index 000000000..828c4c112 --- /dev/null +++ b/configs/mmaction/video-recognition/video-recognition_2d_tensorrt_static-224x224.py @@ -0,0 +1,14 @@ +_base_ = ['./video-recognition_static.py', '../../_base_/backends/tensorrt.py'] + +onnx_config = dict(input_shape=[224, 224]) + +backend_config = dict( + common_config=dict(max_workspace_size=1 << 30), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 250, 3, 224, 224], + opt_shape=[1, 250, 3, 224, 224], + max_shape=[1, 250, 3, 224, 224]))) + ]) diff --git a/configs/mmaction/video-recognition/video-recognition_3d_dynamic.py b/configs/mmaction/video-recognition/video-recognition_3d_dynamic.py new file mode 100644 index 000000000..78da11876 --- /dev/null +++ b/configs/mmaction/video-recognition/video-recognition_3d_dynamic.py @@ -0,0 +1,16 @@ +_base_ = ['./video-recognition_static.py'] + +onnx_config = dict( + dynamic_axes={ + 'input': { + 0: 'batch', + 1: 'num_crops * num_segs', + 3: 'time', + 4: 'height', + 5: 'width' + }, + 'output': { + 0: 'batch', + } + }, + input_shape=None) diff --git a/configs/mmaction/video-recognition/video-recognition_3d_tensorrt_static-256x256.py b/configs/mmaction/video-recognition/video-recognition_3d_tensorrt_static-256x256.py new file mode 100644 index 000000000..1e20414fd --- /dev/null +++ b/configs/mmaction/video-recognition/video-recognition_3d_tensorrt_static-256x256.py @@ -0,0 +1,14 @@ +_base_ = ['./video-recognition_static.py', '../../_base_/backends/tensorrt.py'] + +onnx_config = dict(input_shape=[256, 256]) + +backend_config = dict( + common_config=dict(max_workspace_size=1 << 30), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 30, 3, 32, 256, 256], + opt_shape=[1, 30, 3, 32, 256, 256], + max_shape=[1, 30, 3, 32, 256, 256]))) + ]) diff --git a/configs/mmaction/video-recognition/video-recognition_onnxruntime_static.py b/configs/mmaction/video-recognition/video-recognition_onnxruntime_static.py new file mode 100644 index 000000000..fec2822ed --- /dev/null +++ b/configs/mmaction/video-recognition/video-recognition_onnxruntime_static.py @@ -0,0 +1,5 @@ +_base_ = [ + './video-recognition_static.py', '../../_base_/backends/onnxruntime.py' +] + +onnx_config = dict(input_shape=None) diff --git a/configs/mmaction/video-recognition/video-recognition_static.py b/configs/mmaction/video-recognition/video-recognition_static.py new file mode 100644 index 000000000..c4c9d5928 --- /dev/null +++ b/configs/mmaction/video-recognition/video-recognition_static.py @@ -0,0 +1,3 @@ +_base_ = ['../../_base_/onnx_config.py'] + +codebase_config = dict(type='mmaction', task='VideoRecognition') diff --git a/mmdeploy/codebase/mmaction/__init__.py b/mmdeploy/codebase/mmaction/__init__.py new file mode 100644 index 000000000..1cdbf57bd --- /dev/null +++ b/mmdeploy/codebase/mmaction/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from .deploy import * # noqa: F401,F403 +from .models import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmaction/deploy/__init__.py b/mmdeploy/codebase/mmaction/deploy/__init__.py new file mode 100644 index 000000000..5bd37d477 --- /dev/null +++ b/mmdeploy/codebase/mmaction/deploy/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from .mmaction import MMACTION +from .video_recognition import VideoRecognition + +__all__ = ['MMACTION', 'VideoRecognition'] diff --git a/mmdeploy/codebase/mmaction/deploy/mmaction.py b/mmdeploy/codebase/mmaction/deploy/mmaction.py new file mode 100644 index 000000000..4f018901f --- /dev/null +++ b/mmdeploy/codebase/mmaction/deploy/mmaction.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.registry import Registry + +from mmdeploy.codebase.base import CODEBASE, MMCodebase +from mmdeploy.utils import Codebase + +MMACTION_TASK = Registry('mmaction_tasks') + + +@CODEBASE.register_module(Codebase.MMACTION.value) +class MMACTION(MMCodebase): + """MMAction codebase class.""" + + task_registry = MMACTION_TASK + + @classmethod + def register_all_modules(cls): + from mmaction.utils.setup_env import register_all_modules + register_all_modules(True) diff --git a/mmdeploy/codebase/mmaction/deploy/video_recognition.py b/mmdeploy/codebase/mmaction/deploy/video_recognition.py new file mode 100644 index 000000000..ca9628a0c --- /dev/null +++ b/mmdeploy/codebase/mmaction/deploy/video_recognition.py @@ -0,0 +1,273 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from operator import itemgetter +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import mmengine +import numpy as np +import torch +from mmengine.dataset import pseudo_collate +from mmengine.model import BaseDataPreprocessor + +from mmdeploy.codebase.base import BaseTask +from mmdeploy.utils import Task, get_root_logger +from mmdeploy.utils.config_utils import get_input_shape +from .mmaction import MMACTION_TASK + + +def process_model_config(model_cfg: mmengine.Config, + imgs: Union[Sequence[str], Sequence[np.ndarray]], + input_shape: Optional[Sequence[int]] = None): + """Process the model config. + + Args: + model_cfg (mmengine.Config): The model config. + imgs (Sequence[str] | Sequence[np.ndarray]): Input image(s), accepted + data type are List[str], List[np.ndarray]. + input_shape (list[int]): A list of two integer in (width, height) + format specifying input shape. Default: None. + + Returns: + mmengine.Config: the model config after processing. + """ + logger = get_root_logger() + cfg = model_cfg.deepcopy() + test_pipeline_cfg = cfg.test_pipeline + if 'Init' not in test_pipeline_cfg[0]['type']: + test_pipeline_cfg = [dict(type='OpenCVInit')] + test_pipeline_cfg + else: + test_pipeline_cfg[0] = dict(type='OpenCVInit') + for i, trans in enumerate(test_pipeline_cfg): + if 'Decode' in trans['type']: + test_pipeline_cfg[i] = dict(type='OpenCVDecode') + cfg.test_pipeline = test_pipeline_cfg + + # check whether input_shape is valid + if input_shape is not None: + has_crop = False + crop_size = -1 + has_resize = False + scale = (-1, -1) + keep_ratio = True + for trans in cfg.test_pipeline: + if trans['type'] == 'Resize': + has_resize = True + keep_ratio = trans.get('keep_ratio', True) + scale = trans.scale + if trans['type'] in ['TenCrop', 'CenterCrop', 'ThreeCrop']: + has_crop = True + crop_size = trans.crop_size + + if has_crop and tuple(input_shape) != (crop_size, crop_size): + logger.error( + f'`input shape` should be equal to `crop_size`: {crop_size},' + f' but given: {input_shape}') + if has_resize and (not has_crop): + if keep_ratio: + logger.error('Resize should set `keep_ratio` to False' + ' when `input shape` is given.') + if tuple(input_shape) != scale: + logger.error( + f'`input shape` should be equal to `scale`: {scale},' + f' but given: {input_shape}') + return cfg + + +@MMACTION_TASK.register_module(Task.VIDEO_RECOGNITION.value) +class VideoRecognition(BaseTask): + """VideoRecognition task class. + + Args: + model_cfg (Config): Original PyTorch model config file. + deploy_cfg (Config): Deployment config file or loaded Config + object. + device (str): A string represents device type. + """ + + def __init__(self, model_cfg: mmengine.Config, deploy_cfg: mmengine.Config, + device: str): + super(VideoRecognition, self).__init__(model_cfg, deploy_cfg, device) + + def build_data_preprocessor(self): + model_cfg = self.model_cfg + preprocess_cfg = model_cfg.get('preprocess_cfg', None) + from mmengine.registry import MODELS + if preprocess_cfg is not None: + data_preprocessor = MODELS.build(preprocess_cfg) + else: + data_preprocessor = BaseDataPreprocessor() + + return data_preprocessor + + def build_backend_model(self, + model_files: Sequence[str] = None, + **kwargs) -> torch.nn.Module: + """Initialize backend model. + + Args: + model_files (Sequence[str]): Input model files. + + Returns: + nn.Module: An initialized backend model. + """ + from .video_recognition_model import build_video_recognition_model + model = build_video_recognition_model( + model_files, self.model_cfg, self.deploy_cfg, device=self.device) + model.to(self.device) + model.eval() + return model + + def create_input(self, + imgs: Union[str, np.ndarray], + input_shape: Sequence[int] = None, + data_preprocessor: Optional[BaseDataPreprocessor] = None)\ + -> Tuple[Dict, torch.Tensor]: + """Create input for video recognition. + + Args: + imgs (str | np.ndarray): Input image(s), accepted data type are + `str`, `np.ndarray`. + input_shape (list[int]): A list of two integer in (width, height) + format specifying input shape. Defaults to `None`. + + Returns: + tuple: (data, img), meta information for the input image and input. + """ + if isinstance(imgs, (list, tuple)): + if not all(isinstance(img, str) for img in imgs): + raise AssertionError('imgs must be strings') + elif isinstance(imgs, str): + imgs = [imgs] + else: + raise AssertionError('imgs must be strings') + + from mmcv.transforms.wrappers import Compose + model_cfg = process_model_config(self.model_cfg, imgs, input_shape) + test_pipeline = Compose(model_cfg.test_pipeline) + + data = [] + for img in imgs: + data_ = dict(filename=img, label=-1, start_index=0, modality='RGB') + data_ = test_pipeline(data_) + data.append(data_) + + data = pseudo_collate(data) + if data_preprocessor is not None: + data = data_preprocessor(data, False) + return data, data['inputs'] + else: + return data, BaseTask.get_tensor_from_input(data) + + def visualize(self, + image: str, + result: list, + output_file: str, + window_name: str = '', + show_result: bool = False, + **kwargs): + """Visualize predictions of a model. + + Args: + model (nn.Module): Input model. + image (str): Input video to draw predictions on. + result (list): A list of predictions. + output_file (str): Output file to save drawn image. + window_name (str): The name of visualization window. Defaults to + an empty string. + show_result (bool): Whether to show result in windows, defaults + to `False`. + """ + logger = get_root_logger() + try: + import decord + from moviepy.editor import ImageSequenceClip + except Exception: + logger.warn('Please install moviepy and decord to ' + 'enable visualize for mmaction') + + save_dir, save_name = osp.split(output_file) + video = decord.VideoReader(image) + frames = [x.asnumpy()[..., ::-1] for x in video] + pred_scores = result.pred_scores.item.tolist() + score_tuples = tuple(zip(range(len(pred_scores)), pred_scores)) + score_sorted = sorted(score_tuples, key=itemgetter(1), reverse=True) + top1_item = score_sorted[0] + short_edge_length = min(frames[0].shape[:2]) + scale = short_edge_length // 224. + img_scale = min(max(scale, 0.3), 3.0) + text_cfg = { + 'positions': np.array([(img_scale * 5, ) * 2]).astype(np.int32), + 'font_sizes': int(img_scale * 7), + 'font_families': 'monospace', + 'colors': 'white', + 'bboxes': dict(facecolor='black', alpha=0.5, boxstyle='Round') + } + + visualizer = self.get_visualizer(window_name, save_dir) + out_frames = [] + for i, frame in enumerate(frames): + visualizer.set_image(frame) + texts = [f'Frame {i} of total {len(frames)} frames'] + texts.append(f'top-1 label: {top1_item[0]}, score: {top1_item[0]}') + visualizer.draw_texts('\n'.join(texts), **text_cfg) + drawn_img = visualizer.get_image() + out_frames.append(drawn_img) + out_frames = [x[..., ::-1] for x in out_frames] + video_clips = ImageSequenceClip(out_frames, fps=30) + output_file = output_file[:output_file.rfind('.')] + '.mp4' + video_clips.write_videofile(output_file) + + @staticmethod + def get_partition_cfg(partition_type: str) -> Dict: + """Get a certain partition config. + + Args: + partition_type (str): A string specifying partition type. + + Returns: + dict: A dictionary of partition config. + """ + raise NotImplementedError('Not supported yet.') + + @staticmethod + def get_tensor_from_input(input_data: Dict[str, Any], + **kwargs) -> torch.Tensor: + """Get input tensor from input data. + + Args: + input_data (dict): Input data containing meta info and image + tensor. + Returns: + torch.Tensor: An image in `Tensor`. + """ + return input_data['inputs'] + + def get_preprocess(self) -> Dict: + """Get the preprocess information for SDK. + + Return: + dict: Composed of the preprocess information. + """ + input_shape = get_input_shape(self.deploy_cfg) + model_cfg = process_model_config(self.model_cfg, [''], input_shape) + preprocess = model_cfg.test_pipeline + return preprocess + + def get_postprocess(self) -> Dict: + """Get the postprocess information for SDK. + + Return: + dict: Composed of the postprocess information. + """ + postprocess = self.model_cfg.model.cls_head + return postprocess + + def get_model_name(self) -> str: + """Get the model name. + + Return: + str: the name of the model. + """ + assert 'type' in self.model_cfg.model, 'model config contains no type' + name = self.model_cfg.model.type.lower() + return name diff --git a/mmdeploy/codebase/mmaction/deploy/video_recognition_model.py b/mmdeploy/codebase/mmaction/deploy/video_recognition_model.py new file mode 100644 index 000000000..75f349bf3 --- /dev/null +++ b/mmdeploy/codebase/mmaction/deploy/video_recognition_model.py @@ -0,0 +1,150 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Any, List, Optional, Sequence, Union + +import mmengine +import torch +from mmaction.utils import LabelList +from mmengine import Config +from mmengine.model import BaseDataPreprocessor +from mmengine.registry import Registry +from mmengine.structures import BaseDataElement, LabelData + +from mmdeploy.codebase.base import BaseBackendModel +from mmdeploy.utils import (Backend, get_backend, get_codebase_config, + get_root_logger, load_config) + +__BACKEND_MODEL = Registry('backend_video_recognizer') + + +@__BACKEND_MODEL.register_module('end2end') +class End2EndModel(BaseBackendModel): + """End to end model for inference of video recognition. + + Args: + backend (Backend): The backend enum, specifying backend type. + backend_files (Sequence[str]): Paths to all required backend files(e.g. + '.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn). + device (str): A string represents device type. + deploy_cfg (str | mmengine.Config): Deployment config file or loaded + Config object. + model_cfg (str | mmengine.Config): Model config file or loaded Config + object. + """ + + def __init__(self, + backend: Backend, + backend_files: Sequence[str], + device: str, + deploy_cfg: Union[str, Config] = None, + model_cfg: Union[str, Config] = None, + **kwargs): + super(End2EndModel, self).__init__(deploy_cfg=deploy_cfg) + model_cfg, deploy_cfg = load_config(model_cfg, deploy_cfg) + from mmaction.registry import MODELS + preprocessor_cfg = model_cfg.model.get('data_preprocessor', None) + if preprocessor_cfg is not None: + self.data_preprocessor = MODELS.build( + model_cfg.model.data_preprocessor) + else: + self.data_preprocessor = BaseDataPreprocessor() + self.deploy_cfg = deploy_cfg + self.model_cfg = model_cfg + self._init_wrapper( + backend=backend, + backend_files=backend_files, + device=device, + **kwargs) + self.device = device + + def _init_wrapper(self, backend: Backend, backend_files: Sequence[str], + device: str, **kwargs): + """Initialize backend wrapper. + + Args: + backend (Backend): The backend enum, specifying backend type. + backend_files (Sequence[str]): Paths to all required backend files + (e.g. '.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn). + device (str): A string specifying device type. + """ + output_names = self.output_names + self.wrapper = BaseBackendModel._build_wrapper( + backend=backend, + backend_files=backend_files, + device=device, + input_names=[self.input_name], + output_names=output_names, + deploy_cfg=self.deploy_cfg, + **kwargs) + + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + mode: str = 'predict') -> Any: + """Run forward inference. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[``ActionDataSample``], optional): The + annotation data of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to ``predict``. + + Returns: + list: A list contains predictions. + """ + assert mode == 'predict', \ + 'Backend model only support mode==predict,' f' but get {mode}' + + if inputs.device != torch.device(self.device): + get_root_logger().warning(f'expect input device {self.device}' + f' but get {inputs.device}.') + inputs = inputs.to(self.device) + cls_scores = self.wrapper({self.input_name: + inputs})[self.output_names[0]] + + predictions: LabelList = [] + for score in cls_scores: + label = LabelData(item=score) + predictions.append(label) + + for data_sample, pred_instances in zip(data_samples, predictions): + data_sample.pred_scores = pred_instances + + return data_samples + + +def build_video_recognition_model(model_files: Sequence[str], + model_cfg: Union[str, mmengine.Config], + deploy_cfg: Union[str, mmengine.Config], + device: str, **kwargs): + """Build video recognition model for different backends. + + Args: + model_files (Sequence[str]): Input model file(s). + model_cfg (str | mmengine.Config): Input model config file or Config + object. + deploy_cfg (str | mmengine.Config): Input deployment config file or + Config object. + device (str): Device to input model. + + Returns: + BaseBackendModel: Video recognizer for a configured backend. + """ + # load cfg if necessary + deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) + + backend = get_backend(deploy_cfg) + model_type = get_codebase_config(deploy_cfg).get('model_type', 'end2end') + + backend_video_recognizer = __BACKEND_MODEL.build( + dict( + type=model_type, + backend=backend, + backend_files=model_files, + device=device, + deploy_cfg=deploy_cfg, + model_cfg=model_cfg, + **kwargs)) + + return backend_video_recognizer diff --git a/mmdeploy/codebase/mmaction/models/__init__.py b/mmdeploy/codebase/mmaction/models/__init__.py new file mode 100644 index 000000000..db721b1f3 --- /dev/null +++ b/mmdeploy/codebase/mmaction/models/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from .recognizers import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmaction/models/recognizers/__init__.py b/mmdeploy/codebase/mmaction/models/recognizers/__init__.py new file mode 100644 index 000000000..ff8a52482 --- /dev/null +++ b/mmdeploy/codebase/mmaction/models/recognizers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from .base import base_recognizer__forward + +__all__ = ['base_recognizer__forward'] diff --git a/mmdeploy/codebase/mmaction/models/recognizers/base.py b/mmdeploy/codebase/mmaction/models/recognizers/base.py new file mode 100644 index 000000000..5504f2166 --- /dev/null +++ b/mmdeploy/codebase/mmaction/models/recognizers/base.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmaction.utils import OptSampleList +from torch import Tensor + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + 'mmaction.models.recognizers.BaseRecognizer.forward') +def base_recognizer__forward(ctx, + self, + inputs: Tensor, + data_samples: OptSampleList = None, + mode: str = 'tensor', + **kwargs): + """Rewrite `forward` of Recognizer2D for default backend. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[``ActionDataSample``], optional): The + annotation data of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to ``tensor``. + + Returns: + return a list of `ActionDataSample` + """ + + assert mode == 'predict' + + feats, predict_kwargs = self.extract_feat(inputs, test_mode=True) + cls_scores = self.cls_head(feats, **predict_kwargs) + num_segs = cls_scores.shape[0] // len(data_samples) + cls_scores = self.cls_head.average_clip(cls_scores, num_segs=num_segs) + + return cls_scores diff --git a/mmdeploy/utils/constants.py b/mmdeploy/utils/constants.py index 427470516..3f12cfa84 100644 --- a/mmdeploy/utils/constants.py +++ b/mmdeploy/utils/constants.py @@ -27,6 +27,7 @@ class Task(AdvancedEnum): VOXEL_DETECTION = 'VoxelDetection' POSE_DETECTION = 'PoseDetection' ROTATED_DETECTION = 'RotatedDetection' + VIDEO_RECOGNITION = 'VideoRecognition' class Codebase(AdvancedEnum): @@ -39,6 +40,7 @@ class Codebase(AdvancedEnum): MMDET3D = 'mmdet3d' MMPOSE = 'mmpose' MMROTATE = 'mmrotate' + MMACTION = 'mmaction' class IR(AdvancedEnum): diff --git a/tests/regression/mmaction.yml b/tests/regression/mmaction.yml new file mode 100644 index 000000000..aed265696 --- /dev/null +++ b/tests/regression/mmaction.yml @@ -0,0 +1,56 @@ +globals: + codebase_dir: ../mmaction2 + checkpoint_force_download: False + images: + video: &video ../mmaction2/demo/demo.mp4 + + metric_info: &metric_info + Top 1 Accuracy: + metric_key: acc/top1 + tolerance: 1 + multi_value: 100 + dataset: Kinetics-400 + Top 5 Accuracy: + metric_key: acc/top5 + tolerance: 1 + multi_value: 100 + dataset: Kinetics-400 + convert_image: &convert_image + input_img: *video + test_img: *video + backend_test: &default_backend_test True + sdk: + sdk_dynamic: &sdk_dynamic "" + +onnxruntime: + pipeline_ort_static_fp32: &pipeline_ort_static_fp32 + convert_image: *convert_image + deploy_config: configs/mmaction/video-recognition/video-recognition_onnxruntime_static.py + backend_test: *default_backend_test + +tensorrt: + pipeline_trt_2d_static_fp32: &pipeline_trt_2d_static_fp32 + convert_image: *convert_image + deploy_config: configs/mmaction/video-recognition/video-recognition_2d_tensorrt_static-224x224.py + backend_test: *default_backend_test + pipeline_trt_3d_static_fp32: &pipeline_trt_3d_static_fp32 + convert_image: *convert_image + deploy_config: configs/mmaction/video-recognition/video-recognition_3d_tensorrt_static-256x256.py + backend_test: *default_backend_test + +models: + - name: TSN + metafile: configs/recognition/tsn/metafile.yml + model_configs: + - configs/recognition/tsn/tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb.py + pipelines: + - *pipeline_ort_static_fp32 + - *pipeline_trt_2d_static_fp32 + + - name: SlowFast + metafile: configs/recognition/slowfast/metafile.yml + model_configs: + - configs/recognition/slowfast/slowfast_r50_8xb8-4x16x1-256e_kinetics400-rgb.py + pipelines: + - *pipeline_ort_static_fp32 + - *pipeline_trt_3d_static_fp32 diff --git a/tests/test_codebase/test_mmaction/conftest.py b/tests/test_codebase/test_mmaction/conftest.py new file mode 100644 index 000000000..0bab54f0f --- /dev/null +++ b/tests/test_codebase/test_mmaction/conftest.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest + + +@pytest.fixture(autouse=True) +def init_test(): + # init default scope + from mmaction.utils import register_all_modules + + register_all_modules(True) diff --git a/tests/test_codebase/test_mmaction/data/ann.txt b/tests/test_codebase/test_mmaction/data/ann.txt new file mode 100644 index 000000000..ae06ee6d2 --- /dev/null +++ b/tests/test_codebase/test_mmaction/data/ann.txt @@ -0,0 +1 @@ +demo.mp4 6 diff --git a/tests/test_codebase/test_mmaction/data/model.py b/tests/test_codebase/test_mmaction/data/model.py new file mode 100644 index 000000000..f73a95277 --- /dev/null +++ b/tests/test_codebase/test_mmaction/data/model.py @@ -0,0 +1,186 @@ +# Copyright (c) OpenMMLab. All rights reserved. +model = dict( + type='Recognizer2D', + backbone=dict( + type='ResNet', + pretrained='https://download.pytorch.org/models/resnet50-11ad3fa6.pth', + depth=50, + norm_eval=False), + cls_head=dict( + type='TSNHead', + num_classes=400, + in_channels=2048, + spatial_type='avg', + consensus=dict(type='AvgConsensus', dim=1), + dropout_ratio=0.4, + init_std=0.01, + average_clips=None), + data_preprocessor=dict( + type='ActionDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + format_shape='NCHW'), + train_cfg=None, + test_cfg=None) +train_cfg = dict( + type='EpochBasedTrainLoop', max_epochs=100, val_begin=1, val_interval=1) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') +param_scheduler = [ + dict( + type='MultiStepLR', + begin=0, + end=100, + by_epoch=True, + milestones=[40, 80], + gamma=0.1) +] +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=40, norm_type=2)) +default_scope = 'mmaction' +default_hooks = dict( + runtime_info=dict(type='RuntimeInfoHook'), + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=20, ignore_last=False), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict( + type='CheckpointHook', interval=3, save_best='auto', max_keep_ckpts=3), + sampler_seed=dict(type='DistSamplerSeedHook')) +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl')) +log_processor = dict(type='LogProcessor', window_size=20, by_epoch=True) +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='ActionVisualizer', vis_backends=[dict(type='LocalVisBackend')]) +log_level = 'INFO' +load_from = None +resume = False +dataset_type = 'VideoDataset' +data_root = 'data/kinetics400/videos_train' +data_root_val = 'data/video' +ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt' +ann_file_val = 'data/ann.txt' +train_pipeline = [ + dict(type='DecordInit'), + dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=3), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict( + type='MultiScaleCrop', + input_size=224, + scales=(1, 0.875, 0.75, 0.66), + random_crop=False, + max_wh_scale_gap=1), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Flip', flip_ratio=0.5), + dict(type='FormatShape', input_format='NCHW'), + dict(type='PackActionInputs') +] +val_pipeline = [ + dict(type='DecordInit'), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=3, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='CenterCrop', crop_size=224), + dict(type='FormatShape', input_format='NCHW'), + dict(type='PackActionInputs') +] +test_pipeline = [ + dict(type='DecordInit'), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=25, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='TenCrop', crop_size=224), + dict(type='FormatShape', input_format='NCHW'), + dict(type='PackActionInputs') +] +train_dataloader = dict( + batch_size=32, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='VideoDataset', + ann_file='data/kinetics400/kinetics400_train_list_videos.txt', + data_prefix=dict(video='data/kinetics400/videos_train'), + pipeline=[ + dict(type='DecordInit'), + dict( + type='SampleFrames', clip_len=1, frame_interval=1, + num_clips=3), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict( + type='MultiScaleCrop', + input_size=224, + scales=(1, 0.875, 0.75, 0.66), + random_crop=False, + max_wh_scale_gap=1), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Flip', flip_ratio=0.5), + dict(type='FormatShape', input_format='NCHW'), + dict(type='PackActionInputs') + ])) +val_dataloader = dict( + batch_size=32, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type='VideoDataset', + ann_file='tests/test_codebase/test_mmaction/data/ann.txt', + data_prefix=dict(video='tests/test_codebase/test_mmaction/data/video'), + pipeline=[ + dict(type='DecordInit'), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=3, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='CenterCrop', crop_size=224), + dict(type='FormatShape', input_format='NCHW'), + dict(type='PackActionInputs') + ], + test_mode=True)) +test_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type='VideoDataset', + ann_file='tests/test_codebase/test_mmaction/data/ann.txt', + data_prefix=dict(video='tests/test_codebase/test_mmaction/data/video'), + pipeline=[ + dict(type='DecordInit'), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=25, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='TenCrop', crop_size=224), + dict(type='FormatShape', input_format='NCHW'), + dict(type='PackActionInputs') + ], + test_mode=True)) +val_evaluator = dict(type='AccMetric') +test_evaluator = dict(type='AccMetric') diff --git a/tests/test_codebase/test_mmaction/data/video/demo.mp4 b/tests/test_codebase/test_mmaction/data/video/demo.mp4 new file mode 100755 index 000000000..8a1ffbf2c Binary files /dev/null and b/tests/test_codebase/test_mmaction/data/video/demo.mp4 differ diff --git a/tests/test_codebase/test_mmaction/test_mmaction_model.py b/tests/test_codebase/test_mmaction/test_mmaction_model.py new file mode 100644 index 000000000..d16ec5033 --- /dev/null +++ b/tests/test_codebase/test_mmaction/test_mmaction_model.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch +from mmengine import Config + +from mmdeploy.codebase import import_codebase +from mmdeploy.utils import Backend, Codebase, load_config +from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs + +try: + import_codebase(Codebase.MMACTION) +except ImportError: + pytest.skip( + f'{Codebase.MMACTION} is not installed.', allow_module_level=True) + + +@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME]) +@pytest.mark.parametrize('model_cfg_path', + ['tests/test_codebase/test_mmaction/data/model.py']) +def test_forward_of_base_recognizer(model_cfg_path, backend): + check_backend(backend) + deploy_cfg = Config( + dict( + backend_config=dict(type='onnxruntime'), + codebase_config=dict(type='mmaction', task='VideoRecognition'), + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + input_shape=None, + input_names=['input'], + output_names=['output']))) + + model_cfg = load_config(model_cfg_path)[0] + from mmaction.apis import init_recognizer + model = init_recognizer(model_cfg, None, device='cpu') + + img = torch.randn(1, 3, 3, 224, 224) + from mmaction.structures import ActionDataSample + data_sample = ActionDataSample() + img_meta = dict(img_shape=(224, 224)) + data_sample.set_metainfo(img_meta) + rewrite_inputs = {'inputs': img} + wrapped_model = WrapModel( + model, 'forward', data_samples=[data_sample], mode='predict') + rewrite_outputs, _ = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + + assert rewrite_outputs is not None diff --git a/tests/test_codebase/test_mmaction/test_video_recognition.py b/tests/test_codebase/test_mmaction/test_video_recognition.py new file mode 100644 index 000000000..ef61df0db --- /dev/null +++ b/tests/test_codebase/test_mmaction/test_video_recognition.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from tempfile import NamedTemporaryFile, TemporaryDirectory + +import pytest +import torch +from mmengine import Config + +import mmdeploy.backend.onnxruntime as ort_apis +from mmdeploy.apis import build_task_processor +from mmdeploy.codebase import import_codebase +from mmdeploy.utils import Codebase, load_config +from mmdeploy.utils.test import SwitchBackendWrapper + +try: + import_codebase(Codebase.MMACTION) +except ImportError: + pytest.skip( + f'{Codebase.MMACTION} is not installed.', allow_module_level=True) + +model_cfg_path = 'tests/test_codebase/test_mmaction/data/model.py' +model_cfg = load_config(model_cfg_path)[0] +deploy_cfg = Config( + dict( + backend_config=dict(type='onnxruntime'), + codebase_config=dict(type='mmaction', task='VideoRecognition'), + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + input_shape=None, + input_names=['input'], + output_names=['output']))) + +onnx_file = NamedTemporaryFile(suffix='.onnx').name +task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu') +img_shape = (224, 224) +num_classes = 400 +video = 'tests/test_codebase/test_mmaction/data/video/demo.mp4' + + +@pytest.fixture +def backend_model(): + from mmdeploy.backend.onnxruntime import ORTWrapper + ort_apis.__dict__.update({'ORTWrapper': ORTWrapper}) + wrapper = SwitchBackendWrapper(ORTWrapper) + wrapper.set(outputs={ + 'output': torch.rand(1, num_classes), + }) + + yield task_processor.build_backend_model(['']) + + wrapper.recover() + + +def test_build_backend_model(backend_model): + assert isinstance(backend_model, torch.nn.Module) + + +def test_create_input(): + inputs = task_processor.create_input(video, input_shape=img_shape) + assert isinstance(inputs, tuple) and len(inputs) == 2 + + +def test_build_pytorch_model(): + from mmaction.models.recognizers.base import BaseRecognizer + model = task_processor.build_pytorch_model(None) + assert isinstance(model, BaseRecognizer) + + +def test_get_tensor_from_input(): + input_data = {'inputs': torch.ones(3, 4, 5)} + inputs = task_processor.get_tensor_from_input(input_data) + assert torch.equal(inputs, torch.ones(3, 4, 5)) + + +def test_get_model_name(): + model_name = task_processor.get_model_name() + assert isinstance(model_name, str) and model_name is not None + + +def test_build_dataset_and_dataloader(): + from torch.utils.data import DataLoader, Dataset + dataset = task_processor.build_dataset( + dataset_cfg=model_cfg.test_dataloader.dataset) + assert isinstance(dataset, Dataset), 'Failed to build dataset' + dataloader_cfg = task_processor.model_cfg.test_dataloader + dataloader = task_processor.build_dataloader(dataloader_cfg) + assert isinstance(dataloader, DataLoader), 'Failed to build dataloader' + + +def test_build_test_runner(backend_model): + from mmdeploy.codebase.base.runner import DeployTestRunner + temp_dir = TemporaryDirectory().name + runner = task_processor.build_test_runner(backend_model, temp_dir) + assert isinstance(runner, DeployTestRunner) + + +def test_get_preprocess(): + process = task_processor.get_preprocess() + assert process is not None + + +def test_get_postprocess(): + process = task_processor.get_postprocess() + assert isinstance(process, dict) diff --git a/tests/test_codebase/test_mmaction/test_video_recognition_model.py b/tests/test_codebase/test_mmaction/test_video_recognition_model.py new file mode 100644 index 000000000..5908ae763 --- /dev/null +++ b/tests/test_codebase/test_mmaction/test_video_recognition_model.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import pytest +import torch +from mmengine import Config + +import mmdeploy.backend.onnxruntime as ort_apis +from mmdeploy.codebase import import_codebase +from mmdeploy.utils import Backend, Codebase, load_config +from mmdeploy.utils.test import SwitchBackendWrapper, backend_checker + +IMAGE_SIZE = 224 + +try: + import_codebase(Codebase.MMACTION) +except ImportError: + pytest.skip( + f'{Codebase.MMACTION} is not installed.', allow_module_level=True) + + +@backend_checker(Backend.ONNXRUNTIME) +class TestEnd2EndModel: + + @classmethod + def setup_class(cls): + # force add backend wrapper regardless of plugins + from mmdeploy.backend.onnxruntime import ORTWrapper + ort_apis.__dict__.update({'ORTWrapper': ORTWrapper}) + + # simplify backend inference + cls.wrapper = SwitchBackendWrapper(ORTWrapper) + cls.outputs = { + 'outputs': torch.rand(1, 400), + } + cls.wrapper.set(outputs=cls.outputs) + deploy_cfg = Config({'onnx_config': {'output_names': ['outputs']}}) + model_cfg_path = 'tests/test_codebase/test_mmaction/data/model.py' + model_cfg = load_config(model_cfg_path)[0] + + from mmdeploy.codebase.mmaction.deploy.video_recognition_model import \ + End2EndModel + cls.end2end_model = End2EndModel( + Backend.ONNXRUNTIME, [''], + device='cpu', + deploy_cfg=deploy_cfg, + model_cfg=model_cfg) + + @classmethod + def teardown_class(cls): + cls.wrapper.recover() + + def test_forward(self): + inputs = torch.rand(1, 3, 3, IMAGE_SIZE, IMAGE_SIZE) + from mmaction.structures import ActionDataSample + data_sample = ActionDataSample( + metainfo=dict(img_shape=(IMAGE_SIZE, IMAGE_SIZE))) + results = self.end2end_model.forward( + inputs, [data_sample], mode='predict') + assert results is not None, 'failed to get output using '\ + 'End2EndModel' + + +@backend_checker(Backend.ONNXRUNTIME) +def test_build_video_recognition_model(): + model_cfg_path = 'tests/test_codebase/test_mmaction/data/model.py' + model_cfg = load_config(model_cfg_path)[0] + deploy_cfg = Config( + dict( + backend_config=dict(type='onnxruntime'), + onnx_config=dict(output_names=['outputs']), + codebase_config=dict(type='mmaction'))) + + from mmdeploy.backend.onnxruntime import ORTWrapper + ort_apis.__dict__.update({'ORTWrapper': ORTWrapper}) + + # simplify backend inference + with SwitchBackendWrapper(ORTWrapper) as wrapper: + wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg) + from mmdeploy.codebase.mmaction.deploy.video_recognition_model import ( + End2EndModel, build_video_recognition_model) + classifier = build_video_recognition_model([''], model_cfg, deploy_cfg, + 'cpu') + assert isinstance(classifier, End2EndModel)