mirror of https://github.com/open-mmlab/mmcv.git
[Refactor] Replace the implementation of psa_mask with mlu-ops. (#2756)
parent
8ca930cc72
commit
515d5416a0
|
@ -1,615 +0,0 @@
|
|||
/*************************************************************************
|
||||
* Copyright (C) 2022 Cambricon.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "common_mlu_helper.hpp"
|
||||
#include "psamask_utils.hpp"
|
||||
|
||||
#define COMPUTE_COUNT_ALIGN 64
|
||||
|
||||
__nram__ char buf[MAX_NRAM_SIZE];
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void swap(T &a, T &b) {
|
||||
T tmp = a;
|
||||
a = b;
|
||||
b = tmp;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void storeDataFromNramToDram(T *dst, const T *src,
|
||||
const PositionInCore &position,
|
||||
const Shape &shape_full) {
|
||||
int n_offset = shape_full.h * shape_full.w * shape_full.c;
|
||||
int h_offset = shape_full.w * shape_full.c;
|
||||
int w_offset = shape_full.c;
|
||||
int n_seg = position.n_end - position.n_start;
|
||||
int h_seg = position.h_end - position.h_start;
|
||||
int w_seg = position.w_end - position.w_start;
|
||||
int size = h_seg * w_seg * shape_full.c;
|
||||
|
||||
__memcpy(dst + position.n_start * n_offset + position.h_start * h_offset +
|
||||
position.w_start * w_offset,
|
||||
src, size * sizeof(T), NRAM2GDRAM, n_offset * sizeof(T),
|
||||
size * sizeof(T), n_seg - 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void loadDataFromDramToNram(T *dst, const T *src,
|
||||
const PositionInCore &position,
|
||||
const Shape &shape_full) {
|
||||
int n_offset = shape_full.h * shape_full.w * shape_full.c;
|
||||
int h_offset = shape_full.w * shape_full.c;
|
||||
int w_offset = shape_full.c;
|
||||
int n_seg = position.n_end - position.n_start;
|
||||
int h_seg = position.h_end - position.h_start;
|
||||
int w_seg = position.w_end - position.w_start;
|
||||
int size = h_seg * w_seg * shape_full.c;
|
||||
|
||||
__memcpy(dst, src + position.n_start * n_offset +
|
||||
position.h_start * h_offset + position.w_start * w_offset,
|
||||
size * sizeof(T), GDRAM2NRAM, size * sizeof(T), n_offset * sizeof(T),
|
||||
n_seg - 1);
|
||||
}
|
||||
|
||||
// transpose the data from A*B*C*(D*E) to A*D*E*(B*C)
|
||||
template <typename T>
|
||||
__mlu_func__ void transposeData(T *dst, T *src, const Shape &shape_seg) {
|
||||
int align_c = CEIL_ALIGN(shape_seg.c, COMPUTE_COUNT_ALIGN / sizeof(T));
|
||||
int align_hw =
|
||||
CEIL_ALIGN(shape_seg.h * shape_seg.w, COMPUTE_COUNT_ALIGN / sizeof(T));
|
||||
for (int i = 0; i < shape_seg.n; ++i) {
|
||||
__bang_transpose(dst, src, align_hw, align_c);
|
||||
dst += align_hw * align_c;
|
||||
src += align_hw * align_c;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void psamaskCollectForward(
|
||||
const T *x_dram, T *y_dram, const PositionInCore &position,
|
||||
const Shape &x_full, const Shape &y_full, const Shape &shape_seg,
|
||||
const int h_mask, const int w_mask, const int half_h_mask,
|
||||
const int half_w_mask) {
|
||||
T *x_nram = (T *)buf;
|
||||
T *y_nram =
|
||||
x_nram + CEIL_ALIGN(shape_seg.n * shape_seg.h * shape_seg.w * x_full.c,
|
||||
COMPUTE_COUNT_ALIGN / sizeof(T));
|
||||
loadDataFromDramToNram(x_nram, x_dram, position, x_full);
|
||||
|
||||
// fill zeros to output
|
||||
int elem_count =
|
||||
CEIL_ALIGN(shape_seg.n * shape_seg.h * shape_seg.w * y_full.c,
|
||||
NFU_ALIGN_SIZE / sizeof(T));
|
||||
__bang_write_value(y_nram, elem_count, (T)0);
|
||||
|
||||
int y_n_offset = shape_seg.h * shape_seg.w * shape_seg.c;
|
||||
int y_h_offset = shape_seg.w * shape_seg.c;
|
||||
int y_w_offset = shape_seg.c;
|
||||
int x_n_offset = shape_seg.h * shape_seg.w * x_full.c;
|
||||
int y_c_offset = 1;
|
||||
int x_h_offset = shape_seg.w * x_full.c;
|
||||
int x_w_offset = x_full.c;
|
||||
int x_c_offset = 1;
|
||||
int x_start = 0;
|
||||
int y_start = 0;
|
||||
for (int nidx = 0; nidx < shape_seg.n; ++nidx) {
|
||||
for (int hidx = 0; hidx < shape_seg.h; ++hidx) {
|
||||
for (int widx = 0; widx < shape_seg.w; ++widx) {
|
||||
int h_abs = hidx + position.h_start;
|
||||
int w_abs = widx + position.w_start;
|
||||
int y_offset = y_start;
|
||||
int x_offset = x_start;
|
||||
y_offset += hidx * y_h_offset + widx * y_w_offset;
|
||||
x_offset += hidx * x_h_offset + widx * x_w_offset;
|
||||
|
||||
const int hstart = half_h_mask - h_abs > 0 ? half_h_mask - h_abs : 0;
|
||||
const int hend = x_full.h + half_h_mask - h_abs < h_mask
|
||||
? x_full.h + half_h_mask - h_abs
|
||||
: h_mask;
|
||||
const int wstart = half_w_mask - w_abs > 0 ? half_w_mask - w_abs : 0;
|
||||
const int wend = x_full.w + half_w_mask - w_abs < w_mask
|
||||
? x_full.w + half_w_mask - w_abs
|
||||
: w_mask;
|
||||
// (h, w ) with mask-indexed
|
||||
// (h + hidx - half_h_mask, w + widx - half_w_mask) with feature-indexed
|
||||
y_offset += ((hstart + h_abs - half_h_mask) * x_full.w + wstart +
|
||||
w_abs - half_w_mask) *
|
||||
y_c_offset;
|
||||
x_offset += (hstart * w_mask + wstart) * x_c_offset;
|
||||
int count = wend - wstart;
|
||||
__memcpy(y_nram + y_offset, x_nram + x_offset, count * sizeof(T),
|
||||
NRAM2NRAM, y_c_offset * x_full.w * sizeof(T),
|
||||
x_c_offset * w_mask * sizeof(T), hend - hstart - 1);
|
||||
}
|
||||
}
|
||||
y_start += y_n_offset;
|
||||
x_start += x_n_offset;
|
||||
}
|
||||
storeDataFromNramToDram(y_dram, y_nram, position, y_full);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void psamaskDistributeForward(
|
||||
const T *x_dram, T *y_dram, const PositionInCore &position,
|
||||
const Shape &x_full, const Shape &y_full, const Shape &shape_seg,
|
||||
const int h_mask, const int w_mask, const int half_h_mask,
|
||||
const int half_w_mask) {
|
||||
T *x_nram = (T *)buf;
|
||||
T *y_nram_temp =
|
||||
x_nram + CEIL_ALIGN(shape_seg.n * shape_seg.h * shape_seg.w * x_full.c,
|
||||
COMPUTE_COUNT_ALIGN / sizeof(T));
|
||||
loadDataFromDramToNram(x_nram, x_dram, position, x_full);
|
||||
|
||||
// fill zeros to output
|
||||
int align_c = CEIL_ALIGN(y_full.c, COMPUTE_COUNT_ALIGN / sizeof(T));
|
||||
int align_hw =
|
||||
CEIL_ALIGN(shape_seg.h * shape_seg.w, COMPUTE_COUNT_ALIGN / sizeof(T));
|
||||
int elem_count =
|
||||
CEIL_ALIGN(shape_seg.n * align_c * align_hw, NFU_ALIGN_SIZE / sizeof(T));
|
||||
__bang_write_value(y_nram_temp, elem_count, (T)0);
|
||||
|
||||
int y_n_offset = align_hw * align_c;
|
||||
int y_h_offset = shape_seg.w * align_c;
|
||||
int y_w_offset = align_c;
|
||||
int y_c_offset = 1;
|
||||
int x_n_offset = shape_seg.h * shape_seg.w * x_full.c;
|
||||
int x_h_offset = shape_seg.w * x_full.c;
|
||||
int x_w_offset = x_full.c;
|
||||
int x_c_offset = 1;
|
||||
int h_feature = y_full.h;
|
||||
int w_feature = y_full.w;
|
||||
|
||||
int y_start = 0;
|
||||
int x_start = 0;
|
||||
for (int nidx = 0; nidx < shape_seg.n; ++nidx) {
|
||||
for (int hidx = 0; hidx < shape_seg.h; ++hidx) {
|
||||
for (int widx = 0; widx < shape_seg.w; ++widx) {
|
||||
int h_abs = hidx + position.h_start;
|
||||
int w_abs = widx + position.w_start;
|
||||
int y_offset = y_start;
|
||||
int x_offset = x_start;
|
||||
y_offset += hidx * y_h_offset + widx * y_w_offset;
|
||||
x_offset += hidx * x_h_offset + widx * x_w_offset;
|
||||
const int hstart = half_h_mask - h_abs > 0 ? half_h_mask - h_abs : 0;
|
||||
const int hend = h_feature + half_h_mask - h_abs < h_mask
|
||||
? h_feature + half_h_mask - h_abs
|
||||
: h_mask;
|
||||
const int wstart = half_w_mask - w_abs > 0 ? half_w_mask - w_abs : 0;
|
||||
const int wend = w_feature + half_w_mask - w_abs < w_mask
|
||||
? w_feature + half_w_mask - w_abs
|
||||
: w_mask;
|
||||
// (h, w ) with mask-indexed
|
||||
// (h + hidx - half_h_mask, w + widx - half_w_mask) with feature-indexed
|
||||
y_offset += ((hstart + h_abs - half_h_mask) * x_full.w + wstart +
|
||||
w_abs - half_w_mask) *
|
||||
y_c_offset;
|
||||
x_offset += (hstart * w_mask + wstart) * x_c_offset;
|
||||
int count = wend - wstart;
|
||||
__memcpy(y_nram_temp + y_offset, x_nram + x_offset, count * sizeof(T),
|
||||
NRAM2NRAM, y_c_offset * w_feature * sizeof(T),
|
||||
x_c_offset * w_mask * sizeof(T), hend - hstart - 1);
|
||||
}
|
||||
}
|
||||
y_start += y_n_offset;
|
||||
x_start += x_n_offset;
|
||||
}
|
||||
// transpose y
|
||||
T *y_nram = y_nram_temp + shape_seg.n * align_hw * align_c;
|
||||
Shape y_seg{shape_seg.n, shape_seg.h, shape_seg.w, y_full.c};
|
||||
transposeData(y_nram, y_nram_temp, y_seg);
|
||||
swap(align_c, align_hw);
|
||||
// store y from nram to dram
|
||||
int y_n_offset_full = y_full.h * y_full.w * y_full.c;
|
||||
int y_w_offset_full = y_full.c;
|
||||
int y_c_offset_full = 1;
|
||||
|
||||
int y_dram_start =
|
||||
position.n_start * y_n_offset_full +
|
||||
(position.h_start * y_full.w + position.w_start) * y_c_offset_full;
|
||||
int y_nram_start = 0;
|
||||
for (int nidx = 0; nidx < shape_seg.n; ++nidx) {
|
||||
int y_dram_offset = y_dram_start + nidx * y_n_offset_full;
|
||||
int y_nram_offset = y_nram_start + nidx * align_hw * align_c;
|
||||
__memcpy(y_dram + y_dram_offset, y_nram + y_nram_offset,
|
||||
shape_seg.h * shape_seg.w * sizeof(T), NRAM2GDRAM,
|
||||
y_w_offset_full * sizeof(T), align_c * sizeof(T),
|
||||
h_feature * w_feature - 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void psamaskCollectBackward(
|
||||
const T *dy_dram, T *dx_dram, const PositionInCore &position,
|
||||
const Shape &dy_full, const Shape &dx_full, const Shape &shape_seg,
|
||||
const int h_mask, const int w_mask, const int half_h_mask,
|
||||
const int half_w_mask) {
|
||||
T *dy_nram = (T *)buf;
|
||||
T *dx_nram =
|
||||
dy_nram + CEIL_ALIGN(shape_seg.n * shape_seg.h * shape_seg.w * dy_full.c,
|
||||
COMPUTE_COUNT_ALIGN / sizeof(T));
|
||||
loadDataFromDramToNram(dy_nram, dy_dram, position, dy_full);
|
||||
|
||||
// fill zeros to output
|
||||
int elem_count =
|
||||
CEIL_ALIGN(shape_seg.n * shape_seg.h * shape_seg.w * shape_seg.c,
|
||||
NFU_ALIGN_SIZE / sizeof(T));
|
||||
__bang_write_value(dx_nram, elem_count, (T)0);
|
||||
|
||||
int dy_n_offset = shape_seg.h * shape_seg.w * dy_full.c;
|
||||
int dy_h_offset = shape_seg.w * dy_full.c;
|
||||
int dy_w_offset = dy_full.c;
|
||||
int dy_c_offset = 1;
|
||||
int dx_n_offset = shape_seg.h * shape_seg.w * dx_full.c;
|
||||
int dx_h_offset = shape_seg.w * dx_full.c;
|
||||
int dx_w_offset = dx_full.c;
|
||||
int dx_c_offset = 1;
|
||||
int h_feature = dy_full.h;
|
||||
int w_feature = dy_full.w;
|
||||
|
||||
int dy_start = 0;
|
||||
int dx_start = 0;
|
||||
for (int nidx = 0; nidx < shape_seg.n; ++nidx) {
|
||||
for (int hidx = 0; hidx < shape_seg.h; ++hidx) {
|
||||
for (int widx = 0; widx < shape_seg.w; ++widx) {
|
||||
int h_abs = hidx + position.h_start;
|
||||
int w_abs = widx + position.w_start;
|
||||
int dy_offset = dy_start;
|
||||
int dx_offset = dx_start;
|
||||
dy_offset += hidx * dy_h_offset + widx * dy_w_offset;
|
||||
dx_offset += hidx * dx_h_offset + widx * dx_w_offset;
|
||||
|
||||
const int hstart = half_h_mask - h_abs > 0 ? half_h_mask - h_abs : 0;
|
||||
const int hend = h_feature + half_h_mask - h_abs < h_mask
|
||||
? h_feature + half_h_mask - h_abs
|
||||
: h_mask;
|
||||
const int wstart = half_w_mask - w_abs > 0 ? half_w_mask - w_abs : 0;
|
||||
const int wend = w_feature + half_w_mask - w_abs < w_mask
|
||||
? w_feature + half_w_mask - w_abs
|
||||
: w_mask;
|
||||
// (h, w ) with mask-indexed
|
||||
// (h + h_abs - half_h_mask, w + w_abs - half_w_mask) with
|
||||
// feature-indexed
|
||||
dy_offset += ((hstart + h_abs - half_h_mask) * w_feature + wstart +
|
||||
w_abs - half_w_mask) *
|
||||
dy_c_offset;
|
||||
dx_offset += (hstart * w_mask + wstart) * dx_c_offset;
|
||||
int count = wend - wstart;
|
||||
__memcpy(dx_nram + dx_offset, dy_nram + dy_offset, count * sizeof(T),
|
||||
NRAM2NRAM, dx_c_offset * w_mask * sizeof(T),
|
||||
dy_c_offset * w_feature * sizeof(T), hend - hstart - 1);
|
||||
}
|
||||
}
|
||||
dy_start += dy_n_offset;
|
||||
dx_start += dx_n_offset;
|
||||
}
|
||||
storeDataFromNramToDram(dx_dram, dx_nram, position, dx_full);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void psamaskDistributeBackward(
|
||||
const T *dy_dram, T *dx_dram, const PositionInCore &position,
|
||||
const Shape &dy_full, const Shape &dx_full, const Shape &shape_seg,
|
||||
const int h_mask, const int w_mask, const int half_h_mask,
|
||||
const int half_w_mask) {
|
||||
// load dy from dram to nram
|
||||
T *dy_nram_temp = (T *)buf;
|
||||
int dy_n_offset_full = dy_full.h * dy_full.w * dy_full.c;
|
||||
int dy_c_offset_full = 1;
|
||||
int h_feature = dy_full.h;
|
||||
int w_feature = dy_full.w;
|
||||
int align_c =
|
||||
CEIL_ALIGN(shape_seg.h * shape_seg.w, COMPUTE_COUNT_ALIGN / sizeof(T));
|
||||
int align_hw =
|
||||
CEIL_ALIGN(h_feature * w_feature, COMPUTE_COUNT_ALIGN / sizeof(T));
|
||||
|
||||
int dy_dram_start =
|
||||
position.n_start * dy_n_offset_full +
|
||||
(position.h_start * w_feature + position.w_start) * dy_c_offset_full;
|
||||
int dy_nram_start = 0;
|
||||
for (int i = 0; i < shape_seg.n; ++i) {
|
||||
int dy_nram_offset = dy_nram_start + i * (align_hw * align_c);
|
||||
int dy_dram_offset = dy_dram_start + i * dy_n_offset_full;
|
||||
__memcpy(dy_nram_temp + dy_nram_offset, dy_dram + dy_dram_offset,
|
||||
shape_seg.h * shape_seg.w * sizeof(T), GDRAM2NRAM,
|
||||
align_c * sizeof(T), dy_full.c * sizeof(T),
|
||||
h_feature * w_feature - 1);
|
||||
}
|
||||
T *dy_nram = dy_nram_temp + shape_seg.n * align_hw * align_c;
|
||||
Shape dy_seg{shape_seg.n, h_feature, w_feature, shape_seg.h * shape_seg.w};
|
||||
transposeData(dy_nram, dy_nram_temp, dy_seg);
|
||||
swap(align_c, align_hw);
|
||||
|
||||
// fill zeros to dx
|
||||
T *dx_nram = dy_nram + shape_seg.n * align_hw * align_c;
|
||||
int dx_size = shape_seg.n * shape_seg.h * shape_seg.w * dx_full.c;
|
||||
__bang_write_value(dx_nram, CEIL_ALIGN(dx_size, NFU_ALIGN_SIZE / sizeof(T)),
|
||||
(T)0);
|
||||
|
||||
int dy_n_offset_seg = align_hw * align_c;
|
||||
int dy_h_offset_seg = shape_seg.w * align_c;
|
||||
int dy_w_offset_seg = align_c;
|
||||
int dy_c_offset_seg = 1;
|
||||
int dx_n_offset_seg = shape_seg.h * shape_seg.w * shape_seg.c;
|
||||
int dx_h_offset_seg = shape_seg.w * shape_seg.c;
|
||||
int dx_w_offset_seg = shape_seg.c;
|
||||
int dx_c_offset_seg = 1;
|
||||
|
||||
int dy_start = 0;
|
||||
int dx_start = 0;
|
||||
for (int nidx = 0; nidx < shape_seg.n; ++nidx) {
|
||||
for (int hidx = 0; hidx < shape_seg.h; ++hidx) {
|
||||
for (int widx = 0; widx < shape_seg.w; ++widx) {
|
||||
int h_abs = hidx + position.h_start;
|
||||
int w_abs = widx + position.w_start;
|
||||
int dy_offset = dy_start;
|
||||
int dx_offset = dx_start;
|
||||
dy_offset += hidx * dy_h_offset_seg + widx * dy_w_offset_seg;
|
||||
dx_offset += hidx * dx_h_offset_seg + widx * dx_w_offset_seg;
|
||||
const int hstart = half_h_mask - h_abs > 0 ? half_h_mask - h_abs : 0;
|
||||
const int hend = h_feature + half_h_mask - h_abs < h_mask
|
||||
? h_feature + half_h_mask - h_abs
|
||||
: h_mask;
|
||||
const int wstart = half_w_mask - w_abs > 0 ? half_w_mask - w_abs : 0;
|
||||
const int wend = w_feature + half_w_mask - w_abs < w_mask
|
||||
? w_feature + half_w_mask - w_abs
|
||||
: w_mask;
|
||||
// (h, w ) with mask-indexed
|
||||
// (h + h_abs - half_h_mask, w + w_abs - half_w_mask) with
|
||||
// feature-indexed
|
||||
dy_offset += ((hstart + h_abs - half_h_mask) * w_feature + wstart +
|
||||
w_abs - half_w_mask) *
|
||||
dy_c_offset_seg;
|
||||
dx_offset += (hstart * w_mask + wstart) * dx_c_offset_seg;
|
||||
int count = wend - wstart;
|
||||
__memcpy(dx_nram + dx_offset, dy_nram + dy_offset, count * sizeof(T),
|
||||
NRAM2NRAM, w_mask * dx_c_offset_seg * sizeof(T),
|
||||
w_feature * dy_c_offset_seg * sizeof(T), hend - hstart - 1);
|
||||
}
|
||||
}
|
||||
dy_start += dy_n_offset_seg;
|
||||
dx_start += dx_n_offset_seg;
|
||||
}
|
||||
storeDataFromNramToDram(dx_dram, dx_nram, position, dx_full);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void psamaskBase(const T *input_dram, T *output_dram,
|
||||
const Shape &input_full, const Shape &output_full,
|
||||
LimitParam &limit, const PsamaskType psa_type,
|
||||
const DimPartitionType core_partition,
|
||||
const DimPartitionType cluster_partition,
|
||||
const bool is_forward, const int h_mask,
|
||||
const int w_mask, const int half_h_mask,
|
||||
const int half_w_mask, const int n_per_core,
|
||||
const int h_per_core, const int n_per_cluster,
|
||||
const int h_per_cluster) {
|
||||
PositionInCore position_full;
|
||||
PositionInCore position_seg;
|
||||
position_full.w_start = 0;
|
||||
position_full.w_end = output_full.w;
|
||||
int n_num_in_cluster = n_per_cluster;
|
||||
int h_num_in_cluster = h_per_cluster;
|
||||
|
||||
switch (cluster_partition) {
|
||||
case PARTITION_N: {
|
||||
position_full.h_start = 0;
|
||||
position_full.h_end = input_full.h;
|
||||
position_full.n_start = taskIdY * n_per_cluster;
|
||||
int cluster_need = (input_full.n + n_per_cluster - 1) / n_per_cluster;
|
||||
if (taskIdY >= cluster_need) return;
|
||||
int n_remainder = input_full.n - (cluster_need - 1) * n_per_cluster;
|
||||
n_num_in_cluster =
|
||||
(taskIdY == cluster_need - 1) ? n_remainder : n_per_cluster;
|
||||
position_full.n_end = position_full.n_start + n_num_in_cluster;
|
||||
}; break;
|
||||
case PARTITION_H: {
|
||||
position_full.n_start = 0;
|
||||
position_full.n_end = input_full.n;
|
||||
position_full.h_start = taskIdY * h_per_cluster;
|
||||
int cluster_need = (input_full.h + h_per_cluster - 1) / h_per_cluster;
|
||||
if (taskIdY >= cluster_need) return;
|
||||
int h_remainder = input_full.h - (cluster_need - 1) * h_per_cluster;
|
||||
h_num_in_cluster =
|
||||
(taskIdY == cluster_need - 1) ? h_remainder : h_per_cluster;
|
||||
position_full.h_end = position_full.h_start + h_num_in_cluster;
|
||||
}; break;
|
||||
}
|
||||
switch (core_partition) {
|
||||
case PARTITION_N: {
|
||||
position_full.n_start += taskIdX * n_per_core;
|
||||
int core_need = (n_num_in_cluster + n_per_core - 1) / n_per_core;
|
||||
if (taskIdX >= core_need) return;
|
||||
int n_remainder = n_num_in_cluster - (core_need - 1) * n_per_core;
|
||||
position_full.n_end =
|
||||
position_full.n_start +
|
||||
((taskIdX == core_need - 1) ? n_remainder : n_per_core);
|
||||
}; break;
|
||||
case PARTITION_H: {
|
||||
position_full.h_start += taskIdX * h_per_core;
|
||||
int core_need = (h_num_in_cluster + h_per_core - 1) / h_per_core;
|
||||
if (taskIdX >= core_need) return;
|
||||
int h_remainder = h_num_in_cluster - (core_need - 1) * h_per_core;
|
||||
position_full.h_end =
|
||||
position_full.h_start +
|
||||
((taskIdX == core_need - 1) ? h_remainder : h_per_core);
|
||||
}; break;
|
||||
}
|
||||
// the count of n ,h and w need to be processed in the current core
|
||||
int shape_core_n = position_full.n_end - position_full.n_start;
|
||||
int shape_core_h = position_full.h_end - position_full.h_start;
|
||||
int shape_core_w = input_full.w;
|
||||
|
||||
limit.n = limit.n < shape_core_n ? limit.n : shape_core_n;
|
||||
limit.h = limit.h < shape_core_h ? limit.h : shape_core_h;
|
||||
limit.w = limit.w < shape_core_w ? limit.w : shape_core_w;
|
||||
|
||||
// load the data to nram according to the limit
|
||||
for (int nidx = position_full.n_start; nidx < position_full.n_end;
|
||||
nidx += limit.n) {
|
||||
position_seg.n_start = nidx;
|
||||
position_seg.n_end =
|
||||
position_seg.n_start + (position_full.n_end - nidx < limit.n
|
||||
? position_full.n_end - nidx
|
||||
: limit.n);
|
||||
for (int hidx = position_full.h_start; hidx < position_full.h_end;
|
||||
hidx += limit.h) {
|
||||
position_seg.h_start = hidx;
|
||||
position_seg.h_end =
|
||||
position_seg.h_start + (position_full.h_end - hidx < limit.h
|
||||
? position_full.h_end - hidx
|
||||
: limit.h);
|
||||
for (int widx = position_full.w_start; widx < position_full.w_end;
|
||||
widx += limit.w) {
|
||||
position_seg.w_start = widx;
|
||||
position_seg.w_end =
|
||||
position_seg.w_start + (position_full.w_end - widx < limit.w
|
||||
? position_full.w_end - widx
|
||||
: limit.w);
|
||||
|
||||
// record the segment of output except the size of channel
|
||||
// channel segments of output and input are the same
|
||||
Shape shape_seg;
|
||||
shape_seg.n = position_seg.n_end - position_seg.n_start;
|
||||
shape_seg.h = position_seg.h_end - position_seg.h_start;
|
||||
shape_seg.w = position_seg.w_end - position_seg.w_start;
|
||||
shape_seg.c = output_full.c;
|
||||
|
||||
switch (psa_type) {
|
||||
case COLLECT: {
|
||||
if (is_forward) {
|
||||
psamaskCollectForward(input_dram, output_dram, position_seg,
|
||||
input_full, output_full, shape_seg, h_mask,
|
||||
w_mask, half_h_mask, half_w_mask);
|
||||
} else {
|
||||
psamaskCollectBackward(input_dram, output_dram, position_seg,
|
||||
input_full, output_full, shape_seg, h_mask,
|
||||
w_mask, half_h_mask, half_w_mask);
|
||||
}
|
||||
} break;
|
||||
case DISTRIBUTE: {
|
||||
if (is_forward) {
|
||||
psamaskDistributeForward(input_dram, output_dram, position_seg,
|
||||
input_full, output_full, shape_seg,
|
||||
h_mask, w_mask, half_h_mask,
|
||||
half_w_mask);
|
||||
} else {
|
||||
psamaskDistributeBackward(input_dram, output_dram, position_seg,
|
||||
input_full, output_full, shape_seg,
|
||||
h_mask, w_mask, half_h_mask,
|
||||
half_w_mask);
|
||||
}
|
||||
} break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUUnion1KernelPsamaskForward(
|
||||
const T *x, T *y, const PsamaskType psa_type,
|
||||
const DimPartitionType core_partition,
|
||||
const DimPartitionType cluster_partition, const int batch,
|
||||
const int h_feature, const int w_feature, const int h_mask,
|
||||
const int w_mask, const int x_c, const int y_c, const int half_h_mask,
|
||||
const int half_w_mask, const int n_per_core, const int h_per_core,
|
||||
const int n_per_cluster, const int h_per_cluster, const int limit_n_seg,
|
||||
const int limit_h_seg, const int limit_w_seg) {
|
||||
if (coreId == 0x80) {
|
||||
return;
|
||||
}
|
||||
Shape x_full, y_full;
|
||||
x_full.n = batch;
|
||||
x_full.h = h_feature;
|
||||
x_full.w = w_feature;
|
||||
x_full.c = x_c;
|
||||
y_full.n = batch;
|
||||
y_full.h = h_feature;
|
||||
y_full.w = w_feature;
|
||||
y_full.c = y_c;
|
||||
|
||||
LimitParam limit;
|
||||
limit.n = limit_n_seg;
|
||||
limit.h = limit_h_seg;
|
||||
limit.w = limit_w_seg;
|
||||
|
||||
psamaskBase(x, y, x_full, y_full, limit, psa_type, core_partition,
|
||||
cluster_partition, true, h_mask, w_mask, half_h_mask, half_w_mask,
|
||||
n_per_core, h_per_core, n_per_cluster, h_per_cluster);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUUnion1KernelPsamaskBackward(
|
||||
const T *dy, T *dx, const PsamaskType psa_type,
|
||||
const DimPartitionType core_partition,
|
||||
const DimPartitionType cluster_partition, const int batch,
|
||||
const int h_feature, const int w_feature, const int h_mask,
|
||||
const int w_mask, const int dx_c, const int dy_c, const int half_h_mask,
|
||||
const int half_w_mask, const int n_per_core, const int h_per_core,
|
||||
const int n_per_cluster, const int h_per_cluster, const int limit_n_seg,
|
||||
const int limit_h_seg, const int limit_w_seg) {
|
||||
if (coreId == 0x80) {
|
||||
return;
|
||||
}
|
||||
Shape dy_full, dx_full;
|
||||
dx_full.n = batch;
|
||||
dx_full.h = h_feature;
|
||||
dx_full.w = w_feature;
|
||||
dx_full.c = dx_c;
|
||||
dy_full.n = batch;
|
||||
dy_full.h = h_feature;
|
||||
dy_full.w = w_feature;
|
||||
dy_full.c = dy_c;
|
||||
|
||||
LimitParam limit;
|
||||
limit.n = limit_n_seg;
|
||||
limit.h = limit_h_seg;
|
||||
limit.w = limit_w_seg;
|
||||
|
||||
psamaskBase(dy, dx, dy_full, dx_full, limit, psa_type, core_partition,
|
||||
cluster_partition, false, h_mask, w_mask, half_h_mask,
|
||||
half_w_mask, n_per_core, h_per_core, n_per_cluster,
|
||||
h_per_cluster);
|
||||
}
|
||||
|
||||
void KernelPsamaskForward(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const void *x, void *y, const PsamaskType psa_type,
|
||||
const DimPartitionType core_partition,
|
||||
const DimPartitionType cluster_partition, const int batch,
|
||||
const int h_feature, const int w_feature, const int h_mask,
|
||||
const int w_mask, const int x_c, const int y_c, const int half_h_mask,
|
||||
const int half_w_mask, const int n_per_core, const int h_per_core,
|
||||
const int n_per_cluster, const int h_per_cluster, const int limit_n_seg,
|
||||
const int limit_h_seg, const int limit_w_seg) {
|
||||
MLUUnion1KernelPsamaskForward<<<k_dim, k_type, queue>>>(
|
||||
static_cast<const float *>(x), static_cast<float *>(y), psa_type,
|
||||
core_partition, cluster_partition, batch, h_feature, w_feature, h_mask,
|
||||
w_mask, x_c, y_c, half_h_mask, half_w_mask, n_per_core, h_per_core,
|
||||
n_per_cluster, h_per_cluster, limit_n_seg, limit_h_seg, limit_w_seg);
|
||||
}
|
||||
|
||||
void KernelPsamaskBackward(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const void *dy, void *dx, const PsamaskType psa_type,
|
||||
const DimPartitionType core_partition,
|
||||
const DimPartitionType cluster_partition, const int batch,
|
||||
const int h_feature, const int w_feature, const int h_mask,
|
||||
const int w_mask, const int dx_c, const int dy_c, const int half_h_mask,
|
||||
const int half_w_mask, const int n_per_core, const int h_per_core,
|
||||
const int n_per_cluster, const int h_per_cluster, const int limit_n_seg,
|
||||
const int limit_h_seg, const int limit_w_seg) {
|
||||
MLUUnion1KernelPsamaskBackward<<<k_dim, k_type, queue>>>(
|
||||
static_cast<const float *>(dy), static_cast<float *>(dx), psa_type,
|
||||
core_partition, cluster_partition, batch, h_feature, w_feature, h_mask,
|
||||
w_mask, dx_c, dy_c, half_h_mask, half_w_mask, n_per_core, h_per_core,
|
||||
n_per_cluster, h_per_cluster, limit_n_seg, limit_h_seg, limit_w_seg);
|
||||
}
|
|
@ -1,55 +0,0 @@
|
|||
/*************************************************************************
|
||||
* Copyright (C) 2022 Cambricon.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef PSAMASK_UTILS_HPP_
|
||||
#define PSAMASK_UTILS_HPP_
|
||||
|
||||
typedef enum {
|
||||
COLLECT = 0,
|
||||
DISTRIBUTE = 1,
|
||||
} PsamaskType;
|
||||
|
||||
typedef enum {
|
||||
PARTITION_N = 0,
|
||||
PARTITION_H = 1,
|
||||
} DimPartitionType;
|
||||
|
||||
struct PartitionSeg {
|
||||
int h_per_cluster;
|
||||
int n_per_cluster;
|
||||
int h_per_core;
|
||||
int n_per_core;
|
||||
DimPartitionType cluster_partition;
|
||||
DimPartitionType core_partition;
|
||||
};
|
||||
|
||||
struct Shape {
|
||||
int n;
|
||||
int h;
|
||||
int w;
|
||||
int c;
|
||||
};
|
||||
|
||||
struct LimitParam {
|
||||
int n;
|
||||
int h;
|
||||
int w;
|
||||
};
|
||||
|
||||
struct PositionInCore {
|
||||
int n_start;
|
||||
int n_end;
|
||||
int h_start;
|
||||
int h_end;
|
||||
int w_start;
|
||||
int w_end;
|
||||
};
|
||||
#endif // PSAMASK_UTILS_HPP_
|
|
@ -9,136 +9,7 @@
|
|||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <algorithm>
|
||||
|
||||
#include "psamask_utils.hpp"
|
||||
#include "pytorch_device_registry.hpp"
|
||||
#include "pytorch_mlu_helper.hpp"
|
||||
|
||||
#define COMPUTE_COUNT_ALIGN 64
|
||||
|
||||
void KernelPsamaskForward(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const void *x, void *y, const PsamaskType psa_type,
|
||||
const DimPartitionType core_partition,
|
||||
const DimPartitionType cluster_partition, const int batch,
|
||||
const int h_feature, const int w_feature, const int h_mask,
|
||||
const int w_mask, const int x_c, const int y_c, const int half_h_mask,
|
||||
const int half_w_mask, const int n_per_core, const int h_per_core,
|
||||
const int n_per_cluster, const int h_per_cluster, const int limit_n_seg,
|
||||
const int limit_h_seg, const int limit_w_seg);
|
||||
|
||||
void KernelPsamaskBackward(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const void *dy, void *dx, const PsamaskType psa_type,
|
||||
const DimPartitionType core_partition,
|
||||
const DimPartitionType cluster_partition, const int batch,
|
||||
const int h_feature, const int w_feature, const int h_mask,
|
||||
const int w_mask, const int dx_c, const int dy_c, const int half_h_mask,
|
||||
const int half_w_mask, const int n_per_core, const int h_per_core,
|
||||
const int n_per_cluster, const int h_per_cluster, const int limit_n_seg,
|
||||
const int limit_h_seg, const int limit_w_seg);
|
||||
|
||||
namespace {
|
||||
void policyFunc(cnrtDim3_t *k_dim_ptr, cnrtFunctionType_t *f_type_ptr,
|
||||
PartitionSeg *partition_ptr, const int n, const int h_feature) {
|
||||
unsigned int core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
|
||||
unsigned int cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
|
||||
unsigned int use_cluster_num = cluster_num;
|
||||
unsigned int use_core_num = core_dim;
|
||||
|
||||
if (n >= cluster_num || n >= h_feature) {
|
||||
partition_ptr->cluster_partition = PARTITION_N;
|
||||
partition_ptr->n_per_cluster = (n + cluster_num - 1) / cluster_num;
|
||||
partition_ptr->h_per_cluster = h_feature;
|
||||
use_cluster_num =
|
||||
(n + partition_ptr->n_per_cluster - 1) / partition_ptr->n_per_cluster;
|
||||
} else {
|
||||
partition_ptr->cluster_partition = PARTITION_H;
|
||||
partition_ptr->h_per_cluster = (h_feature + cluster_num - 1) / cluster_num;
|
||||
partition_ptr->n_per_cluster = n;
|
||||
use_cluster_num = (h_feature + partition_ptr->h_per_cluster - 1) /
|
||||
partition_ptr->h_per_cluster;
|
||||
}
|
||||
|
||||
if (partition_ptr->n_per_cluster >= core_dim ||
|
||||
partition_ptr->n_per_cluster >= partition_ptr->h_per_cluster) {
|
||||
partition_ptr->core_partition = PARTITION_N;
|
||||
partition_ptr->n_per_core =
|
||||
(partition_ptr->n_per_cluster + core_dim - 1) / core_dim;
|
||||
partition_ptr->h_per_core = partition_ptr->h_per_cluster;
|
||||
use_core_num =
|
||||
(partition_ptr->n_per_cluster + partition_ptr->n_per_core - 1) /
|
||||
partition_ptr->n_per_core;
|
||||
} else {
|
||||
partition_ptr->core_partition = PARTITION_H;
|
||||
partition_ptr->h_per_core =
|
||||
(partition_ptr->h_per_cluster + core_dim - 1) / core_dim;
|
||||
partition_ptr->n_per_core = partition_ptr->n_per_cluster;
|
||||
use_core_num =
|
||||
(partition_ptr->h_per_cluster + partition_ptr->h_per_core - 1) /
|
||||
partition_ptr->h_per_core;
|
||||
}
|
||||
*k_dim_ptr = {core_dim, use_cluster_num, 1};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool findLimit(const int shape_core_n, const int shape_core_h,
|
||||
const int shape_core_w, const int shape_core_ci,
|
||||
const int shape_core_co, int *limit_n_seg_ptr,
|
||||
int *limit_h_seg_ptr, int *limit_w_seg_ptr, const int psa_type) {
|
||||
const bool need_temp = psa_type == 1;
|
||||
const int input_bytes = sizeof(float);
|
||||
int limit_n_seg = shape_core_n;
|
||||
int limit_h_seg = shape_core_h;
|
||||
int limit_w_seg = shape_core_w;
|
||||
|
||||
const int max_nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
|
||||
const int align_base_128 = NFU_ALIGN_SIZE / input_bytes;
|
||||
const int align_base_64 = COMPUTE_COUNT_ALIGN / input_bytes;
|
||||
const int align_co = CEIL_ALIGN(shape_core_co, align_base_64);
|
||||
const int align_w = CEIL_ALIGN(shape_core_w, align_base_64);
|
||||
const int align_hw = CEIL_ALIGN(shape_core_h * shape_core_w, align_base_64);
|
||||
const int max_num = max_nram_size / input_bytes;
|
||||
|
||||
int n_limit =
|
||||
max_num /
|
||||
(CEIL_ALIGN(shape_core_h * shape_core_w * shape_core_ci, align_base_128) +
|
||||
align_hw * align_co * (1 + need_temp));
|
||||
if (n_limit > 0) {
|
||||
n_limit = std::min(n_limit, shape_core_n);
|
||||
limit_n_seg = n_limit;
|
||||
} else {
|
||||
int h_limit =
|
||||
max_num / (CEIL_ALIGN(shape_core_w * shape_core_ci, align_base_128) +
|
||||
align_w * align_co * (1 + need_temp));
|
||||
if (h_limit > 0) {
|
||||
h_limit = std::min(h_limit, shape_core_h);
|
||||
limit_h_seg = h_limit;
|
||||
limit_n_seg = 1;
|
||||
} else {
|
||||
int w_limit =
|
||||
max_num / (CEIL_ALIGN(shape_core_ci, align_base_128) +
|
||||
CEIL_ALIGN(align_co, align_base_128) * (1 + need_temp));
|
||||
if (w_limit > 0 && w_limit >= (COMPUTE_COUNT_ALIGN / input_bytes)) {
|
||||
w_limit = std::min(w_limit, shape_core_w);
|
||||
w_limit = w_limit / (COMPUTE_COUNT_ALIGN / input_bytes) *
|
||||
(COMPUTE_COUNT_ALIGN / input_bytes);
|
||||
limit_w_seg = w_limit;
|
||||
limit_h_seg = 1;
|
||||
limit_n_seg = 1;
|
||||
} else {
|
||||
CNLOG(INFO) << "The size of input channel is too large.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
*limit_n_seg_ptr = limit_n_seg;
|
||||
*limit_h_seg_ptr = limit_h_seg;
|
||||
*limit_w_seg_ptr = limit_w_seg;
|
||||
return true;
|
||||
}
|
||||
#include "mlu_common_helper.h"
|
||||
|
||||
void PSAMaskForwardMLUKernelLauncher(const int psa_type, const Tensor x,
|
||||
Tensor y, const int num_,
|
||||
|
@ -146,39 +17,7 @@ void PSAMaskForwardMLUKernelLauncher(const int psa_type, const Tensor x,
|
|||
const int h_mask, const int w_mask,
|
||||
const int half_h_mask,
|
||||
const int half_w_mask) {
|
||||
// params check
|
||||
TORCH_CHECK(x.scalar_type() == at::kFloat, "x type should be Float, got ",
|
||||
x.scalar_type());
|
||||
TORCH_CHECK(y.scalar_type() == x.scalar_type(),
|
||||
"y should have the same type as x");
|
||||
TORCH_CHECK(x.dim() == 4, "x should be a 4d tensor, got ", x.dim(), "D");
|
||||
TORCH_CHECK(y.dim() == 4, "y should be a 4d tensor, got ", y.dim(), "D");
|
||||
|
||||
int x_c = x.size(1);
|
||||
int y_c = y.size(1);
|
||||
TORCH_CHECK(h_mask * w_mask == x_c,
|
||||
"channel of x should be the same as h_mask * w_mask");
|
||||
TORCH_CHECK(h_feature * w_feature == y_c,
|
||||
"channel of y should be the same as h_feature * w_feature");
|
||||
TORCH_CHECK(psa_type == 0 || psa_type == 1,
|
||||
"psa_type only supports 'COLLECT' and 'DISTRIBUTE' currently");
|
||||
|
||||
if (x.numel() == 0) {
|
||||
CNLOG(INFO) << "skip zero-element tensor";
|
||||
return;
|
||||
}
|
||||
|
||||
cnrtFunctionType_t k_type = CNRT_FUNC_TYPE_UNION1;
|
||||
cnrtDim3_t k_dim;
|
||||
PartitionSeg partition_info;
|
||||
policyFunc(&k_dim, &k_type, &partition_info, num_, h_feature);
|
||||
int n_limit_seg, h_limit_seg, w_limit_seg;
|
||||
bool ret =
|
||||
findLimit(partition_info.n_per_core, partition_info.h_per_core, w_feature,
|
||||
x_c, y_c, &n_limit_seg, &h_limit_seg, &w_limit_seg, psa_type);
|
||||
if (ret != true) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto memory_format =
|
||||
torch_mlu::cnnl::ops::get_channels_last_memory_format(x.dim());
|
||||
|
@ -186,22 +25,18 @@ void PSAMaskForwardMLUKernelLauncher(const int psa_type, const Tensor x,
|
|||
at::Tensor y_tmp =
|
||||
at::empty({num_, y_c, h_feature, w_feature}, x.options(), memory_format);
|
||||
|
||||
// get compute queue
|
||||
auto queue = torch_mlu::getCurQueue();
|
||||
MluOpTensorDescriptor x_desc, y_desc;
|
||||
x_desc.set_with_layout(x_tensor, MLUOP_LAYOUT_NHWC);
|
||||
y_desc.set_with_layout(y_tmp, MLUOP_LAYOUT_NHWC);
|
||||
|
||||
// get ptr of tensors
|
||||
auto handle = mluOpGetCurrentHandle();
|
||||
auto x_impl = torch_mlu::getMluTensorImpl(x_tensor);
|
||||
auto x_ptr = x_impl->cnnlMalloc();
|
||||
auto y_impl = torch_mlu::getMluTensorImpl(y_tmp);
|
||||
auto y_ptr = y_impl->cnnlMalloc();
|
||||
|
||||
KernelPsamaskForward(
|
||||
k_dim, k_type, queue, x_ptr, y_ptr, (PsamaskType)psa_type,
|
||||
partition_info.core_partition, partition_info.cluster_partition, num_,
|
||||
h_feature, w_feature, h_mask, w_mask, x_c, y_c, half_h_mask, half_w_mask,
|
||||
partition_info.n_per_core, partition_info.h_per_core,
|
||||
partition_info.n_per_cluster, partition_info.h_per_cluster, n_limit_seg,
|
||||
h_limit_seg, w_limit_seg);
|
||||
mluOpPsamaskForward(handle, psa_type, x_desc.desc(), x_ptr, h_mask, w_mask,
|
||||
y_desc.desc(), y_ptr);
|
||||
|
||||
y.copy_(y_tmp);
|
||||
}
|
||||
|
@ -212,39 +47,7 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
|
|||
const int h_mask, const int w_mask,
|
||||
const int half_h_mask,
|
||||
const int half_w_mask) {
|
||||
// params check
|
||||
TORCH_CHECK(dy.scalar_type() == at::kFloat, "dy type should be Float, got ",
|
||||
dy.scalar_type());
|
||||
TORCH_CHECK(dx.scalar_type() == dy.scalar_type(),
|
||||
"dx should have the same type as dy");
|
||||
TORCH_CHECK(dy.dim() == 4, "dy should be a 4d tensor, got ", dy.dim(), "D");
|
||||
TORCH_CHECK(dx.dim() == 4, "dx should be a 4d tensor, got ", dx.dim(), "D");
|
||||
|
||||
int dy_c = dy.size(1);
|
||||
int dx_c = dx.size(1);
|
||||
TORCH_CHECK(h_feature * w_feature == dy_c,
|
||||
"channel of dy should be the same as h_feature * w_feature");
|
||||
TORCH_CHECK(h_mask * w_mask == dx_c,
|
||||
"channel of dx should be the same as h_mask * w_mask");
|
||||
TORCH_CHECK(psa_type == 0 || psa_type == 1,
|
||||
"psa_type only supports 'COLLECT' and 'DISTRIBUTE' currently");
|
||||
|
||||
if (dx.numel() == 0) {
|
||||
CNLOG(INFO) << "skip zero-element tensor";
|
||||
return;
|
||||
}
|
||||
|
||||
cnrtFunctionType_t k_type = CNRT_FUNC_TYPE_UNION1;
|
||||
cnrtDim3_t k_dim;
|
||||
PartitionSeg partition_info;
|
||||
policyFunc(&k_dim, &k_type, &partition_info, num_, h_feature);
|
||||
int n_limit_seg, h_limit_seg, w_limit_seg;
|
||||
bool ret =
|
||||
findLimit(partition_info.n_per_core, partition_info.h_per_core, w_feature,
|
||||
dx_c, dy_c, &n_limit_seg, &h_limit_seg, &w_limit_seg, psa_type);
|
||||
if (ret != true) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto memory_format =
|
||||
torch_mlu::cnnl::ops::get_channels_last_memory_format(dy.dim());
|
||||
|
@ -252,8 +55,11 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
|
|||
at::Tensor dx_tmp = at::empty({num_, dx_c, h_feature, w_feature},
|
||||
dy.options(), memory_format);
|
||||
|
||||
// get compute queue
|
||||
auto queue = torch_mlu::getCurQueue();
|
||||
MluOpTensorDescriptor dy_desc, dx_tmp_desc;
|
||||
dy_desc.set_with_layout(dy_tensor, MLUOP_LAYOUT_NHWC);
|
||||
dx_tmp_desc.set_with_layout(dx_tmp, MLUOP_LAYOUT_NHWC);
|
||||
|
||||
auto handle = mluOpGetCurrentHandle();
|
||||
|
||||
// get ptr of tensors
|
||||
auto dx_impl = torch_mlu::getMluTensorImpl(dx_tmp);
|
||||
|
@ -261,13 +67,8 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
|
|||
auto dy_impl = torch_mlu::getMluTensorImpl(dy_tensor);
|
||||
auto dy_ptr = dy_impl->cnnlMalloc();
|
||||
|
||||
KernelPsamaskBackward(
|
||||
k_dim, k_type, queue, dy_ptr, dx_ptr, (PsamaskType)psa_type,
|
||||
partition_info.core_partition, partition_info.cluster_partition, num_,
|
||||
h_feature, w_feature, h_mask, w_mask, dx_c, dy_c, half_h_mask,
|
||||
half_w_mask, partition_info.n_per_core, partition_info.h_per_core,
|
||||
partition_info.n_per_cluster, partition_info.h_per_cluster, n_limit_seg,
|
||||
h_limit_seg, w_limit_seg);
|
||||
mluOpPsamaskBackward(handle, psa_type, dy_desc.desc(), dy_ptr, h_mask, w_mask,
|
||||
dx_tmp_desc.desc(), dx_ptr);
|
||||
|
||||
dx.copy_(dx_tmp);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue