mirror of https://github.com/open-mmlab/mmcv.git
[MUSA] mmcv support musa, split pr 4 (#3260)
* mmcv support musa, split pr 4 * fix lint * fix lintpull/3264/head
parent
24a2bb4f7b
commit
4b38ffcf45
|
@ -0,0 +1,91 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef POINT_IN_BOXES_MUSA_KERNEL_MUH
|
||||
#define POINT_IN_BOXES_MUSA_KERNEL_MUH
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
template <typename T>
|
||||
__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz,
|
||||
T &local_x, T &local_y) {
|
||||
T cosa = cos(-rz), sina = sin(-rz);
|
||||
local_x = shift_x * cosa + shift_y * (-sina);
|
||||
local_y = shift_x * sina + shift_y * cosa;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x,
|
||||
T &local_y) {
|
||||
// param pt: (x, y, z)
|
||||
// param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate,
|
||||
// cz in the bottom center
|
||||
T x = pt[0], y = pt[1], z = pt[2];
|
||||
T cx = box3d[0], cy = box3d[1], cz = box3d[2];
|
||||
T x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6];
|
||||
cz += z_size /
|
||||
2.0; // shift to the center since cz in box3d is the bottom center
|
||||
|
||||
if (fabsf(z - cz) > z_size / 2.0) return 0;
|
||||
lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
|
||||
float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) &
|
||||
(local_y > -y_size / 2.0) & (local_y < y_size / 2.0);
|
||||
return in_flag;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void points_in_boxes_part_forward_musa_kernel(
|
||||
int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts,
|
||||
int *box_idx_of_points) {
|
||||
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
|
||||
// coordinate, z is the bottom center, each box DO NOT overlaps params pts:
|
||||
// (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points:
|
||||
// (B, npoints), default -1
|
||||
|
||||
int bs_idx = blockIdx.y;
|
||||
MUSA_1D_KERNEL_LOOP(pt_idx, pts_num) {
|
||||
if (bs_idx >= batch_size) return;
|
||||
|
||||
boxes += bs_idx * boxes_num * 7;
|
||||
pts += bs_idx * pts_num * 3 + pt_idx * 3;
|
||||
box_idx_of_points += bs_idx * pts_num + pt_idx;
|
||||
|
||||
T local_x = 0, local_y = 0;
|
||||
int cur_in_flag = 0;
|
||||
for (int k = 0; k < boxes_num; k++) {
|
||||
cur_in_flag = check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y);
|
||||
if (cur_in_flag) {
|
||||
box_idx_of_points[0] = k;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void points_in_boxes_all_forward_musa_kernel(
|
||||
int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts,
|
||||
int *box_idx_of_points) {
|
||||
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
|
||||
// coordinate, z is the bottom center, each box DO NOT overlaps params pts:
|
||||
// (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points:
|
||||
// (B, npoints), default -1
|
||||
|
||||
int bs_idx = blockIdx.y;
|
||||
MUSA_1D_KERNEL_LOOP(pt_idx, pts_num) {
|
||||
if (bs_idx >= batch_size) return;
|
||||
|
||||
boxes += bs_idx * boxes_num * 7;
|
||||
pts += bs_idx * pts_num * 3 + pt_idx * 3;
|
||||
box_idx_of_points += bs_idx * pts_num * boxes_num + pt_idx * boxes_num;
|
||||
|
||||
T local_x = 0, local_y = 0;
|
||||
for (int k = 0; k < boxes_num; k++) {
|
||||
const int cur_in_flag =
|
||||
check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y);
|
||||
if (cur_in_flag) {
|
||||
box_idx_of_points[k] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // POINT_IN_BOXES_MUSA_KERNEL_MUH
|
|
@ -0,0 +1,75 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef POINTS_IN_POLYGONS_MUSA_KERNEL_MUH
|
||||
#define POINTS_IN_POLYGONS_MUSA_KERNEL_MUH
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
struct point {
|
||||
float x, y;
|
||||
};
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void points_in_polygons_forward_musa_kernel(
|
||||
const int nthreads, const scalar_t *vertex1, const scalar_t *vertex2,
|
||||
const int rows, const int cols, scalar_t *inside_flag) {
|
||||
MUSA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int row = index / cols;
|
||||
int col = index % cols;
|
||||
|
||||
const scalar_t *offset_vertex1 = vertex1 + row * 2;
|
||||
const scalar_t *offset_vertex2 = vertex2 + col * 8;
|
||||
|
||||
point point_[1];
|
||||
point polygon[4];
|
||||
|
||||
point_[0].x = offset_vertex1[0];
|
||||
point_[0].y = offset_vertex1[1];
|
||||
|
||||
polygon[0].x = offset_vertex2[0];
|
||||
polygon[0].y = offset_vertex2[1];
|
||||
polygon[1].x = offset_vertex2[2];
|
||||
polygon[1].y = offset_vertex2[3];
|
||||
polygon[2].x = offset_vertex2[4];
|
||||
polygon[2].y = offset_vertex2[5];
|
||||
polygon[3].x = offset_vertex2[6];
|
||||
polygon[3].y = offset_vertex2[7];
|
||||
|
||||
int nCross = 0;
|
||||
int i, j;
|
||||
float sx, sy, tx, ty, px, py, x;
|
||||
for (i = 0, j = 3; i < 4; j = i, i++) {
|
||||
sx = polygon[i].x;
|
||||
sy = polygon[i].y;
|
||||
tx = polygon[j].x;
|
||||
ty = polygon[j].y;
|
||||
|
||||
px = point_[0].x;
|
||||
py = point_[0].y;
|
||||
|
||||
if (py < min(sy, ty)) continue;
|
||||
if (py > max(sy, ty)) continue;
|
||||
|
||||
if ((sx == px && sy == py) || (tx == px && ty == py)) {
|
||||
break;
|
||||
} else {
|
||||
if ((sy < py && ty >= py) || (sy >= py && ty < py)) {
|
||||
x = sx + (py - sy) * (tx - sx) / (ty - sy);
|
||||
if (x == px) {
|
||||
break;
|
||||
}
|
||||
if (x > px) {
|
||||
nCross++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (nCross % 2 == 1) {
|
||||
inside_flag[index] = 1.0;
|
||||
} else {
|
||||
inside_flag[index] = 0.0;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // POINTS_IN_POLYGONS_MUSA_KERNEL_MUH
|
|
@ -0,0 +1,377 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
// Modified from
|
||||
// https://github.com/vacancy/PreciseRoIPooling/blob/master/src/prroi_pooling_gpu_impl.cu
|
||||
// Distributed under terms of the MIT license.
|
||||
#ifndef PRROI_POOL_MUSA_KERNEL_MUH
|
||||
#define PRROI_POOL_MUSA_KERNEL_MUH
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
template <typename T>
|
||||
__device__ static __forceinline__ T PrRoIPoolingGetData(const T *data,
|
||||
const int h,
|
||||
const int w,
|
||||
const int height,
|
||||
const int width) {
|
||||
bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);
|
||||
T retVal = overflow ? 0.0f : data[h * width + w];
|
||||
return retVal;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ static __forceinline__ T PrRoIPoolingGetCoeff(T dh, T dw) {
|
||||
return (1.0f - abs(dh)) * (1.0f - abs(dw));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ static __forceinline__ T PrRoIPoolingSingleCoorIntegral(T s, T t,
|
||||
T c1, T c2) {
|
||||
return 0.5 * (t * t - s * s) * (c2 - c1) + (t - s) * c1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ static T PrRoIPoolingInterpolation(const T *data, const T h,
|
||||
const T w, const int height,
|
||||
const int width) {
|
||||
T retVal = 0.0f;
|
||||
int h1 = floorf(h);
|
||||
int w1 = floorf(w);
|
||||
retVal += PrRoIPoolingGetData(data, h1, w1, height, width) *
|
||||
PrRoIPoolingGetCoeff(h - T(h1), w - T(w1));
|
||||
h1 = floorf(h) + 1;
|
||||
w1 = floorf(w);
|
||||
retVal += PrRoIPoolingGetData(data, h1, w1, height, width) *
|
||||
PrRoIPoolingGetCoeff(h - T(h1), w - T(w1));
|
||||
h1 = floorf(h);
|
||||
w1 = floorf(w) + 1;
|
||||
retVal += PrRoIPoolingGetData(data, h1, w1, height, width) *
|
||||
PrRoIPoolingGetCoeff(h - T(h1), w - T(w1));
|
||||
h1 = floorf(h) + 1;
|
||||
w1 = floorf(w) + 1;
|
||||
retVal += PrRoIPoolingGetData(data, h1, w1, height, width) *
|
||||
PrRoIPoolingGetCoeff(h - T(h1), w - T(w1));
|
||||
return retVal;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ static T PrRoIPoolingMatCalculation(const T *this_data,
|
||||
const int s_h, const int s_w,
|
||||
const int e_h, const int e_w,
|
||||
const T y0, const T x0,
|
||||
const T y1, const T x1,
|
||||
const int h0, const int w0) {
|
||||
T alpha, beta, lim_alpha, lim_beta, tmp;
|
||||
T sum_out = 0;
|
||||
|
||||
alpha = x0 - T(s_w);
|
||||
beta = y0 - T(s_h);
|
||||
lim_alpha = x1 - T(s_w);
|
||||
lim_beta = y1 - T(s_h);
|
||||
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
|
||||
0.5f * alpha * alpha) *
|
||||
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
|
||||
sum_out += PrRoIPoolingGetData(this_data, s_h, s_w, h0, w0) * tmp;
|
||||
|
||||
alpha = T(e_w) - x1;
|
||||
lim_alpha = T(e_w) - x0;
|
||||
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
|
||||
0.5f * alpha * alpha) *
|
||||
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
|
||||
sum_out += PrRoIPoolingGetData(this_data, s_h, e_w, h0, w0) * tmp;
|
||||
|
||||
alpha = x0 - T(s_w);
|
||||
beta = T(e_h) - y1;
|
||||
lim_alpha = x1 - T(s_w);
|
||||
lim_beta = T(e_h) - y0;
|
||||
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
|
||||
0.5f * alpha * alpha) *
|
||||
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
|
||||
sum_out += PrRoIPoolingGetData(this_data, e_h, s_w, h0, w0) * tmp;
|
||||
|
||||
alpha = T(e_w) - x1;
|
||||
lim_alpha = T(e_w) - x0;
|
||||
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
|
||||
0.5f * alpha * alpha) *
|
||||
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
|
||||
sum_out += PrRoIPoolingGetData(this_data, e_h, e_w, h0, w0) * tmp;
|
||||
|
||||
return sum_out;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ static void PrRoIPoolingDistributeDiff(T *diff, const T top_diff,
|
||||
const int h, const int w,
|
||||
const int height,
|
||||
const int width,
|
||||
const T coeff) {
|
||||
bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);
|
||||
if (!overflow) atomicAdd(diff + h * width + w, top_diff * coeff);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ static void PrRoIPoolingMatDistributeDiff(
|
||||
T *diff, const T top_diff, const int s_h, const int s_w, const int e_h,
|
||||
const int e_w, const T y0, const T x0, const T y1, const T x1, const int h0,
|
||||
const int w0) {
|
||||
T alpha, beta, lim_alpha, lim_beta, tmp;
|
||||
|
||||
alpha = x0 - T(s_w);
|
||||
beta = y0 - T(s_h);
|
||||
lim_alpha = x1 - T(s_w);
|
||||
lim_beta = y1 - T(s_h);
|
||||
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
|
||||
0.5f * alpha * alpha) *
|
||||
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
|
||||
PrRoIPoolingDistributeDiff(diff, top_diff, s_h, s_w, h0, w0, tmp);
|
||||
|
||||
alpha = T(e_w) - x1;
|
||||
lim_alpha = T(e_w) - x0;
|
||||
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
|
||||
0.5f * alpha * alpha) *
|
||||
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
|
||||
PrRoIPoolingDistributeDiff(diff, top_diff, s_h, e_w, h0, w0, tmp);
|
||||
|
||||
alpha = x0 - T(s_w);
|
||||
beta = T(e_h) - y1;
|
||||
lim_alpha = x1 - T(s_w);
|
||||
lim_beta = T(e_h) - y0;
|
||||
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
|
||||
0.5f * alpha * alpha) *
|
||||
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
|
||||
PrRoIPoolingDistributeDiff(diff, top_diff, e_h, s_w, h0, w0, tmp);
|
||||
|
||||
alpha = T(e_w) - x1;
|
||||
lim_alpha = T(e_w) - x0;
|
||||
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
|
||||
0.5f * alpha * alpha) *
|
||||
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
|
||||
PrRoIPoolingDistributeDiff(diff, top_diff, e_h, e_w, h0, w0, tmp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void prroi_pool_forward_musa_kernel(
|
||||
const int nthreads, const T *input, const T *rois, T *output,
|
||||
const int pooled_height, const int pooled_width, const T spatial_scale,
|
||||
const int channels, const int height, const int width) {
|
||||
MUSA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
|
||||
const T *offset_rois = rois + n * 5;
|
||||
int roi_batch_ind = offset_rois[0];
|
||||
|
||||
T roi_x1 = offset_rois[1] * spatial_scale;
|
||||
T roi_y1 = offset_rois[2] * spatial_scale;
|
||||
T roi_x2 = offset_rois[3] * spatial_scale;
|
||||
T roi_y2 = offset_rois[4] * spatial_scale;
|
||||
|
||||
T roi_width = max(roi_x2 - roi_x1, ((T)0.0));
|
||||
T roi_height = max(roi_y2 - roi_y1, ((T)0.0));
|
||||
T bin_size_h = roi_height / static_cast<T>(pooled_height);
|
||||
T bin_size_w = roi_width / static_cast<T>(pooled_width);
|
||||
|
||||
const T *this_data =
|
||||
input + (roi_batch_ind * channels + c) * height * width;
|
||||
T *this_out = output + index;
|
||||
|
||||
T bin_x1 = roi_x1 + bin_size_w * pw;
|
||||
T bin_y1 = roi_y1 + bin_size_h * ph;
|
||||
T bin_x2 = bin_x1 + bin_size_w;
|
||||
T bin_y2 = bin_y1 + bin_size_h;
|
||||
|
||||
T bin_size = max(T(0.0), bin_size_w * bin_size_h);
|
||||
if (bin_size == 0) {
|
||||
*this_out = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
T sum_out = 0;
|
||||
|
||||
int start_x, start_y, end_x, end_y;
|
||||
|
||||
start_x = floorf(bin_x1);
|
||||
end_x = ceilf(bin_x2);
|
||||
start_y = floorf(bin_y1);
|
||||
end_y = ceilf(bin_y2);
|
||||
|
||||
for (int bin_x = start_x; bin_x < end_x; ++bin_x)
|
||||
for (int bin_y = start_y; bin_y < end_y; ++bin_y)
|
||||
sum_out += PrRoIPoolingMatCalculation(
|
||||
this_data, bin_y, bin_x, bin_y + 1, bin_x + 1,
|
||||
max(bin_y1, T(bin_y)), max(bin_x1, T(bin_x)),
|
||||
min(bin_y2, T(bin_y) + 1.0f), min(bin_x2, T(bin_x + 1.0f)), height,
|
||||
width);
|
||||
*this_out = sum_out / bin_size;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void prroi_pool_backward_musa_kernel(
|
||||
const int nthreads, const T *grad_output, const T *rois, T *grad_input,
|
||||
const int pooled_height, const int pooled_width, const T spatial_scale,
|
||||
const int channels, const int height, const int width) {
|
||||
MUSA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
auto rois_cur = rois + n * 5;
|
||||
|
||||
int roi_batch_ind = rois_cur[0];
|
||||
T roi_x1 = rois_cur[1] * spatial_scale;
|
||||
T roi_y1 = rois_cur[2] * spatial_scale;
|
||||
T roi_x2 = rois_cur[3] * spatial_scale;
|
||||
T roi_y2 = rois_cur[4] * spatial_scale;
|
||||
|
||||
T roi_width = max(roi_x2 - roi_x1, (T)0);
|
||||
T roi_height = max(roi_y2 - roi_y1, (T)0);
|
||||
T bin_size_h = roi_height / static_cast<T>(pooled_height);
|
||||
T bin_size_w = roi_width / static_cast<T>(pooled_width);
|
||||
|
||||
const T *this_out_grad = grad_output + index;
|
||||
T *this_data_grad =
|
||||
grad_input + (roi_batch_ind * channels + c) * height * width;
|
||||
|
||||
T bin_x1 = roi_x1 + bin_size_w * pw;
|
||||
T bin_y1 = roi_y1 + bin_size_h * ph;
|
||||
T bin_x2 = bin_x1 + bin_size_w;
|
||||
T bin_y2 = bin_y1 + bin_size_h;
|
||||
|
||||
T bin_size = max(T(0.0), bin_size_w * bin_size_h);
|
||||
|
||||
T sum_out = bin_size == T(0) ? T(0) : *this_out_grad / bin_size;
|
||||
|
||||
int start_x, start_y, end_x, end_y;
|
||||
|
||||
start_x = floorf(bin_x1);
|
||||
end_x = ceilf(bin_x2);
|
||||
start_y = floorf(bin_y1);
|
||||
end_y = ceilf(bin_y2);
|
||||
|
||||
for (int bin_x = start_x; bin_x < end_x; ++bin_x)
|
||||
for (int bin_y = start_y; bin_y < end_y; ++bin_y)
|
||||
PrRoIPoolingMatDistributeDiff(
|
||||
this_data_grad, sum_out, bin_y, bin_x, bin_y + 1, bin_x + 1,
|
||||
max(bin_y1, T(bin_y)), max(bin_x1, T(bin_x)),
|
||||
min(bin_y2, T(bin_y) + 1.0f), min(bin_x2, T(bin_x + 1.0f)), height,
|
||||
width);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void prroi_pool_coor_backward_musa_kernel(
|
||||
const int nthreads, const T *output, const T *grad_output, const T *input,
|
||||
const T *rois, T *grad_rois, const int pooled_height,
|
||||
const int pooled_width, const T spatial_scale, const int channels,
|
||||
const int height, const int width) {
|
||||
MUSA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
auto rois_cur = rois + n * 5;
|
||||
|
||||
int roi_batch_ind = rois_cur[0];
|
||||
T roi_x1 = rois_cur[1] * spatial_scale;
|
||||
T roi_y1 = rois_cur[2] * spatial_scale;
|
||||
T roi_x2 = rois_cur[3] * spatial_scale;
|
||||
T roi_y2 = rois_cur[4] * spatial_scale;
|
||||
|
||||
T roi_width = max(roi_x2 - roi_x1, (T)0);
|
||||
T roi_height = max(roi_y2 - roi_y1, (T)0);
|
||||
T bin_size_h = roi_height / static_cast<T>(pooled_height);
|
||||
T bin_size_w = roi_width / static_cast<T>(pooled_width);
|
||||
|
||||
const T output_grad_val = grad_output[index];
|
||||
const T *this_input_data =
|
||||
input + (roi_batch_ind * channels + c) * height * width;
|
||||
const T output_val = output[index];
|
||||
T *this_rois_grad = grad_rois + n * 5;
|
||||
|
||||
T bin_x1 = roi_x1 + bin_size_w * pw;
|
||||
T bin_y1 = roi_y1 + bin_size_h * ph;
|
||||
T bin_x2 = bin_x1 + bin_size_w;
|
||||
T bin_y2 = bin_y1 + bin_size_h;
|
||||
|
||||
T bin_size = max(T(0.0), bin_size_w * bin_size_h);
|
||||
|
||||
T sum_out = bin_size == T(0) ? T(0) : output_grad_val / bin_size;
|
||||
|
||||
// WARNING: to be discussed
|
||||
if (sum_out == 0) continue;
|
||||
|
||||
int start_x, start_y, end_x, end_y;
|
||||
|
||||
start_x = floorf(bin_x1);
|
||||
end_x = ceilf(bin_x2);
|
||||
start_y = floorf(bin_y1);
|
||||
end_y = ceilf(bin_y2);
|
||||
|
||||
T grad_x1_y = 0, grad_x2_y = 0, grad_x_y1 = 0, grad_x_y2 = 0;
|
||||
for (int bin_y = start_y; bin_y < end_y; ++bin_y) {
|
||||
grad_x1_y += PrRoIPoolingSingleCoorIntegral(
|
||||
max(bin_y1, T(bin_y)) - bin_y, min(bin_y2, T(bin_y + 1)) - bin_y,
|
||||
PrRoIPoolingInterpolation(this_input_data, float(bin_y), bin_x1,
|
||||
height, width),
|
||||
PrRoIPoolingInterpolation(this_input_data, float(bin_y + 1), bin_x1,
|
||||
height, width));
|
||||
|
||||
grad_x2_y += PrRoIPoolingSingleCoorIntegral(
|
||||
max(bin_y1, T(bin_y)) - bin_y, min(bin_y2, T(bin_y + 1)) - bin_y,
|
||||
PrRoIPoolingInterpolation(this_input_data, float(bin_y), bin_x2,
|
||||
height, width),
|
||||
PrRoIPoolingInterpolation(this_input_data, float(bin_y + 1), bin_x2,
|
||||
height, width));
|
||||
}
|
||||
|
||||
for (int bin_x = start_x; bin_x < end_x; ++bin_x) {
|
||||
grad_x_y1 += PrRoIPoolingSingleCoorIntegral(
|
||||
max(bin_x1, T(bin_x)) - bin_x, min(bin_x2, T(bin_x + 1)) - bin_x,
|
||||
PrRoIPoolingInterpolation(this_input_data, bin_y1, float(bin_x),
|
||||
height, width),
|
||||
PrRoIPoolingInterpolation(this_input_data, bin_y1, float(bin_x + 1),
|
||||
height, width));
|
||||
|
||||
grad_x_y2 += PrRoIPoolingSingleCoorIntegral(
|
||||
max(bin_x1, T(bin_x)) - bin_x, min(bin_x2, T(bin_x + 1)) - bin_x,
|
||||
PrRoIPoolingInterpolation(this_input_data, bin_y2, float(bin_x),
|
||||
height, width),
|
||||
PrRoIPoolingInterpolation(this_input_data, bin_y2, float(bin_x + 1),
|
||||
height, width));
|
||||
}
|
||||
|
||||
T partial_x1 = -grad_x1_y + (bin_y2 - bin_y1) * output_val;
|
||||
T partial_y1 = -grad_x_y1 + (bin_x2 - bin_x1) * output_val;
|
||||
T partial_x2 = grad_x2_y - (bin_y2 - bin_y1) * output_val;
|
||||
T partial_y2 = grad_x_y2 - (bin_x2 - bin_x1) * output_val;
|
||||
|
||||
partial_x1 = partial_x1 / bin_size * spatial_scale;
|
||||
partial_x2 = partial_x2 / bin_size * spatial_scale;
|
||||
partial_y1 = partial_y1 / bin_size * spatial_scale;
|
||||
partial_y2 = partial_y2 / bin_size * spatial_scale;
|
||||
|
||||
// (index, x1, y1, x2, y2)
|
||||
this_rois_grad[0] = 0;
|
||||
atomicAdd(this_rois_grad + 1,
|
||||
(partial_x1 * (1.0f - T(pw) / pooled_width) +
|
||||
partial_x2 * (1.0f - T(pw + 1) / pooled_width)) *
|
||||
output_grad_val);
|
||||
atomicAdd(this_rois_grad + 2,
|
||||
(partial_y1 * (1.0f - T(ph) / pooled_height) +
|
||||
partial_y2 * (1.0f - T(ph + 1) / pooled_height)) *
|
||||
output_grad_val);
|
||||
atomicAdd(this_rois_grad + 3, (partial_x2 * T(pw + 1) / pooled_width +
|
||||
partial_x1 * T(pw) / pooled_width) *
|
||||
output_grad_val);
|
||||
atomicAdd(this_rois_grad + 4, (partial_y2 * T(ph + 1) / pooled_height +
|
||||
partial_y1 * T(ph) / pooled_height) *
|
||||
output_grad_val);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // ROI_POOL_MUSA_KERNEL_MUH
|
|
@ -0,0 +1,137 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef PSAMASK_MUSA_KERNEL_MUH
|
||||
#define PSAMASK_MUSA_KERNEL_MUH
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
// MUSA: grid stride looping
|
||||
#ifndef MUSA_KERNEL_LOOP
|
||||
#define MUSA_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__global__ void psamask_collect_forward_musa(
|
||||
const int nthreads, const int h_feature, const int w_feature,
|
||||
const int h_mask, const int w_mask, const int half_h_mask,
|
||||
const int half_w_mask, const T* mask_data, T* buffer_data) {
|
||||
MUSA_KERNEL_LOOP(index, nthreads) {
|
||||
const int w = index % w_feature;
|
||||
const int h = (index / w_feature) % h_feature;
|
||||
const int n = index / w_feature / h_feature;
|
||||
// effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed
|
||||
const int hstart = max(0, half_h_mask - h);
|
||||
const int hend = min(h_mask, h_feature + half_h_mask - h);
|
||||
const int wstart = max(0, half_w_mask - w);
|
||||
const int wend = min(w_mask, w_feature + half_w_mask - w);
|
||||
// (hidx, widx ) with mask-indexed
|
||||
// (hidx + h - half_h_mask, widx + w - half_w_mask) with feature-indexed
|
||||
for (int hidx = hstart; hidx < hend; hidx++) {
|
||||
for (int widx = wstart; widx < wend; widx++) {
|
||||
buffer_data[(n * h_feature * w_feature +
|
||||
(hidx + h - half_h_mask) * w_feature +
|
||||
(widx + w - half_w_mask)) *
|
||||
h_feature * w_feature +
|
||||
h * w_feature + w] = mask_data
|
||||
[((n * h_mask * w_mask + hidx * w_mask + widx) * h_feature + h) *
|
||||
w_feature +
|
||||
w];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void psamask_distribute_forward_musa(
|
||||
const int nthreads, const int h_feature, const int w_feature,
|
||||
const int h_mask, const int w_mask, const int half_h_mask,
|
||||
const int half_w_mask, const T* mask_data, T* buffer_data) {
|
||||
MUSA_KERNEL_LOOP(index, nthreads) {
|
||||
const int w = index % w_feature;
|
||||
const int h = (index / w_feature) % h_feature;
|
||||
const int n = index / w_feature / h_feature;
|
||||
// effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed
|
||||
const int hstart = max(0, half_h_mask - h);
|
||||
const int hend = min(h_mask, h_feature + half_h_mask - h);
|
||||
const int wstart = max(0, half_w_mask - w);
|
||||
const int wend = min(w_mask, w_feature + half_w_mask - w);
|
||||
// (hidx, widx ) with mask-indexed
|
||||
// (hidx + h - half_h_mask, widx + w - half_w_mask) with feature-indexed
|
||||
for (int hidx = hstart; hidx < hend; hidx++) {
|
||||
for (int widx = wstart; widx < wend; widx++) {
|
||||
buffer_data[(n * h_feature * w_feature + h * w_feature + w) *
|
||||
h_feature * w_feature +
|
||||
(hidx + h - half_h_mask) * w_feature +
|
||||
(widx + w - half_w_mask)] = mask_data
|
||||
[((n * h_mask * w_mask + hidx * w_mask + widx) * h_feature + h) *
|
||||
w_feature +
|
||||
w];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void psamask_collect_backward_musa(
|
||||
const int nthreads, const int h_feature, const int w_feature,
|
||||
const int h_mask, const int w_mask, const int half_h_mask,
|
||||
const int half_w_mask, const T* buffer_diff, T* mask_diff) {
|
||||
MUSA_KERNEL_LOOP(index, nthreads) {
|
||||
const int w = index % w_feature;
|
||||
const int h = (index / w_feature) % h_feature;
|
||||
const int n = index / w_feature / h_feature;
|
||||
// effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed
|
||||
const int hstart = max(0, half_h_mask - h);
|
||||
const int hend = min(h_mask, h_feature + half_h_mask - h);
|
||||
const int wstart = max(0, half_w_mask - w);
|
||||
const int wend = min(w_mask, w_feature + half_w_mask - w);
|
||||
// (hidx, widx ) with mask-indexed
|
||||
// (hidx + h - half_h_mask, widx + w - half_w_mask) with feature-indexed
|
||||
for (int hidx = hstart; hidx < hend; hidx++) {
|
||||
for (int widx = wstart; widx < wend; widx++) {
|
||||
mask_diff[((n * h_mask * w_mask + hidx * w_mask + widx) * h_feature +
|
||||
h) *
|
||||
w_feature +
|
||||
w] = buffer_diff[(n * h_feature * w_feature +
|
||||
(hidx + h - half_h_mask) * w_feature +
|
||||
(widx + w - half_w_mask)) *
|
||||
h_feature * w_feature +
|
||||
h * w_feature + w];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void psamask_distribute_backward_musa(
|
||||
const int nthreads, const int h_feature, const int w_feature,
|
||||
const int h_mask, const int w_mask, const int half_h_mask,
|
||||
const int half_w_mask, const T* buffer_diff, T* mask_diff) {
|
||||
MUSA_KERNEL_LOOP(index, nthreads) {
|
||||
const int w = index % w_feature;
|
||||
const int h = (index / w_feature) % h_feature;
|
||||
const int n = index / w_feature / h_feature;
|
||||
// effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed
|
||||
const int hstart = max(0, half_h_mask - h);
|
||||
const int hend = min(h_mask, h_feature + half_h_mask - h);
|
||||
const int wstart = max(0, half_w_mask - w);
|
||||
const int wend = min(w_mask, w_feature + half_w_mask - w);
|
||||
// (hidx, widx ) with mask-indexed
|
||||
// (hidx + h - half_h_mask, widx + w - half_w_mask) with feature-indexed
|
||||
for (int hidx = hstart; hidx < hend; hidx++) {
|
||||
for (int widx = wstart; widx < wend; widx++) {
|
||||
mask_diff[((n * h_mask * w_mask + hidx * w_mask + widx) * h_feature +
|
||||
h) *
|
||||
w_feature +
|
||||
w] =
|
||||
buffer_diff[(n * h_feature * w_feature + h * w_feature + w) *
|
||||
h_feature * w_feature +
|
||||
(hidx + h - half_h_mask) * w_feature +
|
||||
(widx + w - half_w_mask)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // PSAMASK_MUSA_KERNEL_MUH
|
|
@ -0,0 +1,238 @@
|
|||
// Modified from
|
||||
// https://github.com/csuhan/ReDet/blob/master/mmdet/ops/riroi_align/src/riroi_align_kernel.cu
|
||||
#ifndef RIROI_ALIGN_ROTATED_MUSA_KERNEL_MUH
|
||||
#define RIROI_ALIGN_ROTATED_MUSA_KERNEL_MUH
|
||||
|
||||
#include <float.h>
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
/*** Forward ***/
|
||||
template <typename scalar_t>
|
||||
__global__ void riroi_align_rotated_forward_musa_kernel(
|
||||
const int nthreads, const scalar_t *bottom_data,
|
||||
const scalar_t *bottom_rois, const scalar_t spatial_scale,
|
||||
const int num_samples, const bool clockwise, const int channels,
|
||||
const int height, const int width, const int pooled_height,
|
||||
const int pooled_width, const int num_orientations, scalar_t *top_data) {
|
||||
MUSA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int o = (index / pooled_width / pooled_height) % num_orientations;
|
||||
int c =
|
||||
(index / pooled_width / pooled_height / num_orientations) % channels;
|
||||
int n = index / pooled_width / pooled_height / num_orientations / channels;
|
||||
|
||||
const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
|
||||
int roi_batch_ind = offset_bottom_rois[0];
|
||||
|
||||
// Do not using rounding; this implementation detail is critical
|
||||
scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale;
|
||||
scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale;
|
||||
scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
|
||||
scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
|
||||
// scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
|
||||
scalar_t theta = offset_bottom_rois[5];
|
||||
// Force malformed ROIs to be 1x1
|
||||
roi_width = max(roi_width, (scalar_t)1.);
|
||||
roi_height = max(roi_height, (scalar_t)1.);
|
||||
scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
|
||||
static_cast<scalar_t>(pooled_height);
|
||||
scalar_t bin_size_w =
|
||||
static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
|
||||
|
||||
// find aligned index
|
||||
scalar_t ind_float = theta * num_orientations / (2 * M_PI);
|
||||
int ind = floorf(ind_float);
|
||||
scalar_t l_var = ind_float - (scalar_t)ind;
|
||||
scalar_t r_var = 1.0 - l_var;
|
||||
// correct start channel
|
||||
ind = (ind + num_orientations) % num_orientations;
|
||||
// rotated channel
|
||||
int ind_rot = (o - ind + num_orientations) % num_orientations;
|
||||
int ind_rot_plus = (ind_rot + 1 + num_orientations) % num_orientations;
|
||||
const scalar_t *offset_bottom_data =
|
||||
bottom_data + (roi_batch_ind * channels * num_orientations +
|
||||
c * num_orientations + ind_rot) *
|
||||
height * width;
|
||||
|
||||
const scalar_t *offset_bottom_data_plus =
|
||||
bottom_data + (roi_batch_ind * channels * num_orientations +
|
||||
c * num_orientations + ind_rot_plus) *
|
||||
height * width;
|
||||
// We use roi_bin_grid to sample the grid and mimic integral
|
||||
int roi_bin_grid_h = (num_samples > 0)
|
||||
? num_samples
|
||||
: ceilf(roi_height / pooled_height); // e.g., = 2
|
||||
int roi_bin_grid_w =
|
||||
(num_samples > 0) ? num_samples : ceilf(roi_width / pooled_width);
|
||||
|
||||
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
|
||||
// Appropriate translation needs to be applied after.
|
||||
if (clockwise) {
|
||||
theta = -theta; // If clockwise, the angle needs to be reversed.
|
||||
}
|
||||
scalar_t roi_start_h = -roi_height / 2.0;
|
||||
scalar_t roi_start_w = -roi_width / 2.0;
|
||||
scalar_t cosscalar_theta = cos(theta);
|
||||
scalar_t sinscalar_theta = sin(theta);
|
||||
|
||||
// We do average (integral) pooling inside a bin
|
||||
const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
|
||||
|
||||
scalar_t output_val = 0.;
|
||||
for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
|
||||
const scalar_t yy =
|
||||
roi_start_h + ph * bin_size_h +
|
||||
static_cast<scalar_t>(iy + .5f) * bin_size_h /
|
||||
static_cast<scalar_t>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
||||
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
||||
const scalar_t xx = roi_start_w + pw * bin_size_w +
|
||||
static_cast<scalar_t>(ix + .5f) * bin_size_w /
|
||||
static_cast<scalar_t>(roi_bin_grid_w);
|
||||
|
||||
// Rotate by theta (counterclockwise) around the center and translate
|
||||
scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h;
|
||||
scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w;
|
||||
|
||||
scalar_t val = bilinear_interpolate<scalar_t>(
|
||||
offset_bottom_data, height, width, y, x, index);
|
||||
scalar_t val_plus = bilinear_interpolate<scalar_t>(
|
||||
offset_bottom_data_plus, height, width, y, x, index);
|
||||
output_val += r_var * val + l_var * val_plus;
|
||||
}
|
||||
}
|
||||
output_val /= count;
|
||||
|
||||
top_data[index] = output_val;
|
||||
}
|
||||
}
|
||||
|
||||
/*** Backward ***/
|
||||
template <typename scalar_t>
|
||||
__global__ void riroi_align_rotated_backward_musa_kernel(
|
||||
const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois,
|
||||
const scalar_t spatial_scale, const int num_samples, const bool clockwise,
|
||||
const int channels, const int height, const int width,
|
||||
const int pooled_height, const int pooled_width, const int num_orientations,
|
||||
scalar_t *bottom_diff) {
|
||||
MUSA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int o = (index / pooled_width / pooled_height) % num_orientations;
|
||||
int c =
|
||||
(index / pooled_width / pooled_height / num_orientations) % channels;
|
||||
int n = index / pooled_width / pooled_height / num_orientations / channels;
|
||||
|
||||
const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
|
||||
int roi_batch_ind = offset_bottom_rois[0];
|
||||
|
||||
// Do not round
|
||||
scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale;
|
||||
scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale;
|
||||
scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
|
||||
scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
|
||||
// scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
|
||||
scalar_t theta = offset_bottom_rois[5];
|
||||
// Force malformed ROIs to be 1x1
|
||||
roi_width = max(roi_width, (scalar_t)1.);
|
||||
roi_height = max(roi_height, (scalar_t)1.);
|
||||
|
||||
scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
|
||||
static_cast<scalar_t>(pooled_height);
|
||||
scalar_t bin_size_w =
|
||||
static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
|
||||
|
||||
// find aligned index
|
||||
scalar_t ind_float = theta * num_orientations / (2 * M_PI);
|
||||
int ind = floorf(ind_float);
|
||||
scalar_t l_var = ind_float - (scalar_t)ind;
|
||||
scalar_t r_var = 1.0 - l_var;
|
||||
// correct start channel
|
||||
ind = (ind + num_orientations) % num_orientations;
|
||||
// rotated channel
|
||||
int ind_rot = (o - ind + num_orientations) % num_orientations;
|
||||
int ind_rot_plus = (ind_rot + 1 + num_orientations) % num_orientations;
|
||||
scalar_t *offset_bottom_diff =
|
||||
bottom_diff + (roi_batch_ind * channels * num_orientations +
|
||||
c * num_orientations + ind_rot) *
|
||||
height * width;
|
||||
scalar_t *offset_bottom_diff_plus =
|
||||
bottom_diff + (roi_batch_ind * channels * num_orientations +
|
||||
c * num_orientations + ind_rot_plus) *
|
||||
height * width;
|
||||
int top_offset =
|
||||
(n * channels * num_orientations + c * num_orientations + o) *
|
||||
pooled_height * pooled_width;
|
||||
const scalar_t *offset_top_diff = top_diff + top_offset;
|
||||
const scalar_t top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
|
||||
|
||||
// We use roi_bin_grid to sample the grid and mimic integral
|
||||
int roi_bin_grid_h = (num_samples > 0)
|
||||
? num_samples
|
||||
: ceilf(roi_height / pooled_height); // e.g., = 2
|
||||
int roi_bin_grid_w =
|
||||
(num_samples > 0) ? num_samples : ceilf(roi_width / pooled_width);
|
||||
|
||||
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
|
||||
// Appropriate translation needs to be applied after.
|
||||
if (clockwise) {
|
||||
theta = -theta; // If clockwise, the angle needs to be reversed.
|
||||
}
|
||||
scalar_t roi_start_h = -roi_height / 2.0;
|
||||
scalar_t roi_start_w = -roi_width / 2.0;
|
||||
scalar_t cosTheta = cos(theta);
|
||||
scalar_t sinTheta = sin(theta);
|
||||
|
||||
// We do average (integral) pooling inside a bin
|
||||
const scalar_t count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
|
||||
|
||||
for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
|
||||
const scalar_t yy =
|
||||
roi_start_h + ph * bin_size_h +
|
||||
static_cast<scalar_t>(iy + .5f) * bin_size_h /
|
||||
static_cast<scalar_t>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
||||
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
||||
const scalar_t xx = roi_start_w + pw * bin_size_w +
|
||||
static_cast<scalar_t>(ix + .5f) * bin_size_w /
|
||||
static_cast<scalar_t>(roi_bin_grid_w);
|
||||
|
||||
// Rotate by theta around the center and translate
|
||||
scalar_t y = yy * cosTheta - xx * sinTheta + roi_center_h;
|
||||
scalar_t x = yy * sinTheta + xx * cosTheta + roi_center_w;
|
||||
|
||||
scalar_t w1, w2, w3, w4;
|
||||
int x_low, x_high, y_low, y_high;
|
||||
|
||||
bilinear_interpolate_gradient<scalar_t>(height, width, y, x, w1, w2, w3,
|
||||
w4, x_low, x_high, y_low,
|
||||
y_high, index);
|
||||
|
||||
scalar_t g1 = top_diff_this_bin * w1 / count;
|
||||
scalar_t g2 = top_diff_this_bin * w2 / count;
|
||||
scalar_t g3 = top_diff_this_bin * w3 / count;
|
||||
scalar_t g4 = top_diff_this_bin * w4 / count;
|
||||
|
||||
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
|
||||
atomicAdd(offset_bottom_diff + y_low * width + x_low, g1 * r_var);
|
||||
atomicAdd(offset_bottom_diff + y_low * width + x_high, g2 * r_var);
|
||||
atomicAdd(offset_bottom_diff + y_high * width + x_low, g3 * r_var);
|
||||
atomicAdd(offset_bottom_diff + y_high * width + x_high, g4 * r_var);
|
||||
|
||||
atomicAdd(offset_bottom_diff_plus + y_low * width + x_low,
|
||||
g1 * l_var);
|
||||
atomicAdd(offset_bottom_diff_plus + y_low * width + x_high,
|
||||
g2 * l_var);
|
||||
atomicAdd(offset_bottom_diff_plus + y_high * width + x_low,
|
||||
g3 * l_var);
|
||||
atomicAdd(offset_bottom_diff_plus + y_high * width + x_high,
|
||||
g4 * l_var);
|
||||
|
||||
} // if
|
||||
} // ix
|
||||
} // iy
|
||||
} // MUSA_1D_KERNEL_LOOP
|
||||
} // RiRoIAlignBackward
|
||||
|
||||
#endif // RIROI_ALIGN_ROTATED_MUSA_KERNEL_MUH
|
|
@ -0,0 +1,205 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef ROI_ALIGN_MUSA_KERNEL_MUH
|
||||
#define ROI_ALIGN_MUSA_KERNEL_MUH
|
||||
|
||||
#include <float.h>
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
|
||||
/*** Forward ***/
|
||||
template <typename T>
|
||||
__global__ void roi_align_forward_musa_kernel(
|
||||
const int nthreads, const T* input, const T* rois, T* output, T* argmax_y,
|
||||
T* argmax_x, const int pooled_height, const int pooled_width,
|
||||
const T spatial_scale, const int sampling_ratio,
|
||||
const int pool_mode, // 0 - max pool, 1 - avg pool
|
||||
const bool aligned, const int channels, const int height, const int width) {
|
||||
MUSA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
|
||||
const T* offset_rois = rois + n * 5;
|
||||
int roi_batch_ind = offset_rois[0];
|
||||
|
||||
// Do not using rounding; this implementation detail is critical
|
||||
T offset = aligned ? (T)0.5 : (T)0.0;
|
||||
T roi_start_w = offset_rois[1] * spatial_scale - offset;
|
||||
T roi_start_h = offset_rois[2] * spatial_scale - offset;
|
||||
T roi_end_w = offset_rois[3] * spatial_scale - offset;
|
||||
T roi_end_h = offset_rois[4] * spatial_scale - offset;
|
||||
|
||||
T roi_width = roi_end_w - roi_start_w;
|
||||
T roi_height = roi_end_h - roi_start_h;
|
||||
if (!aligned) { // for backward-compatibility only
|
||||
roi_width = max(roi_width, (T)1.);
|
||||
roi_height = max(roi_height, (T)1.);
|
||||
}
|
||||
|
||||
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
|
||||
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
|
||||
|
||||
const T* offset_input =
|
||||
input + (roi_batch_ind * channels + c) * height * width;
|
||||
|
||||
// We use roi_bin_grid to sample the grid and mimic integral
|
||||
int roi_bin_grid_h =
|
||||
(sampling_ratio > 0)
|
||||
? sampling_ratio
|
||||
: static_cast<int>(ceilf(roi_height / pooled_height));
|
||||
int roi_bin_grid_w =
|
||||
(sampling_ratio > 0)
|
||||
? sampling_ratio
|
||||
: static_cast<int>(ceilf(roi_width / pooled_width));
|
||||
|
||||
if (pool_mode == 0) {
|
||||
// We do max pooling inside a bin
|
||||
T maxval = -FLT_MAX;
|
||||
T maxidx_y = -1.f, maxidx_x = -1.f;
|
||||
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
|
||||
const T y = roi_start_h + ph * bin_size_h +
|
||||
static_cast<T>(iy + .5f) * bin_size_h /
|
||||
static_cast<T>(roi_bin_grid_h);
|
||||
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
||||
const T x = roi_start_w + pw * bin_size_w +
|
||||
static_cast<T>(ix + .5f) * bin_size_w /
|
||||
static_cast<T>(roi_bin_grid_w);
|
||||
T val =
|
||||
bilinear_interpolate(offset_input, height, width, y, x, index);
|
||||
if (val > maxval) {
|
||||
maxval = val;
|
||||
maxidx_y = y;
|
||||
maxidx_x = x;
|
||||
}
|
||||
}
|
||||
}
|
||||
output[index] = maxval;
|
||||
argmax_y[index] = maxidx_y;
|
||||
argmax_x[index] = maxidx_x;
|
||||
} else if (pool_mode == 1) {
|
||||
// We do average pooling inside a bin
|
||||
const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1);
|
||||
T output_val = 0.;
|
||||
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
|
||||
const T y = roi_start_h + ph * bin_size_h +
|
||||
static_cast<T>(iy + .5f) * bin_size_h /
|
||||
static_cast<T>(roi_bin_grid_h);
|
||||
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
||||
const T x = roi_start_w + pw * bin_size_w +
|
||||
static_cast<T>(ix + .5f) * bin_size_w /
|
||||
static_cast<T>(roi_bin_grid_w);
|
||||
T val =
|
||||
bilinear_interpolate(offset_input, height, width, y, x, index);
|
||||
output_val += val;
|
||||
}
|
||||
}
|
||||
output[index] = output_val / count;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*** Backward ***/
|
||||
template <typename T>
|
||||
__global__ void roi_align_backward_musa_kernel(
|
||||
const int nthreads, const T* grad_output, const T* rois, const T* argmax_y,
|
||||
const T* argmax_x, T* grad_input, const int pooled_height,
|
||||
const int pooled_width, const T spatial_scale, const int sampling_ratio,
|
||||
const int pool_mode, // 0 - max pool, 1 - avg pool
|
||||
const bool aligned, const int channels, const int height, const int width) {
|
||||
MUSA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
|
||||
const T grad_output_this_bin = grad_output[index];
|
||||
|
||||
const T* offset_rois = rois + n * 5;
|
||||
int roi_batch_ind = offset_rois[0];
|
||||
T* offset_grad_input =
|
||||
grad_input + ((roi_batch_ind * channels + c) * height * width);
|
||||
|
||||
if (pool_mode == 0) {
|
||||
T y = argmax_y[index], x = argmax_x[index];
|
||||
if (y != -1.f) {
|
||||
T w1, w2, w3, w4;
|
||||
int x_low, x_high, y_low, y_high;
|
||||
bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
|
||||
x_low, x_high, y_low, y_high, index);
|
||||
|
||||
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
|
||||
atomicAdd(offset_grad_input + y_low * width + x_low,
|
||||
grad_output_this_bin * w1);
|
||||
atomicAdd(offset_grad_input + y_low * width + x_high,
|
||||
grad_output_this_bin * w2);
|
||||
atomicAdd(offset_grad_input + y_high * width + x_low,
|
||||
grad_output_this_bin * w3);
|
||||
atomicAdd(offset_grad_input + y_high * width + x_high,
|
||||
grad_output_this_bin * w4);
|
||||
}
|
||||
}
|
||||
} else if (pool_mode == 1) {
|
||||
// Do not using rounding; this implementation detail is critical
|
||||
T offset = aligned ? (T)0.5 : (T)0.0;
|
||||
T roi_start_w = offset_rois[1] * spatial_scale - offset;
|
||||
T roi_start_h = offset_rois[2] * spatial_scale - offset;
|
||||
T roi_end_w = offset_rois[3] * spatial_scale - offset;
|
||||
T roi_end_h = offset_rois[4] * spatial_scale - offset;
|
||||
|
||||
T roi_width = roi_end_w - roi_start_w;
|
||||
T roi_height = roi_end_h - roi_start_h;
|
||||
if (!aligned) { // for backward-compatibility only
|
||||
roi_width = max(roi_width, (T)1.);
|
||||
roi_height = max(roi_height, (T)1.);
|
||||
}
|
||||
|
||||
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
|
||||
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
|
||||
|
||||
// We use roi_bin_grid to sample the grid and mimic integral
|
||||
int roi_bin_grid_h =
|
||||
(sampling_ratio > 0)
|
||||
? sampling_ratio
|
||||
: static_cast<int>(ceilf(roi_height / pooled_height));
|
||||
int roi_bin_grid_w =
|
||||
(sampling_ratio > 0)
|
||||
? sampling_ratio
|
||||
: static_cast<int>(ceilf(roi_width / pooled_width));
|
||||
|
||||
// We do average (integral) pooling inside a bin
|
||||
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
|
||||
|
||||
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
|
||||
const T y = roi_start_h + ph * bin_size_h +
|
||||
static_cast<T>(iy + .5f) * bin_size_h /
|
||||
static_cast<T>(roi_bin_grid_h);
|
||||
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
||||
const T x = roi_start_w + pw * bin_size_w +
|
||||
static_cast<T>(ix + .5f) * bin_size_w /
|
||||
static_cast<T>(roi_bin_grid_w);
|
||||
|
||||
T w1, w2, w3, w4;
|
||||
int x_low, x_high, y_low, y_high;
|
||||
bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
|
||||
x_low, x_high, y_low, y_high, index);
|
||||
|
||||
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
|
||||
atomicAdd(offset_grad_input + y_low * width + x_low,
|
||||
grad_output_this_bin * w1 / count);
|
||||
atomicAdd(offset_grad_input + y_low * width + x_high,
|
||||
grad_output_this_bin * w2 / count);
|
||||
atomicAdd(offset_grad_input + y_high * width + x_low,
|
||||
grad_output_this_bin * w3 / count);
|
||||
atomicAdd(offset_grad_input + y_high * width + x_high,
|
||||
grad_output_this_bin * w4 / count);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // ROI_ALIGN_MUSA_KERNEL_MUH
|
|
@ -0,0 +1,194 @@
|
|||
// Modified from
|
||||
// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/ROIAlignRotated
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
#ifndef ROI_ALIGN_ROTATED_MUSA_KERNEL_MUH
|
||||
#define ROI_ALIGN_ROTATED_MUSA_KERNEL_MUH
|
||||
|
||||
#include <float.h>
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
/*** Forward ***/
|
||||
template <typename scalar_t>
|
||||
__global__ void roi_align_rotated_forward_musa_kernel(
|
||||
const int nthreads, const scalar_t *bottom_data,
|
||||
const scalar_t *bottom_rois, const scalar_t spatial_scale,
|
||||
const int sampling_ratio, const bool aligned, const bool clockwise,
|
||||
const int channels, const int height, const int width,
|
||||
const int pooled_height, const int pooled_width, scalar_t *top_data) {
|
||||
MUSA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
|
||||
const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
|
||||
int roi_batch_ind = offset_bottom_rois[0];
|
||||
|
||||
// Do not using rounding; this implementation detail is critical
|
||||
scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0;
|
||||
scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale - offset;
|
||||
scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale - offset;
|
||||
scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
|
||||
scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
|
||||
// scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
|
||||
scalar_t theta = offset_bottom_rois[5];
|
||||
if (clockwise) {
|
||||
theta = -theta; // If clockwise, the angle needs to be reversed.
|
||||
}
|
||||
if (!aligned) { // for backward-compatibility only
|
||||
// Force malformed ROIs to be 1x1
|
||||
roi_width = max(roi_width, (scalar_t)1.);
|
||||
roi_height = max(roi_height, (scalar_t)1.);
|
||||
}
|
||||
scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
|
||||
static_cast<scalar_t>(pooled_height);
|
||||
scalar_t bin_size_w =
|
||||
static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
|
||||
|
||||
const scalar_t *offset_bottom_data =
|
||||
bottom_data + (roi_batch_ind * channels + c) * height * width;
|
||||
|
||||
// We use roi_bin_grid to sample the grid and mimic integral
|
||||
int roi_bin_grid_h = (sampling_ratio > 0)
|
||||
? sampling_ratio
|
||||
: ceilf(roi_height / pooled_height); // e.g., = 2
|
||||
int roi_bin_grid_w =
|
||||
(sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
|
||||
|
||||
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
|
||||
// Appropriate translation needs to be applied after.
|
||||
scalar_t roi_start_h = -roi_height / 2.0;
|
||||
scalar_t roi_start_w = -roi_width / 2.0;
|
||||
scalar_t cosscalar_theta = cos(theta);
|
||||
scalar_t sinscalar_theta = sin(theta);
|
||||
|
||||
// We do average (integral) pooling inside a bin
|
||||
const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
|
||||
|
||||
scalar_t output_val = 0.;
|
||||
for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
|
||||
const scalar_t yy =
|
||||
roi_start_h + ph * bin_size_h +
|
||||
static_cast<scalar_t>(iy + .5f) * bin_size_h /
|
||||
static_cast<scalar_t>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
||||
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
||||
const scalar_t xx = roi_start_w + pw * bin_size_w +
|
||||
static_cast<scalar_t>(ix + .5f) * bin_size_w /
|
||||
static_cast<scalar_t>(roi_bin_grid_w);
|
||||
|
||||
// Rotate by theta (counterclockwise) around the center and translate
|
||||
scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h;
|
||||
scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w;
|
||||
|
||||
scalar_t val = bilinear_interpolate<scalar_t>(
|
||||
offset_bottom_data, height, width, y, x, index);
|
||||
output_val += val;
|
||||
}
|
||||
}
|
||||
output_val /= count;
|
||||
|
||||
top_data[index] = output_val;
|
||||
}
|
||||
}
|
||||
|
||||
/*** Backward ***/
|
||||
template <typename scalar_t>
|
||||
__global__ void roi_align_rotated_backward_musa_kernel(
|
||||
const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois,
|
||||
const scalar_t spatial_scale, const int sampling_ratio, const bool aligned,
|
||||
const bool clockwise, const int channels, const int height, const int width,
|
||||
const int pooled_height, const int pooled_width, scalar_t *bottom_diff) {
|
||||
MUSA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
|
||||
const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
|
||||
int roi_batch_ind = offset_bottom_rois[0];
|
||||
|
||||
// Do not round
|
||||
scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0;
|
||||
scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale - offset;
|
||||
scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale - offset;
|
||||
scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
|
||||
scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
|
||||
// scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
|
||||
scalar_t theta = offset_bottom_rois[5];
|
||||
if (clockwise) {
|
||||
theta = -theta; // If clockwise, the angle needs to be reversed.
|
||||
}
|
||||
if (!aligned) { // for backward-compatibility only
|
||||
// Force malformed ROIs to be 1x1
|
||||
roi_width = max(roi_width, (scalar_t)1.);
|
||||
roi_height = max(roi_height, (scalar_t)1.);
|
||||
}
|
||||
scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
|
||||
static_cast<scalar_t>(pooled_height);
|
||||
scalar_t bin_size_w =
|
||||
static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
|
||||
|
||||
scalar_t *offset_bottom_diff =
|
||||
bottom_diff + (roi_batch_ind * channels + c) * height * width;
|
||||
|
||||
int top_offset = (n * channels + c) * pooled_height * pooled_width;
|
||||
const scalar_t *offset_top_diff = top_diff + top_offset;
|
||||
const scalar_t top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
|
||||
|
||||
// We use roi_bin_grid to sample the grid and mimic integral
|
||||
int roi_bin_grid_h = (sampling_ratio > 0)
|
||||
? sampling_ratio
|
||||
: ceilf(roi_height / pooled_height); // e.g., = 2
|
||||
int roi_bin_grid_w =
|
||||
(sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
|
||||
|
||||
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
|
||||
// Appropriate translation needs to be applied after.
|
||||
scalar_t roi_start_h = -roi_height / 2.0;
|
||||
scalar_t roi_start_w = -roi_width / 2.0;
|
||||
scalar_t cosTheta = cos(theta);
|
||||
scalar_t sinTheta = sin(theta);
|
||||
|
||||
// We do average (integral) pooling inside a bin
|
||||
const scalar_t count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
|
||||
|
||||
for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
|
||||
const scalar_t yy =
|
||||
roi_start_h + ph * bin_size_h +
|
||||
static_cast<scalar_t>(iy + .5f) * bin_size_h /
|
||||
static_cast<scalar_t>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
||||
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
||||
const scalar_t xx = roi_start_w + pw * bin_size_w +
|
||||
static_cast<scalar_t>(ix + .5f) * bin_size_w /
|
||||
static_cast<scalar_t>(roi_bin_grid_w);
|
||||
|
||||
// Rotate by theta around the center and translate
|
||||
scalar_t y = yy * cosTheta - xx * sinTheta + roi_center_h;
|
||||
scalar_t x = yy * sinTheta + xx * cosTheta + roi_center_w;
|
||||
|
||||
scalar_t w1, w2, w3, w4;
|
||||
int x_low, x_high, y_low, y_high;
|
||||
|
||||
bilinear_interpolate_gradient<scalar_t>(height, width, y, x, w1, w2, w3,
|
||||
w4, x_low, x_high, y_low,
|
||||
y_high, index);
|
||||
|
||||
scalar_t g1 = top_diff_this_bin * w1 / count;
|
||||
scalar_t g2 = top_diff_this_bin * w2 / count;
|
||||
scalar_t g3 = top_diff_this_bin * w3 / count;
|
||||
scalar_t g4 = top_diff_this_bin * w4 / count;
|
||||
|
||||
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
|
||||
atomicAdd(offset_bottom_diff + y_low * width + x_low, g1);
|
||||
atomicAdd(offset_bottom_diff + y_low * width + x_high, g2);
|
||||
atomicAdd(offset_bottom_diff + y_high * width + x_low, g3);
|
||||
atomicAdd(offset_bottom_diff + y_high * width + x_high, g4);
|
||||
} // if
|
||||
} // ix
|
||||
} // iy
|
||||
} // MUSA_1D_KERNEL_LOOP
|
||||
} // RoIAlignBackward
|
||||
|
||||
#endif // ROI_ALIGN_ROTATED_MUSA_KERNEL_MUH
|
|
@ -0,0 +1,89 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef ROI_POOL_MUSA_KERNEL_MUH
|
||||
#define ROI_POOL_MUSA_KERNEL_MUH
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
template <typename T>
|
||||
__global__ void roi_pool_forward_musa_kernel(
|
||||
const int nthreads, const T* input, const T* rois, T* output, int* argmax,
|
||||
const int pooled_height, const int pooled_width, const T spatial_scale,
|
||||
const int channels, const int height, const int width) {
|
||||
MUSA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
|
||||
const T* offset_rois = rois + n * 5;
|
||||
int roi_batch_ind = offset_rois[0];
|
||||
// calculate the roi region on feature maps
|
||||
T roi_x1 = offset_rois[1] * spatial_scale;
|
||||
T roi_y1 = offset_rois[2] * spatial_scale;
|
||||
T roi_x2 = (offset_rois[3] + 1) * spatial_scale;
|
||||
T roi_y2 = (offset_rois[4] + 1) * spatial_scale;
|
||||
|
||||
// force malformed rois to be 1x1
|
||||
T roi_w = roi_x2 - roi_x1;
|
||||
T roi_h = roi_y2 - roi_y1;
|
||||
if (roi_w <= 0 || roi_h <= 0) continue;
|
||||
|
||||
T bin_size_w = roi_w / static_cast<T>(pooled_width);
|
||||
T bin_size_h = roi_h / static_cast<T>(pooled_height);
|
||||
|
||||
// the corresponding bin region
|
||||
int bin_x1 = floorf(static_cast<T>(pw) * bin_size_w + roi_x1);
|
||||
int bin_y1 = floorf(static_cast<T>(ph) * bin_size_h + roi_y1);
|
||||
int bin_x2 = ceilf(static_cast<T>(pw + 1) * bin_size_w + roi_x1);
|
||||
int bin_y2 = ceilf(static_cast<T>(ph + 1) * bin_size_h + roi_y1);
|
||||
|
||||
// add roi offsets and clip to input boundaries
|
||||
bin_x1 = min(max(bin_x1, 0), width);
|
||||
bin_y1 = min(max(bin_y1, 0), height);
|
||||
bin_x2 = min(max(bin_x2, 0), width);
|
||||
bin_y2 = min(max(bin_y2, 0), height);
|
||||
bool is_empty = (bin_y2 <= bin_y1) || (bin_x2 <= bin_x1);
|
||||
|
||||
const T* offset_input =
|
||||
input + (roi_batch_ind * channels + c) * height * width;
|
||||
// Define an empty pooling region to be zero
|
||||
// If nothing is pooled, argmax = -1 causes nothing to be backprop'd
|
||||
T max_val = is_empty ? 0 : -FLT_MAX;
|
||||
int max_idx = -1;
|
||||
for (int h = bin_y1; h < bin_y2; ++h) {
|
||||
for (int w = bin_x1; w < bin_x2; ++w) {
|
||||
int offset = h * width + w;
|
||||
if (offset_input[offset] > max_val) {
|
||||
max_val = offset_input[offset];
|
||||
max_idx = offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
output[index] = max_val;
|
||||
if (argmax != NULL) argmax[index] = max_idx;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void roi_pool_backward_musa_kernel(
|
||||
const int nthreads, const T* grad_output, const T* rois, const int* argmax,
|
||||
T* grad_input, const int pooled_height, const int pooled_width,
|
||||
const int channels, const int height, const int width) {
|
||||
MUSA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c) is an element in the pooled output
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
|
||||
int roi_batch_ind = rois[n * 5];
|
||||
T* grad_input_offset =
|
||||
grad_input + ((roi_batch_ind * channels + c) * height * width);
|
||||
int argmax_index = argmax[index];
|
||||
|
||||
if (argmax_index != -1) {
|
||||
atomicAdd(grad_input_offset + argmax_index, grad_output[index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // ROI_POOL_MUSA_KERNEL_MUH
|
|
@ -0,0 +1,256 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef ROIAWARE_POOL3D_MUSA_KERNEL_MUH
|
||||
#define ROIAWARE_POOL3D_MUSA_KERNEL_MUH
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
template <typename T>
|
||||
__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz,
|
||||
T &local_x, T &local_y) {
|
||||
T cosa = cos(-rz), sina = sin(-rz);
|
||||
local_x = shift_x * cosa + shift_y * (-sina);
|
||||
local_y = shift_x * sina + shift_y * cosa;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x,
|
||||
T &local_y) {
|
||||
// param pt: (x, y, z)
|
||||
// param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate,
|
||||
// cz in the bottom center
|
||||
T x = pt[0], y = pt[1], z = pt[2];
|
||||
T cx = box3d[0], cy = box3d[1], cz = box3d[2];
|
||||
T x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6];
|
||||
cz += z_size /
|
||||
2.0; // shift to the center since cz in box3d is the bottom center
|
||||
|
||||
if (fabsf(z - cz) > z_size / 2.0) return 0;
|
||||
lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
|
||||
float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) &
|
||||
(local_y > -y_size / 2.0) & (local_y < y_size / 2.0);
|
||||
return in_flag;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num,
|
||||
int out_x, int out_y, int out_z,
|
||||
const T *rois, const T *pts,
|
||||
int *pts_mask) {
|
||||
// params rois: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
|
||||
// coordinate params pts: (npoints, 3) [x, y, z] params pts_mask: (N,
|
||||
// npoints): -1 means point does not in this box, otherwise: encode (x_idxs,
|
||||
// y_idxs, z_idxs) by binary bit
|
||||
int box_idx = blockIdx.y;
|
||||
MUSA_1D_KERNEL_LOOP(pt_idx, pts_num) {
|
||||
if (box_idx >= boxes_num) return;
|
||||
|
||||
pts += pt_idx * 3;
|
||||
rois += box_idx * 7;
|
||||
pts_mask += box_idx * pts_num + pt_idx;
|
||||
|
||||
T local_x = 0, local_y = 0;
|
||||
int cur_in_flag = check_pt_in_box3d(pts, rois, local_x, local_y);
|
||||
|
||||
pts_mask[0] = -1;
|
||||
if (cur_in_flag > 0) {
|
||||
T local_z = pts[2] - rois[2];
|
||||
T x_size = rois[3], y_size = rois[4], z_size = rois[5];
|
||||
|
||||
T x_res = x_size / out_x;
|
||||
T y_res = y_size / out_y;
|
||||
T z_res = z_size / out_z;
|
||||
|
||||
unsigned int x_idx = int((local_x + x_size / 2) / x_res);
|
||||
unsigned int y_idx = int((local_y + y_size / 2) / y_res);
|
||||
unsigned int z_idx = int(local_z / z_res);
|
||||
|
||||
x_idx = min(max(x_idx, 0), out_x - 1);
|
||||
y_idx = min(max(y_idx, 0), out_y - 1);
|
||||
z_idx = min(max(z_idx, 0), out_z - 1);
|
||||
|
||||
unsigned int idx_encoding = (x_idx << 16) + (y_idx << 8) + z_idx;
|
||||
|
||||
pts_mask[0] = idx_encoding;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num,
|
||||
int max_pts_each_voxel, int out_x,
|
||||
int out_y, int out_z,
|
||||
const int *pts_mask,
|
||||
T *pts_idx_of_voxels) {
|
||||
// params pts_mask: (N, npoints) 0 or 1
|
||||
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
|
||||
MUSA_1D_KERNEL_LOOP(box_idx, boxes_num) {
|
||||
int max_num_pts = max_pts_each_voxel - 1; // index 0 is the counter
|
||||
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel;
|
||||
|
||||
for (int k = 0; k < pts_num; k++) {
|
||||
if (pts_mask[box_idx * pts_num + k] != -1) {
|
||||
unsigned int idx_encoding = pts_mask[box_idx * pts_num + k];
|
||||
unsigned int x_idx = (idx_encoding >> 16) & 0xFF;
|
||||
unsigned int y_idx = (idx_encoding >> 8) & 0xFF;
|
||||
unsigned int z_idx = idx_encoding & 0xFF;
|
||||
unsigned int base_offset = x_idx * out_y * out_z * max_pts_each_voxel +
|
||||
y_idx * out_z * max_pts_each_voxel +
|
||||
z_idx * max_pts_each_voxel;
|
||||
unsigned int cnt = pts_idx_of_voxels[base_offset];
|
||||
if (cnt < max_num_pts) {
|
||||
pts_idx_of_voxels[base_offset + cnt + 1] = k;
|
||||
pts_idx_of_voxels[base_offset]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels,
|
||||
int max_pts_each_voxel, int out_x, int out_y,
|
||||
int out_z, const T *pts_feature,
|
||||
const int *pts_idx_of_voxels,
|
||||
T *pooled_features, int *argmax) {
|
||||
// params pts_feature: (npoints, C)
|
||||
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel),
|
||||
// index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C)
|
||||
// params argmax: (N, out_x, out_y, out_z, C)
|
||||
|
||||
int box_idx = blockIdx.z;
|
||||
int channel_idx = blockIdx.y;
|
||||
MUSA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
|
||||
int x_idx = voxel_idx_flat / (out_y * out_z);
|
||||
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
|
||||
int z_idx = voxel_idx_flat % out_z;
|
||||
if (box_idx >= boxes_num || channel_idx >= channels) return;
|
||||
|
||||
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
|
||||
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
|
||||
offset_base * max_pts_each_voxel;
|
||||
pooled_features += box_idx * out_x * out_y * out_z * channels +
|
||||
offset_base * channels + channel_idx;
|
||||
argmax += box_idx * out_x * out_y * out_z * channels +
|
||||
offset_base * channels + channel_idx;
|
||||
|
||||
int argmax_idx = -1;
|
||||
float max_val = -1e50;
|
||||
|
||||
int total_pts = pts_idx_of_voxels[0];
|
||||
|
||||
for (int k = 1; k <= total_pts; k++) {
|
||||
if (pts_feature[pts_idx_of_voxels[k] * channels + channel_idx] >
|
||||
max_val) {
|
||||
max_val = pts_feature[pts_idx_of_voxels[k] * channels + channel_idx];
|
||||
argmax_idx = pts_idx_of_voxels[k];
|
||||
}
|
||||
}
|
||||
|
||||
if (argmax_idx != -1) {
|
||||
pooled_features[0] = max_val;
|
||||
}
|
||||
argmax[0] = argmax_idx;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels,
|
||||
int max_pts_each_voxel, int out_x, int out_y,
|
||||
int out_z, const T *pts_feature,
|
||||
const int *pts_idx_of_voxels,
|
||||
T *pooled_features) {
|
||||
// params pts_feature: (npoints, C)
|
||||
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel),
|
||||
// index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C)
|
||||
// params argmax: (N, out_x, out_y, out_z, C)
|
||||
|
||||
int box_idx = blockIdx.z;
|
||||
int channel_idx = blockIdx.y;
|
||||
MUSA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
|
||||
int x_idx = voxel_idx_flat / (out_y * out_z);
|
||||
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
|
||||
int z_idx = voxel_idx_flat % out_z;
|
||||
if (box_idx >= boxes_num || channel_idx >= channels) return;
|
||||
|
||||
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
|
||||
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
|
||||
offset_base * max_pts_each_voxel;
|
||||
pooled_features += box_idx * out_x * out_y * out_z * channels +
|
||||
offset_base * channels + channel_idx;
|
||||
|
||||
float sum_val = 0;
|
||||
int total_pts = pts_idx_of_voxels[0];
|
||||
|
||||
for (int k = 1; k <= total_pts; k++) {
|
||||
sum_val += pts_feature[pts_idx_of_voxels[k] * channels + channel_idx];
|
||||
}
|
||||
|
||||
if (total_pts > 0) {
|
||||
pooled_features[0] = sum_val / total_pts;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void roiaware_maxpool3d_backward(int boxes_num, int channels,
|
||||
int out_x, int out_y, int out_z,
|
||||
const int *argmax,
|
||||
const T *grad_out, T *grad_in) {
|
||||
// params argmax: (N, out_x, out_y, out_z, C)
|
||||
// params grad_out: (N, out_x, out_y, out_z, C)
|
||||
// params grad_in: (npoints, C), return value
|
||||
|
||||
int box_idx = blockIdx.z;
|
||||
int channel_idx = blockIdx.y;
|
||||
MUSA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
|
||||
int x_idx = voxel_idx_flat / (out_y * out_z);
|
||||
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
|
||||
int z_idx = voxel_idx_flat % out_z;
|
||||
if (box_idx >= boxes_num || channel_idx >= channels) return;
|
||||
|
||||
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
|
||||
argmax += box_idx * out_x * out_y * out_z * channels +
|
||||
offset_base * channels + channel_idx;
|
||||
grad_out += box_idx * out_x * out_y * out_z * channels +
|
||||
offset_base * channels + channel_idx;
|
||||
|
||||
if (argmax[0] == -1) return;
|
||||
|
||||
atomicAdd(grad_in + argmax[0] * channels + channel_idx, grad_out[0] * 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void roiaware_avgpool3d_backward(int boxes_num, int channels,
|
||||
int out_x, int out_y, int out_z,
|
||||
int max_pts_each_voxel,
|
||||
const int *pts_idx_of_voxels,
|
||||
const T *grad_out, T *grad_in) {
|
||||
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
|
||||
// params grad_out: (N, out_x, out_y, out_z, C)
|
||||
// params grad_in: (npoints, C), return value
|
||||
|
||||
int box_idx = blockIdx.z;
|
||||
int channel_idx = blockIdx.y;
|
||||
MUSA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
|
||||
int x_idx = voxel_idx_flat / (out_y * out_z);
|
||||
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
|
||||
int z_idx = voxel_idx_flat % out_z;
|
||||
if (box_idx >= boxes_num || channel_idx >= channels) return;
|
||||
|
||||
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
|
||||
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
|
||||
offset_base * max_pts_each_voxel;
|
||||
grad_out += box_idx * out_x * out_y * out_z * channels +
|
||||
offset_base * channels + channel_idx;
|
||||
|
||||
int total_pts = pts_idx_of_voxels[0];
|
||||
float cur_grad = 1 / fmaxf(float(total_pts), 1.0);
|
||||
for (int k = 1; k <= total_pts; k++) {
|
||||
atomicAdd(grad_in + pts_idx_of_voxels[k] * channels + channel_idx,
|
||||
grad_out[0] * cur_grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // ROIAWARE_POOL3D_MUSA_KERNEL_MUH
|
|
@ -0,0 +1,130 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef ROIPOINT_POOL3D_MUSA_KERNEL_MUH
|
||||
#define ROIPOINT_POOL3D_MUSA_KERNEL_MUH
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
template <typename T>
|
||||
__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz,
|
||||
T &local_x, T &local_y) {
|
||||
T cosa = cos(-rz), sina = sin(-rz);
|
||||
local_x = shift_x * cosa + shift_y * (-sina);
|
||||
local_y = shift_x * sina + shift_y * cosa;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x,
|
||||
T &local_y) {
|
||||
// param pt: (x, y, z)
|
||||
// param box3d: (cx, cy, cz, dx, dy, dz, rz) in LiDAR coordinate, cz in the
|
||||
// bottom center
|
||||
T x = pt[0], y = pt[1], z = pt[2];
|
||||
T cx = box3d[0], cy = box3d[1], cz = box3d[2];
|
||||
T dx = box3d[3], dy = box3d[4], dz = box3d[5], rz = box3d[6];
|
||||
cz += dz / 2.0; // shift to the center since cz in box3d is the bottom center
|
||||
|
||||
if (fabsf(z - cz) > dz / 2.0) return 0;
|
||||
lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
|
||||
T in_flag = (local_x > -dx / 2.0) & (local_x < dx / 2.0) &
|
||||
(local_y > -dy / 2.0) & (local_y < dy / 2.0);
|
||||
return in_flag;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void assign_pts_to_box3d(int batch_size, int pts_num, int boxes_num,
|
||||
const T *xyz, const T *boxes3d,
|
||||
int *pts_assign) {
|
||||
// params xyz: (B, N, 3)
|
||||
// params boxes3d: (B, M, 7)
|
||||
// params pts_assign: (B, N, M): idx of the corresponding box3d, -1 means
|
||||
// background points
|
||||
int box_idx = blockIdx.y;
|
||||
int bs_idx = blockIdx.z;
|
||||
MUSA_1D_KERNEL_LOOP(pt_idx, pts_num) {
|
||||
if (box_idx >= boxes_num || bs_idx >= batch_size) return;
|
||||
|
||||
int assign_idx =
|
||||
bs_idx * pts_num * boxes_num + pt_idx * boxes_num + box_idx;
|
||||
pts_assign[assign_idx] = 0;
|
||||
|
||||
int box_offset = bs_idx * boxes_num * 7 + box_idx * 7;
|
||||
int pt_offset = bs_idx * pts_num * 3 + pt_idx * 3;
|
||||
|
||||
T local_x = 0, local_y = 0;
|
||||
int cur_in_flag = check_pt_in_box3d(xyz + pt_offset, boxes3d + box_offset,
|
||||
local_x, local_y);
|
||||
pts_assign[assign_idx] = cur_in_flag;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void get_pooled_idx(int batch_size, int pts_num, int boxes_num,
|
||||
int sampled_pts_num, const int *pts_assign,
|
||||
int *pts_idx, int *pooled_empty_flag) {
|
||||
// params xyz: (B, N, 3)
|
||||
// params pts_feature: (B, N, C)
|
||||
// params pts_assign: (B, N)
|
||||
// params pts_idx: (B, M, 512)
|
||||
// params pooled_empty_flag: (B, M)
|
||||
MUSA_1D_KERNEL_LOOP(boxes_idx, boxes_num) {
|
||||
int bs_idx = blockIdx.y;
|
||||
|
||||
int cnt = 0;
|
||||
for (int k = 0; k < pts_num; k++) {
|
||||
if (pts_assign[bs_idx * pts_num * boxes_num + k * boxes_num +
|
||||
boxes_idx]) {
|
||||
if (cnt < sampled_pts_num) {
|
||||
pts_idx[bs_idx * boxes_num * sampled_pts_num +
|
||||
boxes_idx * sampled_pts_num + cnt] = k;
|
||||
cnt++;
|
||||
} else
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (cnt == 0) {
|
||||
pooled_empty_flag[bs_idx * boxes_num + boxes_idx] = 1;
|
||||
} else if (cnt < sampled_pts_num) {
|
||||
// duplicate same points for sampling
|
||||
for (int k = cnt; k < sampled_pts_num; k++) {
|
||||
int duplicate_idx = k % cnt;
|
||||
int base_offset =
|
||||
bs_idx * boxes_num * sampled_pts_num + boxes_idx * sampled_pts_num;
|
||||
pts_idx[base_offset + k] = pts_idx[base_offset + duplicate_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void roipoint_pool3d_forward(
|
||||
int batch_size, int pts_num, int boxes_num, int feature_in_len,
|
||||
int sampled_pts_num, const T *xyz, const int *pts_idx, const T *pts_feature,
|
||||
T *pooled_features, int *pooled_empty_flag) {
|
||||
// params xyz: (B, N, 3)
|
||||
// params pts_idx: (B, M, 512)
|
||||
// params pts_feature: (B, N, C)
|
||||
// params pooled_features: (B, M, 512, 3+C)
|
||||
// params pooled_empty_flag: (B, M)
|
||||
int box_idx = blockIdx.y;
|
||||
int bs_idx = blockIdx.z;
|
||||
MUSA_1D_KERNEL_LOOP(sample_pt_idx, sampled_pts_num) {
|
||||
if (box_idx >= boxes_num || bs_idx >= batch_size) return;
|
||||
if (pooled_empty_flag[bs_idx * boxes_num + box_idx]) return;
|
||||
|
||||
int temp_idx = bs_idx * boxes_num * sampled_pts_num +
|
||||
box_idx * sampled_pts_num + sample_pt_idx;
|
||||
int src_pt_idx = pts_idx[temp_idx];
|
||||
int dst_feature_offset = temp_idx * (3 + feature_in_len);
|
||||
|
||||
for (int j = 0; j < 3; j++)
|
||||
pooled_features[dst_feature_offset + j] =
|
||||
xyz[bs_idx * pts_num * 3 + src_pt_idx * 3 + j];
|
||||
|
||||
int src_feature_offset =
|
||||
bs_idx * pts_num * feature_in_len + src_pt_idx * feature_in_len;
|
||||
memcpy(pooled_features + dst_feature_offset + 3,
|
||||
pts_feature + src_feature_offset, feature_in_len * sizeof(T));
|
||||
}
|
||||
}
|
||||
|
||||
#endif // ROIPOINT_POOL3D_MUSA_KERNEL_MUH
|
|
@ -0,0 +1,125 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
// Modified from
|
||||
// https://github.com/SJTU-Thinklab-Det/r3det-on-mmdetection/blob/master/mmdet/ops/fr/src/feature_refine_kernel.cu
|
||||
#ifndef ROTATED_FEATURE_ALIGN_MUSA_KERNEL_MUH
|
||||
#define ROTATED_FEATURE_ALIGN_MUSA_KERNEL_MUH
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void rotated_feature_align_forward_kernel(
|
||||
const int nthreads, const int points, const scalar_t* bottom_data,
|
||||
const scalar_t* best_bboxes, const scalar_t spatial_scale,
|
||||
const int channels, const int height, const int width, scalar_t* top_data) {
|
||||
MUSA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int w = index % width;
|
||||
int h = (index / width) % height;
|
||||
int c = (index / width / height) % channels;
|
||||
int n = index / width / height / channels;
|
||||
|
||||
const scalar_t* bbox_offset =
|
||||
best_bboxes + ((n * height + h) * width + w) * 5;
|
||||
scalar_t roi_y = bbox_offset[0] * spatial_scale;
|
||||
scalar_t roi_x = bbox_offset[1] * spatial_scale;
|
||||
|
||||
scalar_t px[5] = {roi_x, 0, 0, 0, 0};
|
||||
scalar_t py[5] = {roi_y, 0, 0, 0, 0};
|
||||
|
||||
if (points > 1) {
|
||||
scalar_t roi_w = bbox_offset[2] * spatial_scale;
|
||||
scalar_t roi_h = bbox_offset[3] * spatial_scale;
|
||||
scalar_t roi_a = bbox_offset[4];
|
||||
|
||||
scalar_t w_2 = roi_w / 2, h_2 = roi_h / 2;
|
||||
scalar_t cosa = cosf(roi_a), sina = sinf(roi_a);
|
||||
scalar_t wx = cosa * w_2, wy = sina * w_2;
|
||||
scalar_t hx = -sina * h_2, hy = cosa * h_2;
|
||||
|
||||
px[1] = roi_x + wx + hx;
|
||||
py[1] = roi_y + wy + hy;
|
||||
px[2] = roi_x - wx + hx;
|
||||
py[2] = roi_y - wy + hy;
|
||||
px[3] = roi_x - wx - hx;
|
||||
py[3] = roi_y - wy - hy;
|
||||
px[4] = roi_x + wx - hx;
|
||||
py[4] = roi_y + wy - hy;
|
||||
}
|
||||
|
||||
const scalar_t* offset_bottom_data =
|
||||
bottom_data + (n * channels + c) * height * width;
|
||||
|
||||
scalar_t output_val = bottom_data[index];
|
||||
for (int i = 0; i < points; i++) {
|
||||
output_val += bilinear_interpolate<scalar_t>(offset_bottom_data, height,
|
||||
width, py[i], px[i], i);
|
||||
}
|
||||
top_data[index] = output_val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void rotated_feature_align_backward_kernel(
|
||||
const int nthreads, const int points, const scalar_t* top_diff,
|
||||
const scalar_t* best_bboxes, const scalar_t spatial_scale,
|
||||
const int channels, const int height, const int width,
|
||||
scalar_t* bottom_diff) {
|
||||
MUSA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int w = index % width;
|
||||
int h = (index / width) % height;
|
||||
int c = (index / width / height) % channels;
|
||||
int n = index / width / height / channels;
|
||||
|
||||
const scalar_t* bbox_offset =
|
||||
best_bboxes + ((n * height + h) * width + w) * 5;
|
||||
scalar_t roi_y = bbox_offset[0] * spatial_scale;
|
||||
scalar_t roi_x = bbox_offset[1] * spatial_scale;
|
||||
|
||||
scalar_t px[5] = {roi_x, 0, 0, 0, 0};
|
||||
scalar_t py[5] = {roi_y, 0, 0, 0, 0};
|
||||
|
||||
if (points > 1) {
|
||||
scalar_t roi_w = bbox_offset[2] * spatial_scale;
|
||||
scalar_t roi_h = bbox_offset[3] * spatial_scale;
|
||||
scalar_t roi_a = bbox_offset[4];
|
||||
|
||||
scalar_t w_2 = roi_w / 2, h_2 = roi_h / 2;
|
||||
scalar_t cosa = cosf(roi_a), sina = sinf(roi_a);
|
||||
scalar_t wx = cosa * w_2, wy = sina * w_2;
|
||||
scalar_t hx = -sina * h_2, hy = cosa * h_2;
|
||||
|
||||
px[1] = roi_x + wx + hx;
|
||||
py[1] = roi_y + wy + hy;
|
||||
px[2] = roi_x - wx + hx;
|
||||
py[2] = roi_y - wy + hy;
|
||||
px[3] = roi_x - wx - hx;
|
||||
py[3] = roi_y - wy - hy;
|
||||
px[4] = roi_x + wx - hx;
|
||||
py[4] = roi_y + wy - hy;
|
||||
}
|
||||
|
||||
scalar_t* offset_bottom_diff =
|
||||
bottom_diff + (n * channels + c) * height * width;
|
||||
scalar_t value_top_diff = top_diff[index];
|
||||
|
||||
atomicAdd(bottom_diff + index, value_top_diff);
|
||||
for (int i = 0; i < points; i++) {
|
||||
scalar_t w1, w2, w3, w4;
|
||||
int x_low, x_high, y_low, y_high;
|
||||
|
||||
bilinear_interpolate_gradient<scalar_t>(height, width, py[i], px[i], w1,
|
||||
w2, w3, w4, x_low, x_high, y_low,
|
||||
y_high, i);
|
||||
scalar_t g1 = value_top_diff * w1;
|
||||
scalar_t g2 = value_top_diff * w2;
|
||||
scalar_t g3 = value_top_diff * w3;
|
||||
scalar_t g4 = value_top_diff * w4;
|
||||
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
|
||||
atomicAdd(offset_bottom_diff + y_low * width + x_low, g1);
|
||||
atomicAdd(offset_bottom_diff + y_low * width + x_high, g2);
|
||||
atomicAdd(offset_bottom_diff + y_high * width + x_low, g3);
|
||||
atomicAdd(offset_bottom_diff + y_high * width + x_high, g4);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // ROTATED_FEATURE_ALIGN_MUSA_KERNEL_MUH
|
|
@ -0,0 +1,137 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef SCATTER_POINTS_MUSA_KERNEL_MUH
|
||||
#define SCATTER_POINTS_MUSA_KERNEL_MUH
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t;
|
||||
int const maxGridDim = 50000;
|
||||
|
||||
__device__ __forceinline__ static void reduceMax(float *address, float val) {
|
||||
int *address_as_i = reinterpret_cast<int *>(address);
|
||||
int old = *address_as_i, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_i, assumed,
|
||||
__float_as_int(fmaxf(val, __int_as_float(assumed))));
|
||||
} while (assumed != old || __int_as_float(old) < val);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ static void reduceMax(double *address, double val) {
|
||||
unsigned long long *address_as_ull =
|
||||
reinterpret_cast<unsigned long long *>(address);
|
||||
unsigned long long old = *address_as_ull, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(
|
||||
address_as_ull, assumed,
|
||||
__double_as_longlong(fmax(val, __longlong_as_double(assumed))));
|
||||
} while (assumed != old || __longlong_as_double(old) < val);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ static void reduceAdd(float *address, float val) {
|
||||
atomicAdd(address, val);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ static void reduceAdd(double *address, double val) {
|
||||
atomicAdd(address, val);
|
||||
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void feats_reduce_kernel(
|
||||
const T *feats, const int32_t *coors_map,
|
||||
T *reduced_feats, // shall be 0 at initialization
|
||||
const int num_input, const int num_feats, const reduce_t reduce_type) {
|
||||
MUSA_1D_KERNEL_LOOP(x, num_input) {
|
||||
int32_t reduce_to = coors_map[x];
|
||||
if (reduce_to == -1) continue;
|
||||
|
||||
const T *feats_offset = feats + x * num_feats;
|
||||
T *reduced_feats_offset = reduced_feats + reduce_to * num_feats;
|
||||
if (reduce_type == reduce_t::MAX) {
|
||||
for (int i = 0; i < num_feats; i++) {
|
||||
reduceMax(&reduced_feats_offset[i], feats_offset[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < num_feats; i++) {
|
||||
reduceAdd(&reduced_feats_offset[i], feats_offset[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void add_reduce_traceback_grad_kernel(
|
||||
T *grad_feats, const T *grad_reduced_feats, const int32_t *coors_map,
|
||||
const int32_t *reduce_count, const int num_input, const int num_feats,
|
||||
const reduce_t reduce_type) {
|
||||
MUSA_1D_KERNEL_LOOP(x, num_input) {
|
||||
int32_t reduce_to = coors_map[x];
|
||||
if (reduce_to == -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int input_offset = x * num_feats;
|
||||
T *grad_feats_offset = grad_feats + input_offset;
|
||||
const int reduced_offset = reduce_to * num_feats;
|
||||
const T *grad_reduced_feats_offset = grad_reduced_feats + reduced_offset;
|
||||
|
||||
if (reduce_type == reduce_t::SUM) {
|
||||
for (int i = 0; i < num_feats; i++) {
|
||||
grad_feats_offset[i] = grad_reduced_feats_offset[i];
|
||||
}
|
||||
} else if (reduce_type == reduce_t::MEAN) {
|
||||
for (int i = 0; i < num_feats; i++) {
|
||||
grad_feats_offset[i] = grad_reduced_feats_offset[i] /
|
||||
static_cast<T>(reduce_count[reduce_to]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void max_reduce_traceback_scatter_idx_kernel(
|
||||
const T *feats, const T *reduced_feats, int32_t *reduce_from,
|
||||
const int32_t *coors_map, const int num_input, const int num_feats) {
|
||||
MUSA_1D_KERNEL_LOOP(x, num_input) {
|
||||
int32_t reduce_to = coors_map[x];
|
||||
|
||||
const int input_offset = x * num_feats;
|
||||
const T *feats_offset = feats + input_offset;
|
||||
|
||||
if (reduce_to == -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int reduced_offset = reduce_to * num_feats;
|
||||
const T *reduced_feats_offset = reduced_feats + reduced_offset;
|
||||
int32_t *reduce_from_offset = reduce_from + reduced_offset;
|
||||
|
||||
for (int i = 0; i < num_feats; i++) {
|
||||
if (feats_offset[i] == reduced_feats_offset[i]) {
|
||||
atomicMin(&reduce_from_offset[i], static_cast<int32_t>(x));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void max_reduce_scatter_grad_kernel(T *grad_feats,
|
||||
const T *grad_reduced_feats,
|
||||
const int32_t *reduce_from,
|
||||
const int num_reduced,
|
||||
const int num_feats) {
|
||||
MUSA_1D_KERNEL_LOOP(x, num_reduced) {
|
||||
const int reduced_offset = x * num_feats;
|
||||
const int32_t *scatter_to_offset = reduce_from + reduced_offset;
|
||||
const T *grad_reduced_feats_offset = grad_reduced_feats + reduced_offset;
|
||||
|
||||
for (int i = 0; i < num_feats; i++) {
|
||||
grad_feats[scatter_to_offset[i] * num_feats + i] =
|
||||
grad_reduced_feats_offset[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // SCATTER_POINTS_MUSA_KERNEL_MUH
|
|
@ -0,0 +1,327 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef SYNCBN_MUSA_KERNEL_MUH
|
||||
#define SYNCBN_MUSA_KERNEL_MUH
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
template <typename T>
|
||||
__global__ void sync_bn_forward_mean_musa_kernel(const T *input, float *mean,
|
||||
int num, int channels,
|
||||
int spatial) {
|
||||
__shared__ float buffer[THREADS_PER_BLOCK];
|
||||
int tid = threadIdx.x;
|
||||
int c = blockIdx.x;
|
||||
buffer[tid] = 0;
|
||||
for (int i = tid; i < num * spatial; i += blockDim.x) {
|
||||
int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
|
||||
buffer[tid] += input[index];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
buffer[tid] += buffer[tid + s];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
int total = num * spatial;
|
||||
if (tid == 0) {
|
||||
mean[c] = buffer[0] / total;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void sync_bn_forward_mean_musa_kernel(const phalf *input,
|
||||
float *mean, int num,
|
||||
int channels, int spatial) {
|
||||
__shared__ float buffer[THREADS_PER_BLOCK];
|
||||
int tid = threadIdx.x;
|
||||
int c = blockIdx.x;
|
||||
buffer[tid] = 0;
|
||||
for (int i = tid; i < num * spatial; i += blockDim.x) {
|
||||
int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
|
||||
buffer[tid] += static_cast<float>(input[index]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
buffer[tid] += buffer[tid + s];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
int total = num * spatial;
|
||||
if (tid == 0) {
|
||||
mean[c] = buffer[0] / total;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void sync_bn_forward_var_musa_kernel(const T *input,
|
||||
const float *mean, float *var,
|
||||
int num, int channels,
|
||||
int spatial) {
|
||||
__shared__ float buffer[THREADS_PER_BLOCK];
|
||||
int tid = threadIdx.x;
|
||||
int c = blockIdx.x;
|
||||
buffer[tid] = 0;
|
||||
for (int i = tid; i < num * spatial; i += blockDim.x) {
|
||||
int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
|
||||
float td = input[index] - mean[c];
|
||||
buffer[tid] += td * td;
|
||||
}
|
||||
__syncthreads();
|
||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
buffer[tid] += buffer[tid + s];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
int total = num * spatial;
|
||||
if (tid == 0) {
|
||||
var[c] = buffer[0] / total;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void sync_bn_forward_var_musa_kernel(const phalf *input,
|
||||
const float *mean, float *var,
|
||||
int num, int channels,
|
||||
int spatial) {
|
||||
__shared__ float buffer[THREADS_PER_BLOCK];
|
||||
int tid = threadIdx.x;
|
||||
int c = blockIdx.x;
|
||||
buffer[tid] = 0;
|
||||
for (int i = tid; i < num * spatial; i += blockDim.x) {
|
||||
int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
|
||||
float td = static_cast<float>(input[index]) - mean[c];
|
||||
buffer[tid] += td * td;
|
||||
}
|
||||
__syncthreads();
|
||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
buffer[tid] += buffer[tid + s];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
int total = num * spatial;
|
||||
if (tid == 0) {
|
||||
var[c] = buffer[0] / total;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void sync_bn_forward_output_musa_kernel(
|
||||
const T *input, const float *mean, const float *var, float *running_mean,
|
||||
float *running_var, const float *weight, const float *bias, float *norm,
|
||||
float *std, T *output, int num, int channels, int spatial, float eps,
|
||||
float momentum, int group_size) {
|
||||
int tid = threadIdx.x;
|
||||
int c = blockIdx.x;
|
||||
float mean_value = mean[c];
|
||||
float std_value = sqrt(var[c] + eps);
|
||||
|
||||
if (weight != nullptr) {
|
||||
float weight_value = weight[c];
|
||||
float bias_value = bias[c];
|
||||
if (norm != nullptr) {
|
||||
for (int i = tid; i < num * spatial; i += blockDim.x) {
|
||||
int index =
|
||||
(i / spatial) * channels * spatial + c * spatial + i % spatial;
|
||||
norm[index] = (input[index] - mean_value) / std_value;
|
||||
output[index] = norm[index] * weight_value + bias_value;
|
||||
}
|
||||
} else {
|
||||
for (int i = tid; i < num * spatial; i += blockDim.x) {
|
||||
int index =
|
||||
(i / spatial) * channels * spatial + c * spatial + i % spatial;
|
||||
output[index] =
|
||||
(input[index] - mean_value) / std_value * weight_value + bias_value;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (norm != nullptr) {
|
||||
for (int i = tid; i < num * spatial; i += blockDim.x) {
|
||||
int index =
|
||||
(i / spatial) * channels * spatial + c * spatial + i % spatial;
|
||||
output[index] = norm[index] = (input[index] - mean_value) / std_value;
|
||||
}
|
||||
} else {
|
||||
for (int i = tid; i < num * spatial; i += blockDim.x) {
|
||||
int index =
|
||||
(i / spatial) * channels * spatial + c * spatial + i % spatial;
|
||||
output[index] = (input[index] - mean_value) / std_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (tid == 0) {
|
||||
if (std != nullptr) std[c] = std_value;
|
||||
if (running_mean != nullptr) {
|
||||
running_mean[c] =
|
||||
momentum * mean_value + (1 - momentum) * running_mean[c];
|
||||
int count = num * spatial * group_size;
|
||||
float var_unbias = count > 1 ? var[c] * count / (count - 1) : var[c];
|
||||
running_var[c] = momentum * var_unbias + (1 - momentum) * running_var[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void sync_bn_forward_output_musa_kernel(
|
||||
const phalf *input, const float *mean, const float *var,
|
||||
float *running_mean, float *running_var, const float *weight,
|
||||
const float *bias, float *norm, float *std, phalf *output, int num,
|
||||
int channels, int spatial, float eps, float momentum, int group_size) {
|
||||
int tid = threadIdx.x;
|
||||
int c = blockIdx.x;
|
||||
float mean_value = mean[c];
|
||||
float std_value = sqrt(var[c] + eps);
|
||||
if (weight != nullptr) {
|
||||
float weight_value = weight[c];
|
||||
float bias_value = bias[c];
|
||||
if (norm != nullptr) {
|
||||
for (int i = tid; i < num * spatial; i += blockDim.x) {
|
||||
int index =
|
||||
(i / spatial) * channels * spatial + c * spatial + i % spatial;
|
||||
norm[index] =
|
||||
(static_cast<float>(input[index]) - mean_value) / std_value;
|
||||
output[index] =
|
||||
static_cast<phalf>(norm[index] * weight_value + bias_value);
|
||||
}
|
||||
} else {
|
||||
for (int i = tid; i < num * spatial; i += blockDim.x) {
|
||||
int index =
|
||||
(i / spatial) * channels * spatial + c * spatial + i % spatial;
|
||||
output[index] =
|
||||
static_cast<phalf>((static_cast<float>(input[index]) - mean_value) /
|
||||
std_value * weight_value +
|
||||
bias_value);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (norm != nullptr) {
|
||||
for (int i = tid; i < num * spatial; i += blockDim.x) {
|
||||
int index =
|
||||
(i / spatial) * channels * spatial + c * spatial + i % spatial;
|
||||
norm[index] =
|
||||
(static_cast<float>(input[index]) - mean_value) / std_value;
|
||||
output[index] = static_cast<phalf>(norm[index]);
|
||||
}
|
||||
} else {
|
||||
for (int i = tid; i < num * spatial; i += blockDim.x) {
|
||||
int index =
|
||||
(i / spatial) * channels * spatial + c * spatial + i % spatial;
|
||||
output[index] = static_cast<phalf>(
|
||||
(static_cast<float>(input[index]) - mean_value) / std_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (tid == 0) {
|
||||
if (std != nullptr) std[c] = std_value;
|
||||
if (running_mean != nullptr) {
|
||||
running_mean[c] =
|
||||
momentum * mean_value + (1 - momentum) * running_mean[c];
|
||||
int count = num * spatial * group_size;
|
||||
float var_unbias = count > 1 ? var[c] * count / (count - 1) : var[c];
|
||||
running_var[c] = momentum * var_unbias + (1 - momentum) * running_var[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void sync_bn_backward_param_musa_kernel(const T *grad_output,
|
||||
const float *norm,
|
||||
float *grad_weight,
|
||||
float *grad_bias, int num,
|
||||
int channels, int spatial) {
|
||||
__shared__ float buffer1[THREADS_PER_BLOCK];
|
||||
__shared__ float buffer2[THREADS_PER_BLOCK];
|
||||
|
||||
int tid = threadIdx.x;
|
||||
int c = blockIdx.x;
|
||||
buffer1[tid] = buffer2[tid] = 0;
|
||||
for (int i = tid; i < num * spatial; i += blockDim.x) {
|
||||
int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
|
||||
buffer1[tid] += grad_output[index] * norm[index];
|
||||
buffer2[tid] += grad_output[index];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
buffer1[tid] += buffer1[tid + s];
|
||||
buffer2[tid] += buffer2[tid + s];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (tid == 0) {
|
||||
grad_weight[c] = buffer1[0];
|
||||
grad_bias[c] = buffer2[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void sync_bn_backward_param_musa_kernel(const phalf *grad_output,
|
||||
const float *norm,
|
||||
float *grad_weight,
|
||||
float *grad_bias, int num,
|
||||
int channels, int spatial) {
|
||||
__shared__ float buffer1[THREADS_PER_BLOCK];
|
||||
__shared__ float buffer2[THREADS_PER_BLOCK];
|
||||
|
||||
int tid = threadIdx.x;
|
||||
int c = blockIdx.x;
|
||||
buffer1[tid] = buffer2[tid] = 0;
|
||||
for (int i = tid; i < num * spatial; i += blockDim.x) {
|
||||
int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
|
||||
buffer1[tid] += static_cast<float>(grad_output[index]) * norm[index];
|
||||
buffer2[tid] += static_cast<float>(grad_output[index]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
buffer1[tid] += buffer1[tid + s];
|
||||
buffer2[tid] += buffer2[tid + s];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (tid == 0) {
|
||||
grad_weight[c] = buffer1[0];
|
||||
grad_bias[c] = buffer2[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void sync_bn_backward_data_musa_kernel(
|
||||
int output_size, const T *grad_output, const float *weight,
|
||||
const float *grad_weight, const float *grad_bias, const float *norm,
|
||||
const float *std, T *grad_input, int num, int channels, int spatial) {
|
||||
int factor = num * spatial;
|
||||
MUSA_1D_KERNEL_LOOP(index, output_size) {
|
||||
int c = (index / spatial) % channels;
|
||||
grad_input[index] =
|
||||
weight[c] *
|
||||
(grad_output[index] -
|
||||
(grad_weight[c] * norm[index] + grad_bias[c]) / factor) /
|
||||
std[c];
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void sync_bn_backward_data_musa_kernel(
|
||||
int output_size, const phalf *grad_output, const float *weight,
|
||||
const float *grad_weight, const float *grad_bias, const float *norm,
|
||||
const float *std, phalf *grad_input, int num, int channels, int spatial) {
|
||||
int factor = num * spatial;
|
||||
MUSA_1D_KERNEL_LOOP(index, output_size) {
|
||||
int c = (index / spatial) % channels;
|
||||
grad_input[index] = static_cast<phalf>(
|
||||
weight[c] *
|
||||
(static_cast<float>(grad_output[index]) -
|
||||
(grad_weight[c] * norm[index] + grad_bias[c]) / factor) /
|
||||
std[c]);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // SYNCBN_MUSA_KERNEL_MUH
|
|
@ -0,0 +1,57 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef THREE_INTERPOLATE_MUSA_KERNEL_MUH
|
||||
#define THREE_INTERPOLATE_MUSA_KERNEL_MUH
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
template <typename T>
|
||||
__global__ void three_interpolate_forward_musa_kernel(
|
||||
int b, int c, int m, int n, const T *points, const int *__restrict__ idx,
|
||||
const T *weight, T *out) {
|
||||
// points: (B, C, M)
|
||||
// idx: (B, N, 3)
|
||||
// weight: (B, N, 3)
|
||||
// output:
|
||||
// out: (B, C, N)
|
||||
|
||||
int bs_idx = blockIdx.z;
|
||||
int c_idx = blockIdx.y;
|
||||
MUSA_1D_KERNEL_LOOP(pt_idx, n) {
|
||||
if (bs_idx >= b || c_idx >= c) return;
|
||||
|
||||
weight += bs_idx * n * 3 + pt_idx * 3;
|
||||
points += bs_idx * c * m + c_idx * m;
|
||||
idx += bs_idx * n * 3 + pt_idx * 3;
|
||||
out += bs_idx * c * n + c_idx * n;
|
||||
|
||||
out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] +
|
||||
weight[2] * points[idx[2]];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void three_interpolate_backward_musa_kernel(
|
||||
int b, int c, int n, int m, const T *grad_out, const int *__restrict__ idx,
|
||||
const T *weight, T *grad_points) {
|
||||
// grad_out: (B, C, N)
|
||||
// weight: (B, N, 3)
|
||||
// output:
|
||||
// grad_points: (B, C, M)
|
||||
|
||||
int bs_idx = blockIdx.z;
|
||||
int c_idx = blockIdx.y;
|
||||
MUSA_1D_KERNEL_LOOP(pt_idx, n) {
|
||||
if (bs_idx >= b || c_idx >= c) return;
|
||||
|
||||
grad_out += bs_idx * c * n + c_idx * n + pt_idx;
|
||||
weight += bs_idx * n * 3 + pt_idx * 3;
|
||||
grad_points += bs_idx * c * m + c_idx * m;
|
||||
idx += bs_idx * n * 3 + pt_idx * 3;
|
||||
|
||||
atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
|
||||
atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
|
||||
atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // THREE_INTERPOLATE_MUSA_KERNEL_MUH
|
|
@ -0,0 +1,63 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef THREE_NN_MUSA_KERNEL_MUH
|
||||
#define THREE_NN_MUSA_KERNEL_MUH
|
||||
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
template <typename T>
|
||||
__global__ void three_nn_forward_musa_kernel(int b, int n, int m,
|
||||
const T *unknown, const T *known,
|
||||
T *dist2, int *__restrict__ idx) {
|
||||
// unknown: (B, N, 3)
|
||||
// known: (B, M, 3)
|
||||
// output:
|
||||
// dist2: (B, N, 3)
|
||||
// idx: (B, N, 3)
|
||||
|
||||
int bs_idx = blockIdx.y;
|
||||
MUSA_1D_KERNEL_LOOP(pt_idx, n) {
|
||||
if (bs_idx >= b) return;
|
||||
|
||||
unknown += bs_idx * n * 3 + pt_idx * 3;
|
||||
known += bs_idx * m * 3;
|
||||
dist2 += bs_idx * n * 3 + pt_idx * 3;
|
||||
idx += bs_idx * n * 3 + pt_idx * 3;
|
||||
|
||||
T ux = unknown[0];
|
||||
T uy = unknown[1];
|
||||
T uz = unknown[2];
|
||||
|
||||
double best1 = 1e40, best2 = 1e40, best3 = 1e40;
|
||||
int besti1 = 0, besti2 = 0, besti3 = 0;
|
||||
for (int k = 0; k < m; ++k) {
|
||||
T x = known[k * 3 + 0];
|
||||
T y = known[k * 3 + 1];
|
||||
T z = known[k * 3 + 2];
|
||||
T d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
|
||||
if (d < best1) {
|
||||
best3 = best2;
|
||||
besti3 = besti2;
|
||||
best2 = best1;
|
||||
besti2 = besti1;
|
||||
best1 = d;
|
||||
besti1 = k;
|
||||
} else if (d < best2) {
|
||||
best3 = best2;
|
||||
besti3 = besti2;
|
||||
best2 = d;
|
||||
besti2 = k;
|
||||
} else if (d < best3) {
|
||||
best3 = d;
|
||||
besti3 = k;
|
||||
}
|
||||
}
|
||||
dist2[0] = best1;
|
||||
dist2[1] = best2;
|
||||
dist2[2] = best3;
|
||||
idx[0] = besti1;
|
||||
idx[1] = besti2;
|
||||
idx[2] = besti3;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // THREE_NN_MUSA_KERNEL_MUH
|
|
@ -0,0 +1,57 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef TIN_SHIFT_MUSA_KERNEL_MUH
|
||||
#define TIN_SHIFT_MUSA_KERNEL_MUH
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
template <typename T>
|
||||
__global__ void tin_shift_forward_musa_kernel(
|
||||
const int nthreads, const T* input, const int* shift, T* output,
|
||||
const int batch_size, const int channels, const int t_size,
|
||||
const int hw_size, const int group_size, const int group_channel) {
|
||||
MUSA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
const int hw_index = index % hw_size;
|
||||
const int j = (index / hw_size) % channels;
|
||||
|
||||
const int n_index = (index / hw_size / channels) % batch_size;
|
||||
int group_id = j / group_channel;
|
||||
int t_shift = shift[n_index * group_size + group_id];
|
||||
int offset = n_index * t_size * hw_size * channels + hw_size * j + hw_index;
|
||||
for (int i = 0; i < t_size; i++) {
|
||||
int now_t = i + t_shift;
|
||||
int data_id = i * hw_size * channels + offset;
|
||||
if (now_t < 0 || now_t >= t_size) {
|
||||
continue;
|
||||
}
|
||||
int out_id = now_t * hw_size * channels + offset;
|
||||
output[out_id] = input[data_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void tin_shift_backward_musa_kernel(
|
||||
const int nthreads, const T* input, const int* shift, T* output,
|
||||
const int batch_size, const int channels, const int t_size,
|
||||
const int hw_size, const int group_size, const int group_channel) {
|
||||
MUSA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
const int hw_index = index % hw_size;
|
||||
const int j = (index / hw_size) % channels;
|
||||
|
||||
const int n_index = (index / hw_size / channels) % batch_size;
|
||||
int group_id = j / group_channel;
|
||||
int t_shift = shift[n_index * group_size + group_id];
|
||||
int offset = n_index * t_size * hw_size * channels + hw_size * j + hw_index;
|
||||
for (int i = 0; i < t_size; i++) {
|
||||
int now_t = i + t_shift;
|
||||
int data_id = i * hw_size * channels + offset;
|
||||
if (now_t < 0 || now_t >= t_size) {
|
||||
continue;
|
||||
}
|
||||
int out_id = now_t * hw_size * channels + offset;
|
||||
output[out_id] = input[data_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // TIN_SHIFT_MUSA_KERNEL_MUH
|
|
@ -0,0 +1,212 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#ifndef VOXELIZATION_MUSA_KERNEL_MUH
|
||||
#define VOXELIZATION_MUSA_KERNEL_MUH
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t;
|
||||
|
||||
template <typename T, typename T_int>
|
||||
__global__ void dynamic_voxelize_kernel(
|
||||
const T* points, T_int* coors, const float voxel_x, const float voxel_y,
|
||||
const float voxel_z, const float coors_x_min, const float coors_y_min,
|
||||
const float coors_z_min, const float coors_x_max, const float coors_y_max,
|
||||
const float coors_z_max, const int grid_x, const int grid_y,
|
||||
const int grid_z, const int num_points, const int num_features,
|
||||
const int NDim) {
|
||||
// const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
|
||||
MUSA_1D_KERNEL_LOOP(index, num_points) {
|
||||
// To save some computation
|
||||
auto points_offset = points + index * num_features;
|
||||
auto coors_offset = coors + index * NDim;
|
||||
int c_x = floorf((points_offset[0] - coors_x_min) / voxel_x);
|
||||
if (c_x < 0 || c_x >= grid_x) {
|
||||
coors_offset[0] = -1;
|
||||
continue;
|
||||
}
|
||||
|
||||
int c_y = floorf((points_offset[1] - coors_y_min) / voxel_y);
|
||||
if (c_y < 0 || c_y >= grid_y) {
|
||||
coors_offset[0] = -1;
|
||||
coors_offset[1] = -1;
|
||||
continue;
|
||||
}
|
||||
|
||||
int c_z = floorf((points_offset[2] - coors_z_min) / voxel_z);
|
||||
if (c_z < 0 || c_z >= grid_z) {
|
||||
coors_offset[0] = -1;
|
||||
coors_offset[1] = -1;
|
||||
coors_offset[2] = -1;
|
||||
} else {
|
||||
coors_offset[0] = c_z;
|
||||
coors_offset[1] = c_y;
|
||||
coors_offset[2] = c_x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename T_int>
|
||||
__global__ void assign_point_to_voxel(const int nthreads, const T* points,
|
||||
T_int* point_to_voxelidx,
|
||||
T_int* coor_to_voxelidx, T* voxels,
|
||||
const int max_points,
|
||||
const int num_features,
|
||||
const int num_points, const int NDim) {
|
||||
MUSA_1D_KERNEL_LOOP(thread_idx, nthreads) {
|
||||
// const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
|
||||
int index = thread_idx / num_features;
|
||||
|
||||
int num = point_to_voxelidx[index];
|
||||
int voxelidx = coor_to_voxelidx[index];
|
||||
if (num > -1 && voxelidx > -1) {
|
||||
auto voxels_offset =
|
||||
voxels + voxelidx * max_points * num_features + num * num_features;
|
||||
|
||||
int k = thread_idx % num_features;
|
||||
voxels_offset[k] = points[thread_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename T_int>
|
||||
__global__ void assign_voxel_coors(const int nthreads, T_int* coor,
|
||||
T_int* point_to_voxelidx,
|
||||
T_int* coor_to_voxelidx, T_int* voxel_coors,
|
||||
const int num_points, const int NDim) {
|
||||
MUSA_1D_KERNEL_LOOP(thread_idx, nthreads) {
|
||||
// const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
|
||||
// if (index >= num_points) return;
|
||||
int index = thread_idx / NDim;
|
||||
int num = point_to_voxelidx[index];
|
||||
int voxelidx = coor_to_voxelidx[index];
|
||||
if (num == 0 && voxelidx > -1) {
|
||||
auto coors_offset = voxel_coors + voxelidx * NDim;
|
||||
int k = thread_idx % NDim;
|
||||
coors_offset[k] = coor[thread_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_int>
|
||||
__global__ void point_to_voxelidx_kernel(const T_int* coor,
|
||||
T_int* point_to_voxelidx,
|
||||
T_int* point_to_pointidx,
|
||||
const int max_points,
|
||||
const int max_voxels,
|
||||
const int num_points, const int NDim) {
|
||||
MUSA_1D_KERNEL_LOOP(index, num_points) {
|
||||
auto coor_offset = coor + index * NDim;
|
||||
// skip invalid points
|
||||
if (coor_offset[0] == -1) continue;
|
||||
|
||||
int num = 0;
|
||||
int coor_x = coor_offset[0];
|
||||
int coor_y = coor_offset[1];
|
||||
int coor_z = coor_offset[2];
|
||||
// only calculate the coors before this coor[index]
|
||||
for (int i = 0; i < index; ++i) {
|
||||
auto prev_coor = coor + i * NDim;
|
||||
if (prev_coor[0] == -1) continue;
|
||||
|
||||
// Find all previous points that have the same coors
|
||||
// if find the same coor, record it
|
||||
if ((prev_coor[0] == coor_x) && (prev_coor[1] == coor_y) &&
|
||||
(prev_coor[2] == coor_z)) {
|
||||
num++;
|
||||
if (num == 1) {
|
||||
// point to the same coor that first show up
|
||||
point_to_pointidx[index] = i;
|
||||
} else if (num >= max_points) {
|
||||
// out of boundary
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (num == 0) {
|
||||
point_to_pointidx[index] = index;
|
||||
}
|
||||
if (num < max_points) {
|
||||
point_to_voxelidx[index] = num;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_int>
|
||||
__global__ void determin_voxel_num(
|
||||
// const T_int* coor,
|
||||
T_int* num_points_per_voxel, T_int* point_to_voxelidx,
|
||||
T_int* point_to_pointidx, T_int* coor_to_voxelidx, T_int* voxel_num,
|
||||
const int max_points, const int max_voxels, const int num_points) {
|
||||
// only calculate the coors before this coor[index]
|
||||
for (int i = 0; i < num_points; ++i) {
|
||||
int point_pos_in_voxel = point_to_voxelidx[i];
|
||||
// record voxel
|
||||
if (point_pos_in_voxel == -1) {
|
||||
// out of max_points or invalid point
|
||||
continue;
|
||||
} else if (point_pos_in_voxel == 0) {
|
||||
// record new voxel
|
||||
int voxelidx = voxel_num[0];
|
||||
if (voxel_num[0] >= max_voxels) continue;
|
||||
voxel_num[0] += 1;
|
||||
coor_to_voxelidx[i] = voxelidx;
|
||||
num_points_per_voxel[voxelidx] = 1;
|
||||
} else {
|
||||
int point_idx = point_to_pointidx[i];
|
||||
int voxelidx = coor_to_voxelidx[point_idx];
|
||||
if (voxelidx != -1) {
|
||||
coor_to_voxelidx[i] = voxelidx;
|
||||
num_points_per_voxel[voxelidx] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void nondeterministic_get_assign_pos(
|
||||
const int nthreads, const int32_t* coors_map, int32_t* pts_id,
|
||||
int32_t* coors_count, int32_t* reduce_count, int32_t* coors_order) {
|
||||
MUSA_1D_KERNEL_LOOP(thread_idx, nthreads) {
|
||||
int coors_idx = coors_map[thread_idx];
|
||||
if (coors_idx > -1) {
|
||||
int32_t coors_pts_pos = atomicAdd(&reduce_count[coors_idx], 1);
|
||||
pts_id[thread_idx] = coors_pts_pos;
|
||||
if (coors_pts_pos == 0) {
|
||||
coors_order[coors_idx] = atomicAdd(coors_count, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void nondeterministic_assign_point_voxel(
|
||||
const int nthreads, const T* points, const int32_t* coors_map,
|
||||
const int32_t* pts_id, const int32_t* coors_in, const int32_t* reduce_count,
|
||||
const int32_t* coors_order, T* voxels, int32_t* coors, int32_t* pts_count,
|
||||
const int max_voxels, const int max_points, const int num_features,
|
||||
const int NDim) {
|
||||
MUSA_1D_KERNEL_LOOP(thread_idx, nthreads) {
|
||||
int coors_idx = coors_map[thread_idx];
|
||||
int coors_pts_pos = pts_id[thread_idx];
|
||||
if (coors_idx > -1 && coors_pts_pos < max_points) {
|
||||
int coors_pos = coors_order[coors_idx];
|
||||
if (coors_pos < max_voxels) {
|
||||
auto voxels_offset =
|
||||
voxels + (coors_pos * max_points + coors_pts_pos) * num_features;
|
||||
auto points_offset = points + thread_idx * num_features;
|
||||
for (int k = 0; k < num_features; k++) {
|
||||
voxels_offset[k] = points_offset[k];
|
||||
}
|
||||
if (coors_pts_pos == 0) {
|
||||
pts_count[coors_pos] = min(reduce_count[coors_idx], max_points);
|
||||
auto coors_offset = coors + coors_pos * NDim;
|
||||
auto coors_in_offset = coors_in + coors_idx * NDim;
|
||||
for (int k = 0; k < NDim; k++) {
|
||||
coors_offset[k] = coors_in_offset[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // VOXELIZATION_MUSA_KERNEL_MUH
|
File diff suppressed because it is too large
Load Diff
|
@ -540,6 +540,17 @@ torch::Tensor bias_act_op(const torch::Tensor &input, const torch::Tensor &bias,
|
|||
|
||||
REGISTER_DEVICE_IMPL(bias_act_op_impl, MUSA, bias_act_op);
|
||||
|
||||
torch::Tensor filtered_lrelu_act_op_impl(torch::Tensor x, torch::Tensor si,
|
||||
int sx, int sy, float gain,
|
||||
float slope, float clamp,
|
||||
bool writeSigns);
|
||||
|
||||
torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx,
|
||||
int sy, float gain, float slope,
|
||||
float clamp, bool writeSigns);
|
||||
|
||||
REGISTER_DEVICE_IMPL(filtered_lrelu_act_op_impl, MUSA, filtered_lrelu_act_op);
|
||||
|
||||
void GatherPointsForwardMUSAKernelLauncher(int b, int c, int n, int npoints,
|
||||
const Tensor points,
|
||||
const Tensor idx, Tensor out);
|
||||
|
@ -869,6 +880,854 @@ Tensor nms_musa(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
|
|||
Tensor nms_impl(Tensor boxes, Tensor scores, float iou_threshold, int offset);
|
||||
REGISTER_DEVICE_IMPL(nms_impl, MUSA, nms_musa);
|
||||
|
||||
void PointsInBoxesPartForwardMUSAKernelLauncher(int batch_size, int boxes_num,
|
||||
int pts_num, const Tensor boxes,
|
||||
const Tensor pts,
|
||||
Tensor box_idx_of_points);
|
||||
|
||||
void PointsInBoxesAllForwardMUSAKernelLauncher(int batch_size, int boxes_num,
|
||||
int pts_num, const Tensor boxes,
|
||||
const Tensor pts,
|
||||
Tensor box_idx_of_points);
|
||||
|
||||
void points_in_boxes_part_forward_musa(int batch_size, int boxes_num,
|
||||
int pts_num, const Tensor boxes,
|
||||
const Tensor pts,
|
||||
Tensor box_idx_of_points) {
|
||||
PointsInBoxesPartForwardMUSAKernelLauncher(batch_size, boxes_num, pts_num,
|
||||
boxes, pts, box_idx_of_points);
|
||||
};
|
||||
|
||||
void points_in_boxes_all_forward_musa(int batch_size, int boxes_num,
|
||||
int pts_num, const Tensor boxes,
|
||||
const Tensor pts,
|
||||
Tensor box_idx_of_points) {
|
||||
PointsInBoxesAllForwardMUSAKernelLauncher(batch_size, boxes_num, pts_num,
|
||||
boxes, pts, box_idx_of_points);
|
||||
};
|
||||
|
||||
void points_in_boxes_part_forward_impl(int batch_size, int boxes_num,
|
||||
int pts_num, const Tensor boxes,
|
||||
const Tensor pts,
|
||||
Tensor box_idx_of_points);
|
||||
|
||||
void points_in_boxes_all_forward_impl(int batch_size, int boxes_num,
|
||||
int pts_num, const Tensor boxes,
|
||||
const Tensor pts,
|
||||
Tensor box_idx_of_points);
|
||||
REGISTER_DEVICE_IMPL(points_in_boxes_part_forward_impl, MUSA,
|
||||
points_in_boxes_part_forward_musa);
|
||||
REGISTER_DEVICE_IMPL(points_in_boxes_all_forward_impl, MUSA,
|
||||
points_in_boxes_all_forward_musa);
|
||||
|
||||
void PSAMaskForwardMUSAKernelLauncher(const int psa_type, const Tensor input,
|
||||
Tensor output, const int num_,
|
||||
const int h_feature, const int w_feature,
|
||||
const int h_mask, const int w_mask,
|
||||
const int half_h_mask,
|
||||
const int half_w_mask);
|
||||
|
||||
void PSAMaskBackwardMUSAKernelLauncher(
|
||||
const int psa_type, const Tensor grad_output, Tensor grad_input,
|
||||
const int num_, const int h_feature, const int w_feature, const int h_mask,
|
||||
const int w_mask, const int half_h_mask, const int half_w_mask);
|
||||
|
||||
void psamask_forward_musa(const int psa_type, const Tensor input, Tensor output,
|
||||
const int num_, const int h_feature,
|
||||
const int w_feature, const int h_mask,
|
||||
const int w_mask, const int half_h_mask,
|
||||
const int half_w_mask) {
|
||||
PSAMaskForwardMUSAKernelLauncher(psa_type, input, output, num_, h_feature,
|
||||
w_feature, h_mask, w_mask, half_h_mask,
|
||||
half_w_mask);
|
||||
}
|
||||
|
||||
void psamask_backward_musa(const int psa_type, const Tensor grad_output,
|
||||
Tensor grad_input, const int num_,
|
||||
const int h_feature, const int w_feature,
|
||||
const int h_mask, const int w_mask,
|
||||
const int half_h_mask, const int half_w_mask) {
|
||||
PSAMaskBackwardMUSAKernelLauncher(psa_type, grad_output, grad_input, num_,
|
||||
h_feature, w_feature, h_mask, w_mask,
|
||||
half_h_mask, half_w_mask);
|
||||
}
|
||||
|
||||
void psamask_forward_impl(const int psa_type, const Tensor input, Tensor output,
|
||||
const int num_, const int h_feature,
|
||||
const int w_feature, const int h_mask,
|
||||
const int w_mask, const int half_h_mask,
|
||||
const int half_w_mask);
|
||||
|
||||
void psamask_backward_impl(const int psa_type, const Tensor grad_output,
|
||||
Tensor grad_input, const int num_,
|
||||
const int h_feature, const int w_feature,
|
||||
const int h_mask, const int w_mask,
|
||||
const int half_h_mask, const int half_w_mask);
|
||||
REGISTER_DEVICE_IMPL(psamask_forward_impl, MUSA, psamask_forward_musa);
|
||||
REGISTER_DEVICE_IMPL(psamask_backward_impl, MUSA, psamask_backward_musa);
|
||||
|
||||
void ROIAlignForwardMUSAKernelLauncher(Tensor input, Tensor rois, Tensor output,
|
||||
Tensor argmax_y, Tensor argmax_x,
|
||||
int aligned_height, int aligned_width,
|
||||
float spatial_scale, int sampling_ratio,
|
||||
int pool_mode, bool aligned);
|
||||
|
||||
void ROIAlignBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois,
|
||||
Tensor argmax_y, Tensor argmax_x,
|
||||
Tensor grad_input, int aligned_height,
|
||||
int aligned_width, float spatial_scale,
|
||||
int sampling_ratio, int pool_mode,
|
||||
bool aligned);
|
||||
|
||||
void roi_align_forward_musa(Tensor input, Tensor rois, Tensor output,
|
||||
Tensor argmax_y, Tensor argmax_x,
|
||||
int aligned_height, int aligned_width,
|
||||
float spatial_scale, int sampling_ratio,
|
||||
int pool_mode, bool aligned) {
|
||||
ROIAlignForwardMUSAKernelLauncher(
|
||||
input, rois, output, argmax_y, argmax_x, aligned_height, aligned_width,
|
||||
spatial_scale, sampling_ratio, pool_mode, aligned);
|
||||
}
|
||||
|
||||
void roi_align_backward_musa(Tensor grad_output, Tensor rois, Tensor argmax_y,
|
||||
Tensor argmax_x, Tensor grad_input,
|
||||
int aligned_height, int aligned_width,
|
||||
float spatial_scale, int sampling_ratio,
|
||||
int pool_mode, bool aligned) {
|
||||
ROIAlignBackwardMUSAKernelLauncher(
|
||||
grad_output, rois, argmax_y, argmax_x, grad_input, aligned_height,
|
||||
aligned_width, spatial_scale, sampling_ratio, pool_mode, aligned);
|
||||
}
|
||||
|
||||
void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output,
|
||||
Tensor argmax_y, Tensor argmax_x,
|
||||
int aligned_height, int aligned_width,
|
||||
float spatial_scale, int sampling_ratio,
|
||||
int pool_mode, bool aligned);
|
||||
|
||||
void roi_align_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax_y,
|
||||
Tensor argmax_x, Tensor grad_input,
|
||||
int aligned_height, int aligned_width,
|
||||
float spatial_scale, int sampling_ratio,
|
||||
int pool_mode, bool aligned);
|
||||
|
||||
REGISTER_DEVICE_IMPL(roi_align_forward_impl, MUSA, roi_align_forward_musa);
|
||||
REGISTER_DEVICE_IMPL(roi_align_backward_impl, MUSA, roi_align_backward_musa);
|
||||
|
||||
void ROIAlignRotatedForwardMUSAKernelLauncher(
|
||||
const at::Tensor input, const at::Tensor rois, const float spatial_scale,
|
||||
const int sampling_ratio, const bool aligned, const bool clockwise,
|
||||
const int channels, const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, at::Tensor output);
|
||||
|
||||
void ROIAlignRotatedBackwardMUSAKernelLauncher(
|
||||
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
|
||||
const int sampling_ratio, const bool aligned, const bool clockwise,
|
||||
const int channels, const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, at::Tensor bottom_grad);
|
||||
|
||||
void roi_align_rotated_forward_musa(Tensor input, Tensor rois, Tensor output,
|
||||
int aligned_height, int aligned_width,
|
||||
float spatial_scale, int sampling_ratio,
|
||||
bool aligned, bool clockwise) {
|
||||
// Number of ROIs
|
||||
int num_rois = rois.size(0);
|
||||
int size_rois = rois.size(1);
|
||||
|
||||
if (size_rois != 6) {
|
||||
AT_ERROR("wrong roi size");
|
||||
}
|
||||
|
||||
int num_channels = input.size(1);
|
||||
int data_height = input.size(2);
|
||||
int data_width = input.size(3);
|
||||
ROIAlignRotatedForwardMUSAKernelLauncher(
|
||||
input, rois, spatial_scale, sampling_ratio, aligned, clockwise,
|
||||
num_channels, data_height, data_width, num_rois, aligned_height,
|
||||
aligned_width, output);
|
||||
}
|
||||
|
||||
void roi_align_rotated_backward_musa(Tensor top_grad, Tensor rois,
|
||||
Tensor bottom_grad, int aligned_height,
|
||||
int aligned_width, float spatial_scale,
|
||||
int sampling_ratio, bool aligned,
|
||||
bool clockwise) {
|
||||
// Number of ROIs
|
||||
int num_rois = rois.size(0);
|
||||
int size_rois = rois.size(1);
|
||||
if (size_rois != 6) {
|
||||
AT_ERROR("wrong roi size");
|
||||
}
|
||||
|
||||
int num_channels = bottom_grad.size(1);
|
||||
int data_height = bottom_grad.size(2);
|
||||
int data_width = bottom_grad.size(3);
|
||||
ROIAlignRotatedBackwardMUSAKernelLauncher(
|
||||
top_grad, rois, spatial_scale, sampling_ratio, aligned, clockwise,
|
||||
num_channels, data_height, data_width, num_rois, aligned_height,
|
||||
aligned_width, bottom_grad);
|
||||
}
|
||||
|
||||
void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output,
|
||||
int aligned_height, int aligned_width,
|
||||
float spatial_scale, int sampling_ratio,
|
||||
bool aligned, bool clockwise);
|
||||
|
||||
void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
|
||||
Tensor bottom_grad, int aligned_height,
|
||||
int aligned_width, float spatial_scale,
|
||||
int sampling_ratio, bool aligned,
|
||||
bool clockwise);
|
||||
REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, MUSA,
|
||||
roi_align_rotated_forward_musa);
|
||||
REGISTER_DEVICE_IMPL(roi_align_rotated_backward_impl, MUSA,
|
||||
roi_align_rotated_backward_musa);
|
||||
|
||||
void RiROIAlignRotatedForwardMUSAKernelLauncher(
|
||||
const at::Tensor features, const at::Tensor rois, const float spatial_scale,
|
||||
const int num_samples, const bool clockwise, const int channels,
|
||||
const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, const int num_orientations,
|
||||
at::Tensor output);
|
||||
|
||||
void RiROIAlignRotatedBackwardMUSAKernelLauncher(
|
||||
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
|
||||
const int num_samples, const bool clockwise, const int channels,
|
||||
const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, const int num_orientations,
|
||||
at::Tensor bottom_grad);
|
||||
|
||||
void riroi_align_rotated_forward_musa(Tensor features, Tensor rois,
|
||||
Tensor output, int pooled_height,
|
||||
int pooled_width, float spatial_scale,
|
||||
int num_samples, int num_orientations,
|
||||
bool clockwise) {
|
||||
// Number of ROIs
|
||||
int num_rois = rois.size(0);
|
||||
int size_rois = rois.size(1);
|
||||
if (size_rois != 6) {
|
||||
AT_ERROR("wrong roi size");
|
||||
}
|
||||
CHECK_CONTIGUOUS(features);
|
||||
CHECK_CONTIGUOUS(rois);
|
||||
int num_channels = features.size(1) / num_orientations;
|
||||
int data_height = features.size(2);
|
||||
int data_width = features.size(3);
|
||||
RiROIAlignRotatedForwardMUSAKernelLauncher(
|
||||
features, rois, spatial_scale, num_samples, clockwise, num_channels,
|
||||
data_height, data_width, num_rois, pooled_height, pooled_width,
|
||||
num_orientations, output);
|
||||
}
|
||||
|
||||
void riroi_align_rotated_backward_musa(Tensor top_grad, Tensor rois,
|
||||
Tensor bottom_grad, int pooled_height,
|
||||
int pooled_width, float spatial_scale,
|
||||
int num_samples, int num_orientations,
|
||||
bool clockwise) {
|
||||
// Number of ROIs
|
||||
int num_rois = rois.size(0);
|
||||
int size_rois = rois.size(1);
|
||||
if (size_rois != 6) {
|
||||
AT_ERROR("wrong roi size");
|
||||
}
|
||||
CHECK_CONTIGUOUS(top_grad);
|
||||
CHECK_CONTIGUOUS(rois);
|
||||
int num_channels = bottom_grad.size(1) / num_orientations;
|
||||
int data_height = bottom_grad.size(2);
|
||||
int data_width = bottom_grad.size(3);
|
||||
RiROIAlignRotatedBackwardMUSAKernelLauncher(
|
||||
top_grad, rois, spatial_scale, num_samples, clockwise, num_channels,
|
||||
data_height, data_width, num_rois, pooled_height, pooled_width,
|
||||
num_orientations, bottom_grad);
|
||||
}
|
||||
|
||||
void riroi_align_rotated_forward_impl(Tensor features, Tensor rois,
|
||||
Tensor output, int pooled_height,
|
||||
int pooled_width, float spatial_scale,
|
||||
int num_samples, int num_orientations,
|
||||
bool clockwise);
|
||||
|
||||
void riroi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
|
||||
Tensor bottom_grad, int pooled_height,
|
||||
int pooled_width, float spatial_scale,
|
||||
int num_samples, int num_orientations,
|
||||
bool clockwise);
|
||||
|
||||
REGISTER_DEVICE_IMPL(riroi_align_rotated_forward_impl, MUSA,
|
||||
riroi_align_rotated_forward_musa);
|
||||
REGISTER_DEVICE_IMPL(riroi_align_rotated_backward_impl, MUSA,
|
||||
riroi_align_rotated_backward_musa);
|
||||
|
||||
void RoiawarePool3dForwardMUSAKernelLauncher(
|
||||
int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x,
|
||||
int out_y, int out_z, const Tensor rois, const Tensor pts,
|
||||
const Tensor pts_feature, Tensor argmax, Tensor pts_idx_of_voxels,
|
||||
Tensor pooled_features, int pool_method);
|
||||
|
||||
void RoiawarePool3dBackwardMUSAKernelLauncher(
|
||||
int boxes_num, int out_x, int out_y, int out_z, int channels,
|
||||
int max_pts_each_voxel, const Tensor pts_idx_of_voxels, const Tensor argmax,
|
||||
const Tensor grad_out, Tensor grad_in, int pool_method);
|
||||
|
||||
void roiaware_pool3d_forward_musa(int boxes_num, int pts_num, int channels,
|
||||
int max_pts_each_voxel, int out_x, int out_y,
|
||||
int out_z, const Tensor rois,
|
||||
const Tensor pts, const Tensor pts_feature,
|
||||
Tensor argmax, Tensor pts_idx_of_voxels,
|
||||
Tensor pooled_features, int pool_method) {
|
||||
RoiawarePool3dForwardMUSAKernelLauncher(
|
||||
boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z,
|
||||
rois, pts, pts_feature, argmax, pts_idx_of_voxels, pooled_features,
|
||||
pool_method);
|
||||
};
|
||||
|
||||
void roiaware_pool3d_backward_musa(int boxes_num, int out_x, int out_y,
|
||||
int out_z, int channels,
|
||||
int max_pts_each_voxel,
|
||||
const Tensor pts_idx_of_voxels,
|
||||
const Tensor argmax, const Tensor grad_out,
|
||||
Tensor grad_in, int pool_method) {
|
||||
RoiawarePool3dBackwardMUSAKernelLauncher(
|
||||
boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel,
|
||||
pts_idx_of_voxels, argmax, grad_out, grad_in, pool_method);
|
||||
};
|
||||
|
||||
void roiaware_pool3d_forward_impl(int boxes_num, int pts_num, int channels,
|
||||
int max_pts_each_voxel, int out_x, int out_y,
|
||||
int out_z, const Tensor rois,
|
||||
const Tensor pts, const Tensor pts_feature,
|
||||
Tensor argmax, Tensor pts_idx_of_voxels,
|
||||
Tensor pooled_features, int pool_method);
|
||||
|
||||
void roiaware_pool3d_backward_impl(int boxes_num, int out_x, int out_y,
|
||||
int out_z, int channels,
|
||||
int max_pts_each_voxel,
|
||||
const Tensor pts_idx_of_voxels,
|
||||
const Tensor argmax, const Tensor grad_out,
|
||||
Tensor grad_in, int pool_method);
|
||||
|
||||
REGISTER_DEVICE_IMPL(roiaware_pool3d_forward_impl, MUSA,
|
||||
roiaware_pool3d_forward_musa);
|
||||
REGISTER_DEVICE_IMPL(roiaware_pool3d_backward_impl, MUSA,
|
||||
roiaware_pool3d_backward_musa);
|
||||
|
||||
void RoIPointPool3dForwardMUSAKernelLauncher(
|
||||
int batch_size, int pts_num, int boxes_num, int feature_in_len,
|
||||
int sampled_pts_num, const Tensor xyz, const Tensor boxes3d,
|
||||
const Tensor pts_feature, Tensor pooled_features, Tensor pooled_empty_flag);
|
||||
|
||||
void roipoint_pool3d_forward_musa(int batch_size, int pts_num, int boxes_num,
|
||||
int feature_in_len, int sampled_pts_num,
|
||||
const Tensor xyz, const Tensor boxes3d,
|
||||
const Tensor pts_feature,
|
||||
Tensor pooled_features,
|
||||
Tensor pooled_empty_flag) {
|
||||
RoIPointPool3dForwardMUSAKernelLauncher(
|
||||
batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, xyz,
|
||||
boxes3d, pts_feature, pooled_features, pooled_empty_flag);
|
||||
};
|
||||
|
||||
void roipoint_pool3d_forward_impl(int batch_size, int pts_num, int boxes_num,
|
||||
int feature_in_len, int sampled_pts_num,
|
||||
const Tensor xyz, const Tensor boxes3d,
|
||||
const Tensor pts_feature,
|
||||
Tensor pooled_features,
|
||||
Tensor pooled_empty_flag);
|
||||
REGISTER_DEVICE_IMPL(roipoint_pool3d_forward_impl, MUSA,
|
||||
roipoint_pool3d_forward_musa);
|
||||
|
||||
void ROIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois, Tensor output,
|
||||
Tensor argmax, int pooled_height,
|
||||
int pooled_width, float spatial_scale);
|
||||
|
||||
void ROIPoolBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois,
|
||||
Tensor argmax, Tensor grad_input,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale);
|
||||
|
||||
void roi_pool_forward_musa(Tensor input, Tensor rois, Tensor output,
|
||||
Tensor argmax, int pooled_height, int pooled_width,
|
||||
float spatial_scale) {
|
||||
ROIPoolForwardMUSAKernelLauncher(input, rois, output, argmax, pooled_height,
|
||||
pooled_width, spatial_scale);
|
||||
}
|
||||
|
||||
void roi_pool_backward_musa(Tensor grad_output, Tensor rois, Tensor argmax,
|
||||
Tensor grad_input, int pooled_height,
|
||||
int pooled_width, float spatial_scale) {
|
||||
ROIPoolBackwardMUSAKernelLauncher(grad_output, rois, argmax, grad_input,
|
||||
pooled_height, pooled_width, spatial_scale);
|
||||
}
|
||||
|
||||
void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
|
||||
Tensor argmax, int pooled_height, int pooled_width,
|
||||
float spatial_scale);
|
||||
void roi_pool_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax,
|
||||
Tensor grad_input, int pooled_height,
|
||||
int pooled_width, float spatial_scale);
|
||||
REGISTER_DEVICE_IMPL(roi_pool_forward_impl, MUSA, roi_pool_forward_musa);
|
||||
REGISTER_DEVICE_IMPL(roi_pool_backward_impl, MUSA, roi_pool_backward_musa);
|
||||
|
||||
typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t;
|
||||
|
||||
std::vector<at::Tensor> DynamicPointToVoxelForwardMUSAKernelLauncher(
|
||||
const at::Tensor &feats, const at::Tensor &coors,
|
||||
const reduce_t reduce_type);
|
||||
|
||||
void DynamicPointToVoxelBackwardMUSAKernelLauncher(
|
||||
at::Tensor &grad_feats, const at::Tensor &grad_reduced_feats,
|
||||
const at::Tensor &feats, const at::Tensor &reduced_feats,
|
||||
const at::Tensor &coors_map, const at::Tensor &reduce_count,
|
||||
const reduce_t reduce_type);
|
||||
|
||||
std::vector<torch::Tensor> dynamic_point_to_voxel_forward_musa(
|
||||
const torch::Tensor &feats, const torch::Tensor &coors,
|
||||
const reduce_t reduce_type) {
|
||||
return DynamicPointToVoxelForwardMUSAKernelLauncher(feats, coors,
|
||||
reduce_type);
|
||||
};
|
||||
|
||||
void dynamic_point_to_voxel_backward_musa(
|
||||
torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats,
|
||||
const torch::Tensor &feats, const torch::Tensor &reduced_feats,
|
||||
const torch::Tensor &coors_idx, const torch::Tensor &reduce_count,
|
||||
const reduce_t reduce_type) {
|
||||
DynamicPointToVoxelBackwardMUSAKernelLauncher(grad_feats, grad_reduced_feats,
|
||||
feats, reduced_feats, coors_idx,
|
||||
reduce_count, reduce_type);
|
||||
};
|
||||
|
||||
std::vector<torch::Tensor> dynamic_point_to_voxel_forward_impl(
|
||||
const torch::Tensor &feats, const torch::Tensor &coors,
|
||||
const reduce_t reduce_type);
|
||||
|
||||
void dynamic_point_to_voxel_backward_impl(
|
||||
torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats,
|
||||
const torch::Tensor &feats, const torch::Tensor &reduced_feats,
|
||||
const torch::Tensor &coors_idx, const torch::Tensor &reduce_count,
|
||||
const reduce_t reduce_type);
|
||||
|
||||
REGISTER_DEVICE_IMPL(dynamic_point_to_voxel_forward_impl, MUSA,
|
||||
dynamic_point_to_voxel_forward_musa);
|
||||
REGISTER_DEVICE_IMPL(dynamic_point_to_voxel_backward_impl, MUSA,
|
||||
dynamic_point_to_voxel_backward_musa);
|
||||
|
||||
void SyncBNForwardMeanMUSAKernelLauncher(const Tensor input, Tensor mean);
|
||||
|
||||
void SyncBNForwardVarMUSAKernelLauncher(const Tensor input, const Tensor mean,
|
||||
Tensor var);
|
||||
|
||||
void SyncBNForwardOutputMUSAKernelLauncher(
|
||||
const Tensor input, const Tensor mean, const Tensor var,
|
||||
Tensor running_mean, Tensor running_var, const Tensor weight,
|
||||
const Tensor bias, Tensor norm, Tensor std, Tensor output, float eps,
|
||||
float momentum, int group_size);
|
||||
|
||||
void SyncBNBackwardParamMUSAKernelLauncher(const Tensor grad_output,
|
||||
const Tensor norm,
|
||||
Tensor grad_weight,
|
||||
Tensor grad_bias);
|
||||
|
||||
void SyncBNBackwardDataMUSAKernelLauncher(const Tensor grad_output,
|
||||
const Tensor weight,
|
||||
const Tensor grad_weight,
|
||||
const Tensor grad_bias,
|
||||
const Tensor norm, const Tensor std,
|
||||
Tensor grad_input);
|
||||
|
||||
void sync_bn_forward_mean_musa(const Tensor input, Tensor mean) {
|
||||
SyncBNForwardMeanMUSAKernelLauncher(input, mean);
|
||||
}
|
||||
|
||||
void sync_bn_forward_var_musa(const Tensor input, const Tensor mean,
|
||||
Tensor var) {
|
||||
SyncBNForwardVarMUSAKernelLauncher(input, mean, var);
|
||||
}
|
||||
|
||||
void sync_bn_forward_output_musa(const Tensor input, const Tensor mean,
|
||||
const Tensor var, Tensor running_mean,
|
||||
Tensor running_var, const Tensor weight,
|
||||
const Tensor bias, Tensor norm, Tensor std,
|
||||
Tensor output, float eps, float momentum,
|
||||
int group_size) {
|
||||
SyncBNForwardOutputMUSAKernelLauncher(input, mean, var, running_mean,
|
||||
running_var, weight, bias, norm, std,
|
||||
output, eps, momentum, group_size);
|
||||
}
|
||||
|
||||
void sync_bn_backward_param_musa(const Tensor grad_output, const Tensor norm,
|
||||
Tensor grad_weight, Tensor grad_bias) {
|
||||
SyncBNBackwardParamMUSAKernelLauncher(grad_output, norm, grad_weight,
|
||||
grad_bias);
|
||||
}
|
||||
|
||||
void sync_bn_backward_data_musa(const Tensor grad_output, const Tensor weight,
|
||||
const Tensor grad_weight,
|
||||
const Tensor grad_bias, const Tensor norm,
|
||||
const Tensor std, Tensor grad_input) {
|
||||
SyncBNBackwardDataMUSAKernelLauncher(grad_output, weight, grad_weight,
|
||||
grad_bias, norm, std, grad_input);
|
||||
}
|
||||
|
||||
void sync_bn_forward_mean_impl(const Tensor input, Tensor mean);
|
||||
|
||||
void sync_bn_forward_var_impl(const Tensor input, const Tensor mean,
|
||||
Tensor var);
|
||||
|
||||
void sync_bn_forward_output_impl(const Tensor input, const Tensor mean,
|
||||
const Tensor var, Tensor running_mean,
|
||||
Tensor running_var, const Tensor weight,
|
||||
const Tensor bias, Tensor norm, Tensor std,
|
||||
Tensor output, float eps, float momentum,
|
||||
int group_size);
|
||||
|
||||
void sync_bn_backward_param_impl(const Tensor grad_output, const Tensor norm,
|
||||
Tensor grad_weight, Tensor grad_bias);
|
||||
|
||||
void sync_bn_backward_data_impl(const Tensor grad_output, const Tensor weight,
|
||||
const Tensor grad_weight,
|
||||
const Tensor grad_bias, const Tensor norm,
|
||||
const Tensor std, Tensor grad_input);
|
||||
|
||||
REGISTER_DEVICE_IMPL(sync_bn_forward_mean_impl, MUSA,
|
||||
sync_bn_forward_mean_musa);
|
||||
REGISTER_DEVICE_IMPL(sync_bn_forward_var_impl, MUSA, sync_bn_forward_var_musa);
|
||||
REGISTER_DEVICE_IMPL(sync_bn_forward_output_impl, MUSA,
|
||||
sync_bn_forward_output_musa);
|
||||
REGISTER_DEVICE_IMPL(sync_bn_backward_param_impl, MUSA,
|
||||
sync_bn_backward_param_musa);
|
||||
REGISTER_DEVICE_IMPL(sync_bn_backward_data_impl, MUSA,
|
||||
sync_bn_backward_data_musa);
|
||||
|
||||
void ThreeInterpolateForwardMUSAKernelLauncher(int b, int c, int m, int n,
|
||||
const Tensor points,
|
||||
const Tensor idx,
|
||||
const Tensor weight, Tensor out);
|
||||
|
||||
void ThreeInterpolateBackwardMUSAKernelLauncher(int b, int c, int n, int m,
|
||||
const Tensor grad_out,
|
||||
const Tensor idx,
|
||||
const Tensor weight,
|
||||
Tensor grad_points);
|
||||
|
||||
void three_interpolate_forward_musa(int b, int c, int m, int n,
|
||||
const Tensor points, const Tensor idx,
|
||||
const Tensor weight, Tensor out) {
|
||||
ThreeInterpolateForwardMUSAKernelLauncher(b, c, m, n, points, idx, weight,
|
||||
out);
|
||||
};
|
||||
|
||||
void three_interpolate_backward_musa(int b, int c, int n, int m,
|
||||
const Tensor grad_out, const Tensor idx,
|
||||
const Tensor weight, Tensor grad_points) {
|
||||
ThreeInterpolateBackwardMUSAKernelLauncher(b, c, n, m, grad_out, idx, weight,
|
||||
grad_points);
|
||||
};
|
||||
|
||||
void three_interpolate_forward_impl(int b, int c, int m, int n,
|
||||
const Tensor points, const Tensor idx,
|
||||
const Tensor weight, Tensor out);
|
||||
|
||||
void three_interpolate_backward_impl(int b, int c, int n, int m,
|
||||
const Tensor grad_out, const Tensor idx,
|
||||
const Tensor weight, Tensor grad_points);
|
||||
REGISTER_DEVICE_IMPL(three_interpolate_forward_impl, MUSA,
|
||||
three_interpolate_forward_musa);
|
||||
REGISTER_DEVICE_IMPL(three_interpolate_backward_impl, MUSA,
|
||||
three_interpolate_backward_musa);
|
||||
|
||||
void ThreeNNForwardMUSAKernelLauncher(int b, int n, int m, const Tensor unknown,
|
||||
const Tensor known, Tensor dist2,
|
||||
Tensor idx);
|
||||
|
||||
void three_nn_forward_musa(int b, int n, int m, const Tensor unknown,
|
||||
const Tensor known, Tensor dist2, Tensor idx) {
|
||||
ThreeNNForwardMUSAKernelLauncher(b, n, m, unknown, known, dist2, idx);
|
||||
};
|
||||
|
||||
void three_nn_forward_impl(int b, int n, int m, const Tensor unknown,
|
||||
const Tensor known, Tensor dist2, Tensor idx);
|
||||
REGISTER_DEVICE_IMPL(three_nn_forward_impl, MUSA, three_nn_forward_musa);
|
||||
|
||||
void TINShiftForwardMUSAKernelLauncher(Tensor input, Tensor shift,
|
||||
Tensor output);
|
||||
|
||||
void TINShiftBackwardMUSAKernelLauncher(Tensor grad_output, Tensor shift,
|
||||
Tensor grad_input);
|
||||
|
||||
void tin_shift_forward_musa(Tensor input, Tensor shift, Tensor output) {
|
||||
TINShiftForwardMUSAKernelLauncher(input, shift, output);
|
||||
}
|
||||
|
||||
void tin_shift_backward_musa(Tensor grad_output, Tensor shift,
|
||||
Tensor grad_input) {
|
||||
TINShiftBackwardMUSAKernelLauncher(grad_output, shift, grad_input);
|
||||
}
|
||||
|
||||
void tin_shift_forward_impl(Tensor input, Tensor shift, Tensor output);
|
||||
void tin_shift_backward_impl(Tensor grad_output, Tensor shift,
|
||||
Tensor grad_input);
|
||||
REGISTER_DEVICE_IMPL(tin_shift_forward_impl, MUSA, tin_shift_forward_musa);
|
||||
REGISTER_DEVICE_IMPL(tin_shift_backward_impl, MUSA, tin_shift_backward_musa);
|
||||
|
||||
#if ((!defined(MUSA_ARCH)) || (defined(MUSA_ARCH)) && (MUSA_ARCH > 21))
|
||||
torch::Tensor upfirdn2d_op(torch::Tensor input, torch::Tensor filter, int upx,
|
||||
int upy, int downx, int downy, int padx0, int padx1,
|
||||
int pady0, int pady1, bool flip, float gain);
|
||||
|
||||
torch::Tensor upfirdn2d_op_impl(torch::Tensor input, torch::Tensor filter,
|
||||
int upx, int upy, int downx, int downy,
|
||||
int padx0, int padx1, int pady0, int pady1,
|
||||
bool flip, float gain);
|
||||
REGISTER_DEVICE_IMPL(upfirdn2d_op_impl, MUSA, upfirdn2d_op);
|
||||
#endif
|
||||
|
||||
int HardVoxelizeForwardMUSAKernelLauncher(
|
||||
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
|
||||
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range, const int max_points,
|
||||
const int max_voxels, const int NDim = 3);
|
||||
|
||||
int NondeterministicHardVoxelizeForwardMUSAKernelLauncher(
|
||||
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
|
||||
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range, const int max_points,
|
||||
const int max_voxels, const int NDim = 3);
|
||||
|
||||
void DynamicVoxelizeForwardMUSAKernelLauncher(
|
||||
const at::Tensor &points, at::Tensor &coors,
|
||||
const std::vector<float> voxel_size, const std::vector<float> coors_range,
|
||||
const int NDim = 3);
|
||||
|
||||
int hard_voxelize_forward_musa(const at::Tensor &points, at::Tensor &voxels,
|
||||
at::Tensor &coors,
|
||||
at::Tensor &num_points_per_voxel,
|
||||
const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range,
|
||||
const int max_points, const int max_voxels,
|
||||
const int NDim) {
|
||||
return HardVoxelizeForwardMUSAKernelLauncher(
|
||||
points, voxels, coors, num_points_per_voxel, voxel_size, coors_range,
|
||||
max_points, max_voxels, NDim);
|
||||
};
|
||||
|
||||
int nondeterministic_hard_voxelize_forward_musa(
|
||||
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
|
||||
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range, const int max_points,
|
||||
const int max_voxels, const int NDim) {
|
||||
return NondeterministicHardVoxelizeForwardMUSAKernelLauncher(
|
||||
points, voxels, coors, num_points_per_voxel, voxel_size, coors_range,
|
||||
max_points, max_voxels, NDim);
|
||||
};
|
||||
|
||||
void dynamic_voxelize_forward_musa(const at::Tensor &points, at::Tensor &coors,
|
||||
const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range,
|
||||
const int NDim) {
|
||||
DynamicVoxelizeForwardMUSAKernelLauncher(points, coors, voxel_size,
|
||||
coors_range, NDim);
|
||||
};
|
||||
|
||||
int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels,
|
||||
at::Tensor &coors,
|
||||
at::Tensor &num_points_per_voxel,
|
||||
const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range,
|
||||
const int max_points, const int max_voxels,
|
||||
const int NDim);
|
||||
|
||||
int nondeterministic_hard_voxelize_forward_impl(
|
||||
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
|
||||
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range, const int max_points,
|
||||
const int max_voxels, const int NDim);
|
||||
|
||||
void dynamic_voxelize_forward_impl(const at::Tensor &points, at::Tensor &coors,
|
||||
const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range,
|
||||
const int NDim);
|
||||
|
||||
REGISTER_DEVICE_IMPL(hard_voxelize_forward_impl, MUSA,
|
||||
hard_voxelize_forward_musa);
|
||||
REGISTER_DEVICE_IMPL(nondeterministic_hard_voxelize_forward_impl, MUSA,
|
||||
nondeterministic_hard_voxelize_forward_musa);
|
||||
REGISTER_DEVICE_IMPL(dynamic_voxelize_forward_impl, MUSA,
|
||||
dynamic_voxelize_forward_musa);
|
||||
|
||||
void RotatedFeatureAlignForwardMUSAKernelLauncher(const Tensor features,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points,
|
||||
Tensor output);
|
||||
|
||||
void RotatedFeatureAlignBackwardMUSAKernelLauncher(const Tensor top_grad,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points,
|
||||
Tensor bottom_grad);
|
||||
|
||||
void rotated_feature_align_forward_musa(const Tensor features,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points, Tensor output) {
|
||||
RotatedFeatureAlignForwardMUSAKernelLauncher(features, best_bboxes,
|
||||
spatial_scale, points, output);
|
||||
};
|
||||
|
||||
void rotated_feature_align_backward_musa(const Tensor top_grad,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points, Tensor bottom_grad) {
|
||||
RotatedFeatureAlignBackwardMUSAKernelLauncher(
|
||||
top_grad, best_bboxes, spatial_scale, points, bottom_grad);
|
||||
};
|
||||
|
||||
void rotated_feature_align_forward_impl(const Tensor features,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points, Tensor output);
|
||||
|
||||
void rotated_feature_align_backward_impl(const Tensor top_grad,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points, Tensor bottom_grad);
|
||||
|
||||
REGISTER_DEVICE_IMPL(rotated_feature_align_forward_impl, MUSA,
|
||||
rotated_feature_align_forward_musa);
|
||||
REGISTER_DEVICE_IMPL(rotated_feature_align_backward_impl, MUSA,
|
||||
rotated_feature_align_backward_musa);
|
||||
|
||||
void PointsInPolygonsForwardMUSAKernelLauncher(const at::Tensor points,
|
||||
const at::Tensor polygons,
|
||||
const int rows, const int cols,
|
||||
at::Tensor output);
|
||||
|
||||
void points_in_polygons_forward_musa(const Tensor points, const Tensor polygons,
|
||||
Tensor output, const int rows,
|
||||
const int cols) {
|
||||
PointsInPolygonsForwardMUSAKernelLauncher(points, polygons, rows, cols,
|
||||
output);
|
||||
};
|
||||
|
||||
void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons,
|
||||
Tensor output, const int rows,
|
||||
const int cols);
|
||||
|
||||
REGISTER_DEVICE_IMPL(points_in_polygons_forward_impl, MUSA,
|
||||
points_in_polygons_forward_musa);
|
||||
|
||||
torch::Tensor IndiceMaxpoolForwardMUSAKernelLauncher(torch::Tensor features,
|
||||
torch::Tensor indicePairs,
|
||||
torch::Tensor indiceNum,
|
||||
int64_t numAct);
|
||||
|
||||
torch::Tensor indice_maxpool_forward_musa(torch::Tensor features,
|
||||
torch::Tensor indicePairs,
|
||||
torch::Tensor indiceNum,
|
||||
int64_t numAct) {
|
||||
return IndiceMaxpoolForwardMUSAKernelLauncher(features, indicePairs,
|
||||
indiceNum, numAct);
|
||||
};
|
||||
|
||||
torch::Tensor indice_maxpool_forward_impl(torch::Tensor features,
|
||||
torch::Tensor indicePairs,
|
||||
torch::Tensor indiceNum,
|
||||
int64_t numAct);
|
||||
REGISTER_DEVICE_IMPL(indice_maxpool_forward_impl, MUSA,
|
||||
indice_maxpool_forward_musa);
|
||||
|
||||
torch::Tensor IndiceMaxpoolBackwardMUSAKernelLauncher(torch::Tensor features,
|
||||
torch::Tensor outFeatures,
|
||||
torch::Tensor outGrad,
|
||||
torch::Tensor indicePairs,
|
||||
torch::Tensor indiceNum);
|
||||
|
||||
torch::Tensor indice_maxpool_backward_musa(torch::Tensor features,
|
||||
torch::Tensor outFeatures,
|
||||
torch::Tensor outGrad,
|
||||
torch::Tensor indicePairs,
|
||||
torch::Tensor indiceNum) {
|
||||
return IndiceMaxpoolBackwardMUSAKernelLauncher(features, outFeatures, outGrad,
|
||||
indicePairs, indiceNum);
|
||||
};
|
||||
|
||||
torch::Tensor indice_maxpool_backward_impl(torch::Tensor features,
|
||||
torch::Tensor outFeatures,
|
||||
torch::Tensor outGrad,
|
||||
torch::Tensor indicePairs,
|
||||
torch::Tensor indiceNum);
|
||||
|
||||
REGISTER_DEVICE_IMPL(indice_maxpool_backward_impl, MUSA,
|
||||
indice_maxpool_backward_musa)
|
||||
|
||||
torch::Tensor IndiceConvForwardMUSAKernelLauncher(
|
||||
torch::Tensor features, torch::Tensor filters, torch::Tensor indicePairs,
|
||||
torch::Tensor indiceNum, int64_t numActOut, int64_t _inverse,
|
||||
int64_t _subM);
|
||||
|
||||
torch::Tensor indice_conv_forward_musa(torch::Tensor features,
|
||||
torch::Tensor filters,
|
||||
torch::Tensor indicePairs,
|
||||
torch::Tensor indiceNum,
|
||||
int64_t numActOut, int64_t _inverse,
|
||||
int64_t _subM) {
|
||||
return IndiceConvForwardMUSAKernelLauncher(
|
||||
features, filters, indicePairs, indiceNum, numActOut, _inverse, _subM);
|
||||
};
|
||||
|
||||
torch::Tensor indice_conv_forward_impl(torch::Tensor features,
|
||||
torch::Tensor filters,
|
||||
torch::Tensor indicePairs,
|
||||
torch::Tensor indiceNum,
|
||||
int64_t numActOut, int64_t _inverse,
|
||||
int64_t _subM);
|
||||
|
||||
REGISTER_DEVICE_IMPL(indice_conv_forward_impl, MUSA, indice_conv_forward_musa);
|
||||
|
||||
std::vector<torch::Tensor> IndiceConvBackwardMUSAKernelLauncher(
|
||||
torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad,
|
||||
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
|
||||
int64_t _subM);
|
||||
|
||||
std::vector<torch::Tensor> indice_conv_backward_musa(
|
||||
torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad,
|
||||
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
|
||||
int64_t _subM) {
|
||||
return IndiceConvBackwardMUSAKernelLauncher(
|
||||
features, filters, outGrad, indicePairs, indiceNum, _inverse, _subM);
|
||||
};
|
||||
|
||||
std::vector<torch::Tensor> indice_conv_backward_impl(
|
||||
torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad,
|
||||
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
|
||||
int64_t _subM);
|
||||
|
||||
REGISTER_DEVICE_IMPL(indice_conv_backward_impl, MUSA,
|
||||
indice_conv_backward_musa);
|
||||
|
||||
torch::Tensor FusedIndiceConvBatchnormMUSAKernelLauncher(
|
||||
torch::Tensor features, torch::Tensor filters, torch::Tensor bias,
|
||||
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t numActOut,
|
||||
int64_t _inverse, int64_t _subM);
|
||||
|
||||
torch::Tensor fused_indice_conv_batchnorm_forward_musa(
|
||||
torch::Tensor features, torch::Tensor filters, torch::Tensor bias,
|
||||
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t numActOut,
|
||||
int64_t _inverse, int64_t _subM) {
|
||||
return FusedIndiceConvBatchnormMUSAKernelLauncher(features, filters, bias,
|
||||
indicePairs, indiceNum,
|
||||
numActOut, _inverse, _subM);
|
||||
};
|
||||
|
||||
torch::Tensor fused_indice_conv_batchnorm_forward_impl(
|
||||
torch::Tensor features, torch::Tensor filters, torch::Tensor bias,
|
||||
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t numActOut,
|
||||
int64_t _inverse, int64_t _subM);
|
||||
|
||||
REGISTER_DEVICE_IMPL(fused_indice_conv_batchnorm_forward_impl, MUSA,
|
||||
fused_indice_conv_batchnorm_forward_musa)
|
||||
|
||||
void MinAreaPolygonsMUSAKernelLauncher(const Tensor pointsets, Tensor polygons);
|
||||
|
||||
void min_area_polygons_musa(const Tensor pointsets, Tensor polygons) {
|
||||
|
@ -990,6 +1849,57 @@ REGISTER_DEVICE_IMPL(chamfer_distance_backward_impl, MUSA,
|
|||
chamfer_distance_backward_musa);
|
||||
#endif
|
||||
|
||||
void PrROIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois,
|
||||
Tensor output, int pooled_height,
|
||||
int pooled_width, float spatial_scale);
|
||||
|
||||
void PrROIPoolBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois,
|
||||
Tensor grad_input, int pooled_height,
|
||||
int pooled_width, float spatial_scale);
|
||||
|
||||
void PrROIPoolCoorBackwardMUSAKernelLauncher(
|
||||
Tensor output, Tensor grad_output, Tensor input, Tensor rois,
|
||||
Tensor grad_rois, int pooled_height, int pooled_width, float spatial_scale);
|
||||
|
||||
void prroi_pool_forward_musa(Tensor input, Tensor rois, Tensor output,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale) {
|
||||
PrROIPoolForwardMUSAKernelLauncher(input, rois, output, pooled_height,
|
||||
pooled_width, spatial_scale);
|
||||
}
|
||||
|
||||
void prroi_pool_backward_musa(Tensor grad_output, Tensor rois,
|
||||
Tensor grad_input, int pooled_height,
|
||||
int pooled_width, float spatial_scale) {
|
||||
PrROIPoolBackwardMUSAKernelLauncher(grad_output, rois, grad_input,
|
||||
pooled_height, pooled_width,
|
||||
spatial_scale);
|
||||
}
|
||||
|
||||
void prroi_pool_coor_backward_musa(Tensor output, Tensor grad_output,
|
||||
Tensor input, Tensor rois, Tensor grad_rois,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale) {
|
||||
PrROIPoolCoorBackwardMUSAKernelLauncher(output, grad_output, input, rois,
|
||||
grad_rois, pooled_height,
|
||||
pooled_width, spatial_scale);
|
||||
}
|
||||
|
||||
void prroi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale);
|
||||
void prroi_pool_backward_impl(Tensor grad_output, Tensor rois,
|
||||
Tensor grad_input, int pooled_height,
|
||||
int pooled_width, float spatial_scale);
|
||||
void prroi_pool_coor_backward_impl(Tensor output, Tensor grad_output,
|
||||
Tensor input, Tensor rois, Tensor grad_rois,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale);
|
||||
REGISTER_DEVICE_IMPL(prroi_pool_forward_impl, MUSA, prroi_pool_forward_musa);
|
||||
REGISTER_DEVICE_IMPL(prroi_pool_backward_impl, MUSA, prroi_pool_backward_musa);
|
||||
REGISTER_DEVICE_IMPL(prroi_pool_coor_backward_impl, MUSA,
|
||||
prroi_pool_coor_backward_musa);
|
||||
|
||||
void BezierAlignForwardMUSAKernelLauncher(Tensor input, Tensor rois,
|
||||
Tensor output, int aligned_height,
|
||||
int aligned_width,
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
// Modified from
|
||||
// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
|
||||
// Written by Shaoshuai Shi
|
||||
// All Rights Reserved 2019.
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include "points_in_boxes_musa_kernel.muh"
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
void PointsInBoxesPartForwardMUSAKernelLauncher(int batch_size, int boxes_num,
|
||||
int pts_num, const Tensor boxes,
|
||||
const Tensor pts,
|
||||
Tensor box_idx_of_points) {
|
||||
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
|
||||
// coordinate, z is
|
||||
// the bottom center, each box DO NOT overlaps params pts: (B, npoints, 3) [x,
|
||||
// y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default
|
||||
// -1
|
||||
|
||||
c10::musa::MUSAGuard device_guard(boxes.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
|
||||
dim3 blocks(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), batch_size);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
boxes.scalar_type(), "points_in_boxes_part_forward_musa_kernel", [&] {
|
||||
points_in_boxes_part_forward_musa_kernel<scalar_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
batch_size, boxes_num, pts_num, boxes.data_ptr<scalar_t>(),
|
||||
pts.data_ptr<scalar_t>(), box_idx_of_points.data_ptr<int>());
|
||||
});
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
||||
|
||||
void PointsInBoxesAllForwardMUSAKernelLauncher(int batch_size, int boxes_num,
|
||||
int pts_num, const Tensor boxes,
|
||||
const Tensor pts,
|
||||
Tensor box_idx_of_points) {
|
||||
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
|
||||
// coordinate, z is the bottom center, each box params pts: (B, npoints, 3)
|
||||
// [x, y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints),
|
||||
// default -1
|
||||
|
||||
c10::musa::MUSAGuard device_guard(boxes.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
|
||||
dim3 blocks(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), batch_size);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
boxes.scalar_type(), "points_in_boxes_all_forward_musa_kernel", [&] {
|
||||
points_in_boxes_all_forward_musa_kernel<scalar_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
batch_size, boxes_num, pts_num, boxes.data_ptr<scalar_t>(),
|
||||
pts.data_ptr<scalar_t>(), box_idx_of_points.data_ptr<int>());
|
||||
});
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
// Modified from
|
||||
// https://github.com/ming71/MUSA/blob/master/point_justify/points_justify_kernel.cu
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include "points_in_polygons_musa_kernel.muh"
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
void PointsInPolygonsForwardMUSAKernelLauncher(const at::Tensor points,
|
||||
const at::Tensor polygons,
|
||||
const int rows, const int cols,
|
||||
at::Tensor output) {
|
||||
const int output_size = rows * cols;
|
||||
c10::musa::MUSAGuard device_guard(points.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
points.scalar_type(), "points_in_polygons_forward_musa_kernel", ([&] {
|
||||
const scalar_t *vertex1 = points.data_ptr<scalar_t>();
|
||||
const scalar_t *vertex2 = polygons.data_ptr<scalar_t>();
|
||||
scalar_t *inside_flag = output.data_ptr<scalar_t>();
|
||||
|
||||
points_in_polygons_forward_musa_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, vertex1, vertex2, rows, cols, inside_flag);
|
||||
}));
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include "prroi_pool_musa_kernel.muh"
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
void PrROIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois,
|
||||
Tensor output, int pooled_height,
|
||||
int pooled_width, float spatial_scale) {
|
||||
int output_size = output.numel();
|
||||
int channels = input.size(1);
|
||||
int height = input.size(2);
|
||||
int width = input.size(3);
|
||||
|
||||
c10::musa::MUSAGuard device_guard(input.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
prroi_pool_forward_musa_kernel<float>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, input.data_ptr<float>(), rois.data_ptr<float>(),
|
||||
output.data_ptr<float>(), pooled_height, pooled_width,
|
||||
static_cast<float>(spatial_scale), channels, height, width);
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
||||
|
||||
void PrROIPoolBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois,
|
||||
Tensor grad_input, int pooled_height,
|
||||
int pooled_width,
|
||||
float spatial_scale) {
|
||||
int output_size = grad_output.numel();
|
||||
int channels = grad_input.size(1);
|
||||
int height = grad_input.size(2);
|
||||
int width = grad_input.size(3);
|
||||
|
||||
c10::musa::MUSAGuard device_guard(grad_output.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
prroi_pool_backward_musa_kernel<float>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, grad_output.data_ptr<float>(), rois.data_ptr<float>(),
|
||||
grad_input.data_ptr<float>(), pooled_height, pooled_width,
|
||||
static_cast<float>(spatial_scale), channels, height, width);
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
||||
|
||||
void PrROIPoolCoorBackwardMUSAKernelLauncher(Tensor output, Tensor grad_output,
|
||||
Tensor input, Tensor rois,
|
||||
Tensor grad_rois,
|
||||
int pooled_height,
|
||||
int pooled_width,
|
||||
float spatial_scale) {
|
||||
int output_size = grad_output.numel();
|
||||
int channels = input.size(1);
|
||||
int height = input.size(2);
|
||||
int width = input.size(3);
|
||||
|
||||
c10::musa::MUSAGuard device_guard(grad_output.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
prroi_pool_coor_backward_musa_kernel<float>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, output.data_ptr<float>(), grad_output.data_ptr<float>(),
|
||||
input.data_ptr<float>(), rois.data_ptr<float>(),
|
||||
grad_rois.data_ptr<float>(), pooled_height, pooled_width,
|
||||
static_cast<float>(spatial_scale), channels, height, width);
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
// Modified from
|
||||
// https://github.com/hszhao/semseg/blob/master/lib/psa/src
|
||||
|
||||
#include <torch/serialize/tensor.h>
|
||||
|
||||
#include "psamask_musa_kernel.muh"
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
void PSAMaskForwardMUSAKernelLauncher(const int psa_type, const Tensor input,
|
||||
Tensor output, const int num_,
|
||||
const int h_feature, const int w_feature,
|
||||
const int h_mask, const int w_mask,
|
||||
const int half_h_mask,
|
||||
const int half_w_mask) {
|
||||
int nthreads = num_ * h_feature * w_feature;
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
if (psa_type == 0)
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "psamask_collect_forward_musa", [&] {
|
||||
psamask_collect_forward_musa<scalar_t><<<nthreads, 512, 0, stream>>>(
|
||||
nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask,
|
||||
half_w_mask, input.data_ptr<scalar_t>(),
|
||||
output.data_ptr<scalar_t>());
|
||||
});
|
||||
else
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "psamask_distribute_forward_musa", [&] {
|
||||
psamask_distribute_forward_musa<scalar_t>
|
||||
<<<nthreads, 512, 0, stream>>>(
|
||||
nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask,
|
||||
half_w_mask, input.data_ptr<scalar_t>(),
|
||||
output.data_ptr<scalar_t>());
|
||||
});
|
||||
}
|
||||
|
||||
void PSAMaskBackwardMUSAKernelLauncher(
|
||||
const int psa_type, const Tensor grad_output, Tensor grad_input,
|
||||
const int num_, const int h_feature, const int w_feature, const int h_mask,
|
||||
const int w_mask, const int half_h_mask, const int half_w_mask) {
|
||||
int nthreads = num_ * h_feature * w_feature;
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
if (psa_type == 0)
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
grad_input.scalar_type(), "psamask_collect_backward_musa", [&] {
|
||||
psamask_collect_backward_musa<scalar_t><<<nthreads, 512, 0, stream>>>(
|
||||
nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask,
|
||||
half_w_mask, grad_output.data_ptr<scalar_t>(),
|
||||
grad_input.data_ptr<scalar_t>());
|
||||
});
|
||||
else
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
grad_input.scalar_type(), "psamask_distribute_backward_musa", [&] {
|
||||
psamask_distribute_backward_musa<scalar_t>
|
||||
<<<nthreads, 512, 0, stream>>>(
|
||||
nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask,
|
||||
half_w_mask, grad_output.data_ptr<scalar_t>(),
|
||||
grad_input.data_ptr<scalar_t>());
|
||||
});
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
#include "riroi_align_rotated_musa_kernel.muh"
|
||||
|
||||
void RiROIAlignRotatedForwardMUSAKernelLauncher(
|
||||
const at::Tensor features, const at::Tensor rois, const float spatial_scale,
|
||||
const int num_samples, const bool clockwise, const int channels,
|
||||
const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, const int num_orientations,
|
||||
at::Tensor output) {
|
||||
const int output_size =
|
||||
num_rois * pooled_height * pooled_width * channels * num_orientations;
|
||||
c10::musa::MUSAGuard device_guard(features.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
features.scalar_type(), "riroi_align_rotated_forward_musa_kernel", ([&] {
|
||||
const scalar_t *bottom_data = features.data_ptr<scalar_t>();
|
||||
const scalar_t *rois_data = rois.data_ptr<scalar_t>();
|
||||
scalar_t *top_data = output.data_ptr<scalar_t>();
|
||||
|
||||
riroi_align_rotated_forward_musa_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, bottom_data, rois_data, scalar_t(spatial_scale),
|
||||
num_samples, clockwise, channels, height, width, pooled_height,
|
||||
pooled_width, num_orientations, top_data);
|
||||
}));
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
||||
|
||||
void RiROIAlignRotatedBackwardMUSAKernelLauncher(
|
||||
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
|
||||
const int num_samples, const bool clockwise, const int channels,
|
||||
const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, const int num_orientations,
|
||||
at::Tensor bottom_grad) {
|
||||
const int output_size =
|
||||
num_rois * pooled_height * pooled_width * channels * num_orientations;
|
||||
c10::musa::MUSAGuard device_guard(top_grad.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
top_grad.scalar_type(), "riroi_align_rotated_backward_musa_kernel", ([&] {
|
||||
const scalar_t *top_diff = top_grad.data_ptr<scalar_t>();
|
||||
const scalar_t *rois_data = rois.data_ptr<scalar_t>();
|
||||
scalar_t *bottom_diff = bottom_grad.data_ptr<scalar_t>();
|
||||
riroi_align_rotated_backward_musa_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, top_diff, rois_data, spatial_scale, num_samples,
|
||||
clockwise, channels, height, width, pooled_height, pooled_width,
|
||||
num_orientations, bottom_diff);
|
||||
}));
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
#include "roi_align_musa_kernel.muh"
|
||||
|
||||
void ROIAlignForwardMUSAKernelLauncher(Tensor input, Tensor rois, Tensor output,
|
||||
Tensor argmax_y, Tensor argmax_x,
|
||||
int aligned_height, int aligned_width,
|
||||
float spatial_scale, int sampling_ratio,
|
||||
int pool_mode, bool aligned) {
|
||||
int output_size = output.numel();
|
||||
int channels = input.size(1);
|
||||
int height = input.size(2);
|
||||
int width = input.size(3);
|
||||
|
||||
c10::musa::MUSAGuard device_guard(input.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input.scalar_type(), "roi_align_forward_musa_kernel", [&] {
|
||||
roi_align_forward_musa_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, input.data_ptr<scalar_t>(),
|
||||
rois.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
|
||||
argmax_y.data_ptr<scalar_t>(), argmax_x.data_ptr<scalar_t>(),
|
||||
aligned_height, aligned_width,
|
||||
static_cast<scalar_t>(spatial_scale), sampling_ratio, pool_mode,
|
||||
aligned, channels, height, width);
|
||||
});
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
||||
|
||||
void ROIAlignBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois,
|
||||
Tensor argmax_y, Tensor argmax_x,
|
||||
Tensor grad_input, int aligned_height,
|
||||
int aligned_width, float spatial_scale,
|
||||
int sampling_ratio, int pool_mode,
|
||||
bool aligned) {
|
||||
int output_size = grad_output.numel();
|
||||
int channels = grad_input.size(1);
|
||||
int height = grad_input.size(2);
|
||||
int width = grad_input.size(3);
|
||||
|
||||
c10::musa::MUSAGuard device_guard(grad_output.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grad_output.scalar_type(), "roi_align_backward_musa_kernel", [&] {
|
||||
roi_align_backward_musa_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, grad_output.data_ptr<scalar_t>(),
|
||||
rois.data_ptr<scalar_t>(), argmax_y.data_ptr<scalar_t>(),
|
||||
argmax_x.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
|
||||
aligned_height, aligned_width,
|
||||
static_cast<scalar_t>(spatial_scale), sampling_ratio, pool_mode,
|
||||
aligned, channels, height, width);
|
||||
});
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
#include "roi_align_rotated_musa_kernel.muh"
|
||||
|
||||
void ROIAlignRotatedForwardMUSAKernelLauncher(
|
||||
const at::Tensor input, const at::Tensor rois, const float spatial_scale,
|
||||
const int sampling_ratio, const bool aligned, const bool clockwise,
|
||||
const int channels, const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, at::Tensor output) {
|
||||
const int output_size = num_rois * pooled_height * pooled_width * channels;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input.scalar_type(), "ROIAlignRotatedLaucherForward", ([&] {
|
||||
const scalar_t *bottom_data = input.data_ptr<scalar_t>();
|
||||
const scalar_t *rois_data = rois.data_ptr<scalar_t>();
|
||||
scalar_t *top_data = output.data_ptr<scalar_t>();
|
||||
|
||||
roi_align_rotated_forward_musa_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
|
||||
output_size, bottom_data, rois_data, scalar_t(spatial_scale),
|
||||
sampling_ratio, aligned, clockwise, channels, height, width,
|
||||
pooled_height, pooled_width, top_data);
|
||||
}));
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
||||
|
||||
void ROIAlignRotatedBackwardMUSAKernelLauncher(
|
||||
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
|
||||
const int sampling_ratio, const bool aligned, const bool clockwise,
|
||||
const int channels, const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, at::Tensor bottom_grad) {
|
||||
const int output_size = num_rois * pooled_height * pooled_width * channels;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
top_grad.scalar_type(), "ROIAlignLaucherBackward", ([&] {
|
||||
const scalar_t *top_diff = top_grad.data_ptr<scalar_t>();
|
||||
const scalar_t *rois_data = rois.data_ptr<scalar_t>();
|
||||
scalar_t *bottom_diff = bottom_grad.data_ptr<scalar_t>();
|
||||
roi_align_rotated_backward_musa_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
|
||||
output_size, top_diff, rois_data, spatial_scale, sampling_ratio,
|
||||
aligned, clockwise, channels, height, width, pooled_height,
|
||||
pooled_width, bottom_diff);
|
||||
}));
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
#include "roi_pool_musa_kernel.muh"
|
||||
|
||||
void ROIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois, Tensor output,
|
||||
Tensor argmax, int pooled_height,
|
||||
int pooled_width, float spatial_scale) {
|
||||
int output_size = output.numel();
|
||||
int channels = input.size(1);
|
||||
int height = input.size(2);
|
||||
int width = input.size(3);
|
||||
|
||||
c10::musa::MUSAGuard device_guard(input.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input.scalar_type(), "roi_pool_forward_musa_kernel", [&] {
|
||||
roi_pool_forward_musa_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, input.data_ptr<scalar_t>(),
|
||||
rois.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
|
||||
argmax.data_ptr<int>(), pooled_height, pooled_width,
|
||||
static_cast<scalar_t>(spatial_scale), channels, height, width);
|
||||
});
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
||||
|
||||
void ROIPoolBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois,
|
||||
Tensor argmax, Tensor grad_input,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale) {
|
||||
int output_size = grad_output.numel();
|
||||
int channels = grad_input.size(1);
|
||||
int height = grad_input.size(2);
|
||||
int width = grad_input.size(3);
|
||||
|
||||
c10::musa::MUSAGuard device_guard(grad_output.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grad_output.scalar_type(), "roi_pool_backward_musa_kernel", [&] {
|
||||
roi_pool_backward_musa_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, grad_output.data_ptr<scalar_t>(),
|
||||
rois.data_ptr<scalar_t>(), argmax.data_ptr<int>(),
|
||||
grad_input.data_ptr<scalar_t>(), pooled_height, pooled_width,
|
||||
channels, height, width);
|
||||
});
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
// Modified from
|
||||
// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
|
||||
// Written by Shaoshuai Shi
|
||||
// All Rights Reserved 2019.
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
#include "roiaware_pool3d_musa_kernel.muh"
|
||||
|
||||
void RoiawarePool3dForwardMUSAKernelLauncher(
|
||||
int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x,
|
||||
int out_y, int out_z, const Tensor rois, const Tensor pts,
|
||||
const Tensor pts_feature, Tensor argmax, Tensor pts_idx_of_voxels,
|
||||
Tensor pooled_features, int pool_method) {
|
||||
// params rois: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
|
||||
// coordinate params pts: (npoints, 3) [x, y, z] in LiDAR coordinate params
|
||||
// pts_feature: (npoints, C) params argmax: (N, out_x, out_y, out_z, C) params
|
||||
// pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) params
|
||||
// pooled_features: (N, out_x, out_y, out_z, C) params pool_method: 0:
|
||||
// max_pool 1: avg_pool
|
||||
|
||||
c10::musa::MUSAGuard device_guard(pts_feature.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
|
||||
Tensor pts_mask =
|
||||
-at::ones({boxes_num, pts_num}, pts_feature.options().dtype(at::kInt));
|
||||
|
||||
dim3 blocks_mask(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), boxes_num);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
rois.scalar_type(), "generate_pts_mask_for_box3d", [&] {
|
||||
generate_pts_mask_for_box3d<scalar_t>
|
||||
<<<blocks_mask, threads, 0, stream>>>(
|
||||
boxes_num, pts_num, out_x, out_y, out_z,
|
||||
rois.data_ptr<scalar_t>(), pts.data_ptr<scalar_t>(),
|
||||
pts_mask.data_ptr<int>());
|
||||
});
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
|
||||
// TODO: Merge the collect and pool functions, SS
|
||||
|
||||
dim3 blocks_collect(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK));
|
||||
|
||||
AT_DISPATCH_INTEGRAL_TYPES(
|
||||
pts_idx_of_voxels.scalar_type(), "collect_inside_pts_for_box3d", [&] {
|
||||
collect_inside_pts_for_box3d<scalar_t>
|
||||
<<<blocks_collect, threads, 0, stream>>>(
|
||||
boxes_num, pts_num, max_pts_each_voxel, out_x, out_y, out_z,
|
||||
pts_mask.data_ptr<int>(),
|
||||
pts_idx_of_voxels.data_ptr<scalar_t>());
|
||||
});
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
|
||||
dim3 blocks_pool(GET_BLOCKS(out_x * out_y * out_z, THREADS_PER_BLOCK),
|
||||
channels, boxes_num);
|
||||
if (pool_method == 0) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
pts_feature.scalar_type(), "roiaware_maxpool3d", [&] {
|
||||
roiaware_maxpool3d<scalar_t><<<blocks_pool, threads, 0, stream>>>(
|
||||
boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y,
|
||||
out_z, pts_feature.data_ptr<scalar_t>(),
|
||||
pts_idx_of_voxels.data_ptr<int>(),
|
||||
pooled_features.data_ptr<scalar_t>(), argmax.data_ptr<int>());
|
||||
});
|
||||
} else if (pool_method == 1) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
pts_feature.scalar_type(), "roiaware_avgpool3d", [&] {
|
||||
roiaware_avgpool3d<scalar_t><<<blocks_pool, threads, 0, stream>>>(
|
||||
boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y,
|
||||
out_z, pts_feature.data_ptr<scalar_t>(),
|
||||
pts_idx_of_voxels.data_ptr<int>(),
|
||||
pooled_features.data_ptr<scalar_t>());
|
||||
});
|
||||
}
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
||||
|
||||
void RoiawarePool3dBackwardMUSAKernelLauncher(
|
||||
int boxes_num, int out_x, int out_y, int out_z, int channels,
|
||||
int max_pts_each_voxel, const Tensor pts_idx_of_voxels, const Tensor argmax,
|
||||
const Tensor grad_out, Tensor grad_in, int pool_method) {
|
||||
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
|
||||
// params argmax: (N, out_x, out_y, out_z, C)
|
||||
// params grad_out: (N, out_x, out_y, out_z, C)
|
||||
// params grad_in: (npoints, C), return value
|
||||
// params pool_method: 0: max_pool, 1: avg_pool
|
||||
|
||||
c10::musa::MUSAGuard device_guard(grad_out.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
|
||||
dim3 blocks(GET_BLOCKS(out_x * out_y * out_z, THREADS_PER_BLOCK), channels,
|
||||
boxes_num);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
if (pool_method == 0) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grad_in.scalar_type(), "roiaware_maxpool3d_backward", [&] {
|
||||
roiaware_maxpool3d_backward<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
boxes_num, channels, out_x, out_y, out_z, argmax.data_ptr<int>(),
|
||||
grad_out.data_ptr<scalar_t>(), grad_in.data_ptr<scalar_t>());
|
||||
});
|
||||
} else if (pool_method == 1) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grad_in.scalar_type(), "roiaware_avgpool3d_backward", [&] {
|
||||
roiaware_avgpool3d_backward<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
boxes_num, channels, out_x, out_y, out_z, max_pts_each_voxel,
|
||||
pts_idx_of_voxels.data_ptr<int>(), grad_out.data_ptr<scalar_t>(),
|
||||
grad_in.data_ptr<scalar_t>());
|
||||
});
|
||||
}
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
Modified from
|
||||
https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/roipoint_pool3d/src/roipoint_pool3d_kernel.cu
|
||||
Point cloud feature pooling
|
||||
Written by Shaoshuai Shi
|
||||
All Rights Reserved 2018.
|
||||
*/
|
||||
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
#include "roipoint_pool3d_musa_kernel.muh"
|
||||
|
||||
void RoIPointPool3dForwardMUSAKernelLauncher(
|
||||
int batch_size, int pts_num, int boxes_num, int feature_in_len,
|
||||
int sampled_pts_num, const Tensor xyz, const Tensor boxes3d,
|
||||
const Tensor pts_feature, Tensor pooled_features,
|
||||
Tensor pooled_empty_flag) {
|
||||
Tensor pts_assign = at::empty({batch_size, pts_num, boxes_num},
|
||||
boxes3d.options().dtype(at::kInt));
|
||||
|
||||
c10::musa::MUSAGuard device_guard(xyz.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), boxes_num, batch_size);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
xyz.scalar_type(), "assign_pts_to_box3d", [&] {
|
||||
assign_pts_to_box3d<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
batch_size, pts_num, boxes_num, xyz.data_ptr<scalar_t>(),
|
||||
boxes3d.data_ptr<scalar_t>(), pts_assign.data_ptr<int>());
|
||||
});
|
||||
|
||||
Tensor pts_idx = at::empty({batch_size, boxes_num, sampled_pts_num},
|
||||
boxes3d.options().dtype(at::kInt));
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks2(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK), batch_size);
|
||||
|
||||
get_pooled_idx<<<blocks2, threads, 0, stream>>>(
|
||||
batch_size, pts_num, boxes_num, sampled_pts_num,
|
||||
pts_assign.data_ptr<int>(), pts_idx.data_ptr<int>(),
|
||||
pooled_empty_flag.data_ptr<int>());
|
||||
|
||||
dim3 blocks_pool(GET_BLOCKS(sampled_pts_num, THREADS_PER_BLOCK), boxes_num,
|
||||
batch_size);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
xyz.scalar_type(), "roipoint_pool3d_forward", [&] {
|
||||
roipoint_pool3d_forward<scalar_t><<<blocks_pool, threads, 0, stream>>>(
|
||||
batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num,
|
||||
xyz.data_ptr<scalar_t>(), pts_idx.data_ptr<int>(),
|
||||
pts_feature.data_ptr<scalar_t>(),
|
||||
pooled_features.data_ptr<scalar_t>(),
|
||||
pooled_empty_flag.data_ptr<int>());
|
||||
});
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
// Modified from
|
||||
// https://github.com/SJTU-Thinklab-Det/r3det-on-mmdetection/blob/master/mmdet/ops/fr/src/feature_refine_kernel.cu
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
#include "rotated_feature_align_musa_kernel.muh"
|
||||
|
||||
void RotatedFeatureAlignForwardMUSAKernelLauncher(const Tensor features,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points,
|
||||
Tensor output) {
|
||||
c10::musa::MUSAGuard device_guard(features.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
const int output_size = features.numel();
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
features.scalar_type(), "rotated_feature_align_forward_musa_kernel",
|
||||
([&] {
|
||||
const scalar_t* bottom_data = features.data_ptr<scalar_t>();
|
||||
const scalar_t* bboxes_data = best_bboxes.data_ptr<scalar_t>();
|
||||
scalar_t* top_data = output.data_ptr<scalar_t>();
|
||||
|
||||
rotated_feature_align_forward_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, points, bottom_data, bboxes_data,
|
||||
scalar_t(spatial_scale), features.size(1), features.size(2),
|
||||
features.size(3), top_data);
|
||||
}));
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
||||
|
||||
void RotatedFeatureAlignBackwardMUSAKernelLauncher(const Tensor top_grad,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points,
|
||||
Tensor bottom_grad) {
|
||||
c10::musa::MUSAGuard device_guard(top_grad.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
const int output_size = top_grad.numel();
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
top_grad.scalar_type(), "rotated_feature_align_backward_musa_kernel",
|
||||
([&] {
|
||||
const scalar_t* top_diff = top_grad.data_ptr<scalar_t>();
|
||||
const scalar_t* bboxes_data = best_bboxes.data_ptr<scalar_t>();
|
||||
scalar_t* bottom_diff = bottom_grad.data_ptr<scalar_t>();
|
||||
|
||||
rotated_feature_align_backward_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, points, top_diff, bboxes_data,
|
||||
scalar_t(spatial_scale), top_grad.size(1), top_grad.size(2),
|
||||
top_grad.size(3), bottom_diff);
|
||||
}));
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
|
@ -0,0 +1,132 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
#include "scatter_points_musa_kernel.muh"
|
||||
|
||||
std::vector<at::Tensor> DynamicPointToVoxelForwardMUSAKernelLauncher(
|
||||
const at::Tensor &feats, const at::Tensor &coors,
|
||||
const reduce_t reduce_type) {
|
||||
const int num_input = feats.size(0);
|
||||
const int num_feats = feats.size(1);
|
||||
|
||||
if (num_input == 0)
|
||||
return {feats.clone().detach(), coors.clone().detach(),
|
||||
coors.new_empty({0}, torch::kInt32),
|
||||
coors.new_empty({0}, torch::kInt32)};
|
||||
|
||||
at::Tensor out_coors;
|
||||
at::Tensor coors_map;
|
||||
at::Tensor reduce_count;
|
||||
|
||||
auto coors_clean = coors.masked_fill(coors.lt(0).any(-1, true), -1);
|
||||
|
||||
std::tie(out_coors, coors_map, reduce_count) =
|
||||
at::unique_dim(coors_clean, 0, true, true, true);
|
||||
|
||||
if (out_coors[0][0].lt(0).item<bool>()) {
|
||||
// the first element of out_coors (-1,-1,-1) and should be removed
|
||||
out_coors = out_coors.slice(0, 1);
|
||||
reduce_count = reduce_count.slice(0, 1);
|
||||
coors_map = coors_map - 1;
|
||||
}
|
||||
|
||||
coors_map = coors_map.to(torch::kInt32);
|
||||
reduce_count = reduce_count.to(torch::kInt32);
|
||||
|
||||
auto reduced_feats =
|
||||
at::empty({out_coors.size(0), num_feats}, feats.options());
|
||||
|
||||
c10::musa::MUSAGuard device_guard(feats.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
feats.scalar_type(), "feats_reduce_kernel", ([&] {
|
||||
if (reduce_type == reduce_t::MAX)
|
||||
reduced_feats.fill_(-std::numeric_limits<scalar_t>::infinity());
|
||||
else
|
||||
reduced_feats.fill_(static_cast<scalar_t>(0));
|
||||
|
||||
dim3 blocks(std::min(
|
||||
at::musa::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim));
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
feats_reduce_kernel<<<blocks, threads, 0, stream>>>(
|
||||
feats.data_ptr<scalar_t>(), coors_map.data_ptr<int32_t>(),
|
||||
reduced_feats.data_ptr<scalar_t>(), num_input, num_feats,
|
||||
reduce_type);
|
||||
if (reduce_type == reduce_t::MEAN)
|
||||
reduced_feats /= reduce_count.unsqueeze(-1).to(reduced_feats.dtype());
|
||||
}));
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
|
||||
return {reduced_feats, out_coors, coors_map, reduce_count};
|
||||
}
|
||||
|
||||
void DynamicPointToVoxelBackwardMUSAKernelLauncher(
|
||||
at::Tensor &grad_feats, const at::Tensor &grad_reduced_feats,
|
||||
const at::Tensor &feats, const at::Tensor &reduced_feats,
|
||||
const at::Tensor &coors_map, const at::Tensor &reduce_count,
|
||||
const reduce_t reduce_type) {
|
||||
const int num_input = feats.size(0);
|
||||
const int num_reduced = reduced_feats.size(0);
|
||||
const int num_feats = feats.size(1);
|
||||
|
||||
grad_feats.fill_(0);
|
||||
// copy voxel grad to points
|
||||
|
||||
if (num_input == 0 || num_reduced == 0) return;
|
||||
c10::musa::MUSAGuard device_guard(feats.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
|
||||
if (reduce_type == reduce_t::MEAN || reduce_type == reduce_t::SUM) {
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
grad_reduced_feats.scalar_type(), "add_reduce_traceback_grad_kernel",
|
||||
([&] {
|
||||
dim3 blocks(std::min(
|
||||
at::musa::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim));
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
add_reduce_traceback_grad_kernel<<<blocks, threads, 0, stream>>>(
|
||||
grad_feats.data_ptr<scalar_t>(),
|
||||
grad_reduced_feats.data_ptr<scalar_t>(),
|
||||
coors_map.data_ptr<int32_t>(), reduce_count.data_ptr<int32_t>(),
|
||||
num_input, num_feats, reduce_type);
|
||||
}));
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
} else {
|
||||
auto reduce_from = at::full({num_reduced, num_feats}, num_input,
|
||||
coors_map.options().dtype(torch::kInt32));
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
grad_reduced_feats.scalar_type(),
|
||||
"max_reduce_traceback_scatter_idx_kernel", ([&] {
|
||||
dim3 blocks(std::min(
|
||||
at::musa::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim));
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
max_reduce_traceback_scatter_idx_kernel<<<blocks, threads, 0,
|
||||
stream>>>(
|
||||
feats.data_ptr<scalar_t>(), reduced_feats.data_ptr<scalar_t>(),
|
||||
reduce_from.data_ptr<int32_t>(), coors_map.data_ptr<int32_t>(),
|
||||
num_input, num_feats);
|
||||
}));
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
grad_reduced_feats.scalar_type(),
|
||||
"max_reduce_traceback_scatter_idx_kernel", ([&] {
|
||||
dim3 blocks(
|
||||
std::min(at::musa::ATenCeilDiv(num_reduced, THREADS_PER_BLOCK),
|
||||
maxGridDim));
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
max_reduce_scatter_grad_kernel<<<blocks, threads, 0, stream>>>(
|
||||
grad_feats.data_ptr<scalar_t>(),
|
||||
grad_reduced_feats.data_ptr<scalar_t>(),
|
||||
reduce_from.data_ptr<int32_t>(), num_reduced, num_feats);
|
||||
}));
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,486 @@
|
|||
// Copyright 2019 Yan Yan
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
// clang-format off
|
||||
// TODO: make spconv_utils.h order agnostic
|
||||
#include "../spconv_utils.h"
|
||||
// clang-format on
|
||||
#include <utils/spconv/spconv/maxpool.h>
|
||||
#include <utils/spconv/spconv/mp_helper.h>
|
||||
#include <utils/spconv/tensorview/helper_launch.h>
|
||||
#include <utils/spconv/tensorview/tensorview.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#include <utils/spconv/tensorview/helper_kernel.muh>
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
template <typename scalar_t, typename Index, int NumTLP, int NumILP>
|
||||
__global__ void maxPoolFwdBlockKernel(scalar_t *outFeatures,
|
||||
const scalar_t *inFeatures,
|
||||
const Index *indicesIn,
|
||||
const Index *indicesOut, int numHot,
|
||||
int numPlanes) {
|
||||
scalar_t in, out;
|
||||
int ILPStrideY[NumILP];
|
||||
Index idxo, idxi;
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ilp++)
|
||||
ILPStrideY[ilp] = threadIdx.y + ilp * blockDim.y;
|
||||
outFeatures += blockIdx.y * NumTLP;
|
||||
inFeatures += blockIdx.y * NumTLP;
|
||||
for (int ix = blockIdx.x * blockDim.x; ix < numHot;
|
||||
ix += blockDim.x * gridDim.x) {
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ++ilp) {
|
||||
idxi = indicesIn[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
|
||||
idxo = indicesOut[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
|
||||
in = inFeatures[idxi];
|
||||
out = outFeatures[idxo];
|
||||
if (in > out) {
|
||||
outFeatures[idxo] = in;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename Index, int NumTLP, int NumILP>
|
||||
__global__ void maxPoolFwdGenericBlockKernel(scalar_t *outFeatures,
|
||||
const scalar_t *inFeatures,
|
||||
const Index *indicesIn,
|
||||
const Index *indicesOut,
|
||||
int numHot, int numPlanes) {
|
||||
int ILPStrideX[NumILP];
|
||||
Index RI[NumILP];
|
||||
Index RO[NumILP];
|
||||
scalar_t in, out;
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ilp++)
|
||||
ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
|
||||
for (int ix : tv::KernelLoopX<int, NumILP>(numHot)) {
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ilp++) {
|
||||
RI[ilp] = indicesIn[ix + ILPStrideX[ilp]] * numPlanes;
|
||||
RO[ilp] = indicesOut[ix + ILPStrideX[ilp]] * numPlanes;
|
||||
}
|
||||
for (int iy : tv::KernelLoopY<int>(numPlanes)) {
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ++ilp) {
|
||||
in = inFeatures[RI[ilp] + iy];
|
||||
out = outFeatures[RO[ilp] + iy];
|
||||
if (in > out) {
|
||||
outFeatures[RO[ilp] + iy] = in;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename Index, int NumTLP, int NumILP,
|
||||
typename VecType>
|
||||
__global__ void maxPoolFwdVecBlockKernel(scalar_t *outFeatures,
|
||||
const scalar_t *inFeatures,
|
||||
const Index *indicesIn,
|
||||
const Index *indicesOut, int numHot,
|
||||
int numPlanes) {
|
||||
int ILPStrideY[NumILP];
|
||||
constexpr int vecloadFactor = sizeof(VecType) / sizeof(scalar_t);
|
||||
scalar_t bufi[vecloadFactor];
|
||||
scalar_t bufo[vecloadFactor];
|
||||
Index idxi, idxo;
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ilp++)
|
||||
ILPStrideY[ilp] = threadIdx.y + ilp * blockDim.y;
|
||||
outFeatures += blockIdx.y * NumTLP;
|
||||
inFeatures += blockIdx.y * NumTLP;
|
||||
for (int ix = blockIdx.x * blockDim.x * vecloadFactor; ix < numHot;
|
||||
ix += blockDim.x * gridDim.x * vecloadFactor) {
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ++ilp) {
|
||||
idxi = indicesIn[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
|
||||
idxo = indicesOut[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
|
||||
reinterpret_cast<VecType *>(bufo)[0] =
|
||||
reinterpret_cast<VecType *>(outFeatures)[idxo];
|
||||
reinterpret_cast<VecType *>(bufi)[0] =
|
||||
reinterpret_cast<const VecType *>(inFeatures)[idxi];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vecloadFactor; i++) {
|
||||
if (bufi[i] > bufo[i]) {
|
||||
bufo[i] = bufi[i];
|
||||
}
|
||||
}
|
||||
reinterpret_cast<VecType *>(outFeatures)[idxo] =
|
||||
reinterpret_cast<VecType *>(bufo)[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename Index, int NumTLP, int NumILP>
|
||||
__global__ void maxPoolFwdGenericKernel(scalar_t *outFeatures,
|
||||
const scalar_t *inFeatures,
|
||||
const Index *indicesIn,
|
||||
const Index *indicesOut, int numHot,
|
||||
int numPlanes) {
|
||||
int ILPStrideX[NumILP];
|
||||
Index RI[NumILP];
|
||||
Index RO[NumILP];
|
||||
scalar_t in, out;
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ilp++)
|
||||
ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
|
||||
for (int ix : tv::KernelLoopX<int, NumILP>(numHot)) {
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ilp++) {
|
||||
if (ix + ILPStrideX[ilp] < numHot) {
|
||||
RI[ilp] = indicesIn[ix + ILPStrideX[ilp]] * numPlanes;
|
||||
RO[ilp] = indicesOut[ix + ILPStrideX[ilp]] * numPlanes;
|
||||
}
|
||||
}
|
||||
for (int iy : tv::KernelLoopY<int>(numPlanes)) {
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ++ilp) {
|
||||
if (ix + ILPStrideX[ilp] < numHot) {
|
||||
in = inFeatures[RI[ilp] + iy];
|
||||
out = outFeatures[RO[ilp] + iy];
|
||||
if (in > out) {
|
||||
outFeatures[RO[ilp] + iy] = in;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename Index, int NumTLP, int NumILP>
|
||||
__global__ void maxPoolBwdBlockKernel(const scalar_t *outFeatures,
|
||||
const scalar_t *inFeatures,
|
||||
const scalar_t *fout, scalar_t *fin,
|
||||
const Index *indicesIn,
|
||||
const Index *indicesOut, int numHot,
|
||||
int numPlanes) {
|
||||
scalar_t in, out;
|
||||
Index idxo, idxi;
|
||||
int ILPStrideY[NumILP];
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ilp++)
|
||||
ILPStrideY[ilp] = threadIdx.y + ilp * blockDim.y;
|
||||
outFeatures += blockIdx.y * NumTLP;
|
||||
inFeatures += blockIdx.y * NumTLP;
|
||||
fout += blockIdx.y * NumTLP;
|
||||
fin += blockIdx.y * NumTLP;
|
||||
for (int ix = blockIdx.x * blockDim.x; ix < numHot;
|
||||
ix += blockDim.x * gridDim.x) {
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ++ilp) {
|
||||
idxi = indicesIn[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
|
||||
idxo = indicesOut[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
|
||||
in = inFeatures[idxi];
|
||||
out = outFeatures[idxo];
|
||||
if (in == out) {
|
||||
fin[idxi] += fout[idxo];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename Index, int NumTLP, int NumILP>
|
||||
__global__ void maxPoolBwdGenericBlockKernel(
|
||||
const scalar_t *outFeatures, const scalar_t *inFeatures,
|
||||
const scalar_t *fout, scalar_t *fin, const Index *indicesIn,
|
||||
const Index *indicesOut, int numHot, int numPlanes) {
|
||||
int ILPStrideX[NumILP];
|
||||
Index RI[NumILP];
|
||||
Index RO[NumILP];
|
||||
scalar_t in, out;
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ilp++)
|
||||
ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
|
||||
for (int ix : tv::KernelLoopX<int, NumILP>(numHot)) {
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ilp++) {
|
||||
RI[ilp] = indicesIn[ix + ILPStrideX[ilp]] * numPlanes;
|
||||
RO[ilp] = indicesOut[ix + ILPStrideX[ilp]] * numPlanes;
|
||||
}
|
||||
for (int iy : tv::KernelLoopY<int>(numPlanes)) {
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ++ilp) {
|
||||
in = inFeatures[RI[ilp] + iy];
|
||||
out = outFeatures[RO[ilp] + iy];
|
||||
if (in == out) {
|
||||
fin[RI[ilp] + iy] += fout[RO[ilp] + iy];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename Index, int NumTLP, int NumILP,
|
||||
typename VecType>
|
||||
__global__ void maxPoolBwdVecBlockKernel(const scalar_t *outFeatures,
|
||||
const scalar_t *inFeatures,
|
||||
const scalar_t *fout, scalar_t *fin,
|
||||
const Index *indicesIn,
|
||||
const Index *indicesOut, int numHot,
|
||||
int numPlanes) {
|
||||
int ILPStrideY[NumILP];
|
||||
constexpr int vecloadFactor = sizeof(VecType) / sizeof(scalar_t);
|
||||
scalar_t bufi[vecloadFactor];
|
||||
scalar_t bufo[vecloadFactor];
|
||||
scalar_t bufdi[vecloadFactor];
|
||||
scalar_t bufdo[vecloadFactor];
|
||||
Index idxi, idxo;
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ilp++)
|
||||
ILPStrideY[ilp] = threadIdx.y + ilp * blockDim.y;
|
||||
outFeatures += blockIdx.y * NumTLP;
|
||||
inFeatures += blockIdx.y * NumTLP;
|
||||
for (int ix = blockIdx.x * blockDim.x * vecloadFactor; ix < numHot;
|
||||
ix += blockDim.x * gridDim.x * vecloadFactor) {
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ++ilp) {
|
||||
idxi = indicesIn[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
|
||||
idxo = indicesOut[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
|
||||
reinterpret_cast<VecType *>(bufo)[0] =
|
||||
reinterpret_cast<const VecType *>(outFeatures)[idxo];
|
||||
reinterpret_cast<VecType *>(bufi)[0] =
|
||||
reinterpret_cast<const VecType *>(inFeatures)[idxi];
|
||||
reinterpret_cast<VecType *>(bufdo)[0] =
|
||||
reinterpret_cast<const VecType *>(fout)[idxo];
|
||||
reinterpret_cast<VecType *>(bufdi)[0] =
|
||||
reinterpret_cast<VecType *>(fin)[idxi];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vecloadFactor; i++) {
|
||||
if (bufi[i] == bufo[i]) {
|
||||
bufdi[i] += bufdo[i];
|
||||
}
|
||||
}
|
||||
reinterpret_cast<VecType *>(fin)[idxi] =
|
||||
reinterpret_cast<VecType *>(bufdi)[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename Index, int NumTLP, int NumILP>
|
||||
__global__ void maxPoolBwdGenericKernel(const scalar_t *outFeatures,
|
||||
const scalar_t *inFeatures,
|
||||
const scalar_t *fout, scalar_t *fin,
|
||||
const Index *indicesIn,
|
||||
const Index *indicesOut, int numHot,
|
||||
int numPlanes) {
|
||||
int ILPStrideX[NumILP];
|
||||
Index RI[NumILP];
|
||||
Index RO[NumILP];
|
||||
scalar_t in, out;
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ilp++)
|
||||
ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
|
||||
for (int ix : tv::KernelLoopX<int, NumILP>(numHot)) {
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ilp++) {
|
||||
if (ix + ILPStrideX[ilp] < numHot) {
|
||||
RI[ilp] = indicesIn[ix + ILPStrideX[ilp]] * numPlanes;
|
||||
RO[ilp] = indicesOut[ix + ILPStrideX[ilp]] * numPlanes;
|
||||
}
|
||||
}
|
||||
for (int iy : tv::KernelLoopY<int>(numPlanes)) {
|
||||
#pragma unroll
|
||||
for (int ilp = 0; ilp < NumILP; ++ilp) {
|
||||
if (ix + ILPStrideX[ilp] < numHot) {
|
||||
in = inFeatures[RI[ilp] + iy];
|
||||
out = outFeatures[RO[ilp] + iy];
|
||||
if (in == out) {
|
||||
fin[RI[ilp] + iy] += fout[RO[ilp] + iy];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace functor {
|
||||
template <typename scalar_t, typename Index>
|
||||
struct SparseMaxPoolForwardFunctor<tv::TorchGPU, scalar_t, Index> {
|
||||
using vecload_type_t =
|
||||
std::conditional_t<std::is_same<scalar_t, at::Half>::value, int2, int4>;
|
||||
using kernel_block_t = mp_list_c<int, 64, 32, 16>;
|
||||
void operator()(const tv::TorchGPU &d, tv::TensorView<scalar_t> outFeatures,
|
||||
tv::TensorView<const scalar_t> inFeatures,
|
||||
tv::TensorView<const Index> indices, int size) {
|
||||
if (size <= 0) return;
|
||||
int numPlanes = inFeatures.dim(1);
|
||||
bool notFound = true;
|
||||
constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(scalar_t);
|
||||
mp_for_each<kernel_block_t>([=, &outFeatures, &inFeatures, &indices,
|
||||
¬Found](auto NumTLP) {
|
||||
constexpr int NumILP = NumTLP / 4;
|
||||
|
||||
int numHotBlock = (size / NumTLP) * NumTLP;
|
||||
if (notFound) {
|
||||
if (numPlanes % NumTLP == 0) {
|
||||
if (numHotBlock >= NumTLP) {
|
||||
maxPoolFwdVecBlockKernel<scalar_t, Index, int(NumTLP), NumILP,
|
||||
vecload_type_t>
|
||||
<<<dim3(std::min(size / NumTLP, 512), numPlanes / NumTLP),
|
||||
dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0,
|
||||
d.getStream()>>>(outFeatures.data(), inFeatures.data(),
|
||||
indices.subview(0).data(),
|
||||
indices.subview(1).data(), numHotBlock,
|
||||
numPlanes / vecloadFactor);
|
||||
TV_CHECK_MUSA_ERR();
|
||||
}
|
||||
|
||||
if (size > numHotBlock) {
|
||||
maxPoolFwdGenericKernel<scalar_t, Index, int(NumTLP), NumILP>
|
||||
<<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP),
|
||||
0, d.getStream()>>>(outFeatures.data(), inFeatures.data(),
|
||||
indices.subview(0).data() + numHotBlock,
|
||||
indices.subview(1).data() + numHotBlock,
|
||||
size - numHotBlock, numPlanes);
|
||||
TV_CHECK_MUSA_ERR();
|
||||
}
|
||||
notFound = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (notFound) {
|
||||
constexpr int NumTLP = 64;
|
||||
constexpr int NumILP = NumTLP / 4;
|
||||
int numHotBlock = (size / NumTLP) * NumTLP;
|
||||
if (numHotBlock >= NumTLP) {
|
||||
maxPoolFwdGenericBlockKernel<scalar_t, Index, NumTLP, NumILP>
|
||||
<<<dim3(size / NumTLP, tv::launch::DivUp(numPlanes, NumTLP)),
|
||||
dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>(
|
||||
outFeatures.data(), inFeatures.data(),
|
||||
indices.subview(0).data(), indices.subview(1).data(),
|
||||
numHotBlock, numPlanes);
|
||||
TV_CHECK_MUSA_ERR();
|
||||
}
|
||||
|
||||
if (size > numHotBlock) {
|
||||
maxPoolFwdGenericKernel<scalar_t, Index, NumTLP, NumILP>
|
||||
<<<dim3(1, tv::launch::DivUp(numPlanes, NumTLP)),
|
||||
dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>(
|
||||
outFeatures.data(), inFeatures.data(),
|
||||
indices.subview(0).data() + numHotBlock,
|
||||
indices.subview(1).data() + numHotBlock, size - numHotBlock,
|
||||
numPlanes);
|
||||
TV_CHECK_MUSA_ERR();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, typename Index>
|
||||
struct SparseMaxPoolBackwardFunctor<tv::TorchGPU, scalar_t, Index> {
|
||||
using vecload_type_t =
|
||||
std::conditional_t<std::is_same<scalar_t, at::Half>::value, int2, int4>;
|
||||
using kernel_block_t = mp_list_c<int, 64, 32, 16>;
|
||||
void operator()(const tv::TorchGPU &d,
|
||||
tv::TensorView<const scalar_t> outFeatures,
|
||||
tv::TensorView<const scalar_t> inFeatures,
|
||||
tv::TensorView<const scalar_t> fout,
|
||||
tv::TensorView<scalar_t> fin,
|
||||
tv::TensorView<const Index> indices, int size) {
|
||||
if (size <= 0) return;
|
||||
int numPlanes = inFeatures.dim(1);
|
||||
bool notFound = true;
|
||||
constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(scalar_t);
|
||||
mp_for_each<kernel_block_t>([=, &outFeatures, &inFeatures, &fout, &fin,
|
||||
&indices, ¬Found](auto NumTLP) {
|
||||
constexpr int NumILP = NumTLP / 4;
|
||||
|
||||
int numHotBlock = (size / NumTLP) * NumTLP;
|
||||
if (notFound) {
|
||||
if (numPlanes % NumTLP == 0) {
|
||||
if (numHotBlock >= NumTLP) {
|
||||
maxPoolBwdVecBlockKernel<scalar_t, Index, int(NumTLP), NumILP,
|
||||
vecload_type_t>
|
||||
<<<dim3(std::min(size / NumTLP, 512), numPlanes / NumTLP),
|
||||
dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0,
|
||||
d.getStream()>>>(outFeatures.data(), inFeatures.data(),
|
||||
fout.data(), fin.data(),
|
||||
indices.subview(0).data(),
|
||||
indices.subview(1).data(), numHotBlock,
|
||||
numPlanes / vecloadFactor);
|
||||
TV_CHECK_MUSA_ERR();
|
||||
}
|
||||
|
||||
if (size > numHotBlock) {
|
||||
maxPoolBwdGenericKernel<scalar_t, Index, int(NumTLP), NumILP>
|
||||
<<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP),
|
||||
0, d.getStream()>>>(outFeatures.data(), inFeatures.data(),
|
||||
fout.data(), fin.data(),
|
||||
indices.subview(0).data() + numHotBlock,
|
||||
indices.subview(1).data() + numHotBlock,
|
||||
size - numHotBlock, numPlanes);
|
||||
TV_CHECK_MUSA_ERR();
|
||||
}
|
||||
notFound = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (notFound) {
|
||||
constexpr int NumTLP = 64;
|
||||
constexpr int NumILP = NumTLP / 4;
|
||||
int numHotBlock = (size / NumTLP) * NumTLP;
|
||||
if (numHotBlock >= NumTLP) {
|
||||
maxPoolBwdGenericBlockKernel<scalar_t, Index, NumTLP, NumILP>
|
||||
<<<dim3(size / NumTLP, tv::launch::DivUp(numPlanes, NumTLP)),
|
||||
dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>(
|
||||
outFeatures.data(), inFeatures.data(), fout.data(), fin.data(),
|
||||
indices.subview(0).data(), indices.subview(1).data(),
|
||||
numHotBlock, numPlanes);
|
||||
TV_CHECK_MUSA_ERR();
|
||||
}
|
||||
|
||||
if (size > numHotBlock) {
|
||||
maxPoolBwdGenericKernel<scalar_t, Index, NumTLP, NumILP>
|
||||
<<<dim3(1, tv::launch::DivUp(numPlanes, NumTLP)),
|
||||
dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>(
|
||||
outFeatures.data(), inFeatures.data(), fout.data(), fin.data(),
|
||||
indices.subview(0).data() + numHotBlock,
|
||||
indices.subview(1).data() + numHotBlock, size - numHotBlock,
|
||||
numPlanes);
|
||||
TV_CHECK_MUSA_ERR();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#define DECLARE_GPU_SPECS_T_INDEX(scalar_t, Index) \
|
||||
template struct functor::SparseMaxPoolForwardFunctor<tv::TorchGPU, scalar_t, \
|
||||
Index>; \
|
||||
template struct functor::SparseMaxPoolBackwardFunctor<tv::TorchGPU, \
|
||||
scalar_t, Index>;
|
||||
|
||||
#define DECLARE_GPU_SPECS(scalar_t) DECLARE_GPU_SPECS_T_INDEX(scalar_t, int);
|
||||
|
||||
DECLARE_GPU_SPECS(float);
|
||||
DECLARE_GPU_SPECS(double);
|
||||
DECLARE_GPU_SPECS(at::Half);
|
||||
|
||||
#undef DECLARE_GPU_SPECS
|
||||
#undef DECLARE_GPU_SPECS_T_INDEX
|
|
@ -0,0 +1,91 @@
|
|||
#include <musa_runtime_api.h>
|
||||
#include <torch/script.h>
|
||||
// clang-format off
|
||||
// TODO: make spconv_utils.h order agnostic
|
||||
#include "../spconv_utils.h"
|
||||
// clang-format on
|
||||
#include <utils/spconv/spconv/maxpool.h>
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
|
||||
torch::Tensor IndiceMaxpoolForwardMUSAKernelLauncher(torch::Tensor features,
|
||||
torch::Tensor indicePairs,
|
||||
torch::Tensor indiceNum,
|
||||
int64_t numAct) {
|
||||
c10::musa::MUSAGuard device_guard(features.device());
|
||||
auto device = features.device().type();
|
||||
auto kernelVolume = indicePairs.size(0);
|
||||
auto numInPlanes = features.size(1);
|
||||
auto indicePairNumCpu = indiceNum.to({torch::kCPU});
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(features.dtype()).device(features.device());
|
||||
torch::Tensor output = torch::zeros({numAct, numInPlanes}, options);
|
||||
for (int i = 0; i < kernelVolume; ++i) {
|
||||
auto nHot = indicePairNumCpu.data_ptr<int>()[i];
|
||||
if (nHot <= 0) {
|
||||
continue;
|
||||
}
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
features.scalar_type(), "IndiceMaxpoolForwardKernel", [&] {
|
||||
if (device == torch::kCPU) {
|
||||
functor::SparseMaxPoolForwardFunctor<tv::CPU, scalar_t, int>
|
||||
forwardFtor;
|
||||
forwardFtor(tv::CPU(), tv::torch2tv<scalar_t>(output),
|
||||
tv::torch2tv<const scalar_t>(features),
|
||||
tv::torch2tv<const int>(indicePairs).subview(i), nHot);
|
||||
} else {
|
||||
functor::SparseMaxPoolForwardFunctor<tv::TorchGPU, scalar_t, int>
|
||||
forwardFtor;
|
||||
forwardFtor(tv::TorchGPU(), tv::torch2tv<scalar_t>(output),
|
||||
tv::torch2tv<const scalar_t>(features),
|
||||
tv::torch2tv<const int>(indicePairs).subview(i), nHot);
|
||||
TV_CHECK_MUSA_ERR();
|
||||
}
|
||||
});
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
torch::Tensor IndiceMaxpoolBackwardMUSAKernelLauncher(torch::Tensor features,
|
||||
torch::Tensor outFeatures,
|
||||
torch::Tensor outGrad,
|
||||
torch::Tensor indicePairs,
|
||||
torch::Tensor indiceNum) {
|
||||
c10::musa::MUSAGuard device_guard(features.device());
|
||||
auto device = features.device().type();
|
||||
auto numInPlanes = features.size(1);
|
||||
auto indicePairNumCpu = indiceNum.to({torch::kCPU});
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(features.dtype()).device(features.device());
|
||||
torch::Tensor inputGrad = torch::zeros(features.sizes(), options);
|
||||
auto kernelVolume = indicePairs.size(0);
|
||||
for (int i = 0; i < kernelVolume; ++i) {
|
||||
auto nHot = indicePairNumCpu.data_ptr<int>()[i];
|
||||
if (nHot <= 0) {
|
||||
continue;
|
||||
}
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
features.scalar_type(), "IndiceMaxpoolBackwardKernel", [&] {
|
||||
if (device == torch::kCPU) {
|
||||
functor::SparseMaxPoolBackwardFunctor<tv::CPU, scalar_t, int>
|
||||
backwardFtor;
|
||||
backwardFtor(tv::CPU(), tv::torch2tv<const scalar_t>(outFeatures),
|
||||
tv::torch2tv<const scalar_t>(features),
|
||||
tv::torch2tv<const scalar_t>(outGrad),
|
||||
tv::torch2tv<scalar_t>(inputGrad),
|
||||
tv::torch2tv<const int>(indicePairs).subview(i), nHot);
|
||||
} else {
|
||||
functor::SparseMaxPoolBackwardFunctor<tv::TorchGPU, scalar_t, int>
|
||||
backwardFtor;
|
||||
backwardFtor(tv::TorchGPU(),
|
||||
tv::torch2tv<const scalar_t>(outFeatures),
|
||||
tv::torch2tv<const scalar_t>(features),
|
||||
tv::torch2tv<const scalar_t>(outGrad),
|
||||
tv::torch2tv<scalar_t>(inputGrad),
|
||||
tv::torch2tv<const int>(indicePairs).subview(i), nHot);
|
||||
TV_CHECK_MUSA_ERR();
|
||||
}
|
||||
});
|
||||
}
|
||||
return inputGrad;
|
||||
}
|
|
@ -0,0 +1,110 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
#include "sync_bn_musa_kernel.muh"
|
||||
|
||||
void SyncBNForwardMeanMUSAKernelLauncher(const Tensor input, Tensor mean) {
|
||||
int num = input.size(0);
|
||||
int channels = input.size(1);
|
||||
int spatial = input.size(2);
|
||||
|
||||
c10::musa::MUSAGuard device_guard(input.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "sync_bn_forward_mean_musa_kernel", [&] {
|
||||
sync_bn_forward_mean_musa_kernel<scalar_t>
|
||||
<<<channels, THREADS_PER_BLOCK, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(), mean.data_ptr<float>(), num,
|
||||
channels, spatial);
|
||||
});
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
||||
|
||||
void SyncBNForwardVarMUSAKernelLauncher(const Tensor input, const Tensor mean,
|
||||
Tensor var) {
|
||||
int num = input.size(0);
|
||||
int channels = input.size(1);
|
||||
int spatial = input.size(2);
|
||||
|
||||
c10::musa::MUSAGuard device_guard(input.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "sync_bn_forward_mean_musa_kernel", [&] {
|
||||
sync_bn_forward_var_musa_kernel<scalar_t>
|
||||
<<<channels, THREADS_PER_BLOCK, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(), mean.data_ptr<float>(),
|
||||
var.data_ptr<float>(), num, channels, spatial);
|
||||
});
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
||||
|
||||
void SyncBNForwardOutputMUSAKernelLauncher(
|
||||
const Tensor input, const Tensor mean, const Tensor var,
|
||||
Tensor running_mean, Tensor running_var, const Tensor weight,
|
||||
const Tensor bias, Tensor norm, Tensor std, Tensor output, float eps,
|
||||
float momentum, int group_size) {
|
||||
int num = input.size(0);
|
||||
int channels = input.size(1);
|
||||
int spatial = input.size(2);
|
||||
|
||||
c10::musa::MUSAGuard device_guard(input.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "sync_bn_forward_mean_musa_kernel", [&] {
|
||||
sync_bn_forward_output_musa_kernel<scalar_t>
|
||||
<<<channels, THREADS_PER_BLOCK, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(), mean.data_ptr<float>(),
|
||||
var.data_ptr<float>(), running_mean.data_ptr<float>(),
|
||||
running_var.data_ptr<float>(), weight.data_ptr<float>(),
|
||||
bias.data_ptr<float>(), norm.data_ptr<float>(),
|
||||
std.data_ptr<float>(), output.data_ptr<scalar_t>(), num,
|
||||
channels, spatial, eps, momentum, group_size);
|
||||
});
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
||||
|
||||
void SyncBNBackwardParamMUSAKernelLauncher(const Tensor grad_output,
|
||||
const Tensor norm,
|
||||
Tensor grad_weight,
|
||||
Tensor grad_bias) {
|
||||
int num = grad_output.size(0);
|
||||
int channels = grad_output.size(1);
|
||||
int spatial = grad_output.size(2);
|
||||
|
||||
c10::musa::MUSAGuard device_guard(grad_output.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
grad_output.scalar_type(), "sync_bn_backward_param_musa_kernel", [&] {
|
||||
sync_bn_backward_param_musa_kernel<scalar_t>
|
||||
<<<channels, THREADS_PER_BLOCK, 0, stream>>>(
|
||||
grad_output.data_ptr<scalar_t>(), norm.data_ptr<float>(),
|
||||
grad_weight.data_ptr<float>(), grad_bias.data_ptr<float>(), num,
|
||||
channels, spatial);
|
||||
});
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
||||
|
||||
void SyncBNBackwardDataMUSAKernelLauncher(const Tensor grad_output,
|
||||
const Tensor weight,
|
||||
const Tensor grad_weight,
|
||||
const Tensor grad_bias,
|
||||
const Tensor norm, const Tensor std,
|
||||
Tensor grad_input) {
|
||||
int output_size = grad_input.numel();
|
||||
int num = grad_input.size(0);
|
||||
int channels = grad_input.size(1);
|
||||
int spatial = grad_input.size(2);
|
||||
|
||||
c10::musa::MUSAGuard device_guard(grad_input.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
grad_output.scalar_type(), "sync_bn_backward_data_musa_kernel", [&] {
|
||||
sync_bn_backward_data_musa_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, grad_output.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<float>(), grad_weight.data_ptr<float>(),
|
||||
grad_bias.data_ptr<float>(), norm.data_ptr<float>(),
|
||||
std.data_ptr<float>(), grad_input.data_ptr<scalar_t>(), num,
|
||||
channels, spatial);
|
||||
});
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate_gpu.cu
|
||||
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
#include "three_interpolate_musa_kernel.muh"
|
||||
|
||||
void ThreeInterpolateForwardMUSAKernelLauncher(int b, int c, int m, int n,
|
||||
const Tensor points,
|
||||
const Tensor idx,
|
||||
const Tensor weight,
|
||||
Tensor out) {
|
||||
// points: (B, C, M)
|
||||
// idx: (B, N, 3)
|
||||
// weight: (B, N, 3)
|
||||
// output:
|
||||
// out: (B, C, N)
|
||||
|
||||
c10::musa::MUSAGuard device_guard(points.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(GET_BLOCKS(n, THREADS_PER_BLOCK), c, b);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
points.scalar_type(), "three_interpolate_forward_musa_kernel", [&] {
|
||||
three_interpolate_forward_musa_kernel<scalar_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
b, c, m, n, points.data_ptr<scalar_t>(), idx.data_ptr<int>(),
|
||||
weight.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
|
||||
});
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
||||
|
||||
void ThreeInterpolateBackwardMUSAKernelLauncher(int b, int c, int n, int m,
|
||||
const Tensor grad_out,
|
||||
const Tensor idx,
|
||||
const Tensor weight,
|
||||
Tensor grad_points) {
|
||||
// grad_out: (B, C, N)
|
||||
// weight: (B, N, 3)
|
||||
// output:
|
||||
// grad_points: (B, C, M)
|
||||
|
||||
c10::musa::MUSAGuard device_guard(grad_out.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(GET_BLOCKS(n, THREADS_PER_BLOCK), c, b);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grad_out.scalar_type(), "three_interpolate_backward_musa_kernel", [&] {
|
||||
three_interpolate_backward_musa_kernel<scalar_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
b, c, n, m, grad_out.data_ptr<scalar_t>(), idx.data_ptr<int>(),
|
||||
weight.data_ptr<scalar_t>(), grad_points.data_ptr<scalar_t>());
|
||||
});
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate_gpu.cu
|
||||
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
#include "three_nn_musa_kernel.muh"
|
||||
|
||||
void ThreeNNForwardMUSAKernelLauncher(int b, int n, int m, const Tensor unknown,
|
||||
const Tensor known, Tensor dist2,
|
||||
Tensor idx) {
|
||||
// unknown: (B, N, 3)
|
||||
// known: (B, M, 3)
|
||||
// output:
|
||||
// dist2: (B, N, 3)
|
||||
// idx: (B, N, 3)
|
||||
|
||||
c10::musa::MUSAGuard device_guard(unknown.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(GET_BLOCKS(n, THREADS_PER_BLOCK), b);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
unknown.scalar_type(), "three_nn_forward_musa_kernel", [&] {
|
||||
three_nn_forward_musa_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
b, n, m, unknown.data_ptr<scalar_t>(), known.data_ptr<scalar_t>(),
|
||||
dist2.data_ptr<scalar_t>(), idx.data_ptr<int>());
|
||||
});
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
#include "pytorch_device_registry.hpp"
|
||||
#include "tin_shift_musa_kernel.muh"
|
||||
|
||||
void TINShiftForwardMUSAKernelLauncher(Tensor input, Tensor shift,
|
||||
Tensor output) {
|
||||
int output_size = output.numel();
|
||||
int batch_size = input.size(0);
|
||||
int t_size = input.size(1);
|
||||
int channels = input.size(2);
|
||||
int hw_size = input.size(3);
|
||||
int group_size = shift.size(1);
|
||||
int group_channel = channels / group_size;
|
||||
int num_kernels = batch_size * hw_size * channels;
|
||||
|
||||
c10::musa::MUSAGuard device_guard(input.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input.scalar_type(), "tin_shift_forward_musa_kernel", [&] {
|
||||
tin_shift_forward_musa_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, input.data_ptr<scalar_t>(), shift.data_ptr<int>(),
|
||||
output.data_ptr<scalar_t>(), batch_size, channels, t_size,
|
||||
hw_size, group_size, group_channel);
|
||||
});
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
||||
|
||||
void TINShiftBackwardMUSAKernelLauncher(Tensor grad_output, Tensor shift,
|
||||
Tensor grad_input) {
|
||||
int output_size = grad_output.numel();
|
||||
int batch_size = grad_output.size(0);
|
||||
int t_size = grad_output.size(1);
|
||||
int channels = grad_output.size(2);
|
||||
int hw_size = grad_output.size(3);
|
||||
int group_size = shift.size(1);
|
||||
int group_channel = channels / group_size;
|
||||
int num_kernels = batch_size * hw_size * channels;
|
||||
|
||||
c10::musa::MUSAGuard device_guard(grad_output.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grad_output.scalar_type(), "tin_shift_backward_musa_kernel", [&] {
|
||||
tin_shift_backward_musa_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, grad_output.data_ptr<scalar_t>(),
|
||||
shift.data_ptr<int>(), grad_input.data_ptr<scalar_t>(),
|
||||
batch_size, channels, t_size, hw_size, group_size,
|
||||
group_channel);
|
||||
});
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
|
@ -0,0 +1,749 @@
|
|||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
#include <c10/util/Half.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
#if MUSA_ARCH > 21
|
||||
struct upfirdn2d_kernel_params {
|
||||
const void *x;
|
||||
const float *f;
|
||||
void *y;
|
||||
|
||||
int2 up;
|
||||
int2 down;
|
||||
int2 pad0;
|
||||
int flip;
|
||||
float gain;
|
||||
|
||||
int4 inSize; // [width, height, channel, batch]
|
||||
int4 inStride;
|
||||
int2 filterSize; // [width, height]
|
||||
int2 filterStride;
|
||||
int4 outSize; // [width, height, channel, batch]
|
||||
int4 outStride;
|
||||
int sizeMinor;
|
||||
int sizeMajor;
|
||||
|
||||
int loopMinor;
|
||||
int loopMajor;
|
||||
int loopX;
|
||||
int launchMinor;
|
||||
int launchMajor;
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// MUSA kernel specialization.
|
||||
|
||||
struct upfirdn2d_kernel_spec {
|
||||
void *kernel;
|
||||
int tileOutW;
|
||||
int tileOutH;
|
||||
int loopMinor;
|
||||
int loopX;
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// MUSA kernel selection.
|
||||
|
||||
template <class T>
|
||||
upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params &p);
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Helpers.
|
||||
|
||||
template <class T>
|
||||
struct InternalType;
|
||||
template <>
|
||||
struct InternalType<double> {
|
||||
typedef double scalar_t;
|
||||
};
|
||||
template <>
|
||||
struct InternalType<float> {
|
||||
typedef float scalar_t;
|
||||
};
|
||||
template <>
|
||||
struct InternalType<c10::Half> {
|
||||
typedef float scalar_t;
|
||||
};
|
||||
|
||||
static __device__ __forceinline__ int floor_div(int a, int b) {
|
||||
int t = 1 - a / b;
|
||||
return (a + t * b) / b - t;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Generic MUSA implementation for large filters.
|
||||
|
||||
template <class T>
|
||||
static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) {
|
||||
typedef typename InternalType<T>::scalar_t scalar_t;
|
||||
|
||||
// Calculate thread index.
|
||||
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int outY = minorBase / p.launchMinor;
|
||||
minorBase -= outY * p.launchMinor;
|
||||
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
|
||||
int majorBase = blockIdx.z * p.loopMajor;
|
||||
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
|
||||
return;
|
||||
|
||||
// Setup Y receptive field.
|
||||
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
|
||||
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
|
||||
int h =
|
||||
min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
|
||||
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
|
||||
if (p.flip) filterY = p.filterSize.y - 1 - filterY;
|
||||
|
||||
// Loop over major, minor, and X.
|
||||
for (int majorIdx = 0, major = majorBase;
|
||||
majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
||||
for (int minorIdx = 0, minor = minorBase;
|
||||
minorIdx < p.loopMinor & minor < p.sizeMinor;
|
||||
minorIdx++, minor += p.launchMinor) {
|
||||
int nc = major * p.sizeMinor + minor;
|
||||
int n = nc / p.inSize.z;
|
||||
int c = nc - n * p.inSize.z;
|
||||
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x;
|
||||
loopX++, outX += blockDim.y) {
|
||||
// Setup X receptive field.
|
||||
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
|
||||
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
|
||||
int w =
|
||||
min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) -
|
||||
inX;
|
||||
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
|
||||
if (p.flip) filterX = p.filterSize.x - 1 - filterX;
|
||||
|
||||
// Initialize pointers.
|
||||
const T *xp =
|
||||
&((const T *)p.x)[inX * p.inStride.x + inY * p.inStride.y +
|
||||
c * p.inStride.z + n * p.inStride.w];
|
||||
const float *fp =
|
||||
&p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
|
||||
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
|
||||
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
|
||||
|
||||
// Inner loop.
|
||||
scalar_t v = 0;
|
||||
for (int y = 0; y < h; y++) {
|
||||
for (int x = 0; x < w; x++) {
|
||||
v += (scalar_t)(*xp) * (scalar_t)(*fp);
|
||||
xp += p.inStride.x;
|
||||
fp += filterStepX;
|
||||
}
|
||||
xp += p.inStride.y - w * p.inStride.x;
|
||||
fp += filterStepY - w * filterStepX;
|
||||
}
|
||||
|
||||
// Store result.
|
||||
v *= p.gain;
|
||||
((T *)p.y)[outX * p.outStride.x + outY * p.outStride.y +
|
||||
c * p.outStride.z + n * p.outStride.w] = (T)v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Specialized MUSA implementation for small filters.
|
||||
|
||||
template <class T, int upx, int upy, int downx, int downy, int filterW,
|
||||
int filterH, int tileOutW, int tileOutH, int loopMinor>
|
||||
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) {
|
||||
typedef typename InternalType<T>::scalar_t scalar_t;
|
||||
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
|
||||
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
|
||||
__shared__ volatile scalar_t sf[filterH][filterW];
|
||||
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
|
||||
|
||||
// Calculate tile index.
|
||||
int minorBase = blockIdx.x;
|
||||
int tileOutY = minorBase / p.launchMinor;
|
||||
minorBase -= tileOutY * p.launchMinor;
|
||||
minorBase *= loopMinor;
|
||||
tileOutY *= tileOutH;
|
||||
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
|
||||
int majorBase = blockIdx.z * p.loopMajor;
|
||||
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y |
|
||||
majorBase >= p.sizeMajor)
|
||||
return;
|
||||
|
||||
// Load filter (flipped).
|
||||
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW;
|
||||
tapIdx += blockDim.x) {
|
||||
int fy = tapIdx / filterW;
|
||||
int fx = tapIdx - fy * filterW;
|
||||
scalar_t v = 0;
|
||||
if (fx < p.filterSize.x & fy < p.filterSize.y) {
|
||||
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
|
||||
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
|
||||
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
|
||||
}
|
||||
sf[fy][fx] = v;
|
||||
}
|
||||
|
||||
// Loop over major and X.
|
||||
for (int majorIdx = 0, major = majorBase;
|
||||
majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) {
|
||||
int baseNC = major * p.sizeMinor + minorBase;
|
||||
int n = baseNC / p.inSize.z;
|
||||
int baseC = baseNC - n * p.inSize.z;
|
||||
for (int loopX = 0, tileOutX = tileOutXBase;
|
||||
loopX < p.loopX & tileOutX < p.outSize.x;
|
||||
loopX++, tileOutX += tileOutW) {
|
||||
// Load input pixels.
|
||||
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
|
||||
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
|
||||
int tileInX = floor_div(tileMidX, upx);
|
||||
int tileInY = floor_div(tileMidY, upy);
|
||||
__syncthreads();
|
||||
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor;
|
||||
inIdx += blockDim.x) {
|
||||
int relC = inIdx;
|
||||
int relInX = relC / loopMinor;
|
||||
int relInY = relInX / tileInW;
|
||||
relC -= relInX * loopMinor;
|
||||
relInX -= relInY * tileInW;
|
||||
int c = baseC + relC;
|
||||
int inX = tileInX + relInX;
|
||||
int inY = tileInY + relInY;
|
||||
scalar_t v = 0;
|
||||
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y &
|
||||
c < p.inSize.z)
|
||||
v = (scalar_t)(
|
||||
(const T *)p.x)[inX * p.inStride.x + inY * p.inStride.y +
|
||||
c * p.inStride.z + n * p.inStride.w];
|
||||
sx[relInY][relInX][relC] = v;
|
||||
}
|
||||
|
||||
// Loop over output pixels.
|
||||
__syncthreads();
|
||||
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor;
|
||||
outIdx += blockDim.x) {
|
||||
int relC = outIdx;
|
||||
int relOutX = relC / loopMinor;
|
||||
int relOutY = relOutX / tileOutW;
|
||||
relC -= relOutX * loopMinor;
|
||||
relOutX -= relOutY * tileOutW;
|
||||
int c = baseC + relC;
|
||||
int outX = tileOutX + relOutX;
|
||||
int outY = tileOutY + relOutY;
|
||||
|
||||
// Setup receptive field.
|
||||
int midX = tileMidX + relOutX * downx;
|
||||
int midY = tileMidY + relOutY * downy;
|
||||
int inX = floor_div(midX, upx);
|
||||
int inY = floor_div(midY, upy);
|
||||
int relInX = inX - tileInX;
|
||||
int relInY = inY - tileInY;
|
||||
int filterX = (inX + 1) * upx - midX - 1; // flipped
|
||||
int filterY = (inY + 1) * upy - midY - 1; // flipped
|
||||
|
||||
// Inner loop.
|
||||
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) {
|
||||
scalar_t v = 0;
|
||||
#pragma unroll
|
||||
for (int y = 0; y < filterH / upy; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < filterW / upx; x++)
|
||||
v += sx[relInY + y][relInX + x][relC] *
|
||||
sf[filterY + y * upy][filterX + x * upx];
|
||||
v *= p.gain;
|
||||
((T *)p.y)[outX * p.outStride.x + outY * p.outStride.y +
|
||||
c * p.outStride.z + n * p.outStride.w] = (T)v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// MUSA kernel selection.
|
||||
|
||||
template <class T>
|
||||
upfirdn2d_kernel_spec choose_upfirdn2d_kernel(
|
||||
const upfirdn2d_kernel_params &p) {
|
||||
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
|
||||
upfirdn2d_kernel_spec spec = {(void *)upfirdn2d_kernel_large<T>, -1, -1, 1,
|
||||
4}; // contiguous
|
||||
if (s == 1)
|
||||
spec = {(void *)upfirdn2d_kernel_large<T>, -1, -1, 4, 1}; // channels_last
|
||||
|
||||
// No up/downsampling.
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) {
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 24)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 24, 24, 64, 32, 1>,
|
||||
64, 32, 1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 16)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 16, 16, 64, 32, 1>,
|
||||
64, 32, 1, 1};
|
||||
if (s != 1 && fx <= 7 && fy <= 7)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 7, 7, 64, 16, 1>,
|
||||
64, 16, 1, 1};
|
||||
if (s != 1 && fx <= 6 && fy <= 6)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 6, 6, 64, 16, 1>,
|
||||
64, 16, 1, 1};
|
||||
if (s != 1 && fx <= 5 && fy <= 5)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 5, 5, 64, 16, 1>,
|
||||
64, 16, 1, 1};
|
||||
if (s != 1 && fx <= 4 && fy <= 4)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 4, 4, 64, 16, 1>,
|
||||
64, 16, 1, 1};
|
||||
if (s != 1 && fx <= 3 && fy <= 3)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 3, 3, 64, 16, 1>,
|
||||
64, 16, 1, 1};
|
||||
if (s != 1 && fx <= 24 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 24, 1, 128, 8, 1>,
|
||||
128, 8, 1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 16, 1, 128, 8, 1>,
|
||||
128, 8, 1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 8, 1, 128, 8, 1>,
|
||||
128, 8, 1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 24)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 24, 32, 32, 1>,
|
||||
32, 32, 1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 16)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 16, 32, 32, 1>,
|
||||
32, 32, 1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 8)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 8, 32, 32, 1>,
|
||||
32, 32, 1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 24)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 24, 24, 32, 32, 1>,
|
||||
32, 32, 1, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 16)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 16, 16, 32, 32, 1>,
|
||||
32, 32, 1, 1};
|
||||
if (s == 1 && fx <= 7 && fy <= 7)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 7, 7, 16, 16, 8>,
|
||||
16, 16, 8, 1};
|
||||
if (s == 1 && fx <= 6 && fy <= 6)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 6, 6, 16, 16, 8>,
|
||||
16, 16, 8, 1};
|
||||
if (s == 1 && fx <= 5 && fy <= 5)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 5, 5, 16, 16, 8>,
|
||||
16, 16, 8, 1};
|
||||
if (s == 1 && fx <= 4 && fy <= 4)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 4, 4, 16, 16, 8>,
|
||||
16, 16, 8, 1};
|
||||
if (s == 1 && fx <= 3 && fy <= 3)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 3, 3, 16, 16, 8>,
|
||||
16, 16, 8, 1};
|
||||
if (s == 1 && fx <= 24 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 24, 1, 128, 1, 16>,
|
||||
128, 1, 16, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 16, 1, 128, 1, 16>,
|
||||
128, 1, 16, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 8, 1, 128, 1, 16>,
|
||||
128, 1, 16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 24)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 24, 1, 128, 16>,
|
||||
1, 128, 16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 16)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 16, 1, 128, 16>,
|
||||
1, 128, 16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 8)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 8, 1, 128, 16>,
|
||||
1, 128, 16, 1};
|
||||
}
|
||||
|
||||
// 2x upsampling.
|
||||
if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) {
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 24)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 24, 24, 64, 32, 1>,
|
||||
64, 32, 1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 16)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 16, 16, 64, 32, 1>,
|
||||
64, 32, 1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 8)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 8, 8, 64, 16, 1>,
|
||||
64, 16, 1, 1};
|
||||
if (s != 1 && fx <= 6 && fy <= 6)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 6, 6, 64, 16, 1>,
|
||||
64, 16, 1, 1};
|
||||
if (s != 1 && fx <= 4 && fy <= 4)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 4, 4, 64, 16, 1>,
|
||||
64, 16, 1, 1};
|
||||
if (s != 1 && fx <= 2 && fy <= 2)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 2, 2, 64, 16, 1>,
|
||||
64, 16, 1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 24)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 24, 24, 32, 32, 1>,
|
||||
32, 32, 1, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 16)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 16, 16, 32, 32, 1>,
|
||||
32, 32, 1, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 8)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 8, 8, 16, 16, 8>,
|
||||
16, 16, 8, 1};
|
||||
if (s == 1 && fx <= 6 && fy <= 6)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 6, 6, 16, 16, 8>,
|
||||
16, 16, 8, 1};
|
||||
if (s == 1 && fx <= 4 && fy <= 4)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 4, 4, 16, 16, 8>,
|
||||
16, 16, 8, 1};
|
||||
if (s == 1 && fx <= 2 && fy <= 2)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 2, 2, 16, 16, 8>,
|
||||
16, 16, 8, 1};
|
||||
}
|
||||
if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) {
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 24, 1, 128, 8, 1>,
|
||||
128, 8, 1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 16, 1, 128, 8, 1>,
|
||||
128, 8, 1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 8, 1, 128, 8, 1>,
|
||||
128, 8, 1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 24, 1, 128, 1, 16>,
|
||||
128, 1, 16, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 16, 1, 128, 1, 16>,
|
||||
128, 1, 16, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 8, 1, 128, 1, 16>,
|
||||
128, 1, 16, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) {
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 1 && fy <= 24)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 24, 32, 32, 1>,
|
||||
32, 32, 1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 16)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 16, 32, 32, 1>,
|
||||
32, 32, 1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 8)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 8, 32, 32, 1>,
|
||||
32, 32, 1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 1 && fy <= 24)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 24, 1, 128, 16>,
|
||||
1, 128, 16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 16)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 16, 1, 128, 16>,
|
||||
1, 128, 16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 8)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 8, 1, 128, 16>,
|
||||
1, 128, 16, 1};
|
||||
}
|
||||
|
||||
// 2x downsampling.
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) {
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 24)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 24, 24, 32, 16, 1>,
|
||||
32, 16, 1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 16)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 16, 16, 32, 16, 1>,
|
||||
32, 16, 1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 8)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 8, 8, 32, 8, 1>, 32,
|
||||
8, 1, 1};
|
||||
if (s != 1 && fx <= 6 && fy <= 6)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 6, 6, 32, 8, 1>, 32,
|
||||
8, 1, 1};
|
||||
if (s != 1 && fx <= 4 && fy <= 4)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 4, 4, 32, 8, 1>, 32,
|
||||
8, 1, 1};
|
||||
if (s != 1 && fx <= 2 && fy <= 2)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 2, 2, 32, 8, 1>, 32,
|
||||
8, 1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 24)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 24, 24, 16, 16, 1>,
|
||||
16, 16, 1, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 16)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 16, 16, 16, 16, 1>,
|
||||
16, 16, 1, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 8)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 8, 8, 8, 8, 8>, 8,
|
||||
8, 8, 1};
|
||||
if (s == 1 && fx <= 6 && fy <= 6)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 6, 6, 8, 8, 8>, 8,
|
||||
8, 8, 1};
|
||||
if (s == 1 && fx <= 4 && fy <= 4)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 4, 4, 8, 8, 8>, 8,
|
||||
8, 8, 1};
|
||||
if (s == 1 && fx <= 2 && fy <= 2)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 2, 2, 8, 8, 8>, 8,
|
||||
8, 8, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) {
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 24, 1, 64, 8, 1>,
|
||||
64, 8, 1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 16, 1, 64, 8, 1>,
|
||||
64, 8, 1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 8, 1, 64, 8, 1>, 64,
|
||||
8, 1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 24, 1, 64, 1, 8>,
|
||||
64, 1, 8, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 16, 1, 64, 1, 8>,
|
||||
64, 1, 8, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 8, 1, 64, 1, 8>, 64,
|
||||
1, 8, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) {
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 1 && fy <= 24)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 24, 32, 16, 1>,
|
||||
32, 16, 1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 16)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 16, 32, 16, 1>,
|
||||
32, 16, 1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 8)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 8, 32, 16, 1>,
|
||||
32, 16, 1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 1 && fy <= 24)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 24, 1, 64, 8>, 1,
|
||||
64, 8, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 16)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 16, 1, 64, 8>, 1,
|
||||
64, 8, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 8)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 8, 1, 64, 8>, 1,
|
||||
64, 8, 1};
|
||||
}
|
||||
|
||||
// 4x upsampling.
|
||||
if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) {
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 48 && fy <= 48)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 4, 4, 1, 1, 48, 48, 64, 32, 1>,
|
||||
64, 32, 1, 1};
|
||||
if (s != 1 && fx <= 32 && fy <= 32)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 4, 4, 1, 1, 32, 32, 64, 32, 1>,
|
||||
64, 32, 1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 48 && fy <= 48)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 4, 4, 1, 1, 48, 48, 32, 32, 1>,
|
||||
32, 32, 1, 1};
|
||||
if (s == 1 && fx <= 32 && fy <= 32)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 4, 4, 1, 1, 32, 32, 32, 32, 1>,
|
||||
32, 32, 1, 1};
|
||||
}
|
||||
if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) {
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 48 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 4, 1, 1, 1, 48, 1, 128, 8, 1>,
|
||||
128, 8, 1, 1};
|
||||
if (s != 1 && fx <= 32 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 4, 1, 1, 1, 32, 1, 128, 8, 1>,
|
||||
128, 8, 1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 48 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 4, 1, 1, 1, 48, 1, 128, 1, 16>,
|
||||
128, 1, 16, 1};
|
||||
if (s == 1 && fx <= 32 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 4, 1, 1, 1, 32, 1, 128, 1, 16>,
|
||||
128, 1, 16, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) {
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 1 && fy <= 48)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 4, 1, 1, 1, 48, 32, 32, 1>,
|
||||
32, 32, 1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 32)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 4, 1, 1, 1, 32, 32, 32, 1>,
|
||||
32, 32, 1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 1 && fy <= 48)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 4, 1, 1, 1, 48, 1, 128, 16>,
|
||||
1, 128, 16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 32)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 4, 1, 1, 1, 32, 1, 128, 16>,
|
||||
1, 128, 16, 1};
|
||||
}
|
||||
|
||||
// 4x downsampling (inefficient).
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1) {
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 48 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 4, 1, 48, 1, 32, 8, 1>,
|
||||
32, 8, 1, 1};
|
||||
if (s != 1 && fx <= 32 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 4, 1, 32, 1, 32, 8, 1>,
|
||||
32, 8, 1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 48 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 4, 1, 48, 1, 32, 1, 8>,
|
||||
32, 1, 8, 1};
|
||||
if (s == 1 && fx <= 32 && fy <= 1)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 4, 1, 32, 1, 32, 1, 8>,
|
||||
32, 1, 8, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4) {
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 1 && fy <= 48)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 4, 1, 48, 32, 8, 1>,
|
||||
32, 8, 1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 32)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 4, 1, 32, 32, 8, 1>,
|
||||
32, 8, 1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 1 && fy <= 48)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 4, 1, 48, 1, 32, 8>, 1,
|
||||
32, 8, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 32)
|
||||
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 4, 1, 32, 1, 32, 8>, 1,
|
||||
32, 8, 1};
|
||||
}
|
||||
return spec;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Template specializations.
|
||||
|
||||
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double>(
|
||||
const upfirdn2d_kernel_params &p);
|
||||
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float>(
|
||||
const upfirdn2d_kernel_params &p);
|
||||
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(
|
||||
const upfirdn2d_kernel_params &p);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
torch::Tensor upfirdn2d_op(torch::Tensor x, torch::Tensor f, int upx, int upy,
|
||||
int downx, int downy, int padx0, int padx1,
|
||||
int pady0, int pady1, bool flip, float gain) {
|
||||
// Validate arguments.
|
||||
TORCH_CHECK(x.is_privateuseone(), "x must reside on MUSA device");
|
||||
TORCH_CHECK(f.device() == x.device(),
|
||||
"f must reside on the same device as x");
|
||||
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
|
||||
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
||||
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
|
||||
TORCH_CHECK(x.numel() > 0, "x has zero size");
|
||||
TORCH_CHECK(f.numel() > 0, "f has zero size");
|
||||
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
||||
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
|
||||
TORCH_CHECK((x.size(0) - 1) * x.stride(0) + (x.size(1) - 1) * x.stride(1) +
|
||||
(x.size(2) - 1) * x.stride(2) +
|
||||
(x.size(3) - 1) * x.stride(3) <=
|
||||
INT_MAX,
|
||||
"x memory footprint is too large");
|
||||
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
|
||||
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
|
||||
TORCH_CHECK(downx >= 1 && downy >= 1,
|
||||
"downsampling factor must be at least 1");
|
||||
|
||||
// Create output tensor.
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(x));
|
||||
int outW =
|
||||
((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
|
||||
int outH =
|
||||
((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
|
||||
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
|
||||
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW},
|
||||
x.options(), x.suggest_memory_format());
|
||||
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
|
||||
TORCH_CHECK((y.size(0) - 1) * y.stride(0) + (y.size(1) - 1) * y.stride(1) +
|
||||
(y.size(2) - 1) * y.stride(2) +
|
||||
(y.size(3) - 1) * y.stride(3) <=
|
||||
INT_MAX,
|
||||
"output memory footprint is too large");
|
||||
|
||||
// Initialize MUSA kernel parameters.
|
||||
upfirdn2d_kernel_params p;
|
||||
p.x = x.data_ptr();
|
||||
p.f = f.data_ptr<float>();
|
||||
p.y = y.data_ptr();
|
||||
p.up = make_int2(upx, upy);
|
||||
p.down = make_int2(downx, downy);
|
||||
p.pad0 = make_int2(padx0, pady0);
|
||||
p.flip = (flip) ? 1 : 0;
|
||||
p.gain = gain;
|
||||
p.inSize =
|
||||
make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
||||
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1),
|
||||
(int)x.stride(0));
|
||||
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
|
||||
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
|
||||
p.outSize =
|
||||
make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
|
||||
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1),
|
||||
(int)y.stride(0));
|
||||
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
|
||||
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
|
||||
|
||||
// Choose MUSA kernel.
|
||||
upfirdn2d_kernel_spec spec;
|
||||
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "upfirdn2d_musa", [&] {
|
||||
spec = choose_upfirdn2d_kernel<scalar_t>(p);
|
||||
});
|
||||
|
||||
// Set looping options.
|
||||
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
|
||||
p.loopMinor = spec.loopMinor;
|
||||
p.loopX = spec.loopX;
|
||||
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
|
||||
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
|
||||
|
||||
// Compute grid size.
|
||||
dim3 blockSize, gridSize;
|
||||
if (spec.tileOutW < 0) // large
|
||||
{
|
||||
blockSize = dim3(4, 32, 1);
|
||||
gridSize =
|
||||
dim3(((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
|
||||
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, p.launchMajor);
|
||||
} else // small
|
||||
{
|
||||
blockSize = dim3(256, 1, 1);
|
||||
gridSize =
|
||||
dim3(((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
|
||||
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, p.launchMajor);
|
||||
}
|
||||
|
||||
// Launch MUSA kernel.
|
||||
void *args[] = {&p};
|
||||
#ifdef MMCV_WITH_HIP
|
||||
AT_MUSA_CHECK(hipLaunchKernel(spec.kernel, gridSize, blockSize, args, 0,
|
||||
c10::musa::getCurrentMUSAStream()));
|
||||
#else
|
||||
AT_MUSA_CHECK(musaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0,
|
||||
c10::musa::getCurrentMUSAStream()));
|
||||
#endif
|
||||
|
||||
return y;
|
||||
}
|
||||
#else
|
||||
#warning "upfirdn2d is supported when MUSA_ARCH > 21"
|
||||
#endif //MUSA_ARCH
|
|
@ -0,0 +1,286 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "pytorch_musa_helper.hpp"
|
||||
#include "voxelization_musa_kernel.muh"
|
||||
|
||||
int HardVoxelizeForwardMUSAKernelLauncher(
|
||||
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
|
||||
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range, const int max_points,
|
||||
const int max_voxels, const int NDim = 3) {
|
||||
// current version tooks about 0.04s for one frame on cpu
|
||||
// check device
|
||||
|
||||
c10::musa::MUSAGuard device_guard(points.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
|
||||
const int num_points = points.size(0);
|
||||
const int num_features = points.size(1);
|
||||
|
||||
const float voxel_x = voxel_size[0];
|
||||
const float voxel_y = voxel_size[1];
|
||||
const float voxel_z = voxel_size[2];
|
||||
const float coors_x_min = coors_range[0];
|
||||
const float coors_y_min = coors_range[1];
|
||||
const float coors_z_min = coors_range[2];
|
||||
const float coors_x_max = coors_range[3];
|
||||
const float coors_y_max = coors_range[4];
|
||||
const float coors_z_max = coors_range[5];
|
||||
|
||||
const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
|
||||
const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
|
||||
const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
|
||||
|
||||
// map points to voxel coors
|
||||
at::Tensor temp_coors =
|
||||
at::zeros({num_points, NDim}, points.options().dtype(at::kInt));
|
||||
|
||||
dim3 grid(std::min(at::musa::ATenCeilDiv(num_points, 512), 4096));
|
||||
dim3 block(512);
|
||||
|
||||
// 1. link point to corresponding voxel coors
|
||||
AT_DISPATCH_ALL_TYPES(
|
||||
points.scalar_type(), "hard_voxelize_kernel", ([&] {
|
||||
dynamic_voxelize_kernel<scalar_t, int><<<grid, block, 0, stream>>>(
|
||||
points.contiguous().data_ptr<scalar_t>(),
|
||||
temp_coors.contiguous().data_ptr<int>(), voxel_x, voxel_y, voxel_z,
|
||||
coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max,
|
||||
coors_z_max, grid_x, grid_y, grid_z, num_points, num_features,
|
||||
NDim);
|
||||
}));
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
|
||||
// 2. map point to the idx of the corresponding voxel, find duplicate coor
|
||||
// create some temporary variables
|
||||
auto point_to_pointidx = -at::ones(
|
||||
{
|
||||
num_points,
|
||||
},
|
||||
points.options().dtype(at::kInt));
|
||||
auto point_to_voxelidx = -at::ones(
|
||||
{
|
||||
num_points,
|
||||
},
|
||||
points.options().dtype(at::kInt));
|
||||
|
||||
dim3 map_grid(std::min(at::musa::ATenCeilDiv(num_points, 512), 4096));
|
||||
dim3 map_block(512);
|
||||
|
||||
AT_DISPATCH_ALL_TYPES(
|
||||
temp_coors.scalar_type(), "determin_duplicate", ([&] {
|
||||
point_to_voxelidx_kernel<int><<<map_grid, map_block, 0, stream>>>(
|
||||
temp_coors.contiguous().data_ptr<int>(),
|
||||
point_to_voxelidx.contiguous().data_ptr<int>(),
|
||||
point_to_pointidx.contiguous().data_ptr<int>(), max_points,
|
||||
max_voxels, num_points, NDim);
|
||||
}));
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
|
||||
// 3. determine voxel num and voxel's coor index
|
||||
// make the logic in the MUSA device could accelerate about 10 times
|
||||
auto coor_to_voxelidx = -at::ones(
|
||||
{
|
||||
num_points,
|
||||
},
|
||||
points.options().dtype(at::kInt));
|
||||
auto voxel_num = at::zeros(
|
||||
{
|
||||
1,
|
||||
},
|
||||
points.options().dtype(at::kInt)); // must be zero from the beginning
|
||||
|
||||
AT_DISPATCH_ALL_TYPES(temp_coors.scalar_type(), "determin_duplicate", ([&] {
|
||||
determin_voxel_num<int><<<1, 1, 0, stream>>>(
|
||||
num_points_per_voxel.contiguous().data_ptr<int>(),
|
||||
point_to_voxelidx.contiguous().data_ptr<int>(),
|
||||
point_to_pointidx.contiguous().data_ptr<int>(),
|
||||
coor_to_voxelidx.contiguous().data_ptr<int>(),
|
||||
voxel_num.contiguous().data_ptr<int>(),
|
||||
max_points, max_voxels, num_points);
|
||||
}));
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
|
||||
// 4. copy point features to voxels
|
||||
// Step 4 & 5 could be parallel
|
||||
auto pts_output_size = num_points * num_features;
|
||||
dim3 cp_grid(std::min(at::musa::ATenCeilDiv(pts_output_size, 512), 4096));
|
||||
dim3 cp_block(512);
|
||||
AT_DISPATCH_ALL_TYPES(
|
||||
points.scalar_type(), "assign_point_to_voxel", ([&] {
|
||||
assign_point_to_voxel<float, int><<<cp_grid, cp_block, 0, stream>>>(
|
||||
pts_output_size, points.contiguous().data_ptr<float>(),
|
||||
point_to_voxelidx.contiguous().data_ptr<int>(),
|
||||
coor_to_voxelidx.contiguous().data_ptr<int>(),
|
||||
voxels.contiguous().data_ptr<float>(), max_points, num_features,
|
||||
num_points, NDim);
|
||||
}));
|
||||
// musaDeviceSynchronize();
|
||||
// AT_MUSA_CHECK(musaGetLastError());
|
||||
|
||||
// 5. copy coors of each voxels
|
||||
auto coors_output_size = num_points * NDim;
|
||||
dim3 coors_cp_grid(
|
||||
std::min(at::musa::ATenCeilDiv(coors_output_size, 512), 4096));
|
||||
dim3 coors_cp_block(512);
|
||||
AT_DISPATCH_ALL_TYPES(
|
||||
points.scalar_type(), "assign_point_to_voxel", ([&] {
|
||||
assign_voxel_coors<float, int>
|
||||
<<<coors_cp_grid, coors_cp_block, 0, stream>>>(
|
||||
coors_output_size, temp_coors.contiguous().data_ptr<int>(),
|
||||
point_to_voxelidx.contiguous().data_ptr<int>(),
|
||||
coor_to_voxelidx.contiguous().data_ptr<int>(),
|
||||
coors.contiguous().data_ptr<int>(), num_points, NDim);
|
||||
}));
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
|
||||
auto voxel_num_cpu = voxel_num.to(at::kCPU);
|
||||
int voxel_num_int = voxel_num_cpu.data_ptr<int>()[0];
|
||||
|
||||
return voxel_num_int;
|
||||
}
|
||||
|
||||
int NondeterministicHardVoxelizeForwardMUSAKernelLauncher(
|
||||
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
|
||||
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range, const int max_points,
|
||||
const int max_voxels, const int NDim = 3) {
|
||||
c10::musa::MUSAGuard device_guard(points.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
|
||||
const int num_points = points.size(0);
|
||||
const int num_features = points.size(1);
|
||||
|
||||
if (num_points == 0) return 0;
|
||||
|
||||
dim3 blocks(
|
||||
std::min(at::musa::ATenCeilDiv(num_points, THREADS_PER_BLOCK), 4096));
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
const float voxel_x = voxel_size[0];
|
||||
const float voxel_y = voxel_size[1];
|
||||
const float voxel_z = voxel_size[2];
|
||||
const float coors_x_min = coors_range[0];
|
||||
const float coors_y_min = coors_range[1];
|
||||
const float coors_z_min = coors_range[2];
|
||||
const float coors_x_max = coors_range[3];
|
||||
const float coors_y_max = coors_range[4];
|
||||
const float coors_z_max = coors_range[5];
|
||||
|
||||
const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
|
||||
const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
|
||||
const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
|
||||
|
||||
// map points to voxel coors
|
||||
at::Tensor temp_coors =
|
||||
at::zeros({num_points, NDim}, points.options().dtype(at::kInt));
|
||||
|
||||
// 1. link point to corresponding voxel coors
|
||||
AT_DISPATCH_ALL_TYPES(
|
||||
points.scalar_type(), "hard_voxelize_kernel", ([&] {
|
||||
dynamic_voxelize_kernel<scalar_t, int><<<blocks, threads, 0, stream>>>(
|
||||
points.contiguous().data_ptr<scalar_t>(),
|
||||
temp_coors.contiguous().data_ptr<int>(), voxel_x, voxel_y, voxel_z,
|
||||
coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max,
|
||||
coors_z_max, grid_x, grid_y, grid_z, num_points, num_features,
|
||||
NDim);
|
||||
}));
|
||||
|
||||
at::Tensor coors_map;
|
||||
at::Tensor reduce_count;
|
||||
|
||||
auto coors_clean = temp_coors.masked_fill(temp_coors.lt(0).any(-1, true), -1);
|
||||
|
||||
std::tie(temp_coors, coors_map, reduce_count) =
|
||||
at::unique_dim(coors_clean, 0, true, true, false);
|
||||
|
||||
if (temp_coors[0][0].lt(0).item<bool>()) {
|
||||
// the first element of temp_coors is (-1,-1,-1) and should be removed
|
||||
temp_coors = temp_coors.slice(0, 1);
|
||||
coors_map = coors_map - 1;
|
||||
}
|
||||
|
||||
int num_coors = temp_coors.size(0);
|
||||
temp_coors = temp_coors.to(at::kInt);
|
||||
coors_map = coors_map.to(at::kInt);
|
||||
|
||||
at::Tensor coors_count = at::zeros({1}, coors_map.options());
|
||||
at::Tensor coors_order = at::empty({num_coors}, coors_map.options());
|
||||
at::Tensor pts_id = at::zeros({num_points}, coors_map.options());
|
||||
reduce_count = at::zeros({num_coors}, coors_map.options());
|
||||
|
||||
AT_DISPATCH_ALL_TYPES(
|
||||
points.scalar_type(), "get_assign_pos", ([&] {
|
||||
nondeterministic_get_assign_pos<<<blocks, threads, 0, stream>>>(
|
||||
num_points, coors_map.contiguous().data_ptr<int32_t>(),
|
||||
pts_id.contiguous().data_ptr<int32_t>(),
|
||||
coors_count.contiguous().data_ptr<int32_t>(),
|
||||
reduce_count.contiguous().data_ptr<int32_t>(),
|
||||
coors_order.contiguous().data_ptr<int32_t>());
|
||||
}));
|
||||
|
||||
AT_DISPATCH_ALL_TYPES(
|
||||
points.scalar_type(), "assign_point_to_voxel", ([&] {
|
||||
nondeterministic_assign_point_voxel<scalar_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
num_points, points.contiguous().data_ptr<scalar_t>(),
|
||||
coors_map.contiguous().data_ptr<int32_t>(),
|
||||
pts_id.contiguous().data_ptr<int32_t>(),
|
||||
temp_coors.contiguous().data_ptr<int32_t>(),
|
||||
reduce_count.contiguous().data_ptr<int32_t>(),
|
||||
coors_order.contiguous().data_ptr<int32_t>(),
|
||||
voxels.contiguous().data_ptr<scalar_t>(),
|
||||
coors.contiguous().data_ptr<int32_t>(),
|
||||
num_points_per_voxel.contiguous().data_ptr<int32_t>(),
|
||||
max_voxels, max_points, num_features, NDim);
|
||||
}));
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
return max_voxels < num_coors ? max_voxels : num_coors;
|
||||
}
|
||||
|
||||
void DynamicVoxelizeForwardMUSAKernelLauncher(
|
||||
const at::Tensor &points, at::Tensor &coors,
|
||||
const std::vector<float> voxel_size, const std::vector<float> coors_range,
|
||||
const int NDim = 3) {
|
||||
// current version tooks about 0.04s for one frame on cpu
|
||||
// check device
|
||||
|
||||
c10::musa::MUSAGuard device_guard(points.device());
|
||||
musaStream_t stream = c10::musa::getCurrentMUSAStream();
|
||||
|
||||
const int num_points = points.size(0);
|
||||
const int num_features = points.size(1);
|
||||
|
||||
const float voxel_x = voxel_size[0];
|
||||
const float voxel_y = voxel_size[1];
|
||||
const float voxel_z = voxel_size[2];
|
||||
const float coors_x_min = coors_range[0];
|
||||
const float coors_y_min = coors_range[1];
|
||||
const float coors_z_min = coors_range[2];
|
||||
const float coors_x_max = coors_range[3];
|
||||
const float coors_y_max = coors_range[4];
|
||||
const float coors_z_max = coors_range[5];
|
||||
|
||||
const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
|
||||
const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
|
||||
const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
|
||||
|
||||
const int col_blocks = at::musa::ATenCeilDiv(num_points, THREADS_PER_BLOCK);
|
||||
dim3 blocks(col_blocks);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_ALL_TYPES(points.scalar_type(), "dynamic_voxelize_kernel", [&] {
|
||||
dynamic_voxelize_kernel<scalar_t, int><<<blocks, threads, 0, stream>>>(
|
||||
points.contiguous().data_ptr<scalar_t>(),
|
||||
coors.contiguous().data_ptr<int>(), voxel_x, voxel_y, voxel_z,
|
||||
coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max,
|
||||
coors_z_max, grid_x, grid_y, grid_z, num_points, num_features, NDim);
|
||||
});
|
||||
|
||||
AT_MUSA_CHECK(musaGetLastError());
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
from mmengine.device import is_cuda_available, is_musa_available
|
||||
from torch import Tensor
|
||||
|
||||
from ..utils import ext_loader
|
||||
|
@ -10,7 +11,7 @@ ext_module = ext_loader.load_ext('_ext', [
|
|||
|
||||
|
||||
def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor:
|
||||
"""Find the box in which each point is (CUDA).
|
||||
"""Find the box in which each point is (CUDA/MUSA).
|
||||
|
||||
Args:
|
||||
points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate.
|
||||
|
@ -38,7 +39,7 @@ def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor:
|
|||
|
||||
# If manually put the tensor 'points' or 'boxes' on a device
|
||||
# which is not the current device, some temporary variables
|
||||
# will be created on the current device in the cuda op,
|
||||
# will be created on the current device in the cuda/musa op,
|
||||
# and the output will be incorrect.
|
||||
# Therefore, we force the current device to be the same
|
||||
# as the device of the tensors if it was not.
|
||||
|
@ -48,8 +49,12 @@ def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor:
|
|||
assert points_device == boxes.get_device(), \
|
||||
'Points and boxes should be put on the same device'
|
||||
if points.device.type != 'npu':
|
||||
if torch.cuda.current_device() != points_device:
|
||||
torch.cuda.set_device(points_device)
|
||||
if is_cuda_available():
|
||||
if torch.cuda.current_device() != points_device:
|
||||
torch.cuda.set_device(points_device)
|
||||
elif is_musa_available():
|
||||
if torch.musa.current_device() != points_device:
|
||||
torch.musa.set_device(points_device)
|
||||
else:
|
||||
boxes[:, :, 2] += boxes[:, :, 5] / 2.0
|
||||
|
||||
|
@ -99,7 +104,7 @@ def points_in_boxes_cpu(points: Tensor, boxes: Tensor) -> Tensor:
|
|||
|
||||
|
||||
def points_in_boxes_all(points: Tensor, boxes: Tensor) -> Tensor:
|
||||
"""Find all boxes in which each point is (CUDA).
|
||||
"""Find all boxes in which each point is (CUDA/MUSA).
|
||||
|
||||
Args:
|
||||
points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate
|
||||
|
@ -131,8 +136,12 @@ def points_in_boxes_all(points: Tensor, boxes: Tensor) -> Tensor:
|
|||
assert points_device == boxes.get_device(), \
|
||||
'Points and boxes should be put on the same device'
|
||||
if points.device.type != 'npu':
|
||||
if torch.cuda.current_device() != points_device:
|
||||
torch.cuda.set_device(points_device)
|
||||
if is_cuda_available():
|
||||
if torch.cuda.current_device() != points_device:
|
||||
torch.cuda.set_device(points_device)
|
||||
elif is_musa_available():
|
||||
if torch.musa.current_device() != points_device:
|
||||
torch.musa.set_device(points_device)
|
||||
|
||||
ext_module.points_in_boxes_all_forward(boxes.contiguous(),
|
||||
points.contiguous(),
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import Optional
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from mmengine.device import is_cuda_available, is_musa_available
|
||||
from mmengine.registry import MODELS
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
@ -47,10 +48,20 @@ class SyncBatchNormFunction(Function):
|
|||
self.group_size = group_size
|
||||
self.stats_mode = stats_mode
|
||||
|
||||
assert isinstance(
|
||||
input, (torch.HalfTensor, torch.FloatTensor,
|
||||
torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \
|
||||
f'only support Half or Float Tensor, but {input.type()}'
|
||||
if is_cuda_available():
|
||||
assert isinstance(
|
||||
input, (torch.HalfTensor, torch.FloatTensor,
|
||||
torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \
|
||||
f'only support Half or Float Tensor, but {input.type()}'
|
||||
elif is_musa_available():
|
||||
assert isinstance(
|
||||
input, (torch.HalfTensor, torch.FloatTensor,
|
||||
torch.musa.HalfTensor, torch.musa.FloatTensor)), \
|
||||
f'only support Half or Float Tensor, but {input.type()}'
|
||||
else:
|
||||
assert isinstance(
|
||||
input, (torch.HalfTensor, torch.FloatTensor)), \
|
||||
f'only support Half or Float Tensor, but {input.type()}'
|
||||
output = torch.zeros_like(input)
|
||||
input3d = input.flatten(start_dim=2)
|
||||
output3d = output.view_as(input3d)
|
||||
|
|
|
@ -116,6 +116,13 @@ def upfirdn2d(input: torch.Tensor,
|
|||
padding=padding,
|
||||
flip_filter=flip_filter,
|
||||
gain=gain).apply(input, filter)
|
||||
elif use_custom_op and input.device.type == 'musa':
|
||||
return _upfirdn2d_musa(
|
||||
up=up,
|
||||
down=down,
|
||||
padding=padding,
|
||||
flip_filter=flip_filter,
|
||||
gain=gain).apply(input, filter)
|
||||
return _upfirdn2d_ref(
|
||||
input,
|
||||
filter,
|
||||
|
@ -303,6 +310,101 @@ def _upfirdn2d_cuda(up: int = 1,
|
|||
return Upfirdn2dCuda
|
||||
|
||||
|
||||
_upfirdn2d_musa_cache: Dict = dict()
|
||||
|
||||
|
||||
def _upfirdn2d_musa(up: int = 1,
|
||||
down: int = 1,
|
||||
padding: Union[int, List[int]] = 0,
|
||||
flip_filter: bool = False,
|
||||
gain: Union[float, int] = 1):
|
||||
"""Fast MUSA implementation of `upfirdn2d()` using custom ops.
|
||||
|
||||
Args:
|
||||
up (int): Integer upsampling factor. Can be a single int or a
|
||||
list/tuple `[x, y]`. Defaults to 1.
|
||||
down (int): Integer downsampling factor. Can be a single int
|
||||
or a list/tuple `[x, y]`. Defaults to 1.
|
||||
padding (int | tuple[int]): Padding with respect to the upsampled
|
||||
image. Can be a single number or a list/tuple `[x, y]` or
|
||||
`[x_before, x_after, y_before, y_after]`. Defaults to 0.
|
||||
flip_filter (bool): False = convolution, True = correlation.
|
||||
Defaults to False.
|
||||
gain (int): Overall scaling factor for signal magnitude.
|
||||
Defaults to 1.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Tensor of the shape `[batch_size, num_channels,
|
||||
out_height, out_width]`
|
||||
"""
|
||||
# Parse arguments.
|
||||
upx, upy = _parse_scaling(up)
|
||||
downx, downy = _parse_scaling(down)
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
|
||||
# Lookup from cache.
|
||||
key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter,
|
||||
gain)
|
||||
if key in _upfirdn2d_musa_cache:
|
||||
return _upfirdn2d_musa_cache[key]
|
||||
|
||||
# Forward op.
|
||||
class Upfirdn2dMusa(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, f): # pylint: disable=arguments-differ
|
||||
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
||||
if f is None:
|
||||
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
||||
if f.ndim == 1 and f.shape[0] == 1:
|
||||
f = f.square().unsqueeze(
|
||||
0) # Convert separable-1 into full-1x1.
|
||||
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
||||
y = x
|
||||
if f.ndim == 2:
|
||||
y = ext_module.upfirdn2d(y, f, upx, upy, downx, downy, padx0,
|
||||
padx1, pady0, pady1, flip_filter,
|
||||
gain)
|
||||
else:
|
||||
y = ext_module.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1,
|
||||
padx0, padx1, 0, 0, flip_filter, 1.0)
|
||||
y = ext_module.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy,
|
||||
0, 0, pady0, pady1, flip_filter, gain)
|
||||
ctx.save_for_backward(f)
|
||||
ctx.x_shape = x.shape
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy): # pylint: disable=arguments-differ
|
||||
f, = ctx.saved_tensors
|
||||
_, _, ih, iw = ctx.x_shape
|
||||
_, _, oh, ow = dy.shape
|
||||
fw, fh = _get_filter_size(f)
|
||||
p = [
|
||||
fw - padx0 - 1,
|
||||
iw * upx - ow * downx + padx0 - upx + 1,
|
||||
fh - pady0 - 1,
|
||||
ih * upy - oh * downy + pady0 - upy + 1,
|
||||
]
|
||||
dx = None
|
||||
df = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
dx = _upfirdn2d_musa(
|
||||
up=down,
|
||||
down=up,
|
||||
padding=p,
|
||||
flip_filter=(not flip_filter),
|
||||
gain=gain).apply(dy, f)
|
||||
|
||||
assert not ctx.needs_input_grad[1]
|
||||
return dx, df
|
||||
|
||||
# Add to cache.
|
||||
_upfirdn2d_musa_cache[key] = Upfirdn2dMusa
|
||||
return Upfirdn2dMusa
|
||||
|
||||
|
||||
def filter2d(input: torch.Tensor,
|
||||
filter: torch.Tensor,
|
||||
padding: Union[int, List[int]] = 0,
|
||||
|
|
|
@ -4,7 +4,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
from mmcv.ops import points_in_polygons
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE, IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
@pytest.mark.parametrize('device', [
|
||||
|
@ -15,7 +15,11 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
|
|||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
|
||||
])
|
||||
def test_points_in_polygons(device):
|
||||
points = np.array([[300., 300.], [400., 400.], [100., 100], [300, 250],
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE
|
||||
|
||||
_USING_PARROTS = True
|
||||
try:
|
||||
|
@ -41,7 +41,11 @@ class TestPrRoiPool:
|
|||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
|
||||
])
|
||||
def test_roipool_gradcheck(self, device):
|
||||
from mmcv.ops import PrRoIPool
|
||||
|
@ -92,7 +96,11 @@ class TestPrRoiPool:
|
|||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
|
||||
])
|
||||
def test_roipool_allclose_float(self, device):
|
||||
self._test_roipool_allclose(device, dtype=torch.float)
|
||||
|
|
|
@ -4,7 +4,8 @@ import pytest
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
|
||||
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
|
||||
IS_NPU_AVAILABLE)
|
||||
|
||||
|
||||
class Loss(nn.Module):
|
||||
|
@ -32,7 +33,11 @@ class TestPSAMask:
|
|||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
|
||||
])
|
||||
def test_psa_mask_collect(self, device):
|
||||
from mmcv.ops import PSAMask
|
||||
|
@ -84,7 +89,11 @@ class TestPSAMask:
|
|||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
|
||||
])
|
||||
def test_psa_mask_distribute(self, device):
|
||||
from mmcv.ops import PSAMask
|
||||
|
|
|
@ -3,7 +3,8 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
|
||||
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
|
||||
IS_NPU_AVAILABLE)
|
||||
|
||||
_USING_PARROTS = True
|
||||
try:
|
||||
|
@ -107,7 +108,11 @@ def _test_roialign_allclose(device, dtype):
|
|||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
|
||||
])
|
||||
def test_roialign_float(device, dtype):
|
||||
_test_roialign_allclose(device=device, dtype=dtype)
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE
|
||||
|
||||
_USING_PARROTS = True
|
||||
try:
|
||||
|
@ -132,15 +132,19 @@ def _test_roialign_rotated_allclose(device, dtype):
|
|||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
|
||||
])
|
||||
@pytest.mark.parametrize('dtype', [
|
||||
torch.float,
|
||||
pytest.param(
|
||||
torch.double,
|
||||
marks=pytest.mark.skipif(
|
||||
IS_MLU_AVAILABLE,
|
||||
reason='MLU does not support for 64-bit floating point')),
|
||||
IS_MLU_AVAILABLE or IS_MUSA_AVAILABLE,
|
||||
reason='MLU, MUSA does not support for 64-bit floating point')),
|
||||
torch.half
|
||||
])
|
||||
def test_roialign_rotated(device, dtype):
|
||||
|
|
|
@ -5,7 +5,8 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
|
||||
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
|
||||
IS_NPU_AVAILABLE)
|
||||
|
||||
_USING_PARROTS = True
|
||||
try:
|
||||
|
@ -89,16 +90,20 @@ class TestRoiPool:
|
|||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
|
||||
])
|
||||
@pytest.mark.parametrize('dtype', [
|
||||
torch.float,
|
||||
pytest.param(
|
||||
torch.double,
|
||||
marks=pytest.mark.skipif(
|
||||
IS_MLU_AVAILABLE or IS_NPU_AVAILABLE,
|
||||
reason='MLU, NPU does not support for 64-bit floating point')),
|
||||
torch.half
|
||||
IS_MLU_AVAILABLE or IS_NPU_AVAILABLE or IS_MUSA_AVAILABLE,
|
||||
reason='MLU, NPU, MUSA '
|
||||
'does not support for 64-bit floating point')), torch.half
|
||||
])
|
||||
def test_roipool_allclose(self, device, dtype):
|
||||
self._test_roipool_allclose(device, dtype)
|
||||
|
|
|
@ -5,7 +5,8 @@ import torch
|
|||
|
||||
from mmcv.ops import (RoIAwarePool3d, points_in_boxes_all, points_in_boxes_cpu,
|
||||
points_in_boxes_part)
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
|
||||
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
|
||||
IS_NPU_AVAILABLE)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [
|
||||
|
@ -13,7 +14,8 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
|
|||
pytest.param(
|
||||
torch.double,
|
||||
marks=pytest.mark.skipif(
|
||||
IS_MLU_AVAILABLE, reason='MLU does not support for double'))
|
||||
IS_MLU_AVAILABLE or IS_MUSA_AVAILABLE,
|
||||
reason='MLU, MUSA does not support for double'))
|
||||
])
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
|
@ -23,7 +25,11 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
|
|||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
|
||||
])
|
||||
def test_RoIAwarePool3d(device, dtype):
|
||||
roiaware_pool3d_max = RoIAwarePool3d(
|
||||
|
@ -64,7 +70,11 @@ def test_RoIAwarePool3d(device, dtype):
|
|||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
|
||||
])
|
||||
def test_points_in_boxes_part(device):
|
||||
boxes = torch.tensor(
|
||||
|
|
|
@ -3,7 +3,8 @@ import pytest
|
|||
import torch
|
||||
|
||||
from mmcv.ops import RoIPointPool3d
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
|
||||
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
|
||||
IS_NPU_AVAILABLE)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('device', [
|
||||
|
@ -18,15 +19,19 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
|
|||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
|
||||
])
|
||||
@pytest.mark.parametrize('dtype', [
|
||||
torch.float, torch.half,
|
||||
pytest.param(
|
||||
torch.double,
|
||||
marks=pytest.mark.skipif(
|
||||
IS_MLU_AVAILABLE or IS_NPU_AVAILABLE,
|
||||
reason='MLU and NPU does not support for double'))
|
||||
IS_MLU_AVAILABLE or IS_NPU_AVAILABLE or IS_MUSA_AVAILABLE,
|
||||
reason='MLU, NPU, MUSA does not support for double'))
|
||||
])
|
||||
def test_roipoint(device, dtype):
|
||||
points = torch.tensor(
|
||||
|
|
|
@ -3,7 +3,8 @@ import pytest
|
|||
import torch
|
||||
|
||||
from mmcv.ops import rotated_feature_align
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
|
||||
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
|
||||
IS_NPU_AVAILABLE)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
@ -21,6 +22,10 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
|
|||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
|
||||
pytest.param(
|
||||
'cpu',
|
||||
marks=pytest.mark.skipif(
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
from torch.autograd import gradcheck
|
||||
|
||||
from mmcv.ops import DynamicScatter
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE
|
||||
|
||||
if torch.__version__ == 'parrots':
|
||||
pytest.skip('not supported in parrots now', allow_module_level=True)
|
||||
|
@ -18,7 +18,11 @@ if torch.__version__ == 'parrots':
|
|||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
|
||||
])
|
||||
def test_dynamic_scatter(device):
|
||||
dsmean = DynamicScatter([0.32, 0.32, 6],
|
||||
|
|
|
@ -10,7 +10,7 @@ from mmcv.ops import (SparseConvTensor, SparseInverseConv3d, SparseSequential,
|
|||
if torch.__version__ == 'parrots':
|
||||
pytest.skip('not supported in parrots now', allow_module_level=True)
|
||||
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE
|
||||
|
||||
|
||||
def make_sparse_convmodule(in_channels,
|
||||
|
@ -86,10 +86,17 @@ def make_sparse_convmodule(in_channels,
|
|||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
|
||||
])
|
||||
def test_make_sparse_convmodule(device):
|
||||
torch.cuda.empty_cache()
|
||||
if IS_CUDA_AVAILABLE:
|
||||
torch.cuda.empty_cache()
|
||||
elif IS_MUSA_AVAILABLE:
|
||||
torch.musa.empty_cache()
|
||||
voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315],
|
||||
[6.8162713, -2.480431, -1.3616394, 0.36],
|
||||
[11.643568, -4.744306, -1.3580885, 0.16],
|
||||
|
|
|
@ -8,6 +8,8 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE
|
||||
|
||||
if platform.system() == 'Windows':
|
||||
import regex as re
|
||||
else:
|
||||
|
@ -29,10 +31,24 @@ class TestSyncBN:
|
|||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['RANK'] = str(rank)
|
||||
|
||||
dist.init_process_group('nccl')
|
||||
torch.cuda.set_device(local_rank)
|
||||
if IS_CUDA_AVAILABLE:
|
||||
dist.init_process_group('nccl')
|
||||
torch.cuda.set_device(local_rank)
|
||||
elif IS_MUSA_AVAILABLE:
|
||||
dist.init_process_group('mccl')
|
||||
torch.musa.set_device(local_rank)
|
||||
|
||||
def _test_syncbn_train(self, size=1, half=False):
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
|
||||
])
|
||||
def _test_syncbn_train(self, size=1, half=False, device='cuda'):
|
||||
|
||||
if 'SLURM_NTASKS' not in os.environ or int(
|
||||
os.environ['SLURM_NTASKS']) != 4:
|
||||
|
@ -49,10 +65,13 @@ class TestSyncBN:
|
|||
rank = dist.get_rank()
|
||||
|
||||
torch.manual_seed(9)
|
||||
torch.cuda.manual_seed(9)
|
||||
if IS_CUDA_AVAILABLE:
|
||||
torch.cuda.manual_seed(9)
|
||||
elif IS_MUSA_AVAILABLE:
|
||||
torch.musa.manual_seed(9)
|
||||
|
||||
self.x = torch.rand(16, 3, 2, 3).cuda()
|
||||
self.y_bp = torch.rand(16, 3, 2, 3).cuda()
|
||||
self.x = torch.rand(16, 3, 2, 3).to(device)
|
||||
self.y_bp = torch.rand(16, 3, 2, 3).to(device)
|
||||
|
||||
if half:
|
||||
self.x = self.x.half()
|
||||
|
@ -60,7 +79,10 @@ class TestSyncBN:
|
|||
dist.broadcast(self.x, src=0)
|
||||
dist.broadcast(self.y_bp, src=0)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
if IS_CUDA_AVAILABLE:
|
||||
torch.cuda.synchronize()
|
||||
elif IS_MUSA_AVAILABLE:
|
||||
torch.musa.synchronize()
|
||||
if size == 1:
|
||||
groups = [None, None, None, None]
|
||||
groups[0] = dist.new_group([0])
|
||||
|
@ -75,13 +97,13 @@ class TestSyncBN:
|
|||
group = groups[rank]
|
||||
elif size == 4:
|
||||
group = dist.group.WORLD
|
||||
syncbn = SyncBatchNorm(3, group=group).cuda()
|
||||
syncbn = SyncBatchNorm(3, group=group).to(device)
|
||||
syncbn.weight.data[0] = 0.2
|
||||
syncbn.weight.data[1] = 0.5
|
||||
syncbn.weight.data[2] = 0.7
|
||||
syncbn.train()
|
||||
|
||||
bn = nn.BatchNorm2d(3).cuda()
|
||||
bn = nn.BatchNorm2d(3).to(device)
|
||||
bn.weight.data[0] = 0.2
|
||||
bn.weight.data[1] = 0.5
|
||||
bn.weight.data[2] = 0.7
|
||||
|
@ -143,7 +165,17 @@ class TestSyncBN:
|
|||
assert np.allclose(x_grad.data.cpu().numpy(),
|
||||
sx_grad.data.cpu().numpy(), 1e-2)
|
||||
|
||||
def _test_syncbn_empty_train(self, size=1, half=False):
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
|
||||
])
|
||||
def _test_syncbn_empty_train(self, size=1, half=False, device='cuda'):
|
||||
|
||||
if 'SLURM_NTASKS' not in os.environ or int(
|
||||
os.environ['SLURM_NTASKS']) != 4:
|
||||
|
@ -160,10 +192,13 @@ class TestSyncBN:
|
|||
rank = dist.get_rank()
|
||||
|
||||
torch.manual_seed(9)
|
||||
torch.cuda.manual_seed(9)
|
||||
if IS_CUDA_AVAILABLE:
|
||||
torch.cuda.manual_seed(9)
|
||||
elif IS_MUSA_AVAILABLE:
|
||||
torch.musa.manual_seed(9)
|
||||
|
||||
self.x = torch.rand(0, 3, 2, 3).cuda()
|
||||
self.y_bp = torch.rand(0, 3, 2, 3).cuda()
|
||||
self.x = torch.rand(0, 3, 2, 3).to(device)
|
||||
self.y_bp = torch.rand(0, 3, 2, 3).to(device)
|
||||
|
||||
if half:
|
||||
self.x = self.x.half()
|
||||
|
@ -171,7 +206,10 @@ class TestSyncBN:
|
|||
dist.broadcast(self.x, src=0)
|
||||
dist.broadcast(self.y_bp, src=0)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
if IS_CUDA_AVAILABLE:
|
||||
torch.cuda.synchronize()
|
||||
elif IS_MUSA_AVAILABLE:
|
||||
torch.musa.synchronize()
|
||||
if size == 1:
|
||||
groups = [None, None, None, None]
|
||||
groups[0] = dist.new_group([0])
|
||||
|
@ -187,13 +225,13 @@ class TestSyncBN:
|
|||
elif size == 4:
|
||||
group = dist.group.WORLD
|
||||
|
||||
syncbn = SyncBatchNorm(3, group=group, stats_mode='N').cuda()
|
||||
syncbn = SyncBatchNorm(3, group=group, stats_mode='N').to(device)
|
||||
syncbn.weight.data[0] = 0.2
|
||||
syncbn.weight.data[1] = 0.5
|
||||
syncbn.weight.data[2] = 0.7
|
||||
syncbn.train()
|
||||
|
||||
bn = nn.BatchNorm2d(3).cuda()
|
||||
bn = nn.BatchNorm2d(3).to(device)
|
||||
bn.weight.data[0] = 0.2
|
||||
bn.weight.data[1] = 0.5
|
||||
bn.weight.data[2] = 0.7
|
||||
|
|
|
@ -3,7 +3,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
from mmcv.ops import three_interpolate
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE, IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [
|
||||
|
@ -11,8 +11,8 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
|
|||
pytest.param(
|
||||
torch.double,
|
||||
marks=pytest.mark.skipif(
|
||||
IS_NPU_AVAILABLE,
|
||||
reason='NPU does not support for 64-bit floating point'))
|
||||
IS_NPU_AVAILABLE or IS_MUSA_AVAILABLE,
|
||||
reason='NPU, MUSA does not support for 64-bit floating point'))
|
||||
])
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
|
@ -22,9 +22,15 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
|
|||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
|
||||
])
|
||||
def test_three_interpolate(dtype, device):
|
||||
if IS_MUSA_AVAILABLE:
|
||||
torch.musa.empty_cache()
|
||||
features = torch.tensor(
|
||||
[[[2.4350, 4.7516, 4.4995, 2.4350, 2.4350, 2.4350],
|
||||
[3.1236, 2.6278, 3.0447, 3.1236, 3.1236, 3.1236],
|
||||
|
|
|
@ -3,7 +3,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
from mmcv.ops import three_nn
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE
|
||||
|
||||
known = [[[-1.8373, 3.5605, -0.7867], [0.7615, 2.9420, 0.2314],
|
||||
[-0.6503, 3.6637, -1.0622], [-1.8373, 3.5605, -0.7867],
|
||||
|
@ -48,7 +48,11 @@ expected_idx = [[[0, 3, 4], [1, 2, 0], [2, 0, 3], [0, 3, 4], [2, 1, 0],
|
|||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
|
||||
])
|
||||
@pytest.mark.parametrize('dtype,rtol', [(torch.float, 1e-8),
|
||||
(torch.half, 1e-3)])
|
||||
|
|
|
@ -5,7 +5,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE
|
||||
|
||||
_USING_PARROTS = True
|
||||
try:
|
||||
|
@ -209,15 +209,19 @@ def _test_tinshift_assert(device, dtype):
|
|||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support')),
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
|
||||
])
|
||||
@pytest.mark.parametrize('dtype', [
|
||||
torch.float,
|
||||
pytest.param(
|
||||
torch.double,
|
||||
marks=pytest.mark.skipif(
|
||||
IS_MLU_AVAILABLE,
|
||||
reason='MLU does not support for 64-bit floating point')),
|
||||
IS_MLU_AVAILABLE or IS_MUSA_AVAILABLE,
|
||||
reason='MLU, MUSA does not support for 64-bit floating point')),
|
||||
torch.half
|
||||
])
|
||||
def test_tinshift(device, dtype):
|
||||
|
|
|
@ -4,7 +4,8 @@ import pytest
|
|||
import torch
|
||||
|
||||
from mmcv.ops import Voxelization
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
|
||||
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
|
||||
IS_NPU_AVAILABLE)
|
||||
|
||||
|
||||
def _get_voxel_points_indices(points, coors, voxel):
|
||||
|
@ -215,3 +216,38 @@ def test_voxelization_npu(device_type):
|
|||
assert np.all(coors == expected_coors)
|
||||
assert np.all(voxels == expected_voxels)
|
||||
assert np.all(num_points_per_voxel == expected_num_points_per_voxel)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('device_type', [
|
||||
pytest.param(
|
||||
'musa',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
|
||||
])
|
||||
def test_voxelization_musa(device_type):
|
||||
voxel_size = [0.5, 0.5, 0.5]
|
||||
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
|
||||
|
||||
voxel_dict = np.load(
|
||||
'tests/data/for_3d_ops/test_voxel.npy', allow_pickle=True).item()
|
||||
expected_coors = voxel_dict['coors']
|
||||
expected_voxels = voxel_dict['voxels']
|
||||
expected_num_points_per_voxel = voxel_dict['num_points_per_voxel']
|
||||
points = voxel_dict['points']
|
||||
|
||||
points = torch.tensor(points)
|
||||
max_num_points = 1000
|
||||
hard_voxelization = Voxelization(voxel_size, point_cloud_range,
|
||||
max_num_points)
|
||||
|
||||
device = torch.device(device_type)
|
||||
|
||||
# test hard_voxelization on mlu
|
||||
points = points.contiguous().to(device)
|
||||
coors, voxels, num_points_per_voxel = hard_voxelization.forward(points)
|
||||
coors = coors.cpu().detach().numpy()
|
||||
voxels = voxels.cpu().detach().numpy()
|
||||
num_points_per_voxel = num_points_per_voxel.cpu().detach().numpy()
|
||||
assert np.all(coors == expected_coors)
|
||||
assert np.all(voxels == expected_voxels)
|
||||
assert np.all(num_points_per_voxel == expected_num_points_per_voxel)
|
||||
|
|
Loading…
Reference in New Issue