mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
Support custom operators cummax and cummin for onnxruntime (#1010)
* support custom op `mmcv::cummax` for onnxruntime in mmcv * fix clang-format lint error * support mmcv::cummin, reformat codes * fix merge from master * add docs for mmcv::cummax and mmcv::cummin * format doc * add assertion for torch version, when exporting `cummax` to onnx * add more comments for torch version * handle exporting to onnx in `soft_nms` * commit for test_onnx * remove `is_in_onnx_export` in softnms * add more comments * fix c++ lint error * add known issues doc for `cummax` * fix known issues doc
This commit is contained in:
parent
db6b0542c7
commit
934b549e23
@ -33,6 +33,18 @@
|
||||
- [Inputs](#inputs-4)
|
||||
- [Outputs](#outputs-4)
|
||||
- [Type Constraints](#type-constraints-4)
|
||||
- [cummax](#cummax)
|
||||
- [Description](#description-5)
|
||||
- [Parameters](#parameters-5)
|
||||
- [Inputs](#inputs-5)
|
||||
- [Outputs](#outputs-5)
|
||||
- [Type Constraints](#type-constraints-5)
|
||||
- [cummin](#cummin)
|
||||
- [Description](#description-6)
|
||||
- [Parameters](#parameters-6)
|
||||
- [Inputs](#inputs-6)
|
||||
- [Outputs](#outputs-6)
|
||||
- [Type Constraints](#type-constraints-6)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
@ -207,3 +219,67 @@ Perform CornerPool on `input` features. Read [CornerNet -- Detecting Objects as
|
||||
### Type Constraints
|
||||
|
||||
- T:tensor(float32)
|
||||
|
||||
## cummax
|
||||
|
||||
### Description
|
||||
|
||||
Returns a tuple (`values`, `indices`) where `values` is the cumulative maximum elements of `input` in the dimension `dim`. And `indices` is the index location of each maximum value found in the dimension `dim`. Read [torch.cummax](https://pytorch.org/docs/stable/generated/torch.cummax.html) for more details.
|
||||
|
||||
### Parameters
|
||||
|
||||
| Type | Parameter | Description |
|
||||
| ------- | --------------- | ---------------------------------------------------------------- |
|
||||
| `int` | `dim` | the dimension to do the operation over |
|
||||
|
||||
### Inputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>input</tt>: T</dt>
|
||||
<dd>The input tensor with various shapes. Tensor with empty element is also supported.</dd>
|
||||
</dl>
|
||||
|
||||
### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>output</tt>: T</dt>
|
||||
<dd>Output the cumulative maximum elements of `input` in the dimension `dim`, with the same shape and dtype as `input`.</dd>
|
||||
<dt><tt>indices</tt>: tensor(int64)</dt>
|
||||
<dd>Output the index location of each cumulative maximum value found in the dimension `dim`, with the same shape as `input`.</dd>
|
||||
</dl>
|
||||
|
||||
### Type Constraints
|
||||
|
||||
- T:tensor(float32)
|
||||
|
||||
## cummin
|
||||
|
||||
### Description
|
||||
|
||||
Returns a tuple (`values`, `indices`) where `values` is the cumulative minimum elements of `input` in the dimension `dim`. And `indices` is the index location of each minimum value found in the dimension `dim`. Read [torch.cummin](https://pytorch.org/docs/stable/generated/torch.cummin.html) for more details.
|
||||
|
||||
### Parameters
|
||||
|
||||
| Type | Parameter | Description |
|
||||
| ------- | --------------- | ---------------------------------------------------------------- |
|
||||
| `int` | `dim` | the dimension to do the operation over |
|
||||
|
||||
### Inputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>input</tt>: T</dt>
|
||||
<dd>The input tensor with various shapes. Tensor with empty element is also supported.</dd>
|
||||
</dl>
|
||||
|
||||
### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>output</tt>: T</dt>
|
||||
<dd>Output the cumulative minimum elements of `input` in the dimension `dim`, with the same shape and dtype as `input`.</dd>
|
||||
<dt><tt>indices</tt>: tensor(int64)</dt>
|
||||
<dd>Output the index location of each cumulative minimum value found in the dimension `dim`, with the same shape as `input`.</dd>
|
||||
</dl>
|
||||
|
||||
### Type Constraints
|
||||
|
||||
- T:tensor(float32)
|
||||
|
@ -21,7 +21,9 @@
|
||||
| [RoIAlign](onnxruntime_custom_ops.md#roialign) | Y | N | 1.2.5 |
|
||||
| [NMS](onnxruntime_custom_ops.md#nms) | Y | N | 1.2.7 |
|
||||
| [grid_sampler](onnxruntime_custom_ops.md#grid_sampler) | Y | N | master |
|
||||
| [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | master |
|
||||
| [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | master |
|
||||
| [cummax](onnxruntime_custom_ops.md#cummax) | Y | N | master |
|
||||
| [cummin](onnxruntime_custom_ops.md#cummin) | Y | N | master |
|
||||
|
||||
## How to build custom operators for ONNX Runtime
|
||||
|
||||
@ -115,7 +117,9 @@ Take custom operator `soft_nms` for example.
|
||||
|
||||
## Known Issues
|
||||
|
||||
- None
|
||||
- "RuntimeError: tuple appears in op that does not forward tuples, unsupported kind: `prim::PythonOp`."
|
||||
1. Note generally `cummax` or `cummin` is exportable to ONNX as long as the torch version >= 1.5.0, since `torch.cummax` is only supported with torch >= 1.5.0. But when `cummax` or `cummin` serves as an intermediate component whose outputs is used as inputs for another modules, it's expected that torch version must be >= 1.7.0. Otherwise the above error might arise, when running exported ONNX model with onnxruntime.
|
||||
2. Solution: update the torch version to 1.7.0 or higher.
|
||||
|
||||
## References
|
||||
|
||||
|
@ -396,6 +396,16 @@ def grid_sampler(g,
|
||||
align_corners_i=align_corners)
|
||||
|
||||
|
||||
@parse_args('v', 'i')
|
||||
def cummax(g, input, dim):
|
||||
return g.op('mmcv::cummax', input, dim_i=dim, outputs=2)
|
||||
|
||||
|
||||
@parse_args('v', 'i')
|
||||
def cummin(g, input, dim):
|
||||
return g.op('mmcv::cummin', input, dim_i=dim, outputs=2)
|
||||
|
||||
|
||||
def register_extra_symbolics(opset=11):
|
||||
register_op('one_hot', one_hot, '', opset)
|
||||
register_op('im2col', im2col, '', opset)
|
||||
@ -421,3 +431,5 @@ def register_extra_symbolics(opset=11):
|
||||
register_op('upsample_bicubic2d', upsample_bicubic2d, '', opset)
|
||||
register_op('new_full', new_full, '', opset)
|
||||
register_op('grid_sampler', grid_sampler, '', opset)
|
||||
register_op('cummax', cummax, '', opset)
|
||||
register_op('cummin', cummin, '', opset)
|
||||
|
@ -140,6 +140,15 @@ class CornerPool(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0':
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
assert torch.__version__ >= '1.7.0', \
|
||||
'When `cummax` serves as an intermediate component whose '\
|
||||
'outputs is used as inputs for another modules, it\'s '\
|
||||
'expected that pytorch version must be >= 1.7.0, '\
|
||||
'otherwise Error appears like: `RuntimeError: tuple '\
|
||||
'appears in op that does not forward tuples, unsupported '\
|
||||
'kind: prim::PythonOp`.'
|
||||
|
||||
dim, flip = self.cummax_dim_flip[self.mode]
|
||||
if flip:
|
||||
x = x.flip(dim)
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include "grid_sample.h"
|
||||
#include "nms.h"
|
||||
#include "ort_mmcv_utils.h"
|
||||
#include "reduce_ops.h"
|
||||
#include "roi_align.h"
|
||||
#include "roi_align_rotated.h"
|
||||
#include "soft_nms.h"
|
||||
@ -14,6 +15,8 @@ NmsOp c_NmsOp;
|
||||
MMCVRoiAlignCustomOp c_MMCVRoiAlignCustomOp;
|
||||
MMCVRoIAlignRotatedCustomOp c_MMCVRoIAlignRotatedCustomOp;
|
||||
GridSampleOp c_GridSampleOp;
|
||||
MMCVCumMaxCustomOp c_MMCVCumMaxCustomOp;
|
||||
MMCVCumMinCustomOp c_MMCVCumMinCustomOp;
|
||||
MMCVCornerPoolCustomOp c_MMCVCornerPoolCustomOp;
|
||||
|
||||
OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
|
||||
@ -52,5 +55,13 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
|
||||
return status;
|
||||
}
|
||||
|
||||
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_MMCVCumMaxCustomOp)) {
|
||||
return status;
|
||||
}
|
||||
|
||||
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_MMCVCumMinCustomOp)) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return ortApi->AddCustomOpDomain(options, domain);
|
||||
}
|
||||
|
187
mmcv/ops/csrc/onnxruntime/cpu/reduce_ops.cpp
Normal file
187
mmcv/ops/csrc/onnxruntime/cpu/reduce_ops.cpp
Normal file
@ -0,0 +1,187 @@
|
||||
#include "reduce_ops.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "../ort_mmcv_utils.h"
|
||||
|
||||
// modified from
|
||||
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ReduceOps.cpp
|
||||
|
||||
static inline int64_t maybe_wrap_dim(int64_t dim, int64_t ndims) {
|
||||
int64_t min = -ndims;
|
||||
int64_t max = ndims - 1;
|
||||
assert(dim >= min && dim <= max);
|
||||
if (dim < 0) dim += ndims;
|
||||
return dim;
|
||||
}
|
||||
|
||||
static inline int64_t get_dim_stride(const int64_t dim, const int64_t ndims,
|
||||
const int64_t *reversed_dim_cumprod) {
|
||||
return dim == ndims - 1 ? 1 : reversed_dim_cumprod[dim + 1];
|
||||
}
|
||||
|
||||
static inline int64_t get_dim_size(const int64_t dim, const int64_t ndims,
|
||||
const int64_t *reversed_dim_cumprod) {
|
||||
return dim == ndims - 1
|
||||
? reversed_dim_cumprod[dim]
|
||||
: reversed_dim_cumprod[dim] / reversed_dim_cumprod[dim + 1];
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename Operation>
|
||||
void cummax_cummin_helper(const T1 *input, T1 *output, T2 *indices,
|
||||
const int64_t input_dim_size, const int64_t stride) {
|
||||
Operation op;
|
||||
T1 out = input[0];
|
||||
int64_t idx = 0;
|
||||
for (int64_t i = 0; i < input_dim_size; i++) {
|
||||
T1 curr_elem = input[i * stride];
|
||||
if (op(curr_elem, out)) {
|
||||
out = curr_elem;
|
||||
idx = i;
|
||||
}
|
||||
output[i * stride] = out;
|
||||
indices[i * stride] = idx;
|
||||
}
|
||||
}
|
||||
|
||||
// modified `tensor_dim_apply3` from
|
||||
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorDimApply.h.
|
||||
// the difference is that: (1) use `reversed_dim_cumprod` for fast computing of
|
||||
// tensor `size` and `stride`. (2) the same `stride` is used for input, output,
|
||||
// and indices, since it's unnecessary to use separate values. currently
|
||||
// `tensor_dim_apply3` is only used for `cummax` and `cummin`, according to the
|
||||
// official pytorch projects: https://github.com/pytorch/pytorch.
|
||||
template <typename T1, typename T2, typename Function>
|
||||
void tensor_dim_apply3(const T1 *input, T1 *output, T2 *indices,
|
||||
const int64_t dim, const int64_t ndims,
|
||||
const int64_t *reversed_dim_cumprod, Function func) {
|
||||
int dim_apply_finished = 0;
|
||||
int64_t input_dim_size = get_dim_size(dim, ndims, reversed_dim_cumprod);
|
||||
// the same stride is used for input, output and indices
|
||||
int64_t stride = get_dim_stride(dim, ndims, reversed_dim_cumprod);
|
||||
std::vector<int64_t> counter(ndims, 0);
|
||||
|
||||
while (!dim_apply_finished) {
|
||||
// call `func` once to update output and indices
|
||||
func(input, output, indices, input_dim_size, stride);
|
||||
if (ndims == 1) break;
|
||||
for (int64_t dim_i = 0; dim_i < ndims; dim_i++) {
|
||||
if (dim_i == dim) {
|
||||
if (dim_i == (ndims - 1)) {
|
||||
dim_apply_finished = 1;
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
counter[dim_i]++;
|
||||
|
||||
// the same stride is used for input, output, and indices
|
||||
int64_t stride_dim_i = get_dim_stride(dim_i, ndims, reversed_dim_cumprod);
|
||||
input += stride_dim_i;
|
||||
output += stride_dim_i;
|
||||
indices += stride_dim_i;
|
||||
|
||||
if (counter[dim_i] == get_dim_size(dim_i, ndims, reversed_dim_cumprod)) {
|
||||
if (dim_i == ndims - 1) {
|
||||
dim_apply_finished = 1;
|
||||
break;
|
||||
} else {
|
||||
input -= counter[dim_i] * stride_dim_i;
|
||||
output -= counter[dim_i] * stride_dim_i;
|
||||
indices -= counter[dim_i] * stride_dim_i;
|
||||
counter[dim_i] = 0;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
} // if
|
||||
} // for
|
||||
} // while
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename Operation>
|
||||
void CumMax_CumMin_CPU(const T1 *input, T1 *output, T2 *indices,
|
||||
int64_t *reversed_dim_cumprod, const int64_t dim,
|
||||
const OrtTensorDimensions &out_dimensions) {
|
||||
// calculate numel
|
||||
const int64_t ndims = out_dimensions.size();
|
||||
int64_t numel = 1;
|
||||
for (int64_t dim_i = 0; dim_i < ndims; dim_i++) {
|
||||
numel *= out_dimensions.data()[dim_i];
|
||||
}
|
||||
|
||||
// cummax is only applied to input which is non-zero dim and non-empty
|
||||
if (numel) {
|
||||
// compute the cumulative production on dimension size,
|
||||
// which is then used for computing the stride or size of a specific `dim`.
|
||||
reversed_dim_cumprod[ndims - 1] = out_dimensions.data()[ndims - 1];
|
||||
for (int64_t dim_i = ndims - 2; dim_i >= 0; dim_i--) {
|
||||
reversed_dim_cumprod[dim_i] =
|
||||
reversed_dim_cumprod[dim_i + 1] * out_dimensions.data()[dim_i];
|
||||
}
|
||||
|
||||
// do cummax or cummin besed on `Operation` type
|
||||
tensor_dim_apply3<float, int64_t>(
|
||||
input, output, indices, dim, ndims, reversed_dim_cumprod,
|
||||
cummax_cummin_helper<float, int64_t, Operation>);
|
||||
}
|
||||
}
|
||||
|
||||
void MMCVCumMaxKernel::Compute(OrtKernelContext *context) {
|
||||
// get input
|
||||
const OrtValue *input = ort_.KernelContext_GetInput(context, 0);
|
||||
const float *input_data =
|
||||
reinterpret_cast<const float *>(ort_.GetTensorData<float>(input));
|
||||
|
||||
// get ouput
|
||||
OrtTensorDimensions out_dimensions(ort_, input);
|
||||
OrtValue *output = ort_.KernelContext_GetOutput(
|
||||
context, 0, out_dimensions.data(), out_dimensions.size());
|
||||
float *output_data = ort_.GetTensorMutableData<float>(output);
|
||||
OrtValue *indices = ort_.KernelContext_GetOutput(
|
||||
context, 1, out_dimensions.data(), out_dimensions.size());
|
||||
int64_t *indices_data = ort_.GetTensorMutableData<int64_t>(indices);
|
||||
|
||||
// allocate tmp memory for computing the cumulative production on dimension
|
||||
// size
|
||||
const int64_t ndims = out_dimensions.size();
|
||||
assert(ndims > 0);
|
||||
int64_t *reversed_dim_cumprod =
|
||||
(int64_t *)allocator_.Alloc(sizeof(int64_t) * ndims);
|
||||
|
||||
// dim should be wrapped if it's negative (e.g. -1)
|
||||
const int64_t dim = maybe_wrap_dim(dim_, ndims);
|
||||
CumMax_CumMin_CPU<float, int64_t, std::greater_equal<float>>(
|
||||
input_data, output_data, indices_data, reversed_dim_cumprod, dim,
|
||||
out_dimensions);
|
||||
}
|
||||
|
||||
void MMCVCumMinKernel::Compute(OrtKernelContext *context) {
|
||||
// get input
|
||||
const OrtValue *input = ort_.KernelContext_GetInput(context, 0);
|
||||
const float *input_data =
|
||||
reinterpret_cast<const float *>(ort_.GetTensorData<float>(input));
|
||||
|
||||
// get ouput
|
||||
OrtTensorDimensions out_dimensions(ort_, input);
|
||||
OrtValue *output = ort_.KernelContext_GetOutput(
|
||||
context, 0, out_dimensions.data(), out_dimensions.size());
|
||||
float *output_data = ort_.GetTensorMutableData<float>(output);
|
||||
OrtValue *indices = ort_.KernelContext_GetOutput(
|
||||
context, 1, out_dimensions.data(), out_dimensions.size());
|
||||
int64_t *indices_data = ort_.GetTensorMutableData<int64_t>(indices);
|
||||
|
||||
// allocate tmp memory for computing the cumulative production on dimension
|
||||
// size
|
||||
const int64_t ndims = out_dimensions.size();
|
||||
assert(ndims > 0);
|
||||
int64_t *reversed_dim_cumprod =
|
||||
(int64_t *)allocator_.Alloc(sizeof(int64_t) * ndims);
|
||||
|
||||
// dim should be wrapped if it's negative (e.g. -1)
|
||||
const int64_t dim = maybe_wrap_dim(dim_, ndims);
|
||||
CumMax_CumMin_CPU<float, int64_t, std::less_equal<float>>(
|
||||
input_data, output_data, indices_data, reversed_dim_cumprod, dim,
|
||||
out_dimensions);
|
||||
}
|
94
mmcv/ops/csrc/onnxruntime/reduce_ops.h
Normal file
94
mmcv/ops/csrc/onnxruntime/reduce_ops.h
Normal file
@ -0,0 +1,94 @@
|
||||
#ifndef ONNXRUNTIME_REDUCE_OPS_H
|
||||
#define ONNXRUNTIME_REDUCE_OPS_H
|
||||
|
||||
#include <onnxruntime_cxx_api.h>
|
||||
|
||||
struct MMCVCumMaxKernel {
|
||||
public:
|
||||
MMCVCumMaxKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info)
|
||||
: ort_(ort) {
|
||||
dim_ = ort_.KernelInfoGetAttribute<int64_t>(info, "dim");
|
||||
|
||||
// create allocator
|
||||
allocator_ = Ort::AllocatorWithDefaultOptions();
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
Ort::CustomOpApi ort_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
int64_t dim_;
|
||||
};
|
||||
|
||||
struct MMCVCumMinKernel {
|
||||
public:
|
||||
MMCVCumMinKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info)
|
||||
: ort_(ort) {
|
||||
dim_ = ort_.KernelInfoGetAttribute<int64_t>(info, "dim");
|
||||
|
||||
// create allocator
|
||||
allocator_ = Ort::AllocatorWithDefaultOptions();
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
Ort::CustomOpApi ort_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
int64_t dim_;
|
||||
};
|
||||
|
||||
struct MMCVCumMaxCustomOp
|
||||
: Ort::CustomOpBase<MMCVCumMaxCustomOp, MMCVCumMaxKernel> {
|
||||
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) {
|
||||
return new MMCVCumMaxKernel(api, info);
|
||||
}
|
||||
|
||||
const char* GetName() const { return "cummax"; }
|
||||
|
||||
size_t GetInputTypeCount() const { return 1; }
|
||||
ONNXTensorElementDataType GetInputType(size_t) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
|
||||
size_t GetOutputTypeCount() const { return 2; }
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
if (index == 1) return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
|
||||
// force cpu
|
||||
const char* GetExecutionProviderType() const {
|
||||
return "CPUExecutionProvider";
|
||||
};
|
||||
};
|
||||
|
||||
struct MMCVCumMinCustomOp
|
||||
: Ort::CustomOpBase<MMCVCumMinCustomOp, MMCVCumMinKernel> {
|
||||
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) {
|
||||
return new MMCVCumMinKernel(api, info);
|
||||
}
|
||||
|
||||
const char* GetName() const { return "cummin"; }
|
||||
|
||||
size_t GetInputTypeCount() const { return 1; }
|
||||
ONNXTensorElementDataType GetInputType(size_t) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
|
||||
size_t GetOutputTypeCount() const { return 2; }
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
if (index == 1) return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
|
||||
// force cpu
|
||||
const char* GetExecutionProviderType() const {
|
||||
return "CPUExecutionProvider";
|
||||
};
|
||||
};
|
||||
|
||||
#endif // ONNXRUNTIME_REDUCE_OPS_H
|
@ -218,6 +218,7 @@ def soft_nms(boxes,
|
||||
float(iou_threshold), float(sigma),
|
||||
float(min_score), method_dict[method],
|
||||
int(offset))
|
||||
|
||||
dets = dets[:inds.size(0)]
|
||||
|
||||
if is_numpy:
|
||||
|
@ -494,3 +494,78 @@ def test_corner_pool(mode, opset=11):
|
||||
pytorch_results = wrapped_model(input.clone())
|
||||
os.remove(onnx_file)
|
||||
assert np.allclose(pytorch_results, ort_result, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('key', ['cummax', 'cummin'])
|
||||
def test_cummax_cummin(key, opset=11):
|
||||
if torch.__version__ == 'parrots':
|
||||
pytest.skip('onnx is not supported in parrots directly')
|
||||
|
||||
# Note generally `cummax` or `cummin` is exportable to ONNX
|
||||
# as long as the pytorch version >= 1.5.0, since `torch.cummax`
|
||||
# is only supported with torch >= 1.5.0.
|
||||
# But when `cummax` or `cummin` serves as an intermediate component
|
||||
# whose outputs is used as inputs for another modules, it's expected
|
||||
# that pytorch version must be >= 1.7.0. Otherwise error appears like:
|
||||
# `RuntimeError: tuple appears in op that does not forward tuples,
|
||||
# unsupported 'kind: prim::PythonOp`.
|
||||
if version.parse(torch.__version__) < version.parse('1.7.0'):
|
||||
pytest.skip('test_cummax_cummin should be ran with pytorch >= 1.7.0')
|
||||
|
||||
# register custom op `mmcv::cummax` and `mmcv::cummin`
|
||||
from mmcv.onnx.symbolic import register_extra_symbolics
|
||||
register_extra_symbolics(opset)
|
||||
|
||||
from mmcv.ops import get_onnxruntime_op_path
|
||||
ort_custom_op_path = get_onnxruntime_op_path()
|
||||
if not os.path.exists(ort_custom_op_path):
|
||||
pytest.skip('custom ops for onnxruntime are not compiled.')
|
||||
|
||||
input_list = [
|
||||
# arbitrary shape, e.g. 1-D, 2-D, 3-D, ...
|
||||
torch.rand((2, 3, 4, 1, 5)),
|
||||
torch.rand((1)),
|
||||
torch.rand((2, 0, 1)), # tensor.numel() is 0
|
||||
torch.FloatTensor(), # empty tensor
|
||||
]
|
||||
|
||||
cummax_cummin_funcs = {'cummax': torch.cummax, 'cummin': torch.cummin}
|
||||
|
||||
for input in input_list:
|
||||
ndims = input.dim()
|
||||
# valid dim range is [-ndims, ndims-1]
|
||||
# test for all `dim` value which is valid
|
||||
for dim in range(-ndims, ndims):
|
||||
cummax_func = partial(cummax_cummin_funcs[key], dim=dim)
|
||||
wrapped_model = WrapFunction(cummax_func).eval()
|
||||
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(
|
||||
wrapped_model,
|
||||
input,
|
||||
onnx_file,
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=True,
|
||||
input_names=['input'],
|
||||
output_names=['output', 'indices'],
|
||||
opset_version=opset)
|
||||
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
input_all = [node.name for node in onnx_model.graph.input]
|
||||
input_initializer = [
|
||||
node.name for node in onnx_model.graph.initializer
|
||||
]
|
||||
net_feed_input = list(set(input_all) - set(input_initializer))
|
||||
assert (len(net_feed_input) == 1)
|
||||
|
||||
session_options = rt.SessionOptions()
|
||||
session_options.register_custom_ops_library(ort_custom_op_path)
|
||||
sess = rt.InferenceSession(onnx_file, session_options)
|
||||
ort_output, ort_inds = sess.run(None,
|
||||
{'input': input.detach().numpy()})
|
||||
pytorch_output, pytorch_inds = wrapped_model(input.clone())
|
||||
pytorch_output = pytorch_output.detach().numpy()
|
||||
pytorch_inds = pytorch_inds.detach().numpy()
|
||||
assert np.allclose(pytorch_output, ort_output, atol=1e-5)
|
||||
assert np.all(pytorch_inds == ort_inds)
|
||||
os.remove(onnx_file)
|
||||
|
Loading…
x
Reference in New Issue
Block a user