mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
BUG P0 (#1044)
This commit is contained in:
parent
b87afb9ebb
commit
e912e86d24
@ -22,3 +22,66 @@ def coreml_nms(context, node):
|
|||||||
max_boxes=max_boxes)
|
max_boxes=max_boxes)
|
||||||
|
|
||||||
context.add(tuple(results), torch_name=node.outputs[0])
|
context.add(tuple(results), torch_name=node.outputs[0])
|
||||||
|
|
||||||
|
|
||||||
|
@register_torch_op
|
||||||
|
def log2(context, node):
|
||||||
|
"""bind log2."""
|
||||||
|
import numpy as np
|
||||||
|
inputs = _get_inputs(context, node)
|
||||||
|
x = inputs[0]
|
||||||
|
log_x = mb.log(x=x)
|
||||||
|
context.add(mb.mul(x=log_x, y=1 / np.log(2.0)), node.name)
|
||||||
|
|
||||||
|
|
||||||
|
@register_torch_op
|
||||||
|
def roi_align(context, node):
|
||||||
|
"""roi align."""
|
||||||
|
inputs = _get_inputs(context, node)
|
||||||
|
|
||||||
|
x = context[node.inputs[0]]
|
||||||
|
input_shape = x.shape # (B, C, h_in, w_in)
|
||||||
|
if len(input_shape) != 4:
|
||||||
|
raise ValueError(
|
||||||
|
'"CropResize" op: expected input rank 4, got {}'.format(x.rank))
|
||||||
|
|
||||||
|
const_box_info = True
|
||||||
|
if context[node.inputs[1]].val is None or context[
|
||||||
|
node.inputs[2]].val is None:
|
||||||
|
const_box_info = False
|
||||||
|
|
||||||
|
extrapolation_value = context[node.inputs[2]].val
|
||||||
|
# CoreML index information along with boxes
|
||||||
|
if const_box_info:
|
||||||
|
boxes = context[node.inputs[1]].val
|
||||||
|
# CoreML expects boxes/ROI in
|
||||||
|
# [N, 1, 5, 1, 1] format
|
||||||
|
boxes = boxes.reshape(boxes.shape[0], 1, boxes.shape[1], 1, 1)
|
||||||
|
else:
|
||||||
|
boxes = inputs[1]
|
||||||
|
boxes = mb.reshape(
|
||||||
|
x=boxes, shape=[boxes.shape[0], 1, boxes.shape[1], 1, 1])
|
||||||
|
# Get Height and Width of crop
|
||||||
|
h_out = inputs[3]
|
||||||
|
w_out = inputs[4]
|
||||||
|
|
||||||
|
# Torch input format: [B, C, h_in, w_in]
|
||||||
|
# CoreML input format: [B, C, h_in, w_in]
|
||||||
|
|
||||||
|
# Crop Resize
|
||||||
|
x = mb.crop_resize(
|
||||||
|
x=x,
|
||||||
|
roi=boxes,
|
||||||
|
target_height=h_out.val,
|
||||||
|
target_width=w_out.val,
|
||||||
|
normalized_coordinates=False,
|
||||||
|
spatial_scale=extrapolation_value,
|
||||||
|
box_coordinate_mode='CORNERS_WIDTH_FIRST',
|
||||||
|
sampling_mode='OFFSET_CORNERS',
|
||||||
|
)
|
||||||
|
|
||||||
|
# CoreML output format: [N, 1, C, h_out, w_out]
|
||||||
|
# Torch output format: [N, C, h_out, w_out]
|
||||||
|
x = mb.squeeze(x=x, axes=[1])
|
||||||
|
|
||||||
|
context.add(x, torch_name=node.outputs[0])
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmdeploy.codebase.mmdet import (get_post_processing_params,
|
from mmdeploy.codebase.mmdet import (gather_topk, get_post_processing_params,
|
||||||
multiclass_nms,
|
multiclass_nms,
|
||||||
pad_with_value_if_necessary)
|
pad_with_value_if_necessary)
|
||||||
from mmdeploy.core import FUNCTION_REWRITER
|
from mmdeploy.core import FUNCTION_REWRITER
|
||||||
@ -104,11 +104,17 @@ def rpn_head__get_bboxes(ctx,
|
|||||||
|
|
||||||
if pre_topk > 0:
|
if pre_topk > 0:
|
||||||
_, topk_inds = scores.squeeze(2).topk(pre_topk)
|
_, topk_inds = scores.squeeze(2).topk(pre_topk)
|
||||||
batch_inds = torch.arange(batch_size, device=device).unsqueeze(-1)
|
bbox_pred, scores = gather_topk(
|
||||||
prior_inds = topk_inds.new_zeros((1, 1))
|
bbox_pred,
|
||||||
anchors = anchors[prior_inds, topk_inds, :]
|
scores,
|
||||||
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
|
inds=topk_inds,
|
||||||
scores = scores[batch_inds, topk_inds, :]
|
batch_size=batch_size,
|
||||||
|
is_batched=True)
|
||||||
|
anchors = gather_topk(
|
||||||
|
anchors,
|
||||||
|
inds=topk_inds,
|
||||||
|
batch_size=batch_size,
|
||||||
|
is_batched=False)
|
||||||
mlvl_valid_bboxes.append(bbox_pred)
|
mlvl_valid_bboxes.append(bbox_pred)
|
||||||
mlvl_scores.append(scores)
|
mlvl_scores.append(scores)
|
||||||
mlvl_valid_anchors.append(anchors)
|
mlvl_valid_anchors.append(anchors)
|
||||||
|
@ -316,3 +316,39 @@ def single_roi_extractor__forward__openvino(ctx,
|
|||||||
args = (output_size, featmap_strides, sample_num, rois, *feats)
|
args = (output_size, featmap_strides, sample_num, rois, *feats)
|
||||||
result = SingleRoIExtractorOpenVINO.apply(*args)
|
result = SingleRoIExtractorOpenVINO.apply(*args)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
|
func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward',
|
||||||
|
backend=Backend.COREML.value)
|
||||||
|
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
|
||||||
|
def single_roi_extractor__forward__coreml(ctx,
|
||||||
|
self,
|
||||||
|
feats,
|
||||||
|
rois,
|
||||||
|
roi_scale_factor=None):
|
||||||
|
"""Rewrite `forward` of SingleRoIExtractor for coreml."""
|
||||||
|
out_size = self.roi_layers[0].output_size
|
||||||
|
num_levels = len(feats)
|
||||||
|
roi_feats = feats[0].new_zeros(rois.shape[0], self.out_channels, *out_size)
|
||||||
|
if num_levels == 1:
|
||||||
|
assert len(rois) > 0, 'The number of rois should be positive'
|
||||||
|
self.roi_layers[0].use_torchvision = True
|
||||||
|
return self.roi_layers[0](feats[0], rois)
|
||||||
|
|
||||||
|
target_lvls = self.map_roi_levels(rois, num_levels)
|
||||||
|
|
||||||
|
if roi_scale_factor is not None:
|
||||||
|
rois = self.roi_rescale(rois, roi_scale_factor)
|
||||||
|
|
||||||
|
for i in range(num_levels):
|
||||||
|
mask = target_lvls == i
|
||||||
|
# inds = mask.nonzero(as_tuple=False).squeeze(1)
|
||||||
|
rois_t = rois * mask.unsqueeze(-1)
|
||||||
|
# use the roi align in torhcvision
|
||||||
|
self.roi_layers[i].use_torchvision = True
|
||||||
|
roi_feats_t = self.roi_layers[i](feats[i], rois_t)
|
||||||
|
roi_feats = roi_feats + roi_feats_t * (rois_t[:, -1] > 0).reshape(
|
||||||
|
-1, 1, 1, 1)
|
||||||
|
# slice to recover original size
|
||||||
|
return roi_feats
|
||||||
|
Loading…
x
Reference in New Issue
Block a user