mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
[Enhance] Optimize the performace of ms_deform_attn for MLU device (#2510)
* ms_opt * ms_opt * ms_opt * ms_opt * ms_opt * [Feature] ms_deform_attn performance optimization * [Feature] ms_deform_attn performance optimization * [Feature] ms_deform_attn performance optimization
This commit is contained in:
parent
f76de9077b
commit
fdc052e84b
File diff suppressed because it is too large
Load Diff
@ -14,7 +14,15 @@
|
|||||||
|
|
||||||
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
|
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
|
||||||
|
|
||||||
void KernelMsDeformAttnForward(
|
typedef enum {
|
||||||
|
MS_DEFORM_ATTN_FORWARD_INVALID = 0, /*!< Index is invalid. */
|
||||||
|
MS_DEFORM_ATTN_FORWARD_DEFAULT =
|
||||||
|
1, /*!< MLUKernelMsDeformAttnForwardDefault */
|
||||||
|
MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL =
|
||||||
|
2, /*!< MLUKernelMsDeformAttnForwardSmallChannel */
|
||||||
|
} MsDeformAttnForwardPolicy;
|
||||||
|
|
||||||
|
void KernelMsDeformAttnForwardDefault(
|
||||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||||
const cnrtDataType_t d_type, const char* data_value_gdram,
|
const cnrtDataType_t d_type, const char* data_value_gdram,
|
||||||
const char* data_spatial_shapes_gdram,
|
const char* data_spatial_shapes_gdram,
|
||||||
@ -23,7 +31,37 @@ void KernelMsDeformAttnForward(
|
|||||||
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
|
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
|
||||||
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
|
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
|
||||||
const int32_t num_points, char* data_col_gdram);
|
const int32_t num_points, char* data_col_gdram);
|
||||||
void KernelMsDeformAttnBackward(
|
void KernelMsDeformAttnForwardSmallChannel(
|
||||||
|
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||||
|
const cnrtDataType_t d_type, const char* data_value_gdram,
|
||||||
|
const char* data_spatial_shapes_gdram,
|
||||||
|
const char* data_level_start_index_gdram,
|
||||||
|
const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram,
|
||||||
|
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
|
||||||
|
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
|
||||||
|
const int32_t num_points, char* data_col_gdram);
|
||||||
|
|
||||||
|
typedef enum {
|
||||||
|
MS_DEFORM_ATTN_BACKWARD_DEFAULT = 0,
|
||||||
|
MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL = 1,
|
||||||
|
} MsDeformAttnBackwardKernelPolicy;
|
||||||
|
|
||||||
|
MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc(
|
||||||
|
const int32_t channels, const int32_t num_levels,
|
||||||
|
const int32_t num_points) {
|
||||||
|
const int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
|
||||||
|
const uint64_t max_num = nram_size / sizeof(float);
|
||||||
|
const uint64_t deal_num =
|
||||||
|
12 * PAD_UP(channels, 32) + 3 * PAD_UP(num_levels, 32) + 3 * num_points;
|
||||||
|
|
||||||
|
if (max_num >= deal_num) {
|
||||||
|
return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL;
|
||||||
|
}
|
||||||
|
|
||||||
|
return MS_DEFORM_ATTN_BACKWARD_DEFAULT;
|
||||||
|
}
|
||||||
|
|
||||||
|
void KernelMsDeformAttnBackwardDefaultKernel(
|
||||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||||
const cnrtDataType_t d_type, const float* data_value,
|
const cnrtDataType_t d_type, const float* data_value,
|
||||||
const int32_t* spatial_shapes, const int32_t* data_level_start_index,
|
const int32_t* spatial_shapes, const int32_t* data_level_start_index,
|
||||||
@ -32,10 +70,23 @@ void KernelMsDeformAttnBackward(
|
|||||||
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
|
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
|
||||||
const int32_t num_queries, const int32_t num_points, float* grad_value,
|
const int32_t num_queries, const int32_t num_points, float* grad_value,
|
||||||
float* grad_sampling_loc, float* grad_attn_weight);
|
float* grad_sampling_loc, float* grad_attn_weight);
|
||||||
|
|
||||||
|
void KernelMsDeformAttnBackwardSmallChannelsKernel(
|
||||||
|
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||||
|
const cnrtDataType_t d_type, const float* data_value,
|
||||||
|
const int32_t* spatial_shapes, const int32_t* data_level_start_index,
|
||||||
|
const float* data_sampling_loc, const float* data_attn_weight,
|
||||||
|
const float* grad_output, const int32_t batch, const int32_t spatial_size,
|
||||||
|
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
|
||||||
|
const int32_t num_query, const int32_t num_points, float* grad_value,
|
||||||
|
float* grad_sampling_loc, float* grad_attn_weight);
|
||||||
|
|
||||||
// policy function
|
// policy function
|
||||||
static void policyFuncForward(cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type,
|
MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc(
|
||||||
const int batch_size, const int num_queries,
|
cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type, const int32_t batch_size,
|
||||||
const int num_heads) {
|
const int32_t num_keys, const int32_t num_heads, const int32_t channels,
|
||||||
|
const int32_t num_levels, const int32_t num_queries,
|
||||||
|
const int32_t num_points) {
|
||||||
k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
|
k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
|
||||||
k_dim->y =
|
k_dim->y =
|
||||||
MIN((batch_size * num_queries * num_heads + k_dim->x - 1) / k_dim->x,
|
MIN((batch_size * num_queries * num_heads + k_dim->x - 1) / k_dim->x,
|
||||||
@ -46,6 +97,15 @@ static void policyFuncForward(cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type,
|
|||||||
#else
|
#else
|
||||||
*k_type = CNRT_FUNC_TYPE_UNION1;
|
*k_type = CNRT_FUNC_TYPE_UNION1;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
|
||||||
|
if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) {
|
||||||
|
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
|
||||||
|
} else if (channels > nram_size / 12 / sizeof(float)) {
|
||||||
|
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
|
||||||
|
} else {
|
||||||
|
return MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// policy function for backward
|
// policy function for backward
|
||||||
@ -196,7 +256,9 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value,
|
|||||||
// calculate task dimension
|
// calculate task dimension
|
||||||
cnrtDim3_t k_dim;
|
cnrtDim3_t k_dim;
|
||||||
cnrtFunctionType_t k_type;
|
cnrtFunctionType_t k_type;
|
||||||
policyFuncForward(&k_dim, &k_type, batch_size, num_queries, num_heads);
|
MsDeformAttnForwardPolicy policy = msDeformAttnForwardPolicyFunc(
|
||||||
|
&k_dim, &k_type, batch_size, num_keys, num_heads, channels, num_levels,
|
||||||
|
num_queries, num_points);
|
||||||
|
|
||||||
// get compute queue
|
// get compute queue
|
||||||
auto queue = torch_mlu::getCurQueue();
|
auto queue = torch_mlu::getCurQueue();
|
||||||
@ -222,15 +284,33 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value,
|
|||||||
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(value.dtype());
|
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(value.dtype());
|
||||||
|
|
||||||
// launch kernel
|
// launch kernel
|
||||||
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForward<<<" << k_dim.x
|
switch (policy) {
|
||||||
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
|
default: {
|
||||||
|
VLOG(5) << "MsDeformAttnForward Policy not supported";
|
||||||
KernelMsDeformAttnForward(
|
}; break;
|
||||||
|
case MS_DEFORM_ATTN_FORWARD_DEFAULT: {
|
||||||
|
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardDefault<<<"
|
||||||
|
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>";
|
||||||
|
KernelMsDeformAttnForwardDefault(
|
||||||
k_dim, k_type, queue, data_type, (char*)value_ptr,
|
k_dim, k_type, queue, data_type, (char*)value_ptr,
|
||||||
(char*)spatial_shapes_ptr, (char*)level_start_index_ptr,
|
(char*)spatial_shapes_ptr, (char*)level_start_index_ptr,
|
||||||
(char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys,
|
(char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys,
|
||||||
num_heads, channels, num_levels, num_queries, num_points,
|
num_heads, channels, num_levels, num_queries, num_points,
|
||||||
(char*)output_ptr);
|
(char*)output_ptr);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL: {
|
||||||
|
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardSmallChannel<<<"
|
||||||
|
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>";
|
||||||
|
KernelMsDeformAttnForwardSmallChannel(
|
||||||
|
k_dim, k_type, queue, data_type, (char*)value_ptr,
|
||||||
|
(char*)spatial_shapes_ptr, (char*)level_start_index_ptr,
|
||||||
|
(char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys,
|
||||||
|
num_heads, channels, num_levels, num_queries, num_points,
|
||||||
|
(char*)output_ptr);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
output = output.view({batch_size, num_queries, num_heads * channels});
|
output = output.view({batch_size, num_queries, num_heads * channels});
|
||||||
return output;
|
return output;
|
||||||
@ -391,14 +471,31 @@ void ms_deform_attn_mlu_backward(
|
|||||||
// launch kernel
|
// launch kernel
|
||||||
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x
|
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x
|
||||||
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
|
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
|
||||||
|
MsDeformAttnBackwardKernelPolicy kernelPolicy =
|
||||||
KernelMsDeformAttnBackward(
|
msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points);
|
||||||
|
switch (kernelPolicy) {
|
||||||
|
default: {
|
||||||
|
VLOG(5) << "NotImplemented.";
|
||||||
|
} break;
|
||||||
|
case MS_DEFORM_ATTN_BACKWARD_DEFAULT: {
|
||||||
|
KernelMsDeformAttnBackwardDefaultKernel(
|
||||||
k_dim, k_type, queue, data_type, (float*)value_ptr,
|
k_dim, k_type, queue, data_type, (float*)value_ptr,
|
||||||
(int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr,
|
(int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr,
|
||||||
(float*)sampling_loc_ptr, (float*)attn_weight_ptr,
|
(float*)sampling_loc_ptr, (float*)attn_weight_ptr,
|
||||||
(float*)grad_output_ptr, batch_size, num_keys, num_heads, channels,
|
(float*)grad_output_ptr, batch_size, num_keys, num_heads, channels,
|
||||||
num_levels, num_queries, num_points, (float*)grad_value_ptr,
|
num_levels, num_queries, num_points, (float*)grad_value_ptr,
|
||||||
(float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr);
|
(float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr);
|
||||||
|
} break;
|
||||||
|
case MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL: {
|
||||||
|
KernelMsDeformAttnBackwardSmallChannelsKernel(
|
||||||
|
k_dim, k_type, queue, data_type, (float*)value_ptr,
|
||||||
|
(int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr,
|
||||||
|
(float*)sampling_loc_ptr, (float*)attn_weight_ptr,
|
||||||
|
(float*)grad_output_ptr, batch_size, num_keys, num_heads, channels,
|
||||||
|
num_levels, num_queries, num_points, (float*)grad_value_ptr,
|
||||||
|
(float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr);
|
||||||
|
} break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor ms_deform_attn_impl_forward(const Tensor& value,
|
Tensor ms_deform_attn_impl_forward(const Tensor& value,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user