Support mmaction2:dev-1.x (#1012)

* support tsn

* support slowfast

* fix export info & End2EndModel

* add test

* fix forward

* fix lint

* update tests

* add onnxruntime 2d config

* fix ort-gpu

* add mmaction.yml, need to update

* fix reviews

* add ann.txt

* add visualize

* fix lint

* rebase

* add conftest.py

* fix circle ci

* fix registry

* fix regression test
This commit is contained in:
Chen Xin 2022-10-19 15:42:57 +08:00 committed by GitHub
parent 72923e7844
commit 2020e74480
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1060 additions and 0 deletions

View File

@ -0,0 +1,15 @@
_base_ = ['./video-recognition_static.py']
onnx_config = dict(
dynamic_axes={
'input': {
0: 'batch',
1: 'num_crops * num_segs',
3: 'height',
4: 'width'
},
'output': {
0: 'batch',
}
},
input_shape=None)

View File

@ -0,0 +1,14 @@
_base_ = ['./video-recognition_static.py', '../../_base_/backends/tensorrt.py']
onnx_config = dict(input_shape=[224, 224])
backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 250, 3, 224, 224],
opt_shape=[1, 250, 3, 224, 224],
max_shape=[1, 250, 3, 224, 224])))
])

View File

@ -0,0 +1,16 @@
_base_ = ['./video-recognition_static.py']
onnx_config = dict(
dynamic_axes={
'input': {
0: 'batch',
1: 'num_crops * num_segs',
3: 'time',
4: 'height',
5: 'width'
},
'output': {
0: 'batch',
}
},
input_shape=None)

View File

@ -0,0 +1,14 @@
_base_ = ['./video-recognition_static.py', '../../_base_/backends/tensorrt.py']
onnx_config = dict(input_shape=[256, 256])
backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 30, 3, 32, 256, 256],
opt_shape=[1, 30, 3, 32, 256, 256],
max_shape=[1, 30, 3, 32, 256, 256])))
])

View File

@ -0,0 +1,5 @@
_base_ = [
'./video-recognition_static.py', '../../_base_/backends/onnxruntime.py'
]
onnx_config = dict(input_shape=None)

View File

@ -0,0 +1,3 @@
_base_ = ['../../_base_/onnx_config.py']
codebase_config = dict(type='mmaction', task='VideoRecognition')

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .deploy import * # noqa: F401,F403
from .models import * # noqa: F401,F403

View File

@ -0,0 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .mmaction import MMACTION
from .video_recognition import VideoRecognition
__all__ = ['MMACTION', 'VideoRecognition']

View File

@ -0,0 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.registry import Registry
from mmdeploy.codebase.base import CODEBASE, MMCodebase
from mmdeploy.utils import Codebase
MMACTION_TASK = Registry('mmaction_tasks')
@CODEBASE.register_module(Codebase.MMACTION.value)
class MMACTION(MMCodebase):
"""MMAction codebase class."""
task_registry = MMACTION_TASK
@classmethod
def register_all_modules(cls):
from mmaction.utils.setup_env import register_all_modules
register_all_modules(True)

View File

@ -0,0 +1,273 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from operator import itemgetter
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import mmengine
import numpy as np
import torch
from mmengine.dataset import pseudo_collate
from mmengine.model import BaseDataPreprocessor
from mmdeploy.codebase.base import BaseTask
from mmdeploy.utils import Task, get_root_logger
from mmdeploy.utils.config_utils import get_input_shape
from .mmaction import MMACTION_TASK
def process_model_config(model_cfg: mmengine.Config,
imgs: Union[Sequence[str], Sequence[np.ndarray]],
input_shape: Optional[Sequence[int]] = None):
"""Process the model config.
Args:
model_cfg (mmengine.Config): The model config.
imgs (Sequence[str] | Sequence[np.ndarray]): Input image(s), accepted
data type are List[str], List[np.ndarray].
input_shape (list[int]): A list of two integer in (width, height)
format specifying input shape. Default: None.
Returns:
mmengine.Config: the model config after processing.
"""
logger = get_root_logger()
cfg = model_cfg.deepcopy()
test_pipeline_cfg = cfg.test_pipeline
if 'Init' not in test_pipeline_cfg[0]['type']:
test_pipeline_cfg = [dict(type='OpenCVInit')] + test_pipeline_cfg
else:
test_pipeline_cfg[0] = dict(type='OpenCVInit')
for i, trans in enumerate(test_pipeline_cfg):
if 'Decode' in trans['type']:
test_pipeline_cfg[i] = dict(type='OpenCVDecode')
cfg.test_pipeline = test_pipeline_cfg
# check whether input_shape is valid
if input_shape is not None:
has_crop = False
crop_size = -1
has_resize = False
scale = (-1, -1)
keep_ratio = True
for trans in cfg.test_pipeline:
if trans['type'] == 'Resize':
has_resize = True
keep_ratio = trans.get('keep_ratio', True)
scale = trans.scale
if trans['type'] in ['TenCrop', 'CenterCrop', 'ThreeCrop']:
has_crop = True
crop_size = trans.crop_size
if has_crop and tuple(input_shape) != (crop_size, crop_size):
logger.error(
f'`input shape` should be equal to `crop_size`: {crop_size},'
f' but given: {input_shape}')
if has_resize and (not has_crop):
if keep_ratio:
logger.error('Resize should set `keep_ratio` to False'
' when `input shape` is given.')
if tuple(input_shape) != scale:
logger.error(
f'`input shape` should be equal to `scale`: {scale},'
f' but given: {input_shape}')
return cfg
@MMACTION_TASK.register_module(Task.VIDEO_RECOGNITION.value)
class VideoRecognition(BaseTask):
"""VideoRecognition task class.
Args:
model_cfg (Config): Original PyTorch model config file.
deploy_cfg (Config): Deployment config file or loaded Config
object.
device (str): A string represents device type.
"""
def __init__(self, model_cfg: mmengine.Config, deploy_cfg: mmengine.Config,
device: str):
super(VideoRecognition, self).__init__(model_cfg, deploy_cfg, device)
def build_data_preprocessor(self):
model_cfg = self.model_cfg
preprocess_cfg = model_cfg.get('preprocess_cfg', None)
from mmengine.registry import MODELS
if preprocess_cfg is not None:
data_preprocessor = MODELS.build(preprocess_cfg)
else:
data_preprocessor = BaseDataPreprocessor()
return data_preprocessor
def build_backend_model(self,
model_files: Sequence[str] = None,
**kwargs) -> torch.nn.Module:
"""Initialize backend model.
Args:
model_files (Sequence[str]): Input model files.
Returns:
nn.Module: An initialized backend model.
"""
from .video_recognition_model import build_video_recognition_model
model = build_video_recognition_model(
model_files, self.model_cfg, self.deploy_cfg, device=self.device)
model.to(self.device)
model.eval()
return model
def create_input(self,
imgs: Union[str, np.ndarray],
input_shape: Sequence[int] = None,
data_preprocessor: Optional[BaseDataPreprocessor] = None)\
-> Tuple[Dict, torch.Tensor]:
"""Create input for video recognition.
Args:
imgs (str | np.ndarray): Input image(s), accepted data type are
`str`, `np.ndarray`.
input_shape (list[int]): A list of two integer in (width, height)
format specifying input shape. Defaults to `None`.
Returns:
tuple: (data, img), meta information for the input image and input.
"""
if isinstance(imgs, (list, tuple)):
if not all(isinstance(img, str) for img in imgs):
raise AssertionError('imgs must be strings')
elif isinstance(imgs, str):
imgs = [imgs]
else:
raise AssertionError('imgs must be strings')
from mmcv.transforms.wrappers import Compose
model_cfg = process_model_config(self.model_cfg, imgs, input_shape)
test_pipeline = Compose(model_cfg.test_pipeline)
data = []
for img in imgs:
data_ = dict(filename=img, label=-1, start_index=0, modality='RGB')
data_ = test_pipeline(data_)
data.append(data_)
data = pseudo_collate(data)
if data_preprocessor is not None:
data = data_preprocessor(data, False)
return data, data['inputs']
else:
return data, BaseTask.get_tensor_from_input(data)
def visualize(self,
image: str,
result: list,
output_file: str,
window_name: str = '',
show_result: bool = False,
**kwargs):
"""Visualize predictions of a model.
Args:
model (nn.Module): Input model.
image (str): Input video to draw predictions on.
result (list): A list of predictions.
output_file (str): Output file to save drawn image.
window_name (str): The name of visualization window. Defaults to
an empty string.
show_result (bool): Whether to show result in windows, defaults
to `False`.
"""
logger = get_root_logger()
try:
import decord
from moviepy.editor import ImageSequenceClip
except Exception:
logger.warn('Please install moviepy and decord to '
'enable visualize for mmaction')
save_dir, save_name = osp.split(output_file)
video = decord.VideoReader(image)
frames = [x.asnumpy()[..., ::-1] for x in video]
pred_scores = result.pred_scores.item.tolist()
score_tuples = tuple(zip(range(len(pred_scores)), pred_scores))
score_sorted = sorted(score_tuples, key=itemgetter(1), reverse=True)
top1_item = score_sorted[0]
short_edge_length = min(frames[0].shape[:2])
scale = short_edge_length // 224.
img_scale = min(max(scale, 0.3), 3.0)
text_cfg = {
'positions': np.array([(img_scale * 5, ) * 2]).astype(np.int32),
'font_sizes': int(img_scale * 7),
'font_families': 'monospace',
'colors': 'white',
'bboxes': dict(facecolor='black', alpha=0.5, boxstyle='Round')
}
visualizer = self.get_visualizer(window_name, save_dir)
out_frames = []
for i, frame in enumerate(frames):
visualizer.set_image(frame)
texts = [f'Frame {i} of total {len(frames)} frames']
texts.append(f'top-1 label: {top1_item[0]}, score: {top1_item[0]}')
visualizer.draw_texts('\n'.join(texts), **text_cfg)
drawn_img = visualizer.get_image()
out_frames.append(drawn_img)
out_frames = [x[..., ::-1] for x in out_frames]
video_clips = ImageSequenceClip(out_frames, fps=30)
output_file = output_file[:output_file.rfind('.')] + '.mp4'
video_clips.write_videofile(output_file)
@staticmethod
def get_partition_cfg(partition_type: str) -> Dict:
"""Get a certain partition config.
Args:
partition_type (str): A string specifying partition type.
Returns:
dict: A dictionary of partition config.
"""
raise NotImplementedError('Not supported yet.')
@staticmethod
def get_tensor_from_input(input_data: Dict[str, Any],
**kwargs) -> torch.Tensor:
"""Get input tensor from input data.
Args:
input_data (dict): Input data containing meta info and image
tensor.
Returns:
torch.Tensor: An image in `Tensor`.
"""
return input_data['inputs']
def get_preprocess(self) -> Dict:
"""Get the preprocess information for SDK.
Return:
dict: Composed of the preprocess information.
"""
input_shape = get_input_shape(self.deploy_cfg)
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
preprocess = model_cfg.test_pipeline
return preprocess
def get_postprocess(self) -> Dict:
"""Get the postprocess information for SDK.
Return:
dict: Composed of the postprocess information.
"""
postprocess = self.model_cfg.model.cls_head
return postprocess
def get_model_name(self) -> str:
"""Get the model name.
Return:
str: the name of the model.
"""
assert 'type' in self.model_cfg.model, 'model config contains no type'
name = self.model_cfg.model.type.lower()
return name

View File

@ -0,0 +1,150 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, List, Optional, Sequence, Union
import mmengine
import torch
from mmaction.utils import LabelList
from mmengine import Config
from mmengine.model import BaseDataPreprocessor
from mmengine.registry import Registry
from mmengine.structures import BaseDataElement, LabelData
from mmdeploy.codebase.base import BaseBackendModel
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
get_root_logger, load_config)
__BACKEND_MODEL = Registry('backend_video_recognizer')
@__BACKEND_MODEL.register_module('end2end')
class End2EndModel(BaseBackendModel):
"""End to end model for inference of video recognition.
Args:
backend (Backend): The backend enum, specifying backend type.
backend_files (Sequence[str]): Paths to all required backend files(e.g.
'.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn).
device (str): A string represents device type.
deploy_cfg (str | mmengine.Config): Deployment config file or loaded
Config object.
model_cfg (str | mmengine.Config): Model config file or loaded Config
object.
"""
def __init__(self,
backend: Backend,
backend_files: Sequence[str],
device: str,
deploy_cfg: Union[str, Config] = None,
model_cfg: Union[str, Config] = None,
**kwargs):
super(End2EndModel, self).__init__(deploy_cfg=deploy_cfg)
model_cfg, deploy_cfg = load_config(model_cfg, deploy_cfg)
from mmaction.registry import MODELS
preprocessor_cfg = model_cfg.model.get('data_preprocessor', None)
if preprocessor_cfg is not None:
self.data_preprocessor = MODELS.build(
model_cfg.model.data_preprocessor)
else:
self.data_preprocessor = BaseDataPreprocessor()
self.deploy_cfg = deploy_cfg
self.model_cfg = model_cfg
self._init_wrapper(
backend=backend,
backend_files=backend_files,
device=device,
**kwargs)
self.device = device
def _init_wrapper(self, backend: Backend, backend_files: Sequence[str],
device: str, **kwargs):
"""Initialize backend wrapper.
Args:
backend (Backend): The backend enum, specifying backend type.
backend_files (Sequence[str]): Paths to all required backend files
(e.g. '.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn).
device (str): A string specifying device type.
"""
output_names = self.output_names
self.wrapper = BaseBackendModel._build_wrapper(
backend=backend,
backend_files=backend_files,
device=device,
input_names=[self.input_name],
output_names=output_names,
deploy_cfg=self.deploy_cfg,
**kwargs)
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
mode: str = 'predict') -> Any:
"""Run forward inference.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[``ActionDataSample``], optional): The
annotation data of every samples. Defaults to None.
mode (str): Return what kind of value. Defaults to ``predict``.
Returns:
list: A list contains predictions.
"""
assert mode == 'predict', \
'Backend model only support mode==predict,' f' but get {mode}'
if inputs.device != torch.device(self.device):
get_root_logger().warning(f'expect input device {self.device}'
f' but get {inputs.device}.')
inputs = inputs.to(self.device)
cls_scores = self.wrapper({self.input_name:
inputs})[self.output_names[0]]
predictions: LabelList = []
for score in cls_scores:
label = LabelData(item=score)
predictions.append(label)
for data_sample, pred_instances in zip(data_samples, predictions):
data_sample.pred_scores = pred_instances
return data_samples
def build_video_recognition_model(model_files: Sequence[str],
model_cfg: Union[str, mmengine.Config],
deploy_cfg: Union[str, mmengine.Config],
device: str, **kwargs):
"""Build video recognition model for different backends.
Args:
model_files (Sequence[str]): Input model file(s).
model_cfg (str | mmengine.Config): Input model config file or Config
object.
deploy_cfg (str | mmengine.Config): Input deployment config file or
Config object.
device (str): Device to input model.
Returns:
BaseBackendModel: Video recognizer for a configured backend.
"""
# load cfg if necessary
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
backend = get_backend(deploy_cfg)
model_type = get_codebase_config(deploy_cfg).get('model_type', 'end2end')
backend_video_recognizer = __BACKEND_MODEL.build(
dict(
type=model_type,
backend=backend,
backend_files=model_files,
device=device,
deploy_cfg=deploy_cfg,
model_cfg=model_cfg,
**kwargs))
return backend_video_recognizer

View File

@ -0,0 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .recognizers import * # noqa: F401,F403

View File

@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import base_recognizer__forward
__all__ = ['base_recognizer__forward']

View File

@ -0,0 +1,37 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmaction.utils import OptSampleList
from torch import Tensor
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmaction.models.recognizers.BaseRecognizer.forward')
def base_recognizer__forward(ctx,
self,
inputs: Tensor,
data_samples: OptSampleList = None,
mode: str = 'tensor',
**kwargs):
"""Rewrite `forward` of Recognizer2D for default backend.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[``ActionDataSample``], optional): The
annotation data of every samples. Defaults to None.
mode (str): Return what kind of value. Defaults to ``tensor``.
Returns:
return a list of `ActionDataSample`
"""
assert mode == 'predict'
feats, predict_kwargs = self.extract_feat(inputs, test_mode=True)
cls_scores = self.cls_head(feats, **predict_kwargs)
num_segs = cls_scores.shape[0] // len(data_samples)
cls_scores = self.cls_head.average_clip(cls_scores, num_segs=num_segs)
return cls_scores

View File

@ -27,6 +27,7 @@ class Task(AdvancedEnum):
VOXEL_DETECTION = 'VoxelDetection'
POSE_DETECTION = 'PoseDetection'
ROTATED_DETECTION = 'RotatedDetection'
VIDEO_RECOGNITION = 'VideoRecognition'
class Codebase(AdvancedEnum):
@ -39,6 +40,7 @@ class Codebase(AdvancedEnum):
MMDET3D = 'mmdet3d'
MMPOSE = 'mmpose'
MMROTATE = 'mmrotate'
MMACTION = 'mmaction'
class IR(AdvancedEnum):

View File

@ -0,0 +1,56 @@
globals:
codebase_dir: ../mmaction2
checkpoint_force_download: False
images:
video: &video ../mmaction2/demo/demo.mp4
metric_info: &metric_info
Top 1 Accuracy:
metric_key: acc/top1
tolerance: 1
multi_value: 100
dataset: Kinetics-400
Top 5 Accuracy:
metric_key: acc/top5
tolerance: 1
multi_value: 100
dataset: Kinetics-400
convert_image: &convert_image
input_img: *video
test_img: *video
backend_test: &default_backend_test True
sdk:
sdk_dynamic: &sdk_dynamic ""
onnxruntime:
pipeline_ort_static_fp32: &pipeline_ort_static_fp32
convert_image: *convert_image
deploy_config: configs/mmaction/video-recognition/video-recognition_onnxruntime_static.py
backend_test: *default_backend_test
tensorrt:
pipeline_trt_2d_static_fp32: &pipeline_trt_2d_static_fp32
convert_image: *convert_image
deploy_config: configs/mmaction/video-recognition/video-recognition_2d_tensorrt_static-224x224.py
backend_test: *default_backend_test
pipeline_trt_3d_static_fp32: &pipeline_trt_3d_static_fp32
convert_image: *convert_image
deploy_config: configs/mmaction/video-recognition/video-recognition_3d_tensorrt_static-256x256.py
backend_test: *default_backend_test
models:
- name: TSN
metafile: configs/recognition/tsn/metafile.yml
model_configs:
- configs/recognition/tsn/tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb.py
pipelines:
- *pipeline_ort_static_fp32
- *pipeline_trt_2d_static_fp32
- name: SlowFast
metafile: configs/recognition/slowfast/metafile.yml
model_configs:
- configs/recognition/slowfast/slowfast_r50_8xb8-4x16x1-256e_kinetics400-rgb.py
pipelines:
- *pipeline_ort_static_fp32
- *pipeline_trt_3d_static_fp32

View File

@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
@pytest.fixture(autouse=True)
def init_test():
# init default scope
from mmaction.utils import register_all_modules
register_all_modules(True)

View File

@ -0,0 +1 @@
demo.mp4 6

View File

@ -0,0 +1,186 @@
# Copyright (c) OpenMMLab. All rights reserved.
model = dict(
type='Recognizer2D',
backbone=dict(
type='ResNet',
pretrained='https://download.pytorch.org/models/resnet50-11ad3fa6.pth',
depth=50,
norm_eval=False),
cls_head=dict(
type='TSNHead',
num_classes=400,
in_channels=2048,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.4,
init_std=0.01,
average_clips=None),
data_preprocessor=dict(
type='ActionDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
format_shape='NCHW'),
train_cfg=None,
test_cfg=None)
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=100, val_begin=1, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=100,
by_epoch=True,
milestones=[40, 80],
gamma=0.1)
]
optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001),
clip_grad=dict(max_norm=40, norm_type=2))
default_scope = 'mmaction'
default_hooks = dict(
runtime_info=dict(type='RuntimeInfoHook'),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=20, ignore_last=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook', interval=3, save_best='auto', max_keep_ckpts=3),
sampler_seed=dict(type='DistSamplerSeedHook'))
env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'))
log_processor = dict(type='LogProcessor', window_size=20, by_epoch=True)
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='ActionVisualizer', vis_backends=[dict(type='LocalVisBackend')])
log_level = 'INFO'
load_from = None
resume = False
dataset_type = 'VideoDataset'
data_root = 'data/kinetics400/videos_train'
data_root_val = 'data/video'
ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt'
ann_file_val = 'data/ann.txt'
train_pipeline = [
dict(type='DecordInit'),
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=3),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='FormatShape', input_format='NCHW'),
dict(type='PackActionInputs')
]
val_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=3,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='FormatShape', input_format='NCHW'),
dict(type='PackActionInputs')
]
test_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=25,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='TenCrop', crop_size=224),
dict(type='FormatShape', input_format='NCHW'),
dict(type='PackActionInputs')
]
train_dataloader = dict(
batch_size=32,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='VideoDataset',
ann_file='data/kinetics400/kinetics400_train_list_videos.txt',
data_prefix=dict(video='data/kinetics400/videos_train'),
pipeline=[
dict(type='DecordInit'),
dict(
type='SampleFrames', clip_len=1, frame_interval=1,
num_clips=3),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='FormatShape', input_format='NCHW'),
dict(type='PackActionInputs')
]))
val_dataloader = dict(
batch_size=32,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='VideoDataset',
ann_file='tests/test_codebase/test_mmaction/data/ann.txt',
data_prefix=dict(video='tests/test_codebase/test_mmaction/data/video'),
pipeline=[
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=3,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='FormatShape', input_format='NCHW'),
dict(type='PackActionInputs')
],
test_mode=True))
test_dataloader = dict(
batch_size=1,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='VideoDataset',
ann_file='tests/test_codebase/test_mmaction/data/ann.txt',
data_prefix=dict(video='tests/test_codebase/test_mmaction/data/video'),
pipeline=[
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=25,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='TenCrop', crop_size=224),
dict(type='FormatShape', input_format='NCHW'),
dict(type='PackActionInputs')
],
test_mode=True))
val_evaluator = dict(type='AccMetric')
test_evaluator = dict(type='AccMetric')

Binary file not shown.

View File

@ -0,0 +1,52 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmengine import Config
from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Backend, Codebase, load_config
from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs
try:
import_codebase(Codebase.MMACTION)
except ImportError:
pytest.skip(
f'{Codebase.MMACTION} is not installed.', allow_module_level=True)
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
@pytest.mark.parametrize('model_cfg_path',
['tests/test_codebase/test_mmaction/data/model.py'])
def test_forward_of_base_recognizer(model_cfg_path, backend):
check_backend(backend)
deploy_cfg = Config(
dict(
backend_config=dict(type='onnxruntime'),
codebase_config=dict(type='mmaction', task='VideoRecognition'),
onnx_config=dict(
type='onnx',
export_params=True,
keep_initializers_as_inputs=False,
opset_version=11,
input_shape=None,
input_names=['input'],
output_names=['output'])))
model_cfg = load_config(model_cfg_path)[0]
from mmaction.apis import init_recognizer
model = init_recognizer(model_cfg, None, device='cpu')
img = torch.randn(1, 3, 3, 224, 224)
from mmaction.structures import ActionDataSample
data_sample = ActionDataSample()
img_meta = dict(img_shape=(224, 224))
data_sample.set_metainfo(img_meta)
rewrite_inputs = {'inputs': img}
wrapped_model = WrapModel(
model, 'forward', data_samples=[data_sample], mode='predict')
rewrite_outputs, _ = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
assert rewrite_outputs is not None

View File

@ -0,0 +1,106 @@
# Copyright (c) OpenMMLab. All rights reserved.
from tempfile import NamedTemporaryFile, TemporaryDirectory
import pytest
import torch
from mmengine import Config
import mmdeploy.backend.onnxruntime as ort_apis
from mmdeploy.apis import build_task_processor
from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Codebase, load_config
from mmdeploy.utils.test import SwitchBackendWrapper
try:
import_codebase(Codebase.MMACTION)
except ImportError:
pytest.skip(
f'{Codebase.MMACTION} is not installed.', allow_module_level=True)
model_cfg_path = 'tests/test_codebase/test_mmaction/data/model.py'
model_cfg = load_config(model_cfg_path)[0]
deploy_cfg = Config(
dict(
backend_config=dict(type='onnxruntime'),
codebase_config=dict(type='mmaction', task='VideoRecognition'),
onnx_config=dict(
type='onnx',
export_params=True,
keep_initializers_as_inputs=False,
opset_version=11,
input_shape=None,
input_names=['input'],
output_names=['output'])))
onnx_file = NamedTemporaryFile(suffix='.onnx').name
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
img_shape = (224, 224)
num_classes = 400
video = 'tests/test_codebase/test_mmaction/data/video/demo.mp4'
@pytest.fixture
def backend_model():
from mmdeploy.backend.onnxruntime import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
wrapper = SwitchBackendWrapper(ORTWrapper)
wrapper.set(outputs={
'output': torch.rand(1, num_classes),
})
yield task_processor.build_backend_model([''])
wrapper.recover()
def test_build_backend_model(backend_model):
assert isinstance(backend_model, torch.nn.Module)
def test_create_input():
inputs = task_processor.create_input(video, input_shape=img_shape)
assert isinstance(inputs, tuple) and len(inputs) == 2
def test_build_pytorch_model():
from mmaction.models.recognizers.base import BaseRecognizer
model = task_processor.build_pytorch_model(None)
assert isinstance(model, BaseRecognizer)
def test_get_tensor_from_input():
input_data = {'inputs': torch.ones(3, 4, 5)}
inputs = task_processor.get_tensor_from_input(input_data)
assert torch.equal(inputs, torch.ones(3, 4, 5))
def test_get_model_name():
model_name = task_processor.get_model_name()
assert isinstance(model_name, str) and model_name is not None
def test_build_dataset_and_dataloader():
from torch.utils.data import DataLoader, Dataset
dataset = task_processor.build_dataset(
dataset_cfg=model_cfg.test_dataloader.dataset)
assert isinstance(dataset, Dataset), 'Failed to build dataset'
dataloader_cfg = task_processor.model_cfg.test_dataloader
dataloader = task_processor.build_dataloader(dataloader_cfg)
assert isinstance(dataloader, DataLoader), 'Failed to build dataloader'
def test_build_test_runner(backend_model):
from mmdeploy.codebase.base.runner import DeployTestRunner
temp_dir = TemporaryDirectory().name
runner = task_processor.build_test_runner(backend_model, temp_dir)
assert isinstance(runner, DeployTestRunner)
def test_get_preprocess():
process = task_processor.get_preprocess()
assert process is not None
def test_get_postprocess():
process = task_processor.get_postprocess()
assert isinstance(process, dict)

View File

@ -0,0 +1,83 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmengine import Config
import mmdeploy.backend.onnxruntime as ort_apis
from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Backend, Codebase, load_config
from mmdeploy.utils.test import SwitchBackendWrapper, backend_checker
IMAGE_SIZE = 224
try:
import_codebase(Codebase.MMACTION)
except ImportError:
pytest.skip(
f'{Codebase.MMACTION} is not installed.', allow_module_level=True)
@backend_checker(Backend.ONNXRUNTIME)
class TestEnd2EndModel:
@classmethod
def setup_class(cls):
# force add backend wrapper regardless of plugins
from mmdeploy.backend.onnxruntime import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
# simplify backend inference
cls.wrapper = SwitchBackendWrapper(ORTWrapper)
cls.outputs = {
'outputs': torch.rand(1, 400),
}
cls.wrapper.set(outputs=cls.outputs)
deploy_cfg = Config({'onnx_config': {'output_names': ['outputs']}})
model_cfg_path = 'tests/test_codebase/test_mmaction/data/model.py'
model_cfg = load_config(model_cfg_path)[0]
from mmdeploy.codebase.mmaction.deploy.video_recognition_model import \
End2EndModel
cls.end2end_model = End2EndModel(
Backend.ONNXRUNTIME, [''],
device='cpu',
deploy_cfg=deploy_cfg,
model_cfg=model_cfg)
@classmethod
def teardown_class(cls):
cls.wrapper.recover()
def test_forward(self):
inputs = torch.rand(1, 3, 3, IMAGE_SIZE, IMAGE_SIZE)
from mmaction.structures import ActionDataSample
data_sample = ActionDataSample(
metainfo=dict(img_shape=(IMAGE_SIZE, IMAGE_SIZE)))
results = self.end2end_model.forward(
inputs, [data_sample], mode='predict')
assert results is not None, 'failed to get output using '\
'End2EndModel'
@backend_checker(Backend.ONNXRUNTIME)
def test_build_video_recognition_model():
model_cfg_path = 'tests/test_codebase/test_mmaction/data/model.py'
model_cfg = load_config(model_cfg_path)[0]
deploy_cfg = Config(
dict(
backend_config=dict(type='onnxruntime'),
onnx_config=dict(output_names=['outputs']),
codebase_config=dict(type='mmaction')))
from mmdeploy.backend.onnxruntime import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
# simplify backend inference
with SwitchBackendWrapper(ORTWrapper) as wrapper:
wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg)
from mmdeploy.codebase.mmaction.deploy.video_recognition_model import (
End2EndModel, build_video_recognition_model)
classifier = build_video_recognition_model([''], model_cfg, deploy_cfg,
'cpu')
assert isinstance(classifier, End2EndModel)