mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* feat: support solo mmdet3.x openvino 2.0 api * feat: support solo mmdet3.x openvino * fix: lint * fix: add solo head test * docs: add supported_modesl * docs: add supported_models * fix: fix unreasonable code * fix: fix ci failed * feat: add linspace func rewrite * fix: fix unreasonable rewrite linspace__onnx * fix: change func name from __onnx to __default * feat: add solo test regression
21 lines
722 B
Python
21 lines
722 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
from torch.types import Number
|
|
|
|
from mmdeploy.core import FUNCTION_REWRITER
|
|
|
|
|
|
@FUNCTION_REWRITER.register_rewriter(func_name='torch.linspace')
|
|
def linspace__default(start: Number, end: Number, steps: int = None, **kwargs):
|
|
"""Rewrite `linspace` for onnxruntime."""
|
|
steps = 100 if steps is None else steps
|
|
dtype = kwargs.pop('dtype', torch.float32)
|
|
dtype = dtype if dtype else torch.float32
|
|
if steps == 1:
|
|
output = torch.arange(start, end + 1, dtype=dtype, **kwargs)[:steps]
|
|
else:
|
|
output = torch.arange(
|
|
start, end + 1, (end - start) / (steps - 1), dtype=dtype,
|
|
**kwargs)[:steps]
|
|
return output
|