mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
BUG P0 (#1044)
This commit is contained in:
parent
b87afb9ebb
commit
e912e86d24
@ -22,3 +22,66 @@ def coreml_nms(context, node):
|
||||
max_boxes=max_boxes)
|
||||
|
||||
context.add(tuple(results), torch_name=node.outputs[0])
|
||||
|
||||
|
||||
@register_torch_op
|
||||
def log2(context, node):
|
||||
"""bind log2."""
|
||||
import numpy as np
|
||||
inputs = _get_inputs(context, node)
|
||||
x = inputs[0]
|
||||
log_x = mb.log(x=x)
|
||||
context.add(mb.mul(x=log_x, y=1 / np.log(2.0)), node.name)
|
||||
|
||||
|
||||
@register_torch_op
|
||||
def roi_align(context, node):
|
||||
"""roi align."""
|
||||
inputs = _get_inputs(context, node)
|
||||
|
||||
x = context[node.inputs[0]]
|
||||
input_shape = x.shape # (B, C, h_in, w_in)
|
||||
if len(input_shape) != 4:
|
||||
raise ValueError(
|
||||
'"CropResize" op: expected input rank 4, got {}'.format(x.rank))
|
||||
|
||||
const_box_info = True
|
||||
if context[node.inputs[1]].val is None or context[
|
||||
node.inputs[2]].val is None:
|
||||
const_box_info = False
|
||||
|
||||
extrapolation_value = context[node.inputs[2]].val
|
||||
# CoreML index information along with boxes
|
||||
if const_box_info:
|
||||
boxes = context[node.inputs[1]].val
|
||||
# CoreML expects boxes/ROI in
|
||||
# [N, 1, 5, 1, 1] format
|
||||
boxes = boxes.reshape(boxes.shape[0], 1, boxes.shape[1], 1, 1)
|
||||
else:
|
||||
boxes = inputs[1]
|
||||
boxes = mb.reshape(
|
||||
x=boxes, shape=[boxes.shape[0], 1, boxes.shape[1], 1, 1])
|
||||
# Get Height and Width of crop
|
||||
h_out = inputs[3]
|
||||
w_out = inputs[4]
|
||||
|
||||
# Torch input format: [B, C, h_in, w_in]
|
||||
# CoreML input format: [B, C, h_in, w_in]
|
||||
|
||||
# Crop Resize
|
||||
x = mb.crop_resize(
|
||||
x=x,
|
||||
roi=boxes,
|
||||
target_height=h_out.val,
|
||||
target_width=w_out.val,
|
||||
normalized_coordinates=False,
|
||||
spatial_scale=extrapolation_value,
|
||||
box_coordinate_mode='CORNERS_WIDTH_FIRST',
|
||||
sampling_mode='OFFSET_CORNERS',
|
||||
)
|
||||
|
||||
# CoreML output format: [N, 1, C, h_out, w_out]
|
||||
# Torch output format: [N, C, h_out, w_out]
|
||||
x = mb.squeeze(x=x, axes=[1])
|
||||
|
||||
context.add(x, torch_name=node.outputs[0])
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmdeploy.codebase.mmdet import (get_post_processing_params,
|
||||
from mmdeploy.codebase.mmdet import (gather_topk, get_post_processing_params,
|
||||
multiclass_nms,
|
||||
pad_with_value_if_necessary)
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
@ -104,11 +104,17 @@ def rpn_head__get_bboxes(ctx,
|
||||
|
||||
if pre_topk > 0:
|
||||
_, topk_inds = scores.squeeze(2).topk(pre_topk)
|
||||
batch_inds = torch.arange(batch_size, device=device).unsqueeze(-1)
|
||||
prior_inds = topk_inds.new_zeros((1, 1))
|
||||
anchors = anchors[prior_inds, topk_inds, :]
|
||||
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
|
||||
scores = scores[batch_inds, topk_inds, :]
|
||||
bbox_pred, scores = gather_topk(
|
||||
bbox_pred,
|
||||
scores,
|
||||
inds=topk_inds,
|
||||
batch_size=batch_size,
|
||||
is_batched=True)
|
||||
anchors = gather_topk(
|
||||
anchors,
|
||||
inds=topk_inds,
|
||||
batch_size=batch_size,
|
||||
is_batched=False)
|
||||
mlvl_valid_bboxes.append(bbox_pred)
|
||||
mlvl_scores.append(scores)
|
||||
mlvl_valid_anchors.append(anchors)
|
||||
|
@ -316,3 +316,39 @@ def single_roi_extractor__forward__openvino(ctx,
|
||||
args = (output_size, featmap_strides, sample_num, rois, *feats)
|
||||
result = SingleRoIExtractorOpenVINO.apply(*args)
|
||||
return result
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward',
|
||||
backend=Backend.COREML.value)
|
||||
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
|
||||
def single_roi_extractor__forward__coreml(ctx,
|
||||
self,
|
||||
feats,
|
||||
rois,
|
||||
roi_scale_factor=None):
|
||||
"""Rewrite `forward` of SingleRoIExtractor for coreml."""
|
||||
out_size = self.roi_layers[0].output_size
|
||||
num_levels = len(feats)
|
||||
roi_feats = feats[0].new_zeros(rois.shape[0], self.out_channels, *out_size)
|
||||
if num_levels == 1:
|
||||
assert len(rois) > 0, 'The number of rois should be positive'
|
||||
self.roi_layers[0].use_torchvision = True
|
||||
return self.roi_layers[0](feats[0], rois)
|
||||
|
||||
target_lvls = self.map_roi_levels(rois, num_levels)
|
||||
|
||||
if roi_scale_factor is not None:
|
||||
rois = self.roi_rescale(rois, roi_scale_factor)
|
||||
|
||||
for i in range(num_levels):
|
||||
mask = target_lvls == i
|
||||
# inds = mask.nonzero(as_tuple=False).squeeze(1)
|
||||
rois_t = rois * mask.unsqueeze(-1)
|
||||
# use the roi align in torhcvision
|
||||
self.roi_layers[i].use_torchvision = True
|
||||
roi_feats_t = self.roi_layers[i](feats[i], rois_t)
|
||||
roi_feats = roi_feats + roi_feats_t * (rois_t[:, -1] > 0).reshape(
|
||||
-1, 1, 1, 1)
|
||||
# slice to recover original size
|
||||
return roi_feats
|
||||
|
Loading…
x
Reference in New Issue
Block a user