mirror of https://github.com/open-mmlab/mmcv.git
[Enhancement] ms_deform_attn performance optimization (#2616)
* ms_opt_v2 * ms_opt_v2_1 * optimize MultiScaleDeformableAttention ops for MLU * ms_opt_v2_1 * [Feature] ms_deform_attn performance optimization V2 * [Feature] ms_deform_attn performance optimization V2 * [Feature] ms_deform_attn performance optimization V2 * [Feature] ms_deform_attn performance optimization V2 * [Feature] ms_deform_attn performance optimization V2 * [Feature] ms_deform_attn performance optimization V2 * [Feature] ms_deform_attn performance optimization V2 --------- Co-authored-by: dongchengwei <dongchengwei@cambricon.com>pull/2642/head
parent
ec63932333
commit
5f122293bf
File diff suppressed because it is too large
Load Diff
|
@ -47,17 +47,17 @@ typedef enum {
|
|||
} MsDeformAttnBackwardKernelPolicy;
|
||||
|
||||
MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc(
|
||||
const int32_t channels, const int32_t num_levels,
|
||||
const int32_t num_points) {
|
||||
const int32_t channels, const int32_t num_levels, const int32_t num_points,
|
||||
const int32_t num_heads) {
|
||||
const int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
|
||||
const uint64_t max_num = nram_size / sizeof(float);
|
||||
const uint64_t deal_num =
|
||||
12 * PAD_UP(channels, 32) + 3 * PAD_UP(num_levels, 32) + 3 * num_points;
|
||||
|
||||
if (max_num >= deal_num) {
|
||||
const int num_hlp = num_heads * num_levels * num_points;
|
||||
int num_per_time_theory = (nram_size - num_levels * sizeof(float) -
|
||||
3 * num_levels * sizeof(int32_t)) /
|
||||
sizeof(float) / (8 * PAD_UP(channels, 32) + 28) /
|
||||
PAD_UP((num_hlp), 32);
|
||||
if (num_per_time_theory >= 1) {
|
||||
return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL;
|
||||
}
|
||||
|
||||
return MS_DEFORM_ATTN_BACKWARD_DEFAULT;
|
||||
}
|
||||
|
||||
|
@ -101,7 +101,8 @@ MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc(
|
|||
int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
|
||||
if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) {
|
||||
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
|
||||
} else if (channels > nram_size / 12 / sizeof(float)) {
|
||||
} else if (channels > nram_size / 12 / sizeof(float) || channels > 96 ||
|
||||
channels < 16) {
|
||||
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
|
||||
} else {
|
||||
return MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL;
|
||||
|
@ -472,7 +473,8 @@ void ms_deform_attn_mlu_backward(
|
|||
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x
|
||||
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
|
||||
MsDeformAttnBackwardKernelPolicy kernelPolicy =
|
||||
msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points);
|
||||
msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points,
|
||||
num_heads);
|
||||
switch (kernelPolicy) {
|
||||
default: {
|
||||
VLOG(5) << "NotImplemented.";
|
||||
|
|
Loading…
Reference in New Issue