[Feature] Add MPS bbox overlap (#2123)

* add mps bbox overlap

* format

* update document and manifest

* update readme
pull/2142/head
q.yao 2022-07-22 19:30:01 +08:00 committed by GitHub
parent 73066430be
commit 22fadceecd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 687 additions and 122 deletions

View File

@ -3,4 +3,5 @@ include mmcv/model_zoo/open_mmlab.json mmcv/model_zoo/deprecated.json mmcv/model
include mmcv/ops/csrc/common/cuda/*.cuh mmcv/ops/csrc/common/cuda/*.hpp mmcv/ops/csrc/common/*.hpp
include mmcv/ops/csrc/pytorch/*.cpp mmcv/ops/csrc/pytorch/cuda/*.cu mmcv/ops/csrc/pytorch/cuda/*.cpp mmcv/ops/csrc/pytorch/cpu/*.cpp
include mmcv/ops/csrc/parrots/*.h mmcv/ops/csrc/parrots/*.cpp
recursive-include mmcv/ops/csrc/ *.h *.hpp *.cpp *.cuh *.cu
include mmcv/ops/csrc/pytorch/mps/*.mm mmcv/ops/csrc/common/mps/*.h mmcv/ops/csrc/common/mps/*.mm
recursive-include mmcv/ops/csrc/ *.h *.hpp *.cpp *.cuh *.cu *.mm

View File

@ -2,58 +2,58 @@
We implement common ops used in detection, segmentation, etc.
| Device | CPU | CUDA | MLU |
| ---------------------------- | --- | ---- | --- |
| ActiveRotatedFilter | √ | √ | |
| AssignScoreWithK | | √ | |
| BallQuery | | √ | |
| BBoxOverlaps | | √ | √ |
| BorderAlign | | √ | |
| BoxIouRotated | √ | √ | |
| CARAFE | | √ | |
| ChamferDistance | | √ | |
| CrissCrossAttention | | √ | |
| ContourExpand | √ | | |
| ConvexIoU | | √ | |
| CornerPool | | √ | |
| Correlation | | √ | |
| Deformable Convolution v1/v2 | √ | √ | |
| Deformable RoIPool | | √ | |
| DiffIoURotated | | √ | |
| DynamicScatter | | √ | |
| FurthestPointSample | | √ | |
| FurthestPointSampleWithDist | | √ | |
| FusedBiasLeakyrelu | | √ | |
| GatherPoints | | √ | |
| GroupPoints | | √ | |
| Iou3d | | √ | |
| KNN | | √ | |
| MaskedConv | | √ | |
| MergeCells | | √ | |
| MinAreaPolygon | | √ | |
| ModulatedDeformConv2d | √ | √ | |
| MultiScaleDeformableAttn | | √ | |
| NMS | √ | √ | √ |
| NMSRotated | √ | √ | |
| PixelGroup | √ | | |
| PointsInBoxes | √ | √ | |
| PointsInPolygons | | √ | |
| PSAMask | √ | √ | √ |
| RotatedFeatureAlign | √ | √ | |
| RoIPointPool3d | | √ | |
| RoIPool | | √ | √ |
| RoIAlignRotated | √ | √ | √ |
| RiRoIAlignRotated | | √ | |
| RoIAlign | √ | √ | √ |
| RoIAwarePool3d | | √ | |
| SAConv2d | | √ | |
| SigmoidFocalLoss | | √ | √ |
| SoftmaxFocalLoss | | √ | |
| SoftNMS | | √ | |
| Sparse Convolution | | √ | |
| Synchronized BatchNorm | | √ | |
| ThreeInterpolate | | √ | |
| ThreeNN | | √ | |
| TINShift | | √ | √ |
| UpFirDn2d | | √ | |
| Voxelization | √ | √ | |
| Device | CPU | CUDA | MLU | MPS |
| ---------------------------- | --- | ---- | --- | --- |
| ActiveRotatedFilter | √ | √ | | |
| AssignScoreWithK | | √ | | |
| BallQuery | | √ | | |
| BBoxOverlaps | | √ | √ | √ |
| BorderAlign | | √ | | |
| BoxIouRotated | √ | √ | | |
| CARAFE | | √ | | |
| ChamferDistance | | √ | | |
| CrissCrossAttention | | √ | | |
| ContourExpand | √ | | | |
| ConvexIoU | | √ | | |
| CornerPool | | √ | | |
| Correlation | | √ | | |
| Deformable Convolution v1/v2 | √ | √ | | |
| Deformable RoIPool | | √ | | |
| DiffIoURotated | | √ | | |
| DynamicScatter | | √ | | |
| FurthestPointSample | | √ | | |
| FurthestPointSampleWithDist | | √ | | |
| FusedBiasLeakyrelu | | √ | | |
| GatherPoints | | √ | | |
| GroupPoints | | √ | | |
| Iou3d | | √ | | |
| KNN | | √ | | |
| MaskedConv | | √ | | |
| MergeCells | | √ | | |
| MinAreaPolygon | | √ | | |
| ModulatedDeformConv2d | √ | √ | | |
| MultiScaleDeformableAttn | | √ | | |
| NMS | √ | √ | √ | |
| NMSRotated | √ | √ | | |
| PixelGroup | √ | | | |
| PointsInBoxes | √ | √ | | |
| PointsInPolygons | | √ | | |
| PSAMask | √ | √ | √ | |
| RotatedFeatureAlign | √ | √ | | |
| RoIPointPool3d | | √ | | |
| RoIPool | | √ | √ | |
| RoIAlignRotated | √ | √ | √ | |
| RiRoIAlignRotated | | √ | | |
| RoIAlign | √ | √ | √ | |
| RoIAwarePool3d | | √ | | |
| SAConv2d | | √ | | |
| SigmoidFocalLoss | | √ | √ | |
| SoftmaxFocalLoss | | √ | | |
| SoftNMS | | √ | | |
| Sparse Convolution | | √ | | |
| Synchronized BatchNorm | | √ | | |
| ThreeInterpolate | | √ | | |
| ThreeNN | | √ | | |
| TINShift | | √ | √ | |
| UpFirDn2d | | √ | | |
| Voxelization | √ | √ | | |

View File

@ -2,58 +2,58 @@
MMCV 提供了检测、分割等任务中常用的算子
| Device | CPU | CUDA | MLU |
| ---------------------------- | --- | ---- | --- |
| ActiveRotatedFilter | √ | √ | |
| AssignScoreWithK | | √ | |
| BallQuery | | √ | |
| BBoxOverlaps | | √ | √ |
| BorderAlign | | √ | |
| BoxIouRotated | √ | √ | |
| CARAFE | | √ | |
| ChamferDistance | | √ | |
| CrissCrossAttention | | √ | |
| ContourExpand | √ | | |
| ConvexIoU | | √ | |
| CornerPool | | √ | |
| Correlation | | √ | |
| Deformable Convolution v1/v2 | √ | √ | |
| Deformable RoIPool | | √ | |
| DiffIoURotated | | √ | |
| DynamicScatter | | √ | |
| FurthestPointSample | | √ | |
| FurthestPointSampleWithDist | | √ | |
| FusedBiasLeakyrelu | | √ | |
| GatherPoints | | √ | |
| GroupPoints | | √ | |
| Iou3d | | √ | |
| KNN | | √ | |
| MaskedConv | | √ | |
| MergeCells | | √ | |
| MinAreaPolygon | | √ | |
| ModulatedDeformConv2d | √ | √ | |
| MultiScaleDeformableAttn | | √ | |
| NMS | √ | √ | √ |
| NMSRotated | √ | √ | |
| PixelGroup | √ | | |
| PointsInBoxes | √ | √ | |
| PointsInPolygons | | √ | |
| PSAMask | √ | √ | √ |
| RotatedFeatureAlign | √ | √ | |
| RoIPointPool3d | | √ | |
| RoIPool | | √ | √ |
| RoIAlignRotated | √ | √ | √ |
| RiRoIAlignRotated | | √ | |
| RoIAlign | √ | √ | √ |
| RoIAwarePool3d | | √ | |
| SAConv2d | | √ | |
| SigmoidFocalLoss | | √ | √ |
| SoftmaxFocalLoss | | √ | |
| SoftNMS | | √ | |
| Sparse Convolution | | √ | |
| Synchronized BatchNorm | | √ | |
| ThreeInterpolate | | √ | |
| ThreeNN | | √ | |
| TINShift | | √ | √ |
| UpFirDn2d | | √ | |
| Voxelization | √ | √ | |
| Device | CPU | CUDA | MLU | MPS |
| ---------------------------- | --- | ---- | --- | --- |
| ActiveRotatedFilter | √ | √ | | |
| AssignScoreWithK | | √ | | |
| BallQuery | | √ | | |
| BBoxOverlaps | | √ | √ | √ |
| BorderAlign | | √ | | |
| BoxIouRotated | √ | √ | | |
| CARAFE | | √ | | |
| ChamferDistance | | √ | | |
| CrissCrossAttention | | √ | | |
| ContourExpand | √ | | | |
| ConvexIoU | | √ | | |
| CornerPool | | √ | | |
| Correlation | | √ | | |
| Deformable Convolution v1/v2 | √ | √ | | |
| Deformable RoIPool | | √ | | |
| DiffIoURotated | | √ | | |
| DynamicScatter | | √ | | |
| FurthestPointSample | | √ | | |
| FurthestPointSampleWithDist | | √ | | |
| FusedBiasLeakyrelu | | √ | | |
| GatherPoints | | √ | | |
| GroupPoints | | √ | | |
| Iou3d | | √ | | |
| KNN | | √ | | |
| MaskedConv | | √ | | |
| MergeCells | | √ | | |
| MinAreaPolygon | | √ | | |
| ModulatedDeformConv2d | √ | √ | | |
| MultiScaleDeformableAttn | | √ | | |
| NMS | √ | √ | √ | |
| NMSRotated | √ | √ | | |
| PixelGroup | √ | | | |
| PointsInBoxes | √ | √ | | |
| PointsInPolygons | | √ | | |
| PSAMask | √ | √ | √ | |
| RotatedFeatureAlign | √ | √ | | |
| RoIPointPool3d | | √ | | |
| RoIPool | | √ | √ | |
| RoIAlignRotated | √ | √ | √ | |
| RiRoIAlignRotated | | √ | | |
| RoIAlign | √ | √ | √ | |
| RoIAwarePool3d | | √ | | |
| SAConv2d | | √ | | |
| SigmoidFocalLoss | | √ | √ | |
| SoftmaxFocalLoss | | √ | | |
| SoftNMS | | √ | | |
| Sparse Convolution | | √ | | |
| Synchronized BatchNorm | | √ | | |
| ThreeInterpolate | | √ | | |
| ThreeNN | | √ | | |
| TINShift | | √ | √ | |
| UpFirDn2d | | √ | | |
| Voxelization | √ | √ | | |

View File

@ -13,11 +13,19 @@ This folder contains all non-python code for MMCV custom ops. Please follow the
│ ├── pytorch_cpp_helper.hpp
│ ├── pytorch_cuda_helper.hpp
│ ├── pytorch_device_registry.hpp
│   └── cuda
│   ├── common_cuda_helper.hpp
│   ├── parrots_cudawarpfunction.cuh
│   ├── ...
│   └── ops_cuda_kernel.cuh
│   ├── cuda
│   │ ├── common_cuda_helper.hpp
│   │ ├── parrots_cudawarpfunction.cuh
│   │ ├── ...
│   │ └── ops_cuda_kernel.cuh
|   ├── mps
│   │ ├── MPSLibrary.h
│   │ ├── ...
│   │ └── MPSUtils.h
|   ├── mlu
│   │ └── ...
|   └── utils
│   │ └── ...
├── onnxruntime
│   ├── onnxruntime_register.h
│   ├── onnxruntime_session_options_config_keys.h
@ -41,9 +49,15 @@ This folder contains all non-python code for MMCV custom ops. Please follow the
│   ├── cuda
│   │   ├── ...
│   │   └── ops_cuda.cu
│   └── cpu
│   ├── cpu
│   │   ├── ...
│   │   └── ops.cpp
│   ├── mps
│   │   ├── ...
│   |   └── op_mps.mm
│   └── mlu
│      ├── ...
│      └── ops.cpp
│      └── op_mlu.cpp
└── tensorrt
├── trt_cuda_helper.cuh
├── trt_plugin_helper.hpp
@ -63,13 +77,18 @@ This folder contains all non-python code for MMCV custom ops. Please follow the
- `common`: This directory contains all tools and shared codes.
- `cuda`: The cuda kernels which can be shared by all backends. **HIP** kernel is also here since they have similar syntax.
- `onnxruntime`: **ONNX Runtime** support for custom ops.
- `mps`: The tools used to support MPS ops. **NOTE** that MPS support is **experimental**.
- `mlu`: The MLU kernels used to support [Cambricon](https://www.cambricon.com/) device.
- `utils`: The kernels and utils of spconv.
- `onnxruntime`: **ONNX Runtime** support for custom ops. Has been deprecated, please try the latest custom ops in [MMDeploy](https://github.com/open-mmlab/mmdeploy).
- `cpu`: CPU implementation of supported ops.
- `parrots`: **Parrots** is a deep learning frame for model training and inference. Parrots custom ops are placed in this directory.
- `pytorch`: **PyTorch** custom ops are supported by binding C++ to Python with **pybind11**. The ops implementation and binding codes are placed in this directory.
- `cuda`: This directory contains cuda kernel launchers, which feed memory pointers of tensor to the cuda kernel in `common/cuda`. The launchers provide c++ interface of cuda implementation of corresponding custom ops.
- `cpu`: This directory contain cpu implementations of corresponding custom ops.
- `tensorrt`: **TensorRT** support for custom ops.
- `mlu`: This directory contain launchers of each MLU kernels.
- `mps`: MPS ops implementation and launchers.
- `tensorrt`: **TensorRT** support for custom ops. Has been deprecated, please try the latest custom ops in [MMDeploy](https://github.com/open-mmlab/mmdeploy).
- `plugins`: This directory contains the implementation of the supported custom ops. Some ops might also use shared cuda kernel in `common/cuda`.
## How to add new PyTorch ops?

View File

@ -0,0 +1,64 @@
// Copyright © 2022 Apple Inc.
// This file is modify from:
// https://github.com/pytorch/pytorch/blob/a85d1f0bcdd02cf18d3b0517337458cb51a18cdb/aten/src/ATen/mps/MPSDevice.h
#pragma once
#include <ATen/ATen.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#ifdef __OBJC__
#include <Foundation/Foundation.h>
#include <Metal/Metal.h>
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
typedef id<MTLDevice> MTLDevice_t;
#else
typedef void* MTLDevice;
typedef void* MTLDevice_t;
#endif
using namespace std;
namespace at {
namespace mps {
//-----------------------------------------------------------------
// MPSDevice
//
// MPSDevice is a singleton class that returns the default device
//-----------------------------------------------------------------
class TORCH_API MPSDevice {
public:
/**
* MPSDevice should not be cloneable.
*/
MPSDevice(MPSDevice& other) = delete;
/**
* MPSDevice should not be assignable.
*/
void operator=(const MPSDevice&) = delete;
/**
* Gets single instance of the Device.
*/
static MPSDevice* getInstance();
/**
* Returns the single device.
*/
MTLDevice_t device() { return _mtl_device; }
~MPSDevice();
private:
static MPSDevice* _device;
MTLDevice_t _mtl_device;
MPSDevice();
};
TORCH_API bool is_available();
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
} // namespace mps
} // namespace at

View File

@ -0,0 +1,61 @@
#ifndef _MPS_LIBRARY_H_
#define _MPS_LIBRARY_H_
#include <string>
#include <unordered_map>
#ifdef __OBJC__
#include <Foundation/Foundation.h>
#include <Metal/Metal.h>
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
typedef id<MTLComputePipelineState> MTLComputePipelineState_t;
typedef id<MTLLibrary> MTLLibrary_t;
#else
typedef void* MTLComputePipelineState;
typedef void* MTLComputePipelineState_t;
typedef void* MTLLibrary;
typedef void* MTLLibrary_t;
#endif
class MPSLibrary {
public:
// disable constructor for singleton
static MPSLibrary* createFromUrl(const std::string& library_url);
static MPSLibrary* createFromSource(const std::string& source);
~MPSLibrary();
MTLLibrary_t library() { return _library; }
MTLComputePipelineState_t getComputePipelineState(
const std::string& function_name);
private:
MTLLibrary_t _library;
std::unordered_map<std::string, MTLComputePipelineState_t> _pso_map;
};
class MPSLibraryManager {
public:
// disable constructor for singleton
MPSLibraryManager(const MPSLibraryManager&) = delete;
MPSLibraryManager& operator=(const MPSLibraryManager&) = delete;
MPSLibraryManager(MPSLibraryManager&&) = delete;
MPSLibraryManager& operator=(MPSLibraryManager&&) = delete;
static MPSLibraryManager* getInstance();
bool hasLibrary(const std::string& name);
MPSLibrary* getLibrary(const std::string& library_url);
MPSLibrary* createLibraryFromSouce(const std::string& name,
const std::string& sources);
~MPSLibraryManager();
private:
MPSLibraryManager();
std::unordered_map<std::string, std::unique_ptr<MPSLibrary>> _library_map;
};
#endif

View File

@ -0,0 +1,110 @@
#include "MPSLibrary.h"
#include <c10/util/CallOnce.h>
#include "MPSDevice.h"
static std::unique_ptr<MPSLibraryManager> mps_library_manager;
static c10::once_flag mpsdev_init;
MPSLibraryManager* MPSLibraryManager::getInstance() {
c10::call_once(mpsdev_init, [] {
mps_library_manager = std::unique_ptr<MPSLibraryManager>(new MPSLibraryManager());
});
return mps_library_manager.get();
}
MPSLibraryManager::~MPSLibraryManager() {}
MPSLibraryManager::MPSLibraryManager() {}
bool MPSLibraryManager::hasLibrary(const std::string& name) {
return _library_map.find(name) != _library_map.end();
}
MPSLibrary* MPSLibraryManager::getLibrary(const std::string& library_url) {
if (_library_map.find(library_url) != _library_map.end()) {
return _library_map[library_url].get();
}
_library_map.emplace(std::make_pair(
library_url, std::unique_ptr<MPSLibrary>(MPSLibrary::createFromUrl(library_url))));
return _library_map[library_url].get();
}
MPSLibrary* MPSLibraryManager::createLibraryFromSouce(const std::string& name,
const std::string& source) {
NSString* ns_name = [NSString stringWithCString:name.c_str()];
if (_library_map.find(name) != _library_map.end()) {
NSLog(@"Library %@ already exist.", ns_name);
return nullptr;
}
_library_map.emplace(
std::make_pair(name, std::unique_ptr<MPSLibrary>(MPSLibrary::createFromSource(source))));
return _library_map[name].get();
}
MPSLibrary* MPSLibrary::createFromUrl(const std::string& library_url) {
MPSLibrary* library = new MPSLibrary();
@autoreleasepool {
NSError* error = nil;
// load library and func
NSString* utl_str = [NSString stringWithCString:library_url.c_str()];
NSURL* metal_url = [NSURL fileURLWithPath:utl_str];
library->_library = [at::mps::MPSDevice::getInstance()->device() newLibraryWithURL:metal_url
error:&error];
if (library->_library == nil) {
NSLog(@"Failed to find library, error %@.", error);
exit(1);
}
}
return library;
}
MPSLibrary* MPSLibrary::createFromSource(const std::string& sources) {
MPSLibrary* library = new MPSLibrary();
@autoreleasepool {
NSError* error = nil;
// load library and func
NSString* code_str = [NSString stringWithCString:sources.c_str()];
library->_library = [at::mps::MPSDevice::getInstance()->device() newLibraryWithSource:code_str
options:nil
error:&error];
if (library->_library == nil) {
NSLog(@"Failed to find library, error %@.", error);
exit(1);
}
}
return library;
}
MPSLibrary::~MPSLibrary() {
[_library release];
_library = nil;
}
MTLComputePipelineState_t MPSLibrary::getComputePipelineState(const std::string& function_name) {
if (_pso_map.find(function_name) != _pso_map.end()) {
return _pso_map[function_name];
}
MTLComputePipelineState_t pso;
@autoreleasepool {
NSError* error = nil;
// create function
NSString* function_name_str = [NSString stringWithCString:function_name.c_str()];
id<MTLFunction> func = [_library newFunctionWithName:function_name_str];
if (func == nil) {
NSLog(@"Failed to created pipeline state object, error %@.", error);
exit(1);
}
// create pipeline
pso = [at::mps::MPSDevice::getInstance()->device() newComputePipelineStateWithFunction:func
error:&error];
_pso_map.emplace(std::make_pair(function_name, pso));
}
return _pso_map[function_name];
}

View File

@ -0,0 +1,132 @@
// Copyright © 2022 Apple Inc.
// This file is modify from:
// https://github.com/pytorch/pytorch/blob/a85d1f0bcdd02cf18d3b0517337458cb51a18cdb/aten/src/ATen/mps/MPSStream.h
#pragma once
#include <cstdint>
#include <utility>
#include <c10/core/DeviceGuard.h>
#include <c10/core/Stream.h>
#include <c10/util/Exception.h>
#include "MPSDevice.h"
#ifdef __OBJC__
#include <Foundation/Foundation.h>
#include <Metal/Metal.h>
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
typedef id<MTLCommandQueue> MTLCommandQueue_t;
typedef id<MTLCommandBuffer> MTLCommandBuffer_t;
typedef id<MTLSharedEvent> MTLSharedEvent_t;
typedef id<MTLDevice> MTLDevice_t;
#else
typedef void* MTLCommandQueue_t;
typedef void* MTLCommandQueue;
typedef void* MTLCommandBuffer_t;
typedef void* MTLCommandBuffer;
typedef void* MTLSharedEvent_t;
typedef void* dispatch_queue_t;
typedef void* MTLDevice_t;
#define nil NULL;
#endif
namespace at {
namespace mps {
//-----------------------------------------------------------------
// MPSStream
//-----------------------------------------------------------------
class TORCH_API MPSStream {
public:
enum Unchecked { UNCHECKED };
/// Construct a MPSStream from a Stream. This construction is checked,
/// and will raise an error if the Stream is not, in fact, a MPS stream.
explicit MPSStream(Stream stream);
~MPSStream();
MTLCommandQueue_t commandQueue() const { return _commandQueue; };
dispatch_queue_t queue() const { return _serialQueue; }
MTLCommandBuffer_t commandBuffer();
void commit(bool flush);
void commitAndWait();
void synchronize();
void flush();
/// Get the MPS device index that this stream is associated with.
c10::DeviceIndex device_index() const { return _stream.device_index(); }
MTLCommandQueue_t stream() const { return _commandQueue; };
MTLDevice_t device() const { return [_commandQueue device]; }
/// Explicit conversion to Stream.
Stream unwrap() const { return _stream; }
private:
Stream _stream;
MTLCommandQueue_t _commandQueue = nil;
MTLCommandBuffer_t _commandBuffer = nil;
void _flush(bool commitAndWait) const;
dispatch_queue_t _serialQueue = nullptr;
};
/**
* Get the current MPS stream
*/
TORCH_API MPSStream* getCurrentMPSStream();
/**
* Get the default MPS stream
*/
TORCH_API MPSStream* getDefaultMPSStream();
//-----------------------------------------------------------------
// MPSStreamImpl
//-----------------------------------------------------------------
class TORCH_API MPSStreamImpl {
public:
/**
* Gets single instance of the MPSStream.
*/
static MPSStream* getInstance();
private:
static MPSStream* _stream;
MPSStreamImpl();
};
//-----------------------------------------------------------------
// MPSEvent
//-----------------------------------------------------------------
struct TORCH_API MPSEvent {
MPSEvent();
// MPSEvent(id<MTLDevice> device);
~MPSEvent();
MTLSharedEvent_t event() const { return _event; }
void recordEvent(MPSStream* stream);
void waitForEvent(MPSStream* queue); // waits on the cpu
bool queryEvent();
uint64_t getCurrentValue() { return _currentValue; }
void setCurrentValue(uint64_t currValue) { _currentValue = currValue; }
private:
bool _isRecorded = false;
uint64_t _currentValue = 0;
MTLSharedEvent_t _event;
};
typedef MPSEvent* mpsEvent_t;
} // namespace mps
} // namespace at

View File

@ -0,0 +1,51 @@
#ifndef _MPS_UTILS_H_
#define _MPS_UTILS_H_
#include <torch/extension.h>
#ifdef __OBJC__
#include <Foundation/Foundation.h>
#include <Metal/Metal.h>
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
typedef id<MTLBuffer> MTLBuffer_t;
typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t;
#else
typedef void* MTLBuffer;
typedef void* MTLBuffer_t;
typedef void* MTLComputeCommandEncoder;
typedef void* MTLComputeCommandEncoder_t;
#endif
// utils
static inline MTLBuffer_t getMTLBufferStorage(const at::Tensor& tensor) {
return __builtin_bit_cast(MTLBuffer_t, tensor.storage().data());
}
template <typename T,
std::enable_if_t<!std::is_same<std::decay_t<T>, at::Tensor>::value, bool> = true>
void setMTLArg(MTLComputeCommandEncoder_t encoder, int index, T&& t);
template <typename T,
std::enable_if_t<std::is_same<std::decay_t<T>, at::Tensor>::value, bool> = true>
void setMTLArg(MTLComputeCommandEncoder_t encoder, int index, T&& t) {
[encoder setBuffer:getMTLBufferStorage(t) offset:0 atIndex:index];
}
template <typename T, std::enable_if_t<!std::is_same<std::decay_t<T>, at::Tensor>::value, bool>>
void setMTLArg(MTLComputeCommandEncoder_t encoder, int index, T&& t) {
[encoder setBytes:&t length:sizeof(t) atIndex:index];
}
inline void setMTLArgsImpl(MTLComputeCommandEncoder_t, int) {}
template <typename T, typename... Args>
void setMTLArgsImpl(MTLComputeCommandEncoder_t encoder, int index, T&& t, Args&&... args) {
setMTLArg(encoder, index, std::forward<T>(t));
setMTLArgsImpl(encoder, index + 1, std::forward<Args>(args)...);
}
template <typename... Args>
void setMTLArgs(MTLComputeCommandEncoder_t encoder, MTLComputePipelineState_t pso, Args&&... args) {
[encoder setComputePipelineState:pso];
setMTLArgsImpl(encoder, 0, std::forward<Args>(args)...);
}
#endif

View File

@ -0,0 +1,99 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include "pytorch_device_registry.hpp"
#include "MPSLibrary.h"
#include "MPSStream.h"
#include "MPSUtils.h"
using at::Tensor;
const static std::string kSourceCode = R"(
#include <metal_math>
#include <metal_stdlib>
using namespace metal;
kernel void bbox_overlap_mps_kernel(constant const float4* bboxes1,
constant const float4* bboxes2,
device float* ious,
constant int& num_bbox1,
constant int& num_bbox2,
constant int& mode,
constant bool& aligned,
constant int& offset,
uint index [[thread_position_in_grid]])
{
int base1 = index;
int base2 = index;
if(!aligned){
base1 = index / num_bbox2;
base2 = index % num_bbox2;
}
const float f_offset = float(offset);
const float4 b1 = bboxes1[base1];
const float b1_area = (b1[2]-b1[0]+f_offset)*(b1[3]-b1[1]+f_offset);
const float4 b2 = bboxes2[base2];
const float b2_area = (b2[2]-b2[0]+f_offset)*(b2[3]-b2[1]+f_offset);
const float2 left_top = fmax(b1.xy, b2.xy);
const float2 right_bottom = fmin(b1.zw, b2.zw);
const float2 wh = fmax(right_bottom - left_top + f_offset, 0.0f);
const float interS = wh.x * wh.y;
const float baseS =
fmax(mode == 0 ? b1_area + b2_area - interS : b1_area, f_offset);
ious[index] = interS / baseS;
}
)";
void BBoxOverlapsMPSKernelLauncher(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
const int mode, const bool aligned, const int offset) {
// get stream
auto stream = at::mps::getCurrentMPSStream();
auto library_manager = MPSLibraryManager::getInstance();
MPSLibrary* library;
const static std::string kLibraryName = "bbox_overlap";
if (library_manager->hasLibrary(kLibraryName))
library = library_manager->getLibrary(kLibraryName);
else
library = library_manager->createLibraryFromSouce(kLibraryName, kSourceCode);
auto func_pso = library->getComputePipelineState("bbox_overlap_mps_kernel");
// create command buffer and encoder
MTLCommandBuffer_t command_buffer = stream->commandBuffer();
MTLComputeCommandEncoder_t compute_encoder = [command_buffer computeCommandEncoder];
// set pso and buffer
int output_size = ious.numel();
int num_bbox1 = bboxes1.size(0);
int num_bbox2 = bboxes2.size(0);
int num_elements = output_size;
setMTLArgs(compute_encoder, func_pso, bboxes1, bboxes2, ious, num_bbox1, num_bbox2, mode, aligned,
offset);
// set grid size
MTLSize grid_size = MTLSizeMake(num_elements, 1, 1);
NSUInteger thread_group_size_x = func_pso.maxTotalThreadsPerThreadgroup;
if (thread_group_size_x > num_elements) {
thread_group_size_x = num_elements;
}
MTLSize thread_group_size = MTLSizeMake(thread_group_size_x, 1, 1);
// encoding
[compute_encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size];
[compute_encoder endEncoding];
// commit, not sure if flush is required
stream->commit(false);
}
void bbox_overlaps_mps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, const int mode,
const bool aligned, const int offset) {
BBoxOverlapsMPSKernelLauncher(bboxes1, bboxes2, ious, mode, aligned, offset);
}
void bbox_overlaps_impl(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, const int mode,
const bool aligned, const int offset);
REGISTER_DEVICE_IMPL(bbox_overlaps_impl, MPS, bbox_overlaps_mps);

View File

@ -305,6 +305,30 @@ def get_extensions():
extension = MLUExtension
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu'))
elif (hasattr(torch.backends, 'mps')
and torch.backends.mps.is_available()) or os.getenv(
'FORCE_MPS', '0') == '1':
# objc compiler support
from distutils.unixccompiler import UnixCCompiler
if '.mm' not in UnixCCompiler.src_extensions:
UnixCCompiler.src_extensions.append('.mm')
UnixCCompiler.language_map['.mm'] = 'objc'
define_macros += [('MMCV_WITH_MPS', None)]
extra_compile_args = {}
extra_compile_args['cxx'] = ['-Wall', '-std=c++17']
extra_compile_args['cxx'] += [
'-framework', 'Metal', '-framework', 'Foundation'
]
extra_compile_args['cxx'] += ['-ObjC++']
# src
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \
glob.glob('./mmcv/ops/csrc/common/mps/*.mm') + \
glob.glob('./mmcv/ops/csrc/pytorch/mps/*.mm')
extension = CppExtension
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mps'))
else:
print(f'Compiling {ext_name} only with CPU')
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \

View File

@ -3,7 +3,7 @@ import numpy as np
import pytest
import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
class TestBBox:
@ -43,7 +43,11 @@ class TestBBox:
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'mps',
marks=pytest.mark.skipif(
not IS_MPS_AVAILABLE, reason='requires MPS support'))
])
def test_bbox_overlaps_float(self, device):
self._test_bbox_overlaps(device, dtype=torch.float)