fix ascend (#1667)
parent
85320df2b4
commit
99d6fb3190
|
@ -94,7 +94,7 @@
|
|||
</tr>
|
||||
<tr>
|
||||
<td>MMDEPLOY_TARGET_BACKENDS</td>
|
||||
<td>{"trt", "ort", "pplnn", "ncnn", "openvino", "torchscript", "snpe", "tvm"}</td>
|
||||
<td>{"trt", "ort", "pplnn", "ncnn", "openvino", "torchscript", "snpe", "tvm", "acl"}</td>
|
||||
<td>N/A</td>
|
||||
<td>Enabling inference engine. <b>By default, no target inference engine is set, since it highly depends on the use case.</b> When more than one engine are specified, it has to be set with a semicolon separated list of inference backend names, e.g. <pre><code>-DMMDEPLOY_TARGET_BACKENDS="trt;ort;pplnn;ncnn;openvino"</code></pre>
|
||||
After specifying the inference engine, it's package path has to be passed to cmake as follows, <br>
|
||||
|
|
|
@ -97,7 +97,7 @@
|
|||
|
||||
<tr>
|
||||
<td>MMDEPLOY_TARGET_BACKENDS</td>
|
||||
<td>{"trt", "ort", "pplnn", "ncnn", "openvino", "torchscript", "snpe", "coreml", "tvm"}</td>
|
||||
<td>{"trt", "ort", "pplnn", "ncnn", "openvino", "torchscript", "snpe", "coreml", "tvm", "acl"}</td>
|
||||
<td>N/A</td>
|
||||
<td> <b>默认情况下,SDK不设置任何后端</b>, 因为它与应用场景高度相关。 当选择多个后端时, 中间使用分号隔开。比如,<pre><code>-DMMDEPLOY_TARGET_BACKENDS="trt;ort;pplnn;ncnn;openvino"</code></pre>
|
||||
构建时,几乎每个后端,都需设置一些路径变量,用来查找依赖包。<br>
|
||||
|
|
|
@ -75,19 +75,21 @@ def visualize_model(model_cfg: Union[str, mmcv.Config],
|
|||
# check headless
|
||||
import tkinter
|
||||
tkinter.Tk()
|
||||
if isinstance(img, str) or not isinstance(img, Sequence):
|
||||
img = [img]
|
||||
for single_img in img:
|
||||
task_processor.visualize(
|
||||
image=single_img,
|
||||
model=model,
|
||||
result=result,
|
||||
output_file=output_file,
|
||||
window_name=backend.value,
|
||||
show_result=show_result)
|
||||
except Exception as e:
|
||||
from mmdeploy.utils import get_root_logger
|
||||
logger = get_root_logger()
|
||||
logger.warn(
|
||||
f'render and display result skipped for headless device, exception {e}' # noqa: E501
|
||||
)
|
||||
show_result = False
|
||||
|
||||
if isinstance(img, str) or not isinstance(img, Sequence):
|
||||
img = [img]
|
||||
for single_img in img:
|
||||
task_processor.visualize(
|
||||
image=single_img,
|
||||
model=model,
|
||||
result=result,
|
||||
output_file=output_file,
|
||||
window_name=backend.value,
|
||||
show_result=show_result)
|
||||
|
|
|
@ -85,7 +85,8 @@ class AscendManager(BaseBackendManager):
|
|||
|
||||
om_files = []
|
||||
for model_id, onnx_path in enumerate(ir_files):
|
||||
om_path = osp.splitext(onnx_path)[0] + '.om'
|
||||
om_name = osp.splitext(osp.split(onnx_path)[1])[0] + '.om'
|
||||
om_path = osp.join(work_dir, om_name)
|
||||
from_onnx(onnx_path, work_dir, model_inputs[model_id])
|
||||
om_files.append(om_path)
|
||||
backend_files = om_files
|
||||
|
|
|
@ -393,6 +393,7 @@ class AscendWrapper(BaseWrapper):
|
|||
self.__ascend_execute()
|
||||
|
||||
for binding in self._model_desc.outputs:
|
||||
tensor = outputs[binding.name]
|
||||
self._copy_buffer_to_tensor(
|
||||
self._output.buffers[binding.index], tensor)
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import Tensor
|
||||
|
@ -405,9 +407,14 @@ def multiclass_nms__torchscript(ctx,
|
|||
class AscendBatchNMSOp(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, bboxes: torch.Tensor, scores: torch.Tensor,
|
||||
score_threshold: float, iou_threshold: float,
|
||||
max_size_per_class: int, max_total_size: int):
|
||||
def forward(ctx,
|
||||
bboxes: torch.Tensor,
|
||||
scores: torch.Tensor,
|
||||
score_threshold: float,
|
||||
iou_threshold: float,
|
||||
max_size_per_class: int,
|
||||
max_total_size: int,
|
||||
dummy_output: Any = None):
|
||||
"""Dummy nms forward
|
||||
Args:
|
||||
boxes (torch.Tensor): boxes in shape (batch, N, C, 4).
|
||||
|
@ -416,6 +423,7 @@ class AscendBatchNMSOp(torch.autograd.Function):
|
|||
iou_threshold (float): the iou threshold.
|
||||
max_size_per_class (int): max size per class.
|
||||
max_total_size (int): max total size.
|
||||
dummy_output (Any): output for this op
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): boxes,(1, N, 4)
|
||||
|
@ -425,16 +433,29 @@ class AscendBatchNMSOp(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
# Python implementation for onnx export
|
||||
nmsed_boxes = bboxes[:, :max_total_size, 0, :]
|
||||
nmsed_scores = scores[:, :max_total_size, 0]
|
||||
nmsed_classes = torch.arange(max_total_size, dtype=torch.long)
|
||||
nmsed_num = torch.Tensor([max_total_size])
|
||||
if dummy_output is not None:
|
||||
dets, labels = dummy_output
|
||||
nmsed_boxes = dets[..., :4]
|
||||
nmsed_scores = dets[..., 4]
|
||||
nmsed_classes = labels
|
||||
nmsed_num = labels.new_tensor([max_total_size])
|
||||
else:
|
||||
nmsed_boxes = bboxes[:, :max_total_size, 0, :]
|
||||
nmsed_scores = scores[:, :max_total_size, 0]
|
||||
nmsed_classes = torch.arange(max_total_size, dtype=torch.long)
|
||||
nmsed_num = torch.Tensor([max_total_size])
|
||||
|
||||
return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class,
|
||||
max_t_size):
|
||||
def symbolic(g,
|
||||
bboxes,
|
||||
scores,
|
||||
score_thr,
|
||||
iou_thr,
|
||||
max_size_p_class,
|
||||
max_t_size,
|
||||
dummy_output=None):
|
||||
nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = g.op(
|
||||
'mmdeploy::BatchMultiClassNMS',
|
||||
bboxes,
|
||||
|
@ -480,11 +501,21 @@ def multiclass_nms__ascend(ctx,
|
|||
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
|
||||
and `labels` of shape [N, num_det].
|
||||
"""
|
||||
origin_output = ctx.origin_func(
|
||||
boxes,
|
||||
scores,
|
||||
max_output_boxes_per_class=max_output_boxes_per_class,
|
||||
iou_threshold=iou_threshold,
|
||||
score_threshold=score_threshold,
|
||||
pre_top_k=pre_top_k,
|
||||
keep_top_k=keep_top_k)
|
||||
|
||||
boxes = boxes if boxes.dim() == 4 else boxes.unsqueeze(2)
|
||||
keep_top_k = max_output_boxes_per_class if keep_top_k < 0 else min(
|
||||
max_output_boxes_per_class, keep_top_k)
|
||||
nmsed_boxes, nmsed_scores, nmsed_classes, _ = AscendBatchNMSOp.apply(
|
||||
boxes, scores, score_threshold, iou_threshold, keep_top_k, keep_top_k)
|
||||
boxes, scores, score_threshold, iou_threshold, keep_top_k, keep_top_k,
|
||||
origin_output)
|
||||
|
||||
dets = torch.cat([nmsed_boxes, nmsed_scores.unsqueeze(2)], dim=-1)
|
||||
return dets, nmsed_classes
|
||||
|
|
|
@ -250,6 +250,9 @@ def __gather_topk__trt(ctx,
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk',
|
||||
backend=Backend.COREML.value)
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk',
|
||||
backend=Backend.ASCEND.value)
|
||||
def __gather_topk__nonbatch(ctx,
|
||||
*inputs: Sequence[torch.Tensor],
|
||||
inds: torch.Tensor,
|
||||
|
|
|
@ -110,6 +110,7 @@ class AscendRoiExtractor(Function):
|
|||
@staticmethod
|
||||
def symbolic(g, *args):
|
||||
"""Symbolic function for creating onnx op."""
|
||||
args = args[:-1]
|
||||
aligned = args[-1]
|
||||
featmap_strides = [1 / stride for stride in args[-2]]
|
||||
finest_scale = args[-3]
|
||||
|
@ -137,20 +138,26 @@ class AscendRoiExtractor(Function):
|
|||
@staticmethod
|
||||
def forward(ctx, *args):
|
||||
"""Run forward."""
|
||||
# aligned = args[-1]
|
||||
featmap_strides = args[-2]
|
||||
# finest_scale = args[-3]
|
||||
# roi_scale_factor = args[-4]
|
||||
# sampling_ratio = args[-5]
|
||||
output_size = args[-7]
|
||||
inputs = args[:len(featmap_strides)]
|
||||
rois = args[len(featmap_strides)]
|
||||
dummy_output = args[-1]
|
||||
args = args[:-1]
|
||||
|
||||
num_proposals = rois.shape[0]
|
||||
channel = inputs[0].shape[1]
|
||||
if dummy_output is not None:
|
||||
return dummy_output
|
||||
else:
|
||||
# aligned = args[-1]
|
||||
featmap_strides = args[-2]
|
||||
# finest_scale = args[-3]
|
||||
# roi_scale_factor = args[-4]
|
||||
# sampling_ratio = args[-5]
|
||||
output_size = args[-7]
|
||||
inputs = args[:len(featmap_strides)]
|
||||
rois = args[len(featmap_strides)]
|
||||
|
||||
return rois.new_zeros(
|
||||
(num_proposals, channel, output_size[1], output_size[0]))
|
||||
num_proposals = rois.shape[0]
|
||||
channel = inputs[0].shape[1]
|
||||
|
||||
return rois.new_zeros(
|
||||
(num_proposals, channel, output_size[1], output_size[0]))
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
|
@ -183,9 +190,12 @@ def single_roi_extractor__forward__ascend(ctx,
|
|||
roi_scale_factor = 1.0
|
||||
|
||||
featmap_strides = [float(s) for s in featmap_strides]
|
||||
|
||||
origin_output = ctx.origin_func(self, feats, rois, roi_scale_factor)
|
||||
return AscendRoiExtractor.apply(*feats, rois, out_size, pool_mode,
|
||||
sampling_ratio, roi_scale_factor,
|
||||
finest_scale, featmap_strides, aligned)
|
||||
finest_scale, featmap_strides, aligned,
|
||||
origin_output)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
|
|
Loading…
Reference in New Issue