[Enhance] Optimize the performace of ms_deform_attn for MLU device (#2510)

* ms_opt

* ms_opt

* ms_opt

* ms_opt

* ms_opt

* [Feature] ms_deform_attn performance optimization

* [Feature] ms_deform_attn performance optimization

* [Feature] ms_deform_attn performance optimization
pull/2528/head
BinZheng 2023-01-06 15:17:34 +08:00 committed by GitHub
parent f76de9077b
commit fdc052e84b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 1137 additions and 111 deletions

File diff suppressed because it is too large Load Diff

View File

@ -14,7 +14,15 @@
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
void KernelMsDeformAttnForward(
typedef enum {
MS_DEFORM_ATTN_FORWARD_INVALID = 0, /*!< Index is invalid. */
MS_DEFORM_ATTN_FORWARD_DEFAULT =
1, /*!< MLUKernelMsDeformAttnForwardDefault */
MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL =
2, /*!< MLUKernelMsDeformAttnForwardSmallChannel */
} MsDeformAttnForwardPolicy;
void KernelMsDeformAttnForwardDefault(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char* data_value_gdram,
const char* data_spatial_shapes_gdram,
@ -23,7 +31,37 @@ void KernelMsDeformAttnForward(
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char* data_col_gdram);
void KernelMsDeformAttnBackward(
void KernelMsDeformAttnForwardSmallChannel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char* data_value_gdram,
const char* data_spatial_shapes_gdram,
const char* data_level_start_index_gdram,
const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char* data_col_gdram);
typedef enum {
MS_DEFORM_ATTN_BACKWARD_DEFAULT = 0,
MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL = 1,
} MsDeformAttnBackwardKernelPolicy;
MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc(
const int32_t channels, const int32_t num_levels,
const int32_t num_points) {
const int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
const uint64_t max_num = nram_size / sizeof(float);
const uint64_t deal_num =
12 * PAD_UP(channels, 32) + 3 * PAD_UP(num_levels, 32) + 3 * num_points;
if (max_num >= deal_num) {
return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL;
}
return MS_DEFORM_ATTN_BACKWARD_DEFAULT;
}
void KernelMsDeformAttnBackwardDefaultKernel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const float* data_value,
const int32_t* spatial_shapes, const int32_t* data_level_start_index,
@ -32,10 +70,23 @@ void KernelMsDeformAttnBackward(
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
const int32_t num_queries, const int32_t num_points, float* grad_value,
float* grad_sampling_loc, float* grad_attn_weight);
void KernelMsDeformAttnBackwardSmallChannelsKernel(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const float* data_value,
const int32_t* spatial_shapes, const int32_t* data_level_start_index,
const float* data_sampling_loc, const float* data_attn_weight,
const float* grad_output, const int32_t batch, const int32_t spatial_size,
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
const int32_t num_query, const int32_t num_points, float* grad_value,
float* grad_sampling_loc, float* grad_attn_weight);
// policy function
static void policyFuncForward(cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type,
const int batch_size, const int num_queries,
const int num_heads) {
MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc(
cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type, const int32_t batch_size,
const int32_t num_keys, const int32_t num_heads, const int32_t channels,
const int32_t num_levels, const int32_t num_queries,
const int32_t num_points) {
k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
k_dim->y =
MIN((batch_size * num_queries * num_heads + k_dim->x - 1) / k_dim->x,
@ -46,6 +97,15 @@ static void policyFuncForward(cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type,
#else
*k_type = CNRT_FUNC_TYPE_UNION1;
#endif
int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) {
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
} else if (channels > nram_size / 12 / sizeof(float)) {
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
} else {
return MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL;
}
}
// policy function for backward
@ -196,7 +256,9 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value,
// calculate task dimension
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFuncForward(&k_dim, &k_type, batch_size, num_queries, num_heads);
MsDeformAttnForwardPolicy policy = msDeformAttnForwardPolicyFunc(
&k_dim, &k_type, batch_size, num_keys, num_heads, channels, num_levels,
num_queries, num_points);
// get compute queue
auto queue = torch_mlu::getCurQueue();
@ -222,15 +284,33 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value,
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(value.dtype());
// launch kernel
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForward<<<" << k_dim.x
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelMsDeformAttnForward(
switch (policy) {
default: {
VLOG(5) << "MsDeformAttnForward Policy not supported";
}; break;
case MS_DEFORM_ATTN_FORWARD_DEFAULT: {
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardDefault<<<"
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelMsDeformAttnForwardDefault(
k_dim, k_type, queue, data_type, (char*)value_ptr,
(char*)spatial_shapes_ptr, (char*)level_start_index_ptr,
(char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points,
(char*)output_ptr);
break;
}
case MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL: {
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardSmallChannel<<<"
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelMsDeformAttnForwardSmallChannel(
k_dim, k_type, queue, data_type, (char*)value_ptr,
(char*)spatial_shapes_ptr, (char*)level_start_index_ptr,
(char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points,
(char*)output_ptr);
break;
}
}
output = output.view({batch_size, num_queries, num_heads * channels});
return output;
@ -391,14 +471,31 @@ void ms_deform_attn_mlu_backward(
// launch kernel
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelMsDeformAttnBackward(
MsDeformAttnBackwardKernelPolicy kernelPolicy =
msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points);
switch (kernelPolicy) {
default: {
VLOG(5) << "NotImplemented.";
} break;
case MS_DEFORM_ATTN_BACKWARD_DEFAULT: {
KernelMsDeformAttnBackwardDefaultKernel(
k_dim, k_type, queue, data_type, (float*)value_ptr,
(int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr,
(float*)sampling_loc_ptr, (float*)attn_weight_ptr,
(float*)grad_output_ptr, batch_size, num_keys, num_heads, channels,
num_levels, num_queries, num_points, (float*)grad_value_ptr,
(float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr);
} break;
case MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL: {
KernelMsDeformAttnBackwardSmallChannelsKernel(
k_dim, k_type, queue, data_type, (float*)value_ptr,
(int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr,
(float*)sampling_loc_ptr, (float*)attn_weight_ptr,
(float*)grad_output_ptr, batch_size, num_keys, num_heads, channels,
num_levels, num_queries, num_points, (float*)grad_value_ptr,
(float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr);
} break;
}
}
Tensor ms_deform_attn_impl_forward(const Tensor& value,