Change op domain (#262)

* change domain to mmdeploy

* update tests

* resolve comments
This commit is contained in:
RunningLeon 2021-12-08 15:06:41 +08:00 committed by GitHub
parent 03c95a1149
commit d96ee9e9f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 38 additions and 37 deletions

View File

@ -3,13 +3,13 @@
#include "ort_utils.h" #include "ort_utils.h"
const char *c_MMCVOpDomain = "mmcv"; const char *c_MMDeployOpDomain = "mmdeploy";
OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, const OrtApiBase *api) { OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, const OrtApiBase *api) {
OrtCustomOpDomain *domain = nullptr; OrtCustomOpDomain *domain = nullptr;
const OrtApi *kOrtApi = api->GetApi(ORT_API_VERSION); const OrtApi *kOrtApi = api->GetApi(ORT_API_VERSION);
if (auto status = kOrtApi->CreateCustomOpDomain(c_MMCVOpDomain, &domain)) { if (auto status = kOrtApi->CreateCustomOpDomain(c_MMDeployOpDomain, &domain)) {
return status; return status;
} }

View File

@ -22,6 +22,7 @@
# go back to third_party directory and git clone pybind11 # go back to third_party directory and git clone pybind11
cd .. cd ..
git clone git@github.com:pybind/pybind11.git pybind11 git clone git@github.com:pybind/pybind11.git pybind11
cd pybind11
git checkout 70a58c5 git checkout 70a58c5
``` ```

View File

@ -28,7 +28,7 @@ class MultiLevelRoiAlign(Function):
inputs = args[:len(featmap_strides)] inputs = args[:len(featmap_strides)]
rois = args[len(featmap_strides)] rois = args[len(featmap_strides)]
return g.op( return g.op(
'mmcv::MMCVMultiLevelRoiAlign', 'mmdeploy::MMCVMultiLevelRoiAlign',
rois, rois,
*inputs, *inputs,
output_height_i=output_size[1], output_height_i=output_size[1],

View File

@ -40,9 +40,9 @@ class Mark(torch.autograd.Function):
@staticmethod @staticmethod
def symbolic(g, x, dtype, shape, func, func_id, type, name, id, attrs): def symbolic(g, x, dtype, shape, func, func_id, type, name, id, attrs):
"""Symbolic function for mmcv::Mark op.""" """Symbolic function for mmdeploy::Mark op."""
n = g.op( n = g.op(
'mmcv::Mark', 'mmdeploy::Mark',
x, x,
dtype_i=TORCH_DTYPE_TO_ONNX[dtype], dtype_i=TORCH_DTYPE_TO_ONNX[dtype],
shape_i=shape, shape_i=shape,

View File

@ -132,10 +132,10 @@ def nms_dynamic(ctx, g, boxes: Tensor, scores: Tensor,
class TRTBatchedNMSop(torch.autograd.Function): class TRTBatchedNMSop(torch.autograd.Function):
"""Create mmcv::TRTBatchedNMS op for TensorRT backend. """Create mmdeploy::TRTBatchedNMS op for TensorRT backend.
NMS in ONNX supports dynamic outputs. This class helps replace NMS in ONNX supports dynamic outputs. This class helps replace
onnx::NonMaxSuppression with mmcv::TRTBatchedNMS. onnx::NonMaxSuppression with mmdeploy::TRTBatchedNMS.
""" """
@staticmethod @staticmethod
@ -190,9 +190,9 @@ class TRTBatchedNMSop(torch.autograd.Function):
iou_threshold: float, iou_threshold: float,
score_threshold: float, score_threshold: float,
background_label_id: int = -1): background_label_id: int = -1):
"""Symbolic function for mmcv::TRTBatchedNMS.""" """Symbolic function for mmdeploy::TRTBatchedNMS."""
return g.op( return g.op(
'mmcv::TRTBatchedNMS', 'mmdeploy::TRTBatchedNMS',
boxes, boxes,
scores, scores,
num_classes_i=num_classes, num_classes_i=num_classes,

View File

@ -16,7 +16,7 @@ def roi_align_default(ctx, g, input: Tensor, rois: Tensor,
sampling_ratio: int, pool_mode: str, aligned: bool): sampling_ratio: int, pool_mode: str, aligned: bool):
"""Rewrite symbolic function for default backend. """Rewrite symbolic function for default backend.
Replace onnx::RoiAlign with mmcv::MMCVRoiAlign. Replace onnx::RoiAlign with mmdeploy::MMCVRoiAlign.
Args: Args:
ctx (ContextCaller): The context with additional information. ctx (ContextCaller): The context with additional information.
@ -38,7 +38,7 @@ def roi_align_default(ctx, g, input: Tensor, rois: Tensor,
""" """
return g.op( return g.op(
'mmcv::MMCVRoiAlign', 'mmdeploy::MMCVRoiAlign',
input, input,
rois, rois,
output_height_i=output_size[0], output_height_i=output_size[0],

View File

@ -15,10 +15,10 @@ def grid_sampler(g,
PyTorch does not support export grid_sampler to ONNX by default. We add the PyTorch does not support export grid_sampler to ONNX by default. We add the
support here. `grid_sampler` will be exported as ONNX node support here. `grid_sampler` will be exported as ONNX node
'mmcv::grid_sampler' 'mmdeploy::grid_sampler'
""" """
return g.op( return g.op(
'mmcv::grid_sampler', 'mmdeploy::grid_sampler',
input, input,
grid, grid,
interpolation_mode_i=interpolation_mode, interpolation_mode_i=interpolation_mode,

View File

@ -39,7 +39,7 @@ def instance_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
'Tensor')) 'Tensor'))
norm_reshaped = g.op( norm_reshaped = g.op(
'mmcv::TRTInstanceNormalization', 'mmdeploy::TRTInstanceNormalization',
input_reshaped, input_reshaped,
weight_, weight_,
bias_, bias_,

View File

@ -35,7 +35,7 @@ def test_mark():
nodes = onnx_model.graph.node nodes = onnx_model.graph.node
assert nodes[0].op_type == 'Mark' assert nodes[0].op_type == 'Mark'
assert nodes[0].domain == 'mmcv' assert nodes[0].domain == 'mmdeploy'
assert attribute_to_dict(nodes[0].attribute) == dict( assert attribute_to_dict(nodes[0].attribute) == dict(
dtype=1, dtype=1,
func='add', func='add',
@ -46,7 +46,7 @@ def test_mark():
shape=[2, 3, 4]) shape=[2, 3, 4])
assert nodes[1].op_type == 'Mark' assert nodes[1].op_type == 'Mark'
assert nodes[1].domain == 'mmcv' assert nodes[1].domain == 'mmdeploy'
assert attribute_to_dict(nodes[1].attribute) == dict( assert attribute_to_dict(nodes[1].attribute) == dict(
dtype=1, dtype=1,
func='add', func='add',
@ -59,7 +59,7 @@ def test_mark():
assert nodes[2].op_type == 'Add' assert nodes[2].op_type == 'Add'
assert nodes[3].op_type == 'Mark' assert nodes[3].op_type == 'Mark'
assert nodes[3].domain == 'mmcv' assert nodes[3].domain == 'mmdeploy'
assert attribute_to_dict(nodes[3].attribute) == dict( assert attribute_to_dict(nodes[3].attribute) == dict(
dtype=1, dtype=1,
func='add', func='add',

View File

@ -20,7 +20,7 @@ def create_custom_module():
@staticmethod @staticmethod
def symbolic(g, x, val): def symbolic(g, x, val):
return g.op('mmcv::symbolic_old', x, val_i=val) return g.op('mmdeploy::symbolic_old', x, val_i=val)
@staticmethod @staticmethod
def forward(ctx, x, val): def forward(ctx, x, val):
@ -42,17 +42,17 @@ def test_symbolic_rewriter():
@SYMBOLIC_REWRITER.register_symbolic('mmdeploy.TestFunc') @SYMBOLIC_REWRITER.register_symbolic('mmdeploy.TestFunc')
def symbolic_testfunc_default(symbolic_wrapper, g, x, val): def symbolic_testfunc_default(symbolic_wrapper, g, x, val):
assert hasattr(symbolic_wrapper, 'cfg') assert hasattr(symbolic_wrapper, 'cfg')
return g.op('mmcv::symbolic_testfunc_default', x, val_i=val) return g.op('mmdeploy::symbolic_testfunc_default', x, val_i=val)
@SYMBOLIC_REWRITER.register_symbolic( @SYMBOLIC_REWRITER.register_symbolic(
'mmdeploy.TestFunc', backend='tensorrt') 'mmdeploy.TestFunc', backend='tensorrt')
def symbolic_testfunc_tensorrt(symbolic_wrapper, g, x, val): def symbolic_testfunc_tensorrt(symbolic_wrapper, g, x, val):
return g.op('mmcv::symbolic_testfunc_tensorrt', x, val_i=val) return g.op('mmdeploy::symbolic_testfunc_tensorrt', x, val_i=val)
@SYMBOLIC_REWRITER.register_symbolic( @SYMBOLIC_REWRITER.register_symbolic(
'cummax', is_pytorch=True, arg_descriptors=['v', 'i']) 'cummax', is_pytorch=True, arg_descriptors=['v', 'i'])
def symbolic_cummax(symbolic_wrapper, g, input, dim): def symbolic_cummax(symbolic_wrapper, g, input, dim):
return g.op('mmcv::cummax_default', input, dim_i=dim, outputs=2) return g.op('mmdeploy::cummax_default', input, dim_i=dim, outputs=2)
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
@ -74,9 +74,9 @@ def test_symbolic_rewriter():
onnx_model = onnx.load(output_file) onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_testfunc_default' assert nodes[0].op_type == 'symbolic_testfunc_default'
assert nodes[0].domain == 'mmcv' assert nodes[0].domain == 'mmdeploy'
assert nodes[1].op_type == 'cummax_default' assert nodes[1].op_type == 'cummax_default'
assert nodes[1].domain == 'mmcv' assert nodes[1].domain == 'mmdeploy'
# ncnn # ncnn
with RewriterContext(cfg=cfg, backend='ncnn', opset=11): with RewriterContext(cfg=cfg, backend='ncnn', opset=11):
@ -84,9 +84,9 @@ def test_symbolic_rewriter():
onnx_model = onnx.load(output_file) onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_testfunc_default' assert nodes[0].op_type == 'symbolic_testfunc_default'
assert nodes[0].domain == 'mmcv' assert nodes[0].domain == 'mmdeploy'
assert nodes[1].op_type == 'cummax_default' assert nodes[1].op_type == 'cummax_default'
assert nodes[1].domain == 'mmcv' assert nodes[1].domain == 'mmdeploy'
# tensorrt # tensorrt
with RewriterContext(cfg=cfg, backend='tensorrt', opset=11): with RewriterContext(cfg=cfg, backend='tensorrt', opset=11):
@ -94,9 +94,9 @@ def test_symbolic_rewriter():
onnx_model = onnx.load(output_file) onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_testfunc_tensorrt' assert nodes[0].op_type == 'symbolic_testfunc_tensorrt'
assert nodes[0].domain == 'mmcv' assert nodes[0].domain == 'mmdeploy'
assert nodes[1].op_type == 'cummax_default' assert nodes[1].op_type == 'cummax_default'
assert nodes[1].domain == 'mmcv' assert nodes[1].domain == 'mmdeploy'
def test_unregister(): def test_unregister():
@ -104,12 +104,12 @@ def test_unregister():
@SYMBOLIC_REWRITER.register_symbolic('mmdeploy.TestFunc') @SYMBOLIC_REWRITER.register_symbolic('mmdeploy.TestFunc')
def symbolic_testfunc_default(symbolic_wrapper, g, x, val): def symbolic_testfunc_default(symbolic_wrapper, g, x, val):
return g.op('mmcv::symbolic_testfunc_default', x, val_i=val) return g.op('mmdeploy::symbolic_testfunc_default', x, val_i=val)
@SYMBOLIC_REWRITER.register_symbolic( @SYMBOLIC_REWRITER.register_symbolic(
'cummax', is_pytorch=True, arg_descriptors=['v', 'i']) 'cummax', is_pytorch=True, arg_descriptors=['v', 'i'])
def symbolic_cummax(symbolic_wrapper, g, input, dim): def symbolic_cummax(symbolic_wrapper, g, input, dim):
return g.op('mmcv::cummax_default', input, dim_i=dim, outputs=2) return g.op('mmdeploy::cummax_default', input, dim_i=dim, outputs=2)
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
@ -135,7 +135,7 @@ def test_unregister():
onnx_model = onnx.load(output_file) onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node nodes = onnx_model.graph.node
assert nodes[0].op_type == 'cummax_default' assert nodes[0].op_type == 'cummax_default'
assert nodes[0].domain == 'mmcv' assert nodes[0].domain == 'mmdeploy'
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
torch.onnx.export(model, x, output_file, opset_version=11) torch.onnx.export(model, x, output_file, opset_version=11)
@ -146,13 +146,13 @@ def test_unregister():
onnx_model = onnx.load(output_file) onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_testfunc_default' assert nodes[0].op_type == 'symbolic_testfunc_default'
assert nodes[0].domain == 'mmcv' assert nodes[0].domain == 'mmdeploy'
torch.onnx.export(model, x, output_file, opset_version=11) torch.onnx.export(model, x, output_file, opset_version=11)
onnx_model = onnx.load(output_file) onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_old' assert nodes[0].op_type == 'symbolic_old'
assert nodes[0].domain == 'mmcv' assert nodes[0].domain == 'mmdeploy'
def test_register_empty_symbolic(): def test_register_empty_symbolic():
@ -160,7 +160,7 @@ def test_register_empty_symbolic():
@symbolic_rewriter.register_symbolic('mmdeploy.EmptyFunction') @symbolic_rewriter.register_symbolic('mmdeploy.EmptyFunction')
def symbolic_testfunc_default(symbolic_wrapper, g, x, val): def symbolic_testfunc_default(symbolic_wrapper, g, x, val):
return g.op('mmcv::symbolic_testfunc_default', x, val_i=val) return g.op('mmdeploy::symbolic_testfunc_default', x, val_i=val)
symbolic_rewriter.enter() symbolic_rewriter.enter()
assert len(symbolic_rewriter._extra_symbolic) == 0 assert len(symbolic_rewriter._extra_symbolic) == 0

View File

@ -384,7 +384,7 @@ def test_multi_level_roi_align(backend,
input_name, ['bbox_feats'], input_name, ['bbox_feats'],
'MMCVMultiLevelRoiAlign_0', 'MMCVMultiLevelRoiAlign_0',
None, None,
'mmlab', 'mmdeploy',
aligned=aligned, aligned=aligned,
featmap_strides=featmap_strides, featmap_strides=featmap_strides,
finest_scale=finest_scale, finest_scale=finest_scale,
@ -397,7 +397,7 @@ def test_multi_level_roi_align(backend,
graph, producer_name='pytorch', producer_version='1.8') graph, producer_name='pytorch', producer_version='1.8')
onnx_model.opset_import[0].version = 11 onnx_model.opset_import[0].version = 11
onnx_model.opset_import.append( onnx_model.opset_import.append(
onnx.onnx_ml_pb2.OperatorSetIdProto(domain='mmlab', version=1)) onnx.onnx_ml_pb2.OperatorSetIdProto(domain='mmdeploy', version=1))
backend.run_and_validate( backend.run_and_validate(
onnx_model, [rois, *input], onnx_model, [rois, *input],

View File

@ -85,7 +85,7 @@ def test_grid_sampler():
model = OpModel(torch.grid_sampler, flow, 0, 0, False).eval() model = OpModel(torch.grid_sampler, flow, 0, 0, False).eval()
nodes = get_model_onnx_nodes(model, x) nodes = get_model_onnx_nodes(model, x)
assert nodes[1].op_type == 'grid_sampler' assert nodes[1].op_type == 'grid_sampler'
assert nodes[1].domain == 'mmcv' assert nodes[1].domain == 'mmdeploy'
def test_instance_norm(): def test_instance_norm():
@ -94,7 +94,7 @@ def test_instance_norm():
1e-05).eval() 1e-05).eval()
nodes = get_model_onnx_nodes(model, x) nodes = get_model_onnx_nodes(model, x)
assert nodes[4].op_type == 'TRTInstanceNormalization' assert nodes[4].op_type == 'TRTInstanceNormalization'
assert nodes[4].domain == 'mmcv' assert nodes[4].domain == 'mmdeploy'
class TestSqueeze: class TestSqueeze: