Change op domain (#262)

* change domain to mmdeploy

* update tests

* resolve comments
pull/12/head
RunningLeon 2021-12-08 15:06:41 +08:00 committed by GitHub
parent 03c95a1149
commit d96ee9e9f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 38 additions and 37 deletions

View File

@ -3,13 +3,13 @@
#include "ort_utils.h"
const char *c_MMCVOpDomain = "mmcv";
const char *c_MMDeployOpDomain = "mmdeploy";
OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, const OrtApiBase *api) {
OrtCustomOpDomain *domain = nullptr;
const OrtApi *kOrtApi = api->GetApi(ORT_API_VERSION);
if (auto status = kOrtApi->CreateCustomOpDomain(c_MMCVOpDomain, &domain)) {
if (auto status = kOrtApi->CreateCustomOpDomain(c_MMDeployOpDomain, &domain)) {
return status;
}

View File

@ -22,6 +22,7 @@
# go back to third_party directory and git clone pybind11
cd ..
git clone git@github.com:pybind/pybind11.git pybind11
cd pybind11
git checkout 70a58c5
```

View File

@ -28,7 +28,7 @@ class MultiLevelRoiAlign(Function):
inputs = args[:len(featmap_strides)]
rois = args[len(featmap_strides)]
return g.op(
'mmcv::MMCVMultiLevelRoiAlign',
'mmdeploy::MMCVMultiLevelRoiAlign',
rois,
*inputs,
output_height_i=output_size[1],

View File

@ -40,9 +40,9 @@ class Mark(torch.autograd.Function):
@staticmethod
def symbolic(g, x, dtype, shape, func, func_id, type, name, id, attrs):
"""Symbolic function for mmcv::Mark op."""
"""Symbolic function for mmdeploy::Mark op."""
n = g.op(
'mmcv::Mark',
'mmdeploy::Mark',
x,
dtype_i=TORCH_DTYPE_TO_ONNX[dtype],
shape_i=shape,

View File

@ -132,10 +132,10 @@ def nms_dynamic(ctx, g, boxes: Tensor, scores: Tensor,
class TRTBatchedNMSop(torch.autograd.Function):
"""Create mmcv::TRTBatchedNMS op for TensorRT backend.
"""Create mmdeploy::TRTBatchedNMS op for TensorRT backend.
NMS in ONNX supports dynamic outputs. This class helps replace
onnx::NonMaxSuppression with mmcv::TRTBatchedNMS.
onnx::NonMaxSuppression with mmdeploy::TRTBatchedNMS.
"""
@staticmethod
@ -190,9 +190,9 @@ class TRTBatchedNMSop(torch.autograd.Function):
iou_threshold: float,
score_threshold: float,
background_label_id: int = -1):
"""Symbolic function for mmcv::TRTBatchedNMS."""
"""Symbolic function for mmdeploy::TRTBatchedNMS."""
return g.op(
'mmcv::TRTBatchedNMS',
'mmdeploy::TRTBatchedNMS',
boxes,
scores,
num_classes_i=num_classes,

View File

@ -16,7 +16,7 @@ def roi_align_default(ctx, g, input: Tensor, rois: Tensor,
sampling_ratio: int, pool_mode: str, aligned: bool):
"""Rewrite symbolic function for default backend.
Replace onnx::RoiAlign with mmcv::MMCVRoiAlign.
Replace onnx::RoiAlign with mmdeploy::MMCVRoiAlign.
Args:
ctx (ContextCaller): The context with additional information.
@ -38,7 +38,7 @@ def roi_align_default(ctx, g, input: Tensor, rois: Tensor,
"""
return g.op(
'mmcv::MMCVRoiAlign',
'mmdeploy::MMCVRoiAlign',
input,
rois,
output_height_i=output_size[0],

View File

@ -15,10 +15,10 @@ def grid_sampler(g,
PyTorch does not support export grid_sampler to ONNX by default. We add the
support here. `grid_sampler` will be exported as ONNX node
'mmcv::grid_sampler'
'mmdeploy::grid_sampler'
"""
return g.op(
'mmcv::grid_sampler',
'mmdeploy::grid_sampler',
input,
grid,
interpolation_mode_i=interpolation_mode,

View File

@ -39,7 +39,7 @@ def instance_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
'Tensor'))
norm_reshaped = g.op(
'mmcv::TRTInstanceNormalization',
'mmdeploy::TRTInstanceNormalization',
input_reshaped,
weight_,
bias_,

View File

@ -35,7 +35,7 @@ def test_mark():
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'Mark'
assert nodes[0].domain == 'mmcv'
assert nodes[0].domain == 'mmdeploy'
assert attribute_to_dict(nodes[0].attribute) == dict(
dtype=1,
func='add',
@ -46,7 +46,7 @@ def test_mark():
shape=[2, 3, 4])
assert nodes[1].op_type == 'Mark'
assert nodes[1].domain == 'mmcv'
assert nodes[1].domain == 'mmdeploy'
assert attribute_to_dict(nodes[1].attribute) == dict(
dtype=1,
func='add',
@ -59,7 +59,7 @@ def test_mark():
assert nodes[2].op_type == 'Add'
assert nodes[3].op_type == 'Mark'
assert nodes[3].domain == 'mmcv'
assert nodes[3].domain == 'mmdeploy'
assert attribute_to_dict(nodes[3].attribute) == dict(
dtype=1,
func='add',

View File

@ -20,7 +20,7 @@ def create_custom_module():
@staticmethod
def symbolic(g, x, val):
return g.op('mmcv::symbolic_old', x, val_i=val)
return g.op('mmdeploy::symbolic_old', x, val_i=val)
@staticmethod
def forward(ctx, x, val):
@ -42,17 +42,17 @@ def test_symbolic_rewriter():
@SYMBOLIC_REWRITER.register_symbolic('mmdeploy.TestFunc')
def symbolic_testfunc_default(symbolic_wrapper, g, x, val):
assert hasattr(symbolic_wrapper, 'cfg')
return g.op('mmcv::symbolic_testfunc_default', x, val_i=val)
return g.op('mmdeploy::symbolic_testfunc_default', x, val_i=val)
@SYMBOLIC_REWRITER.register_symbolic(
'mmdeploy.TestFunc', backend='tensorrt')
def symbolic_testfunc_tensorrt(symbolic_wrapper, g, x, val):
return g.op('mmcv::symbolic_testfunc_tensorrt', x, val_i=val)
return g.op('mmdeploy::symbolic_testfunc_tensorrt', x, val_i=val)
@SYMBOLIC_REWRITER.register_symbolic(
'cummax', is_pytorch=True, arg_descriptors=['v', 'i'])
def symbolic_cummax(symbolic_wrapper, g, input, dim):
return g.op('mmcv::cummax_default', input, dim_i=dim, outputs=2)
return g.op('mmdeploy::cummax_default', input, dim_i=dim, outputs=2)
class TestModel(torch.nn.Module):
@ -74,9 +74,9 @@ def test_symbolic_rewriter():
onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_testfunc_default'
assert nodes[0].domain == 'mmcv'
assert nodes[0].domain == 'mmdeploy'
assert nodes[1].op_type == 'cummax_default'
assert nodes[1].domain == 'mmcv'
assert nodes[1].domain == 'mmdeploy'
# ncnn
with RewriterContext(cfg=cfg, backend='ncnn', opset=11):
@ -84,9 +84,9 @@ def test_symbolic_rewriter():
onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_testfunc_default'
assert nodes[0].domain == 'mmcv'
assert nodes[0].domain == 'mmdeploy'
assert nodes[1].op_type == 'cummax_default'
assert nodes[1].domain == 'mmcv'
assert nodes[1].domain == 'mmdeploy'
# tensorrt
with RewriterContext(cfg=cfg, backend='tensorrt', opset=11):
@ -94,9 +94,9 @@ def test_symbolic_rewriter():
onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_testfunc_tensorrt'
assert nodes[0].domain == 'mmcv'
assert nodes[0].domain == 'mmdeploy'
assert nodes[1].op_type == 'cummax_default'
assert nodes[1].domain == 'mmcv'
assert nodes[1].domain == 'mmdeploy'
def test_unregister():
@ -104,12 +104,12 @@ def test_unregister():
@SYMBOLIC_REWRITER.register_symbolic('mmdeploy.TestFunc')
def symbolic_testfunc_default(symbolic_wrapper, g, x, val):
return g.op('mmcv::symbolic_testfunc_default', x, val_i=val)
return g.op('mmdeploy::symbolic_testfunc_default', x, val_i=val)
@SYMBOLIC_REWRITER.register_symbolic(
'cummax', is_pytorch=True, arg_descriptors=['v', 'i'])
def symbolic_cummax(symbolic_wrapper, g, input, dim):
return g.op('mmcv::cummax_default', input, dim_i=dim, outputs=2)
return g.op('mmdeploy::cummax_default', input, dim_i=dim, outputs=2)
class TestModel(torch.nn.Module):
@ -135,7 +135,7 @@ def test_unregister():
onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'cummax_default'
assert nodes[0].domain == 'mmcv'
assert nodes[0].domain == 'mmdeploy'
with pytest.raises(RuntimeError):
torch.onnx.export(model, x, output_file, opset_version=11)
@ -146,13 +146,13 @@ def test_unregister():
onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_testfunc_default'
assert nodes[0].domain == 'mmcv'
assert nodes[0].domain == 'mmdeploy'
torch.onnx.export(model, x, output_file, opset_version=11)
onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_old'
assert nodes[0].domain == 'mmcv'
assert nodes[0].domain == 'mmdeploy'
def test_register_empty_symbolic():
@ -160,7 +160,7 @@ def test_register_empty_symbolic():
@symbolic_rewriter.register_symbolic('mmdeploy.EmptyFunction')
def symbolic_testfunc_default(symbolic_wrapper, g, x, val):
return g.op('mmcv::symbolic_testfunc_default', x, val_i=val)
return g.op('mmdeploy::symbolic_testfunc_default', x, val_i=val)
symbolic_rewriter.enter()
assert len(symbolic_rewriter._extra_symbolic) == 0

View File

@ -384,7 +384,7 @@ def test_multi_level_roi_align(backend,
input_name, ['bbox_feats'],
'MMCVMultiLevelRoiAlign_0',
None,
'mmlab',
'mmdeploy',
aligned=aligned,
featmap_strides=featmap_strides,
finest_scale=finest_scale,
@ -397,7 +397,7 @@ def test_multi_level_roi_align(backend,
graph, producer_name='pytorch', producer_version='1.8')
onnx_model.opset_import[0].version = 11
onnx_model.opset_import.append(
onnx.onnx_ml_pb2.OperatorSetIdProto(domain='mmlab', version=1))
onnx.onnx_ml_pb2.OperatorSetIdProto(domain='mmdeploy', version=1))
backend.run_and_validate(
onnx_model, [rois, *input],

View File

@ -85,7 +85,7 @@ def test_grid_sampler():
model = OpModel(torch.grid_sampler, flow, 0, 0, False).eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[1].op_type == 'grid_sampler'
assert nodes[1].domain == 'mmcv'
assert nodes[1].domain == 'mmdeploy'
def test_instance_norm():
@ -94,7 +94,7 @@ def test_instance_norm():
1e-05).eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[4].op_type == 'TRTInstanceNormalization'
assert nodes[4].domain == 'mmcv'
assert nodes[4].domain == 'mmdeploy'
class TestSqueeze: