mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add MPS bbox overlap (#2123)
* add mps bbox overlap * format * update document and manifest * update readmepull/2142/head
parent
73066430be
commit
22fadceecd
|
@ -3,4 +3,5 @@ include mmcv/model_zoo/open_mmlab.json mmcv/model_zoo/deprecated.json mmcv/model
|
|||
include mmcv/ops/csrc/common/cuda/*.cuh mmcv/ops/csrc/common/cuda/*.hpp mmcv/ops/csrc/common/*.hpp
|
||||
include mmcv/ops/csrc/pytorch/*.cpp mmcv/ops/csrc/pytorch/cuda/*.cu mmcv/ops/csrc/pytorch/cuda/*.cpp mmcv/ops/csrc/pytorch/cpu/*.cpp
|
||||
include mmcv/ops/csrc/parrots/*.h mmcv/ops/csrc/parrots/*.cpp
|
||||
recursive-include mmcv/ops/csrc/ *.h *.hpp *.cpp *.cuh *.cu
|
||||
include mmcv/ops/csrc/pytorch/mps/*.mm mmcv/ops/csrc/common/mps/*.h mmcv/ops/csrc/common/mps/*.mm
|
||||
recursive-include mmcv/ops/csrc/ *.h *.hpp *.cpp *.cuh *.cu *.mm
|
||||
|
|
|
@ -2,58 +2,58 @@
|
|||
|
||||
We implement common ops used in detection, segmentation, etc.
|
||||
|
||||
| Device | CPU | CUDA | MLU |
|
||||
| ---------------------------- | --- | ---- | --- |
|
||||
| ActiveRotatedFilter | √ | √ | |
|
||||
| AssignScoreWithK | | √ | |
|
||||
| BallQuery | | √ | |
|
||||
| BBoxOverlaps | | √ | √ |
|
||||
| BorderAlign | | √ | |
|
||||
| BoxIouRotated | √ | √ | |
|
||||
| CARAFE | | √ | |
|
||||
| ChamferDistance | | √ | |
|
||||
| CrissCrossAttention | | √ | |
|
||||
| ContourExpand | √ | | |
|
||||
| ConvexIoU | | √ | |
|
||||
| CornerPool | | √ | |
|
||||
| Correlation | | √ | |
|
||||
| Deformable Convolution v1/v2 | √ | √ | |
|
||||
| Deformable RoIPool | | √ | |
|
||||
| DiffIoURotated | | √ | |
|
||||
| DynamicScatter | | √ | |
|
||||
| FurthestPointSample | | √ | |
|
||||
| FurthestPointSampleWithDist | | √ | |
|
||||
| FusedBiasLeakyrelu | | √ | |
|
||||
| GatherPoints | | √ | |
|
||||
| GroupPoints | | √ | |
|
||||
| Iou3d | | √ | |
|
||||
| KNN | | √ | |
|
||||
| MaskedConv | | √ | |
|
||||
| MergeCells | | √ | |
|
||||
| MinAreaPolygon | | √ | |
|
||||
| ModulatedDeformConv2d | √ | √ | |
|
||||
| MultiScaleDeformableAttn | | √ | |
|
||||
| NMS | √ | √ | √ |
|
||||
| NMSRotated | √ | √ | |
|
||||
| PixelGroup | √ | | |
|
||||
| PointsInBoxes | √ | √ | |
|
||||
| PointsInPolygons | | √ | |
|
||||
| PSAMask | √ | √ | √ |
|
||||
| RotatedFeatureAlign | √ | √ | |
|
||||
| RoIPointPool3d | | √ | |
|
||||
| RoIPool | | √ | √ |
|
||||
| RoIAlignRotated | √ | √ | √ |
|
||||
| RiRoIAlignRotated | | √ | |
|
||||
| RoIAlign | √ | √ | √ |
|
||||
| RoIAwarePool3d | | √ | |
|
||||
| SAConv2d | | √ | |
|
||||
| SigmoidFocalLoss | | √ | √ |
|
||||
| SoftmaxFocalLoss | | √ | |
|
||||
| SoftNMS | | √ | |
|
||||
| Sparse Convolution | | √ | |
|
||||
| Synchronized BatchNorm | | √ | |
|
||||
| ThreeInterpolate | | √ | |
|
||||
| ThreeNN | | √ | |
|
||||
| TINShift | | √ | √ |
|
||||
| UpFirDn2d | | √ | |
|
||||
| Voxelization | √ | √ | |
|
||||
| Device | CPU | CUDA | MLU | MPS |
|
||||
| ---------------------------- | --- | ---- | --- | --- |
|
||||
| ActiveRotatedFilter | √ | √ | | |
|
||||
| AssignScoreWithK | | √ | | |
|
||||
| BallQuery | | √ | | |
|
||||
| BBoxOverlaps | | √ | √ | √ |
|
||||
| BorderAlign | | √ | | |
|
||||
| BoxIouRotated | √ | √ | | |
|
||||
| CARAFE | | √ | | |
|
||||
| ChamferDistance | | √ | | |
|
||||
| CrissCrossAttention | | √ | | |
|
||||
| ContourExpand | √ | | | |
|
||||
| ConvexIoU | | √ | | |
|
||||
| CornerPool | | √ | | |
|
||||
| Correlation | | √ | | |
|
||||
| Deformable Convolution v1/v2 | √ | √ | | |
|
||||
| Deformable RoIPool | | √ | | |
|
||||
| DiffIoURotated | | √ | | |
|
||||
| DynamicScatter | | √ | | |
|
||||
| FurthestPointSample | | √ | | |
|
||||
| FurthestPointSampleWithDist | | √ | | |
|
||||
| FusedBiasLeakyrelu | | √ | | |
|
||||
| GatherPoints | | √ | | |
|
||||
| GroupPoints | | √ | | |
|
||||
| Iou3d | | √ | | |
|
||||
| KNN | | √ | | |
|
||||
| MaskedConv | | √ | | |
|
||||
| MergeCells | | √ | | |
|
||||
| MinAreaPolygon | | √ | | |
|
||||
| ModulatedDeformConv2d | √ | √ | | |
|
||||
| MultiScaleDeformableAttn | | √ | | |
|
||||
| NMS | √ | √ | √ | |
|
||||
| NMSRotated | √ | √ | | |
|
||||
| PixelGroup | √ | | | |
|
||||
| PointsInBoxes | √ | √ | | |
|
||||
| PointsInPolygons | | √ | | |
|
||||
| PSAMask | √ | √ | √ | |
|
||||
| RotatedFeatureAlign | √ | √ | | |
|
||||
| RoIPointPool3d | | √ | | |
|
||||
| RoIPool | | √ | √ | |
|
||||
| RoIAlignRotated | √ | √ | √ | |
|
||||
| RiRoIAlignRotated | | √ | | |
|
||||
| RoIAlign | √ | √ | √ | |
|
||||
| RoIAwarePool3d | | √ | | |
|
||||
| SAConv2d | | √ | | |
|
||||
| SigmoidFocalLoss | | √ | √ | |
|
||||
| SoftmaxFocalLoss | | √ | | |
|
||||
| SoftNMS | | √ | | |
|
||||
| Sparse Convolution | | √ | | |
|
||||
| Synchronized BatchNorm | | √ | | |
|
||||
| ThreeInterpolate | | √ | | |
|
||||
| ThreeNN | | √ | | |
|
||||
| TINShift | | √ | √ | |
|
||||
| UpFirDn2d | | √ | | |
|
||||
| Voxelization | √ | √ | | |
|
||||
|
|
|
@ -2,58 +2,58 @@
|
|||
|
||||
MMCV 提供了检测、分割等任务中常用的算子
|
||||
|
||||
| Device | CPU | CUDA | MLU |
|
||||
| ---------------------------- | --- | ---- | --- |
|
||||
| ActiveRotatedFilter | √ | √ | |
|
||||
| AssignScoreWithK | | √ | |
|
||||
| BallQuery | | √ | |
|
||||
| BBoxOverlaps | | √ | √ |
|
||||
| BorderAlign | | √ | |
|
||||
| BoxIouRotated | √ | √ | |
|
||||
| CARAFE | | √ | |
|
||||
| ChamferDistance | | √ | |
|
||||
| CrissCrossAttention | | √ | |
|
||||
| ContourExpand | √ | | |
|
||||
| ConvexIoU | | √ | |
|
||||
| CornerPool | | √ | |
|
||||
| Correlation | | √ | |
|
||||
| Deformable Convolution v1/v2 | √ | √ | |
|
||||
| Deformable RoIPool | | √ | |
|
||||
| DiffIoURotated | | √ | |
|
||||
| DynamicScatter | | √ | |
|
||||
| FurthestPointSample | | √ | |
|
||||
| FurthestPointSampleWithDist | | √ | |
|
||||
| FusedBiasLeakyrelu | | √ | |
|
||||
| GatherPoints | | √ | |
|
||||
| GroupPoints | | √ | |
|
||||
| Iou3d | | √ | |
|
||||
| KNN | | √ | |
|
||||
| MaskedConv | | √ | |
|
||||
| MergeCells | | √ | |
|
||||
| MinAreaPolygon | | √ | |
|
||||
| ModulatedDeformConv2d | √ | √ | |
|
||||
| MultiScaleDeformableAttn | | √ | |
|
||||
| NMS | √ | √ | √ |
|
||||
| NMSRotated | √ | √ | |
|
||||
| PixelGroup | √ | | |
|
||||
| PointsInBoxes | √ | √ | |
|
||||
| PointsInPolygons | | √ | |
|
||||
| PSAMask | √ | √ | √ |
|
||||
| RotatedFeatureAlign | √ | √ | |
|
||||
| RoIPointPool3d | | √ | |
|
||||
| RoIPool | | √ | √ |
|
||||
| RoIAlignRotated | √ | √ | √ |
|
||||
| RiRoIAlignRotated | | √ | |
|
||||
| RoIAlign | √ | √ | √ |
|
||||
| RoIAwarePool3d | | √ | |
|
||||
| SAConv2d | | √ | |
|
||||
| SigmoidFocalLoss | | √ | √ |
|
||||
| SoftmaxFocalLoss | | √ | |
|
||||
| SoftNMS | | √ | |
|
||||
| Sparse Convolution | | √ | |
|
||||
| Synchronized BatchNorm | | √ | |
|
||||
| ThreeInterpolate | | √ | |
|
||||
| ThreeNN | | √ | |
|
||||
| TINShift | | √ | √ |
|
||||
| UpFirDn2d | | √ | |
|
||||
| Voxelization | √ | √ | |
|
||||
| Device | CPU | CUDA | MLU | MPS |
|
||||
| ---------------------------- | --- | ---- | --- | --- |
|
||||
| ActiveRotatedFilter | √ | √ | | |
|
||||
| AssignScoreWithK | | √ | | |
|
||||
| BallQuery | | √ | | |
|
||||
| BBoxOverlaps | | √ | √ | √ |
|
||||
| BorderAlign | | √ | | |
|
||||
| BoxIouRotated | √ | √ | | |
|
||||
| CARAFE | | √ | | |
|
||||
| ChamferDistance | | √ | | |
|
||||
| CrissCrossAttention | | √ | | |
|
||||
| ContourExpand | √ | | | |
|
||||
| ConvexIoU | | √ | | |
|
||||
| CornerPool | | √ | | |
|
||||
| Correlation | | √ | | |
|
||||
| Deformable Convolution v1/v2 | √ | √ | | |
|
||||
| Deformable RoIPool | | √ | | |
|
||||
| DiffIoURotated | | √ | | |
|
||||
| DynamicScatter | | √ | | |
|
||||
| FurthestPointSample | | √ | | |
|
||||
| FurthestPointSampleWithDist | | √ | | |
|
||||
| FusedBiasLeakyrelu | | √ | | |
|
||||
| GatherPoints | | √ | | |
|
||||
| GroupPoints | | √ | | |
|
||||
| Iou3d | | √ | | |
|
||||
| KNN | | √ | | |
|
||||
| MaskedConv | | √ | | |
|
||||
| MergeCells | | √ | | |
|
||||
| MinAreaPolygon | | √ | | |
|
||||
| ModulatedDeformConv2d | √ | √ | | |
|
||||
| MultiScaleDeformableAttn | | √ | | |
|
||||
| NMS | √ | √ | √ | |
|
||||
| NMSRotated | √ | √ | | |
|
||||
| PixelGroup | √ | | | |
|
||||
| PointsInBoxes | √ | √ | | |
|
||||
| PointsInPolygons | | √ | | |
|
||||
| PSAMask | √ | √ | √ | |
|
||||
| RotatedFeatureAlign | √ | √ | | |
|
||||
| RoIPointPool3d | | √ | | |
|
||||
| RoIPool | | √ | √ | |
|
||||
| RoIAlignRotated | √ | √ | √ | |
|
||||
| RiRoIAlignRotated | | √ | | |
|
||||
| RoIAlign | √ | √ | √ | |
|
||||
| RoIAwarePool3d | | √ | | |
|
||||
| SAConv2d | | √ | | |
|
||||
| SigmoidFocalLoss | | √ | √ | |
|
||||
| SoftmaxFocalLoss | | √ | | |
|
||||
| SoftNMS | | √ | | |
|
||||
| Sparse Convolution | | √ | | |
|
||||
| Synchronized BatchNorm | | √ | | |
|
||||
| ThreeInterpolate | | √ | | |
|
||||
| ThreeNN | | √ | | |
|
||||
| TINShift | | √ | √ | |
|
||||
| UpFirDn2d | | √ | | |
|
||||
| Voxelization | √ | √ | | |
|
||||
|
|
|
@ -13,11 +13,19 @@ This folder contains all non-python code for MMCV custom ops. Please follow the
|
|||
│ ├── pytorch_cpp_helper.hpp
|
||||
│ ├── pytorch_cuda_helper.hpp
|
||||
│ ├── pytorch_device_registry.hpp
|
||||
│ └── cuda
|
||||
│ ├── common_cuda_helper.hpp
|
||||
│ ├── parrots_cudawarpfunction.cuh
|
||||
│ ├── ...
|
||||
│ └── ops_cuda_kernel.cuh
|
||||
│ ├── cuda
|
||||
│ │ ├── common_cuda_helper.hpp
|
||||
│ │ ├── parrots_cudawarpfunction.cuh
|
||||
│ │ ├── ...
|
||||
│ │ └── ops_cuda_kernel.cuh
|
||||
| ├── mps
|
||||
│ │ ├── MPSLibrary.h
|
||||
│ │ ├── ...
|
||||
│ │ └── MPSUtils.h
|
||||
| ├── mlu
|
||||
│ │ └── ...
|
||||
| └── utils
|
||||
│ │ └── ...
|
||||
├── onnxruntime
|
||||
│ ├── onnxruntime_register.h
|
||||
│ ├── onnxruntime_session_options_config_keys.h
|
||||
|
@ -41,9 +49,15 @@ This folder contains all non-python code for MMCV custom ops. Please follow the
|
|||
│ ├── cuda
|
||||
│ │ ├── ...
|
||||
│ │ └── ops_cuda.cu
|
||||
│ └── cpu
|
||||
│ ├── cpu
|
||||
│ │ ├── ...
|
||||
│ │ └── ops.cpp
|
||||
│ ├── mps
|
||||
│ │ ├── ...
|
||||
│ | └── op_mps.mm
|
||||
│ └── mlu
|
||||
│ ├── ...
|
||||
│ └── ops.cpp
|
||||
│ └── op_mlu.cpp
|
||||
└── tensorrt
|
||||
├── trt_cuda_helper.cuh
|
||||
├── trt_plugin_helper.hpp
|
||||
|
@ -63,13 +77,18 @@ This folder contains all non-python code for MMCV custom ops. Please follow the
|
|||
|
||||
- `common`: This directory contains all tools and shared codes.
|
||||
- `cuda`: The cuda kernels which can be shared by all backends. **HIP** kernel is also here since they have similar syntax.
|
||||
- `onnxruntime`: **ONNX Runtime** support for custom ops.
|
||||
- `mps`: The tools used to support MPS ops. **NOTE** that MPS support is **experimental**.
|
||||
- `mlu`: The MLU kernels used to support [Cambricon](https://www.cambricon.com/) device.
|
||||
- `utils`: The kernels and utils of spconv.
|
||||
- `onnxruntime`: **ONNX Runtime** support for custom ops. Has been deprecated, please try the latest custom ops in [MMDeploy](https://github.com/open-mmlab/mmdeploy).
|
||||
- `cpu`: CPU implementation of supported ops.
|
||||
- `parrots`: **Parrots** is a deep learning frame for model training and inference. Parrots custom ops are placed in this directory.
|
||||
- `pytorch`: **PyTorch** custom ops are supported by binding C++ to Python with **pybind11**. The ops implementation and binding codes are placed in this directory.
|
||||
- `cuda`: This directory contains cuda kernel launchers, which feed memory pointers of tensor to the cuda kernel in `common/cuda`. The launchers provide c++ interface of cuda implementation of corresponding custom ops.
|
||||
- `cpu`: This directory contain cpu implementations of corresponding custom ops.
|
||||
- `tensorrt`: **TensorRT** support for custom ops.
|
||||
- `mlu`: This directory contain launchers of each MLU kernels.
|
||||
- `mps`: MPS ops implementation and launchers.
|
||||
- `tensorrt`: **TensorRT** support for custom ops. Has been deprecated, please try the latest custom ops in [MMDeploy](https://github.com/open-mmlab/mmdeploy).
|
||||
- `plugins`: This directory contains the implementation of the supported custom ops. Some ops might also use shared cuda kernel in `common/cuda`.
|
||||
|
||||
## How to add new PyTorch ops?
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
// This file is modify from:
|
||||
// https://github.com/pytorch/pytorch/blob/a85d1f0bcdd02cf18d3b0517337458cb51a18cdb/aten/src/ATen/mps/MPSDevice.h
|
||||
|
||||
#pragma once
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#ifdef __OBJC__
|
||||
#include <Foundation/Foundation.h>
|
||||
#include <Metal/Metal.h>
|
||||
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
typedef id<MTLDevice> MTLDevice_t;
|
||||
#else
|
||||
typedef void* MTLDevice;
|
||||
typedef void* MTLDevice_t;
|
||||
#endif
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace at {
|
||||
namespace mps {
|
||||
|
||||
//-----------------------------------------------------------------
|
||||
// MPSDevice
|
||||
//
|
||||
// MPSDevice is a singleton class that returns the default device
|
||||
//-----------------------------------------------------------------
|
||||
|
||||
class TORCH_API MPSDevice {
|
||||
public:
|
||||
/**
|
||||
* MPSDevice should not be cloneable.
|
||||
*/
|
||||
MPSDevice(MPSDevice& other) = delete;
|
||||
/**
|
||||
* MPSDevice should not be assignable.
|
||||
*/
|
||||
void operator=(const MPSDevice&) = delete;
|
||||
/**
|
||||
* Gets single instance of the Device.
|
||||
*/
|
||||
static MPSDevice* getInstance();
|
||||
/**
|
||||
* Returns the single device.
|
||||
*/
|
||||
MTLDevice_t device() { return _mtl_device; }
|
||||
|
||||
~MPSDevice();
|
||||
|
||||
private:
|
||||
static MPSDevice* _device;
|
||||
MTLDevice_t _mtl_device;
|
||||
MPSDevice();
|
||||
};
|
||||
|
||||
TORCH_API bool is_available();
|
||||
|
||||
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
|
||||
|
||||
} // namespace mps
|
||||
} // namespace at
|
|
@ -0,0 +1,61 @@
|
|||
#ifndef _MPS_LIBRARY_H_
|
||||
#define _MPS_LIBRARY_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#ifdef __OBJC__
|
||||
#include <Foundation/Foundation.h>
|
||||
#include <Metal/Metal.h>
|
||||
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
|
||||
typedef id<MTLComputePipelineState> MTLComputePipelineState_t;
|
||||
typedef id<MTLLibrary> MTLLibrary_t;
|
||||
#else
|
||||
typedef void* MTLComputePipelineState;
|
||||
typedef void* MTLComputePipelineState_t;
|
||||
typedef void* MTLLibrary;
|
||||
typedef void* MTLLibrary_t;
|
||||
#endif
|
||||
|
||||
class MPSLibrary {
|
||||
public:
|
||||
// disable constructor for singleton
|
||||
static MPSLibrary* createFromUrl(const std::string& library_url);
|
||||
static MPSLibrary* createFromSource(const std::string& source);
|
||||
~MPSLibrary();
|
||||
|
||||
MTLLibrary_t library() { return _library; }
|
||||
|
||||
MTLComputePipelineState_t getComputePipelineState(
|
||||
const std::string& function_name);
|
||||
|
||||
private:
|
||||
MTLLibrary_t _library;
|
||||
std::unordered_map<std::string, MTLComputePipelineState_t> _pso_map;
|
||||
};
|
||||
|
||||
class MPSLibraryManager {
|
||||
public:
|
||||
// disable constructor for singleton
|
||||
MPSLibraryManager(const MPSLibraryManager&) = delete;
|
||||
MPSLibraryManager& operator=(const MPSLibraryManager&) = delete;
|
||||
MPSLibraryManager(MPSLibraryManager&&) = delete;
|
||||
MPSLibraryManager& operator=(MPSLibraryManager&&) = delete;
|
||||
|
||||
static MPSLibraryManager* getInstance();
|
||||
|
||||
bool hasLibrary(const std::string& name);
|
||||
|
||||
MPSLibrary* getLibrary(const std::string& library_url);
|
||||
|
||||
MPSLibrary* createLibraryFromSouce(const std::string& name,
|
||||
const std::string& sources);
|
||||
|
||||
~MPSLibraryManager();
|
||||
|
||||
private:
|
||||
MPSLibraryManager();
|
||||
std::unordered_map<std::string, std::unique_ptr<MPSLibrary>> _library_map;
|
||||
};
|
||||
#endif
|
|
@ -0,0 +1,110 @@
|
|||
#include "MPSLibrary.h"
|
||||
#include <c10/util/CallOnce.h>
|
||||
#include "MPSDevice.h"
|
||||
|
||||
static std::unique_ptr<MPSLibraryManager> mps_library_manager;
|
||||
static c10::once_flag mpsdev_init;
|
||||
|
||||
MPSLibraryManager* MPSLibraryManager::getInstance() {
|
||||
c10::call_once(mpsdev_init, [] {
|
||||
mps_library_manager = std::unique_ptr<MPSLibraryManager>(new MPSLibraryManager());
|
||||
});
|
||||
return mps_library_manager.get();
|
||||
}
|
||||
|
||||
MPSLibraryManager::~MPSLibraryManager() {}
|
||||
|
||||
MPSLibraryManager::MPSLibraryManager() {}
|
||||
|
||||
bool MPSLibraryManager::hasLibrary(const std::string& name) {
|
||||
return _library_map.find(name) != _library_map.end();
|
||||
}
|
||||
|
||||
MPSLibrary* MPSLibraryManager::getLibrary(const std::string& library_url) {
|
||||
if (_library_map.find(library_url) != _library_map.end()) {
|
||||
return _library_map[library_url].get();
|
||||
}
|
||||
_library_map.emplace(std::make_pair(
|
||||
library_url, std::unique_ptr<MPSLibrary>(MPSLibrary::createFromUrl(library_url))));
|
||||
return _library_map[library_url].get();
|
||||
}
|
||||
|
||||
MPSLibrary* MPSLibraryManager::createLibraryFromSouce(const std::string& name,
|
||||
const std::string& source) {
|
||||
NSString* ns_name = [NSString stringWithCString:name.c_str()];
|
||||
if (_library_map.find(name) != _library_map.end()) {
|
||||
NSLog(@"Library %@ already exist.", ns_name);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
_library_map.emplace(
|
||||
std::make_pair(name, std::unique_ptr<MPSLibrary>(MPSLibrary::createFromSource(source))));
|
||||
return _library_map[name].get();
|
||||
}
|
||||
|
||||
MPSLibrary* MPSLibrary::createFromUrl(const std::string& library_url) {
|
||||
MPSLibrary* library = new MPSLibrary();
|
||||
@autoreleasepool {
|
||||
NSError* error = nil;
|
||||
|
||||
// load library and func
|
||||
NSString* utl_str = [NSString stringWithCString:library_url.c_str()];
|
||||
NSURL* metal_url = [NSURL fileURLWithPath:utl_str];
|
||||
library->_library = [at::mps::MPSDevice::getInstance()->device() newLibraryWithURL:metal_url
|
||||
error:&error];
|
||||
if (library->_library == nil) {
|
||||
NSLog(@"Failed to find library, error %@.", error);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
return library;
|
||||
}
|
||||
|
||||
MPSLibrary* MPSLibrary::createFromSource(const std::string& sources) {
|
||||
MPSLibrary* library = new MPSLibrary();
|
||||
@autoreleasepool {
|
||||
NSError* error = nil;
|
||||
|
||||
// load library and func
|
||||
NSString* code_str = [NSString stringWithCString:sources.c_str()];
|
||||
library->_library = [at::mps::MPSDevice::getInstance()->device() newLibraryWithSource:code_str
|
||||
options:nil
|
||||
error:&error];
|
||||
if (library->_library == nil) {
|
||||
NSLog(@"Failed to find library, error %@.", error);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
return library;
|
||||
}
|
||||
|
||||
MPSLibrary::~MPSLibrary() {
|
||||
[_library release];
|
||||
_library = nil;
|
||||
}
|
||||
|
||||
MTLComputePipelineState_t MPSLibrary::getComputePipelineState(const std::string& function_name) {
|
||||
if (_pso_map.find(function_name) != _pso_map.end()) {
|
||||
return _pso_map[function_name];
|
||||
}
|
||||
|
||||
MTLComputePipelineState_t pso;
|
||||
@autoreleasepool {
|
||||
NSError* error = nil;
|
||||
|
||||
// create function
|
||||
NSString* function_name_str = [NSString stringWithCString:function_name.c_str()];
|
||||
id<MTLFunction> func = [_library newFunctionWithName:function_name_str];
|
||||
if (func == nil) {
|
||||
NSLog(@"Failed to created pipeline state object, error %@.", error);
|
||||
exit(1);
|
||||
}
|
||||
// create pipeline
|
||||
pso = [at::mps::MPSDevice::getInstance()->device() newComputePipelineStateWithFunction:func
|
||||
error:&error];
|
||||
_pso_map.emplace(std::make_pair(function_name, pso));
|
||||
}
|
||||
return _pso_map[function_name];
|
||||
}
|
|
@ -0,0 +1,132 @@
|
|||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
// This file is modify from:
|
||||
// https://github.com/pytorch/pytorch/blob/a85d1f0bcdd02cf18d3b0517337458cb51a18cdb/aten/src/ATen/mps/MPSStream.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
|
||||
#include <c10/core/DeviceGuard.h>
|
||||
#include <c10/core/Stream.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include "MPSDevice.h"
|
||||
|
||||
#ifdef __OBJC__
|
||||
#include <Foundation/Foundation.h>
|
||||
#include <Metal/Metal.h>
|
||||
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
|
||||
typedef id<MTLCommandQueue> MTLCommandQueue_t;
|
||||
typedef id<MTLCommandBuffer> MTLCommandBuffer_t;
|
||||
typedef id<MTLSharedEvent> MTLSharedEvent_t;
|
||||
typedef id<MTLDevice> MTLDevice_t;
|
||||
#else
|
||||
typedef void* MTLCommandQueue_t;
|
||||
typedef void* MTLCommandQueue;
|
||||
typedef void* MTLCommandBuffer_t;
|
||||
typedef void* MTLCommandBuffer;
|
||||
typedef void* MTLSharedEvent_t;
|
||||
typedef void* dispatch_queue_t;
|
||||
typedef void* MTLDevice_t;
|
||||
#define nil NULL;
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
namespace mps {
|
||||
|
||||
//-----------------------------------------------------------------
|
||||
// MPSStream
|
||||
//-----------------------------------------------------------------
|
||||
|
||||
class TORCH_API MPSStream {
|
||||
public:
|
||||
enum Unchecked { UNCHECKED };
|
||||
/// Construct a MPSStream from a Stream. This construction is checked,
|
||||
/// and will raise an error if the Stream is not, in fact, a MPS stream.
|
||||
explicit MPSStream(Stream stream);
|
||||
|
||||
~MPSStream();
|
||||
MTLCommandQueue_t commandQueue() const { return _commandQueue; };
|
||||
dispatch_queue_t queue() const { return _serialQueue; }
|
||||
|
||||
MTLCommandBuffer_t commandBuffer();
|
||||
void commit(bool flush);
|
||||
void commitAndWait();
|
||||
void synchronize();
|
||||
|
||||
void flush();
|
||||
|
||||
/// Get the MPS device index that this stream is associated with.
|
||||
c10::DeviceIndex device_index() const { return _stream.device_index(); }
|
||||
|
||||
MTLCommandQueue_t stream() const { return _commandQueue; };
|
||||
|
||||
MTLDevice_t device() const { return [_commandQueue device]; }
|
||||
|
||||
/// Explicit conversion to Stream.
|
||||
Stream unwrap() const { return _stream; }
|
||||
|
||||
private:
|
||||
Stream _stream;
|
||||
MTLCommandQueue_t _commandQueue = nil;
|
||||
MTLCommandBuffer_t _commandBuffer = nil;
|
||||
void _flush(bool commitAndWait) const;
|
||||
|
||||
dispatch_queue_t _serialQueue = nullptr;
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the current MPS stream
|
||||
*/
|
||||
TORCH_API MPSStream* getCurrentMPSStream();
|
||||
|
||||
/**
|
||||
* Get the default MPS stream
|
||||
*/
|
||||
TORCH_API MPSStream* getDefaultMPSStream();
|
||||
|
||||
//-----------------------------------------------------------------
|
||||
// MPSStreamImpl
|
||||
//-----------------------------------------------------------------
|
||||
|
||||
class TORCH_API MPSStreamImpl {
|
||||
public:
|
||||
/**
|
||||
* Gets single instance of the MPSStream.
|
||||
*/
|
||||
static MPSStream* getInstance();
|
||||
|
||||
private:
|
||||
static MPSStream* _stream;
|
||||
MPSStreamImpl();
|
||||
};
|
||||
|
||||
//-----------------------------------------------------------------
|
||||
// MPSEvent
|
||||
//-----------------------------------------------------------------
|
||||
|
||||
struct TORCH_API MPSEvent {
|
||||
MPSEvent();
|
||||
// MPSEvent(id<MTLDevice> device);
|
||||
|
||||
~MPSEvent();
|
||||
MTLSharedEvent_t event() const { return _event; }
|
||||
|
||||
void recordEvent(MPSStream* stream);
|
||||
void waitForEvent(MPSStream* queue); // waits on the cpu
|
||||
bool queryEvent();
|
||||
uint64_t getCurrentValue() { return _currentValue; }
|
||||
void setCurrentValue(uint64_t currValue) { _currentValue = currValue; }
|
||||
|
||||
private:
|
||||
bool _isRecorded = false;
|
||||
uint64_t _currentValue = 0;
|
||||
MTLSharedEvent_t _event;
|
||||
};
|
||||
|
||||
typedef MPSEvent* mpsEvent_t;
|
||||
|
||||
} // namespace mps
|
||||
} // namespace at
|
|
@ -0,0 +1,51 @@
|
|||
#ifndef _MPS_UTILS_H_
|
||||
#define _MPS_UTILS_H_
|
||||
#include <torch/extension.h>
|
||||
#ifdef __OBJC__
|
||||
#include <Foundation/Foundation.h>
|
||||
#include <Metal/Metal.h>
|
||||
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
|
||||
typedef id<MTLBuffer> MTLBuffer_t;
|
||||
typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t;
|
||||
#else
|
||||
typedef void* MTLBuffer;
|
||||
typedef void* MTLBuffer_t;
|
||||
typedef void* MTLComputeCommandEncoder;
|
||||
typedef void* MTLComputeCommandEncoder_t;
|
||||
#endif
|
||||
|
||||
// utils
|
||||
static inline MTLBuffer_t getMTLBufferStorage(const at::Tensor& tensor) {
|
||||
return __builtin_bit_cast(MTLBuffer_t, tensor.storage().data());
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
std::enable_if_t<!std::is_same<std::decay_t<T>, at::Tensor>::value, bool> = true>
|
||||
void setMTLArg(MTLComputeCommandEncoder_t encoder, int index, T&& t);
|
||||
|
||||
template <typename T,
|
||||
std::enable_if_t<std::is_same<std::decay_t<T>, at::Tensor>::value, bool> = true>
|
||||
void setMTLArg(MTLComputeCommandEncoder_t encoder, int index, T&& t) {
|
||||
[encoder setBuffer:getMTLBufferStorage(t) offset:0 atIndex:index];
|
||||
}
|
||||
|
||||
template <typename T, std::enable_if_t<!std::is_same<std::decay_t<T>, at::Tensor>::value, bool>>
|
||||
void setMTLArg(MTLComputeCommandEncoder_t encoder, int index, T&& t) {
|
||||
[encoder setBytes:&t length:sizeof(t) atIndex:index];
|
||||
}
|
||||
|
||||
inline void setMTLArgsImpl(MTLComputeCommandEncoder_t, int) {}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void setMTLArgsImpl(MTLComputeCommandEncoder_t encoder, int index, T&& t, Args&&... args) {
|
||||
setMTLArg(encoder, index, std::forward<T>(t));
|
||||
setMTLArgsImpl(encoder, index + 1, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void setMTLArgs(MTLComputeCommandEncoder_t encoder, MTLComputePipelineState_t pso, Args&&... args) {
|
||||
[encoder setComputePipelineState:pso];
|
||||
setMTLArgsImpl(encoder, 0, std::forward<Args>(args)...);
|
||||
}
|
||||
#endif
|
|
@ -0,0 +1,99 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
#include "pytorch_device_registry.hpp"
|
||||
|
||||
#include "MPSLibrary.h"
|
||||
#include "MPSStream.h"
|
||||
#include "MPSUtils.h"
|
||||
|
||||
using at::Tensor;
|
||||
|
||||
const static std::string kSourceCode = R"(
|
||||
#include <metal_math>
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
kernel void bbox_overlap_mps_kernel(constant const float4* bboxes1,
|
||||
constant const float4* bboxes2,
|
||||
device float* ious,
|
||||
constant int& num_bbox1,
|
||||
constant int& num_bbox2,
|
||||
constant int& mode,
|
||||
constant bool& aligned,
|
||||
constant int& offset,
|
||||
uint index [[thread_position_in_grid]])
|
||||
{
|
||||
int base1 = index;
|
||||
int base2 = index;
|
||||
if(!aligned){
|
||||
base1 = index / num_bbox2;
|
||||
base2 = index % num_bbox2;
|
||||
}
|
||||
|
||||
const float f_offset = float(offset);
|
||||
|
||||
const float4 b1 = bboxes1[base1];
|
||||
const float b1_area = (b1[2]-b1[0]+f_offset)*(b1[3]-b1[1]+f_offset);
|
||||
|
||||
const float4 b2 = bboxes2[base2];
|
||||
const float b2_area = (b2[2]-b2[0]+f_offset)*(b2[3]-b2[1]+f_offset);
|
||||
|
||||
const float2 left_top = fmax(b1.xy, b2.xy);
|
||||
const float2 right_bottom = fmin(b1.zw, b2.zw);
|
||||
const float2 wh = fmax(right_bottom - left_top + f_offset, 0.0f);
|
||||
const float interS = wh.x * wh.y;
|
||||
|
||||
const float baseS =
|
||||
fmax(mode == 0 ? b1_area + b2_area - interS : b1_area, f_offset);
|
||||
ious[index] = interS / baseS;
|
||||
}
|
||||
)";
|
||||
|
||||
void BBoxOverlapsMPSKernelLauncher(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
|
||||
const int mode, const bool aligned, const int offset) {
|
||||
// get stream
|
||||
auto stream = at::mps::getCurrentMPSStream();
|
||||
auto library_manager = MPSLibraryManager::getInstance();
|
||||
MPSLibrary* library;
|
||||
const static std::string kLibraryName = "bbox_overlap";
|
||||
if (library_manager->hasLibrary(kLibraryName))
|
||||
library = library_manager->getLibrary(kLibraryName);
|
||||
else
|
||||
library = library_manager->createLibraryFromSouce(kLibraryName, kSourceCode);
|
||||
auto func_pso = library->getComputePipelineState("bbox_overlap_mps_kernel");
|
||||
|
||||
// create command buffer and encoder
|
||||
MTLCommandBuffer_t command_buffer = stream->commandBuffer();
|
||||
MTLComputeCommandEncoder_t compute_encoder = [command_buffer computeCommandEncoder];
|
||||
|
||||
// set pso and buffer
|
||||
int output_size = ious.numel();
|
||||
int num_bbox1 = bboxes1.size(0);
|
||||
int num_bbox2 = bboxes2.size(0);
|
||||
int num_elements = output_size;
|
||||
setMTLArgs(compute_encoder, func_pso, bboxes1, bboxes2, ious, num_bbox1, num_bbox2, mode, aligned,
|
||||
offset);
|
||||
|
||||
// set grid size
|
||||
MTLSize grid_size = MTLSizeMake(num_elements, 1, 1);
|
||||
NSUInteger thread_group_size_x = func_pso.maxTotalThreadsPerThreadgroup;
|
||||
if (thread_group_size_x > num_elements) {
|
||||
thread_group_size_x = num_elements;
|
||||
}
|
||||
MTLSize thread_group_size = MTLSizeMake(thread_group_size_x, 1, 1);
|
||||
|
||||
// encoding
|
||||
[compute_encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size];
|
||||
[compute_encoder endEncoding];
|
||||
|
||||
// commit, not sure if flush is required
|
||||
stream->commit(false);
|
||||
}
|
||||
|
||||
void bbox_overlaps_mps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, const int mode,
|
||||
const bool aligned, const int offset) {
|
||||
BBoxOverlapsMPSKernelLauncher(bboxes1, bboxes2, ious, mode, aligned, offset);
|
||||
}
|
||||
|
||||
void bbox_overlaps_impl(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, const int mode,
|
||||
const bool aligned, const int offset);
|
||||
REGISTER_DEVICE_IMPL(bbox_overlaps_impl, MPS, bbox_overlaps_mps);
|
24
setup.py
24
setup.py
|
@ -305,6 +305,30 @@ def get_extensions():
|
|||
extension = MLUExtension
|
||||
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
|
||||
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu'))
|
||||
elif (hasattr(torch.backends, 'mps')
|
||||
and torch.backends.mps.is_available()) or os.getenv(
|
||||
'FORCE_MPS', '0') == '1':
|
||||
# objc compiler support
|
||||
from distutils.unixccompiler import UnixCCompiler
|
||||
if '.mm' not in UnixCCompiler.src_extensions:
|
||||
UnixCCompiler.src_extensions.append('.mm')
|
||||
UnixCCompiler.language_map['.mm'] = 'objc'
|
||||
|
||||
define_macros += [('MMCV_WITH_MPS', None)]
|
||||
extra_compile_args = {}
|
||||
extra_compile_args['cxx'] = ['-Wall', '-std=c++17']
|
||||
extra_compile_args['cxx'] += [
|
||||
'-framework', 'Metal', '-framework', 'Foundation'
|
||||
]
|
||||
extra_compile_args['cxx'] += ['-ObjC++']
|
||||
# src
|
||||
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
|
||||
glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \
|
||||
glob.glob('./mmcv/ops/csrc/common/mps/*.mm') + \
|
||||
glob.glob('./mmcv/ops/csrc/pytorch/mps/*.mm')
|
||||
extension = CppExtension
|
||||
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
|
||||
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mps'))
|
||||
else:
|
||||
print(f'Compiling {ext_name} only with CPU')
|
||||
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
|
||||
|
||||
|
||||
class TestBBox:
|
||||
|
@ -43,7 +43,11 @@ class TestBBox:
|
|||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support')),
|
||||
pytest.param(
|
||||
'mps',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MPS_AVAILABLE, reason='requires MPS support'))
|
||||
])
|
||||
def test_bbox_overlaps_float(self, device):
|
||||
self._test_bbox_overlaps(device, dtype=torch.float)
|
||||
|
|
Loading…
Reference in New Issue