mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
* Add score_threshold and max_num to NMS * Fix codestyle * Fix codestyle * Fix inds in nms * Update nms docstring * Move score_threshold and max_num arguments * Fix args order in docstring * fix lint of c++ file * Remove torch.onnx.is_in_onnx_export() and add max_num to batched_nms for separate classes. * Rewrote max_num handling in NMSop.symbolic * Added processing max_output_boxes_per_class when exporting to TensorRT * Added score_threshold and max_num for NMS in test_onnx.py and test_tensorrt.py * Remove _is_value(max_num) * fix ci errors with torch==1.3.1 * Update test_batched_nms in test_nms.py * Added tests for preprocess_onnx * Moved 'test_tensorrt_preprocess.py' and 'preprocess', updated 'remove_tmp_file'. * Update mmcv/tensorrt/__init__.py * Fix segfault torch==1.3.1 (remove onnx.checker.check_model) * Returned 'onnx.checker.check_model' with torch version check * Changed torch version from 1.3.1 to 1.4.0 * update version check * remove check for onnx Co-authored-by: maningsheng <maningsheng@sensetime.com>
76 lines
2.0 KiB
Python
76 lines
2.0 KiB
Python
import os
|
|
from functools import wraps
|
|
|
|
import onnx
|
|
import torch
|
|
|
|
from mmcv.ops import nms
|
|
from mmcv.tensorrt.preprocess import preprocess_onnx
|
|
|
|
|
|
def remove_tmp_file(func):
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
onnx_file = 'tmp.onnx'
|
|
kwargs['onnx_file'] = onnx_file
|
|
try:
|
|
result = func(*args, **kwargs)
|
|
finally:
|
|
if os.path.exists(onnx_file):
|
|
os.remove(onnx_file)
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
|
|
@remove_tmp_file
|
|
def export_nms_module_to_onnx(module, onnx_file):
|
|
torch_model = module()
|
|
torch_model.eval()
|
|
|
|
input = (torch.rand([100, 4], dtype=torch.float32),
|
|
torch.rand([100], dtype=torch.float32))
|
|
|
|
torch.onnx.export(
|
|
torch_model,
|
|
input,
|
|
onnx_file,
|
|
opset_version=11,
|
|
input_names=['boxes', 'scores'],
|
|
output_names=['output'])
|
|
|
|
onnx_model = onnx.load(onnx_file)
|
|
return onnx_model
|
|
|
|
|
|
def test_can_handle_nms_with_constant_maxnum():
|
|
|
|
class ModuleNMS(torch.nn.Module):
|
|
|
|
def forward(self, boxes, scores):
|
|
return nms(boxes, scores, iou_threshold=0.4, max_num=10)
|
|
|
|
onnx_model = export_nms_module_to_onnx(ModuleNMS)
|
|
preprocess_onnx_model = preprocess_onnx(onnx_model)
|
|
for node in preprocess_onnx_model.graph.node:
|
|
if 'NonMaxSuppression' in node.name:
|
|
assert len(node.attribute) == 5, 'The NMS must have 5 attributes.'
|
|
|
|
|
|
def test_can_handle_nms_with_undefined_maxnum():
|
|
|
|
class ModuleNMS(torch.nn.Module):
|
|
|
|
def forward(self, boxes, scores):
|
|
return nms(boxes, scores, iou_threshold=0.4)
|
|
|
|
onnx_model = export_nms_module_to_onnx(ModuleNMS)
|
|
preprocess_onnx_model = preprocess_onnx(onnx_model)
|
|
for node in preprocess_onnx_model.graph.node:
|
|
if 'NonMaxSuppression' in node.name:
|
|
assert len(node.attribute) == 5, \
|
|
'The NMS must have 5 attributes.'
|
|
assert node.attribute[2].i > 0, \
|
|
'The max_output_boxes_per_class is not defined correctly.'
|