mirror of https://github.com/open-mmlab/mmcv.git
[Enhance] Optimize the performace of ms_deform_attn for MLU device (#2510)
* ms_opt * ms_opt * ms_opt * ms_opt * ms_opt * [Feature] ms_deform_attn performance optimization * [Feature] ms_deform_attn performance optimization * [Feature] ms_deform_attn performance optimizationpull/2528/head
parent
f76de9077b
commit
fdc052e84b
File diff suppressed because it is too large
Load Diff
|
@ -14,7 +14,15 @@
|
|||
|
||||
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
|
||||
|
||||
void KernelMsDeformAttnForward(
|
||||
typedef enum {
|
||||
MS_DEFORM_ATTN_FORWARD_INVALID = 0, /*!< Index is invalid. */
|
||||
MS_DEFORM_ATTN_FORWARD_DEFAULT =
|
||||
1, /*!< MLUKernelMsDeformAttnForwardDefault */
|
||||
MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL =
|
||||
2, /*!< MLUKernelMsDeformAttnForwardSmallChannel */
|
||||
} MsDeformAttnForwardPolicy;
|
||||
|
||||
void KernelMsDeformAttnForwardDefault(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const cnrtDataType_t d_type, const char* data_value_gdram,
|
||||
const char* data_spatial_shapes_gdram,
|
||||
|
@ -23,7 +31,37 @@ void KernelMsDeformAttnForward(
|
|||
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
|
||||
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
|
||||
const int32_t num_points, char* data_col_gdram);
|
||||
void KernelMsDeformAttnBackward(
|
||||
void KernelMsDeformAttnForwardSmallChannel(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const cnrtDataType_t d_type, const char* data_value_gdram,
|
||||
const char* data_spatial_shapes_gdram,
|
||||
const char* data_level_start_index_gdram,
|
||||
const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram,
|
||||
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
|
||||
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
|
||||
const int32_t num_points, char* data_col_gdram);
|
||||
|
||||
typedef enum {
|
||||
MS_DEFORM_ATTN_BACKWARD_DEFAULT = 0,
|
||||
MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL = 1,
|
||||
} MsDeformAttnBackwardKernelPolicy;
|
||||
|
||||
MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc(
|
||||
const int32_t channels, const int32_t num_levels,
|
||||
const int32_t num_points) {
|
||||
const int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
|
||||
const uint64_t max_num = nram_size / sizeof(float);
|
||||
const uint64_t deal_num =
|
||||
12 * PAD_UP(channels, 32) + 3 * PAD_UP(num_levels, 32) + 3 * num_points;
|
||||
|
||||
if (max_num >= deal_num) {
|
||||
return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL;
|
||||
}
|
||||
|
||||
return MS_DEFORM_ATTN_BACKWARD_DEFAULT;
|
||||
}
|
||||
|
||||
void KernelMsDeformAttnBackwardDefaultKernel(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const cnrtDataType_t d_type, const float* data_value,
|
||||
const int32_t* spatial_shapes, const int32_t* data_level_start_index,
|
||||
|
@ -32,10 +70,23 @@ void KernelMsDeformAttnBackward(
|
|||
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
|
||||
const int32_t num_queries, const int32_t num_points, float* grad_value,
|
||||
float* grad_sampling_loc, float* grad_attn_weight);
|
||||
|
||||
void KernelMsDeformAttnBackwardSmallChannelsKernel(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const cnrtDataType_t d_type, const float* data_value,
|
||||
const int32_t* spatial_shapes, const int32_t* data_level_start_index,
|
||||
const float* data_sampling_loc, const float* data_attn_weight,
|
||||
const float* grad_output, const int32_t batch, const int32_t spatial_size,
|
||||
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
|
||||
const int32_t num_query, const int32_t num_points, float* grad_value,
|
||||
float* grad_sampling_loc, float* grad_attn_weight);
|
||||
|
||||
// policy function
|
||||
static void policyFuncForward(cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type,
|
||||
const int batch_size, const int num_queries,
|
||||
const int num_heads) {
|
||||
MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc(
|
||||
cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type, const int32_t batch_size,
|
||||
const int32_t num_keys, const int32_t num_heads, const int32_t channels,
|
||||
const int32_t num_levels, const int32_t num_queries,
|
||||
const int32_t num_points) {
|
||||
k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
|
||||
k_dim->y =
|
||||
MIN((batch_size * num_queries * num_heads + k_dim->x - 1) / k_dim->x,
|
||||
|
@ -46,6 +97,15 @@ static void policyFuncForward(cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type,
|
|||
#else
|
||||
*k_type = CNRT_FUNC_TYPE_UNION1;
|
||||
#endif
|
||||
|
||||
int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
|
||||
if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) {
|
||||
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
|
||||
} else if (channels > nram_size / 12 / sizeof(float)) {
|
||||
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
|
||||
} else {
|
||||
return MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL;
|
||||
}
|
||||
}
|
||||
|
||||
// policy function for backward
|
||||
|
@ -196,7 +256,9 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value,
|
|||
// calculate task dimension
|
||||
cnrtDim3_t k_dim;
|
||||
cnrtFunctionType_t k_type;
|
||||
policyFuncForward(&k_dim, &k_type, batch_size, num_queries, num_heads);
|
||||
MsDeformAttnForwardPolicy policy = msDeformAttnForwardPolicyFunc(
|
||||
&k_dim, &k_type, batch_size, num_keys, num_heads, channels, num_levels,
|
||||
num_queries, num_points);
|
||||
|
||||
// get compute queue
|
||||
auto queue = torch_mlu::getCurQueue();
|
||||
|
@ -222,15 +284,33 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value,
|
|||
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(value.dtype());
|
||||
|
||||
// launch kernel
|
||||
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForward<<<" << k_dim.x
|
||||
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
|
||||
|
||||
KernelMsDeformAttnForward(
|
||||
switch (policy) {
|
||||
default: {
|
||||
VLOG(5) << "MsDeformAttnForward Policy not supported";
|
||||
}; break;
|
||||
case MS_DEFORM_ATTN_FORWARD_DEFAULT: {
|
||||
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardDefault<<<"
|
||||
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>";
|
||||
KernelMsDeformAttnForwardDefault(
|
||||
k_dim, k_type, queue, data_type, (char*)value_ptr,
|
||||
(char*)spatial_shapes_ptr, (char*)level_start_index_ptr,
|
||||
(char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys,
|
||||
num_heads, channels, num_levels, num_queries, num_points,
|
||||
(char*)output_ptr);
|
||||
break;
|
||||
}
|
||||
case MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL: {
|
||||
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardSmallChannel<<<"
|
||||
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>";
|
||||
KernelMsDeformAttnForwardSmallChannel(
|
||||
k_dim, k_type, queue, data_type, (char*)value_ptr,
|
||||
(char*)spatial_shapes_ptr, (char*)level_start_index_ptr,
|
||||
(char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys,
|
||||
num_heads, channels, num_levels, num_queries, num_points,
|
||||
(char*)output_ptr);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
output = output.view({batch_size, num_queries, num_heads * channels});
|
||||
return output;
|
||||
|
@ -391,14 +471,31 @@ void ms_deform_attn_mlu_backward(
|
|||
// launch kernel
|
||||
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x
|
||||
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
|
||||
|
||||
KernelMsDeformAttnBackward(
|
||||
MsDeformAttnBackwardKernelPolicy kernelPolicy =
|
||||
msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points);
|
||||
switch (kernelPolicy) {
|
||||
default: {
|
||||
VLOG(5) << "NotImplemented.";
|
||||
} break;
|
||||
case MS_DEFORM_ATTN_BACKWARD_DEFAULT: {
|
||||
KernelMsDeformAttnBackwardDefaultKernel(
|
||||
k_dim, k_type, queue, data_type, (float*)value_ptr,
|
||||
(int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr,
|
||||
(float*)sampling_loc_ptr, (float*)attn_weight_ptr,
|
||||
(float*)grad_output_ptr, batch_size, num_keys, num_heads, channels,
|
||||
num_levels, num_queries, num_points, (float*)grad_value_ptr,
|
||||
(float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr);
|
||||
} break;
|
||||
case MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL: {
|
||||
KernelMsDeformAttnBackwardSmallChannelsKernel(
|
||||
k_dim, k_type, queue, data_type, (float*)value_ptr,
|
||||
(int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr,
|
||||
(float*)sampling_loc_ptr, (float*)attn_weight_ptr,
|
||||
(float*)grad_output_ptr, batch_size, num_keys, num_heads, channels,
|
||||
num_levels, num_queries, num_points, (float*)grad_value_ptr,
|
||||
(float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor ms_deform_attn_impl_forward(const Tensor& value,
|
||||
|
|
Loading…
Reference in New Issue