[Feature] Add mmocr ncnn support (#53)

* first

* fix0

* fix1

* dirty work

* wip

* add allocator

* finally done!

* lint

* fix lint

* better gather

* better onnx2ncnn

* fix expand

* [Fix] NCNN TensorSlice op bugs (#42)

* fix custom ops support, fix multiple mark bug, add name mapping

* check if the value_info need to be added

* remove unnecessary print

* add nms implement

* two stage split wip

* add two stage split

* add split retinanet visualize

* add two stage split (wip)

* finish two stage split

* fix lint

* move parse string to mmdeploy.utils

* add calib data generator

* create calib dataset

* finish end2end int8

* add split two stage tensorrt visualize

* fix tensorslice bugs

* fix lint

* fix clang-format

* remove comments

* int param

* fix lint

Co-authored-by: grimoire <yaoqian@sensetime.com>

* add two stage ncnn support

* remove unused ops

* git unused config

* remove no_grad, should add in refactor

* add ncnn wrapper

* fix lint

* size return tuple

* Resolve grammar error

* Fix lint

* Trim Trailing Whitespace

* fix trim

* update wrapper

* remove logs

* remove

* csrc optimize

* add ncnn dbnet support

* finish crnn support

* add comment

Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
pull/12/head
q.yao 2021-09-03 15:16:20 +08:00 committed by GitHub
parent 2b98040b26
commit e73d9fb50b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 169 additions and 4 deletions

View File

@ -1917,6 +1917,7 @@ static void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph,
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose
// or LSTM(bi) <= LSTM(bi) - Transpose Constant - Reshape - Transpose
if (node->op_type() == "LSTM" || node->op_type() == "GRU" ||
node->op_type() == "RNN") {
if (node_reference[node->output(0)] != 1) continue;
@ -1926,6 +1927,13 @@ static void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph,
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
// skip if second ops is constant
if (node3->op_type() == "Constant") {
if (i + 3 >= node_count) continue;
node3 = mutable_graph->mutable_node(i + 3);
i += 1;
}
if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
continue;

View File

@ -0,0 +1,6 @@
_base_ = ['../_base_/torch2onnx.py']
codebase = 'mmocr'
# 'TextDetection' or 'TextRecognition'
task = 'TextDetection'
pytorch2onnx = dict(input_names=['input'], output_names=['output'])

View File

@ -0,0 +1 @@
_base_ = ['./base_static.py', '../_base_/backends/ncnn.py']

View File

@ -240,6 +240,48 @@ class TensorRTRecognizer(DeployBaseRecognizer):
return trt_pred
class NCNNDetector(DeployBaseTextDetector):
"""The class for evaluating NCNN file of text detection."""
def __init__(self,
model_file: Iterable[str],
cfg: Union[mmcv.Config, mmcv.ConfigDict],
device_id: int,
show_score: bool = False):
super(NCNNDetector, self).__init__(cfg, device_id, show_score)
from mmdeploy.apis.ncnn import NCNNWrapper
self.model = NCNNWrapper(
model_file[0], model_file[1], output_names=['output'])
def forward_of_backend(self,
img: torch.Tensor,
img_metas: Iterable,
rescale: bool = False):
pred = self.model({'input': img})['output']
return pred
class NCNNRecognizer(DeployBaseRecognizer):
"""The class for evaluating NCNN file of recognition."""
def __init__(self,
model_file: Iterable[str],
cfg: Union[mmcv.Config, mmcv.ConfigDict],
device_id: int,
show_score: bool = False):
super(NCNNRecognizer, self).__init__(cfg, device_id, show_score)
from mmdeploy.apis.ncnn import NCNNWrapper
self.model = NCNNWrapper(
model_file[0], model_file[1], output_names=['output'])
def forward_of_backend(self,
img: torch.Tensor,
img_metas: Iterable,
rescale: bool = False):
pred = self.model({'input': img})['output']
return pred
def get_classes_from_config(model_cfg: Union[str, mmcv.Config], **kwargs):
# load cfg if necessary
model_cfg = load_config(model_cfg)[0]
@ -268,9 +310,15 @@ TASK_TENSORRT_MAP = {
Task.TEXT_RECOGNITION: TensorRTRecognizer
}
TASK_NCNN_MAP = {
Task.TEXT_DETECTION: NCNNDetector,
Task.TEXT_RECOGNITION: NCNNRecognizer
}
BACKEND_TASK_MAP = {
Backend.ONNXRUNTIME: TASK_ONNXRUNTIME_MAP,
Backend.TENSORRT: TASK_TENSORRT_MAP
Backend.TENSORRT: TASK_TENSORRT_MAP,
Backend.NCNN: TASK_NCNN_MAP
}

View File

@ -1,4 +1,6 @@
from .layers import * # noqa: F401, F403
from .recognizer.base import forward_of_base_recognizer
from .recognizer.decoders import * # noqa: F401, F403
from .recognizer.encode_decode_recognizer import \
simple_test_of_encode_decode_recognizer

View File

@ -0,0 +1,3 @@
from .lstm_layer import forward_of_bidirectionallstm
__all__ = ['forward_of_bidirectionallstm']

View File

@ -0,0 +1,15 @@
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmocr.models.textrecog.layers.lstm_layer'
'.BidirectionalLSTM.forward',
backend='ncnn')
def forward_of_bidirectionallstm(ctx, self, input):
self.rnn.batch_first = True
recurrent, _ = self.rnn(input)
self.rnn.batch_first = False
output = self.embedding(recurrent)
return output

View File

@ -0,0 +1,3 @@
from .crnn_decoder import forward_train_of_crnndecoder
__all__ = ['forward_train_of_crnndecoder']

View File

@ -0,0 +1,19 @@
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmocr.models.textrecog.decoders.CRNNDecoder.forward_train',
backend='ncnn')
def forward_train_of_crnndecoder(ctx, self, feat, out_enc, targets_dict,
img_metas):
assert feat.size(2) == 1, 'feature height must be 1'
if self.rnn_flag:
x = feat.squeeze(2) # [N, C, W]
x = x.permute(0, 2, 1) # [N, W, C]
outputs = self.decoder(x)
else:
x = self.decoder(feat)
x = x.permute(0, 3, 1, 2).contiguous()
n, w, c, h = x.size()
outputs = x.view(n, w, c * h)
return outputs

View File

@ -1,10 +1,11 @@
from .getattribute import getattribute_static
from .interpolate import interpolate_static
from .linear import linear_ncnn
from .repeat import repeat_static
from .size import size_of_tensor_static
from .topk import topk_dynamic, topk_static
__all__ = [
'getattribute_static', 'interpolate_static', 'repeat_static',
'size_of_tensor_static', 'topk_static', 'topk_dynamic'
'getattribute_static', 'interpolate_static', 'linear_ncnn',
'repeat_static', 'size_of_tensor_static', 'topk_static', 'topk_dynamic'
]

View File

@ -0,0 +1,42 @@
from typing import Union
import torch
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.linear', backend='ncnn')
def linear_ncnn(
ctx,
input: torch.Tensor,
weight: torch.Tensor,
bias: Union[torch.Tensor, torch.NoneType] = None,
):
origin_func = ctx.origin_func
dim = input.dim()
if dim == 2:
return origin_func(input, weight, bias)
else:
out = origin_func(input, weight)
# permute
out = out.transpose(1, dim - 1)
# ncnn only support [c, h, w] and [c, 1, 1] broadcast
out_shape = out.shape
batch_size = out_shape[0]
broad_cast_size = out_shape[1]
out = out.reshape([batch_size, broad_cast_size, -1, 1])
# add bias
bias = bias.view([1, -1, 1, 1])
out = out + bias
# permute back
out = out.reshape(out_shape)
out = out.transpose(1, dim - 1)
return out

View File

@ -2,8 +2,10 @@ from .adaptive_avg_pool import (adaptive_avg_pool1d_op, adaptive_avg_pool2d_op,
adaptive_avg_pool3d_op)
from .grid_sampler import grid_sampler_default
from .instance_norm import instance_norm_trt
from .squeeze import squeeze_default
__all__ = [
'adaptive_avg_pool1d_op', 'adaptive_avg_pool2d_op',
'adaptive_avg_pool3d_op', 'grid_sampler_default', 'instance_norm_trt'
'adaptive_avg_pool3d_op', 'grid_sampler_default', 'instance_norm_trt',
'squeeze_default'
]

View File

@ -0,0 +1,15 @@
import torch.onnx.symbolic_helper as sym_help
from mmdeploy.core import SYMBOLIC_REGISTER
@SYMBOLIC_REGISTER.register_symbolic('squeeze', is_pytorch=True)
def squeeze_default(ctx, g, self, dim=None):
if dim is None:
dims = []
for i, size in enumerate(self.type().sizes()):
if size == 1:
dims.append(i)
else:
dims = [sym_help._get_const(dim, 'i', 'dim')]
return g.op('Squeeze', self, axes_i=dims)