mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
Support mmaction2:dev-1.x (#1012)
* support tsn * support slowfast * fix export info & End2EndModel * add test * fix forward * fix lint * update tests * add onnxruntime 2d config * fix ort-gpu * add mmaction.yml, need to update * fix reviews * add ann.txt * add visualize * fix lint * rebase * add conftest.py * fix circle ci * fix registry * fix regression test
This commit is contained in:
parent
72923e7844
commit
2020e74480
@ -0,0 +1,15 @@
|
||||
_base_ = ['./video-recognition_static.py']
|
||||
|
||||
onnx_config = dict(
|
||||
dynamic_axes={
|
||||
'input': {
|
||||
0: 'batch',
|
||||
1: 'num_crops * num_segs',
|
||||
3: 'height',
|
||||
4: 'width'
|
||||
},
|
||||
'output': {
|
||||
0: 'batch',
|
||||
}
|
||||
},
|
||||
input_shape=None)
|
@ -0,0 +1,14 @@
|
||||
_base_ = ['./video-recognition_static.py', '../../_base_/backends/tensorrt.py']
|
||||
|
||||
onnx_config = dict(input_shape=[224, 224])
|
||||
|
||||
backend_config = dict(
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 250, 3, 224, 224],
|
||||
opt_shape=[1, 250, 3, 224, 224],
|
||||
max_shape=[1, 250, 3, 224, 224])))
|
||||
])
|
@ -0,0 +1,16 @@
|
||||
_base_ = ['./video-recognition_static.py']
|
||||
|
||||
onnx_config = dict(
|
||||
dynamic_axes={
|
||||
'input': {
|
||||
0: 'batch',
|
||||
1: 'num_crops * num_segs',
|
||||
3: 'time',
|
||||
4: 'height',
|
||||
5: 'width'
|
||||
},
|
||||
'output': {
|
||||
0: 'batch',
|
||||
}
|
||||
},
|
||||
input_shape=None)
|
@ -0,0 +1,14 @@
|
||||
_base_ = ['./video-recognition_static.py', '../../_base_/backends/tensorrt.py']
|
||||
|
||||
onnx_config = dict(input_shape=[256, 256])
|
||||
|
||||
backend_config = dict(
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 30, 3, 32, 256, 256],
|
||||
opt_shape=[1, 30, 3, 32, 256, 256],
|
||||
max_shape=[1, 30, 3, 32, 256, 256])))
|
||||
])
|
@ -0,0 +1,5 @@
|
||||
_base_ = [
|
||||
'./video-recognition_static.py', '../../_base_/backends/onnxruntime.py'
|
||||
]
|
||||
|
||||
onnx_config = dict(input_shape=None)
|
@ -0,0 +1,3 @@
|
||||
_base_ = ['../../_base_/onnx_config.py']
|
||||
|
||||
codebase_config = dict(type='mmaction', task='VideoRecognition')
|
4
mmdeploy/codebase/mmaction/__init__.py
Normal file
4
mmdeploy/codebase/mmaction/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from .deploy import * # noqa: F401,F403
|
||||
from .models import * # noqa: F401,F403
|
6
mmdeploy/codebase/mmaction/deploy/__init__.py
Normal file
6
mmdeploy/codebase/mmaction/deploy/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from .mmaction import MMACTION
|
||||
from .video_recognition import VideoRecognition
|
||||
|
||||
__all__ = ['MMACTION', 'VideoRecognition']
|
19
mmdeploy/codebase/mmaction/deploy/mmaction.py
Normal file
19
mmdeploy/codebase/mmaction/deploy/mmaction.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.registry import Registry
|
||||
|
||||
from mmdeploy.codebase.base import CODEBASE, MMCodebase
|
||||
from mmdeploy.utils import Codebase
|
||||
|
||||
MMACTION_TASK = Registry('mmaction_tasks')
|
||||
|
||||
|
||||
@CODEBASE.register_module(Codebase.MMACTION.value)
|
||||
class MMACTION(MMCodebase):
|
||||
"""MMAction codebase class."""
|
||||
|
||||
task_registry = MMACTION_TASK
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
from mmaction.utils.setup_env import register_all_modules
|
||||
register_all_modules(True)
|
273
mmdeploy/codebase/mmaction/deploy/video_recognition.py
Normal file
273
mmdeploy/codebase/mmaction/deploy/video_recognition.py
Normal file
@ -0,0 +1,273 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from operator import itemgetter
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.dataset import pseudo_collate
|
||||
from mmengine.model import BaseDataPreprocessor
|
||||
|
||||
from mmdeploy.codebase.base import BaseTask
|
||||
from mmdeploy.utils import Task, get_root_logger
|
||||
from mmdeploy.utils.config_utils import get_input_shape
|
||||
from .mmaction import MMACTION_TASK
|
||||
|
||||
|
||||
def process_model_config(model_cfg: mmengine.Config,
|
||||
imgs: Union[Sequence[str], Sequence[np.ndarray]],
|
||||
input_shape: Optional[Sequence[int]] = None):
|
||||
"""Process the model config.
|
||||
|
||||
Args:
|
||||
model_cfg (mmengine.Config): The model config.
|
||||
imgs (Sequence[str] | Sequence[np.ndarray]): Input image(s), accepted
|
||||
data type are List[str], List[np.ndarray].
|
||||
input_shape (list[int]): A list of two integer in (width, height)
|
||||
format specifying input shape. Default: None.
|
||||
|
||||
Returns:
|
||||
mmengine.Config: the model config after processing.
|
||||
"""
|
||||
logger = get_root_logger()
|
||||
cfg = model_cfg.deepcopy()
|
||||
test_pipeline_cfg = cfg.test_pipeline
|
||||
if 'Init' not in test_pipeline_cfg[0]['type']:
|
||||
test_pipeline_cfg = [dict(type='OpenCVInit')] + test_pipeline_cfg
|
||||
else:
|
||||
test_pipeline_cfg[0] = dict(type='OpenCVInit')
|
||||
for i, trans in enumerate(test_pipeline_cfg):
|
||||
if 'Decode' in trans['type']:
|
||||
test_pipeline_cfg[i] = dict(type='OpenCVDecode')
|
||||
cfg.test_pipeline = test_pipeline_cfg
|
||||
|
||||
# check whether input_shape is valid
|
||||
if input_shape is not None:
|
||||
has_crop = False
|
||||
crop_size = -1
|
||||
has_resize = False
|
||||
scale = (-1, -1)
|
||||
keep_ratio = True
|
||||
for trans in cfg.test_pipeline:
|
||||
if trans['type'] == 'Resize':
|
||||
has_resize = True
|
||||
keep_ratio = trans.get('keep_ratio', True)
|
||||
scale = trans.scale
|
||||
if trans['type'] in ['TenCrop', 'CenterCrop', 'ThreeCrop']:
|
||||
has_crop = True
|
||||
crop_size = trans.crop_size
|
||||
|
||||
if has_crop and tuple(input_shape) != (crop_size, crop_size):
|
||||
logger.error(
|
||||
f'`input shape` should be equal to `crop_size`: {crop_size},'
|
||||
f' but given: {input_shape}')
|
||||
if has_resize and (not has_crop):
|
||||
if keep_ratio:
|
||||
logger.error('Resize should set `keep_ratio` to False'
|
||||
' when `input shape` is given.')
|
||||
if tuple(input_shape) != scale:
|
||||
logger.error(
|
||||
f'`input shape` should be equal to `scale`: {scale},'
|
||||
f' but given: {input_shape}')
|
||||
return cfg
|
||||
|
||||
|
||||
@MMACTION_TASK.register_module(Task.VIDEO_RECOGNITION.value)
|
||||
class VideoRecognition(BaseTask):
|
||||
"""VideoRecognition task class.
|
||||
|
||||
Args:
|
||||
model_cfg (Config): Original PyTorch model config file.
|
||||
deploy_cfg (Config): Deployment config file or loaded Config
|
||||
object.
|
||||
device (str): A string represents device type.
|
||||
"""
|
||||
|
||||
def __init__(self, model_cfg: mmengine.Config, deploy_cfg: mmengine.Config,
|
||||
device: str):
|
||||
super(VideoRecognition, self).__init__(model_cfg, deploy_cfg, device)
|
||||
|
||||
def build_data_preprocessor(self):
|
||||
model_cfg = self.model_cfg
|
||||
preprocess_cfg = model_cfg.get('preprocess_cfg', None)
|
||||
from mmengine.registry import MODELS
|
||||
if preprocess_cfg is not None:
|
||||
data_preprocessor = MODELS.build(preprocess_cfg)
|
||||
else:
|
||||
data_preprocessor = BaseDataPreprocessor()
|
||||
|
||||
return data_preprocessor
|
||||
|
||||
def build_backend_model(self,
|
||||
model_files: Sequence[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize backend model.
|
||||
|
||||
Args:
|
||||
model_files (Sequence[str]): Input model files.
|
||||
|
||||
Returns:
|
||||
nn.Module: An initialized backend model.
|
||||
"""
|
||||
from .video_recognition_model import build_video_recognition_model
|
||||
model = build_video_recognition_model(
|
||||
model_files, self.model_cfg, self.deploy_cfg, device=self.device)
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def create_input(self,
|
||||
imgs: Union[str, np.ndarray],
|
||||
input_shape: Sequence[int] = None,
|
||||
data_preprocessor: Optional[BaseDataPreprocessor] = None)\
|
||||
-> Tuple[Dict, torch.Tensor]:
|
||||
"""Create input for video recognition.
|
||||
|
||||
Args:
|
||||
imgs (str | np.ndarray): Input image(s), accepted data type are
|
||||
`str`, `np.ndarray`.
|
||||
input_shape (list[int]): A list of two integer in (width, height)
|
||||
format specifying input shape. Defaults to `None`.
|
||||
|
||||
Returns:
|
||||
tuple: (data, img), meta information for the input image and input.
|
||||
"""
|
||||
if isinstance(imgs, (list, tuple)):
|
||||
if not all(isinstance(img, str) for img in imgs):
|
||||
raise AssertionError('imgs must be strings')
|
||||
elif isinstance(imgs, str):
|
||||
imgs = [imgs]
|
||||
else:
|
||||
raise AssertionError('imgs must be strings')
|
||||
|
||||
from mmcv.transforms.wrappers import Compose
|
||||
model_cfg = process_model_config(self.model_cfg, imgs, input_shape)
|
||||
test_pipeline = Compose(model_cfg.test_pipeline)
|
||||
|
||||
data = []
|
||||
for img in imgs:
|
||||
data_ = dict(filename=img, label=-1, start_index=0, modality='RGB')
|
||||
data_ = test_pipeline(data_)
|
||||
data.append(data_)
|
||||
|
||||
data = pseudo_collate(data)
|
||||
if data_preprocessor is not None:
|
||||
data = data_preprocessor(data, False)
|
||||
return data, data['inputs']
|
||||
else:
|
||||
return data, BaseTask.get_tensor_from_input(data)
|
||||
|
||||
def visualize(self,
|
||||
image: str,
|
||||
result: list,
|
||||
output_file: str,
|
||||
window_name: str = '',
|
||||
show_result: bool = False,
|
||||
**kwargs):
|
||||
"""Visualize predictions of a model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Input model.
|
||||
image (str): Input video to draw predictions on.
|
||||
result (list): A list of predictions.
|
||||
output_file (str): Output file to save drawn image.
|
||||
window_name (str): The name of visualization window. Defaults to
|
||||
an empty string.
|
||||
show_result (bool): Whether to show result in windows, defaults
|
||||
to `False`.
|
||||
"""
|
||||
logger = get_root_logger()
|
||||
try:
|
||||
import decord
|
||||
from moviepy.editor import ImageSequenceClip
|
||||
except Exception:
|
||||
logger.warn('Please install moviepy and decord to '
|
||||
'enable visualize for mmaction')
|
||||
|
||||
save_dir, save_name = osp.split(output_file)
|
||||
video = decord.VideoReader(image)
|
||||
frames = [x.asnumpy()[..., ::-1] for x in video]
|
||||
pred_scores = result.pred_scores.item.tolist()
|
||||
score_tuples = tuple(zip(range(len(pred_scores)), pred_scores))
|
||||
score_sorted = sorted(score_tuples, key=itemgetter(1), reverse=True)
|
||||
top1_item = score_sorted[0]
|
||||
short_edge_length = min(frames[0].shape[:2])
|
||||
scale = short_edge_length // 224.
|
||||
img_scale = min(max(scale, 0.3), 3.0)
|
||||
text_cfg = {
|
||||
'positions': np.array([(img_scale * 5, ) * 2]).astype(np.int32),
|
||||
'font_sizes': int(img_scale * 7),
|
||||
'font_families': 'monospace',
|
||||
'colors': 'white',
|
||||
'bboxes': dict(facecolor='black', alpha=0.5, boxstyle='Round')
|
||||
}
|
||||
|
||||
visualizer = self.get_visualizer(window_name, save_dir)
|
||||
out_frames = []
|
||||
for i, frame in enumerate(frames):
|
||||
visualizer.set_image(frame)
|
||||
texts = [f'Frame {i} of total {len(frames)} frames']
|
||||
texts.append(f'top-1 label: {top1_item[0]}, score: {top1_item[0]}')
|
||||
visualizer.draw_texts('\n'.join(texts), **text_cfg)
|
||||
drawn_img = visualizer.get_image()
|
||||
out_frames.append(drawn_img)
|
||||
out_frames = [x[..., ::-1] for x in out_frames]
|
||||
video_clips = ImageSequenceClip(out_frames, fps=30)
|
||||
output_file = output_file[:output_file.rfind('.')] + '.mp4'
|
||||
video_clips.write_videofile(output_file)
|
||||
|
||||
@staticmethod
|
||||
def get_partition_cfg(partition_type: str) -> Dict:
|
||||
"""Get a certain partition config.
|
||||
|
||||
Args:
|
||||
partition_type (str): A string specifying partition type.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of partition config.
|
||||
"""
|
||||
raise NotImplementedError('Not supported yet.')
|
||||
|
||||
@staticmethod
|
||||
def get_tensor_from_input(input_data: Dict[str, Any],
|
||||
**kwargs) -> torch.Tensor:
|
||||
"""Get input tensor from input data.
|
||||
|
||||
Args:
|
||||
input_data (dict): Input data containing meta info and image
|
||||
tensor.
|
||||
Returns:
|
||||
torch.Tensor: An image in `Tensor`.
|
||||
"""
|
||||
return input_data['inputs']
|
||||
|
||||
def get_preprocess(self) -> Dict:
|
||||
"""Get the preprocess information for SDK.
|
||||
|
||||
Return:
|
||||
dict: Composed of the preprocess information.
|
||||
"""
|
||||
input_shape = get_input_shape(self.deploy_cfg)
|
||||
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
|
||||
preprocess = model_cfg.test_pipeline
|
||||
return preprocess
|
||||
|
||||
def get_postprocess(self) -> Dict:
|
||||
"""Get the postprocess information for SDK.
|
||||
|
||||
Return:
|
||||
dict: Composed of the postprocess information.
|
||||
"""
|
||||
postprocess = self.model_cfg.model.cls_head
|
||||
return postprocess
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""Get the model name.
|
||||
|
||||
Return:
|
||||
str: the name of the model.
|
||||
"""
|
||||
assert 'type' in self.model_cfg.model, 'model config contains no type'
|
||||
name = self.model_cfg.model.type.lower()
|
||||
return name
|
150
mmdeploy/codebase/mmaction/deploy/video_recognition_model.py
Normal file
150
mmdeploy/codebase/mmaction/deploy/video_recognition_model.py
Normal file
@ -0,0 +1,150 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmaction.utils import LabelList
|
||||
from mmengine import Config
|
||||
from mmengine.model import BaseDataPreprocessor
|
||||
from mmengine.registry import Registry
|
||||
from mmengine.structures import BaseDataElement, LabelData
|
||||
|
||||
from mmdeploy.codebase.base import BaseBackendModel
|
||||
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
|
||||
get_root_logger, load_config)
|
||||
|
||||
__BACKEND_MODEL = Registry('backend_video_recognizer')
|
||||
|
||||
|
||||
@__BACKEND_MODEL.register_module('end2end')
|
||||
class End2EndModel(BaseBackendModel):
|
||||
"""End to end model for inference of video recognition.
|
||||
|
||||
Args:
|
||||
backend (Backend): The backend enum, specifying backend type.
|
||||
backend_files (Sequence[str]): Paths to all required backend files(e.g.
|
||||
'.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn).
|
||||
device (str): A string represents device type.
|
||||
deploy_cfg (str | mmengine.Config): Deployment config file or loaded
|
||||
Config object.
|
||||
model_cfg (str | mmengine.Config): Model config file or loaded Config
|
||||
object.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backend: Backend,
|
||||
backend_files: Sequence[str],
|
||||
device: str,
|
||||
deploy_cfg: Union[str, Config] = None,
|
||||
model_cfg: Union[str, Config] = None,
|
||||
**kwargs):
|
||||
super(End2EndModel, self).__init__(deploy_cfg=deploy_cfg)
|
||||
model_cfg, deploy_cfg = load_config(model_cfg, deploy_cfg)
|
||||
from mmaction.registry import MODELS
|
||||
preprocessor_cfg = model_cfg.model.get('data_preprocessor', None)
|
||||
if preprocessor_cfg is not None:
|
||||
self.data_preprocessor = MODELS.build(
|
||||
model_cfg.model.data_preprocessor)
|
||||
else:
|
||||
self.data_preprocessor = BaseDataPreprocessor()
|
||||
self.deploy_cfg = deploy_cfg
|
||||
self.model_cfg = model_cfg
|
||||
self._init_wrapper(
|
||||
backend=backend,
|
||||
backend_files=backend_files,
|
||||
device=device,
|
||||
**kwargs)
|
||||
self.device = device
|
||||
|
||||
def _init_wrapper(self, backend: Backend, backend_files: Sequence[str],
|
||||
device: str, **kwargs):
|
||||
"""Initialize backend wrapper.
|
||||
|
||||
Args:
|
||||
backend (Backend): The backend enum, specifying backend type.
|
||||
backend_files (Sequence[str]): Paths to all required backend files
|
||||
(e.g. '.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn).
|
||||
device (str): A string specifying device type.
|
||||
"""
|
||||
output_names = self.output_names
|
||||
self.wrapper = BaseBackendModel._build_wrapper(
|
||||
backend=backend,
|
||||
backend_files=backend_files,
|
||||
device=device,
|
||||
input_names=[self.input_name],
|
||||
output_names=output_names,
|
||||
deploy_cfg=self.deploy_cfg,
|
||||
**kwargs)
|
||||
|
||||
def forward(self,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[List[BaseDataElement]] = None,
|
||||
mode: str = 'predict') -> Any:
|
||||
"""Run forward inference.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[``ActionDataSample``], optional): The
|
||||
annotation data of every samples. Defaults to None.
|
||||
mode (str): Return what kind of value. Defaults to ``predict``.
|
||||
|
||||
Returns:
|
||||
list: A list contains predictions.
|
||||
"""
|
||||
assert mode == 'predict', \
|
||||
'Backend model only support mode==predict,' f' but get {mode}'
|
||||
|
||||
if inputs.device != torch.device(self.device):
|
||||
get_root_logger().warning(f'expect input device {self.device}'
|
||||
f' but get {inputs.device}.')
|
||||
inputs = inputs.to(self.device)
|
||||
cls_scores = self.wrapper({self.input_name:
|
||||
inputs})[self.output_names[0]]
|
||||
|
||||
predictions: LabelList = []
|
||||
for score in cls_scores:
|
||||
label = LabelData(item=score)
|
||||
predictions.append(label)
|
||||
|
||||
for data_sample, pred_instances in zip(data_samples, predictions):
|
||||
data_sample.pred_scores = pred_instances
|
||||
|
||||
return data_samples
|
||||
|
||||
|
||||
def build_video_recognition_model(model_files: Sequence[str],
|
||||
model_cfg: Union[str, mmengine.Config],
|
||||
deploy_cfg: Union[str, mmengine.Config],
|
||||
device: str, **kwargs):
|
||||
"""Build video recognition model for different backends.
|
||||
|
||||
Args:
|
||||
model_files (Sequence[str]): Input model file(s).
|
||||
model_cfg (str | mmengine.Config): Input model config file or Config
|
||||
object.
|
||||
deploy_cfg (str | mmengine.Config): Input deployment config file or
|
||||
Config object.
|
||||
device (str): Device to input model.
|
||||
|
||||
Returns:
|
||||
BaseBackendModel: Video recognizer for a configured backend.
|
||||
"""
|
||||
# load cfg if necessary
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
|
||||
|
||||
backend = get_backend(deploy_cfg)
|
||||
model_type = get_codebase_config(deploy_cfg).get('model_type', 'end2end')
|
||||
|
||||
backend_video_recognizer = __BACKEND_MODEL.build(
|
||||
dict(
|
||||
type=model_type,
|
||||
backend=backend,
|
||||
backend_files=model_files,
|
||||
device=device,
|
||||
deploy_cfg=deploy_cfg,
|
||||
model_cfg=model_cfg,
|
||||
**kwargs))
|
||||
|
||||
return backend_video_recognizer
|
3
mmdeploy/codebase/mmaction/models/__init__.py
Normal file
3
mmdeploy/codebase/mmaction/models/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from .recognizers import * # noqa: F401,F403
|
@ -0,0 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from .base import base_recognizer__forward
|
||||
|
||||
__all__ = ['base_recognizer__forward']
|
37
mmdeploy/codebase/mmaction/models/recognizers/base.py
Normal file
37
mmdeploy/codebase/mmaction/models/recognizers/base.py
Normal file
@ -0,0 +1,37 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from mmaction.utils import OptSampleList
|
||||
from torch import Tensor
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmaction.models.recognizers.BaseRecognizer.forward')
|
||||
def base_recognizer__forward(ctx,
|
||||
self,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None,
|
||||
mode: str = 'tensor',
|
||||
**kwargs):
|
||||
"""Rewrite `forward` of Recognizer2D for default backend.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[``ActionDataSample``], optional): The
|
||||
annotation data of every samples. Defaults to None.
|
||||
mode (str): Return what kind of value. Defaults to ``tensor``.
|
||||
|
||||
Returns:
|
||||
return a list of `ActionDataSample`
|
||||
"""
|
||||
|
||||
assert mode == 'predict'
|
||||
|
||||
feats, predict_kwargs = self.extract_feat(inputs, test_mode=True)
|
||||
cls_scores = self.cls_head(feats, **predict_kwargs)
|
||||
num_segs = cls_scores.shape[0] // len(data_samples)
|
||||
cls_scores = self.cls_head.average_clip(cls_scores, num_segs=num_segs)
|
||||
|
||||
return cls_scores
|
@ -27,6 +27,7 @@ class Task(AdvancedEnum):
|
||||
VOXEL_DETECTION = 'VoxelDetection'
|
||||
POSE_DETECTION = 'PoseDetection'
|
||||
ROTATED_DETECTION = 'RotatedDetection'
|
||||
VIDEO_RECOGNITION = 'VideoRecognition'
|
||||
|
||||
|
||||
class Codebase(AdvancedEnum):
|
||||
@ -39,6 +40,7 @@ class Codebase(AdvancedEnum):
|
||||
MMDET3D = 'mmdet3d'
|
||||
MMPOSE = 'mmpose'
|
||||
MMROTATE = 'mmrotate'
|
||||
MMACTION = 'mmaction'
|
||||
|
||||
|
||||
class IR(AdvancedEnum):
|
||||
|
56
tests/regression/mmaction.yml
Normal file
56
tests/regression/mmaction.yml
Normal file
@ -0,0 +1,56 @@
|
||||
globals:
|
||||
codebase_dir: ../mmaction2
|
||||
checkpoint_force_download: False
|
||||
images:
|
||||
video: &video ../mmaction2/demo/demo.mp4
|
||||
|
||||
metric_info: &metric_info
|
||||
Top 1 Accuracy:
|
||||
metric_key: acc/top1
|
||||
tolerance: 1
|
||||
multi_value: 100
|
||||
dataset: Kinetics-400
|
||||
Top 5 Accuracy:
|
||||
metric_key: acc/top5
|
||||
tolerance: 1
|
||||
multi_value: 100
|
||||
dataset: Kinetics-400
|
||||
convert_image: &convert_image
|
||||
input_img: *video
|
||||
test_img: *video
|
||||
backend_test: &default_backend_test True
|
||||
sdk:
|
||||
sdk_dynamic: &sdk_dynamic ""
|
||||
|
||||
onnxruntime:
|
||||
pipeline_ort_static_fp32: &pipeline_ort_static_fp32
|
||||
convert_image: *convert_image
|
||||
deploy_config: configs/mmaction/video-recognition/video-recognition_onnxruntime_static.py
|
||||
backend_test: *default_backend_test
|
||||
|
||||
tensorrt:
|
||||
pipeline_trt_2d_static_fp32: &pipeline_trt_2d_static_fp32
|
||||
convert_image: *convert_image
|
||||
deploy_config: configs/mmaction/video-recognition/video-recognition_2d_tensorrt_static-224x224.py
|
||||
backend_test: *default_backend_test
|
||||
pipeline_trt_3d_static_fp32: &pipeline_trt_3d_static_fp32
|
||||
convert_image: *convert_image
|
||||
deploy_config: configs/mmaction/video-recognition/video-recognition_3d_tensorrt_static-256x256.py
|
||||
backend_test: *default_backend_test
|
||||
|
||||
models:
|
||||
- name: TSN
|
||||
metafile: configs/recognition/tsn/metafile.yml
|
||||
model_configs:
|
||||
- configs/recognition/tsn/tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb.py
|
||||
pipelines:
|
||||
- *pipeline_ort_static_fp32
|
||||
- *pipeline_trt_2d_static_fp32
|
||||
|
||||
- name: SlowFast
|
||||
metafile: configs/recognition/slowfast/metafile.yml
|
||||
model_configs:
|
||||
- configs/recognition/slowfast/slowfast_r50_8xb8-4x16x1-256e_kinetics400-rgb.py
|
||||
pipelines:
|
||||
- *pipeline_ort_static_fp32
|
||||
- *pipeline_trt_3d_static_fp32
|
10
tests/test_codebase/test_mmaction/conftest.py
Normal file
10
tests/test_codebase/test_mmaction/conftest.py
Normal file
@ -0,0 +1,10 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def init_test():
|
||||
# init default scope
|
||||
from mmaction.utils import register_all_modules
|
||||
|
||||
register_all_modules(True)
|
1
tests/test_codebase/test_mmaction/data/ann.txt
Normal file
1
tests/test_codebase/test_mmaction/data/ann.txt
Normal file
@ -0,0 +1 @@
|
||||
demo.mp4 6
|
186
tests/test_codebase/test_mmaction/data/model.py
Normal file
186
tests/test_codebase/test_mmaction/data/model.py
Normal file
@ -0,0 +1,186 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
model = dict(
|
||||
type='Recognizer2D',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
pretrained='https://download.pytorch.org/models/resnet50-11ad3fa6.pth',
|
||||
depth=50,
|
||||
norm_eval=False),
|
||||
cls_head=dict(
|
||||
type='TSNHead',
|
||||
num_classes=400,
|
||||
in_channels=2048,
|
||||
spatial_type='avg',
|
||||
consensus=dict(type='AvgConsensus', dim=1),
|
||||
dropout_ratio=0.4,
|
||||
init_std=0.01,
|
||||
average_clips=None),
|
||||
data_preprocessor=dict(
|
||||
type='ActionDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
format_shape='NCHW'),
|
||||
train_cfg=None,
|
||||
test_cfg=None)
|
||||
train_cfg = dict(
|
||||
type='EpochBasedTrainLoop', max_epochs=100, val_begin=1, val_interval=1)
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='MultiStepLR',
|
||||
begin=0,
|
||||
end=100,
|
||||
by_epoch=True,
|
||||
milestones=[40, 80],
|
||||
gamma=0.1)
|
||||
]
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001),
|
||||
clip_grad=dict(max_norm=40, norm_type=2))
|
||||
default_scope = 'mmaction'
|
||||
default_hooks = dict(
|
||||
runtime_info=dict(type='RuntimeInfoHook'),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook', interval=20, ignore_last=False),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(
|
||||
type='CheckpointHook', interval=3, save_best='auto', max_keep_ckpts=3),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'))
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=False,
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||
dist_cfg=dict(backend='nccl'))
|
||||
log_processor = dict(type='LogProcessor', window_size=20, by_epoch=True)
|
||||
vis_backends = [dict(type='LocalVisBackend')]
|
||||
visualizer = dict(
|
||||
type='ActionVisualizer', vis_backends=[dict(type='LocalVisBackend')])
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
resume = False
|
||||
dataset_type = 'VideoDataset'
|
||||
data_root = 'data/kinetics400/videos_train'
|
||||
data_root_val = 'data/video'
|
||||
ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt'
|
||||
ann_file_val = 'data/ann.txt'
|
||||
train_pipeline = [
|
||||
dict(type='DecordInit'),
|
||||
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=3),
|
||||
dict(type='DecordDecode'),
|
||||
dict(type='Resize', scale=(-1, 256)),
|
||||
dict(
|
||||
type='MultiScaleCrop',
|
||||
input_size=224,
|
||||
scales=(1, 0.875, 0.75, 0.66),
|
||||
random_crop=False,
|
||||
max_wh_scale_gap=1),
|
||||
dict(type='Resize', scale=(224, 224), keep_ratio=False),
|
||||
dict(type='Flip', flip_ratio=0.5),
|
||||
dict(type='FormatShape', input_format='NCHW'),
|
||||
dict(type='PackActionInputs')
|
||||
]
|
||||
val_pipeline = [
|
||||
dict(type='DecordInit'),
|
||||
dict(
|
||||
type='SampleFrames',
|
||||
clip_len=1,
|
||||
frame_interval=1,
|
||||
num_clips=3,
|
||||
test_mode=True),
|
||||
dict(type='DecordDecode'),
|
||||
dict(type='Resize', scale=(-1, 256)),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='FormatShape', input_format='NCHW'),
|
||||
dict(type='PackActionInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='DecordInit'),
|
||||
dict(
|
||||
type='SampleFrames',
|
||||
clip_len=1,
|
||||
frame_interval=1,
|
||||
num_clips=25,
|
||||
test_mode=True),
|
||||
dict(type='DecordDecode'),
|
||||
dict(type='Resize', scale=(-1, 256)),
|
||||
dict(type='TenCrop', crop_size=224),
|
||||
dict(type='FormatShape', input_format='NCHW'),
|
||||
dict(type='PackActionInputs')
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type='VideoDataset',
|
||||
ann_file='data/kinetics400/kinetics400_train_list_videos.txt',
|
||||
data_prefix=dict(video='data/kinetics400/videos_train'),
|
||||
pipeline=[
|
||||
dict(type='DecordInit'),
|
||||
dict(
|
||||
type='SampleFrames', clip_len=1, frame_interval=1,
|
||||
num_clips=3),
|
||||
dict(type='DecordDecode'),
|
||||
dict(type='Resize', scale=(-1, 256)),
|
||||
dict(
|
||||
type='MultiScaleCrop',
|
||||
input_size=224,
|
||||
scales=(1, 0.875, 0.75, 0.66),
|
||||
random_crop=False,
|
||||
max_wh_scale_gap=1),
|
||||
dict(type='Resize', scale=(224, 224), keep_ratio=False),
|
||||
dict(type='Flip', flip_ratio=0.5),
|
||||
dict(type='FormatShape', input_format='NCHW'),
|
||||
dict(type='PackActionInputs')
|
||||
]))
|
||||
val_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type='VideoDataset',
|
||||
ann_file='tests/test_codebase/test_mmaction/data/ann.txt',
|
||||
data_prefix=dict(video='tests/test_codebase/test_mmaction/data/video'),
|
||||
pipeline=[
|
||||
dict(type='DecordInit'),
|
||||
dict(
|
||||
type='SampleFrames',
|
||||
clip_len=1,
|
||||
frame_interval=1,
|
||||
num_clips=3,
|
||||
test_mode=True),
|
||||
dict(type='DecordDecode'),
|
||||
dict(type='Resize', scale=(-1, 256)),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='FormatShape', input_format='NCHW'),
|
||||
dict(type='PackActionInputs')
|
||||
],
|
||||
test_mode=True))
|
||||
test_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type='VideoDataset',
|
||||
ann_file='tests/test_codebase/test_mmaction/data/ann.txt',
|
||||
data_prefix=dict(video='tests/test_codebase/test_mmaction/data/video'),
|
||||
pipeline=[
|
||||
dict(type='DecordInit'),
|
||||
dict(
|
||||
type='SampleFrames',
|
||||
clip_len=1,
|
||||
frame_interval=1,
|
||||
num_clips=25,
|
||||
test_mode=True),
|
||||
dict(type='DecordDecode'),
|
||||
dict(type='Resize', scale=(-1, 256)),
|
||||
dict(type='TenCrop', crop_size=224),
|
||||
dict(type='FormatShape', input_format='NCHW'),
|
||||
dict(type='PackActionInputs')
|
||||
],
|
||||
test_mode=True))
|
||||
val_evaluator = dict(type='AccMetric')
|
||||
test_evaluator = dict(type='AccMetric')
|
BIN
tests/test_codebase/test_mmaction/data/video/demo.mp4
Executable file
BIN
tests/test_codebase/test_mmaction/data/video/demo.mp4
Executable file
Binary file not shown.
52
tests/test_codebase/test_mmaction/test_mmaction_model.py
Normal file
52
tests/test_codebase/test_mmaction/test_mmaction_model.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
from mmengine import Config
|
||||
|
||||
from mmdeploy.codebase import import_codebase
|
||||
from mmdeploy.utils import Backend, Codebase, load_config
|
||||
from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs
|
||||
|
||||
try:
|
||||
import_codebase(Codebase.MMACTION)
|
||||
except ImportError:
|
||||
pytest.skip(
|
||||
f'{Codebase.MMACTION} is not installed.', allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
|
||||
@pytest.mark.parametrize('model_cfg_path',
|
||||
['tests/test_codebase/test_mmaction/data/model.py'])
|
||||
def test_forward_of_base_recognizer(model_cfg_path, backend):
|
||||
check_backend(backend)
|
||||
deploy_cfg = Config(
|
||||
dict(
|
||||
backend_config=dict(type='onnxruntime'),
|
||||
codebase_config=dict(type='mmaction', task='VideoRecognition'),
|
||||
onnx_config=dict(
|
||||
type='onnx',
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=False,
|
||||
opset_version=11,
|
||||
input_shape=None,
|
||||
input_names=['input'],
|
||||
output_names=['output'])))
|
||||
|
||||
model_cfg = load_config(model_cfg_path)[0]
|
||||
from mmaction.apis import init_recognizer
|
||||
model = init_recognizer(model_cfg, None, device='cpu')
|
||||
|
||||
img = torch.randn(1, 3, 3, 224, 224)
|
||||
from mmaction.structures import ActionDataSample
|
||||
data_sample = ActionDataSample()
|
||||
img_meta = dict(img_shape=(224, 224))
|
||||
data_sample.set_metainfo(img_meta)
|
||||
rewrite_inputs = {'inputs': img}
|
||||
wrapped_model = WrapModel(
|
||||
model, 'forward', data_samples=[data_sample], mode='predict')
|
||||
rewrite_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
assert rewrite_outputs is not None
|
106
tests/test_codebase/test_mmaction/test_video_recognition.py
Normal file
106
tests/test_codebase/test_mmaction/test_video_recognition.py
Normal file
@ -0,0 +1,106 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from mmengine import Config
|
||||
|
||||
import mmdeploy.backend.onnxruntime as ort_apis
|
||||
from mmdeploy.apis import build_task_processor
|
||||
from mmdeploy.codebase import import_codebase
|
||||
from mmdeploy.utils import Codebase, load_config
|
||||
from mmdeploy.utils.test import SwitchBackendWrapper
|
||||
|
||||
try:
|
||||
import_codebase(Codebase.MMACTION)
|
||||
except ImportError:
|
||||
pytest.skip(
|
||||
f'{Codebase.MMACTION} is not installed.', allow_module_level=True)
|
||||
|
||||
model_cfg_path = 'tests/test_codebase/test_mmaction/data/model.py'
|
||||
model_cfg = load_config(model_cfg_path)[0]
|
||||
deploy_cfg = Config(
|
||||
dict(
|
||||
backend_config=dict(type='onnxruntime'),
|
||||
codebase_config=dict(type='mmaction', task='VideoRecognition'),
|
||||
onnx_config=dict(
|
||||
type='onnx',
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=False,
|
||||
opset_version=11,
|
||||
input_shape=None,
|
||||
input_names=['input'],
|
||||
output_names=['output'])))
|
||||
|
||||
onnx_file = NamedTemporaryFile(suffix='.onnx').name
|
||||
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
|
||||
img_shape = (224, 224)
|
||||
num_classes = 400
|
||||
video = 'tests/test_codebase/test_mmaction/data/video/demo.mp4'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def backend_model():
|
||||
from mmdeploy.backend.onnxruntime import ORTWrapper
|
||||
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||
wrapper = SwitchBackendWrapper(ORTWrapper)
|
||||
wrapper.set(outputs={
|
||||
'output': torch.rand(1, num_classes),
|
||||
})
|
||||
|
||||
yield task_processor.build_backend_model([''])
|
||||
|
||||
wrapper.recover()
|
||||
|
||||
|
||||
def test_build_backend_model(backend_model):
|
||||
assert isinstance(backend_model, torch.nn.Module)
|
||||
|
||||
|
||||
def test_create_input():
|
||||
inputs = task_processor.create_input(video, input_shape=img_shape)
|
||||
assert isinstance(inputs, tuple) and len(inputs) == 2
|
||||
|
||||
|
||||
def test_build_pytorch_model():
|
||||
from mmaction.models.recognizers.base import BaseRecognizer
|
||||
model = task_processor.build_pytorch_model(None)
|
||||
assert isinstance(model, BaseRecognizer)
|
||||
|
||||
|
||||
def test_get_tensor_from_input():
|
||||
input_data = {'inputs': torch.ones(3, 4, 5)}
|
||||
inputs = task_processor.get_tensor_from_input(input_data)
|
||||
assert torch.equal(inputs, torch.ones(3, 4, 5))
|
||||
|
||||
|
||||
def test_get_model_name():
|
||||
model_name = task_processor.get_model_name()
|
||||
assert isinstance(model_name, str) and model_name is not None
|
||||
|
||||
|
||||
def test_build_dataset_and_dataloader():
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
dataset = task_processor.build_dataset(
|
||||
dataset_cfg=model_cfg.test_dataloader.dataset)
|
||||
assert isinstance(dataset, Dataset), 'Failed to build dataset'
|
||||
dataloader_cfg = task_processor.model_cfg.test_dataloader
|
||||
dataloader = task_processor.build_dataloader(dataloader_cfg)
|
||||
assert isinstance(dataloader, DataLoader), 'Failed to build dataloader'
|
||||
|
||||
|
||||
def test_build_test_runner(backend_model):
|
||||
from mmdeploy.codebase.base.runner import DeployTestRunner
|
||||
temp_dir = TemporaryDirectory().name
|
||||
runner = task_processor.build_test_runner(backend_model, temp_dir)
|
||||
assert isinstance(runner, DeployTestRunner)
|
||||
|
||||
|
||||
def test_get_preprocess():
|
||||
process = task_processor.get_preprocess()
|
||||
assert process is not None
|
||||
|
||||
|
||||
def test_get_postprocess():
|
||||
process = task_processor.get_postprocess()
|
||||
assert isinstance(process, dict)
|
@ -0,0 +1,83 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from mmengine import Config
|
||||
|
||||
import mmdeploy.backend.onnxruntime as ort_apis
|
||||
from mmdeploy.codebase import import_codebase
|
||||
from mmdeploy.utils import Backend, Codebase, load_config
|
||||
from mmdeploy.utils.test import SwitchBackendWrapper, backend_checker
|
||||
|
||||
IMAGE_SIZE = 224
|
||||
|
||||
try:
|
||||
import_codebase(Codebase.MMACTION)
|
||||
except ImportError:
|
||||
pytest.skip(
|
||||
f'{Codebase.MMACTION} is not installed.', allow_module_level=True)
|
||||
|
||||
|
||||
@backend_checker(Backend.ONNXRUNTIME)
|
||||
class TestEnd2EndModel:
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
# force add backend wrapper regardless of plugins
|
||||
from mmdeploy.backend.onnxruntime import ORTWrapper
|
||||
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||
|
||||
# simplify backend inference
|
||||
cls.wrapper = SwitchBackendWrapper(ORTWrapper)
|
||||
cls.outputs = {
|
||||
'outputs': torch.rand(1, 400),
|
||||
}
|
||||
cls.wrapper.set(outputs=cls.outputs)
|
||||
deploy_cfg = Config({'onnx_config': {'output_names': ['outputs']}})
|
||||
model_cfg_path = 'tests/test_codebase/test_mmaction/data/model.py'
|
||||
model_cfg = load_config(model_cfg_path)[0]
|
||||
|
||||
from mmdeploy.codebase.mmaction.deploy.video_recognition_model import \
|
||||
End2EndModel
|
||||
cls.end2end_model = End2EndModel(
|
||||
Backend.ONNXRUNTIME, [''],
|
||||
device='cpu',
|
||||
deploy_cfg=deploy_cfg,
|
||||
model_cfg=model_cfg)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
cls.wrapper.recover()
|
||||
|
||||
def test_forward(self):
|
||||
inputs = torch.rand(1, 3, 3, IMAGE_SIZE, IMAGE_SIZE)
|
||||
from mmaction.structures import ActionDataSample
|
||||
data_sample = ActionDataSample(
|
||||
metainfo=dict(img_shape=(IMAGE_SIZE, IMAGE_SIZE)))
|
||||
results = self.end2end_model.forward(
|
||||
inputs, [data_sample], mode='predict')
|
||||
assert results is not None, 'failed to get output using '\
|
||||
'End2EndModel'
|
||||
|
||||
|
||||
@backend_checker(Backend.ONNXRUNTIME)
|
||||
def test_build_video_recognition_model():
|
||||
model_cfg_path = 'tests/test_codebase/test_mmaction/data/model.py'
|
||||
model_cfg = load_config(model_cfg_path)[0]
|
||||
deploy_cfg = Config(
|
||||
dict(
|
||||
backend_config=dict(type='onnxruntime'),
|
||||
onnx_config=dict(output_names=['outputs']),
|
||||
codebase_config=dict(type='mmaction')))
|
||||
|
||||
from mmdeploy.backend.onnxruntime import ORTWrapper
|
||||
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||
|
||||
# simplify backend inference
|
||||
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||
wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg)
|
||||
from mmdeploy.codebase.mmaction.deploy.video_recognition_model import (
|
||||
End2EndModel, build_video_recognition_model)
|
||||
classifier = build_video_recognition_model([''], model_cfg, deploy_cfg,
|
||||
'cpu')
|
||||
assert isinstance(classifier, End2EndModel)
|
Loading…
x
Reference in New Issue
Block a user