From 22fadceecd525942f4ae0cbfac97d14bc6858b11 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Fri, 22 Jul 2022 19:30:01 +0800 Subject: [PATCH] [Feature] Add MPS bbox overlap (#2123) * add mps bbox overlap * format * update document and manifest * update readme --- MANIFEST.in | 3 +- docs/en/understand_mmcv/ops.md | 110 +++++++-------- docs/zh_cn/understand_mmcv/ops.md | 110 +++++++-------- mmcv/ops/csrc/README.md | 37 +++-- mmcv/ops/csrc/common/mps/MPSDevice.h | 64 +++++++++ mmcv/ops/csrc/common/mps/MPSLibrary.h | 61 ++++++++ mmcv/ops/csrc/common/mps/MPSLibrary.mm | 110 +++++++++++++++ mmcv/ops/csrc/common/mps/MPSStream.h | 132 ++++++++++++++++++ mmcv/ops/csrc/common/mps/MPSUtils.h | 51 +++++++ .../ops/csrc/pytorch/mps/bbox_overlaps_mps.mm | 99 +++++++++++++ setup.py | 24 ++++ tests/test_ops/test_bbox.py | 8 +- 12 files changed, 687 insertions(+), 122 deletions(-) create mode 100644 mmcv/ops/csrc/common/mps/MPSDevice.h create mode 100644 mmcv/ops/csrc/common/mps/MPSLibrary.h create mode 100644 mmcv/ops/csrc/common/mps/MPSLibrary.mm create mode 100644 mmcv/ops/csrc/common/mps/MPSStream.h create mode 100644 mmcv/ops/csrc/common/mps/MPSUtils.h create mode 100644 mmcv/ops/csrc/pytorch/mps/bbox_overlaps_mps.mm diff --git a/MANIFEST.in b/MANIFEST.in index 23b6256e5..5de8494b5 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,4 +3,5 @@ include mmcv/model_zoo/open_mmlab.json mmcv/model_zoo/deprecated.json mmcv/model include mmcv/ops/csrc/common/cuda/*.cuh mmcv/ops/csrc/common/cuda/*.hpp mmcv/ops/csrc/common/*.hpp include mmcv/ops/csrc/pytorch/*.cpp mmcv/ops/csrc/pytorch/cuda/*.cu mmcv/ops/csrc/pytorch/cuda/*.cpp mmcv/ops/csrc/pytorch/cpu/*.cpp include mmcv/ops/csrc/parrots/*.h mmcv/ops/csrc/parrots/*.cpp -recursive-include mmcv/ops/csrc/ *.h *.hpp *.cpp *.cuh *.cu +include mmcv/ops/csrc/pytorch/mps/*.mm mmcv/ops/csrc/common/mps/*.h mmcv/ops/csrc/common/mps/*.mm +recursive-include mmcv/ops/csrc/ *.h *.hpp *.cpp *.cuh *.cu *.mm diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index aa44b5937..3565854d5 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -2,58 +2,58 @@ We implement common ops used in detection, segmentation, etc. -| Device | CPU | CUDA | MLU | -| ---------------------------- | --- | ---- | --- | -| ActiveRotatedFilter | √ | √ | | -| AssignScoreWithK | | √ | | -| BallQuery | | √ | | -| BBoxOverlaps | | √ | √ | -| BorderAlign | | √ | | -| BoxIouRotated | √ | √ | | -| CARAFE | | √ | | -| ChamferDistance | | √ | | -| CrissCrossAttention | | √ | | -| ContourExpand | √ | | | -| ConvexIoU | | √ | | -| CornerPool | | √ | | -| Correlation | | √ | | -| Deformable Convolution v1/v2 | √ | √ | | -| Deformable RoIPool | | √ | | -| DiffIoURotated | | √ | | -| DynamicScatter | | √ | | -| FurthestPointSample | | √ | | -| FurthestPointSampleWithDist | | √ | | -| FusedBiasLeakyrelu | | √ | | -| GatherPoints | | √ | | -| GroupPoints | | √ | | -| Iou3d | | √ | | -| KNN | | √ | | -| MaskedConv | | √ | | -| MergeCells | | √ | | -| MinAreaPolygon | | √ | | -| ModulatedDeformConv2d | √ | √ | | -| MultiScaleDeformableAttn | | √ | | -| NMS | √ | √ | √ | -| NMSRotated | √ | √ | | -| PixelGroup | √ | | | -| PointsInBoxes | √ | √ | | -| PointsInPolygons | | √ | | -| PSAMask | √ | √ | √ | -| RotatedFeatureAlign | √ | √ | | -| RoIPointPool3d | | √ | | -| RoIPool | | √ | √ | -| RoIAlignRotated | √ | √ | √ | -| RiRoIAlignRotated | | √ | | -| RoIAlign | √ | √ | √ | -| RoIAwarePool3d | | √ | | -| SAConv2d | | √ | | -| SigmoidFocalLoss | | √ | √ | -| SoftmaxFocalLoss | | √ | | -| SoftNMS | | √ | | -| Sparse Convolution | | √ | | -| Synchronized BatchNorm | | √ | | -| ThreeInterpolate | | √ | | -| ThreeNN | | √ | | -| TINShift | | √ | √ | -| UpFirDn2d | | √ | | -| Voxelization | √ | √ | | +| Device | CPU | CUDA | MLU | MPS | +| ---------------------------- | --- | ---- | --- | --- | +| ActiveRotatedFilter | √ | √ | | | +| AssignScoreWithK | | √ | | | +| BallQuery | | √ | | | +| BBoxOverlaps | | √ | √ | √ | +| BorderAlign | | √ | | | +| BoxIouRotated | √ | √ | | | +| CARAFE | | √ | | | +| ChamferDistance | | √ | | | +| CrissCrossAttention | | √ | | | +| ContourExpand | √ | | | | +| ConvexIoU | | √ | | | +| CornerPool | | √ | | | +| Correlation | | √ | | | +| Deformable Convolution v1/v2 | √ | √ | | | +| Deformable RoIPool | | √ | | | +| DiffIoURotated | | √ | | | +| DynamicScatter | | √ | | | +| FurthestPointSample | | √ | | | +| FurthestPointSampleWithDist | | √ | | | +| FusedBiasLeakyrelu | | √ | | | +| GatherPoints | | √ | | | +| GroupPoints | | √ | | | +| Iou3d | | √ | | | +| KNN | | √ | | | +| MaskedConv | | √ | | | +| MergeCells | | √ | | | +| MinAreaPolygon | | √ | | | +| ModulatedDeformConv2d | √ | √ | | | +| MultiScaleDeformableAttn | | √ | | | +| NMS | √ | √ | √ | | +| NMSRotated | √ | √ | | | +| PixelGroup | √ | | | | +| PointsInBoxes | √ | √ | | | +| PointsInPolygons | | √ | | | +| PSAMask | √ | √ | √ | | +| RotatedFeatureAlign | √ | √ | | | +| RoIPointPool3d | | √ | | | +| RoIPool | | √ | √ | | +| RoIAlignRotated | √ | √ | √ | | +| RiRoIAlignRotated | | √ | | | +| RoIAlign | √ | √ | √ | | +| RoIAwarePool3d | | √ | | | +| SAConv2d | | √ | | | +| SigmoidFocalLoss | | √ | √ | | +| SoftmaxFocalLoss | | √ | | | +| SoftNMS | | √ | | | +| Sparse Convolution | | √ | | | +| Synchronized BatchNorm | | √ | | | +| ThreeInterpolate | | √ | | | +| ThreeNN | | √ | | | +| TINShift | | √ | √ | | +| UpFirDn2d | | √ | | | +| Voxelization | √ | √ | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 9a83fcbdc..94a77218d 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -2,58 +2,58 @@ MMCV 提供了检测、分割等任务中常用的算子 -| Device | CPU | CUDA | MLU | -| ---------------------------- | --- | ---- | --- | -| ActiveRotatedFilter | √ | √ | | -| AssignScoreWithK | | √ | | -| BallQuery | | √ | | -| BBoxOverlaps | | √ | √ | -| BorderAlign | | √ | | -| BoxIouRotated | √ | √ | | -| CARAFE | | √ | | -| ChamferDistance | | √ | | -| CrissCrossAttention | | √ | | -| ContourExpand | √ | | | -| ConvexIoU | | √ | | -| CornerPool | | √ | | -| Correlation | | √ | | -| Deformable Convolution v1/v2 | √ | √ | | -| Deformable RoIPool | | √ | | -| DiffIoURotated | | √ | | -| DynamicScatter | | √ | | -| FurthestPointSample | | √ | | -| FurthestPointSampleWithDist | | √ | | -| FusedBiasLeakyrelu | | √ | | -| GatherPoints | | √ | | -| GroupPoints | | √ | | -| Iou3d | | √ | | -| KNN | | √ | | -| MaskedConv | | √ | | -| MergeCells | | √ | | -| MinAreaPolygon | | √ | | -| ModulatedDeformConv2d | √ | √ | | -| MultiScaleDeformableAttn | | √ | | -| NMS | √ | √ | √ | -| NMSRotated | √ | √ | | -| PixelGroup | √ | | | -| PointsInBoxes | √ | √ | | -| PointsInPolygons | | √ | | -| PSAMask | √ | √ | √ | -| RotatedFeatureAlign | √ | √ | | -| RoIPointPool3d | | √ | | -| RoIPool | | √ | √ | -| RoIAlignRotated | √ | √ | √ | -| RiRoIAlignRotated | | √ | | -| RoIAlign | √ | √ | √ | -| RoIAwarePool3d | | √ | | -| SAConv2d | | √ | | -| SigmoidFocalLoss | | √ | √ | -| SoftmaxFocalLoss | | √ | | -| SoftNMS | | √ | | -| Sparse Convolution | | √ | | -| Synchronized BatchNorm | | √ | | -| ThreeInterpolate | | √ | | -| ThreeNN | | √ | | -| TINShift | | √ | √ | -| UpFirDn2d | | √ | | -| Voxelization | √ | √ | | +| Device | CPU | CUDA | MLU | MPS | +| ---------------------------- | --- | ---- | --- | --- | +| ActiveRotatedFilter | √ | √ | | | +| AssignScoreWithK | | √ | | | +| BallQuery | | √ | | | +| BBoxOverlaps | | √ | √ | √ | +| BorderAlign | | √ | | | +| BoxIouRotated | √ | √ | | | +| CARAFE | | √ | | | +| ChamferDistance | | √ | | | +| CrissCrossAttention | | √ | | | +| ContourExpand | √ | | | | +| ConvexIoU | | √ | | | +| CornerPool | | √ | | | +| Correlation | | √ | | | +| Deformable Convolution v1/v2 | √ | √ | | | +| Deformable RoIPool | | √ | | | +| DiffIoURotated | | √ | | | +| DynamicScatter | | √ | | | +| FurthestPointSample | | √ | | | +| FurthestPointSampleWithDist | | √ | | | +| FusedBiasLeakyrelu | | √ | | | +| GatherPoints | | √ | | | +| GroupPoints | | √ | | | +| Iou3d | | √ | | | +| KNN | | √ | | | +| MaskedConv | | √ | | | +| MergeCells | | √ | | | +| MinAreaPolygon | | √ | | | +| ModulatedDeformConv2d | √ | √ | | | +| MultiScaleDeformableAttn | | √ | | | +| NMS | √ | √ | √ | | +| NMSRotated | √ | √ | | | +| PixelGroup | √ | | | | +| PointsInBoxes | √ | √ | | | +| PointsInPolygons | | √ | | | +| PSAMask | √ | √ | √ | | +| RotatedFeatureAlign | √ | √ | | | +| RoIPointPool3d | | √ | | | +| RoIPool | | √ | √ | | +| RoIAlignRotated | √ | √ | √ | | +| RiRoIAlignRotated | | √ | | | +| RoIAlign | √ | √ | √ | | +| RoIAwarePool3d | | √ | | | +| SAConv2d | | √ | | | +| SigmoidFocalLoss | | √ | √ | | +| SoftmaxFocalLoss | | √ | | | +| SoftNMS | | √ | | | +| Sparse Convolution | | √ | | | +| Synchronized BatchNorm | | √ | | | +| ThreeInterpolate | | √ | | | +| ThreeNN | | √ | | | +| TINShift | | √ | √ | | +| UpFirDn2d | | √ | | | +| Voxelization | √ | √ | | | diff --git a/mmcv/ops/csrc/README.md b/mmcv/ops/csrc/README.md index 317b8fb3d..dbc82b534 100644 --- a/mmcv/ops/csrc/README.md +++ b/mmcv/ops/csrc/README.md @@ -13,11 +13,19 @@ This folder contains all non-python code for MMCV custom ops. Please follow the │ ├── pytorch_cpp_helper.hpp │ ├── pytorch_cuda_helper.hpp │ ├── pytorch_device_registry.hpp -│   └── cuda -│   ├── common_cuda_helper.hpp -│   ├── parrots_cudawarpfunction.cuh -│   ├── ... -│   └── ops_cuda_kernel.cuh +│   ├── cuda +│   │ ├── common_cuda_helper.hpp +│   │ ├── parrots_cudawarpfunction.cuh +│   │ ├── ... +│   │ └── ops_cuda_kernel.cuh +|   ├── mps +│   │ ├── MPSLibrary.h +│   │ ├── ... +│   │ └── MPSUtils.h +|   ├── mlu +│   │ └── ... +|   └── utils +│   │ └── ... ├── onnxruntime │   ├── onnxruntime_register.h │   ├── onnxruntime_session_options_config_keys.h @@ -41,9 +49,15 @@ This folder contains all non-python code for MMCV custom ops. Please follow the │   ├── cuda │   │   ├── ... │   │   └── ops_cuda.cu -│   └── cpu +│   ├── cpu +│   │   ├── ... +│   │   └── ops.cpp +│   ├── mps +│   │   ├── ... +│   |   └── op_mps.mm +│   └── mlu │      ├── ... -│      └── ops.cpp +│      └── op_mlu.cpp └── tensorrt ├── trt_cuda_helper.cuh ├── trt_plugin_helper.hpp @@ -63,13 +77,18 @@ This folder contains all non-python code for MMCV custom ops. Please follow the - `common`: This directory contains all tools and shared codes. - `cuda`: The cuda kernels which can be shared by all backends. **HIP** kernel is also here since they have similar syntax. -- `onnxruntime`: **ONNX Runtime** support for custom ops. + - `mps`: The tools used to support MPS ops. **NOTE** that MPS support is **experimental**. + - `mlu`: The MLU kernels used to support [Cambricon](https://www.cambricon.com/) device. + - `utils`: The kernels and utils of spconv. +- `onnxruntime`: **ONNX Runtime** support for custom ops. Has been deprecated, please try the latest custom ops in [MMDeploy](https://github.com/open-mmlab/mmdeploy). - `cpu`: CPU implementation of supported ops. - `parrots`: **Parrots** is a deep learning frame for model training and inference. Parrots custom ops are placed in this directory. - `pytorch`: **PyTorch** custom ops are supported by binding C++ to Python with **pybind11**. The ops implementation and binding codes are placed in this directory. - `cuda`: This directory contains cuda kernel launchers, which feed memory pointers of tensor to the cuda kernel in `common/cuda`. The launchers provide c++ interface of cuda implementation of corresponding custom ops. - `cpu`: This directory contain cpu implementations of corresponding custom ops. -- `tensorrt`: **TensorRT** support for custom ops. + - `mlu`: This directory contain launchers of each MLU kernels. + - `mps`: MPS ops implementation and launchers. +- `tensorrt`: **TensorRT** support for custom ops. Has been deprecated, please try the latest custom ops in [MMDeploy](https://github.com/open-mmlab/mmdeploy). - `plugins`: This directory contains the implementation of the supported custom ops. Some ops might also use shared cuda kernel in `common/cuda`. ## How to add new PyTorch ops? diff --git a/mmcv/ops/csrc/common/mps/MPSDevice.h b/mmcv/ops/csrc/common/mps/MPSDevice.h new file mode 100644 index 000000000..e1d9d4961 --- /dev/null +++ b/mmcv/ops/csrc/common/mps/MPSDevice.h @@ -0,0 +1,64 @@ +// Copyright © 2022 Apple Inc. + +// This file is modify from: +// https://github.com/pytorch/pytorch/blob/a85d1f0bcdd02cf18d3b0517337458cb51a18cdb/aten/src/ATen/mps/MPSDevice.h + +#pragma once +#include +#include +#include + +#ifdef __OBJC__ +#include +#include +#include +typedef id MTLDevice_t; +#else +typedef void* MTLDevice; +typedef void* MTLDevice_t; +#endif + +using namespace std; + +namespace at { +namespace mps { + +//----------------------------------------------------------------- +// MPSDevice +// +// MPSDevice is a singleton class that returns the default device +//----------------------------------------------------------------- + +class TORCH_API MPSDevice { + public: + /** + * MPSDevice should not be cloneable. + */ + MPSDevice(MPSDevice& other) = delete; + /** + * MPSDevice should not be assignable. + */ + void operator=(const MPSDevice&) = delete; + /** + * Gets single instance of the Device. + */ + static MPSDevice* getInstance(); + /** + * Returns the single device. + */ + MTLDevice_t device() { return _mtl_device; } + + ~MPSDevice(); + + private: + static MPSDevice* _device; + MTLDevice_t _mtl_device; + MPSDevice(); +}; + +TORCH_API bool is_available(); + +TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false); + +} // namespace mps +} // namespace at diff --git a/mmcv/ops/csrc/common/mps/MPSLibrary.h b/mmcv/ops/csrc/common/mps/MPSLibrary.h new file mode 100644 index 000000000..41c33fba8 --- /dev/null +++ b/mmcv/ops/csrc/common/mps/MPSLibrary.h @@ -0,0 +1,61 @@ +#ifndef _MPS_LIBRARY_H_ +#define _MPS_LIBRARY_H_ + +#include +#include + +#ifdef __OBJC__ +#include +#include +#include + +typedef id MTLComputePipelineState_t; +typedef id MTLLibrary_t; +#else +typedef void* MTLComputePipelineState; +typedef void* MTLComputePipelineState_t; +typedef void* MTLLibrary; +typedef void* MTLLibrary_t; +#endif + +class MPSLibrary { + public: + // disable constructor for singleton + static MPSLibrary* createFromUrl(const std::string& library_url); + static MPSLibrary* createFromSource(const std::string& source); + ~MPSLibrary(); + + MTLLibrary_t library() { return _library; } + + MTLComputePipelineState_t getComputePipelineState( + const std::string& function_name); + + private: + MTLLibrary_t _library; + std::unordered_map _pso_map; +}; + +class MPSLibraryManager { + public: + // disable constructor for singleton + MPSLibraryManager(const MPSLibraryManager&) = delete; + MPSLibraryManager& operator=(const MPSLibraryManager&) = delete; + MPSLibraryManager(MPSLibraryManager&&) = delete; + MPSLibraryManager& operator=(MPSLibraryManager&&) = delete; + + static MPSLibraryManager* getInstance(); + + bool hasLibrary(const std::string& name); + + MPSLibrary* getLibrary(const std::string& library_url); + + MPSLibrary* createLibraryFromSouce(const std::string& name, + const std::string& sources); + + ~MPSLibraryManager(); + + private: + MPSLibraryManager(); + std::unordered_map> _library_map; +}; +#endif diff --git a/mmcv/ops/csrc/common/mps/MPSLibrary.mm b/mmcv/ops/csrc/common/mps/MPSLibrary.mm new file mode 100644 index 000000000..1a3d635ca --- /dev/null +++ b/mmcv/ops/csrc/common/mps/MPSLibrary.mm @@ -0,0 +1,110 @@ +#include "MPSLibrary.h" +#include +#include "MPSDevice.h" + +static std::unique_ptr mps_library_manager; +static c10::once_flag mpsdev_init; + +MPSLibraryManager* MPSLibraryManager::getInstance() { + c10::call_once(mpsdev_init, [] { + mps_library_manager = std::unique_ptr(new MPSLibraryManager()); + }); + return mps_library_manager.get(); +} + +MPSLibraryManager::~MPSLibraryManager() {} + +MPSLibraryManager::MPSLibraryManager() {} + +bool MPSLibraryManager::hasLibrary(const std::string& name) { + return _library_map.find(name) != _library_map.end(); +} + +MPSLibrary* MPSLibraryManager::getLibrary(const std::string& library_url) { + if (_library_map.find(library_url) != _library_map.end()) { + return _library_map[library_url].get(); + } + _library_map.emplace(std::make_pair( + library_url, std::unique_ptr(MPSLibrary::createFromUrl(library_url)))); + return _library_map[library_url].get(); +} + +MPSLibrary* MPSLibraryManager::createLibraryFromSouce(const std::string& name, + const std::string& source) { + NSString* ns_name = [NSString stringWithCString:name.c_str()]; + if (_library_map.find(name) != _library_map.end()) { + NSLog(@"Library %@ already exist.", ns_name); + return nullptr; + } + + _library_map.emplace( + std::make_pair(name, std::unique_ptr(MPSLibrary::createFromSource(source)))); + return _library_map[name].get(); +} + +MPSLibrary* MPSLibrary::createFromUrl(const std::string& library_url) { + MPSLibrary* library = new MPSLibrary(); + @autoreleasepool { + NSError* error = nil; + + // load library and func + NSString* utl_str = [NSString stringWithCString:library_url.c_str()]; + NSURL* metal_url = [NSURL fileURLWithPath:utl_str]; + library->_library = [at::mps::MPSDevice::getInstance()->device() newLibraryWithURL:metal_url + error:&error]; + if (library->_library == nil) { + NSLog(@"Failed to find library, error %@.", error); + exit(1); + } + } + + return library; +} + +MPSLibrary* MPSLibrary::createFromSource(const std::string& sources) { + MPSLibrary* library = new MPSLibrary(); + @autoreleasepool { + NSError* error = nil; + + // load library and func + NSString* code_str = [NSString stringWithCString:sources.c_str()]; + library->_library = [at::mps::MPSDevice::getInstance()->device() newLibraryWithSource:code_str + options:nil + error:&error]; + if (library->_library == nil) { + NSLog(@"Failed to find library, error %@.", error); + exit(1); + } + } + + return library; +} + +MPSLibrary::~MPSLibrary() { + [_library release]; + _library = nil; +} + +MTLComputePipelineState_t MPSLibrary::getComputePipelineState(const std::string& function_name) { + if (_pso_map.find(function_name) != _pso_map.end()) { + return _pso_map[function_name]; + } + + MTLComputePipelineState_t pso; + @autoreleasepool { + NSError* error = nil; + + // create function + NSString* function_name_str = [NSString stringWithCString:function_name.c_str()]; + id func = [_library newFunctionWithName:function_name_str]; + if (func == nil) { + NSLog(@"Failed to created pipeline state object, error %@.", error); + exit(1); + } + // create pipeline + pso = [at::mps::MPSDevice::getInstance()->device() newComputePipelineStateWithFunction:func + error:&error]; + _pso_map.emplace(std::make_pair(function_name, pso)); + } + return _pso_map[function_name]; +} diff --git a/mmcv/ops/csrc/common/mps/MPSStream.h b/mmcv/ops/csrc/common/mps/MPSStream.h new file mode 100644 index 000000000..54cd38849 --- /dev/null +++ b/mmcv/ops/csrc/common/mps/MPSStream.h @@ -0,0 +1,132 @@ +// Copyright © 2022 Apple Inc. + +// This file is modify from: +// https://github.com/pytorch/pytorch/blob/a85d1f0bcdd02cf18d3b0517337458cb51a18cdb/aten/src/ATen/mps/MPSStream.h + +#pragma once + +#include +#include + +#include +#include +#include +#include "MPSDevice.h" + +#ifdef __OBJC__ +#include +#include +#include +#include +typedef id MTLCommandQueue_t; +typedef id MTLCommandBuffer_t; +typedef id MTLSharedEvent_t; +typedef id MTLDevice_t; +#else +typedef void* MTLCommandQueue_t; +typedef void* MTLCommandQueue; +typedef void* MTLCommandBuffer_t; +typedef void* MTLCommandBuffer; +typedef void* MTLSharedEvent_t; +typedef void* dispatch_queue_t; +typedef void* MTLDevice_t; +#define nil NULL; +#endif + +namespace at { +namespace mps { + +//----------------------------------------------------------------- +// MPSStream +//----------------------------------------------------------------- + +class TORCH_API MPSStream { + public: + enum Unchecked { UNCHECKED }; + /// Construct a MPSStream from a Stream. This construction is checked, + /// and will raise an error if the Stream is not, in fact, a MPS stream. + explicit MPSStream(Stream stream); + + ~MPSStream(); + MTLCommandQueue_t commandQueue() const { return _commandQueue; }; + dispatch_queue_t queue() const { return _serialQueue; } + + MTLCommandBuffer_t commandBuffer(); + void commit(bool flush); + void commitAndWait(); + void synchronize(); + + void flush(); + + /// Get the MPS device index that this stream is associated with. + c10::DeviceIndex device_index() const { return _stream.device_index(); } + + MTLCommandQueue_t stream() const { return _commandQueue; }; + + MTLDevice_t device() const { return [_commandQueue device]; } + + /// Explicit conversion to Stream. + Stream unwrap() const { return _stream; } + + private: + Stream _stream; + MTLCommandQueue_t _commandQueue = nil; + MTLCommandBuffer_t _commandBuffer = nil; + void _flush(bool commitAndWait) const; + + dispatch_queue_t _serialQueue = nullptr; +}; + +/** + * Get the current MPS stream + */ +TORCH_API MPSStream* getCurrentMPSStream(); + +/** + * Get the default MPS stream + */ +TORCH_API MPSStream* getDefaultMPSStream(); + +//----------------------------------------------------------------- +// MPSStreamImpl +//----------------------------------------------------------------- + +class TORCH_API MPSStreamImpl { + public: + /** + * Gets single instance of the MPSStream. + */ + static MPSStream* getInstance(); + + private: + static MPSStream* _stream; + MPSStreamImpl(); +}; + +//----------------------------------------------------------------- +// MPSEvent +//----------------------------------------------------------------- + +struct TORCH_API MPSEvent { + MPSEvent(); + // MPSEvent(id device); + + ~MPSEvent(); + MTLSharedEvent_t event() const { return _event; } + + void recordEvent(MPSStream* stream); + void waitForEvent(MPSStream* queue); // waits on the cpu + bool queryEvent(); + uint64_t getCurrentValue() { return _currentValue; } + void setCurrentValue(uint64_t currValue) { _currentValue = currValue; } + + private: + bool _isRecorded = false; + uint64_t _currentValue = 0; + MTLSharedEvent_t _event; +}; + +typedef MPSEvent* mpsEvent_t; + +} // namespace mps +} // namespace at diff --git a/mmcv/ops/csrc/common/mps/MPSUtils.h b/mmcv/ops/csrc/common/mps/MPSUtils.h new file mode 100644 index 000000000..2a4ce6d79 --- /dev/null +++ b/mmcv/ops/csrc/common/mps/MPSUtils.h @@ -0,0 +1,51 @@ +#ifndef _MPS_UTILS_H_ +#define _MPS_UTILS_H_ +#include +#ifdef __OBJC__ +#include +#include +#include + +typedef id MTLBuffer_t; +typedef id MTLComputeCommandEncoder_t; +#else +typedef void* MTLBuffer; +typedef void* MTLBuffer_t; +typedef void* MTLComputeCommandEncoder; +typedef void* MTLComputeCommandEncoder_t; +#endif + +// utils +static inline MTLBuffer_t getMTLBufferStorage(const at::Tensor& tensor) { + return __builtin_bit_cast(MTLBuffer_t, tensor.storage().data()); +} + +template , at::Tensor>::value, bool> = true> +void setMTLArg(MTLComputeCommandEncoder_t encoder, int index, T&& t); + +template , at::Tensor>::value, bool> = true> +void setMTLArg(MTLComputeCommandEncoder_t encoder, int index, T&& t) { + [encoder setBuffer:getMTLBufferStorage(t) offset:0 atIndex:index]; +} + +template , at::Tensor>::value, bool>> +void setMTLArg(MTLComputeCommandEncoder_t encoder, int index, T&& t) { + [encoder setBytes:&t length:sizeof(t) atIndex:index]; +} + +inline void setMTLArgsImpl(MTLComputeCommandEncoder_t, int) {} + +template +void setMTLArgsImpl(MTLComputeCommandEncoder_t encoder, int index, T&& t, Args&&... args) { + setMTLArg(encoder, index, std::forward(t)); + setMTLArgsImpl(encoder, index + 1, std::forward(args)...); +} + +template +void setMTLArgs(MTLComputeCommandEncoder_t encoder, MTLComputePipelineState_t pso, Args&&... args) { + [encoder setComputePipelineState:pso]; + setMTLArgsImpl(encoder, 0, std::forward(args)...); +} +#endif diff --git a/mmcv/ops/csrc/pytorch/mps/bbox_overlaps_mps.mm b/mmcv/ops/csrc/pytorch/mps/bbox_overlaps_mps.mm new file mode 100644 index 000000000..cad6a41a0 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mps/bbox_overlaps_mps.mm @@ -0,0 +1,99 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +#include "pytorch_device_registry.hpp" + +#include "MPSLibrary.h" +#include "MPSStream.h" +#include "MPSUtils.h" + +using at::Tensor; + +const static std::string kSourceCode = R"( +#include +#include +using namespace metal; + +kernel void bbox_overlap_mps_kernel(constant const float4* bboxes1, + constant const float4* bboxes2, + device float* ious, + constant int& num_bbox1, + constant int& num_bbox2, + constant int& mode, + constant bool& aligned, + constant int& offset, + uint index [[thread_position_in_grid]]) +{ + int base1 = index; + int base2 = index; + if(!aligned){ + base1 = index / num_bbox2; + base2 = index % num_bbox2; + } + + const float f_offset = float(offset); + + const float4 b1 = bboxes1[base1]; + const float b1_area = (b1[2]-b1[0]+f_offset)*(b1[3]-b1[1]+f_offset); + + const float4 b2 = bboxes2[base2]; + const float b2_area = (b2[2]-b2[0]+f_offset)*(b2[3]-b2[1]+f_offset); + + const float2 left_top = fmax(b1.xy, b2.xy); + const float2 right_bottom = fmin(b1.zw, b2.zw); + const float2 wh = fmax(right_bottom - left_top + f_offset, 0.0f); + const float interS = wh.x * wh.y; + + const float baseS = + fmax(mode == 0 ? b1_area + b2_area - interS : b1_area, f_offset); + ious[index] = interS / baseS; +} +)"; + +void BBoxOverlapsMPSKernelLauncher(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, + const int mode, const bool aligned, const int offset) { + // get stream + auto stream = at::mps::getCurrentMPSStream(); + auto library_manager = MPSLibraryManager::getInstance(); + MPSLibrary* library; + const static std::string kLibraryName = "bbox_overlap"; + if (library_manager->hasLibrary(kLibraryName)) + library = library_manager->getLibrary(kLibraryName); + else + library = library_manager->createLibraryFromSouce(kLibraryName, kSourceCode); + auto func_pso = library->getComputePipelineState("bbox_overlap_mps_kernel"); + + // create command buffer and encoder + MTLCommandBuffer_t command_buffer = stream->commandBuffer(); + MTLComputeCommandEncoder_t compute_encoder = [command_buffer computeCommandEncoder]; + + // set pso and buffer + int output_size = ious.numel(); + int num_bbox1 = bboxes1.size(0); + int num_bbox2 = bboxes2.size(0); + int num_elements = output_size; + setMTLArgs(compute_encoder, func_pso, bboxes1, bboxes2, ious, num_bbox1, num_bbox2, mode, aligned, + offset); + + // set grid size + MTLSize grid_size = MTLSizeMake(num_elements, 1, 1); + NSUInteger thread_group_size_x = func_pso.maxTotalThreadsPerThreadgroup; + if (thread_group_size_x > num_elements) { + thread_group_size_x = num_elements; + } + MTLSize thread_group_size = MTLSizeMake(thread_group_size_x, 1, 1); + + // encoding + [compute_encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; + [compute_encoder endEncoding]; + + // commit, not sure if flush is required + stream->commit(false); +} + +void bbox_overlaps_mps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, const int mode, + const bool aligned, const int offset) { + BBoxOverlapsMPSKernelLauncher(bboxes1, bboxes2, ious, mode, aligned, offset); +} + +void bbox_overlaps_impl(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, const int mode, + const bool aligned, const int offset); +REGISTER_DEVICE_IMPL(bbox_overlaps_impl, MPS, bbox_overlaps_mps); diff --git a/setup.py b/setup.py index 8d836534f..274c13de3 100644 --- a/setup.py +++ b/setup.py @@ -305,6 +305,30 @@ def get_extensions(): extension = MLUExtension include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu')) + elif (hasattr(torch.backends, 'mps') + and torch.backends.mps.is_available()) or os.getenv( + 'FORCE_MPS', '0') == '1': + # objc compiler support + from distutils.unixccompiler import UnixCCompiler + if '.mm' not in UnixCCompiler.src_extensions: + UnixCCompiler.src_extensions.append('.mm') + UnixCCompiler.language_map['.mm'] = 'objc' + + define_macros += [('MMCV_WITH_MPS', None)] + extra_compile_args = {} + extra_compile_args['cxx'] = ['-Wall', '-std=c++17'] + extra_compile_args['cxx'] += [ + '-framework', 'Metal', '-framework', 'Foundation' + ] + extra_compile_args['cxx'] += ['-ObjC++'] + # src + op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ + glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \ + glob.glob('./mmcv/ops/csrc/common/mps/*.mm') + \ + glob.glob('./mmcv/ops/csrc/pytorch/mps/*.mm') + extension = CppExtension + include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) + include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mps')) else: print(f'Compiling {ext_name} only with CPU') op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ diff --git a/tests/test_ops/test_bbox.py b/tests/test_ops/test_bbox.py index f276d0d58..7123b1ee1 100644 --- a/tests/test_ops/test_bbox.py +++ b/tests/test_ops/test_bbox.py @@ -3,7 +3,7 @@ import numpy as np import pytest import torch -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE class TestBBox: @@ -43,7 +43,11 @@ class TestBBox: pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'mps', + marks=pytest.mark.skipif( + not IS_MPS_AVAILABLE, reason='requires MPS support')) ]) def test_bbox_overlaps_float(self, device): self._test_bbox_overlaps(device, dtype=torch.float)