support coreml (#760)

* sdk inference

* fix typo

* fix typo

* add convert things

* fix missling name

* add cls support

* add more pytorch rewriter

* add det support

* support det wip

* make Model export model_path

* fix nms

* add output back

* add docstring

* fix lint

* add coreml build action

* add zh docs

* add coreml backend check

* update ci

* update

* update

* update

* update

* update

* fix lint

* update configs

* add return value when error occured

* update docs

* update docs

* update docs

* fix lint

* udpate docs

* udpate docs

* update

Co-authored-by: grimoire <streetyao@live.com>
pull/865/head
Chen Xin 2022-09-05 19:55:47 +08:00 committed by GitHub
parent 3d092bfb39
commit a0fb3be0df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
43 changed files with 1439 additions and 52 deletions

View File

@ -0,0 +1,71 @@
name: backend-coreml
on:
push:
paths:
- "csrc/**"
- "demo/csrc/**"
- "CMakeLists.txt"
pull_request:
paths:
- "csrc/**"
- "demo/csrc/**"
- "CMakeLists.txt"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
env:
DEVELOPER_DIR: /Applications/Xcode_13.4.1.app/Contents/Developer
permissions:
contents: read
jobs:
build_macos_arm64:
runs-on: macos-12
steps:
- name: Checkout repository
uses: actions/checkout@v3
with:
submodules: 'recursive'
- name: install opencv
run: |
wget https://github.com/irexyc/mmdeploy-ci-resource/releases/download/opencv/opencv-osx-arm64-4.6.0.tar.gz
mkdir $GITHUB_WORKSPACE/opencv-install
tar xf opencv-osx-arm64-4.6.0.tar.gz -C $GITHUB_WORKSPACE/opencv-install
- name: install libtorch
run: |
wget https://github.com/irexyc/mmdeploy-ci-resource/releases/download/libtorch/libtorch-osx-arm64-1.8.0.tar.gz
mkdir $GITHUB_WORKSPACE/libtorch-install
tar xf libtorch-osx-arm64-1.8.0.tar.gz -C $GITHUB_WORKSPACE/libtorch-install
- name: build
run: |
mkdir build && cd build
cmake .. -DCMAKE_OSX_ARCHITECTURES="arm64" \
-DCMAKE_SYSTEM_PROCESSOR="arm64" \
-DMMDEPLOY_BUILD_SDK=ON \
-DMMDEPLOY_TARGET_DEVICES="cpu" \
-DMMDEPLOY_CODEBASES=all \
-DOpenCV_DIR=$GITHUB_WORKSPACE/opencv-install/lib/cmake/opencv4 \
-DTorch_DIR=$GITHUB_WORKSPACE/libtorch-install/share/cmake/Torch \
-DMMDEPLOY_TARGET_BACKENDS="coreml" \
-DMMDEPLOY_BUILD_EXAMPLES=ON \
-DMMDEPLOY_SHARED_LIBS=OFF
cmake --build . -j 3
cmake --build . --target install
- name: build-shared
run: |
mkdir build-shared && cd build-shared
cmake .. -DCMAKE_OSX_ARCHITECTURES="arm64" \
-DCMAKE_SYSTEM_PROCESSOR="arm64" \
-DMMDEPLOY_BUILD_SDK=ON \
-DMMDEPLOY_TARGET_DEVICES="cpu" \
-DMMDEPLOY_CODEBASES=all \
-DOpenCV_DIR=$GITHUB_WORKSPACE/opencv-install/lib/cmake/opencv4 \
-DTorch_DIR=$GITHUB_WORKSPACE/libtorch-install/share/cmake/Torch \
-DMMDEPLOY_TARGET_BACKENDS="coreml" \
-DMMDEPLOY_BUILD_EXAMPLES=ON \
-DMMDEPLOY_SHARED_LIBS=ON
cmake --build . -j 3
cmake --build . --target install

View File

@ -77,6 +77,10 @@ if (MSVC)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/wd4251>)
endif ()
if(APPLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fobjc-arc")
endif()
add_library(MMDeployStaticModules INTERFACE)
add_library(MMDeployDynamicModules INTERFACE)
add_library(MMDeployLibs INTERFACE)

View File

@ -55,9 +55,9 @@ The currently supported codebases and models are as follows, and more will be in
Models can be exported and run in the following backends, and more will be compatible
| ONNX Runtime | TensorRT | ppl.nn | ncnn | OpenVINO | LibTorch | snpe | Ascend | more |
| ------------ | -------- | ------ | ---- | -------- | -------- | ---- | ------ | ---------------------------------------------- |
| ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | [benchmark](docs/en/03-benchmark/benchmark.md) |
| ONNX Runtime | TensorRT | ppl.nn | ncnn | OpenVINO | LibTorch | snpe | Ascend | Core ML | more |
| ------------ | -------- | ------ | ---- | -------- | -------- | ---- | ------ | ------- | ---------------------------------------------- |
| ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | [benchmark](docs/en/03-benchmark/benchmark.md) |
### Efficient and scalable C/C++ SDK Framework
@ -71,6 +71,7 @@ Please read [getting_started](docs/en/get_started.md) for the basic usage of MMD
- [Build from Docker](docs/en/01-how-to-build/build_from_docker.md)
- [Build from Script](docs/en/01-how-to-build/build_from_script.md)
- [Build for Linux](docs/en/01-how-to-build/linux-x86_64.md)
- [Build for macOS](docs/en/01-how-to-build/macos-arm64.md)
- [Build for Win10](docs/en/01-how-to-build/windows.md)
- [Build for Android](docs/en/01-how-to-build/android.md)
- [Build for Jetson](docs/en/01-how-to-build/jetsons.md)

View File

@ -53,9 +53,9 @@ MMDeploy 是 [OpenMMLab](https://openmmlab.com/) 模型部署工具箱,**为
### 支持多种推理后端
| ONNX Runtime | TensorRT | ppl.nn | ncnn | OpenVINO | LibTorch | snpe | Ascend | more |
| ------------ | -------- | ------ | ---- | -------- | -------- | ---- | ------ | ------------------------------------------------- |
| ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | [benchmark](docs/zh_cn/03-benchmark/benchmark.md) |
| ONNX Runtime | TensorRT | ppl.nn | ncnn | OpenVINO | LibTorch | snpe | Ascend | Core ML | more |
| ------------ | -------- | ------ | ---- | -------- | -------- | ---- | ------ | ------- | ---------------------------------------------- |
| ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | [benchmark](docs/en/03-benchmark/benchmark.md) |
### SDK 可高度定制化
@ -70,6 +70,7 @@ MMDeploy 是 [OpenMMLab](https://openmmlab.com/) 模型部署工具箱,**为
- [一键式脚本安装](docs/zh_cn/01-how-to-build/build_from_script.md)
- [Build from Docker](docs/zh_cn/01-how-to-build/build_from_docker.md)
- [Build for Linux](docs/zh_cn/01-how-to-build/linux-x86_64.md)
- [Build for macOS](docs/zh_cn/01-how-to-build/macos-arm64.md)
- [Build for Win10](docs/zh_cn/01-how-to-build/windows.md)
- [Build for Android](docs/zh_cn/01-how-to-build/android.md)
- [Build for Jetson](docs/en/01-how-to-build/jetsons.md)

View File

@ -0,0 +1 @@
backend_config = dict(type='coreml', convert_to='mlprogram')

View File

@ -0,0 +1,12 @@
_base_ = ['../_base_/torchscript_config.py', '../_base_/backends/coreml.py']
codebase_config = dict(type='mmcls', task='Classification')
backend_config = dict(model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 224, 224],
max_shape=[8, 3, 224, 224],
default_shape=[1, 3, 224, 224])))
])

View File

@ -0,0 +1,11 @@
_base_ = ['./base_torchscript.py', '../../_base_/backends/coreml.py']
ir_config = dict(input_shape=(1344, 800))
backend_config = dict(model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 800, 1344],
max_shape=[1, 3, 800, 1344],
default_shape=[1, 3, 800, 1344])))
])

View File

@ -0,0 +1 @@
_base_ = ['../_base_/base_coreml_static-800x1344.py']

View File

@ -0,0 +1,14 @@
_base_ = [
'../_base_/torchscript_config.py', '../_base_/backends/coreml.py',
'./segmentation_static.py'
]
ir_config = dict(input_shape=[1024, 512])
backend_config = dict(model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 512, 1024],
max_shape=[1, 3, 512, 1024],
default_shape=[1, 3, 512, 1024])))
])

View File

@ -32,7 +32,8 @@ if ("ncnn" IN_LIST MMDEPLOY_TARGET_BACKENDS)
endif ()
# build TorchScript ops
if ("torchscript" IN_LIST MMDEPLOY_TARGET_BACKENDS)
message(STATUS "Build torchsciprt custom ops")
if ("torchscript" IN_LIST MMDEPLOY_TARGET_BACKENDS
OR "coreml" IN_LIST MMDEPLOY_TARGET_BACKENDS)
message(STATUS "Build torchscript custom ops")
add_subdirectory(torchscript)
endif ()

View File

@ -1,10 +0,0 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "torch/script.h"
TORCH_LIBRARY(mmdeploy, m) {
m.def(
"modulated_deform_conv(Tensor input, Tensor weight, Tensor bias, Tensor offset, Tensor "
"mask, "
"int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int "
"dilation_h,int dilation_w, int groups, int deform_groups, bool with_bias) -> Tensor");
}

View File

@ -0,0 +1,13 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "torch/script.h"
TORCH_LIBRARY(mmdeploy, m) {
m.def(
"modulated_deform_conv(Tensor input, Tensor weight, Tensor bias, Tensor offset, Tensor "
"mask, "
"int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int "
"dilation_h,int dilation_w, int groups, int deform_groups, bool with_bias) -> Tensor")
.def(
"coreml_nms(Tensor boxes, Tensor scores, float iou_threshold, "
"float score_threshold, int max_boxes) -> Tensor[]");
}

View File

@ -0,0 +1,31 @@
#include <assert.h>
#include <vector>
#include "torch/script.h"
namespace mmdeploy {
using at::Tensor;
std::vector<Tensor> coreml_nms_cpu(Tensor boxes, Tensor scores, double iou_threshold,
double score_threshold, int64_t max_boxes) {
assert(boxes.dim() == 3); // bboxes with shape (batch_size, num_bboxes, 4)
assert(boxes.size(2) == 4);
assert(boxes.size(0) == scores.size(0)); // check batch size
assert(boxes.size(1) == scores.size(1)); // check num boxes
auto batch_size = boxes.size(0);
auto num_boxes = boxes.size(1);
auto num_classes = scores.size(2);
Tensor ret_boxes = at::zeros({batch_size, max_boxes, 4});
Tensor ret_scores = at::zeros({batch_size, max_boxes, num_classes});
Tensor indices = at::zeros({batch_size, max_boxes}, at::kInt);
Tensor num_outputs = at::zeros({batch_size}, at::kInt);
return std::vector<Tensor>({ret_boxes, ret_scores, indices, num_outputs});
}
TORCH_LIBRARY_IMPL(mmdeploy, CPU, m) { m.impl("coreml_nms", coreml_nms_cpu); }
} // namespace mmdeploy

View File

@ -20,6 +20,7 @@ Model::Model(const std::string& model_path) {
Model::Model(const void* buffer, size_t size) { Init(buffer, size).value(); }
Result<void> Model::Init(const std::string& model_path) {
model_path_ = model_path;
if (!fs::exists(model_path)) {
MMDEPLOY_ERROR("'{}' doesn't exist", model_path);
return Status(eFileNotExist);
@ -45,6 +46,8 @@ Result<void> Model::Init(const std::string& model_path) {
return Status(eNotSupported);
}
const std::string& Model::GetModelPath() const { return model_path_; }
Result<void> Model::Init(const void* buffer, size_t size) {
auto registry = ModelRegistry::Get();
auto entries = registry.ListEntries();

View File

@ -94,7 +94,14 @@ class MMDEPLOY_API Model {
*/
explicit operator bool() const { return impl_ != nullptr; }
/**
* @brief get model_path that init with DirectoryModel
* @return file path of an sdk model
*/
const std::string& GetModelPath() const;
private:
std::string model_path_;
std::shared_ptr<ModelImpl> impl_;
deploy_meta_info_t meta_;
};

View File

@ -34,5 +34,9 @@ if ("torchscript" IN_LIST MMDEPLOY_TARGET_BACKENDS)
add_subdirectory(torchscript)
endif ()
if ("coreml" IN_LIST MMDEPLOY_TARGET_BACKENDS)
add_subdirectory(coreml)
endif ()
mmdeploy_add_module(${PROJECT_NAME} net_module.cpp)
add_library(mmdeploy::net_module ALIAS ${PROJECT_NAME})

View File

@ -0,0 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
project(mmdeploy_coreml_net)
if ("cpu" IN_LIST MMDEPLOY_TARGET_DEVICES)
find_library(CORE_ML CoreML)
find_library(FOUNDATION Foundation)
mmdeploy_add_module(${PROJECT_NAME} coreml_net.mm)
target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_link_libraries(${PROJECT_NAME} PRIVATE ${CORE_ML} ${FOUNDATION})
add_library(mmdeploy::coreml_net ALIAS ${PROJECT_NAME})
else ()
message(ERROR "'coreml_net' is NOT supported in target devices: ${MMDEPLOY_TARGET_DEVICES}")
endif ()

View File

@ -0,0 +1,37 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_SRC_NET_COREML_COREML_NET_H_
#define MMDEPLOY_SRC_NET_COREML_COREML_NET_H_
#include "mmdeploy/core/net.h"
namespace mmdeploy {
namespace coreml {
class Execution;
} // namespace coreml
class CoreMLNet : public Net {
public:
~CoreMLNet() override = default;
Result<void> Init(const Value& cfg) override;
Result<void> Deinit() override;
Result<Span<Tensor>> GetInputTensors() override;
Result<Span<Tensor>> GetOutputTensors() override;
Result<void> Reshape(Span<TensorShape> input_shapes) override;
Result<void> Forward() override;
Result<void> ForwardAsync(Event* event) override;
private:
std::unique_ptr<coreml::Execution> execution_;
std::vector<Tensor> input_tensors_;
std::vector<Tensor> output_tensors_;
Device device_;
Stream stream_;
friend class coreml::Execution;
};
} // namespace mmdeploy
#endif // MMDEPLOY_SRC_NET_ORT_ORT_NET_H_

View File

@ -0,0 +1,326 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "coreml_net.h"
#include "mmdeploy/core/model.h"
#include "mmdeploy/core/status_code.h"
#include "mmdeploy/core/utils/filesystem.h"
#include <fstream>
#import <CoreML/CoreML.h>
#import <Foundation/Foundation.h>
#include <memory>
@interface MMBatchTensorFeatureProvider : NSObject <MLBatchProvider> {
const std::vector<mmdeploy::Tensor> *inputs_;
}
- (instancetype)initWithInputs:(const std::vector<mmdeploy::Tensor> &)inputs;
- (NSInteger)count;
- (id<MLFeatureProvider>)featuresAtIndex:(NSInteger)index;
@end
@implementation MMBatchTensorFeatureProvider
- (instancetype)initWithInputs:(const std::vector<mmdeploy::Tensor> &)inputs {
inputs_ = &inputs;
return self;
}
- (NSInteger)count {
return (*inputs_)[0].shape(0);
}
- (id<MLFeatureProvider>)featuresAtIndex:(NSInteger)index {
MLDictionaryFeatureProvider *feature = nil;
NSMutableDictionary<NSString *, id> *input_dict =
[[NSMutableDictionary<NSString *, id> alloc] init];
for (auto x : *inputs_) {
auto in = x.Slice(index);
NSMutableArray *shape = [[NSMutableArray alloc] init];
for (const auto dim : in.shape()) {
[shape addObject:[NSNumber numberWithLongLong:dim]];
}
NSMutableArray *strides = [[NSMutableArray alloc] init];
int64_t stride = 1;
for (int i = in.shape().size() - 1; i >= 0; i--) {
[strides insertObject:[NSNumber numberWithLongLong:stride] atIndex:0];
stride *= in.shape()[i];
}
MLMultiArrayDataType data_type = MLMultiArrayDataTypeFloat32;
NSError *error = nil;
MLMultiArray *mlArray =
[[MLMultiArray alloc] initWithDataPointer:in.data()
shape:shape
dataType:data_type
strides:strides
deallocator:(^(void *){
})error:&error];
if (error != nil) {
MMDEPLOY_ERROR("init MLMultiArray failed with key: {}, error message: {}",
in.name(), [[error localizedDescription] UTF8String]);
return nil;
}
NSString *key = [NSString stringWithUTF8String:in.name()];
input_dict[key] = mlArray;
}
NSError *error = nil;
feature = [[MLDictionaryFeatureProvider alloc] initWithDictionary:input_dict
error:&error];
if (error != nil) {
MMDEPLOY_ERROR("init MLDictionaryFeatureProvider failed with index: {}, "
"error message: {}",
index, [[error localizedDescription] UTF8String]);
return nil;
}
return feature;
}
@end
namespace mmdeploy {
namespace coreml {
static Result<void> CheckInputOutputFeatureType(MLFeatureType type) {
if (type != MLFeatureTypeMultiArray) {
MMDEPLOY_ERROR("unsupported feature type: {}", type);
return Status(eInvalidArgument);
}
return success();
}
static TensorShape to_shape(NSArray<NSNumber *> *shape) {
TensorShape _shape;
for (int i = 0; i < shape.count; i++) {
_shape.push_back(shape[i].intValue);
}
return _shape;
}
static Result<DataType> ConvertElementType(MLMultiArrayDataType type) {
switch (type) {
case MLMultiArrayDataTypeFloat32:
return DataType::kFLOAT;
case MLMultiArrayDataTypeFloat16:
return DataType::kHALF;
case MLMultiArrayDataTypeInt32:
return DataType::kINT32;
default:
MMDEPLOY_ERROR("unsupported MLMultiArrayDataType: {}",
static_cast<int>(type));
return Status(eNotSupported);
}
}
static Result<Tensor> AsTensor(MLMultiArray *mlArray, const Device &device) {
TensorDesc desc;
desc.device = device;
desc.shape = to_shape(mlArray.shape);
OUTCOME_TRY(desc.data_type, ConvertElementType(mlArray.dataType));
std::shared_ptr<void> data(const_cast<void *>(mlArray.dataPointer),
[](void *) {});
return Tensor(desc, data);
}
class Execution {
public:
Execution(const std::string &path, CoreMLNet *net) : path_(path), net_(net) {}
~Execution() { RemoveModel(); }
Result<void> Init() {
OUTCOME_TRY(LoadModel());
OUTCOME_TRY(SetInputOutputTensor());
return success();
}
Result<void> Forward() {
int batch_size = net_->input_tensors_[0].shape(0);
// prepare input
NSError *error = nil;
MMBatchTensorFeatureProvider *input_feature =
[[MMBatchTensorFeatureProvider alloc]
initWithInputs:net_->input_tensors_];
id<MLBatchProvider> output_feature =
[model_ predictionsFromBatch:input_feature error:&error];
if (error != nil) {
MMDEPLOY_ERROR("coreml forward failed, error message: {}",
[[error localizedDescription] UTF8String]);
return Status(eFail);
}
// extract output
for (size_t i = 0; i < net_->output_tensors_.size(); ++i) {
auto &out = net_->output_tensors_[i];
for (int bid = 0; bid < output_feature.count; bid++) {
NSString *name =
[NSString stringWithCString:out.name()
encoding:[NSString defaultCStringEncoding]];
if (name == nil) {
MMDEPLOY_ERROR("output name must not be nil");
return Status(eFail);
}
MLFeatureValue *output_value =
[[output_feature featuresAtIndex:bid] featureValueForName:name];
if (output_value == nil) {
MMDEPLOY_ERROR("model output doesn't have name tensort: {}",
out.name());
return Status(eFail);
}
MLMultiArray *mlArray = [output_value multiArrayValue];
OUTCOME_TRY(auto tmp, AsTensor(mlArray, out.device()));
if (bid == 0) {
TensorShape batch_shape = tmp.shape();
batch_shape[0] = batch_size;
out.Reshape(batch_shape);
}
auto slice = out.Slice(bid);
OUTCOME_TRY(tmp.CopyTo(slice, net_->stream_));
}
}
return success();
}
Result<void> SetInputOutputTensor() {
// input
auto input_desc = model_.modelDescription.inputDescriptionsByName;
for (NSString *name in input_desc) {
MLFeatureDescription *value = input_desc[name];
OUTCOME_TRY(CheckInputOutputFeatureType(value.type));
// use default shape
auto shape = to_shape(value.multiArrayConstraint.shape);
OUTCOME_TRY(auto data_type,
ConvertElementType(value.multiArrayConstraint.dataType));
net_->input_tensors_.emplace_back(
TensorDesc{net_->device_, data_type, shape, [name UTF8String]});
}
// output
auto output_desc = model_.modelDescription.outputDescriptionsByName;
for (NSString *name in output_desc) {
MLFeatureDescription *value = output_desc[name];
OUTCOME_TRY(auto data_type,
ConvertElementType(value.multiArrayConstraint.dataType));
// can't get output shape
net_->output_tensors_.emplace_back(
TensorDesc{net_->device_, data_type, {}, [name UTF8String]});
}
return success();
}
Result<void> Reshape(Span<TensorShape> input_shapes) {
for (size_t i = 0; i < input_shapes.size(); ++i) {
net_->input_tensors_[i].Reshape(input_shapes[i]);
}
return success();
}
Result<void> LoadModel() {
NSString *model_path = [NSString stringWithUTF8String:path_.c_str()];
NSError *error = nil;
NSURL *model_url = [NSURL URLWithString:model_path];
compiled_model_url_ = [MLModel compileModelAtURL:model_url error:&error];
if (error != nil) {
MMDEPLOY_ERROR("failed to compile model, error message: {}",
[[error localizedDescription] UTF8String]);
return Status(eFail);
}
MLModelConfiguration *config = [MLModelConfiguration alloc];
config.computeUnits = MLComputeUnitsAll;
model_ = [MLModel modelWithContentsOfURL:compiled_model_url_
configuration:config
error:&error];
if (error != nil) {
MMDEPLOY_ERROR("failed to construct model, error message: {}",
[[error localizedDescription] UTF8String]);
return Status(eFail);
}
return success();
}
void RemoveModel() {
NSError *error = nil;
if (compiled_model_url_ != nil) {
[[NSFileManager defaultManager] removeItemAtURL:compiled_model_url_
error:&error];
if (error != nil) {
MMDEPLOY_ERROR("failed to remove compiled model, error message: {}",
[[error localizedDescription] UTF8String]);
}
compiled_model_url_ = nil;
}
}
NSURL *compiled_model_url_{nil};
MLModel *model_{nil};
std::string path_;
CoreMLNet *net_{nullptr};
};
} // namespace coreml
Result<void> CoreMLNet::Init(const Value &cfg) {
auto &context = cfg["context"];
device_ = context["device"].get<Device>();
stream_ = context["stream"].get<Stream>();
auto name = cfg["name"].get<std::string>();
auto model = context["model"].get<Model>();
OUTCOME_TRY(auto config, model.GetModelConfig(name));
std::string coreml_tmp_path =
(fs::path(model.GetModelPath()) / config.net).string();
execution_ = std::make_unique<coreml::Execution>(coreml_tmp_path, this);
OUTCOME_TRY(execution_->Init());
return success();
}
Result<void> CoreMLNet::Deinit() { return success(); }
Result<Span<Tensor>> CoreMLNet::GetInputTensors() { return input_tensors_; }
Result<Span<Tensor>> CoreMLNet::GetOutputTensors() { return output_tensors_; }
Result<void> CoreMLNet::Reshape(Span<TensorShape> input_shapes) {
return execution_->Reshape(input_shapes);
}
Result<void> CoreMLNet::Forward() { return execution_->Forward(); }
Result<void> CoreMLNet::ForwardAsync(Event *event) {
return Status(eNotSupported);
}
class CoreMLNetCreator : public Creator<Net> {
public:
const char *GetName() const override { return "coreml"; }
int GetVersion() const override { return 0; }
std::unique_ptr<Net> Create(const Value &args) override {
auto p = std::make_unique<CoreMLNet>();
if (auto r = p->Init(args)) {
return p;
} else {
MMDEPLOY_ERROR("error creating CoreMLNet: {}",
r.error().message().c_str());
return nullptr;
}
}
};
REGISTER_MODULE(Net, CoreMLNetCreator);
} // namespace mmdeploy

View File

@ -0,0 +1,175 @@
# Build for macOS-arm64
- [Build for macOS-arm64](#build-for-macos-arm64)
- [Install Toolchains](#install-toolchains)
- [Install Dependencies](#install-dependencies)
- [Install Dependencies for Model Converter](#install-dependencies-for-model-converter)
- [Install Dependencies for SDK](#install-dependencies-for-sdk)
- [Install Inference Engines for MMDeploy](#install-inference-engines-for-mmdeploy)
- [Build MMDeploy](#build-mmdeploy)
- [Build Model Converter](#build-model-converter)
- [Install Model Converter](#install-model-converter)
- [Build SDK and Demo](#build-sdk-and-demo)
## Install Toolchains
- cmake
```
brew install cmake
```
- clang
install Xcode or Command Line Tools
```
xcode-select --install
```
## Install Dependencies
### Install Dependencies for Model Converter
Please refer to [get_started](../get_started.md) to install conda.
```bash
# install pytorch & mmcv-full
conda install pytorch==1.9.0 torchvision==0.10.0 -c pytorch
pip install mmcv-full==1.6.0 -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.9.0/index.html
```
### Install Dependencies for SDK
You can skip this chapter if you are only interested in the model converter.
<table class="docutils">
<thead>
<tr>
<th>NAME </th>
<th>INSTALLATION </th>
</tr>
</thead>
<tbody>
<tr>
<td>OpenCV<br>(>=3.0) </td>
<td>
<pre><code>
brew install opencv
</code></pre>
</td>
</tbody>
</table>
### Install Inference Engines for MMDeploy
Both MMDeploy's model converter and SDK share the same inference engines.
You can select you interested inference engines and do the installation by following the given commands.
This document focus on Core ML. The installation of ONNX Runtime, ncnn and TorchScript is similar to the linux platform, please refer to the document [linux-x86_64](linux-x86_64.md) for installation.
The TorchScript model is used as the IR in the conversion process of the Core ML model. In order to support the custom operator in some models like detection models in mmdet, libtorch needs to be installed.
<table class="docutils">
<thead>
<tr>
<th>NAME</th>
<th>PACKAGE</th>
<th>INSTALLATION</th>
</tr>
</thead>
<tbody>
<tr>
<td>Core ML</td>
<td>coremltools</td>
<td>
<pre><code>
pip install coremltools==6.0b2
</code></pre>
</td>
</tr>
<tr>
<td>TorchScript</td>
<td>libtorch</td>
<td>
1. Libtorch doesn't provide prebuilt arm library for macOS, so you need to compile it yourself. Please note that the version of libtorch must be consistent with the version of pytorch. <br>
2. Take LibTorch 1.9.0 as an example. You can install it like this:
<pre><code>
git clone -b v1.9.0 --recursive https://github.com/pytorch/pytorch.git
cd pytorch
mkdir build && cd build
cmake .. \
-DCMAKE_BUILD_TYPE=Release \
-DPYTHON_EXECUTABLE=`which python` \
-DCMAKE_INSTALL_PREFIX=install \
-DDISABLE_SVE=ON # low version like 1.9.0 of pytorch need DISABLE_SVE option
make -j4 && make install
export Torch_DIR=$(pwd)/install/share/cmake/Torch
</code></pre>
</td>
</tr>
</tbody>
</table>
## Build MMDeploy
```bash
cd /the/root/path/of/MMDeploy
export MMDEPLOY_DIR=$(pwd)
```
### Build Model Converter
- **Core ML**
Core ML uses torchscript as IR, to convert models in some codebases like mmdet, you need to compile torchscript custom operators
- **torchscript** custom operators
```bash
cd ${MMDEPLOY_DIR}
mkdir -p build && cd build
cmake -DMMDEPLOY_TARGET_BACKENDS=coreml -DTorch_DIR=${Torch_DIR} ..
make -j4 && make install
```
Please check [cmake build option](cmake_option.md).
### Install Model Converter
```bash
# You should use `conda install` to install the grpcio in requirements/runtime.txt
conda install grpcio
```
```bash
cd ${MMDEPLOY_DIR}
pip install -v -e .
```
**Note**
- Some dependencies are optional. Simply running `pip install -e .` will only install the minimum runtime requirements.
To use optional dependencies, install them manually with `pip install -r requirements/optional.txt` or specify desired extras when calling `pip` (e.g. `pip install -e .[optional]`).
Valid keys for the extras field are: `all`, `tests`, `build`, `optional`.
### Build SDK and Demo
The following shows an example of building an SDK using Core ML as the inference engine.
- cpu + Core ML
```Bash
cd ${MMDEPLOY_DIR}
mkdir -p build && cd build
cmake .. \
-DMMDEPLOY_BUILD_SDK=ON \
-DMMDEPLOY_BUILD_EXAMPLES=ON \
-DMMDEPLOY_BUILD_SDK_PYTHON_API=ON \
-DMMDEPLOY_TARGET_DEVICES=cpu \
-DMMDEPLOY_TARGET_BACKENDS=coreml \
-DTorch_DIR=${Torch_DIR}
make -j4 && make install
```

View File

@ -0,0 +1,31 @@
# Core ML feature support
MMDeploy support convert Pytorch model to Core ML and inference.
## Installation
To convert the model in mmdet, you need to compile libtorch to support custom operators such as nms.
```bash
cd ${PYTORCH_DIR}
mkdir build && cd build
cmake .. \
-DCMAKE_BUILD_TYPE=Release \
-DPYTHON_EXECUTABLE=`which python` \
-DCMAKE_INSTALL_PREFIX=install \
-DDISABLE_SVE=ON # low version like 1.8.0 of pytorch need this option
make install
```
## Usage
```bash
python tools/deploy.py \
configs/mmdet/detection/detection_coreml_static-800x1344.py \
/mmdetection_dir/configs/retinanet/retinanet_r18_fpn_1x_coco.py \
/checkpoint/retinanet_r18_fpn_1x_coco_20220407_171055-614fd399.pth \
/mmdetection_dir/demo/demo.jpg \
--work-dir work_dir/retinanet \
--device cpu \
--dump-info
```

View File

@ -104,7 +104,7 @@
<tr>
<td>MMDEPLOY_TARGET_BACKENDS</td>
<td>{"trt", "ort", "pplnn", "ncnn", "openvino", "torchscript", "snpe"}</td>
<td>{"trt", "ort", "pplnn", "ncnn", "openvino", "torchscript", "snpe", "coreml"}</td>
<td>N/A</td>
<td> <b>默认情况下SDK不设置任何后端</b>, 因为它与应用场景高度相关。 当选择多个后端时, 中间使用分号隔开。比如,<pre><code>-DMMDEPLOY_TARGET_BACKENDS="trt;ort;pplnn;ncnn;openvino"</code></pre>
构建时,几乎每个后端,都需设置一些路径变量,用来查找依赖包。<br>
@ -120,6 +120,7 @@
5. <b>openvino</b>: 表示 OpenVINO。需要设置 <code>InferenceEngine_DIR</code><br>
6. <b>torchscript</b>: 表示 TorchScript。目前仅模型转换支持 torchscript 格式SDK 尚未支持。<br>
7. <b>snpe</b>: 表示 qcom snpe。需要环境变量设置 SNPE_ROOT。<br>
8. <b>coreml</b>: 表示 Core ML。目前在进行模型转换时需要设置 <code>Torch_DIR</code><br>
</td>
</tr>

View File

@ -0,0 +1,175 @@
# macOS-arm64 下构建方式
- [macOS-arm64 下构建方式](#macos-arm64-下构建方式)
- [源码安装](#源码安装)
- [安装构建和编译工具链](#安装构建和编译工具链)
- [安装依赖包](#安装依赖包)
- [安装 MMDeploy Converter 依赖](#安装-mmdeploy-converter-依赖)
- [安装 MMDeploy SDK 依赖](#安装-mmdeploy-sdk-依赖)
- [安装推理引擎](#安装推理引擎)
- [编译 MMDeploy](#编译-mmdeploy)
- [编译 Model Converter](#编译-model-converter)
- [安装 Model Converter](#安装-model-converter)
- [编译 SDK 和 Demos](#编译-sdk-和-demos)
## 源码安装
### 安装构建和编译工具链
- cmake
```
brew install cmake
```
- clang
安装 Xcode 或者通过如下命令安装 Command Line Tools
```
xcode-select --install
```
### 安装依赖包
#### 安装 MMDeploy Converter 依赖
参考[get_started](../get_started.md)文档安装conda。
```bash
# install pytoch & mmcv-full
conda install pytorch==1.9.0 torchvision==0.10.0 -c pytorch
pip install mmcv-full==1.6.0 -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.9.0/index.html
```
#### 安装 MMDeploy SDK 依赖
如果您只对模型转换感兴趣,那么可以跳过本章节。
<table class="docutils">
<thead>
<tr>
<th>名称 </th>
<th>安装说明 </th>
</tr>
</thead>
<tbody>
<tr>
<td>OpenCV<br>(>=3.0) </td>
<td>
<pre><code>
brew install opencv
</code></pre>
</td>
</tbody>
</table>
#### 安装推理引擎
MMDeploy 的 Model Converter 和 SDK 共享推理引擎。您可以参考下文,选择自己感兴趣的推理引擎安装。这里重点介绍 Core ML。ONNX Runtimencnn 以及 TorchScript 的安装类似 linux 平台,可参考文档 [linux-x86_64](linux-x86_64.md) 进行安装。
Core ML 模型的转化过程中使用 TorchScript 模型作为IR为了支持含有自定义算子的模型如 mmdet 中的检测模型,需要安装 libtorch这里作简单说明。
<table class="docutils">
<thead>
<tr>
<th>名称</th>
<th>安装包</th>
<th>安装说明</th>
</tr>
</thead>
<tbody>
<tr>
<td>Core ML</td>
<td>coremltools</td>
<td>
<pre><code>
pip install coremltools==6.0b2
</code></pre>
</td>
</tr>
<tr>
<td>TorchScript</td>
<td>libtorch</td>
<td>
1. libtorch暂不提供arm版本的library故需要自行编译。编译时注意libtorch要和pytorch的版本保持一致这样编译出的自定义算子才可以加载成功。<br>
2. 以libtorch 1.9.0为例,可通过如下命令安装:
<pre><code>
git clone -b v1.9.0 --recursive https://github.com/pytorch/pytorch.git
cd pytorch
mkdir build && cd build
cmake .. \
-DCMAKE_BUILD_TYPE=Release \
-DPYTHON_EXECUTABLE=`which python` \
-DCMAKE_INSTALL_PREFIX=install \
-DDISABLE_SVE=ON # low version like 1.9.0 of pytorch need DISABLE_SVE option
make -j4 && make install
export Torch_DIR=$(pwd)/install/share/cmake/Torch
</code></pre>
</td>
</tr>
</tbody>
</table>
### 编译 MMDeploy
```bash
cd /the/root/path/of/MMDeploy
export MMDEPLOY_DIR=$(pwd)
```
#### 编译 Model Converter
这里介绍使用 Core ML 作为推理后端所需的操作。
- **Core ML**
Core ML使用 torchscript 作为IR某些 codebase 如 mmdet 需要编译 torchscript 自定义算子。
- **torchscript** 自定义算子
```bash
cd ${MMDEPLOY_DIR}
mkdir -p build && cd build
cmake -DMMDEPLOY_TARGET_BACKENDS=coreml -DTorch_DIR=${Torch_DIR} ..
make -j4 && make install
```
参考 [cmake 选项说明](cmake_option.md)
#### 安装 Model Converter
```bash
# requirements/runtime.txt 中依赖项grpcio通过pip安装的方式无法正常import, 需使用 conda 安装
conda install grpcio
```
```bash
cd ${MMDEPLOY_DIR}
pip install -v -e .
```
**注意**
- 有些依赖项是可选的。运行 `pip install -e .` 将进行最小化依赖安装。 如果需安装其他可选依赖项,请执行`pip install -r requirements/optional.txt`
或者 `pip install -e .[optional]`。其中,`[optional]`可以替换为:`all`、`tests`、`build` 或 `optional`
#### 编译 SDK 和 Demos
下文展示使用 Core ML 作为推理引擎构建SDK的样例。
- cpu + Core ML
```Bash
cd ${MMDEPLOY_DIR}
mkdir -p build && cd build
cmake .. \
-DMMDEPLOY_BUILD_SDK=ON \
-DMMDEPLOY_BUILD_EXAMPLES=ON \
-DMMDEPLOY_BUILD_SDK_PYTHON_API=ON \
-DMMDEPLOY_TARGET_DEVICES=cpu \
-DMMDEPLOY_TARGET_BACKENDS=coreml \
-DTorch_DIR=${Torch_DIR}
make -j4 && make install
```

View File

@ -0,0 +1,31 @@
# Core ML 支持情况
目前 mmdeploy 集成了 OpenMMLab 算法库中 Pytorch 模型到 Core ML模型的转换以及推理。
## 安装
转换 mmdet 中的模型,需要编译 libtorch 支持 nms 等自定义算子
```bash
cd ${PYTORCH_DIR}
mkdir build && cd build
cmake .. \
-DCMAKE_BUILD_TYPE=Release \
-DPYTHON_EXECUTABLE=`which python` \
-DCMAKE_INSTALL_PREFIX=install \
-DDISABLE_SVE=ON # 低版本比如1.8.0需要加上这个参数
make install
```
## 使用
```bash
python tools/deploy.py \
configs/mmdet/detection/detection_coreml_static-800x1344.py \
/mmdetection_dir/configs/retinanet/retinanet_r18_fpn_1x_coco.py \
/checkpoint/retinanet_r18_fpn_1x_coco_20220407_171055-614fd399.pth \
/mmdetection_dir/demo/demo.jpg \
--work-dir work_dir/retinanet \
--device cpu \
--dump-info
```

View File

@ -0,0 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.backend.coreml import is_available
__all__ = ['is_available']
if is_available():
from mmdeploy.backend.coreml.torchscript2coreml import \
from_torchscript as _from_torchscript
from mmdeploy.backend.coreml.torchscript2coreml import get_model_suffix
from ..core import PIPELINE_MANAGER
from_torchscript = PIPELINE_MANAGER.register_pipeline()(_from_torchscript)
__all__ += ['from_torchscript', 'get_model_suffix']

View File

@ -3,7 +3,6 @@ from copy import deepcopy
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import torch
from packaging.version import parse as version_parse
from mmdeploy.core import RewriterContext, patch_model
from mmdeploy.utils import IR, Backend, get_ir_config, get_root_logger
@ -100,20 +99,6 @@ def trace(func: torch.nn.Module,
check_trace=check_trace,
check_tolerance=check_tolerance)
logger.info('perform torchscript optimizer.')
try:
# custom optimizer
from mmdeploy.backend.torchscript import ts_optimizer
logger = get_root_logger()
ts_optimizer.optimize_for_backend(
ts_model._c, ir=IR.TORCHSCRIPT.value, backend=backend)
except Exception:
# use pytorch builtin optimizer
ts_model = torch.jit.freeze(ts_model)
torch_version = version_parse(torch.__version__)
if torch_version.minor >= 9:
ts_model = torch.jit.optimize_for_inference(ts_model)
# save model
if output_path_prefix is not None:
output_path = output_path_prefix + '.pt'

View File

@ -0,0 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
import importlib
def is_available():
"""Check whether coremltools is installed.
Returns:
bool: True if coremltools package is installed.
"""
return importlib.util.find_spec('coremltools') is not None
__all__ = []
if is_available():
from . import ops
from .torchscript2coreml import get_model_suffix
from .wrapper import CoreMLWrapper
__all__ += ['CoreMLWrapper', 'get_model_suffix', 'ops']

View File

@ -0,0 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.
from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
from coremltools.converters.mil.frontend.torch.torch_op_registry import \
register_torch_op
@register_torch_op
def coreml_nms(context, node):
"""bind CoreML NMS op."""
inputs = _get_inputs(context, node)
boxes = inputs[0]
scores = inputs[1]
iou_threshold = inputs[2]
score_threshold = inputs[3]
max_boxes = inputs[4]
results = mb.non_maximum_suppression(
boxes=boxes,
scores=scores,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
max_boxes=max_boxes)
context.add(tuple(results), torch_name=node.outputs[0])

View File

@ -0,0 +1,116 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Sequence, Union
import coremltools as ct
import torch
from mmdeploy.utils import get_root_logger
try:
# user might need ops from torchvision
import torchvision # noqa
except ImportError:
pass
def get_model_suffix(convert_to: str) -> str:
assert convert_to == 'neuralnetwork' or convert_to == 'mlprogram'
suffix = ''
if convert_to == 'neuralnetwork':
suffix = '.mlmodel'
if convert_to == 'mlprogram':
suffix = '.mlpackage'
return suffix
def create_shape(name: str, input_shapes: Dict) -> ct.Shape:
"""Create input shape."""
min_shape = input_shapes['min_shape']
max_shape = input_shapes['max_shape']
default_shape = input_shapes['default_shape']
assert len(min_shape) == len(max_shape) == len(default_shape)
shape = []
n_dim = len(min_shape)
for i in range(n_dim):
low = min_shape[i]
high = max_shape[i]
assert low <= high
if low == -1 or high == -1:
shape.append(ct.RangeDim())
elif low == high:
shape.append(low)
else:
shape.append(ct.RangeDim(low, high))
shape = ct.Shape(shape=shape, default=default_shape)
return ct.TensorType(shape=shape, name=name)
def from_torchscript(torchscript_model: Union[str,
torch.jit.RecursiveScriptModule],
output_file_prefix: str,
input_names: Sequence[str],
output_names: Sequence[str],
input_shapes: Dict,
convert_to: str = 'neuralnetwork',
fp16_mode: bool = False,
skip_model_load: bool = True,
**kwargs):
"""Create a coreml engine from torchscript.
Args:
torchscript_model (Union[str, torch.jit.RecursiveScriptModule]):
The torchscript model to be converted.
output_file_prefix (str): The output file prefix.
input_names (Sequence[str]): The input names of the model.
output_names (Sequence[str]): The output names of the model.
input_shapes (Dict): The input shapes include max_shape, min_shape and
default_shape
convert_to (str, optional): The converted model type, can be
'neuralnetwork' or 'mlprogram'. Defaults to 'neuralnetwork'.
fp16_mode (bool, optional): Convert to fp16 model. Defaults to False.
skip_model_load (bool, optional): Skip model load. Defaults to True.
"""
try:
from mmdeploy.backend.torchscript import get_ops_path
torch.ops.load_library(get_ops_path())
except Exception as e:
get_root_logger().warning(
'Can not load custom ops because:\n'
f'{e}\n'
'Some model might not be able to be converted.')
if isinstance(torchscript_model, str):
torchscript_model = torch.jit.load(torchscript_model)
inputs = []
outputs = []
for name in input_names:
shape = create_shape(name, input_shapes[name])
inputs.append(shape)
for name in output_names:
outputs.append(ct.TensorType(name=name))
if convert_to == 'neuralnetwork':
compute_precision = None
else:
if fp16_mode:
compute_precision = ct.precision.FLOAT16
else:
compute_precision = ct.precision.FLOAT32
mlmodel = ct.convert(
model=torchscript_model,
inputs=inputs,
outputs=outputs,
compute_precision=compute_precision,
convert_to=convert_to,
skip_model_load=False)
suffix = get_model_suffix(convert_to)
output_path = output_file_prefix + suffix
mlmodel.save(output_path)

View File

@ -0,0 +1,67 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
import coremltools as ct
import numpy as np
import torch
from mmdeploy.utils import Backend
from mmdeploy.utils.timer import TimeCounter
from ..base import BACKEND_WRAPPER, BaseWrapper
@BACKEND_WRAPPER.register_module(Backend.COREML.value)
class CoreMLWrapper(BaseWrapper):
"""CoreML wrapper class for inference.
Args:
model_file (str): Path of a mlpackage file.
bin_file (str): Path of a binary file.
Examples:
>>> from mmdeploy.backend.coreml import CoreMLWrapper
>>> import torch
>>>
>>> model_file = 'model.mlpackage'
>>> model = CoreMLWrapper(model_file)
>>> inputs = dict(input=torch.randn(1, 3, 224, 224))
>>> outputs = model(inputs)
>>> print(outputs)
"""
def __init__(self, model_file: str):
self.model = ct.models.model.MLModel(
model_file, compute_units=ct.ComputeUnit.ALL)
spec = self.model.get_spec()
output_names = [out.name for out in spec.description.output]
super().__init__(output_names)
def forward(self, inputs: Dict[str,
torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Run forward inference.
Args:
inputs (Dict[str, torch.Tensor]): Key-value pairs of model inputs.
Returns:
Dict[str, torch.Tensor]: Key-value pairs of model outputs.
"""
model_inputs = dict(
(k, v.detach().cpu().numpy()) for k, v in inputs.items())
output = self.__execute(model_inputs)
for name, tensor in output.items():
output[name] = torch.from_numpy(tensor)
return output
@TimeCounter.count_time(Backend.COREML.value)
def __execute(self, inputs: Dict[str, np.ndarray]) -> Dict:
"""Run inference with CoreML.
Args:
inputs (Dict[str, np.ndarray]): Input data with keys.
Returns:
Dict[str, np.ndarray]: Inference results with keys.
"""
return self.model.predict(inputs)

View File

@ -136,6 +136,11 @@ def get_models(deploy_cfg: Union[str, mmcv.Config],
net = replace_suffix(ir_name, '.dlc')
elif backend in [Backend.ONNXRUNTIME, Backend.TORCHSCRIPT]:
pass
elif backend == Backend.COREML:
from mmdeploy.backend.coreml import get_model_suffix
convert_to = deploy_cfg.backend_config.convert_to
suffix = get_model_suffix(convert_to)
net = replace_suffix(ir_name, suffix)
else:
raise NotImplementedError(f'Not supported backend: {backend.value}.')

View File

@ -116,6 +116,9 @@ class BaseBackendModel(torch.nn.Module, metaclass=ABCMeta):
uri = kwargs['uri']
return SNPEWrapper(
dlc_file=backend_files[0], uri=uri, output_names=output_names)
elif backend == Backend.COREML:
from mmdeploy.backend.coreml import CoreMLWrapper
return CoreMLWrapper(model_file=backend_files[0])
else:
raise NotImplementedError(f'Unknown backend type: {backend.value}')

View File

@ -1,11 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .core import * # noqa: F401,F403
from .deploy import (MMDetection, ObjectDetection, clip_bboxes,
from .deploy import (MMDetection, ObjectDetection, clip_bboxes, gather_topk,
get_post_processing_params, pad_with_value,
pad_with_value_if_necessary)
from .models import * # noqa: F401,F403
__all__ = [
'get_post_processing_params', 'clip_bboxes', 'pad_with_value',
'pad_with_value_if_necessary', 'MMDetection', 'ObjectDetection'
'pad_with_value_if_necessary', 'gather_topk', 'MMDetection',
'ObjectDetection'
]

View File

@ -5,7 +5,8 @@ from torch import Tensor
import mmdeploy
from mmdeploy.core import FUNCTION_REWRITER, mark
from mmdeploy.mmcv.ops import ONNXNMSop, TRTBatchedNMSop
from mmdeploy.utils import Backend, is_dynamic_batch
from mmdeploy.utils import IR, is_dynamic_batch
from mmdeploy.utils.constants import Backend
def select_nms_index(scores: torch.Tensor,
@ -272,7 +273,68 @@ def multiclass_nms(*args, **kwargs):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.codebase.mmdet.core.post_processing._multiclass_nms',
backend=Backend.TORCHSCRIPT.value)
backend=Backend.COREML.value)
def multiclass_nms__coreml(ctx,
boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = -1):
"""rewrite for coreml batched nms.
Use coreml_nms from custom ops.
"""
# load custom nms
from mmdeploy.backend.torchscript import get_ops_path, ops_available
assert ops_available(), 'coreml require custom torchscript ops support.'
torch.ops.load_library(get_ops_path())
try:
coreml_nms = torch.ops.mmdeploy.coreml_nms
except Exception:
raise Exception(
'Can not use coreml_nms. Please build torchscript custom ops.')
batch_size = scores.shape[0]
assert batch_size == 1, 'batched nms is not supported for now.'
# pre-topk
if pre_top_k > 0:
max_scores, _ = scores.max(-1)
_, topk_inds = max_scores.topk(pre_top_k)
boxes = boxes[:, topk_inds.squeeze(), ...]
scores = scores[:, topk_inds.squeeze(), ...]
def _xyxy2xywh(boxes):
xy0 = boxes[..., :2]
xy1 = boxes[..., 2:]
xy = (xy0 + xy1) / 2
wh = xy1 - xy0
return torch.cat([xy, wh], dim=-1)
def _xywh2xyxy(boxes):
xy = boxes[..., :2]
half_wh = boxes[..., 2:] / 2
return torch.cat([xy - half_wh, xy + half_wh], dim=-1)
boxes = _xyxy2xywh(boxes)
keep_top_k = keep_top_k if keep_top_k > 0 else max_output_boxes_per_class
boxes, scores, _, _ = coreml_nms(
boxes, scores, iou_threshold, score_threshold,
min(keep_top_k, max_output_boxes_per_class))
scores, labels = scores.max(-1)
boxes = _xywh2xyxy(boxes)
dets = torch.cat([boxes, scores.unsqueeze(-1)], dim=-1)
return dets, labels
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.codebase.mmdet.core.post_processing._multiclass_nms',
ir=IR.TORCHSCRIPT)
def multiclass_nms__torchscript(ctx,
boxes: Tensor,
scores: Tensor,

View File

@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .mmdetection import MMDetection
from .object_detection import ObjectDetection
from .utils import (clip_bboxes, get_post_processing_params, pad_with_value,
pad_with_value_if_necessary)
from .utils import (clip_bboxes, gather_topk, get_post_processing_params,
pad_with_value, pad_with_value_if_necessary)
__all__ = [
'get_post_processing_params', 'clip_bboxes', 'pad_with_value',
'pad_with_value_if_necessary', 'MMDetection', 'ObjectDetection'
'pad_with_value_if_necessary', 'gather_topk', 'MMDetection',
'ObjectDetection'
]

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Sequence, Union
from typing import Any, Optional, Sequence, Tuple, Union
import mmcv
import torch
@ -179,3 +179,62 @@ def __pad_with_value_if_necessary__tensorrt(ctx,
Tensor: Padded tensor.
"""
return pad_with_value(x, pad_dim, pad_size=pad_size, pad_value=pad_value)
def __gather_topk(*inputs: Sequence[torch.Tensor],
inds: torch.Tensor,
batch_size: int,
is_batched: bool = True) -> Tuple[torch.Tensor]:
"""The default implementation of gather_topk."""
if is_batched:
batch_inds = torch.arange(batch_size, device=inds.device).unsqueeze(-1)
outputs = [
x[batch_inds, inds, ...] if x is not None else None for x in inputs
]
else:
prior_inds = inds.new_zeros((1, 1))
outputs = [
x[prior_inds, inds, ...] if x is not None else None for x in inputs
]
return outputs
@FUNCTION_REWRITER.register_rewriter(
'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk',
backend=Backend.COREML.value)
def __gather_topk__nonbatch(ctx,
*inputs: Sequence[torch.Tensor],
inds: torch.Tensor,
batch_size: int,
is_batched: bool = True) -> Tuple[torch.Tensor]:
"""Single batch gather_topk."""
assert batch_size == 1
inds = inds.squeeze(0)
outputs = [x[:, inds, ...] if x is not None else None for x in inputs]
return outputs
def gather_topk(*inputs: Sequence[torch.Tensor],
inds: torch.Tensor,
batch_size: int,
is_batched: bool = True) -> Tuple[torch.Tensor]:
"""Gather topk of each tensor.
Args:
inputs (Sequence[torch.Tensor]): Tensors to be gathered.
inds (torch.Tensor): Topk index.
batch_size (int): batch_size.
is_batched (bool): Inputs is batched or not.
Returns:
Tuple[torch.Tensor]: Gathered tensors.
"""
import mmdeploy
outputs = mmdeploy.codebase.mmdet.deploy.utils.__gather_topk(
*inputs, inds=inds, batch_size=batch_size, is_batched=is_batched)
if len(outputs) == 1:
outputs = outputs[0]
return outputs

View File

@ -5,7 +5,7 @@ from mmdet.core.bbox.coder import (DeltaXYWHBBoxCoder, DistancePointBBoxCoder,
from mmdet.core.bbox.transforms import distance2bbox
from mmdet.models.dense_heads import PAAHead
from mmdeploy.codebase.mmdet import (get_post_processing_params,
from mmdeploy.codebase.mmdet import (gather_topk, get_post_processing_params,
multiclass_nms,
pad_with_value_if_necessary)
from mmdeploy.codebase.mmdet.core.ops import ncnn_detection_output_forward
@ -129,14 +129,18 @@ def base_dense_head__get_bbox(ctx,
else:
max_scores, _ = nms_pre_score[..., :-1].max(-1)
_, topk_inds = max_scores.topk(pre_topk)
batch_inds = torch.arange(
batch_size, device=bbox_pred.device).unsqueeze(-1)
prior_inds = batch_inds.new_zeros((1, 1))
priors = priors[prior_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
if with_score_factors:
score_factors = score_factors[batch_inds, topk_inds, :]
bbox_pred, scores, score_factors = gather_topk(
bbox_pred,
scores,
score_factors,
inds=topk_inds,
batch_size=batch_size,
is_batched=True)
priors = gather_topk(
priors,
inds=topk_inds,
batch_size=batch_size,
is_batched=False)
mlvl_valid_bboxes.append(bbox_pred)
mlvl_valid_scores.append(scores)

View File

@ -4,7 +4,9 @@ from .adaptive_pool import (adaptive_avg_pool2d__default,
adaptive_avg_pool2d__ncnn)
from .atan2 import atan2__default
from .chunk import chunk__ncnn, chunk__torchscript
from .clip import clip__coreml
from .expand import expand__ncnn
from .flatten import flatten__coreml
from .getattribute import tensor__getattribute__ncnn
from .group_norm import group_norm__ncnn
from .interpolate import interpolate__ncnn, interpolate__tensorrt
@ -26,5 +28,5 @@ __all__ = [
'chunk__torchscript', 'masked_fill__onnxruntime',
'tensor__setitem__default', 'tensor__getitem__ascend',
'adaptive_avg_pool2d__default', 'adaptive_avg_pool2d__ncnn',
'multi_head_attention_forward'
'multi_head_attention_forward', 'flatten__coreml', 'clip__coreml'
]

View File

@ -0,0 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.clip', backend=Backend.COREML.value)
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.clip', backend=Backend.COREML.value)
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.clamp', backend=Backend.COREML.value)
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.clamp', backend=Backend.COREML.value)
def clip__coreml(ctx, input, min=None, max=None, **kwargs) -> torch.Tensor:
"""Rewrite `clip` for coreml backend.
Cast data type.
"""
if min is not None and not isinstance(min, torch.Tensor):
min = input.new_tensor(min)
if max is not None and not isinstance(max, torch.Tensor):
max = input.new_tensor(max)
return ctx.origin_func(input, min=min, max=max, **kwargs)

View File

@ -0,0 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.flatten', backend=Backend.COREML.value)
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.flatten', backend=Backend.COREML.value)
def flatten__coreml(ctx, input, start_dim=0, end_dim=-1) -> torch.Tensor:
"""Rewrite `flatten` for coreml backend.
Use reshape instead of flatten
"""
shape = input.shape
end_dim = end_dim if end_dim > 0 else len(shape) + end_dim
shape1 = list(shape[:start_dim])
shape3 = list(shape[end_dim + 1:])
return input.reshape(shape1 + [-1] + shape3)

View File

@ -60,6 +60,7 @@ class Backend(AdvancedEnum):
SDK = 'sdk'
TORCHSCRIPT = 'torchscript'
ASCEND = 'ascend'
COREML = 'coreml'
DEFAULT = 'default'

View File

@ -47,6 +47,9 @@ def check_backend():
import mmdeploy.apis.ascend as ascend_apis
logger.info(f'ascend_is_available: {ascend_apis.is_available()}')
import mmdeploy.apis.coreml as coreml_apis
logger.info(f'coreml_is_available: {coreml_apis.is_available()}')
def check_codebase():
codebase_versions = get_codebase_version()

View File

@ -370,6 +370,23 @@ def main():
from mmdeploy.backend.ascend import update_sdk_pipeline
update_sdk_pipeline(args.work_dir)
elif backend == Backend.COREML:
from mmdeploy.apis.coreml import from_torchscript, get_model_suffix
coreml_pipeline_funcs = [from_torchscript]
PIPELINE_MANAGER.set_log_level(log_level, coreml_pipeline_funcs)
model_inputs = get_model_inputs(deploy_cfg)
coreml_files = []
for model_id, torchscript_path in enumerate(ir_files):
torchscript_name = osp.splitext(osp.split(torchscript_path)[1])[0]
output_file_prefix = osp.join(args.work_dir, torchscript_name)
convert_to = deploy_cfg.backend_config.convert_to
from_torchscript(torchscript_path, output_file_prefix,
ir_config.input_names, ir_config.output_names,
model_inputs[model_id].input_shapes, convert_to)
suffix = get_model_suffix(convert_to)
coreml_files.append(output_file_prefix + suffix)
backend_files = coreml_files
if args.test_img is None:
args.test_img = args.img