mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Support NMS with cambricon MLU590 backend (#2401)
* [Feature] Support Nms with cambricon MLU590 backend support 590 for nms * add blankpull/2374/merge
parent
e847cf8ad4
commit
193de43bc8
|
@ -234,7 +234,7 @@ __mlu_func__ void nms_detection_ux(
|
||||||
IN_DT *score_data, const IN_DT *boxes_data, const Addr input_ram,
|
IN_DT *score_data, const IN_DT *boxes_data, const Addr input_ram,
|
||||||
const int input_num_boxes, const int max_output_size,
|
const int input_num_boxes, const int max_output_size,
|
||||||
const float thresh_iou, const float thresh_score, const float offset,
|
const float thresh_iou, const float thresh_score, const float offset,
|
||||||
const int output_mode, const int algo) {
|
const int output_mode, const int algo, char *cdma_gdram) {
|
||||||
exit_flag[0] = 0;
|
exit_flag[0] = 0;
|
||||||
|
|
||||||
IN_DT *sram = (IN_DT *)sram_buffer;
|
IN_DT *sram = (IN_DT *)sram_buffer;
|
||||||
|
@ -321,7 +321,25 @@ __mlu_func__ void nms_detection_ux(
|
||||||
__memcpy(sram, max_box, REDUCE_NUM * sizeof(IN_DT), NRAM2SRAM);
|
__memcpy(sram, max_box, REDUCE_NUM * sizeof(IN_DT), NRAM2SRAM);
|
||||||
}
|
}
|
||||||
__sync_all();
|
__sync_all();
|
||||||
#if __BANG_ARCH__ <= 372
|
#if __BANG_ARCH__ >= 590
|
||||||
|
__memcpy((char *)cdma_gdram + REDUCE_NUM * clusterId * sizeof(IN_DT), sram,
|
||||||
|
REDUCE_NUM * sizeof(IN_DT), SRAM2GDRAM);
|
||||||
|
__sync_all();
|
||||||
|
if (clusterId == 0 && coreId == 0) {
|
||||||
|
__bang_write_zero(inter_x1, NMS_SIZE);
|
||||||
|
__memcpy((char *)inter_x1, (char *)cdma_gdram, sizeof(IN_DT), GDRAM2NRAM,
|
||||||
|
sizeof(IN_DT), REDUCE_NUM * sizeof(IN_DT), clusterDim - 1);
|
||||||
|
__bang_max(max_box, inter_x1, NMS_SIZE);
|
||||||
|
int max_cluster = (sizeof(IN_DT) == sizeof(half))
|
||||||
|
? ((uint16_t *)max_box)[1]
|
||||||
|
: ((uint32_t *)max_box)[1];
|
||||||
|
__memcpy((char *)cdma_gdram,
|
||||||
|
(char *)cdma_gdram + max_cluster * REDUCE_NUM * sizeof(IN_DT),
|
||||||
|
REDUCE_NUM * sizeof(IN_DT), GDRAM2GDRAM);
|
||||||
|
}
|
||||||
|
__sync_all();
|
||||||
|
__memcpy(max_box, cdma_gdram, REDUCE_NUM * sizeof(IN_DT), GDRAM2NRAM);
|
||||||
|
#else
|
||||||
findGlobalMaxBox(max_box, sram, inter_x1);
|
findGlobalMaxBox(max_box, sram, inter_x1);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -380,6 +398,7 @@ __mlu_global__ void MLUUionXKernelNMS(
|
||||||
int input_dwidth = (data_type_input == CNRT_FLOAT32) ? 4 : 2;
|
int input_dwidth = (data_type_input == CNRT_FLOAT32) ? 4 : 2;
|
||||||
int32_t *exit_flag = (int32_t *)((char *)workspace +
|
int32_t *exit_flag = (int32_t *)((char *)workspace +
|
||||||
INFO_NUM * input_num_boxes * input_dwidth);
|
INFO_NUM * input_num_boxes * input_dwidth);
|
||||||
|
char *cdma_addr = (char *)exit_flag + sizeof(int32_t);
|
||||||
int reduce_sram_size = NFU_ALIGN_SIZE * REDUCE_NUM * input_dwidth;
|
int reduce_sram_size = NFU_ALIGN_SIZE * REDUCE_NUM * input_dwidth;
|
||||||
int availbale_sram_size = SIZE_SRAM_BUF - reduce_sram_size;
|
int availbale_sram_size = SIZE_SRAM_BUF - reduce_sram_size;
|
||||||
|
|
||||||
|
@ -409,24 +428,26 @@ __mlu_global__ void MLUUionXKernelNMS(
|
||||||
nms_detection_ux(exit_flag, output_box_num, (uint32_t *)output,
|
nms_detection_ux(exit_flag, output_box_num, (uint32_t *)output,
|
||||||
score_data, boxes_data, input_ram, input_num_boxes,
|
score_data, boxes_data, input_ram, input_num_boxes,
|
||||||
max_output_size, iou_threshold, confidence_threshold,
|
max_output_size, iou_threshold, confidence_threshold,
|
||||||
offset, output_mode, algo);
|
offset, output_mode, algo, cdma_addr);
|
||||||
} else {
|
} else {
|
||||||
nms_detection_ux(exit_flag, output_box_num, (uint32_t *)output,
|
nms_detection_ux(exit_flag, output_box_num, (uint32_t *)output,
|
||||||
(half *)score_data, (half *)boxes_data, input_ram,
|
(half *)score_data, (half *)boxes_data, input_ram,
|
||||||
input_num_boxes, max_output_size, iou_threshold,
|
input_num_boxes, max_output_size, iou_threshold,
|
||||||
confidence_threshold, offset, output_mode, algo);
|
confidence_threshold, offset, output_mode, algo,
|
||||||
|
cdma_addr);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (data_type_input == CNRT_FLOAT32) {
|
if (data_type_input == CNRT_FLOAT32) {
|
||||||
nms_detection_ux(exit_flag, output_box_num, (float *)output, score_data,
|
nms_detection_ux(exit_flag, output_box_num, (float *)output, score_data,
|
||||||
boxes_data, input_ram, input_num_boxes, max_output_size,
|
boxes_data, input_ram, input_num_boxes, max_output_size,
|
||||||
iou_threshold, confidence_threshold, offset, output_mode,
|
iou_threshold, confidence_threshold, offset, output_mode,
|
||||||
algo);
|
algo, cdma_addr);
|
||||||
} else {
|
} else {
|
||||||
nms_detection_ux(exit_flag, output_box_num, (half *)output,
|
nms_detection_ux(exit_flag, output_box_num, (half *)output,
|
||||||
(half *)score_data, (half *)boxes_data, input_ram,
|
(half *)score_data, (half *)boxes_data, input_ram,
|
||||||
input_num_boxes, max_output_size, iou_threshold,
|
input_num_boxes, max_output_size, iou_threshold,
|
||||||
confidence_threshold, offset, output_mode, algo);
|
confidence_threshold, offset, output_mode, algo,
|
||||||
|
cdma_addr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
((uint32_t *)result_num)[0] = output_box_num;
|
((uint32_t *)result_num)[0] = output_box_num;
|
||||||
|
|
|
@ -36,6 +36,26 @@ inline int32_t getJobLimitCapability() {
|
||||||
return (int32_t)ctx_conf_param.unionLimit;
|
return (int32_t)ctx_conf_param.unionLimit;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline int32_t getCoreNumOfJobLimitCapability() {
|
||||||
|
switch (getJobLimitCapability()) {
|
||||||
|
default:
|
||||||
|
return torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster) *
|
||||||
|
getJobLimitCapability();
|
||||||
|
case CN_KERNEL_CLASS_BLOCK:
|
||||||
|
return 1;
|
||||||
|
case CN_KERNEL_CLASS_UNION:
|
||||||
|
return torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
|
||||||
|
case CN_KERNEL_CLASS_UNION2:
|
||||||
|
return torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster) * 2;
|
||||||
|
case CN_KERNEL_CLASS_UNION4:
|
||||||
|
return torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster) * 4;
|
||||||
|
case CN_KERNEL_CLASS_UNION8:
|
||||||
|
return torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster) * 8;
|
||||||
|
case CN_KERNEL_CLASS_UNION16:
|
||||||
|
return torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster) * 16;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#endif // MMCV_WITH_MLU
|
#endif // MMCV_WITH_MLU
|
||||||
|
|
||||||
#endif // PYTORCH_MLU_HELPER_HPP_
|
#endif // PYTORCH_MLU_HELPER_HPP_
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/*************************************************************************
|
/*************************************************************************
|
||||||
* Copyright (C) 2021 by Cambricon.
|
* Copyright (C) 2021 Cambricon.
|
||||||
*
|
*
|
||||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
@ -34,6 +34,7 @@ static cnnlStatus_t policyFunc(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type,
|
||||||
int &core_num_per_class,
|
int &core_num_per_class,
|
||||||
const int input_box_num) {
|
const int input_box_num) {
|
||||||
uint32_t core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
|
uint32_t core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
|
||||||
|
uint32_t cluster_number = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
|
||||||
uint32_t job_limit = getJobLimitCapability();
|
uint32_t job_limit = getJobLimitCapability();
|
||||||
uint32_t core_number = job_limit;
|
uint32_t core_number = job_limit;
|
||||||
|
|
||||||
|
@ -116,7 +117,11 @@ Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold,
|
||||||
} else {
|
} else {
|
||||||
space_size = input_num_boxes * sizeof(float) * info_num + sizeof(float);
|
space_size = input_num_boxes * sizeof(float) * info_num + sizeof(float);
|
||||||
}
|
}
|
||||||
|
#if __BANG_ARCH__ > 370
|
||||||
|
int cluster_num = getCoreNumOfJobLimitCapability() /
|
||||||
|
torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
|
||||||
|
space_size += cluster_number * sizeof(float) * 7;
|
||||||
|
#endif
|
||||||
auto workspace = at::empty(space_size, boxes.options().dtype(at::kByte));
|
auto workspace = at::empty(space_size, boxes.options().dtype(at::kByte));
|
||||||
|
|
||||||
// get compute queue
|
// get compute queue
|
||||||
|
|
Loading…
Reference in New Issue