mirror of https://github.com/open-mmlab/mmcv.git
[Refactor] Replace tin_shift op of MLU backend with mlu-ops (#2911)
parent
a0a17050f2
commit
2491dbb7d0
|
@ -1,307 +0,0 @@
|
|||
/*************************************************************************
|
||||
* Copyright (C) 2022 Cambricon.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "common_mlu_helper.hpp"
|
||||
|
||||
__nram__ char data_nram[MAX_NRAM_SIZE];
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void mluMultiKernelTinShift(
|
||||
const T *input, const int *shifts, T *output, const int batch_size,
|
||||
const int time_size, const int channel_size, const int hw_size,
|
||||
const int group_size, const int group_channel) {
|
||||
for (int cur_channel_index = taskId;
|
||||
cur_channel_index < batch_size * channel_size;
|
||||
cur_channel_index += taskDim) {
|
||||
int n_index = cur_channel_index / channel_size;
|
||||
int group_id = cur_channel_index % channel_size / group_channel;
|
||||
int t_shift = shifts[n_index * group_size + group_id];
|
||||
int index = cur_channel_index % channel_size * hw_size +
|
||||
n_index * time_size * channel_size * hw_size;
|
||||
__bang_write_value(data_nram, MAX_NRAM_SIZE, (char)0);
|
||||
__asm__ volatile("sync;");
|
||||
if (abs(t_shift) >= time_size) {
|
||||
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
|
||||
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
|
||||
time_size - 1);
|
||||
} else {
|
||||
if (t_shift > 0) {
|
||||
__memcpy(data_nram + t_shift * hw_size * sizeof(T), input + index,
|
||||
hw_size * sizeof(T), GDRAM2NRAM, hw_size * sizeof(T),
|
||||
channel_size * hw_size * sizeof(T), time_size - 1 - t_shift);
|
||||
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
|
||||
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
|
||||
time_size - 1);
|
||||
} else {
|
||||
__memcpy(data_nram, input + (index - t_shift * channel_size * hw_size),
|
||||
hw_size * sizeof(T), GDRAM2NRAM, hw_size * sizeof(T),
|
||||
channel_size * hw_size * sizeof(T), time_size - 1 + t_shift);
|
||||
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
|
||||
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
|
||||
time_size - 1);
|
||||
}
|
||||
}
|
||||
__asm__ volatile("sync;");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void mluHwSplit(const T *input, const int t_shift,
|
||||
const int time_size, const int hw_size,
|
||||
const int channel_size, const int index,
|
||||
const int cur_sequence_index,
|
||||
const int max_length_per_core, T *output) {
|
||||
for (int cur_index = index; cur_index < index + hw_size;
|
||||
cur_index += max_length_per_core) {
|
||||
int memcpy_size = max_length_per_core;
|
||||
if (cur_index + max_length_per_core > index + hw_size) {
|
||||
memcpy_size = index + hw_size - cur_index;
|
||||
}
|
||||
if (cur_sequence_index - t_shift < 0 ||
|
||||
cur_sequence_index - t_shift >= time_size) {
|
||||
__memcpy(output + cur_index, data_nram, memcpy_size * sizeof(T),
|
||||
NRAM2GDRAM);
|
||||
} else {
|
||||
__memcpy(data_nram, input + cur_index - t_shift * channel_size * hw_size,
|
||||
memcpy_size * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy(output + cur_index, data_nram, memcpy_size * sizeof(T),
|
||||
NRAM2GDRAM);
|
||||
}
|
||||
__asm__ volatile("sync;");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void mluMultiKernelTinShiftSplitSequence(
|
||||
const T *input, const int *shifts, T *output, const int batch_size,
|
||||
const int time_size, const int channel_size, const int hw_size,
|
||||
const int group_size, const int group_channel,
|
||||
const int max_number_hw_per_core, const int max_length_per_core) {
|
||||
const int tmp_max_number_hw_per_core =
|
||||
max_number_hw_per_core > 0 ? max_number_hw_per_core : 1;
|
||||
const int loop_time = time_size / tmp_max_number_hw_per_core +
|
||||
((time_size % tmp_max_number_hw_per_core) > 0 ? 1 : 0);
|
||||
int segmentime_size = tmp_max_number_hw_per_core;
|
||||
int res_segment = time_size % tmp_max_number_hw_per_core;
|
||||
|
||||
for (int cur_segment_index = taskId;
|
||||
cur_segment_index < loop_time * batch_size * channel_size;
|
||||
cur_segment_index += taskDim) {
|
||||
int n_index = cur_segment_index / loop_time / channel_size;
|
||||
int group_id = cur_segment_index / loop_time % channel_size / group_channel;
|
||||
int t_shift = shifts[n_index * group_size + group_id];
|
||||
int index = n_index * time_size * channel_size * hw_size +
|
||||
(cur_segment_index / loop_time % channel_size) * hw_size +
|
||||
cur_segment_index % loop_time * segmentime_size * hw_size *
|
||||
channel_size;
|
||||
char *dst_gdram2nram = data_nram;
|
||||
const T *src_gdram2nram = input + index;
|
||||
int count_gdram2nram = -1;
|
||||
int count_nram2gdram = -1;
|
||||
int next_sequence_index =
|
||||
index / hw_size / channel_size % time_size + segmentime_size;
|
||||
int cur_sequence_index = index / hw_size / channel_size % time_size;
|
||||
__bang_write_value(data_nram, MAX_NRAM_SIZE, (char)0);
|
||||
__asm__ volatile("sync;");
|
||||
if (max_number_hw_per_core == 0) {
|
||||
mluHwSplit(input, t_shift, time_size, hw_size, channel_size, index,
|
||||
cur_sequence_index, max_length_per_core, output);
|
||||
continue;
|
||||
}
|
||||
if (abs(t_shift) >= time_size) {
|
||||
if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) {
|
||||
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
|
||||
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
|
||||
res_segment - 1);
|
||||
} else {
|
||||
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
|
||||
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
|
||||
segmentime_size - 1);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (t_shift == 0) {
|
||||
if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) {
|
||||
dst_gdram2nram = data_nram;
|
||||
src_gdram2nram = input + index;
|
||||
count_gdram2nram = res_segment - 1;
|
||||
count_nram2gdram = res_segment - 1;
|
||||
} else {
|
||||
dst_gdram2nram = data_nram;
|
||||
src_gdram2nram = input + index;
|
||||
count_gdram2nram = segmentime_size - 1;
|
||||
count_nram2gdram = segmentime_size - 1;
|
||||
}
|
||||
} else if (t_shift > 0) {
|
||||
int first_index_cur_channel =
|
||||
n_index * time_size * channel_size * hw_size +
|
||||
(cur_segment_index / loop_time % channel_size) * hw_size;
|
||||
if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) {
|
||||
dst_gdram2nram = data_nram;
|
||||
src_gdram2nram =
|
||||
input +
|
||||
(index - t_shift * channel_size * hw_size < first_index_cur_channel
|
||||
? first_index_cur_channel
|
||||
: index - t_shift * channel_size * hw_size);
|
||||
count_gdram2nram = res_segment - 1;
|
||||
count_nram2gdram = res_segment - 1;
|
||||
if (cur_sequence_index < t_shift && t_shift < next_sequence_index) {
|
||||
dst_gdram2nram =
|
||||
data_nram + t_shift % segmentime_size * hw_size * sizeof(T);
|
||||
count_gdram2nram = res_segment - (t_shift - cur_sequence_index) - 1;
|
||||
}
|
||||
} else {
|
||||
if (t_shift >= next_sequence_index) {
|
||||
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
|
||||
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
|
||||
segmentime_size - 1);
|
||||
continue;
|
||||
} else if (cur_sequence_index < t_shift &&
|
||||
t_shift < next_sequence_index) {
|
||||
dst_gdram2nram =
|
||||
data_nram + t_shift % segmentime_size * hw_size * sizeof(T);
|
||||
src_gdram2nram = input + first_index_cur_channel;
|
||||
count_gdram2nram = segmentime_size - (t_shift % segmentime_size) - 1;
|
||||
count_nram2gdram = segmentime_size - 1;
|
||||
} else {
|
||||
dst_gdram2nram = data_nram;
|
||||
src_gdram2nram = input + index - t_shift * channel_size * hw_size;
|
||||
count_gdram2nram = segmentime_size - 1;
|
||||
count_nram2gdram = segmentime_size - 1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int offset_index = time_size + t_shift;
|
||||
if (cur_sequence_index >= offset_index) {
|
||||
if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) {
|
||||
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
|
||||
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
|
||||
res_segment - 1);
|
||||
continue;
|
||||
} else {
|
||||
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
|
||||
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
|
||||
segmentime_size - 1);
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
dst_gdram2nram = data_nram;
|
||||
src_gdram2nram = input + index - t_shift * channel_size * hw_size;
|
||||
if (cur_sequence_index - t_shift + segmentime_size < time_size) {
|
||||
count_gdram2nram = segmentime_size - 1;
|
||||
count_nram2gdram = segmentime_size - 1;
|
||||
} else {
|
||||
count_gdram2nram = time_size - (cur_sequence_index - t_shift) - 1;
|
||||
count_nram2gdram =
|
||||
(segmentime_size - 1) < (time_size - cur_sequence_index - 1)
|
||||
? (segmentime_size - 1)
|
||||
: (time_size - cur_sequence_index - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
__memcpy(dst_gdram2nram, src_gdram2nram, hw_size * sizeof(T), GDRAM2NRAM,
|
||||
hw_size * sizeof(T), channel_size * hw_size * sizeof(T),
|
||||
count_gdram2nram);
|
||||
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
|
||||
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
|
||||
count_nram2gdram);
|
||||
__asm__ volatile("sync;");
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_entry__ void MLUUnion1KernelTinShift(
|
||||
const void *input, const void *shifts, void *output, const int batch_size,
|
||||
const int time_size, const int channel_size, const int hw_size,
|
||||
const int group_size, const int group_channel,
|
||||
const cnrtDataType_t data_dtype) {
|
||||
// make sure that memcore is not used
|
||||
if (coreId == 0x80) {
|
||||
return;
|
||||
}
|
||||
switch (data_dtype) {
|
||||
case CNRT_FLOAT16: {
|
||||
mluMultiKernelTinShift((half *)input, (const int *)shifts, (half *)output,
|
||||
batch_size, time_size, channel_size, hw_size,
|
||||
group_size, group_channel);
|
||||
}; break;
|
||||
case CNRT_FLOAT32: {
|
||||
mluMultiKernelTinShift((float *)input, (const int *)shifts,
|
||||
(float *)output, batch_size, time_size,
|
||||
channel_size, hw_size, group_size, group_channel);
|
||||
}; break;
|
||||
default: { return; }
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_entry__ void MLUUnion1KernelTinShiftSplitSequence(
|
||||
const void *input, const void *shifts, void *output, const int batch_size,
|
||||
const int time_size, const int channel_size, const int hw_size,
|
||||
const int group_size, const int group_channel,
|
||||
const int max_number_hw_per_core, const int max_length_per_core,
|
||||
const cnrtDataType_t data_dtype) {
|
||||
// make sure that memcore is not used
|
||||
if (coreId == 0x80) {
|
||||
return;
|
||||
}
|
||||
switch (data_dtype) {
|
||||
case CNRT_FLOAT16: {
|
||||
mluMultiKernelTinShiftSplitSequence(
|
||||
(half *)input, (const int *)shifts, (half *)output, batch_size,
|
||||
time_size, channel_size, hw_size, group_size, group_channel,
|
||||
max_number_hw_per_core, max_length_per_core);
|
||||
}; break;
|
||||
case CNRT_FLOAT32: {
|
||||
mluMultiKernelTinShiftSplitSequence(
|
||||
(float *)input, (const int *)shifts, (float *)output, batch_size,
|
||||
time_size, channel_size, hw_size, group_size, group_channel,
|
||||
max_number_hw_per_core, max_length_per_core);
|
||||
}; break;
|
||||
default: { return; }
|
||||
}
|
||||
}
|
||||
|
||||
void KernelTinShiftForward(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const void *input, const void *shifts, void *output, const int batch_size,
|
||||
const int time_size, const int channel_size, const int hw_size,
|
||||
const int group_size, const int group_channel,
|
||||
const cnrtDataType_t data_dtype, const int channel_per_core,
|
||||
const int max_number_hw_per_core, const int max_length_per_core) {
|
||||
if (channel_per_core >= 1) {
|
||||
MLUUnion1KernelTinShift<<<k_dim, k_type, queue>>>(
|
||||
input, shifts, output, batch_size, time_size, channel_size, hw_size,
|
||||
group_size, group_channel, data_dtype);
|
||||
} else {
|
||||
MLUUnion1KernelTinShiftSplitSequence<<<k_dim, k_type, queue>>>(
|
||||
input, shifts, output, batch_size, time_size, channel_size, hw_size,
|
||||
group_size, group_channel, max_number_hw_per_core, max_length_per_core,
|
||||
data_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
void KernelTinShiftBackward(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const void *grad_output, const void *shifts, void *grad_input,
|
||||
const int batch_size, const int time_size, const int channel_size,
|
||||
const int hw_size, const int group_size, const int group_channel,
|
||||
const cnrtDataType_t data_dtype, const int channel_per_core,
|
||||
const int max_number_hw_per_core, const int max_length_per_core) {
|
||||
if (channel_per_core >= 1) {
|
||||
MLUUnion1KernelTinShift<<<k_dim, k_type, queue>>>(
|
||||
grad_output, shifts, grad_input, batch_size, time_size, channel_size,
|
||||
hw_size, group_size, group_channel, data_dtype);
|
||||
} else {
|
||||
MLUUnion1KernelTinShiftSplitSequence<<<k_dim, k_type, queue>>>(
|
||||
grad_output, shifts, grad_input, batch_size, time_size, channel_size,
|
||||
hw_size, group_size, group_channel, max_number_hw_per_core,
|
||||
max_length_per_core, data_dtype);
|
||||
}
|
||||
}
|
|
@ -14,9 +14,9 @@
|
|||
|
||||
#include "mlu_common_helper.h"
|
||||
|
||||
void sigmoid_focal_loss_forward_mlu(Tensor input, Tensor target,
|
||||
Tensor weight, Tensor output,
|
||||
const float gamma, const float alpha) {
|
||||
void sigmoid_focal_loss_forward_mlu(Tensor input, Tensor target, Tensor weight,
|
||||
Tensor output, const float gamma,
|
||||
const float alpha) {
|
||||
// params check
|
||||
TORCH_CHECK(gamma >= 0, "gamma should be greater than or equal to 0. ",
|
||||
"But now gamma is ", gamma, ".");
|
||||
|
@ -82,15 +82,15 @@ void sigmoid_focal_loss_forward_mlu(Tensor input, Tensor target,
|
|||
auto handle = mluOpGetCurrentHandle();
|
||||
|
||||
// launch kernel
|
||||
TORCH_MLUOP_CHECK(mluOpFocalLossSigmoidForward(handle, prefer, reduction, input_desc.desc(),
|
||||
input_ptr, target_desc.desc(), target_ptr,
|
||||
weight_desc.desc(), weight_ptr, alpha, gamma,
|
||||
output_desc.desc(), output_ptr));
|
||||
TORCH_MLUOP_CHECK(mluOpFocalLossSigmoidForward(
|
||||
handle, prefer, reduction, input_desc.desc(), input_ptr,
|
||||
target_desc.desc(), target_ptr, weight_desc.desc(), weight_ptr, alpha,
|
||||
gamma, output_desc.desc(), output_ptr));
|
||||
}
|
||||
|
||||
void sigmoid_focal_loss_backward_mlu(Tensor input, Tensor target,
|
||||
Tensor weight, Tensor output,
|
||||
const float gamma, const float alpha) {
|
||||
void sigmoid_focal_loss_backward_mlu(Tensor input, Tensor target, Tensor weight,
|
||||
Tensor output, const float gamma,
|
||||
const float alpha) {
|
||||
// params check
|
||||
TORCH_CHECK(gamma >= 0, "gamma should be greater than or equal to 0. ",
|
||||
"But now gamma is ", gamma, ".");
|
||||
|
@ -158,10 +158,10 @@ void sigmoid_focal_loss_backward_mlu(Tensor input, Tensor target,
|
|||
auto handle = mluOpGetCurrentHandle();
|
||||
|
||||
// launch kernel
|
||||
TORCH_MLUOP_CHECK(mluOpFocalLossSigmoidBackward(handle, prefer, reduction, input_desc.desc(),
|
||||
input_ptr, target_desc.desc(), target_ptr,
|
||||
weight_desc.desc(), weight_ptr, alpha, gamma,
|
||||
output_desc.desc(), output_ptr));
|
||||
TORCH_MLUOP_CHECK(mluOpFocalLossSigmoidBackward(
|
||||
handle, prefer, reduction, input_desc.desc(), input_ptr,
|
||||
target_desc.desc(), target_ptr, weight_desc.desc(), weight_ptr, alpha,
|
||||
gamma, output_desc.desc(), output_ptr));
|
||||
}
|
||||
|
||||
void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
#include "pytorch_device_registry.hpp"
|
||||
|
||||
#define MLUOP_MAJOR 0
|
||||
#define MLUOP_MINOR 7
|
||||
#define MLUOP_PATCHLEVEL 1
|
||||
#define MLUOP_MINOR 8
|
||||
#define MLUOP_PATCHLEVEL 0
|
||||
|
||||
/*************************************************************************
|
||||
* This MACRO contains operations of simple tensor to mlu-tensor.
|
||||
|
|
|
@ -74,8 +74,8 @@ void RoIPointPool3dForwardMLUKernelLauncher(
|
|||
pts_feature.numel(), ".");
|
||||
|
||||
// set contiguous
|
||||
auto xyz_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
xyz, xyz.suggest_memory_format());
|
||||
auto xyz_contiguous =
|
||||
torch_mlu::cnnl::ops::cnnl_contiguous(xyz, xyz.suggest_memory_format());
|
||||
auto pts_feature_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
pts_feature, pts_feature.suggest_memory_format());
|
||||
auto boxes3d_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
|
@ -92,13 +92,16 @@ void RoIPointPool3dForwardMLUKernelLauncher(
|
|||
auto pts_feature_ptr = pts_feature_impl->cnnlMalloc();
|
||||
auto boxes3d_impl = torch_mlu::getMluTensorImpl(boxes3d_contiguous);
|
||||
auto boxes3d_ptr = boxes3d_impl->cnnlMalloc();
|
||||
auto pooled_features_impl = torch_mlu::getMluTensorImpl(pooled_features_contiguous);
|
||||
auto pooled_features_impl =
|
||||
torch_mlu::getMluTensorImpl(pooled_features_contiguous);
|
||||
auto pooled_features_ptr = pooled_features_impl->cnnlMalloc();
|
||||
auto pooled_empty_flag_impl = torch_mlu::getMluTensorImpl(pooled_empty_flag_contiguous);
|
||||
auto pooled_empty_flag_impl =
|
||||
torch_mlu::getMluTensorImpl(pooled_empty_flag_contiguous);
|
||||
auto pooled_empty_flag_ptr = pooled_empty_flag_impl->cnnlMalloc();
|
||||
|
||||
// create tensor descriptors
|
||||
MluOpTensorDescriptor xyz_desc, pts_feature_desc, boxes3d_desc, pooled_features_desc, pooled_empty_flag_desc;
|
||||
MluOpTensorDescriptor xyz_desc, pts_feature_desc, boxes3d_desc,
|
||||
pooled_features_desc, pooled_empty_flag_desc;
|
||||
xyz_desc.set(xyz_contiguous);
|
||||
pts_feature_desc.set(pts_feature_contiguous);
|
||||
boxes3d_desc.set(boxes3d_contiguous);
|
||||
|
@ -108,10 +111,11 @@ void RoIPointPool3dForwardMLUKernelLauncher(
|
|||
// get workspace
|
||||
size_t workspace_size = 0;
|
||||
auto handle = mluOpGetCurrentHandle();
|
||||
TORCH_MLUOP_CHECK(mluOpGetRoiPointPool3dWorkspaceSize(handle, batch_size,
|
||||
pts_num, boxes_num, feature_in_len, sampled_pts_num, xyz_desc.desc(),
|
||||
pts_feature_desc.desc(), boxes3d_desc.desc(), pooled_features_desc.desc(),
|
||||
pooled_empty_flag_desc.desc(), &workspace_size));
|
||||
TORCH_MLUOP_CHECK(mluOpGetRoiPointPool3dWorkspaceSize(
|
||||
handle, batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num,
|
||||
xyz_desc.desc(), pts_feature_desc.desc(), boxes3d_desc.desc(),
|
||||
pooled_features_desc.desc(), pooled_empty_flag_desc.desc(),
|
||||
&workspace_size));
|
||||
|
||||
auto workspace = at::empty(workspace_size, xyz.options().dtype(at::kByte));
|
||||
auto workspace_impl = torch_mlu::getMluTensorImpl(workspace);
|
||||
|
@ -120,8 +124,8 @@ void RoIPointPool3dForwardMLUKernelLauncher(
|
|||
handle, batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num,
|
||||
xyz_desc.desc(), xyz_ptr, pts_feature_desc.desc(), pts_feature_ptr,
|
||||
boxes3d_desc.desc(), boxes3d_ptr, workspace_ptr, workspace_size,
|
||||
pooled_features_desc.desc(), pooled_features_ptr, pooled_empty_flag_desc.desc(),
|
||||
(int *)pooled_empty_flag_ptr));
|
||||
pooled_features_desc.desc(), pooled_features_ptr,
|
||||
pooled_empty_flag_desc.desc(), (int *)pooled_empty_flag_ptr));
|
||||
}
|
||||
|
||||
void roipoint_pool3d_forward_mlu(int batch_size, int pts_num, int boxes_num,
|
||||
|
|
|
@ -9,65 +9,7 @@
|
|||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "pytorch_device_registry.hpp"
|
||||
#include "pytorch_mlu_helper.hpp"
|
||||
|
||||
void KernelTinShiftForward(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const void *input, const void *shifts, void *output, const int batch_size,
|
||||
const int time_size, const int channel_size, const int hw_size,
|
||||
const int group_size, const int group_channel,
|
||||
const cnrtDataType_t data_dtype, const int channel_per_core,
|
||||
const int max_number_hw_per_core, const int max_length_per_core);
|
||||
|
||||
void KernelTinShiftBackward(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const void *grad_output, const void *shifts, void *grad_input,
|
||||
const int batch_size, const int time_size, const int channel_size,
|
||||
const int hw_size, const int group_size, const int group_channel,
|
||||
const cnrtDataType_t data_dtype, const int channel_per_core,
|
||||
const int max_number_hw_per_core, const int max_length_per_core);
|
||||
|
||||
// policy function
|
||||
static void policyFunc(const Tensor &input, cnrtDim3_t *k_dim,
|
||||
cnrtFunctionType_t *k_type, int *channel_per_core,
|
||||
int *max_number_hw_per_core, int *max_length_per_core) {
|
||||
const int32_t cluster_limit = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
|
||||
const int32_t core_limit = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
|
||||
auto nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
|
||||
const int core_num = core_limit * cluster_limit;
|
||||
const int batch_size = input.size(0);
|
||||
const int time_size = input.size(1);
|
||||
const int channel_size = input.size(2);
|
||||
const int hw_size = input.size(3);
|
||||
|
||||
const size_t size_per_channel = time_size * hw_size * input.itemsize();
|
||||
*channel_per_core = nram_size / size_per_channel;
|
||||
int task_dim = 0;
|
||||
if (*channel_per_core == 0) {
|
||||
const size_t size_per_hw = hw_size * input.itemsize();
|
||||
*max_number_hw_per_core = nram_size / size_per_hw;
|
||||
if (*max_number_hw_per_core <= 0) {
|
||||
*max_length_per_core = nram_size / input.itemsize();
|
||||
}
|
||||
int tmp_max_number_hw_per_core =
|
||||
*max_number_hw_per_core > 0 ? *max_number_hw_per_core : 1;
|
||||
const int loop_time =
|
||||
(time_size / (tmp_max_number_hw_per_core)) +
|
||||
((time_size % (tmp_max_number_hw_per_core)) > 0 ? 1 : 0);
|
||||
task_dim = batch_size * channel_size * loop_time < core_num
|
||||
? batch_size * channel_size * loop_time
|
||||
: core_num;
|
||||
} else {
|
||||
task_dim = batch_size * channel_size < core_num ? batch_size * channel_size
|
||||
: core_num;
|
||||
}
|
||||
|
||||
k_dim->x = core_limit;
|
||||
k_dim->y = (task_dim / core_limit) > 0 ? (task_dim / core_limit) : 1;
|
||||
k_dim->z = 1;
|
||||
*k_type = CNRT_FUNC_TYPE_UNION1;
|
||||
}
|
||||
#include "mlu_common_helper.h"
|
||||
|
||||
void TINShiftForwardMLUKernelLauncher(Tensor input, Tensor shift,
|
||||
Tensor output) {
|
||||
|
@ -89,40 +31,37 @@ void TINShiftForwardMLUKernelLauncher(Tensor input, Tensor shift,
|
|||
if (input.size(1) == 0) {
|
||||
return;
|
||||
}
|
||||
cnrtDim3_t k_dim;
|
||||
cnrtFunctionType_t k_type;
|
||||
int channel_per_core = 0;
|
||||
int max_number_hw_per_core = 0;
|
||||
int max_length_per_core = 0;
|
||||
policyFunc(input, &k_dim, &k_type, &channel_per_core, &max_number_hw_per_core,
|
||||
&max_length_per_core);
|
||||
|
||||
const int batch_size = input.size(0);
|
||||
const int time_size = input.size(1);
|
||||
const int channel_size = input.size(2);
|
||||
const int hw_size = input.size(3);
|
||||
const int group_size = shift.size(1);
|
||||
int group_channel = channel_size / group_size;
|
||||
// set contiguous
|
||||
auto input_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
input, input.suggest_memory_format());
|
||||
auto shift_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
shift, shift.suggest_memory_format());
|
||||
auto output_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
output, output.suggest_memory_format());
|
||||
|
||||
// get tensor impl
|
||||
auto input_impl = torch_mlu::getMluTensorImpl(input);
|
||||
auto shift_impl = torch_mlu::getMluTensorImpl(shift);
|
||||
auto output_impl = torch_mlu::getMluTensorImpl(output);
|
||||
|
||||
// get compute queue
|
||||
auto queue = torch_mlu::getCurQueue();
|
||||
auto input_impl = torch_mlu::getMluTensorImpl(input_contiguous);
|
||||
auto shift_impl = torch_mlu::getMluTensorImpl(shift_contiguous);
|
||||
auto output_impl = torch_mlu::getMluTensorImpl(output_contiguous);
|
||||
|
||||
// get the mlu ptr
|
||||
auto input_ptr = input_impl->cnnlMalloc();
|
||||
auto shift_ptr = shift_impl->cnnlMalloc();
|
||||
auto output_ptr = output_impl->cnnlMalloc();
|
||||
|
||||
cnrtDataType_t data_dtype = torch_mlu::toCnrtDtype(input.dtype());
|
||||
// set tensor descriptor
|
||||
MluOpTensorDescriptor input_desc, shift_desc, output_desc;
|
||||
input_desc.set(input_contiguous);
|
||||
shift_desc.set(shift_contiguous);
|
||||
output_desc.set(output_contiguous);
|
||||
|
||||
KernelTinShiftForward(k_dim, k_type, queue, input_ptr, shift_ptr, output_ptr,
|
||||
batch_size, time_size, channel_size, hw_size,
|
||||
group_size, group_channel, data_dtype, channel_per_core,
|
||||
max_number_hw_per_core, max_length_per_core);
|
||||
// get current handle
|
||||
auto handle = mluOpGetCurrentHandle();
|
||||
|
||||
TORCH_MLUOP_CHECK(mluOpTinShiftForward(handle, input_desc.desc(), input_ptr,
|
||||
shift_desc.desc(), shift_ptr,
|
||||
output_desc.desc(), output_ptr));
|
||||
}
|
||||
|
||||
void TINShiftBackwardMLUKernelLauncher(Tensor grad_output, Tensor shift,
|
||||
|
@ -148,41 +87,37 @@ void TINShiftBackwardMLUKernelLauncher(Tensor grad_output, Tensor shift,
|
|||
if (grad_output.size(1) == 0) {
|
||||
return;
|
||||
}
|
||||
cnrtDim3_t k_dim;
|
||||
cnrtFunctionType_t k_type;
|
||||
int channel_per_core = 0;
|
||||
int max_number_hw_per_core = 0;
|
||||
int max_length_per_core = 0;
|
||||
policyFunc(grad_output, &k_dim, &k_type, &channel_per_core,
|
||||
&max_number_hw_per_core, &max_length_per_core);
|
||||
|
||||
const int batch_size = grad_output.size(0);
|
||||
const int time_size = grad_output.size(1);
|
||||
const int channel_size = grad_output.size(2);
|
||||
const int hw_size = grad_output.size(3);
|
||||
const int group_size = shift.size(1);
|
||||
int group_channel = channel_size / group_size;
|
||||
// set contiguous
|
||||
auto grad_output_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
grad_output, grad_output.suggest_memory_format());
|
||||
auto shift_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
shift, shift.suggest_memory_format());
|
||||
auto grad_input_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
grad_input, grad_input.suggest_memory_format());
|
||||
|
||||
// get tensor impl
|
||||
auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output);
|
||||
auto shift_impl = torch_mlu::getMluTensorImpl(shift);
|
||||
auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input);
|
||||
|
||||
// get compute queue
|
||||
auto queue = torch_mlu::getCurQueue();
|
||||
auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output_contiguous);
|
||||
auto shift_impl = torch_mlu::getMluTensorImpl(shift_contiguous);
|
||||
auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input_contiguous);
|
||||
|
||||
// get the mlu ptr
|
||||
auto grad_output_ptr = grad_output_impl->cnnlMalloc();
|
||||
auto shift_ptr = shift_impl->cnnlMalloc();
|
||||
auto grad_input_ptr = grad_input_impl->cnnlMalloc();
|
||||
|
||||
cnrtDataType_t data_dtype = torch_mlu::toCnrtDtype(grad_output.dtype());
|
||||
// set tensor descriptor
|
||||
MluOpTensorDescriptor grad_output_desc, shift_desc, grad_input_desc;
|
||||
grad_output_desc.set(grad_output_contiguous);
|
||||
shift_desc.set(shift_contiguous);
|
||||
grad_input_desc.set(grad_input_contiguous);
|
||||
|
||||
KernelTinShiftBackward(k_dim, k_type, queue, grad_output_ptr, shift_ptr,
|
||||
grad_input_ptr, batch_size, time_size, channel_size,
|
||||
hw_size, group_size, group_channel, data_dtype,
|
||||
channel_per_core, max_number_hw_per_core,
|
||||
max_length_per_core);
|
||||
// get current handle
|
||||
auto handle = mluOpGetCurrentHandle();
|
||||
|
||||
TORCH_MLUOP_CHECK(mluOpTinShiftBackward(
|
||||
handle, grad_output_desc.desc(), grad_output_ptr, shift_desc.desc(),
|
||||
shift_ptr, grad_input_desc.desc(), grad_input_ptr));
|
||||
}
|
||||
|
||||
void tin_shift_forward_mlu(Tensor input, Tensor shift, Tensor output) {
|
||||
|
|
Loading…
Reference in New Issue