mirror of https://github.com/open-mmlab/mmcv.git
[Feature] torch_npu support aclnn and add op (#2997)
parent
4c01b026f0
commit
5494299ba2
|
@ -11,7 +11,7 @@ We implement common ops used in detection, segmentation, etc.
|
|||
| BorderAlign | | √ | | | |
|
||||
| BoxIouRotated | √ | √ | √ | | √ |
|
||||
| BoxIouQuadri | √ | √ | | | |
|
||||
| CARAFE | | √ | √ | | |
|
||||
| CARAFE | | √ | √ | | √ |
|
||||
| ChamferDistance | | √ | | | |
|
||||
| CrissCrossAttention | | √ | | | |
|
||||
| ContourExpand | √ | | | | |
|
||||
|
@ -41,7 +41,7 @@ We implement common ops used in detection, segmentation, etc.
|
|||
| PointsInBoxes | √ | √ | | | |
|
||||
| PointsInPolygons | | √ | | | √ |
|
||||
| PSAMask | √ | √ | √ | | √ |
|
||||
| RotatedFeatureAlign | √ | √ | √ | | |
|
||||
| RotatedFeatureAlign | √ | √ | √ | | √ |
|
||||
| RoIPointPool3d | | √ | √ | | |
|
||||
| RoIPool | | √ | √ | | √ |
|
||||
| RoIAlignRotated | √ | √ | √ | | |
|
||||
|
|
|
@ -12,7 +12,7 @@ MMCV 提供了检测、分割等任务中常用的算子
|
|||
| BoxIouRotated | √ | √ | √ | | √ |
|
||||
| BoxIouQuadri | √ | √ | | | |
|
||||
| CARAFE | | √ | √ | | |
|
||||
| ChamferDistance | | √ | | | |
|
||||
| ChamferDistance | | √ | | | √ |
|
||||
| CrissCrossAttention | | √ | | | |
|
||||
| ContourExpand | √ | | | | |
|
||||
| ConvexIoU | | √ | | | |
|
||||
|
@ -41,7 +41,7 @@ MMCV 提供了检测、分割等任务中常用的算子
|
|||
| PointsInBoxes | √ | √ | | | |
|
||||
| PointsInPolygons | | √ | | | |
|
||||
| PSAMask | √ | √ | √ | | √ |
|
||||
| RotatedFeatureAlign | √ | √ | √ | | |
|
||||
| RotatedFeatureAlign | √ | √ | √ | | √ |
|
||||
| RoIPointPool3d | | √ | √ | | |
|
||||
| RoIPool | | √ | √ | | √ |
|
||||
| RoIAlignRotated | √ | √ | √ | | |
|
||||
|
|
|
@ -44,8 +44,8 @@ class ChamferDistanceFunction(Function):
|
|||
xyz1 = xyz1.contiguous()
|
||||
xyz2 = xyz2.contiguous()
|
||||
|
||||
dist1 = torch.zeros(batch_size, n).to(device)
|
||||
dist2 = torch.zeros(batch_size, m).to(device)
|
||||
dist1 = torch.zeros(batch_size, n).type(xyz1.dtype).to(device)
|
||||
dist2 = torch.zeros(batch_size, m).type(xyz2.dtype).to(device)
|
||||
idx1 = torch.zeros(batch_size, n).type(torch.IntTensor).to(device)
|
||||
idx2 = torch.zeros(batch_size, m).type(torch.IntTensor).to(device)
|
||||
|
||||
|
@ -81,8 +81,8 @@ class ChamferDistanceFunction(Function):
|
|||
device = grad_dist1.device
|
||||
grad_dist1 = grad_dist1.contiguous()
|
||||
grad_dist2 = grad_dist2.contiguous()
|
||||
grad_xyz1 = torch.zeros(xyz1.size()).to(device)
|
||||
grad_xyz2 = torch.zeros(xyz2.size()).to(device)
|
||||
grad_xyz1 = torch.zeros(xyz1.size()).type(xyz1.dtype).to(device)
|
||||
grad_xyz2 = torch.zeros(xyz2.size()).type(xyz2.dtype).to(device)
|
||||
|
||||
ext_module.chamfer_distance_backward(xyz1, xyz2, idx1, idx2,
|
||||
grad_dist1, grad_dist2, grad_xyz1,
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
#include "pytorch_device_registry.hpp"
|
||||
#include "pytorch_npu_util.hpp"
|
||||
|
||||
#define NPU_NAME_SPACE at_npu::native
|
||||
|
||||
|
|
|
@ -0,0 +1,586 @@
|
|||
/******************************************************************************
|
||||
* Copyright (c) 2022 Huawei Technologies Co., Ltd
|
||||
* All rights reserved.
|
||||
*
|
||||
* Licensed under the BSD 3-Clause License (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://opensource.org/licenses/BSD-3-Clause
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
******************************************************************************/
|
||||
|
||||
#ifndef MMCV_OPS_CSRC_COMMON_PYTORCH_NPU_UTIL_HPP_
|
||||
#define MMCV_OPS_CSRC_COMMON_PYTORCH_NPU_UTIL_HPP_
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <acl/acl_base.h>
|
||||
#include <acl/acl_rt.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <dlfcn.h>
|
||||
#include <torch_npu/csrc/framework/utils/CalcuOpUtil.h>
|
||||
#include <torch_npu/csrc/framework/utils/OpAdapter.h>
|
||||
|
||||
#include <functional>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
#include "pytorch_device_registry.hpp"
|
||||
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
|
||||
#include "torch_npu/csrc/core/npu/NPUStream.h"
|
||||
#include "torch_npu/csrc/framework/OpCommand.h"
|
||||
#include "torch_npu/csrc/framework/interface/EnvVariables.h"
|
||||
#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h"
|
||||
#include "torch_npu/csrc/framework/utils/OpPreparation.h"
|
||||
|
||||
#define NPU_NAME_SPACE at_npu::native
|
||||
|
||||
typedef struct aclOpExecutor aclOpExecutor;
|
||||
typedef struct aclTensor aclTensor;
|
||||
typedef struct aclScalar aclScalar;
|
||||
typedef struct aclIntArray aclIntArray;
|
||||
typedef struct aclFloatArray aclFloatArray;
|
||||
typedef struct aclBoolArray aclBoolArray;
|
||||
typedef struct aclTensorList aclTensorList;
|
||||
|
||||
typedef aclTensor *(*_aclCreateTensor)(
|
||||
const int64_t *view_dims, uint64_t view_dims_num, aclDataType data_type,
|
||||
const int64_t *stride, int64_t offset, aclFormat format,
|
||||
const int64_t *storage_dims, uint64_t storage_dims_num, void *tensor_data);
|
||||
typedef aclScalar *(*_aclCreateScalar)(void *value, aclDataType data_type);
|
||||
typedef aclIntArray *(*_aclCreateIntArray)(const int64_t *value, uint64_t size);
|
||||
typedef aclFloatArray *(*_aclCreateFloatArray)(const float *value,
|
||||
uint64_t size);
|
||||
typedef aclBoolArray *(*_aclCreateBoolArray)(const bool *value, uint64_t size);
|
||||
typedef aclTensorList *(*_aclCreateTensorList)(const aclTensor *const *value,
|
||||
uint64_t size);
|
||||
|
||||
typedef int (*_aclDestroyTensor)(const aclTensor *tensor);
|
||||
typedef int (*_aclDestroyScalar)(const aclScalar *scalar);
|
||||
typedef int (*_aclDestroyIntArray)(const aclIntArray *array);
|
||||
typedef int (*_aclDestroyFloatArray)(const aclFloatArray *array);
|
||||
typedef int (*_aclDestroyBoolArray)(const aclBoolArray *array);
|
||||
typedef int (*_aclDestroyTensorList)(const aclTensorList *array);
|
||||
|
||||
constexpr int kHashBufSize = 8192;
|
||||
constexpr int kHashBufMaxSize = kHashBufSize + 1024;
|
||||
extern thread_local char g_hashBuf[kHashBufSize];
|
||||
extern thread_local int g_hashOffset;
|
||||
|
||||
#ifdef MMCV_WITH_XLA
|
||||
#define DEVICE_TYPE at_npu::key::NativeDeviceType
|
||||
#else
|
||||
#define DEVICE_TYPE c10::DeviceType::PrivateUse1
|
||||
#endif
|
||||
|
||||
#define AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(_) \
|
||||
_(at::ScalarType::Byte, ACL_UINT8) \
|
||||
_(at::ScalarType::Char, ACL_INT8) \
|
||||
_(at::ScalarType::Short, ACL_INT16) \
|
||||
_(at::ScalarType::Int, ACL_INT32) \
|
||||
_(at::ScalarType::Long, ACL_INT64) \
|
||||
_(at::ScalarType::Half, ACL_FLOAT16) \
|
||||
_(at::ScalarType::Float, ACL_FLOAT) \
|
||||
_(at::ScalarType::Double, ACL_DOUBLE) \
|
||||
_(at::ScalarType::ComplexHalf, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::ComplexFloat, ACL_COMPLEX64) \
|
||||
_(at::ScalarType::ComplexDouble, ACL_COMPLEX128) \
|
||||
_(at::ScalarType::Bool, ACL_BOOL) \
|
||||
_(at::ScalarType::QInt8, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::QUInt8, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::QInt32, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::BFloat16, ACL_BF16) \
|
||||
_(at::ScalarType::QUInt4x2, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::QUInt2x4, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::Undefined, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::NumOptions, ACL_DT_UNDEFINED)
|
||||
|
||||
constexpr aclDataType kATenScalarTypeToAclDataTypeTable
|
||||
[static_cast<int64_t>(at::ScalarType::NumOptions) + 1] = {
|
||||
#define DEFINE_ENUM(_1, n) n,
|
||||
AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(DEFINE_ENUM)
|
||||
#undef DEFINE_ENUM
|
||||
};
|
||||
|
||||
#define GET_OP_API_FUNC(apiName) \
|
||||
reinterpret_cast<_##apiName>(GetOpApiFuncAddr(#apiName))
|
||||
|
||||
#define MEMCPY_TO_BUF(data_expression, size_expression) \
|
||||
if (g_hashOffset + (size_expression) > kHashBufSize) { \
|
||||
g_hashOffset = kHashBufMaxSize; \
|
||||
return; \
|
||||
} \
|
||||
memcpy(g_hashBuf + g_hashOffset, data_expression, size_expression); \
|
||||
g_hashOffset += size_expression;
|
||||
|
||||
inline const char *GetOpApiLibName(void) { return "libopapi.so"; }
|
||||
|
||||
inline const char *GetCustOpApiLibName(void) { return "libcust_opapi.so"; }
|
||||
|
||||
inline void *GetOpApiFuncAddrInLib(void *handler, const char *libName,
|
||||
const char *apiName) {
|
||||
auto funcAddr = dlsym(handler, apiName);
|
||||
if (funcAddr == nullptr) {
|
||||
ASCEND_LOGW("dlsym %s from %s failed, error:%s.", apiName, libName,
|
||||
dlerror());
|
||||
}
|
||||
return funcAddr;
|
||||
}
|
||||
|
||||
inline void *GetOpApiLibHandler(const char *libName) {
|
||||
auto handler = dlopen(libName, RTLD_LAZY);
|
||||
if (handler == nullptr) {
|
||||
ASCEND_LOGW("dlopen %s failed, error:%s.", libName, dlerror());
|
||||
}
|
||||
return handler;
|
||||
}
|
||||
|
||||
inline void *GetOpApiFuncAddr(const char *apiName) {
|
||||
static auto custOpApiHandler = GetOpApiLibHandler(GetCustOpApiLibName());
|
||||
if (custOpApiHandler != nullptr) {
|
||||
auto funcAddr =
|
||||
GetOpApiFuncAddrInLib(custOpApiHandler, GetCustOpApiLibName(), apiName);
|
||||
if (funcAddr != nullptr) {
|
||||
return funcAddr;
|
||||
}
|
||||
}
|
||||
|
||||
static auto opApiHandler = GetOpApiLibHandler(GetOpApiLibName());
|
||||
if (opApiHandler == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetOpApiFuncAddrInLib(opApiHandler, GetOpApiLibName(), apiName);
|
||||
}
|
||||
|
||||
inline c10::Scalar ConvertTensorToScalar(const at::Tensor &tensor) {
|
||||
c10::Scalar expScalar;
|
||||
const at::Tensor *aclInput = &tensor;
|
||||
if (aclInput->scalar_type() == at::ScalarType::Double) {
|
||||
double value = *(double *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::Long) {
|
||||
int64_t value = *(int64_t *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::Float) {
|
||||
float value = *(float *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::Int) {
|
||||
int value = *(int *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::Half) {
|
||||
c10::Half value = *(c10::Half *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::Bool) {
|
||||
int8_t value = *(int8_t *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::ComplexDouble) {
|
||||
c10::complex<double> value = *(c10::complex<double> *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::ComplexFloat) {
|
||||
c10::complex<float> value = *(c10::complex<float> *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::BFloat16) {
|
||||
c10::BFloat16 value = *(c10::BFloat16 *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
}
|
||||
return expScalar;
|
||||
}
|
||||
|
||||
inline at::Tensor CopyTensorHostToDevice(const at::Tensor &cpu_tensor) {
|
||||
at::Tensor cpuPinMemTensor = cpu_tensor.pin_memory();
|
||||
int deviceIndex = 0;
|
||||
return cpuPinMemTensor.to(c10::Device(DEVICE_TYPE, deviceIndex),
|
||||
cpuPinMemTensor.scalar_type(), true, true);
|
||||
}
|
||||
|
||||
inline at::Tensor CopyScalarToDevice(const c10::Scalar &cpu_scalar,
|
||||
at::ScalarType scalar_data_type) {
|
||||
return CopyTensorHostToDevice(
|
||||
scalar_to_tensor(cpu_scalar).to(scalar_data_type));
|
||||
}
|
||||
|
||||
inline aclTensor *ConvertType(const at::Tensor &at_tensor) {
|
||||
static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
|
||||
if (aclCreateTensor == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!at_tensor.defined()) {
|
||||
return nullptr;
|
||||
}
|
||||
at::ScalarType scalar_data_type = at_tensor.scalar_type();
|
||||
aclDataType acl_data_type =
|
||||
kATenScalarTypeToAclDataTypeTable[static_cast<int64_t>(scalar_data_type)];
|
||||
TORCH_CHECK(
|
||||
acl_data_type != ACL_DT_UNDEFINED,
|
||||
std::string(c10::toString(scalar_data_type)) + " has not been supported")
|
||||
c10::SmallVector<int64_t, 5> storageDims;
|
||||
// if acl_data_type is ACL_STRING, storageDims is empty.
|
||||
auto itemsize = at_tensor.itemsize();
|
||||
if (itemsize == 0) {
|
||||
AT_ERROR("When ConvertType, tensor item size of cannot be zero.");
|
||||
return nullptr;
|
||||
}
|
||||
if (acl_data_type != ACL_STRING) {
|
||||
storageDims.push_back(at_tensor.storage().nbytes() / itemsize);
|
||||
}
|
||||
|
||||
const auto dimNum = at_tensor.sizes().size();
|
||||
aclFormat format = ACL_FORMAT_ND;
|
||||
switch (dimNum) {
|
||||
case 3:
|
||||
format = ACL_FORMAT_NCL;
|
||||
break;
|
||||
case 4:
|
||||
format = ACL_FORMAT_NCHW;
|
||||
break;
|
||||
case 5:
|
||||
format = ACL_FORMAT_NCDHW;
|
||||
break;
|
||||
default:
|
||||
format = ACL_FORMAT_ND;
|
||||
}
|
||||
|
||||
if (at_tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
|
||||
c10::Scalar expScalar = ConvertTensorToScalar(at_tensor);
|
||||
at::Tensor aclInput = CopyScalarToDevice(expScalar, scalar_data_type);
|
||||
return aclCreateTensor(aclInput.sizes().data(), aclInput.sizes().size(),
|
||||
acl_data_type, aclInput.strides().data(),
|
||||
aclInput.storage_offset(), format,
|
||||
storageDims.data(), storageDims.size(),
|
||||
const_cast<void *>(aclInput.storage().data()));
|
||||
}
|
||||
|
||||
auto acl_tensor = aclCreateTensor(
|
||||
at_tensor.sizes().data(), at_tensor.sizes().size(), acl_data_type,
|
||||
at_tensor.strides().data(), at_tensor.storage_offset(), format,
|
||||
storageDims.data(), storageDims.size(),
|
||||
const_cast<void *>(at_tensor.storage().data()));
|
||||
return acl_tensor;
|
||||
}
|
||||
|
||||
inline aclScalar *ConvertType(const at::Scalar &at_scalar) {
|
||||
static const auto aclCreateScalar = GET_OP_API_FUNC(aclCreateScalar);
|
||||
if (aclCreateScalar == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
at::ScalarType scalar_data_type = at_scalar.type();
|
||||
aclDataType acl_data_type =
|
||||
kATenScalarTypeToAclDataTypeTable[static_cast<int64_t>(scalar_data_type)];
|
||||
TORCH_CHECK(
|
||||
acl_data_type != ACL_DT_UNDEFINED,
|
||||
std::string(c10::toString(scalar_data_type)) + " has not been supported")
|
||||
aclScalar *acl_scalar = nullptr;
|
||||
switch (scalar_data_type) {
|
||||
case at::ScalarType::Double: {
|
||||
double value = at_scalar.toDouble();
|
||||
acl_scalar = aclCreateScalar(&value, acl_data_type);
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Long: {
|
||||
int64_t value = at_scalar.toLong();
|
||||
acl_scalar = aclCreateScalar(&value, acl_data_type);
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Bool: {
|
||||
bool value = at_scalar.toBool();
|
||||
acl_scalar = aclCreateScalar(&value, acl_data_type);
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::ComplexDouble: {
|
||||
auto value = at_scalar.toComplexDouble();
|
||||
acl_scalar = aclCreateScalar(&value, acl_data_type);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
acl_scalar = nullptr;
|
||||
break;
|
||||
}
|
||||
return acl_scalar;
|
||||
}
|
||||
|
||||
inline aclIntArray *ConvertType(const at::IntArrayRef &at_array) {
|
||||
static const auto aclCreateIntArray = GET_OP_API_FUNC(aclCreateIntArray);
|
||||
if (aclCreateIntArray == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto array = aclCreateIntArray(at_array.data(), at_array.size());
|
||||
return array;
|
||||
}
|
||||
|
||||
template <std::size_t N>
|
||||
inline aclBoolArray *ConvertType(const std::array<bool, N> &value) {
|
||||
static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray);
|
||||
if (aclCreateBoolArray == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto array = aclCreateBoolArray(value.data(), value.size());
|
||||
return array;
|
||||
}
|
||||
|
||||
inline aclBoolArray *ConvertType(const at::ArrayRef<bool> &value) {
|
||||
static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray);
|
||||
if (aclCreateBoolArray == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto array = aclCreateBoolArray(value.data(), value.size());
|
||||
return array;
|
||||
}
|
||||
|
||||
inline aclTensorList *ConvertType(const at::TensorList &at_tensor_list) {
|
||||
static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList);
|
||||
if (aclCreateTensorList == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<const aclTensor *> tensor_list(at_tensor_list.size());
|
||||
for (size_t i = 0; i < at_tensor_list.size(); i++) {
|
||||
tensor_list[i] = ConvertType(at_tensor_list[i]);
|
||||
}
|
||||
auto acl_tensor_list =
|
||||
aclCreateTensorList(tensor_list.data(), tensor_list.size());
|
||||
return acl_tensor_list;
|
||||
}
|
||||
|
||||
inline aclTensor *ConvertType(const c10::optional<at::Tensor> &opt_tensor) {
|
||||
if (opt_tensor.has_value() && opt_tensor.value().defined()) {
|
||||
return ConvertType(opt_tensor.value());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline aclIntArray *ConvertType(
|
||||
const c10::optional<at::IntArrayRef> &opt_array) {
|
||||
if (opt_array.has_value()) {
|
||||
return ConvertType(opt_array.value());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline aclScalar *ConvertType(const c10::optional<at::Scalar> &opt_scalar) {
|
||||
if (opt_scalar.has_value()) {
|
||||
return ConvertType(opt_scalar.value());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline aclDataType ConvertType(const at::ScalarType scalarType) {
|
||||
return kATenScalarTypeToAclDataTypeTable[static_cast<int64_t>(scalarType)];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T ConvertType(T value) {
|
||||
return value;
|
||||
}
|
||||
|
||||
template <typename Tuple, size_t... I>
|
||||
auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr,
|
||||
std::index_sequence<I...>) {
|
||||
typedef int (*OpApiFunc)(
|
||||
typename std::decay<decltype(std::get<I>(params))>::type...);
|
||||
auto func = reinterpret_cast<OpApiFunc>(opApiAddr);
|
||||
return func;
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr) {
|
||||
static constexpr auto size = std::tuple_size<Tuple>::value;
|
||||
return ConvertToOpApiFunc(params, opApiAddr,
|
||||
std::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
inline void Release(aclTensor *p) {
|
||||
static const auto aclDestroyTensor = GET_OP_API_FUNC(aclDestroyTensor);
|
||||
if (aclDestroyTensor == nullptr) {
|
||||
return;
|
||||
}
|
||||
aclDestroyTensor(p);
|
||||
}
|
||||
|
||||
inline void Release(aclScalar *p) {
|
||||
static const auto aclDestroyScalar = GET_OP_API_FUNC(aclDestroyScalar);
|
||||
if (aclDestroyScalar == nullptr) {
|
||||
return;
|
||||
}
|
||||
aclDestroyScalar(p);
|
||||
}
|
||||
|
||||
inline void Release(aclIntArray *p) {
|
||||
static const auto aclDestroyIntArray = GET_OP_API_FUNC(aclDestroyIntArray);
|
||||
if (aclDestroyIntArray == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
aclDestroyIntArray(p);
|
||||
}
|
||||
|
||||
inline void Release(aclBoolArray *p) {
|
||||
static const auto aclDestroyBoolArray = GET_OP_API_FUNC(aclDestroyBoolArray);
|
||||
if (aclDestroyBoolArray == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
aclDestroyBoolArray(p);
|
||||
}
|
||||
|
||||
inline void Release(aclTensorList *p) {
|
||||
static const auto aclDestroyTensorList =
|
||||
GET_OP_API_FUNC(aclDestroyTensorList);
|
||||
if (aclDestroyTensorList == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
aclDestroyTensorList(p);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Release(T value) {
|
||||
(void)value;
|
||||
}
|
||||
|
||||
template <typename Tuple, size_t... I>
|
||||
void CallRelease(Tuple t, std::index_sequence<I...>) {
|
||||
(void)std::initializer_list<int>{(Release(std::get<I>(t)), 0)...};
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
void ReleaseConvertTypes(Tuple &t) {
|
||||
static constexpr auto size = std::tuple_size<Tuple>::value;
|
||||
CallRelease(t, std::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
constexpr auto ConvertTypes(Ts &... args) {
|
||||
return std::make_tuple(ConvertType(args)...);
|
||||
}
|
||||
|
||||
template <typename Function, typename Tuple, size_t... I>
|
||||
auto call(Function f, Tuple t, std::index_sequence<I...>) {
|
||||
return f(std::get<I>(t)...);
|
||||
}
|
||||
|
||||
template <typename Function, typename Tuple>
|
||||
auto call(Function f, Tuple t) {
|
||||
static constexpr auto size = std::tuple_size<Tuple>::value;
|
||||
return call(f, t, std::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
template <std::size_t N>
|
||||
void AddParamToBuf(const std::array<bool, N> &value) {
|
||||
MEMCPY_TO_BUF(value.data(), value.size() * sizeof(bool));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AddParamToBuf(const T &value) {
|
||||
MEMCPY_TO_BUF(&value, sizeof(T));
|
||||
}
|
||||
|
||||
void AddParamToBuf(const at::Tensor &);
|
||||
void AddParamToBuf(const at::Scalar &);
|
||||
void AddParamToBuf(const at::IntArrayRef &);
|
||||
void AddParamToBuf(const at::ArrayRef<bool> &);
|
||||
void AddParamToBuf(const at::TensorList &);
|
||||
void AddParamToBuf(const c10::optional<at::Tensor> &);
|
||||
void AddParamToBuf(const c10::optional<at::IntArrayRef> &);
|
||||
void AddParamToBuf(const c10::optional<at::Scalar> &);
|
||||
void AddParamToBuf(const at::ScalarType);
|
||||
void AddParamToBuf(const string &);
|
||||
void AddParamToBuf();
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void AddParamToBuf(const T &arg, Args &... args) {
|
||||
AddParamToBuf(arg);
|
||||
AddParamToBuf(args...);
|
||||
}
|
||||
|
||||
uint64_t CalcHashId();
|
||||
typedef int (*InitHugeMemThreadLocal)(void *, bool);
|
||||
typedef void (*UnInitHugeMemThreadLocal)(void *, bool);
|
||||
typedef void (*ReleaseHugeMem)(void *, bool);
|
||||
|
||||
#define EXEC_NPU_CMD(aclnn_api, ...) \
|
||||
do { \
|
||||
static const auto getWorkspaceSizeFuncAddr = \
|
||||
GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \
|
||||
static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \
|
||||
static const auto initMemAddr = \
|
||||
GetOpApiFuncAddr("InitHugeMemThreadLocal"); \
|
||||
static const auto unInitMemAddr = \
|
||||
GetOpApiFuncAddr("UnInitHugeMemThreadLocal"); \
|
||||
static const auto releaseMemAddr = GetOpApiFuncAddr("ReleaseHugeMem"); \
|
||||
TORCH_CHECK( \
|
||||
getWorkspaceSizeFuncAddr != nullptr && opApiFuncAddr != nullptr, \
|
||||
#aclnn_api, " or ", #aclnn_api "GetWorkspaceSize", " not in ", \
|
||||
GetOpApiLibName(), ", or ", GetOpApiLibName(), "not found."); \
|
||||
auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \
|
||||
uint64_t workspace_size = 0; \
|
||||
uint64_t *workspace_size_addr = &workspace_size; \
|
||||
aclOpExecutor *executor = nullptr; \
|
||||
aclOpExecutor **executor_addr = &executor; \
|
||||
InitHugeMemThreadLocal initMemFunc = \
|
||||
reinterpret_cast<InitHugeMemThreadLocal>(initMemAddr); \
|
||||
UnInitHugeMemThreadLocal unInitMemFunc = \
|
||||
reinterpret_cast<UnInitHugeMemThreadLocal>(unInitMemAddr); \
|
||||
if (initMemFunc) { \
|
||||
initMemFunc(nullptr, false); \
|
||||
} \
|
||||
auto converted_params = \
|
||||
ConvertTypes(__VA_ARGS__, workspace_size_addr, executor_addr); \
|
||||
static auto getWorkspaceSizeFunc = \
|
||||
ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \
|
||||
auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \
|
||||
TORCH_CHECK(workspace_status == 0, \
|
||||
"call " #aclnn_api " failed, detail:", aclGetRecentErrMsg()); \
|
||||
void *workspace_addr = nullptr; \
|
||||
if (workspace_size != 0) { \
|
||||
at::TensorOptions options = \
|
||||
at::TensorOptions(torch_npu::utils::get_npu_device_type()); \
|
||||
auto workspace_tensor = \
|
||||
at::empty({workspace_size}, options.dtype(kByte)); \
|
||||
workspace_addr = const_cast<void *>(workspace_tensor.storage().data()); \
|
||||
} \
|
||||
auto acl_call = [converted_params, workspace_addr, workspace_size, \
|
||||
acl_stream, executor]() -> int { \
|
||||
typedef int (*OpApiFunc)(void *, uint64_t, aclOpExecutor *, \
|
||||
const aclrtStream); \
|
||||
OpApiFunc opApiFunc = reinterpret_cast<OpApiFunc>(opApiFuncAddr); \
|
||||
auto api_ret = \
|
||||
opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \
|
||||
TORCH_CHECK(api_ret == 0, "call " #aclnn_api " failed, detail:", \
|
||||
aclGetRecentErrMsg()); \
|
||||
ReleaseConvertTypes(converted_params); \
|
||||
ReleaseHugeMem releaseMemFunc = \
|
||||
reinterpret_cast<ReleaseHugeMem>(releaseMemAddr); \
|
||||
if (releaseMemFunc) { \
|
||||
releaseMemFunc(nullptr, false); \
|
||||
} \
|
||||
return api_ret; \
|
||||
}; \
|
||||
at_npu::native::OpCommand cmd; \
|
||||
cmd.Name(#aclnn_api); \
|
||||
cmd.SetCustomHandler(acl_call); \
|
||||
cmd.Run(); \
|
||||
if (unInitMemFunc) { \
|
||||
unInitMemFunc(nullptr, false); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#endif // MMCV_OPS_CSRC_COMMON_PYTORCH_NPU_UTIL_HPP_
|
|
@ -0,0 +1,41 @@
|
|||
|
||||
#include "pytorch_npu_helper.hpp"
|
||||
|
||||
using namespace NPU_NAME_SPACE;
|
||||
using namespace std;
|
||||
|
||||
void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1,
|
||||
Tensor dist2, Tensor idx1, Tensor idx2) {
|
||||
at::Tensor xyz1 = at::ones_like(XYZ1);
|
||||
at::Tensor xyz2 = at::ones_like(XYZ2);
|
||||
xyz1 = XYZ1.transpose(1, 2).transpose(0, 1);
|
||||
xyz2 = XYZ2.transpose(1, 2).transpose(0, 1);
|
||||
OpCommand cmd;
|
||||
cmd.Name("ChamferDistance")
|
||||
.Input(xyz1)
|
||||
.Input(xyz2)
|
||||
.Output(dist1)
|
||||
.Output(dist2)
|
||||
.Output(idx1)
|
||||
.Output(idx2)
|
||||
.Run();
|
||||
}
|
||||
|
||||
void chamfer_distance_backward_npu(Tensor xyz1, Tensor xyz2, Tensor idx1,
|
||||
Tensor idx2, Tensor grad_dist1,
|
||||
Tensor grad_dist2, Tensor grad_xyz1,
|
||||
Tensor grad_xyz2) {
|
||||
EXEC_NPU_CMD(aclnnChamferDistanceBackward, xyz1, xyz2, idx1, idx2, grad_dist1,
|
||||
grad_dist2, grad_xyz1, grad_xyz2);
|
||||
}
|
||||
|
||||
void chamfer_distance_forward_impl(Tensor XYZ1, Tensor XYZ2, Tensor dist1,
|
||||
Tensor dist2, Tensor idx1, Tensor idx2);
|
||||
REGISTER_NPU_IMPL(chamfer_distance_forward_impl, chamfer_distance_forward_npu);
|
||||
|
||||
void chamfer_distance_backward_impl(Tensor xyz1, Tensor xyz2, Tensor idx1,
|
||||
Tensor idx2, Tensor grad_dist1,
|
||||
Tensor grad_dist2, Tensor grad_xyz1,
|
||||
Tensor grad_xyz2);
|
||||
REGISTER_NPU_IMPL(chamfer_distance_backward_impl,
|
||||
chamfer_distance_backward_npu);
|
|
@ -0,0 +1,13 @@
|
|||
#ifndef MMCV_OPS_CSRC_COMMON__UTIL_HPP_
|
||||
#define MMCV_OPS_CSRC_COMMON__UTIL_HPP_
|
||||
const int SIZE = 8;
|
||||
|
||||
c10::SmallVector<int64_t, SIZE> array_to_vector(c10::IntArrayRef shape) {
|
||||
c10::SmallVector<int64_t, SIZE> shape_small_vec;
|
||||
for (uint64_t i = 0; i < shape.size(); i++) {
|
||||
shape_small_vec.emplace_back(shape[i]);
|
||||
}
|
||||
return shape_small_vec;
|
||||
}
|
||||
|
||||
#endif // MMCV_OPS_CSRC_COMMON__UTIL_HPP_
|
|
@ -1,5 +1,4 @@
|
|||
#include "pytorch_npu_helper.hpp"
|
||||
|
||||
using namespace NPU_NAME_SPACE;
|
||||
using namespace std;
|
||||
|
||||
|
@ -99,15 +98,18 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
|
|||
c10::SmallVector<int64_t, 2> offsets = {0, 0};
|
||||
c10::SmallVector<int64_t, 2> sizes = {n_batch, 1};
|
||||
at::IntArrayRef offset = at::IntArrayRef(offsets);
|
||||
at::IntArrayRef size = at::IntArrayRef(sizes);
|
||||
at::IntArrayRef size_array = at::IntArrayRef(sizes);
|
||||
c10::SmallVector<int64_t, N> output_size;
|
||||
for (uint64_t i = 0; i < size_array.size(); i++) {
|
||||
output_size.emplace_back(size_array[i]);
|
||||
c10::SmallVector<int64_t, 8> offsetVec;
|
||||
for (uint64_t i = 0; i < offset.size(); i++) {
|
||||
offsetVec.emplace_back(offset[i]);
|
||||
}
|
||||
at::Tensor result = at::empty(output_size, op_output.options());
|
||||
c10::SmallVector<int64_t, N> offsetVec = array_to_small_vector(offset);
|
||||
c10::SmallVector<int64_t, N> sizeVec = array_to_small_vector(size_array);
|
||||
cmd.Name("Slice")
|
||||
c10::SmallVector<int64_t, 8> sizeVec;
|
||||
for (uint64_t i = 0; i < size_array.size(); i++) {
|
||||
sizeVec.emplace_back(size_array[i]);
|
||||
}
|
||||
OpCommand cmd2;
|
||||
cmd2.Name("Slice")
|
||||
.Input(op_output)
|
||||
.Input(offsetVec)
|
||||
.Input(sizeVec)
|
||||
|
|
|
@ -16,7 +16,9 @@ Tensor fused_bias_leakyrelu_npu(const Tensor &input, const Tensor &bias,
|
|||
auto input_size = input.sizes();
|
||||
int input_length = input_size.size();
|
||||
c10::SmallVector<int64_t, SIZE> input_size_tmp;
|
||||
input_size_tmp = array_to_small_vector(input_size);
|
||||
for (uint64_t i = 0; i < input_size.size(); i++) {
|
||||
input_size_tmp.emplace_back(input_size[i]);
|
||||
}
|
||||
if (input_length > 1) {
|
||||
for (int i = 0; i < input_length; i++) {
|
||||
if (i != 1) {
|
||||
|
|
|
@ -21,6 +21,50 @@ void gather_points_forward_npu(int b, int c, int n, int npoints,
|
|||
.Attr("batch_dims", batch_dims)
|
||||
.Run();
|
||||
}
|
||||
void gather_points_backward_npu(int b, int c, int n, int npoints,
|
||||
const Tensor grad_out, const Tensor idx,
|
||||
Tensor grad_points) {
|
||||
at::Tensor indices = idx;
|
||||
if (idx.scalar_type() != at::ScalarType::Int) {
|
||||
indices = idx.to(at::kInt);
|
||||
}
|
||||
if (idx.dim() == 0) {
|
||||
indices.unsqueeze_(0);
|
||||
}
|
||||
int64_t dim = 0;
|
||||
auto shape = idx.sizes();
|
||||
c10::SmallVector<int64_t, 8> pad_size;
|
||||
for (uint64_t i = 0; i < shape.size(); i++) {
|
||||
pad_size.emplace_back(shape[i]);
|
||||
}
|
||||
at::Tensor trans_grad_points = grad_points.transpose(1, 2).contiguous();
|
||||
at::Tensor grad_points_view = trans_grad_points.view(
|
||||
{trans_grad_points.sizes()[0] * trans_grad_points.sizes()[1],
|
||||
trans_grad_points.sizes()[2]});
|
||||
at::Tensor trans_grad_out = grad_out.transpose(1, 2).contiguous();
|
||||
trans_grad_out = trans_grad_out.view(
|
||||
{trans_grad_out.sizes()[0] * trans_grad_out.sizes()[1],
|
||||
trans_grad_out.sizes()[2]});
|
||||
auto index = at::arange(0, b);
|
||||
index = index.to(grad_out.device());
|
||||
index = at::mul(index, n);
|
||||
index = index.view({b, 1});
|
||||
index = at::broadcast_to(index, pad_size);
|
||||
indices = at::add(index, indices);
|
||||
indices = indices.view({-1});
|
||||
OpCommand cmd;
|
||||
cmd.Name("InplaceIndexAdd")
|
||||
.Input(grad_points_view)
|
||||
.Input(indices)
|
||||
.Input(trans_grad_out)
|
||||
.Output(grad_points_view)
|
||||
.Attr("axis", dim)
|
||||
.Run();
|
||||
at::Tensor grad_points_result =
|
||||
grad_points_view.view(trans_grad_points.sizes());
|
||||
grad_points_result = grad_points_result.transpose(1, 2);
|
||||
grad_points.copy_(grad_points_result);
|
||||
}
|
||||
|
||||
void gather_points_forward_impl(int b, int c, int n, int npoints,
|
||||
const Tensor points, const Tensor idx,
|
||||
|
|
|
@ -35,9 +35,16 @@ void roi_align_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax_y,
|
|||
int64_t aligned_height_64 = aligned_height;
|
||||
int64_t aligned_width_64 = aligned_width;
|
||||
int64_t sampling_ratio_64 = sampling_ratio;
|
||||
int64_t roi_end_mode = 0;
|
||||
c10::SmallVector<int64_t, SIZE> xdiff_shape =
|
||||
array_to_small_vector(grad_input.sizes());
|
||||
int64_t roi_end_mode = 2;
|
||||
if (!aligned) {
|
||||
LOG(WARNING) << "The [aligned] attr in roi_align_grad op is false";
|
||||
roi_end_mode = 0;
|
||||
}
|
||||
auto shape = grad_input.sizes();
|
||||
c10::SmallVector<int64_t, SIZE> xdiff_shape;
|
||||
for (uint64_t i = 0; i < shape.size(); i++) {
|
||||
xdiff_shape.emplace_back(shape[i]);
|
||||
}
|
||||
OpCommand cmd;
|
||||
cmd.Name("ROIAlignGrad")
|
||||
.Input(grad_output)
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
#include "pytorch_npu_helper.hpp"
|
||||
|
||||
using namespace NPU_NAME_SPACE;
|
||||
using namespace std;
|
||||
|
||||
void rotated_feature_align_forward_impl(const Tensor features,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points, Tensor output);
|
||||
|
||||
void rotated_feature_align_backward_impl(const Tensor top_grad,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points, Tensor bottom_grad);
|
||||
|
||||
void rotated_feature_align_forward_npu(const Tensor features,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points, Tensor output) {
|
||||
int64_t points_ = (int64_t)points;
|
||||
at::Tensor best_bboxes_ = best_bboxes.transpose(2, 3).transpose(1, 2);
|
||||
OpCommand cmd;
|
||||
cmd.Name("RotatedFeatureAlign")
|
||||
.Input(features)
|
||||
.Input(best_bboxes_)
|
||||
.Output(output)
|
||||
.Attr("spatial_scale", spatial_scale)
|
||||
.Attr("points", points_)
|
||||
.Run();
|
||||
}
|
||||
|
||||
void rotated_feature_align_backward_npu(const Tensor top_grad,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points, Tensor bottom_grad) {
|
||||
int64_t points_ = (int64_t)points;
|
||||
at::Tensor best_bboxes_ = best_bboxes.transpose(2, 3).transpose(1, 2);
|
||||
OpCommand cmd;
|
||||
cmd.Name("RotatedFeatureAlignGrad")
|
||||
.Input(top_grad)
|
||||
.Input(best_bboxes_)
|
||||
.Output(bottom_grad)
|
||||
.Attr("spatial_scale", spatial_scale)
|
||||
.Attr("points", points_)
|
||||
.Run();
|
||||
}
|
||||
|
||||
REGISTER_NPU_IMPL(rotated_feature_align_forward_impl,
|
||||
rotated_feature_align_forward_npu);
|
||||
|
||||
REGISTER_NPU_IMPL(rotated_feature_align_backward_impl,
|
||||
rotated_feature_align_backward_npu);
|
|
@ -0,0 +1,59 @@
|
|||
#include "pytorch_npu_helper.hpp"
|
||||
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
|
||||
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
|
||||
|
||||
using namespace NPU_NAME_SPACE;
|
||||
using namespace std;
|
||||
|
||||
void three_interpolate_forward_npu(int b, int c, int m, int n,
|
||||
const Tensor points, const Tensor idx,
|
||||
const Tensor weight, Tensor out) {
|
||||
auto originDtype = points.scalar_type();
|
||||
TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf),
|
||||
"three_interpolate_forward ascend only support fp32 and fp16.");
|
||||
|
||||
auto point_c_trans = points.transpose(1, 2);
|
||||
|
||||
OpCommand cmd;
|
||||
cmd.Name("ThreeInterpolate")
|
||||
.Input(point_c_trans)
|
||||
.Input(idx)
|
||||
.Input(weight)
|
||||
.Output(out)
|
||||
.Run();
|
||||
|
||||
auto output = out.view({b, n, c}).transpose(1, 2);
|
||||
auto res = output.contiguous();
|
||||
out.copy_(res);
|
||||
}
|
||||
|
||||
void three_interpolate_backward_npu(int b, int c, int n, int m,
|
||||
const Tensor grad_out, const Tensor idx,
|
||||
const Tensor weight, Tensor grad_points) {
|
||||
auto originDtype = grad_out.scalar_type();
|
||||
TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf),
|
||||
"three_interpolate_backward ascend only support fp32 and fp16.");
|
||||
|
||||
auto grad_x = at::unsqueeze(grad_out, 3);
|
||||
auto grad_y = at::unsqueeze(grad_points, 3);
|
||||
|
||||
EXEC_NPU_CMD(aclnnThreeInterpolateBackward, grad_x, idx, weight, m, grad_y);
|
||||
|
||||
auto output = at::squeeze(grad_y, 3);
|
||||
auto res = output.contiguous();
|
||||
grad_points.copy_(res);
|
||||
}
|
||||
|
||||
void three_interpolate_forward_impl(int b, int c, int m, int n,
|
||||
const Tensor points, const Tensor idx,
|
||||
const Tensor weight, Tensor out);
|
||||
|
||||
void three_interpolate_backward_impl(int b, int c, int n, int m,
|
||||
const Tensor grad_out, const Tensor idx,
|
||||
const Tensor weight, Tensor grad_points);
|
||||
|
||||
REGISTER_NPU_IMPL(three_interpolate_forward_impl,
|
||||
three_interpolate_forward_npu);
|
||||
|
||||
REGISTER_NPU_IMPL(three_interpolate_backward_impl,
|
||||
three_interpolate_backward_npu);
|
14
setup.py
14
setup.py
|
@ -435,12 +435,22 @@ def get_extensions():
|
|||
elif (os.getenv('FORCE_NPU', '0') == '1'):
|
||||
print(f'Compiling {ext_name} only with CPU and NPU')
|
||||
try:
|
||||
import importlib
|
||||
|
||||
from torch_npu.utils.cpp_extension import NpuExtension
|
||||
extra_compile_args['cxx'] += [
|
||||
'-D__FILENAME__=\"$$(notdir $$(abspath $$<))\"'
|
||||
]
|
||||
extra_compile_args['cxx'] += [
|
||||
'-I' + importlib.util.find_spec(
|
||||
'torch_npu').submodule_search_locations[0] +
|
||||
'/include/third_party/acl/inc'
|
||||
]
|
||||
define_macros += [('MMCV_WITH_NPU', None)]
|
||||
extension = NpuExtension
|
||||
if parse_version(torch.__version__) <= parse_version('2.0.0'):
|
||||
if parse_version(torch.__version__) < parse_version('2.1.0'):
|
||||
define_macros += [('MMCV_WITH_XLA', None)]
|
||||
if parse_version(torch.__version__) > parse_version('2.0.0'):
|
||||
if parse_version(torch.__version__) >= parse_version('2.1.0'):
|
||||
define_macros += [('MMCV_WITH_KPRIVATE', None)]
|
||||
except Exception:
|
||||
raise ImportError('can not find any torch_npu')
|
||||
|
|
|
@ -1,57 +1,72 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.ops import chamfer_distance
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_chamfer_distance():
|
||||
pointset1 = torch.tensor(
|
||||
[[[1.3, 9.39], [2.3, 9.39], [2.3, 10.39], [1.3, 10.39]],
|
||||
[[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]],
|
||||
[[1.6, 9.99], [2.3, 9.99], [2.3, 10.39], [1.6, 10.39]]],
|
||||
device='cuda',
|
||||
requires_grad=True)
|
||||
def chamfer_distance_forward_gloden(xyz1, xyz2, dtype):
|
||||
bs, ns, ss = xyz1.shape
|
||||
dist1 = np.zeros((bs, ns)).astype(torch_type_trans(dtype))
|
||||
dist2 = np.zeros((bs, ns)).astype(torch_type_trans(dtype))
|
||||
idx1 = np.zeros((bs, ns)).astype('int32')
|
||||
idx2 = np.zeros((bs, ns)).astype('int32')
|
||||
for b1 in range(bs):
|
||||
for n1 in range(ns):
|
||||
x1, y1 = xyz1[b1][n1]
|
||||
dist1[b1][n1] = 10000000
|
||||
for n2 in range(ns):
|
||||
x2, y2 = xyz2[b1][n2]
|
||||
dst = (x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2)
|
||||
if dist1[b1][n1] > dst:
|
||||
dist1[b1][n1] = dst
|
||||
idx1[b1][n1] = n2
|
||||
for b1 in range(bs):
|
||||
for n1 in range(ns):
|
||||
x1, y1 = xyz2[b1][n1]
|
||||
dist2[b1][n1] = 10000000
|
||||
for n2 in range(ns):
|
||||
x2, y2 = xyz1[b1][n2]
|
||||
dst = (x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2)
|
||||
if dist2[b1][n1] > dst:
|
||||
dist2[b1][n1] = dst
|
||||
idx2[b1][n1] = n2
|
||||
return [dist1, dist2, idx1, idx2]
|
||||
|
||||
pointset2 = torch.tensor(
|
||||
[[[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]],
|
||||
[[1.3, 9.39], [2.3, 9.39], [2.3, 10.39], [1.3, 10.39]],
|
||||
[[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]]],
|
||||
device='cuda',
|
||||
requires_grad=True)
|
||||
|
||||
expected_dist1 = torch.tensor(
|
||||
[[0.0900, 0.4900, 0.4900, 0.0900], [0.0900, 0.4900, 0.4900, 0.0900],
|
||||
[0.5200, 0.6500, 0.4900, 0.3600]],
|
||||
device='cuda')
|
||||
expected_dist2 = torch.tensor(
|
||||
[[0.0900, 0.4900, 0.4900, 0.0900], [0.0900, 0.4900, 0.4900, 0.0900],
|
||||
[0.7200, 0.8500, 0.4900, 0.3600]],
|
||||
device='cuda')
|
||||
def torch_type_trans(dtype):
|
||||
if dtype == torch.half:
|
||||
return np.float16
|
||||
elif dtype == torch.float32:
|
||||
return np.float32
|
||||
|
||||
expected_pointset1_grad = torch.tensor(
|
||||
[[[0.6000, 0.0000], [-1.4000, 0.0000], [-1.4000, 0.0000],
|
||||
[0.6000, 0.0000]],
|
||||
[[-0.6000, 0.0000], [1.4000, 0.0000], [1.4000, 0.0000],
|
||||
[-0.6000, 0.0000]],
|
||||
[[1.2000, -0.8000], [-1.4000, -0.8000], [-1.4000, 0.0000],
|
||||
[1.2000, 0.0000]]],
|
||||
device='cuda')
|
||||
|
||||
expected_pointset2_grad = torch.tensor(
|
||||
[[[-0.6000, 0.0000], [1.4000, 0.0000], [1.4000, 0.0000],
|
||||
[-0.6000, 0.0000]],
|
||||
[[0.6000, 0.0000], [-1.4000, 0.0000], [-1.4000, 0.0000],
|
||||
[0.6000, 0.0000]],
|
||||
[[0.0000, 0.0000], [0.0000, 0.0000], [2.8000, 0.8000],
|
||||
[-2.4000, 0.8000]]],
|
||||
device='cuda')
|
||||
|
||||
dist1, dist2, idx1, idx2 = chamfer_distance(pointset1, pointset2)
|
||||
dist1.backward(torch.ones_like(dist1))
|
||||
assert torch.allclose(dist1, expected_dist1, 1e-2)
|
||||
assert torch.allclose(dist2, expected_dist2, 1e-2)
|
||||
assert torch.allclose(pointset1.grad.data, expected_pointset1_grad, 1e-2)
|
||||
assert torch.allclose(pointset2.grad.data, expected_pointset2_grad, 1e-2)
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
])
|
||||
@pytest.mark.parametrize('dtype', [torch.half, torch.float32])
|
||||
@pytest.mark.parametrize('shape', [(2, 600, 2), (2, 600, 2)])
|
||||
def test_chamfer_distance_npu_dynamic_shape(dtype, device, shape):
|
||||
bs = shape[0]
|
||||
ns = shape[1]
|
||||
xyz1 = np.random.uniform(-10.0, 10.0,
|
||||
(bs, ns, 2)).astype(torch_type_trans(dtype))
|
||||
xyz2 = np.random.uniform(-10.0, 10.0,
|
||||
(bs, ns, 2)).astype(torch_type_trans(dtype))
|
||||
xyz1_npu = torch.tensor(xyz1, dtype=dtype).to(device)
|
||||
xyz2_npu = torch.tensor(xyz2, dtype=dtype).to(device)
|
||||
expected_output = chamfer_distance_forward_gloden(xyz1, xyz2, dtype)
|
||||
output = chamfer_distance(xyz1_npu, xyz2_npu)
|
||||
assert np.allclose(output[0].cpu().numpy(), expected_output[0], 1e-3, 1e-4)
|
||||
assert np.allclose(output[1].cpu().numpy(), expected_output[1], 1e-3, 1e-4)
|
||||
assert np.allclose(output[2].cpu().numpy(), expected_output[2], 1e-3, 1e-4)
|
||||
assert np.allclose(output[3].cpu().numpy(), expected_output[3], 1e-3, 1e-4)
|
||||
|
|
|
@ -3,7 +3,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
from mmcv.ops import rotated_feature_align
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
@ -17,6 +17,10 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
|||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support')),
|
||||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support')),
|
||||
pytest.param(
|
||||
'cpu',
|
||||
marks=pytest.mark.skipif(
|
||||
|
|
Loading…
Reference in New Issue