[Refactor] Replace tin_shift op of MLU backend with mlu-ops (#2911)

pull/2922/head
Chris Jiang 2023-08-28 17:02:51 +08:00 committed by GitHub
parent a0a17050f2
commit 2491dbb7d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 74 additions and 442 deletions

View File

@ -1,307 +0,0 @@
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
__nram__ char data_nram[MAX_NRAM_SIZE];
template <typename T>
__mlu_func__ void mluMultiKernelTinShift(
const T *input, const int *shifts, T *output, const int batch_size,
const int time_size, const int channel_size, const int hw_size,
const int group_size, const int group_channel) {
for (int cur_channel_index = taskId;
cur_channel_index < batch_size * channel_size;
cur_channel_index += taskDim) {
int n_index = cur_channel_index / channel_size;
int group_id = cur_channel_index % channel_size / group_channel;
int t_shift = shifts[n_index * group_size + group_id];
int index = cur_channel_index % channel_size * hw_size +
n_index * time_size * channel_size * hw_size;
__bang_write_value(data_nram, MAX_NRAM_SIZE, (char)0);
__asm__ volatile("sync;");
if (abs(t_shift) >= time_size) {
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
time_size - 1);
} else {
if (t_shift > 0) {
__memcpy(data_nram + t_shift * hw_size * sizeof(T), input + index,
hw_size * sizeof(T), GDRAM2NRAM, hw_size * sizeof(T),
channel_size * hw_size * sizeof(T), time_size - 1 - t_shift);
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
time_size - 1);
} else {
__memcpy(data_nram, input + (index - t_shift * channel_size * hw_size),
hw_size * sizeof(T), GDRAM2NRAM, hw_size * sizeof(T),
channel_size * hw_size * sizeof(T), time_size - 1 + t_shift);
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
time_size - 1);
}
}
__asm__ volatile("sync;");
}
}
template <typename T>
__mlu_func__ void mluHwSplit(const T *input, const int t_shift,
const int time_size, const int hw_size,
const int channel_size, const int index,
const int cur_sequence_index,
const int max_length_per_core, T *output) {
for (int cur_index = index; cur_index < index + hw_size;
cur_index += max_length_per_core) {
int memcpy_size = max_length_per_core;
if (cur_index + max_length_per_core > index + hw_size) {
memcpy_size = index + hw_size - cur_index;
}
if (cur_sequence_index - t_shift < 0 ||
cur_sequence_index - t_shift >= time_size) {
__memcpy(output + cur_index, data_nram, memcpy_size * sizeof(T),
NRAM2GDRAM);
} else {
__memcpy(data_nram, input + cur_index - t_shift * channel_size * hw_size,
memcpy_size * sizeof(T), GDRAM2NRAM);
__memcpy(output + cur_index, data_nram, memcpy_size * sizeof(T),
NRAM2GDRAM);
}
__asm__ volatile("sync;");
}
}
template <typename T>
__mlu_func__ void mluMultiKernelTinShiftSplitSequence(
const T *input, const int *shifts, T *output, const int batch_size,
const int time_size, const int channel_size, const int hw_size,
const int group_size, const int group_channel,
const int max_number_hw_per_core, const int max_length_per_core) {
const int tmp_max_number_hw_per_core =
max_number_hw_per_core > 0 ? max_number_hw_per_core : 1;
const int loop_time = time_size / tmp_max_number_hw_per_core +
((time_size % tmp_max_number_hw_per_core) > 0 ? 1 : 0);
int segmentime_size = tmp_max_number_hw_per_core;
int res_segment = time_size % tmp_max_number_hw_per_core;
for (int cur_segment_index = taskId;
cur_segment_index < loop_time * batch_size * channel_size;
cur_segment_index += taskDim) {
int n_index = cur_segment_index / loop_time / channel_size;
int group_id = cur_segment_index / loop_time % channel_size / group_channel;
int t_shift = shifts[n_index * group_size + group_id];
int index = n_index * time_size * channel_size * hw_size +
(cur_segment_index / loop_time % channel_size) * hw_size +
cur_segment_index % loop_time * segmentime_size * hw_size *
channel_size;
char *dst_gdram2nram = data_nram;
const T *src_gdram2nram = input + index;
int count_gdram2nram = -1;
int count_nram2gdram = -1;
int next_sequence_index =
index / hw_size / channel_size % time_size + segmentime_size;
int cur_sequence_index = index / hw_size / channel_size % time_size;
__bang_write_value(data_nram, MAX_NRAM_SIZE, (char)0);
__asm__ volatile("sync;");
if (max_number_hw_per_core == 0) {
mluHwSplit(input, t_shift, time_size, hw_size, channel_size, index,
cur_sequence_index, max_length_per_core, output);
continue;
}
if (abs(t_shift) >= time_size) {
if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) {
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
res_segment - 1);
} else {
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
segmentime_size - 1);
}
continue;
}
if (t_shift == 0) {
if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) {
dst_gdram2nram = data_nram;
src_gdram2nram = input + index;
count_gdram2nram = res_segment - 1;
count_nram2gdram = res_segment - 1;
} else {
dst_gdram2nram = data_nram;
src_gdram2nram = input + index;
count_gdram2nram = segmentime_size - 1;
count_nram2gdram = segmentime_size - 1;
}
} else if (t_shift > 0) {
int first_index_cur_channel =
n_index * time_size * channel_size * hw_size +
(cur_segment_index / loop_time % channel_size) * hw_size;
if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) {
dst_gdram2nram = data_nram;
src_gdram2nram =
input +
(index - t_shift * channel_size * hw_size < first_index_cur_channel
? first_index_cur_channel
: index - t_shift * channel_size * hw_size);
count_gdram2nram = res_segment - 1;
count_nram2gdram = res_segment - 1;
if (cur_sequence_index < t_shift && t_shift < next_sequence_index) {
dst_gdram2nram =
data_nram + t_shift % segmentime_size * hw_size * sizeof(T);
count_gdram2nram = res_segment - (t_shift - cur_sequence_index) - 1;
}
} else {
if (t_shift >= next_sequence_index) {
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
segmentime_size - 1);
continue;
} else if (cur_sequence_index < t_shift &&
t_shift < next_sequence_index) {
dst_gdram2nram =
data_nram + t_shift % segmentime_size * hw_size * sizeof(T);
src_gdram2nram = input + first_index_cur_channel;
count_gdram2nram = segmentime_size - (t_shift % segmentime_size) - 1;
count_nram2gdram = segmentime_size - 1;
} else {
dst_gdram2nram = data_nram;
src_gdram2nram = input + index - t_shift * channel_size * hw_size;
count_gdram2nram = segmentime_size - 1;
count_nram2gdram = segmentime_size - 1;
}
}
} else {
int offset_index = time_size + t_shift;
if (cur_sequence_index >= offset_index) {
if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) {
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
res_segment - 1);
continue;
} else {
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
segmentime_size - 1);
continue;
}
} else {
dst_gdram2nram = data_nram;
src_gdram2nram = input + index - t_shift * channel_size * hw_size;
if (cur_sequence_index - t_shift + segmentime_size < time_size) {
count_gdram2nram = segmentime_size - 1;
count_nram2gdram = segmentime_size - 1;
} else {
count_gdram2nram = time_size - (cur_sequence_index - t_shift) - 1;
count_nram2gdram =
(segmentime_size - 1) < (time_size - cur_sequence_index - 1)
? (segmentime_size - 1)
: (time_size - cur_sequence_index - 1);
}
}
}
__memcpy(dst_gdram2nram, src_gdram2nram, hw_size * sizeof(T), GDRAM2NRAM,
hw_size * sizeof(T), channel_size * hw_size * sizeof(T),
count_gdram2nram);
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
count_nram2gdram);
__asm__ volatile("sync;");
}
}
__mlu_entry__ void MLUUnion1KernelTinShift(
const void *input, const void *shifts, void *output, const int batch_size,
const int time_size, const int channel_size, const int hw_size,
const int group_size, const int group_channel,
const cnrtDataType_t data_dtype) {
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
switch (data_dtype) {
case CNRT_FLOAT16: {
mluMultiKernelTinShift((half *)input, (const int *)shifts, (half *)output,
batch_size, time_size, channel_size, hw_size,
group_size, group_channel);
}; break;
case CNRT_FLOAT32: {
mluMultiKernelTinShift((float *)input, (const int *)shifts,
(float *)output, batch_size, time_size,
channel_size, hw_size, group_size, group_channel);
}; break;
default: { return; }
}
}
__mlu_entry__ void MLUUnion1KernelTinShiftSplitSequence(
const void *input, const void *shifts, void *output, const int batch_size,
const int time_size, const int channel_size, const int hw_size,
const int group_size, const int group_channel,
const int max_number_hw_per_core, const int max_length_per_core,
const cnrtDataType_t data_dtype) {
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
switch (data_dtype) {
case CNRT_FLOAT16: {
mluMultiKernelTinShiftSplitSequence(
(half *)input, (const int *)shifts, (half *)output, batch_size,
time_size, channel_size, hw_size, group_size, group_channel,
max_number_hw_per_core, max_length_per_core);
}; break;
case CNRT_FLOAT32: {
mluMultiKernelTinShiftSplitSequence(
(float *)input, (const int *)shifts, (float *)output, batch_size,
time_size, channel_size, hw_size, group_size, group_channel,
max_number_hw_per_core, max_length_per_core);
}; break;
default: { return; }
}
}
void KernelTinShiftForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const void *input, const void *shifts, void *output, const int batch_size,
const int time_size, const int channel_size, const int hw_size,
const int group_size, const int group_channel,
const cnrtDataType_t data_dtype, const int channel_per_core,
const int max_number_hw_per_core, const int max_length_per_core) {
if (channel_per_core >= 1) {
MLUUnion1KernelTinShift<<<k_dim, k_type, queue>>>(
input, shifts, output, batch_size, time_size, channel_size, hw_size,
group_size, group_channel, data_dtype);
} else {
MLUUnion1KernelTinShiftSplitSequence<<<k_dim, k_type, queue>>>(
input, shifts, output, batch_size, time_size, channel_size, hw_size,
group_size, group_channel, max_number_hw_per_core, max_length_per_core,
data_dtype);
}
}
void KernelTinShiftBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const void *grad_output, const void *shifts, void *grad_input,
const int batch_size, const int time_size, const int channel_size,
const int hw_size, const int group_size, const int group_channel,
const cnrtDataType_t data_dtype, const int channel_per_core,
const int max_number_hw_per_core, const int max_length_per_core) {
if (channel_per_core >= 1) {
MLUUnion1KernelTinShift<<<k_dim, k_type, queue>>>(
grad_output, shifts, grad_input, batch_size, time_size, channel_size,
hw_size, group_size, group_channel, data_dtype);
} else {
MLUUnion1KernelTinShiftSplitSequence<<<k_dim, k_type, queue>>>(
grad_output, shifts, grad_input, batch_size, time_size, channel_size,
hw_size, group_size, group_channel, max_number_hw_per_core,
max_length_per_core, data_dtype);
}
}

View File

@ -14,9 +14,9 @@
#include "mlu_common_helper.h"
void sigmoid_focal_loss_forward_mlu(Tensor input, Tensor target,
Tensor weight, Tensor output,
const float gamma, const float alpha) {
void sigmoid_focal_loss_forward_mlu(Tensor input, Tensor target, Tensor weight,
Tensor output, const float gamma,
const float alpha) {
// params check
TORCH_CHECK(gamma >= 0, "gamma should be greater than or equal to 0. ",
"But now gamma is ", gamma, ".");
@ -82,15 +82,15 @@ void sigmoid_focal_loss_forward_mlu(Tensor input, Tensor target,
auto handle = mluOpGetCurrentHandle();
// launch kernel
TORCH_MLUOP_CHECK(mluOpFocalLossSigmoidForward(handle, prefer, reduction, input_desc.desc(),
input_ptr, target_desc.desc(), target_ptr,
weight_desc.desc(), weight_ptr, alpha, gamma,
output_desc.desc(), output_ptr));
TORCH_MLUOP_CHECK(mluOpFocalLossSigmoidForward(
handle, prefer, reduction, input_desc.desc(), input_ptr,
target_desc.desc(), target_ptr, weight_desc.desc(), weight_ptr, alpha,
gamma, output_desc.desc(), output_ptr));
}
void sigmoid_focal_loss_backward_mlu(Tensor input, Tensor target,
Tensor weight, Tensor output,
const float gamma, const float alpha) {
void sigmoid_focal_loss_backward_mlu(Tensor input, Tensor target, Tensor weight,
Tensor output, const float gamma,
const float alpha) {
// params check
TORCH_CHECK(gamma >= 0, "gamma should be greater than or equal to 0. ",
"But now gamma is ", gamma, ".");
@ -158,10 +158,10 @@ void sigmoid_focal_loss_backward_mlu(Tensor input, Tensor target,
auto handle = mluOpGetCurrentHandle();
// launch kernel
TORCH_MLUOP_CHECK(mluOpFocalLossSigmoidBackward(handle, prefer, reduction, input_desc.desc(),
input_ptr, target_desc.desc(), target_ptr,
weight_desc.desc(), weight_ptr, alpha, gamma,
output_desc.desc(), output_ptr));
TORCH_MLUOP_CHECK(mluOpFocalLossSigmoidBackward(
handle, prefer, reduction, input_desc.desc(), input_ptr,
target_desc.desc(), target_ptr, weight_desc.desc(), weight_ptr, alpha,
gamma, output_desc.desc(), output_ptr));
}
void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,

View File

@ -18,8 +18,8 @@
#include "pytorch_device_registry.hpp"
#define MLUOP_MAJOR 0
#define MLUOP_MINOR 7
#define MLUOP_PATCHLEVEL 1
#define MLUOP_MINOR 8
#define MLUOP_PATCHLEVEL 0
/*************************************************************************
* This MACRO contains operations of simple tensor to mlu-tensor.

View File

@ -74,8 +74,8 @@ void RoIPointPool3dForwardMLUKernelLauncher(
pts_feature.numel(), ".");
// set contiguous
auto xyz_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
xyz, xyz.suggest_memory_format());
auto xyz_contiguous =
torch_mlu::cnnl::ops::cnnl_contiguous(xyz, xyz.suggest_memory_format());
auto pts_feature_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
pts_feature, pts_feature.suggest_memory_format());
auto boxes3d_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
@ -92,13 +92,16 @@ void RoIPointPool3dForwardMLUKernelLauncher(
auto pts_feature_ptr = pts_feature_impl->cnnlMalloc();
auto boxes3d_impl = torch_mlu::getMluTensorImpl(boxes3d_contiguous);
auto boxes3d_ptr = boxes3d_impl->cnnlMalloc();
auto pooled_features_impl = torch_mlu::getMluTensorImpl(pooled_features_contiguous);
auto pooled_features_impl =
torch_mlu::getMluTensorImpl(pooled_features_contiguous);
auto pooled_features_ptr = pooled_features_impl->cnnlMalloc();
auto pooled_empty_flag_impl = torch_mlu::getMluTensorImpl(pooled_empty_flag_contiguous);
auto pooled_empty_flag_impl =
torch_mlu::getMluTensorImpl(pooled_empty_flag_contiguous);
auto pooled_empty_flag_ptr = pooled_empty_flag_impl->cnnlMalloc();
// create tensor descriptors
MluOpTensorDescriptor xyz_desc, pts_feature_desc, boxes3d_desc, pooled_features_desc, pooled_empty_flag_desc;
MluOpTensorDescriptor xyz_desc, pts_feature_desc, boxes3d_desc,
pooled_features_desc, pooled_empty_flag_desc;
xyz_desc.set(xyz_contiguous);
pts_feature_desc.set(pts_feature_contiguous);
boxes3d_desc.set(boxes3d_contiguous);
@ -108,10 +111,11 @@ void RoIPointPool3dForwardMLUKernelLauncher(
// get workspace
size_t workspace_size = 0;
auto handle = mluOpGetCurrentHandle();
TORCH_MLUOP_CHECK(mluOpGetRoiPointPool3dWorkspaceSize(handle, batch_size,
pts_num, boxes_num, feature_in_len, sampled_pts_num, xyz_desc.desc(),
pts_feature_desc.desc(), boxes3d_desc.desc(), pooled_features_desc.desc(),
pooled_empty_flag_desc.desc(), &workspace_size));
TORCH_MLUOP_CHECK(mluOpGetRoiPointPool3dWorkspaceSize(
handle, batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num,
xyz_desc.desc(), pts_feature_desc.desc(), boxes3d_desc.desc(),
pooled_features_desc.desc(), pooled_empty_flag_desc.desc(),
&workspace_size));
auto workspace = at::empty(workspace_size, xyz.options().dtype(at::kByte));
auto workspace_impl = torch_mlu::getMluTensorImpl(workspace);
@ -120,8 +124,8 @@ void RoIPointPool3dForwardMLUKernelLauncher(
handle, batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num,
xyz_desc.desc(), xyz_ptr, pts_feature_desc.desc(), pts_feature_ptr,
boxes3d_desc.desc(), boxes3d_ptr, workspace_ptr, workspace_size,
pooled_features_desc.desc(), pooled_features_ptr, pooled_empty_flag_desc.desc(),
(int *)pooled_empty_flag_ptr));
pooled_features_desc.desc(), pooled_features_ptr,
pooled_empty_flag_desc.desc(), (int *)pooled_empty_flag_ptr));
}
void roipoint_pool3d_forward_mlu(int batch_size, int pts_num, int boxes_num,

View File

@ -9,65 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void KernelTinShiftForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const void *input, const void *shifts, void *output, const int batch_size,
const int time_size, const int channel_size, const int hw_size,
const int group_size, const int group_channel,
const cnrtDataType_t data_dtype, const int channel_per_core,
const int max_number_hw_per_core, const int max_length_per_core);
void KernelTinShiftBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const void *grad_output, const void *shifts, void *grad_input,
const int batch_size, const int time_size, const int channel_size,
const int hw_size, const int group_size, const int group_channel,
const cnrtDataType_t data_dtype, const int channel_per_core,
const int max_number_hw_per_core, const int max_length_per_core);
// policy function
static void policyFunc(const Tensor &input, cnrtDim3_t *k_dim,
cnrtFunctionType_t *k_type, int *channel_per_core,
int *max_number_hw_per_core, int *max_length_per_core) {
const int32_t cluster_limit = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
const int32_t core_limit = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
auto nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
const int core_num = core_limit * cluster_limit;
const int batch_size = input.size(0);
const int time_size = input.size(1);
const int channel_size = input.size(2);
const int hw_size = input.size(3);
const size_t size_per_channel = time_size * hw_size * input.itemsize();
*channel_per_core = nram_size / size_per_channel;
int task_dim = 0;
if (*channel_per_core == 0) {
const size_t size_per_hw = hw_size * input.itemsize();
*max_number_hw_per_core = nram_size / size_per_hw;
if (*max_number_hw_per_core <= 0) {
*max_length_per_core = nram_size / input.itemsize();
}
int tmp_max_number_hw_per_core =
*max_number_hw_per_core > 0 ? *max_number_hw_per_core : 1;
const int loop_time =
(time_size / (tmp_max_number_hw_per_core)) +
((time_size % (tmp_max_number_hw_per_core)) > 0 ? 1 : 0);
task_dim = batch_size * channel_size * loop_time < core_num
? batch_size * channel_size * loop_time
: core_num;
} else {
task_dim = batch_size * channel_size < core_num ? batch_size * channel_size
: core_num;
}
k_dim->x = core_limit;
k_dim->y = (task_dim / core_limit) > 0 ? (task_dim / core_limit) : 1;
k_dim->z = 1;
*k_type = CNRT_FUNC_TYPE_UNION1;
}
#include "mlu_common_helper.h"
void TINShiftForwardMLUKernelLauncher(Tensor input, Tensor shift,
Tensor output) {
@ -89,40 +31,37 @@ void TINShiftForwardMLUKernelLauncher(Tensor input, Tensor shift,
if (input.size(1) == 0) {
return;
}
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
int channel_per_core = 0;
int max_number_hw_per_core = 0;
int max_length_per_core = 0;
policyFunc(input, &k_dim, &k_type, &channel_per_core, &max_number_hw_per_core,
&max_length_per_core);
const int batch_size = input.size(0);
const int time_size = input.size(1);
const int channel_size = input.size(2);
const int hw_size = input.size(3);
const int group_size = shift.size(1);
int group_channel = channel_size / group_size;
// set contiguous
auto input_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
input, input.suggest_memory_format());
auto shift_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
shift, shift.suggest_memory_format());
auto output_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
output, output.suggest_memory_format());
// get tensor impl
auto input_impl = torch_mlu::getMluTensorImpl(input);
auto shift_impl = torch_mlu::getMluTensorImpl(shift);
auto output_impl = torch_mlu::getMluTensorImpl(output);
// get compute queue
auto queue = torch_mlu::getCurQueue();
auto input_impl = torch_mlu::getMluTensorImpl(input_contiguous);
auto shift_impl = torch_mlu::getMluTensorImpl(shift_contiguous);
auto output_impl = torch_mlu::getMluTensorImpl(output_contiguous);
// get the mlu ptr
auto input_ptr = input_impl->cnnlMalloc();
auto shift_ptr = shift_impl->cnnlMalloc();
auto output_ptr = output_impl->cnnlMalloc();
cnrtDataType_t data_dtype = torch_mlu::toCnrtDtype(input.dtype());
// set tensor descriptor
MluOpTensorDescriptor input_desc, shift_desc, output_desc;
input_desc.set(input_contiguous);
shift_desc.set(shift_contiguous);
output_desc.set(output_contiguous);
KernelTinShiftForward(k_dim, k_type, queue, input_ptr, shift_ptr, output_ptr,
batch_size, time_size, channel_size, hw_size,
group_size, group_channel, data_dtype, channel_per_core,
max_number_hw_per_core, max_length_per_core);
// get current handle
auto handle = mluOpGetCurrentHandle();
TORCH_MLUOP_CHECK(mluOpTinShiftForward(handle, input_desc.desc(), input_ptr,
shift_desc.desc(), shift_ptr,
output_desc.desc(), output_ptr));
}
void TINShiftBackwardMLUKernelLauncher(Tensor grad_output, Tensor shift,
@ -148,41 +87,37 @@ void TINShiftBackwardMLUKernelLauncher(Tensor grad_output, Tensor shift,
if (grad_output.size(1) == 0) {
return;
}
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
int channel_per_core = 0;
int max_number_hw_per_core = 0;
int max_length_per_core = 0;
policyFunc(grad_output, &k_dim, &k_type, &channel_per_core,
&max_number_hw_per_core, &max_length_per_core);
const int batch_size = grad_output.size(0);
const int time_size = grad_output.size(1);
const int channel_size = grad_output.size(2);
const int hw_size = grad_output.size(3);
const int group_size = shift.size(1);
int group_channel = channel_size / group_size;
// set contiguous
auto grad_output_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
grad_output, grad_output.suggest_memory_format());
auto shift_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
shift, shift.suggest_memory_format());
auto grad_input_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
grad_input, grad_input.suggest_memory_format());
// get tensor impl
auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output);
auto shift_impl = torch_mlu::getMluTensorImpl(shift);
auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input);
// get compute queue
auto queue = torch_mlu::getCurQueue();
auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output_contiguous);
auto shift_impl = torch_mlu::getMluTensorImpl(shift_contiguous);
auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input_contiguous);
// get the mlu ptr
auto grad_output_ptr = grad_output_impl->cnnlMalloc();
auto shift_ptr = shift_impl->cnnlMalloc();
auto grad_input_ptr = grad_input_impl->cnnlMalloc();
cnrtDataType_t data_dtype = torch_mlu::toCnrtDtype(grad_output.dtype());
// set tensor descriptor
MluOpTensorDescriptor grad_output_desc, shift_desc, grad_input_desc;
grad_output_desc.set(grad_output_contiguous);
shift_desc.set(shift_contiguous);
grad_input_desc.set(grad_input_contiguous);
KernelTinShiftBackward(k_dim, k_type, queue, grad_output_ptr, shift_ptr,
grad_input_ptr, batch_size, time_size, channel_size,
hw_size, group_size, group_channel, data_dtype,
channel_per_core, max_number_hw_per_core,
max_length_per_core);
// get current handle
auto handle = mluOpGetCurrentHandle();
TORCH_MLUOP_CHECK(mluOpTinShiftBackward(
handle, grad_output_desc.desc(), grad_output_ptr, shift_desc.desc(),
shift_ptr, grad_input_desc.desc(), grad_input_ptr));
}
void tin_shift_forward_mlu(Tensor input, Tensor shift, Tensor output) {