[Feature] Support centernet dev1.x (#1219)
* support centernet head * add centernet head ut * add centernet * add centernet * add support models * fix mdformat * fix reg test * fix scale * fix test.py show_dir kwargs * fix for profile in T4 * fix dynamic shape * fix lint * move rescale and border to outside * fix ut * fix lint * update ort torchscript benchmark * fix centernet * fix ut * remove unused file * support centernet sdk * remove unused rewriter * fix lint * fix flake8 * remove unused line * fix lint * fix lint * fix doc links * fix mdformat * fix scale_factor as default * apart random pad and pad * fix sdk * fix centernet docs * fix code style of cpppull/1372/head
parent
3b6e1ba34d
commit
83756b97c6
|
@ -0,0 +1,14 @@
|
|||
_base_ = [
|
||||
'../_base_/base_dynamic.py', '../../_base_/backends/tensorrt-fp16.py'
|
||||
]
|
||||
|
||||
backend_config = dict(
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 64, 64],
|
||||
opt_shape=[1, 3, 800, 800],
|
||||
max_shape=[1, 3, 800, 800])))
|
||||
])
|
|
@ -0,0 +1,14 @@
|
|||
_base_ = [
|
||||
'../_base_/base_dynamic.py', '../../_base_/backends/tensorrt-int8.py'
|
||||
]
|
||||
|
||||
backend_config = dict(
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 64, 64],
|
||||
opt_shape=[1, 3, 800, 800],
|
||||
max_shape=[1, 3, 800, 800])))
|
||||
])
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/tensorrt.py']
|
||||
|
||||
backend_config = dict(
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 64, 64],
|
||||
opt_shape=[1, 3, 800, 800],
|
||||
max_shape=[1, 3, 800, 800])))
|
||||
])
|
|
@ -129,6 +129,10 @@ Result<Detections> ResizeBBox::GetBBoxes(const Value& prep_res, const Tensor& de
|
|||
|
||||
float w_offset = 0.f;
|
||||
float h_offset = 0.f;
|
||||
if (prep_res.contains("border")) {
|
||||
w_offset = -prep_res["border"][1].get<int>();
|
||||
h_offset = -prep_res["border"][0].get<int>();
|
||||
}
|
||||
int ori_width = prep_res["ori_shape"][2].get<int>();
|
||||
int ori_height = prep_res["ori_shape"][1].get<int>();
|
||||
|
||||
|
|
|
@ -30,8 +30,10 @@ Result<Value> DefaultFormatBundleImpl::Process(const Value& input) {
|
|||
}
|
||||
}
|
||||
if (!output.contains("scale_factor")) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
output["scale_factor"].push_back(1.0);
|
||||
}
|
||||
}
|
||||
if (!output.contains("img_norm_cfg")) {
|
||||
int channel = tensor.shape()[3];
|
||||
for (int i = 0; i < channel; i++) {
|
||||
|
|
|
@ -34,6 +34,11 @@ PadImpl::PadImpl(const Value& args) : TransformImpl(args) {
|
|||
} else {
|
||||
arg_.pad_val = 0.0f;
|
||||
}
|
||||
if (args.contains("logical_or_val")) {
|
||||
// logical_or mode support.
|
||||
arg_.logical_or_val = args["logical_or_val"].get<int>();
|
||||
arg_.add_pix_val = args.value("add_pix_val", 0);
|
||||
}
|
||||
arg_.pad_to_square = args.value("pad_to_square", false);
|
||||
arg_.padding_mode = args.value("padding_mode", std::string("constant"));
|
||||
arg_.orientation_agnostic = args.value("orientation_agnostic", false);
|
||||
|
@ -80,6 +85,16 @@ Result<Value> PadImpl::Process(const Value& input) {
|
|||
output["pad_size_divisor"] = arg_.size_divisor;
|
||||
output["pad_fixed_size"].push_back(pad_h);
|
||||
output["pad_fixed_size"].push_back(pad_w);
|
||||
} else if (arg_.logical_or_val > 0) {
|
||||
int pad_h = (height | arg_.logical_or_val) + arg_.add_pix_val;
|
||||
int pad_w = (width | arg_.logical_or_val) + arg_.add_pix_val;
|
||||
int offset_h = pad_h / 2 - height / 2;
|
||||
int offset_w = pad_w / 2 - width / 2;
|
||||
padding = {offset_w, offset_h, pad_w - width - offset_w, pad_h - height - offset_h};
|
||||
output["border"].push_back(offset_h);
|
||||
output["border"].push_back(offset_w);
|
||||
output["border"].push_back(offset_h + height);
|
||||
output["border"].push_back(offset_w + width);
|
||||
} else {
|
||||
output_tensor = tensor;
|
||||
output["pad_fixed_size"].push_back(height);
|
||||
|
|
|
@ -26,6 +26,8 @@ class MMDEPLOY_API PadImpl : public TransformImpl {
|
|||
protected:
|
||||
struct pad_arg_t {
|
||||
std::array<int, 2> size;
|
||||
int logical_or_val;
|
||||
int add_pix_val;
|
||||
int size_divisor;
|
||||
float pad_val;
|
||||
bool pad_to_square;
|
||||
|
|
|
@ -725,6 +725,20 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../
|
|||
<td align="center">37.4</td>
|
||||
<td align="center">-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center"><a href="https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet/centernet_r18_8xb16-crop512-140e_coco.py">CenterNet</a></td>
|
||||
<td align="center">Object Detection</td>
|
||||
<td align="center">COCO2017</td>
|
||||
<td align="center">box AP</td>
|
||||
<td align="center">25.9</td>
|
||||
<td align="center">26.0</td>
|
||||
<td align="center">26.0</td>
|
||||
<td align="center">26.0</td>
|
||||
<td align="center">25.8</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center"><a href="https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox/yolox_s_8x8_300e_coco.py">YOLOX</a></td>
|
||||
<td align="center">Object Detection</td>
|
||||
|
|
|
@ -21,6 +21,7 @@ The table below lists the models that are guaranteed to be exportable to other b
|
|||
| [VFNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/vfnet) | MMDetection | N | N | N | N | N | Y | N | N |
|
||||
| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | MMDetection | N | N | Y | N | ? | Y | N | N |
|
||||
| [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | MMDetection | N | Y | Y | N | ? | N | N | N |
|
||||
| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | MMDetection | N | Y | Y | N | ? | N | N | N |
|
||||
| [ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnet) | MMClassification | Y | Y | Y | Y | Y | Y | Y | Y |
|
||||
| [ResNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnext) | MMClassification | Y | Y | Y | Y | Y | Y | Y | Y |
|
||||
| [SE-ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/seresnet) | MMClassification | Y | Y | Y | Y | Y | Y | Y | Y |
|
||||
|
|
|
@ -206,6 +206,7 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter
|
|||
| [GFL](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/gfl) | Object Detection | Y | Y | N | ? | Y |
|
||||
| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | Object Detection | N | Y | N | ? | Y |
|
||||
| [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | Object Detection | Y | Y | N | ? | Y |
|
||||
| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | Object Detection | Y | Y | N | ? | ? |
|
||||
| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Instance Segmentation | Y | N | N | N | Y |
|
||||
| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
|
||||
| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/swin) | Instance Segmentation | Y | Y | N | N | N |
|
||||
|
|
|
@ -720,6 +720,20 @@ GPU: ncnn, TensorRT, PPLNN
|
|||
<td align="center">37.4</td>
|
||||
<td align="center">-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center"><a href="https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet/centernet_r18_8xb16-crop512-140e_coco.py">CenterNet</a></td>
|
||||
<td align="center">Object Detection</td>
|
||||
<td align="center">COCO2017</td>
|
||||
<td align="center">box AP</td>
|
||||
<td align="center">25.9</td>
|
||||
<td align="center">26.0</td>
|
||||
<td align="center">26.0</td>
|
||||
<td align="center">26.0</td>
|
||||
<td align="center">25.8</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center"><a href="https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox/yolox_s_8x8_300e_coco.py">YOLOX</a></td>
|
||||
<td align="center">Object Detection</td>
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
| [VFNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/vfnet) | MMDetection | N | N | N | N | N | Y | N | N |
|
||||
| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | MMDetection | N | N | Y | N | ? | Y | N | N |
|
||||
| [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | MMDetection | N | Y | Y | N | ? | N | N | N |
|
||||
| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | MMDetection | N | Y | Y | N | ? | N | N | N |
|
||||
| [ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnet) | MMClassification | Y | Y | Y | Y | Y | Y | Y | Y |
|
||||
| [ResNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnext) | MMClassification | Y | Y | Y | Y | Y | Y | Y | Y |
|
||||
| [SE-ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/seresnet) | MMClassification | Y | Y | Y | Y | Y | Y | Y | Y |
|
||||
|
|
|
@ -209,6 +209,7 @@ cv2.imwrite('output_detection.png', img)
|
|||
| [GFL](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/gfl) | ObjectDetection | Y | Y | N | ? | Y |
|
||||
| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | ObjectDetection | N | Y | N | ? | Y |
|
||||
| [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | ObjectDetection | Y | Y | N | ? | Y |
|
||||
| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | Object Detection | Y | Y | N | ? | ? |
|
||||
| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | InstanceSegmentation | Y | N | N | N | Y |
|
||||
| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/mask_rcnn) | InstanceSegmentation | Y | Y | N | N | Y |
|
||||
| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/swin) | InstanceSegmentation | Y | Y | N | N | N |
|
||||
|
|
|
@ -191,10 +191,10 @@ class BaseTask(metaclass=ABCMeta):
|
|||
' field of config. Please set '\
|
||||
'`visualization=dict(type="VisualizationHook")`'
|
||||
|
||||
cfg.default_hooks.visualization.enable = True
|
||||
cfg.default_hooks.visualization.draw = True
|
||||
cfg.default_hooks.visualization.show = show
|
||||
cfg.default_hooks.visualization.wait_time = wait_time
|
||||
cfg.default_hooks.visualization.out_dir = show_dir
|
||||
cfg.default_hooks.visualization.test_out_dir = show_dir
|
||||
cfg.default_hooks.visualization.interval = interval
|
||||
|
||||
return cfg
|
||||
|
@ -211,7 +211,6 @@ class BaseTask(metaclass=ABCMeta):
|
|||
model_cfg = _merge_cfg(model_cfg)
|
||||
|
||||
visualizer = self.get_visualizer(work_dir, work_dir)
|
||||
|
||||
from .runner import DeployTestRunner
|
||||
runner = DeployTestRunner(
|
||||
model=model,
|
||||
|
|
|
@ -247,6 +247,16 @@ class ObjectDetection(BaseTask):
|
|||
'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg',
|
||||
'valid_ratio'
|
||||
]
|
||||
# Extra pad outside datapreprocessor for CenterNet, CornerNet, etc.
|
||||
for i, transform in enumerate(pipeline):
|
||||
if transform['type'] == 'RandomCenterCropPad':
|
||||
if transform['test_pad_mode'][0] == 'logical_or':
|
||||
extra_pad = dict(
|
||||
type='Pad',
|
||||
logical_or_val=transform['test_pad_mode'][1],
|
||||
add_pix_val=transform['test_pad_add_pix'],
|
||||
)
|
||||
pipeline[i] = extra_pad
|
||||
transforms = [
|
||||
item for item in pipeline if 'Random' not in item['type']
|
||||
and 'Annotation' not in item['type']
|
||||
|
@ -262,6 +272,7 @@ class ObjectDetection(BaseTask):
|
|||
transforms[i]['size'] = transforms[i].pop('scale')
|
||||
|
||||
data_preprocessor = model_cfg.model.data_preprocessor
|
||||
|
||||
transforms.insert(-1, dict(type='DefaultFormatBundle'))
|
||||
transforms.insert(
|
||||
-2,
|
||||
|
|
|
@ -199,9 +199,8 @@ class End2EndModel(BaseBackendModel):
|
|||
|
||||
bboxes = dets[:, :4]
|
||||
scores = dets[:, 4]
|
||||
|
||||
# perform rescale
|
||||
if rescale:
|
||||
if rescale and 'scale_factor' in img_metas[i]:
|
||||
scale_factor = img_metas[i]['scale_factor']
|
||||
if isinstance(scale_factor, (list, tuple, np.ndarray)):
|
||||
if len(scale_factor) == 2:
|
||||
|
@ -212,12 +211,19 @@ class End2EndModel(BaseBackendModel):
|
|||
scale_factor = torch.from_numpy(scale_factor).to(dets)
|
||||
bboxes /= scale_factor
|
||||
|
||||
if 'pad_param' in img_metas[i]:
|
||||
# Most of models in mmdetection 3.x use `pad_param`, but some
|
||||
# models like CenterNet uses `border`.
|
||||
# offset pixel of the top-left corners between original image
|
||||
# and padded/enlarged image, 'pad_param' is used when exporting
|
||||
# CornerNet and CentripetalNet to onnx
|
||||
x_off = img_metas[i]['pad_param'][2]
|
||||
y_off = img_metas[i]['pad_param'][0]
|
||||
pad_key = None
|
||||
if 'pad_param' in img_metas[i]:
|
||||
pad_key = 'pad_param'
|
||||
elif 'border' in img_metas[i]:
|
||||
pad_key = 'border'
|
||||
if pad_key is not None:
|
||||
x_off = img_metas[i][pad_key][2]
|
||||
y_off = img_metas[i][pad_key][0]
|
||||
bboxes[:, ::2] -= x_off
|
||||
bboxes[:, 1::2] -= y_off
|
||||
bboxes *= (bboxes > 0)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import base_dense_head # noqa: F401,F403
|
||||
from . import centernet_head # noqa: F401,F403
|
||||
from . import detr_head # noqa: F401,F403
|
||||
from . import fovea_head # noqa: F401,F403
|
||||
from . import gfl_head # noqa: F401,F403
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.dense_heads.centernet_head.CenterNetHead.predict_by_feat')
|
||||
def centernet_head__predict_by_feat__default(
|
||||
ctx,
|
||||
self,
|
||||
center_heatmap_preds: List[Tensor],
|
||||
wh_preds: List[Tensor],
|
||||
offset_preds: List[Tensor],
|
||||
batch_img_metas: List[dict],
|
||||
rescale: bool = True,
|
||||
with_nms: bool = False):
|
||||
"""Rewrite `centernethead` of `CenterNetHead` for default backend."""
|
||||
|
||||
# The dynamic shape deploy of CenterNet get wrong result on TensorRT-8.4.x
|
||||
# because of TensorRT bugs, https://github.com/NVIDIA/TensorRT/issues/2299,
|
||||
# FYI.
|
||||
|
||||
assert len(center_heatmap_preds) == len(wh_preds) == len(offset_preds) == 1
|
||||
batch_center_heatmap_preds = center_heatmap_preds[0]
|
||||
batch_wh_preds = wh_preds[0]
|
||||
batch_offset_preds = offset_preds[0]
|
||||
batch_size = batch_center_heatmap_preds.shape[0]
|
||||
img_shape = batch_img_metas[0]['img_shape']
|
||||
batch_det_bboxes, batch_labels = self._decode_heatmap(
|
||||
batch_center_heatmap_preds,
|
||||
batch_wh_preds,
|
||||
batch_offset_preds,
|
||||
img_shape,
|
||||
k=self.test_cfg.topk,
|
||||
kernel=self.test_cfg.local_maximum_kernel)
|
||||
det_bboxes = batch_det_bboxes.reshape([batch_size, -1, 5])
|
||||
det_labels = batch_labels.reshape(batch_size, -1)
|
||||
|
||||
if with_nms:
|
||||
det_bboxes, det_labels = self._bboxes_nms(det_bboxes, det_labels,
|
||||
self.test_cfg)
|
||||
return det_bboxes, det_labels
|
|
@ -300,6 +300,17 @@ models:
|
|||
- *pipeline_ort_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp16
|
||||
|
||||
- name: CenterNet
|
||||
metafile: configs/centernet/metafile.yml
|
||||
model_configs:
|
||||
- configs/centernet/centernet_r18_8xb16-crop512-140e_coco.py
|
||||
pipelines:
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
- deploy_config: configs/mmdet/detection/detection_tensorrt-fp16_dynamic-64x64-800x800.py
|
||||
convert_image: *convert_image
|
||||
backend_test: *default_backend_test
|
||||
sdk_config: *sdk_dynamic
|
||||
|
||||
- name: Mask R-CNN
|
||||
metafile: configs/mask_rcnn/metafile.yml
|
||||
model_configs:
|
||||
|
|
|
@ -1482,6 +1482,111 @@ def test_yolov3_head_predict_by_feat_ncnn():
|
|||
assert rewrite_outputs.shape[-1] == 6
|
||||
|
||||
|
||||
def get_centernet_head_model():
|
||||
"""CenterNet Head Config."""
|
||||
test_cfg = Config(dict(topk=100, local_maximum_kernel=3, max_per_img=100))
|
||||
|
||||
from mmdet.models.dense_heads import CenterNetHead
|
||||
model = CenterNetHead(8, 8, 4, test_cfg=test_cfg)
|
||||
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
def test_centernet_head_predict_by_feat(backend_type: Backend):
|
||||
"""Test predict_by_feat rewrite of CenterNetHead."""
|
||||
check_backend(backend_type)
|
||||
centernet_head = get_centernet_head_model()
|
||||
centernet_head.cpu().eval()
|
||||
s = 128
|
||||
batch_img_metas = [{
|
||||
'border':
|
||||
np.array([11., 99., 11., 99.], dtype=np.float32),
|
||||
'img_shape': (s, s),
|
||||
'batch_input_shape': (s, s)
|
||||
}]
|
||||
|
||||
output_names = ['dets', 'labels']
|
||||
deploy_cfg = Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
onnx_config=dict(output_names=output_names, input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmdet',
|
||||
task='ObjectDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.5,
|
||||
max_output_boxes_per_class=20,
|
||||
pre_top_k=-1,
|
||||
keep_top_k=10,
|
||||
background_label_id=-1,
|
||||
))))
|
||||
seed_everything(1234)
|
||||
center_heatmap_preds = [
|
||||
torch.rand(1, centernet_head.num_classes, s // 4, s // 4)
|
||||
]
|
||||
seed_everything(5678)
|
||||
wh_preds = [torch.rand(1, 2, s // 4, s // 4)]
|
||||
seed_everything(9101)
|
||||
offset_preds = [torch.rand(1, 2, s // 4, s // 4)]
|
||||
|
||||
# to get outputs of pytorch model
|
||||
model_inputs = {
|
||||
'center_heatmap_preds': center_heatmap_preds,
|
||||
'wh_preds': wh_preds,
|
||||
'offset_preds': offset_preds,
|
||||
'batch_img_metas': batch_img_metas,
|
||||
'with_nms': False
|
||||
}
|
||||
model_outputs = get_model_outputs(centernet_head, 'predict_by_feat',
|
||||
model_inputs)
|
||||
|
||||
# to get outputs of onnx model after rewrite
|
||||
wrapped_model = WrapModel(
|
||||
centernet_head, 'predict_by_feat', batch_img_metas=batch_img_metas)
|
||||
rewrite_inputs = {
|
||||
'center_heatmap_preds': center_heatmap_preds,
|
||||
'wh_preds': wh_preds,
|
||||
'offset_preds': offset_preds,
|
||||
}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
if is_backend_output:
|
||||
# hard code to make two tensors with the same shape
|
||||
# rewrite and original codes applied different nms strategy
|
||||
min_shape = min(model_outputs[0].bboxes.shape[0],
|
||||
rewrite_outputs[0].shape[1], 5)
|
||||
for i in range(len(model_outputs)):
|
||||
border = batch_img_metas[i]['border']
|
||||
|
||||
rewrite_outputs[0][i, :, 0] -= border[2]
|
||||
rewrite_outputs[0][i, :, 1] -= border[0]
|
||||
rewrite_outputs[0][i, :, 2] -= border[2]
|
||||
rewrite_outputs[0][i, :, 3] -= border[0]
|
||||
assert np.allclose(
|
||||
model_outputs[i].bboxes[:min_shape],
|
||||
rewrite_outputs[0][i, :min_shape, :4],
|
||||
rtol=1e-03,
|
||||
atol=1e-05)
|
||||
assert np.allclose(
|
||||
model_outputs[i].scores[:min_shape],
|
||||
rewrite_outputs[0][i, :min_shape, 4],
|
||||
rtol=1e-03,
|
||||
atol=1e-05)
|
||||
assert np.allclose(
|
||||
model_outputs[i].labels[:min_shape],
|
||||
rewrite_outputs[1][i, :min_shape],
|
||||
rtol=1e-03,
|
||||
atol=1e-05)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
def get_yolox_head_model():
|
||||
"""YOLOX Head Config."""
|
||||
test_cfg = Config(
|
||||
|
|
|
@ -1,52 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import yaml
|
||||
|
||||
assert os.path.exists('checkpoints')
|
||||
assert os.path.exists('images')
|
||||
if os.path.isfile('report.txt'):
|
||||
with open('report.txt', 'a') as f:
|
||||
sys.stdout = f
|
||||
for i in range(50):
|
||||
print('')
|
||||
f.close()
|
||||
yaml_file = 'tools/benchmark_test.yml'
|
||||
with open(yaml_file) as f:
|
||||
benchmark_test_info = yaml.load(f, Loader=yaml.FullLoader)
|
||||
for model in benchmark_test_info['models']:
|
||||
for model_info in model['model_info']:
|
||||
for deploy_cfg in model['deploy_cfg']:
|
||||
if 'int8' in deploy_cfg:
|
||||
assert os.path.exists('data/coco')
|
||||
model_cfg = model_info['model_cfg']
|
||||
checkpoint = model_info['checkpoint']
|
||||
shape = model_info['shape']
|
||||
img_path = f'images/{shape}.jpg'
|
||||
backend_model = model['backend_model']
|
||||
device = model['device']
|
||||
img_folder = 'images/data'
|
||||
assert os.path.exists(img_path)
|
||||
assert os.path.exists(img_folder)
|
||||
convert_cmd = (f'python tools/deploy.py {deploy_cfg} ' +
|
||||
f'{model_cfg} {checkpoint} {img_path} ' +
|
||||
f' --work-dir tools/ --device {device}')
|
||||
os.system(f'rm -rf tools/{backend_model}')
|
||||
with open('report.txt', 'a') as f:
|
||||
sys.stdout = f
|
||||
print(convert_cmd)
|
||||
os.system('nvidia-smi >> report.txt')
|
||||
f.close()
|
||||
os.system(convert_cmd)
|
||||
profile_cmd = (f'python tools/profiler.py {deploy_cfg} ' +
|
||||
f'{model_cfg} {img_folder} --model tools/' +
|
||||
f'{backend_model} --device {device} --shape ' +
|
||||
f'{shape}')
|
||||
with open('report.txt', 'a') as f:
|
||||
sys.stdout = f
|
||||
print(profile_cmd)
|
||||
time.sleep(3)
|
||||
os.system(profile_cmd)
|
||||
f.close()
|
|
@ -1,33 +0,0 @@
|
|||
deploy_cfgs:
|
||||
- &ObjectDetection-0
|
||||
- configs/mmdet/detection/detection_tensorrt_dynamic-320x320-1344x1344.py
|
||||
- configs/mmdet/detection/detection_tensorrt-fp16_dynamic-320x320-1344x1344.py
|
||||
- configs/mmdet/detection/detection_tensorrt-int8_dynamic-320x320-1344x1344.py
|
||||
- &InstanceSegmentation-0
|
||||
- configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py
|
||||
- configs/mmdet/instance-seg/instance-seg_tensorrt-fp16_dynamic-320x320-1344x1344.py
|
||||
- configs/mmdet/instance-seg/instance-seg_tensorrt-int8_dynamic-320x320-1344x1344.py
|
||||
|
||||
img_dir: images
|
||||
|
||||
models:
|
||||
-
|
||||
name: YOLOV3
|
||||
deploy_cfg: *ObjectDetection-0
|
||||
model_info:
|
||||
-
|
||||
model_cfg: ../mmdetection/configs/yolo/yolov3_mobilenetv2_320_300e_coco.py
|
||||
checkpoint: checkpoints/yolov3_mobilenetv2_320_300e_coco_20210719_215349-d18dff72.pth
|
||||
shape: 320x320
|
||||
device: cuda:0
|
||||
backend_model: end2end.engine
|
||||
-
|
||||
name: RetinaNet
|
||||
deploy_cfg: *ObjectDetection-0
|
||||
model_info:
|
||||
-
|
||||
model_cfg: ../mmdetection/configs/retinanet/retinanet_r18_fpn_1x_coco.py
|
||||
checkpoint: checkpoints/retinanet_r18_fpn_1x_coco_20220407_171055-614fd399.pth
|
||||
shape: 800x1344
|
||||
device: cuda:0
|
||||
backend_model: end2end.engine
|
Loading…
Reference in New Issue