[MUSA] mmcv support musa, split pr 4 (#3260)

* mmcv support musa, split pr 4

* fix lint

* fix lint
pull/3264/head
sunyanguomt 2025-03-20 16:34:23 +08:00 committed by GitHub
parent 24a2bb4f7b
commit 4b38ffcf45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
58 changed files with 8741 additions and 71 deletions

View File

@ -0,0 +1,91 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef POINT_IN_BOXES_MUSA_KERNEL_MUH
#define POINT_IN_BOXES_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
template <typename T>
__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz,
T &local_x, T &local_y) {
T cosa = cos(-rz), sina = sin(-rz);
local_x = shift_x * cosa + shift_y * (-sina);
local_y = shift_x * sina + shift_y * cosa;
}
template <typename T>
__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x,
T &local_y) {
// param pt: (x, y, z)
// param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate,
// cz in the bottom center
T x = pt[0], y = pt[1], z = pt[2];
T cx = box3d[0], cy = box3d[1], cz = box3d[2];
T x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6];
cz += z_size /
2.0; // shift to the center since cz in box3d is the bottom center
if (fabsf(z - cz) > z_size / 2.0) return 0;
lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) &
(local_y > -y_size / 2.0) & (local_y < y_size / 2.0);
return in_flag;
}
template <typename T>
__global__ void points_in_boxes_part_forward_musa_kernel(
int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts,
int *box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate, z is the bottom center, each box DO NOT overlaps params pts:
// (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points:
// (B, npoints), default -1
int bs_idx = blockIdx.y;
MUSA_1D_KERNEL_LOOP(pt_idx, pts_num) {
if (bs_idx >= batch_size) return;
boxes += bs_idx * boxes_num * 7;
pts += bs_idx * pts_num * 3 + pt_idx * 3;
box_idx_of_points += bs_idx * pts_num + pt_idx;
T local_x = 0, local_y = 0;
int cur_in_flag = 0;
for (int k = 0; k < boxes_num; k++) {
cur_in_flag = check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y);
if (cur_in_flag) {
box_idx_of_points[0] = k;
break;
}
}
}
}
template <typename T>
__global__ void points_in_boxes_all_forward_musa_kernel(
int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts,
int *box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate, z is the bottom center, each box DO NOT overlaps params pts:
// (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points:
// (B, npoints), default -1
int bs_idx = blockIdx.y;
MUSA_1D_KERNEL_LOOP(pt_idx, pts_num) {
if (bs_idx >= batch_size) return;
boxes += bs_idx * boxes_num * 7;
pts += bs_idx * pts_num * 3 + pt_idx * 3;
box_idx_of_points += bs_idx * pts_num * boxes_num + pt_idx * boxes_num;
T local_x = 0, local_y = 0;
for (int k = 0; k < boxes_num; k++) {
const int cur_in_flag =
check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y);
if (cur_in_flag) {
box_idx_of_points[k] = 1;
}
}
}
}
#endif // POINT_IN_BOXES_MUSA_KERNEL_MUH

View File

@ -0,0 +1,75 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef POINTS_IN_POLYGONS_MUSA_KERNEL_MUH
#define POINTS_IN_POLYGONS_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
struct point {
float x, y;
};
template <typename scalar_t>
__global__ void points_in_polygons_forward_musa_kernel(
const int nthreads, const scalar_t *vertex1, const scalar_t *vertex2,
const int rows, const int cols, scalar_t *inside_flag) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
int row = index / cols;
int col = index % cols;
const scalar_t *offset_vertex1 = vertex1 + row * 2;
const scalar_t *offset_vertex2 = vertex2 + col * 8;
point point_[1];
point polygon[4];
point_[0].x = offset_vertex1[0];
point_[0].y = offset_vertex1[1];
polygon[0].x = offset_vertex2[0];
polygon[0].y = offset_vertex2[1];
polygon[1].x = offset_vertex2[2];
polygon[1].y = offset_vertex2[3];
polygon[2].x = offset_vertex2[4];
polygon[2].y = offset_vertex2[5];
polygon[3].x = offset_vertex2[6];
polygon[3].y = offset_vertex2[7];
int nCross = 0;
int i, j;
float sx, sy, tx, ty, px, py, x;
for (i = 0, j = 3; i < 4; j = i, i++) {
sx = polygon[i].x;
sy = polygon[i].y;
tx = polygon[j].x;
ty = polygon[j].y;
px = point_[0].x;
py = point_[0].y;
if (py < min(sy, ty)) continue;
if (py > max(sy, ty)) continue;
if ((sx == px && sy == py) || (tx == px && ty == py)) {
break;
} else {
if ((sy < py && ty >= py) || (sy >= py && ty < py)) {
x = sx + (py - sy) * (tx - sx) / (ty - sy);
if (x == px) {
break;
}
if (x > px) {
nCross++;
}
}
}
}
if (nCross % 2 == 1) {
inside_flag[index] = 1.0;
} else {
inside_flag[index] = 0.0;
}
return;
}
}
#endif // POINTS_IN_POLYGONS_MUSA_KERNEL_MUH

View File

@ -0,0 +1,377 @@
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/vacancy/PreciseRoIPooling/blob/master/src/prroi_pooling_gpu_impl.cu
// Distributed under terms of the MIT license.
#ifndef PRROI_POOL_MUSA_KERNEL_MUH
#define PRROI_POOL_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
template <typename T>
__device__ static __forceinline__ T PrRoIPoolingGetData(const T *data,
const int h,
const int w,
const int height,
const int width) {
bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);
T retVal = overflow ? 0.0f : data[h * width + w];
return retVal;
}
template <typename T>
__device__ static __forceinline__ T PrRoIPoolingGetCoeff(T dh, T dw) {
return (1.0f - abs(dh)) * (1.0f - abs(dw));
}
template <typename T>
__device__ static __forceinline__ T PrRoIPoolingSingleCoorIntegral(T s, T t,
T c1, T c2) {
return 0.5 * (t * t - s * s) * (c2 - c1) + (t - s) * c1;
}
template <typename T>
__device__ static T PrRoIPoolingInterpolation(const T *data, const T h,
const T w, const int height,
const int width) {
T retVal = 0.0f;
int h1 = floorf(h);
int w1 = floorf(w);
retVal += PrRoIPoolingGetData(data, h1, w1, height, width) *
PrRoIPoolingGetCoeff(h - T(h1), w - T(w1));
h1 = floorf(h) + 1;
w1 = floorf(w);
retVal += PrRoIPoolingGetData(data, h1, w1, height, width) *
PrRoIPoolingGetCoeff(h - T(h1), w - T(w1));
h1 = floorf(h);
w1 = floorf(w) + 1;
retVal += PrRoIPoolingGetData(data, h1, w1, height, width) *
PrRoIPoolingGetCoeff(h - T(h1), w - T(w1));
h1 = floorf(h) + 1;
w1 = floorf(w) + 1;
retVal += PrRoIPoolingGetData(data, h1, w1, height, width) *
PrRoIPoolingGetCoeff(h - T(h1), w - T(w1));
return retVal;
}
template <typename T>
__device__ static T PrRoIPoolingMatCalculation(const T *this_data,
const int s_h, const int s_w,
const int e_h, const int e_w,
const T y0, const T x0,
const T y1, const T x1,
const int h0, const int w0) {
T alpha, beta, lim_alpha, lim_beta, tmp;
T sum_out = 0;
alpha = x0 - T(s_w);
beta = y0 - T(s_h);
lim_alpha = x1 - T(s_w);
lim_beta = y1 - T(s_h);
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
sum_out += PrRoIPoolingGetData(this_data, s_h, s_w, h0, w0) * tmp;
alpha = T(e_w) - x1;
lim_alpha = T(e_w) - x0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
sum_out += PrRoIPoolingGetData(this_data, s_h, e_w, h0, w0) * tmp;
alpha = x0 - T(s_w);
beta = T(e_h) - y1;
lim_alpha = x1 - T(s_w);
lim_beta = T(e_h) - y0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
sum_out += PrRoIPoolingGetData(this_data, e_h, s_w, h0, w0) * tmp;
alpha = T(e_w) - x1;
lim_alpha = T(e_w) - x0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
sum_out += PrRoIPoolingGetData(this_data, e_h, e_w, h0, w0) * tmp;
return sum_out;
}
template <typename T>
__device__ static void PrRoIPoolingDistributeDiff(T *diff, const T top_diff,
const int h, const int w,
const int height,
const int width,
const T coeff) {
bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);
if (!overflow) atomicAdd(diff + h * width + w, top_diff * coeff);
}
template <typename T>
__device__ static void PrRoIPoolingMatDistributeDiff(
T *diff, const T top_diff, const int s_h, const int s_w, const int e_h,
const int e_w, const T y0, const T x0, const T y1, const T x1, const int h0,
const int w0) {
T alpha, beta, lim_alpha, lim_beta, tmp;
alpha = x0 - T(s_w);
beta = y0 - T(s_h);
lim_alpha = x1 - T(s_w);
lim_beta = y1 - T(s_h);
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
PrRoIPoolingDistributeDiff(diff, top_diff, s_h, s_w, h0, w0, tmp);
alpha = T(e_w) - x1;
lim_alpha = T(e_w) - x0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
PrRoIPoolingDistributeDiff(diff, top_diff, s_h, e_w, h0, w0, tmp);
alpha = x0 - T(s_w);
beta = T(e_h) - y1;
lim_alpha = x1 - T(s_w);
lim_beta = T(e_h) - y0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
PrRoIPoolingDistributeDiff(diff, top_diff, e_h, s_w, h0, w0, tmp);
alpha = T(e_w) - x1;
lim_alpha = T(e_w) - x0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
PrRoIPoolingDistributeDiff(diff, top_diff, e_h, e_w, h0, w0, tmp);
}
template <typename T>
__global__ void prroi_pool_forward_musa_kernel(
const int nthreads, const T *input, const T *rois, T *output,
const int pooled_height, const int pooled_width, const T spatial_scale,
const int channels, const int height, const int width) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const T *offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
T roi_x1 = offset_rois[1] * spatial_scale;
T roi_y1 = offset_rois[2] * spatial_scale;
T roi_x2 = offset_rois[3] * spatial_scale;
T roi_y2 = offset_rois[4] * spatial_scale;
T roi_width = max(roi_x2 - roi_x1, ((T)0.0));
T roi_height = max(roi_y2 - roi_y1, ((T)0.0));
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
const T *this_data =
input + (roi_batch_ind * channels + c) * height * width;
T *this_out = output + index;
T bin_x1 = roi_x1 + bin_size_w * pw;
T bin_y1 = roi_y1 + bin_size_h * ph;
T bin_x2 = bin_x1 + bin_size_w;
T bin_y2 = bin_y1 + bin_size_h;
T bin_size = max(T(0.0), bin_size_w * bin_size_h);
if (bin_size == 0) {
*this_out = 0;
continue;
}
T sum_out = 0;
int start_x, start_y, end_x, end_y;
start_x = floorf(bin_x1);
end_x = ceilf(bin_x2);
start_y = floorf(bin_y1);
end_y = ceilf(bin_y2);
for (int bin_x = start_x; bin_x < end_x; ++bin_x)
for (int bin_y = start_y; bin_y < end_y; ++bin_y)
sum_out += PrRoIPoolingMatCalculation(
this_data, bin_y, bin_x, bin_y + 1, bin_x + 1,
max(bin_y1, T(bin_y)), max(bin_x1, T(bin_x)),
min(bin_y2, T(bin_y) + 1.0f), min(bin_x2, T(bin_x + 1.0f)), height,
width);
*this_out = sum_out / bin_size;
}
}
template <typename T>
__global__ void prroi_pool_backward_musa_kernel(
const int nthreads, const T *grad_output, const T *rois, T *grad_input,
const int pooled_height, const int pooled_width, const T spatial_scale,
const int channels, const int height, const int width) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
auto rois_cur = rois + n * 5;
int roi_batch_ind = rois_cur[0];
T roi_x1 = rois_cur[1] * spatial_scale;
T roi_y1 = rois_cur[2] * spatial_scale;
T roi_x2 = rois_cur[3] * spatial_scale;
T roi_y2 = rois_cur[4] * spatial_scale;
T roi_width = max(roi_x2 - roi_x1, (T)0);
T roi_height = max(roi_y2 - roi_y1, (T)0);
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
const T *this_out_grad = grad_output + index;
T *this_data_grad =
grad_input + (roi_batch_ind * channels + c) * height * width;
T bin_x1 = roi_x1 + bin_size_w * pw;
T bin_y1 = roi_y1 + bin_size_h * ph;
T bin_x2 = bin_x1 + bin_size_w;
T bin_y2 = bin_y1 + bin_size_h;
T bin_size = max(T(0.0), bin_size_w * bin_size_h);
T sum_out = bin_size == T(0) ? T(0) : *this_out_grad / bin_size;
int start_x, start_y, end_x, end_y;
start_x = floorf(bin_x1);
end_x = ceilf(bin_x2);
start_y = floorf(bin_y1);
end_y = ceilf(bin_y2);
for (int bin_x = start_x; bin_x < end_x; ++bin_x)
for (int bin_y = start_y; bin_y < end_y; ++bin_y)
PrRoIPoolingMatDistributeDiff(
this_data_grad, sum_out, bin_y, bin_x, bin_y + 1, bin_x + 1,
max(bin_y1, T(bin_y)), max(bin_x1, T(bin_x)),
min(bin_y2, T(bin_y) + 1.0f), min(bin_x2, T(bin_x + 1.0f)), height,
width);
}
}
template <typename T>
__global__ void prroi_pool_coor_backward_musa_kernel(
const int nthreads, const T *output, const T *grad_output, const T *input,
const T *rois, T *grad_rois, const int pooled_height,
const int pooled_width, const T spatial_scale, const int channels,
const int height, const int width) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
auto rois_cur = rois + n * 5;
int roi_batch_ind = rois_cur[0];
T roi_x1 = rois_cur[1] * spatial_scale;
T roi_y1 = rois_cur[2] * spatial_scale;
T roi_x2 = rois_cur[3] * spatial_scale;
T roi_y2 = rois_cur[4] * spatial_scale;
T roi_width = max(roi_x2 - roi_x1, (T)0);
T roi_height = max(roi_y2 - roi_y1, (T)0);
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
const T output_grad_val = grad_output[index];
const T *this_input_data =
input + (roi_batch_ind * channels + c) * height * width;
const T output_val = output[index];
T *this_rois_grad = grad_rois + n * 5;
T bin_x1 = roi_x1 + bin_size_w * pw;
T bin_y1 = roi_y1 + bin_size_h * ph;
T bin_x2 = bin_x1 + bin_size_w;
T bin_y2 = bin_y1 + bin_size_h;
T bin_size = max(T(0.0), bin_size_w * bin_size_h);
T sum_out = bin_size == T(0) ? T(0) : output_grad_val / bin_size;
// WARNING: to be discussed
if (sum_out == 0) continue;
int start_x, start_y, end_x, end_y;
start_x = floorf(bin_x1);
end_x = ceilf(bin_x2);
start_y = floorf(bin_y1);
end_y = ceilf(bin_y2);
T grad_x1_y = 0, grad_x2_y = 0, grad_x_y1 = 0, grad_x_y2 = 0;
for (int bin_y = start_y; bin_y < end_y; ++bin_y) {
grad_x1_y += PrRoIPoolingSingleCoorIntegral(
max(bin_y1, T(bin_y)) - bin_y, min(bin_y2, T(bin_y + 1)) - bin_y,
PrRoIPoolingInterpolation(this_input_data, float(bin_y), bin_x1,
height, width),
PrRoIPoolingInterpolation(this_input_data, float(bin_y + 1), bin_x1,
height, width));
grad_x2_y += PrRoIPoolingSingleCoorIntegral(
max(bin_y1, T(bin_y)) - bin_y, min(bin_y2, T(bin_y + 1)) - bin_y,
PrRoIPoolingInterpolation(this_input_data, float(bin_y), bin_x2,
height, width),
PrRoIPoolingInterpolation(this_input_data, float(bin_y + 1), bin_x2,
height, width));
}
for (int bin_x = start_x; bin_x < end_x; ++bin_x) {
grad_x_y1 += PrRoIPoolingSingleCoorIntegral(
max(bin_x1, T(bin_x)) - bin_x, min(bin_x2, T(bin_x + 1)) - bin_x,
PrRoIPoolingInterpolation(this_input_data, bin_y1, float(bin_x),
height, width),
PrRoIPoolingInterpolation(this_input_data, bin_y1, float(bin_x + 1),
height, width));
grad_x_y2 += PrRoIPoolingSingleCoorIntegral(
max(bin_x1, T(bin_x)) - bin_x, min(bin_x2, T(bin_x + 1)) - bin_x,
PrRoIPoolingInterpolation(this_input_data, bin_y2, float(bin_x),
height, width),
PrRoIPoolingInterpolation(this_input_data, bin_y2, float(bin_x + 1),
height, width));
}
T partial_x1 = -grad_x1_y + (bin_y2 - bin_y1) * output_val;
T partial_y1 = -grad_x_y1 + (bin_x2 - bin_x1) * output_val;
T partial_x2 = grad_x2_y - (bin_y2 - bin_y1) * output_val;
T partial_y2 = grad_x_y2 - (bin_x2 - bin_x1) * output_val;
partial_x1 = partial_x1 / bin_size * spatial_scale;
partial_x2 = partial_x2 / bin_size * spatial_scale;
partial_y1 = partial_y1 / bin_size * spatial_scale;
partial_y2 = partial_y2 / bin_size * spatial_scale;
// (index, x1, y1, x2, y2)
this_rois_grad[0] = 0;
atomicAdd(this_rois_grad + 1,
(partial_x1 * (1.0f - T(pw) / pooled_width) +
partial_x2 * (1.0f - T(pw + 1) / pooled_width)) *
output_grad_val);
atomicAdd(this_rois_grad + 2,
(partial_y1 * (1.0f - T(ph) / pooled_height) +
partial_y2 * (1.0f - T(ph + 1) / pooled_height)) *
output_grad_val);
atomicAdd(this_rois_grad + 3, (partial_x2 * T(pw + 1) / pooled_width +
partial_x1 * T(pw) / pooled_width) *
output_grad_val);
atomicAdd(this_rois_grad + 4, (partial_y2 * T(ph + 1) / pooled_height +
partial_y1 * T(ph) / pooled_height) *
output_grad_val);
}
}
#endif // ROI_POOL_MUSA_KERNEL_MUH

View File

@ -0,0 +1,137 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef PSAMASK_MUSA_KERNEL_MUH
#define PSAMASK_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
// MUSA: grid stride looping
#ifndef MUSA_KERNEL_LOOP
#define MUSA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
#endif
template <typename T>
__global__ void psamask_collect_forward_musa(
const int nthreads, const int h_feature, const int w_feature,
const int h_mask, const int w_mask, const int half_h_mask,
const int half_w_mask, const T* mask_data, T* buffer_data) {
MUSA_KERNEL_LOOP(index, nthreads) {
const int w = index % w_feature;
const int h = (index / w_feature) % h_feature;
const int n = index / w_feature / h_feature;
// effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed
const int hstart = max(0, half_h_mask - h);
const int hend = min(h_mask, h_feature + half_h_mask - h);
const int wstart = max(0, half_w_mask - w);
const int wend = min(w_mask, w_feature + half_w_mask - w);
// (hidx, widx ) with mask-indexed
// (hidx + h - half_h_mask, widx + w - half_w_mask) with feature-indexed
for (int hidx = hstart; hidx < hend; hidx++) {
for (int widx = wstart; widx < wend; widx++) {
buffer_data[(n * h_feature * w_feature +
(hidx + h - half_h_mask) * w_feature +
(widx + w - half_w_mask)) *
h_feature * w_feature +
h * w_feature + w] = mask_data
[((n * h_mask * w_mask + hidx * w_mask + widx) * h_feature + h) *
w_feature +
w];
}
}
}
}
template <typename T>
__global__ void psamask_distribute_forward_musa(
const int nthreads, const int h_feature, const int w_feature,
const int h_mask, const int w_mask, const int half_h_mask,
const int half_w_mask, const T* mask_data, T* buffer_data) {
MUSA_KERNEL_LOOP(index, nthreads) {
const int w = index % w_feature;
const int h = (index / w_feature) % h_feature;
const int n = index / w_feature / h_feature;
// effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed
const int hstart = max(0, half_h_mask - h);
const int hend = min(h_mask, h_feature + half_h_mask - h);
const int wstart = max(0, half_w_mask - w);
const int wend = min(w_mask, w_feature + half_w_mask - w);
// (hidx, widx ) with mask-indexed
// (hidx + h - half_h_mask, widx + w - half_w_mask) with feature-indexed
for (int hidx = hstart; hidx < hend; hidx++) {
for (int widx = wstart; widx < wend; widx++) {
buffer_data[(n * h_feature * w_feature + h * w_feature + w) *
h_feature * w_feature +
(hidx + h - half_h_mask) * w_feature +
(widx + w - half_w_mask)] = mask_data
[((n * h_mask * w_mask + hidx * w_mask + widx) * h_feature + h) *
w_feature +
w];
}
}
}
}
template <typename T>
__global__ void psamask_collect_backward_musa(
const int nthreads, const int h_feature, const int w_feature,
const int h_mask, const int w_mask, const int half_h_mask,
const int half_w_mask, const T* buffer_diff, T* mask_diff) {
MUSA_KERNEL_LOOP(index, nthreads) {
const int w = index % w_feature;
const int h = (index / w_feature) % h_feature;
const int n = index / w_feature / h_feature;
// effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed
const int hstart = max(0, half_h_mask - h);
const int hend = min(h_mask, h_feature + half_h_mask - h);
const int wstart = max(0, half_w_mask - w);
const int wend = min(w_mask, w_feature + half_w_mask - w);
// (hidx, widx ) with mask-indexed
// (hidx + h - half_h_mask, widx + w - half_w_mask) with feature-indexed
for (int hidx = hstart; hidx < hend; hidx++) {
for (int widx = wstart; widx < wend; widx++) {
mask_diff[((n * h_mask * w_mask + hidx * w_mask + widx) * h_feature +
h) *
w_feature +
w] = buffer_diff[(n * h_feature * w_feature +
(hidx + h - half_h_mask) * w_feature +
(widx + w - half_w_mask)) *
h_feature * w_feature +
h * w_feature + w];
}
}
}
}
template <typename T>
__global__ void psamask_distribute_backward_musa(
const int nthreads, const int h_feature, const int w_feature,
const int h_mask, const int w_mask, const int half_h_mask,
const int half_w_mask, const T* buffer_diff, T* mask_diff) {
MUSA_KERNEL_LOOP(index, nthreads) {
const int w = index % w_feature;
const int h = (index / w_feature) % h_feature;
const int n = index / w_feature / h_feature;
// effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed
const int hstart = max(0, half_h_mask - h);
const int hend = min(h_mask, h_feature + half_h_mask - h);
const int wstart = max(0, half_w_mask - w);
const int wend = min(w_mask, w_feature + half_w_mask - w);
// (hidx, widx ) with mask-indexed
// (hidx + h - half_h_mask, widx + w - half_w_mask) with feature-indexed
for (int hidx = hstart; hidx < hend; hidx++) {
for (int widx = wstart; widx < wend; widx++) {
mask_diff[((n * h_mask * w_mask + hidx * w_mask + widx) * h_feature +
h) *
w_feature +
w] =
buffer_diff[(n * h_feature * w_feature + h * w_feature + w) *
h_feature * w_feature +
(hidx + h - half_h_mask) * w_feature +
(widx + w - half_w_mask)];
}
}
}
}
#endif // PSAMASK_MUSA_KERNEL_MUH

View File

@ -0,0 +1,238 @@
// Modified from
// https://github.com/csuhan/ReDet/blob/master/mmdet/ops/riroi_align/src/riroi_align_kernel.cu
#ifndef RIROI_ALIGN_ROTATED_MUSA_KERNEL_MUH
#define RIROI_ALIGN_ROTATED_MUSA_KERNEL_MUH
#include <float.h>
#include "pytorch_musa_helper.hpp"
/*** Forward ***/
template <typename scalar_t>
__global__ void riroi_align_rotated_forward_musa_kernel(
const int nthreads, const scalar_t *bottom_data,
const scalar_t *bottom_rois, const scalar_t spatial_scale,
const int num_samples, const bool clockwise, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const int num_orientations, scalar_t *top_data) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int o = (index / pooled_width / pooled_height) % num_orientations;
int c =
(index / pooled_width / pooled_height / num_orientations) % channels;
int n = index / pooled_width / pooled_height / num_orientations / channels;
const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
int roi_batch_ind = offset_bottom_rois[0];
// Do not using rounding; this implementation detail is critical
scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale;
scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale;
scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
// scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
scalar_t theta = offset_bottom_rois[5];
// Force malformed ROIs to be 1x1
roi_width = max(roi_width, (scalar_t)1.);
roi_height = max(roi_height, (scalar_t)1.);
scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
static_cast<scalar_t>(pooled_height);
scalar_t bin_size_w =
static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
// find aligned index
scalar_t ind_float = theta * num_orientations / (2 * M_PI);
int ind = floorf(ind_float);
scalar_t l_var = ind_float - (scalar_t)ind;
scalar_t r_var = 1.0 - l_var;
// correct start channel
ind = (ind + num_orientations) % num_orientations;
// rotated channel
int ind_rot = (o - ind + num_orientations) % num_orientations;
int ind_rot_plus = (ind_rot + 1 + num_orientations) % num_orientations;
const scalar_t *offset_bottom_data =
bottom_data + (roi_batch_ind * channels * num_orientations +
c * num_orientations + ind_rot) *
height * width;
const scalar_t *offset_bottom_data_plus =
bottom_data + (roi_batch_ind * channels * num_orientations +
c * num_orientations + ind_rot_plus) *
height * width;
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (num_samples > 0)
? num_samples
: ceilf(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(num_samples > 0) ? num_samples : ceilf(roi_width / pooled_width);
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
// Appropriate translation needs to be applied after.
if (clockwise) {
theta = -theta; // If clockwise, the angle needs to be reversed.
}
scalar_t roi_start_h = -roi_height / 2.0;
scalar_t roi_start_w = -roi_width / 2.0;
scalar_t cosscalar_theta = cos(theta);
scalar_t sinscalar_theta = sin(theta);
// We do average (integral) pooling inside a bin
const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
scalar_t output_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
const scalar_t yy =
roi_start_h + ph * bin_size_h +
static_cast<scalar_t>(iy + .5f) * bin_size_h /
static_cast<scalar_t>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const scalar_t xx = roi_start_w + pw * bin_size_w +
static_cast<scalar_t>(ix + .5f) * bin_size_w /
static_cast<scalar_t>(roi_bin_grid_w);
// Rotate by theta (counterclockwise) around the center and translate
scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h;
scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w;
scalar_t val = bilinear_interpolate<scalar_t>(
offset_bottom_data, height, width, y, x, index);
scalar_t val_plus = bilinear_interpolate<scalar_t>(
offset_bottom_data_plus, height, width, y, x, index);
output_val += r_var * val + l_var * val_plus;
}
}
output_val /= count;
top_data[index] = output_val;
}
}
/*** Backward ***/
template <typename scalar_t>
__global__ void riroi_align_rotated_backward_musa_kernel(
const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois,
const scalar_t spatial_scale, const int num_samples, const bool clockwise,
const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, const int num_orientations,
scalar_t *bottom_diff) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int o = (index / pooled_width / pooled_height) % num_orientations;
int c =
(index / pooled_width / pooled_height / num_orientations) % channels;
int n = index / pooled_width / pooled_height / num_orientations / channels;
const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
int roi_batch_ind = offset_bottom_rois[0];
// Do not round
scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale;
scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale;
scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
// scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
scalar_t theta = offset_bottom_rois[5];
// Force malformed ROIs to be 1x1
roi_width = max(roi_width, (scalar_t)1.);
roi_height = max(roi_height, (scalar_t)1.);
scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
static_cast<scalar_t>(pooled_height);
scalar_t bin_size_w =
static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
// find aligned index
scalar_t ind_float = theta * num_orientations / (2 * M_PI);
int ind = floorf(ind_float);
scalar_t l_var = ind_float - (scalar_t)ind;
scalar_t r_var = 1.0 - l_var;
// correct start channel
ind = (ind + num_orientations) % num_orientations;
// rotated channel
int ind_rot = (o - ind + num_orientations) % num_orientations;
int ind_rot_plus = (ind_rot + 1 + num_orientations) % num_orientations;
scalar_t *offset_bottom_diff =
bottom_diff + (roi_batch_ind * channels * num_orientations +
c * num_orientations + ind_rot) *
height * width;
scalar_t *offset_bottom_diff_plus =
bottom_diff + (roi_batch_ind * channels * num_orientations +
c * num_orientations + ind_rot_plus) *
height * width;
int top_offset =
(n * channels * num_orientations + c * num_orientations + o) *
pooled_height * pooled_width;
const scalar_t *offset_top_diff = top_diff + top_offset;
const scalar_t top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (num_samples > 0)
? num_samples
: ceilf(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(num_samples > 0) ? num_samples : ceilf(roi_width / pooled_width);
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
// Appropriate translation needs to be applied after.
if (clockwise) {
theta = -theta; // If clockwise, the angle needs to be reversed.
}
scalar_t roi_start_h = -roi_height / 2.0;
scalar_t roi_start_w = -roi_width / 2.0;
scalar_t cosTheta = cos(theta);
scalar_t sinTheta = sin(theta);
// We do average (integral) pooling inside a bin
const scalar_t count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
const scalar_t yy =
roi_start_h + ph * bin_size_h +
static_cast<scalar_t>(iy + .5f) * bin_size_h /
static_cast<scalar_t>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const scalar_t xx = roi_start_w + pw * bin_size_w +
static_cast<scalar_t>(ix + .5f) * bin_size_w /
static_cast<scalar_t>(roi_bin_grid_w);
// Rotate by theta around the center and translate
scalar_t y = yy * cosTheta - xx * sinTheta + roi_center_h;
scalar_t x = yy * sinTheta + xx * cosTheta + roi_center_w;
scalar_t w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient<scalar_t>(height, width, y, x, w1, w2, w3,
w4, x_low, x_high, y_low,
y_high, index);
scalar_t g1 = top_diff_this_bin * w1 / count;
scalar_t g2 = top_diff_this_bin * w2 / count;
scalar_t g3 = top_diff_this_bin * w3 / count;
scalar_t g4 = top_diff_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomicAdd(offset_bottom_diff + y_low * width + x_low, g1 * r_var);
atomicAdd(offset_bottom_diff + y_low * width + x_high, g2 * r_var);
atomicAdd(offset_bottom_diff + y_high * width + x_low, g3 * r_var);
atomicAdd(offset_bottom_diff + y_high * width + x_high, g4 * r_var);
atomicAdd(offset_bottom_diff_plus + y_low * width + x_low,
g1 * l_var);
atomicAdd(offset_bottom_diff_plus + y_low * width + x_high,
g2 * l_var);
atomicAdd(offset_bottom_diff_plus + y_high * width + x_low,
g3 * l_var);
atomicAdd(offset_bottom_diff_plus + y_high * width + x_high,
g4 * l_var);
} // if
} // ix
} // iy
} // MUSA_1D_KERNEL_LOOP
} // RiRoIAlignBackward
#endif // RIROI_ALIGN_ROTATED_MUSA_KERNEL_MUH

View File

@ -0,0 +1,205 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef ROI_ALIGN_MUSA_KERNEL_MUH
#define ROI_ALIGN_MUSA_KERNEL_MUH
#include <float.h>
#include "pytorch_musa_helper.hpp"
/*** Forward ***/
template <typename T>
__global__ void roi_align_forward_musa_kernel(
const int nthreads, const T* input, const T* rois, T* output, T* argmax_y,
T* argmax_x, const int pooled_height, const int pooled_width,
const T spatial_scale, const int sampling_ratio,
const int pool_mode, // 0 - max pool, 1 - avg pool
const bool aligned, const int channels, const int height, const int width) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T offset = aligned ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!aligned) { // for backward-compatibility only
roi_width = max(roi_width, (T)1.);
roi_height = max(roi_height, (T)1.);
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
const T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_height / pooled_height));
int roi_bin_grid_w =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_width / pooled_width));
if (pool_mode == 0) {
// We do max pooling inside a bin
T maxval = -FLT_MAX;
T maxidx_y = -1.f, maxidx_x = -1.f;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T val =
bilinear_interpolate(offset_input, height, width, y, x, index);
if (val > maxval) {
maxval = val;
maxidx_y = y;
maxidx_x = x;
}
}
}
output[index] = maxval;
argmax_y[index] = maxidx_y;
argmax_x[index] = maxidx_x;
} else if (pool_mode == 1) {
// We do average pooling inside a bin
const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1);
T output_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T val =
bilinear_interpolate(offset_input, height, width, y, x, index);
output_val += val;
}
}
output[index] = output_val / count;
}
}
}
/*** Backward ***/
template <typename T>
__global__ void roi_align_backward_musa_kernel(
const int nthreads, const T* grad_output, const T* rois, const T* argmax_y,
const T* argmax_x, T* grad_input, const int pooled_height,
const int pooled_width, const T spatial_scale, const int sampling_ratio,
const int pool_mode, // 0 - max pool, 1 - avg pool
const bool aligned, const int channels, const int height, const int width) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const T grad_output_this_bin = grad_output[index];
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
T* offset_grad_input =
grad_input + ((roi_batch_ind * channels + c) * height * width);
if (pool_mode == 0) {
T y = argmax_y[index], x = argmax_x[index];
if (y != -1.f) {
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
x_low, x_high, y_low, y_high, index);
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomicAdd(offset_grad_input + y_low * width + x_low,
grad_output_this_bin * w1);
atomicAdd(offset_grad_input + y_low * width + x_high,
grad_output_this_bin * w2);
atomicAdd(offset_grad_input + y_high * width + x_low,
grad_output_this_bin * w3);
atomicAdd(offset_grad_input + y_high * width + x_high,
grad_output_this_bin * w4);
}
}
} else if (pool_mode == 1) {
// Do not using rounding; this implementation detail is critical
T offset = aligned ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!aligned) { // for backward-compatibility only
roi_width = max(roi_width, (T)1.);
roi_height = max(roi_height, (T)1.);
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_height / pooled_height));
int roi_bin_grid_w =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_width / pooled_width));
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
x_low, x_high, y_low, y_high, index);
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomicAdd(offset_grad_input + y_low * width + x_low,
grad_output_this_bin * w1 / count);
atomicAdd(offset_grad_input + y_low * width + x_high,
grad_output_this_bin * w2 / count);
atomicAdd(offset_grad_input + y_high * width + x_low,
grad_output_this_bin * w3 / count);
atomicAdd(offset_grad_input + y_high * width + x_high,
grad_output_this_bin * w4 / count);
}
}
}
}
}
}
#endif // ROI_ALIGN_MUSA_KERNEL_MUH

View File

@ -0,0 +1,194 @@
// Modified from
// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/ROIAlignRotated
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#ifndef ROI_ALIGN_ROTATED_MUSA_KERNEL_MUH
#define ROI_ALIGN_ROTATED_MUSA_KERNEL_MUH
#include <float.h>
#include "pytorch_musa_helper.hpp"
/*** Forward ***/
template <typename scalar_t>
__global__ void roi_align_rotated_forward_musa_kernel(
const int nthreads, const scalar_t *bottom_data,
const scalar_t *bottom_rois, const scalar_t spatial_scale,
const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, scalar_t *top_data) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
int roi_batch_ind = offset_bottom_rois[0];
// Do not using rounding; this implementation detail is critical
scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0;
scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale - offset;
scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale - offset;
scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
// scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
scalar_t theta = offset_bottom_rois[5];
if (clockwise) {
theta = -theta; // If clockwise, the angle needs to be reversed.
}
if (!aligned) { // for backward-compatibility only
// Force malformed ROIs to be 1x1
roi_width = max(roi_width, (scalar_t)1.);
roi_height = max(roi_height, (scalar_t)1.);
}
scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
static_cast<scalar_t>(pooled_height);
scalar_t bin_size_w =
static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
const scalar_t *offset_bottom_data =
bottom_data + (roi_batch_ind * channels + c) * height * width;
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceilf(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
// Appropriate translation needs to be applied after.
scalar_t roi_start_h = -roi_height / 2.0;
scalar_t roi_start_w = -roi_width / 2.0;
scalar_t cosscalar_theta = cos(theta);
scalar_t sinscalar_theta = sin(theta);
// We do average (integral) pooling inside a bin
const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
scalar_t output_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
const scalar_t yy =
roi_start_h + ph * bin_size_h +
static_cast<scalar_t>(iy + .5f) * bin_size_h /
static_cast<scalar_t>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const scalar_t xx = roi_start_w + pw * bin_size_w +
static_cast<scalar_t>(ix + .5f) * bin_size_w /
static_cast<scalar_t>(roi_bin_grid_w);
// Rotate by theta (counterclockwise) around the center and translate
scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h;
scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w;
scalar_t val = bilinear_interpolate<scalar_t>(
offset_bottom_data, height, width, y, x, index);
output_val += val;
}
}
output_val /= count;
top_data[index] = output_val;
}
}
/*** Backward ***/
template <typename scalar_t>
__global__ void roi_align_rotated_backward_musa_kernel(
const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois,
const scalar_t spatial_scale, const int sampling_ratio, const bool aligned,
const bool clockwise, const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, scalar_t *bottom_diff) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
int roi_batch_ind = offset_bottom_rois[0];
// Do not round
scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0;
scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale - offset;
scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale - offset;
scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
// scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
scalar_t theta = offset_bottom_rois[5];
if (clockwise) {
theta = -theta; // If clockwise, the angle needs to be reversed.
}
if (!aligned) { // for backward-compatibility only
// Force malformed ROIs to be 1x1
roi_width = max(roi_width, (scalar_t)1.);
roi_height = max(roi_height, (scalar_t)1.);
}
scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
static_cast<scalar_t>(pooled_height);
scalar_t bin_size_w =
static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
scalar_t *offset_bottom_diff =
bottom_diff + (roi_batch_ind * channels + c) * height * width;
int top_offset = (n * channels + c) * pooled_height * pooled_width;
const scalar_t *offset_top_diff = top_diff + top_offset;
const scalar_t top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceilf(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
// Appropriate translation needs to be applied after.
scalar_t roi_start_h = -roi_height / 2.0;
scalar_t roi_start_w = -roi_width / 2.0;
scalar_t cosTheta = cos(theta);
scalar_t sinTheta = sin(theta);
// We do average (integral) pooling inside a bin
const scalar_t count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
const scalar_t yy =
roi_start_h + ph * bin_size_h +
static_cast<scalar_t>(iy + .5f) * bin_size_h /
static_cast<scalar_t>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const scalar_t xx = roi_start_w + pw * bin_size_w +
static_cast<scalar_t>(ix + .5f) * bin_size_w /
static_cast<scalar_t>(roi_bin_grid_w);
// Rotate by theta around the center and translate
scalar_t y = yy * cosTheta - xx * sinTheta + roi_center_h;
scalar_t x = yy * sinTheta + xx * cosTheta + roi_center_w;
scalar_t w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient<scalar_t>(height, width, y, x, w1, w2, w3,
w4, x_low, x_high, y_low,
y_high, index);
scalar_t g1 = top_diff_this_bin * w1 / count;
scalar_t g2 = top_diff_this_bin * w2 / count;
scalar_t g3 = top_diff_this_bin * w3 / count;
scalar_t g4 = top_diff_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomicAdd(offset_bottom_diff + y_low * width + x_low, g1);
atomicAdd(offset_bottom_diff + y_low * width + x_high, g2);
atomicAdd(offset_bottom_diff + y_high * width + x_low, g3);
atomicAdd(offset_bottom_diff + y_high * width + x_high, g4);
} // if
} // ix
} // iy
} // MUSA_1D_KERNEL_LOOP
} // RoIAlignBackward
#endif // ROI_ALIGN_ROTATED_MUSA_KERNEL_MUH

View File

@ -0,0 +1,89 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef ROI_POOL_MUSA_KERNEL_MUH
#define ROI_POOL_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
template <typename T>
__global__ void roi_pool_forward_musa_kernel(
const int nthreads, const T* input, const T* rois, T* output, int* argmax,
const int pooled_height, const int pooled_width, const T spatial_scale,
const int channels, const int height, const int width) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// calculate the roi region on feature maps
T roi_x1 = offset_rois[1] * spatial_scale;
T roi_y1 = offset_rois[2] * spatial_scale;
T roi_x2 = (offset_rois[3] + 1) * spatial_scale;
T roi_y2 = (offset_rois[4] + 1) * spatial_scale;
// force malformed rois to be 1x1
T roi_w = roi_x2 - roi_x1;
T roi_h = roi_y2 - roi_y1;
if (roi_w <= 0 || roi_h <= 0) continue;
T bin_size_w = roi_w / static_cast<T>(pooled_width);
T bin_size_h = roi_h / static_cast<T>(pooled_height);
// the corresponding bin region
int bin_x1 = floorf(static_cast<T>(pw) * bin_size_w + roi_x1);
int bin_y1 = floorf(static_cast<T>(ph) * bin_size_h + roi_y1);
int bin_x2 = ceilf(static_cast<T>(pw + 1) * bin_size_w + roi_x1);
int bin_y2 = ceilf(static_cast<T>(ph + 1) * bin_size_h + roi_y1);
// add roi offsets and clip to input boundaries
bin_x1 = min(max(bin_x1, 0), width);
bin_y1 = min(max(bin_y1, 0), height);
bin_x2 = min(max(bin_x2, 0), width);
bin_y2 = min(max(bin_y2, 0), height);
bool is_empty = (bin_y2 <= bin_y1) || (bin_x2 <= bin_x1);
const T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
// Define an empty pooling region to be zero
// If nothing is pooled, argmax = -1 causes nothing to be backprop'd
T max_val = is_empty ? 0 : -FLT_MAX;
int max_idx = -1;
for (int h = bin_y1; h < bin_y2; ++h) {
for (int w = bin_x1; w < bin_x2; ++w) {
int offset = h * width + w;
if (offset_input[offset] > max_val) {
max_val = offset_input[offset];
max_idx = offset;
}
}
}
output[index] = max_val;
if (argmax != NULL) argmax[index] = max_idx;
}
}
template <typename T>
__global__ void roi_pool_backward_musa_kernel(
const int nthreads, const T* grad_output, const T* rois, const int* argmax,
T* grad_input, const int pooled_height, const int pooled_width,
const int channels, const int height, const int width) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c) is an element in the pooled output
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int roi_batch_ind = rois[n * 5];
T* grad_input_offset =
grad_input + ((roi_batch_ind * channels + c) * height * width);
int argmax_index = argmax[index];
if (argmax_index != -1) {
atomicAdd(grad_input_offset + argmax_index, grad_output[index]);
}
}
}
#endif // ROI_POOL_MUSA_KERNEL_MUH

View File

@ -0,0 +1,256 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef ROIAWARE_POOL3D_MUSA_KERNEL_MUH
#define ROIAWARE_POOL3D_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
template <typename T>
__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz,
T &local_x, T &local_y) {
T cosa = cos(-rz), sina = sin(-rz);
local_x = shift_x * cosa + shift_y * (-sina);
local_y = shift_x * sina + shift_y * cosa;
}
template <typename T>
__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x,
T &local_y) {
// param pt: (x, y, z)
// param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate,
// cz in the bottom center
T x = pt[0], y = pt[1], z = pt[2];
T cx = box3d[0], cy = box3d[1], cz = box3d[2];
T x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6];
cz += z_size /
2.0; // shift to the center since cz in box3d is the bottom center
if (fabsf(z - cz) > z_size / 2.0) return 0;
lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) &
(local_y > -y_size / 2.0) & (local_y < y_size / 2.0);
return in_flag;
}
template <typename T>
__global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num,
int out_x, int out_y, int out_z,
const T *rois, const T *pts,
int *pts_mask) {
// params rois: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate params pts: (npoints, 3) [x, y, z] params pts_mask: (N,
// npoints): -1 means point does not in this box, otherwise: encode (x_idxs,
// y_idxs, z_idxs) by binary bit
int box_idx = blockIdx.y;
MUSA_1D_KERNEL_LOOP(pt_idx, pts_num) {
if (box_idx >= boxes_num) return;
pts += pt_idx * 3;
rois += box_idx * 7;
pts_mask += box_idx * pts_num + pt_idx;
T local_x = 0, local_y = 0;
int cur_in_flag = check_pt_in_box3d(pts, rois, local_x, local_y);
pts_mask[0] = -1;
if (cur_in_flag > 0) {
T local_z = pts[2] - rois[2];
T x_size = rois[3], y_size = rois[4], z_size = rois[5];
T x_res = x_size / out_x;
T y_res = y_size / out_y;
T z_res = z_size / out_z;
unsigned int x_idx = int((local_x + x_size / 2) / x_res);
unsigned int y_idx = int((local_y + y_size / 2) / y_res);
unsigned int z_idx = int(local_z / z_res);
x_idx = min(max(x_idx, 0), out_x - 1);
y_idx = min(max(y_idx, 0), out_y - 1);
z_idx = min(max(z_idx, 0), out_z - 1);
unsigned int idx_encoding = (x_idx << 16) + (y_idx << 8) + z_idx;
pts_mask[0] = idx_encoding;
}
}
}
template <typename T>
__global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num,
int max_pts_each_voxel, int out_x,
int out_y, int out_z,
const int *pts_mask,
T *pts_idx_of_voxels) {
// params pts_mask: (N, npoints) 0 or 1
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
MUSA_1D_KERNEL_LOOP(box_idx, boxes_num) {
int max_num_pts = max_pts_each_voxel - 1; // index 0 is the counter
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel;
for (int k = 0; k < pts_num; k++) {
if (pts_mask[box_idx * pts_num + k] != -1) {
unsigned int idx_encoding = pts_mask[box_idx * pts_num + k];
unsigned int x_idx = (idx_encoding >> 16) & 0xFF;
unsigned int y_idx = (idx_encoding >> 8) & 0xFF;
unsigned int z_idx = idx_encoding & 0xFF;
unsigned int base_offset = x_idx * out_y * out_z * max_pts_each_voxel +
y_idx * out_z * max_pts_each_voxel +
z_idx * max_pts_each_voxel;
unsigned int cnt = pts_idx_of_voxels[base_offset];
if (cnt < max_num_pts) {
pts_idx_of_voxels[base_offset + cnt + 1] = k;
pts_idx_of_voxels[base_offset]++;
}
}
}
}
}
template <typename T>
__global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels,
int max_pts_each_voxel, int out_x, int out_y,
int out_z, const T *pts_feature,
const int *pts_idx_of_voxels,
T *pooled_features, int *argmax) {
// params pts_feature: (npoints, C)
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel),
// index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C)
// params argmax: (N, out_x, out_y, out_z, C)
int box_idx = blockIdx.z;
int channel_idx = blockIdx.y;
MUSA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels) return;
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
offset_base * max_pts_each_voxel;
pooled_features += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
argmax += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
int argmax_idx = -1;
float max_val = -1e50;
int total_pts = pts_idx_of_voxels[0];
for (int k = 1; k <= total_pts; k++) {
if (pts_feature[pts_idx_of_voxels[k] * channels + channel_idx] >
max_val) {
max_val = pts_feature[pts_idx_of_voxels[k] * channels + channel_idx];
argmax_idx = pts_idx_of_voxels[k];
}
}
if (argmax_idx != -1) {
pooled_features[0] = max_val;
}
argmax[0] = argmax_idx;
}
}
template <typename T>
__global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels,
int max_pts_each_voxel, int out_x, int out_y,
int out_z, const T *pts_feature,
const int *pts_idx_of_voxels,
T *pooled_features) {
// params pts_feature: (npoints, C)
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel),
// index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C)
// params argmax: (N, out_x, out_y, out_z, C)
int box_idx = blockIdx.z;
int channel_idx = blockIdx.y;
MUSA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels) return;
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
offset_base * max_pts_each_voxel;
pooled_features += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
float sum_val = 0;
int total_pts = pts_idx_of_voxels[0];
for (int k = 1; k <= total_pts; k++) {
sum_val += pts_feature[pts_idx_of_voxels[k] * channels + channel_idx];
}
if (total_pts > 0) {
pooled_features[0] = sum_val / total_pts;
}
}
}
template <typename T>
__global__ void roiaware_maxpool3d_backward(int boxes_num, int channels,
int out_x, int out_y, int out_z,
const int *argmax,
const T *grad_out, T *grad_in) {
// params argmax: (N, out_x, out_y, out_z, C)
// params grad_out: (N, out_x, out_y, out_z, C)
// params grad_in: (npoints, C), return value
int box_idx = blockIdx.z;
int channel_idx = blockIdx.y;
MUSA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels) return;
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
argmax += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
grad_out += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
if (argmax[0] == -1) return;
atomicAdd(grad_in + argmax[0] * channels + channel_idx, grad_out[0] * 1);
}
}
template <typename T>
__global__ void roiaware_avgpool3d_backward(int boxes_num, int channels,
int out_x, int out_y, int out_z,
int max_pts_each_voxel,
const int *pts_idx_of_voxels,
const T *grad_out, T *grad_in) {
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
// params grad_out: (N, out_x, out_y, out_z, C)
// params grad_in: (npoints, C), return value
int box_idx = blockIdx.z;
int channel_idx = blockIdx.y;
MUSA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels) return;
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
offset_base * max_pts_each_voxel;
grad_out += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
int total_pts = pts_idx_of_voxels[0];
float cur_grad = 1 / fmaxf(float(total_pts), 1.0);
for (int k = 1; k <= total_pts; k++) {
atomicAdd(grad_in + pts_idx_of_voxels[k] * channels + channel_idx,
grad_out[0] * cur_grad);
}
}
}
#endif // ROIAWARE_POOL3D_MUSA_KERNEL_MUH

View File

@ -0,0 +1,130 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef ROIPOINT_POOL3D_MUSA_KERNEL_MUH
#define ROIPOINT_POOL3D_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
template <typename T>
__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz,
T &local_x, T &local_y) {
T cosa = cos(-rz), sina = sin(-rz);
local_x = shift_x * cosa + shift_y * (-sina);
local_y = shift_x * sina + shift_y * cosa;
}
template <typename T>
__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x,
T &local_y) {
// param pt: (x, y, z)
// param box3d: (cx, cy, cz, dx, dy, dz, rz) in LiDAR coordinate, cz in the
// bottom center
T x = pt[0], y = pt[1], z = pt[2];
T cx = box3d[0], cy = box3d[1], cz = box3d[2];
T dx = box3d[3], dy = box3d[4], dz = box3d[5], rz = box3d[6];
cz += dz / 2.0; // shift to the center since cz in box3d is the bottom center
if (fabsf(z - cz) > dz / 2.0) return 0;
lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
T in_flag = (local_x > -dx / 2.0) & (local_x < dx / 2.0) &
(local_y > -dy / 2.0) & (local_y < dy / 2.0);
return in_flag;
}
template <typename T>
__global__ void assign_pts_to_box3d(int batch_size, int pts_num, int boxes_num,
const T *xyz, const T *boxes3d,
int *pts_assign) {
// params xyz: (B, N, 3)
// params boxes3d: (B, M, 7)
// params pts_assign: (B, N, M): idx of the corresponding box3d, -1 means
// background points
int box_idx = blockIdx.y;
int bs_idx = blockIdx.z;
MUSA_1D_KERNEL_LOOP(pt_idx, pts_num) {
if (box_idx >= boxes_num || bs_idx >= batch_size) return;
int assign_idx =
bs_idx * pts_num * boxes_num + pt_idx * boxes_num + box_idx;
pts_assign[assign_idx] = 0;
int box_offset = bs_idx * boxes_num * 7 + box_idx * 7;
int pt_offset = bs_idx * pts_num * 3 + pt_idx * 3;
T local_x = 0, local_y = 0;
int cur_in_flag = check_pt_in_box3d(xyz + pt_offset, boxes3d + box_offset,
local_x, local_y);
pts_assign[assign_idx] = cur_in_flag;
}
}
__global__ void get_pooled_idx(int batch_size, int pts_num, int boxes_num,
int sampled_pts_num, const int *pts_assign,
int *pts_idx, int *pooled_empty_flag) {
// params xyz: (B, N, 3)
// params pts_feature: (B, N, C)
// params pts_assign: (B, N)
// params pts_idx: (B, M, 512)
// params pooled_empty_flag: (B, M)
MUSA_1D_KERNEL_LOOP(boxes_idx, boxes_num) {
int bs_idx = blockIdx.y;
int cnt = 0;
for (int k = 0; k < pts_num; k++) {
if (pts_assign[bs_idx * pts_num * boxes_num + k * boxes_num +
boxes_idx]) {
if (cnt < sampled_pts_num) {
pts_idx[bs_idx * boxes_num * sampled_pts_num +
boxes_idx * sampled_pts_num + cnt] = k;
cnt++;
} else
break;
}
}
if (cnt == 0) {
pooled_empty_flag[bs_idx * boxes_num + boxes_idx] = 1;
} else if (cnt < sampled_pts_num) {
// duplicate same points for sampling
for (int k = cnt; k < sampled_pts_num; k++) {
int duplicate_idx = k % cnt;
int base_offset =
bs_idx * boxes_num * sampled_pts_num + boxes_idx * sampled_pts_num;
pts_idx[base_offset + k] = pts_idx[base_offset + duplicate_idx];
}
}
}
}
template <typename T>
__global__ void roipoint_pool3d_forward(
int batch_size, int pts_num, int boxes_num, int feature_in_len,
int sampled_pts_num, const T *xyz, const int *pts_idx, const T *pts_feature,
T *pooled_features, int *pooled_empty_flag) {
// params xyz: (B, N, 3)
// params pts_idx: (B, M, 512)
// params pts_feature: (B, N, C)
// params pooled_features: (B, M, 512, 3+C)
// params pooled_empty_flag: (B, M)
int box_idx = blockIdx.y;
int bs_idx = blockIdx.z;
MUSA_1D_KERNEL_LOOP(sample_pt_idx, sampled_pts_num) {
if (box_idx >= boxes_num || bs_idx >= batch_size) return;
if (pooled_empty_flag[bs_idx * boxes_num + box_idx]) return;
int temp_idx = bs_idx * boxes_num * sampled_pts_num +
box_idx * sampled_pts_num + sample_pt_idx;
int src_pt_idx = pts_idx[temp_idx];
int dst_feature_offset = temp_idx * (3 + feature_in_len);
for (int j = 0; j < 3; j++)
pooled_features[dst_feature_offset + j] =
xyz[bs_idx * pts_num * 3 + src_pt_idx * 3 + j];
int src_feature_offset =
bs_idx * pts_num * feature_in_len + src_pt_idx * feature_in_len;
memcpy(pooled_features + dst_feature_offset + 3,
pts_feature + src_feature_offset, feature_in_len * sizeof(T));
}
}
#endif // ROIPOINT_POOL3D_MUSA_KERNEL_MUH

View File

@ -0,0 +1,125 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/SJTU-Thinklab-Det/r3det-on-mmdetection/blob/master/mmdet/ops/fr/src/feature_refine_kernel.cu
#ifndef ROTATED_FEATURE_ALIGN_MUSA_KERNEL_MUH
#define ROTATED_FEATURE_ALIGN_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
template <typename scalar_t>
__global__ void rotated_feature_align_forward_kernel(
const int nthreads, const int points, const scalar_t* bottom_data,
const scalar_t* best_bboxes, const scalar_t spatial_scale,
const int channels, const int height, const int width, scalar_t* top_data) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
int w = index % width;
int h = (index / width) % height;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
const scalar_t* bbox_offset =
best_bboxes + ((n * height + h) * width + w) * 5;
scalar_t roi_y = bbox_offset[0] * spatial_scale;
scalar_t roi_x = bbox_offset[1] * spatial_scale;
scalar_t px[5] = {roi_x, 0, 0, 0, 0};
scalar_t py[5] = {roi_y, 0, 0, 0, 0};
if (points > 1) {
scalar_t roi_w = bbox_offset[2] * spatial_scale;
scalar_t roi_h = bbox_offset[3] * spatial_scale;
scalar_t roi_a = bbox_offset[4];
scalar_t w_2 = roi_w / 2, h_2 = roi_h / 2;
scalar_t cosa = cosf(roi_a), sina = sinf(roi_a);
scalar_t wx = cosa * w_2, wy = sina * w_2;
scalar_t hx = -sina * h_2, hy = cosa * h_2;
px[1] = roi_x + wx + hx;
py[1] = roi_y + wy + hy;
px[2] = roi_x - wx + hx;
py[2] = roi_y - wy + hy;
px[3] = roi_x - wx - hx;
py[3] = roi_y - wy - hy;
px[4] = roi_x + wx - hx;
py[4] = roi_y + wy - hy;
}
const scalar_t* offset_bottom_data =
bottom_data + (n * channels + c) * height * width;
scalar_t output_val = bottom_data[index];
for (int i = 0; i < points; i++) {
output_val += bilinear_interpolate<scalar_t>(offset_bottom_data, height,
width, py[i], px[i], i);
}
top_data[index] = output_val;
}
}
template <typename scalar_t>
__global__ void rotated_feature_align_backward_kernel(
const int nthreads, const int points, const scalar_t* top_diff,
const scalar_t* best_bboxes, const scalar_t spatial_scale,
const int channels, const int height, const int width,
scalar_t* bottom_diff) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
int w = index % width;
int h = (index / width) % height;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
const scalar_t* bbox_offset =
best_bboxes + ((n * height + h) * width + w) * 5;
scalar_t roi_y = bbox_offset[0] * spatial_scale;
scalar_t roi_x = bbox_offset[1] * spatial_scale;
scalar_t px[5] = {roi_x, 0, 0, 0, 0};
scalar_t py[5] = {roi_y, 0, 0, 0, 0};
if (points > 1) {
scalar_t roi_w = bbox_offset[2] * spatial_scale;
scalar_t roi_h = bbox_offset[3] * spatial_scale;
scalar_t roi_a = bbox_offset[4];
scalar_t w_2 = roi_w / 2, h_2 = roi_h / 2;
scalar_t cosa = cosf(roi_a), sina = sinf(roi_a);
scalar_t wx = cosa * w_2, wy = sina * w_2;
scalar_t hx = -sina * h_2, hy = cosa * h_2;
px[1] = roi_x + wx + hx;
py[1] = roi_y + wy + hy;
px[2] = roi_x - wx + hx;
py[2] = roi_y - wy + hy;
px[3] = roi_x - wx - hx;
py[3] = roi_y - wy - hy;
px[4] = roi_x + wx - hx;
py[4] = roi_y + wy - hy;
}
scalar_t* offset_bottom_diff =
bottom_diff + (n * channels + c) * height * width;
scalar_t value_top_diff = top_diff[index];
atomicAdd(bottom_diff + index, value_top_diff);
for (int i = 0; i < points; i++) {
scalar_t w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient<scalar_t>(height, width, py[i], px[i], w1,
w2, w3, w4, x_low, x_high, y_low,
y_high, i);
scalar_t g1 = value_top_diff * w1;
scalar_t g2 = value_top_diff * w2;
scalar_t g3 = value_top_diff * w3;
scalar_t g4 = value_top_diff * w4;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomicAdd(offset_bottom_diff + y_low * width + x_low, g1);
atomicAdd(offset_bottom_diff + y_low * width + x_high, g2);
atomicAdd(offset_bottom_diff + y_high * width + x_low, g3);
atomicAdd(offset_bottom_diff + y_high * width + x_high, g4);
}
}
}
}
#endif // ROTATED_FEATURE_ALIGN_MUSA_KERNEL_MUH

View File

@ -0,0 +1,137 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef SCATTER_POINTS_MUSA_KERNEL_MUH
#define SCATTER_POINTS_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t;
int const maxGridDim = 50000;
__device__ __forceinline__ static void reduceMax(float *address, float val) {
int *address_as_i = reinterpret_cast<int *>(address);
int old = *address_as_i, assumed;
do {
assumed = old;
old = atomicCAS(address_as_i, assumed,
__float_as_int(fmaxf(val, __int_as_float(assumed))));
} while (assumed != old || __int_as_float(old) < val);
}
__device__ __forceinline__ static void reduceMax(double *address, double val) {
unsigned long long *address_as_ull =
reinterpret_cast<unsigned long long *>(address);
unsigned long long old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(
address_as_ull, assumed,
__double_as_longlong(fmax(val, __longlong_as_double(assumed))));
} while (assumed != old || __longlong_as_double(old) < val);
}
__device__ __forceinline__ static void reduceAdd(float *address, float val) {
atomicAdd(address, val);
}
__device__ __forceinline__ static void reduceAdd(double *address, double val) {
atomicAdd(address, val);
}
template <typename T>
__global__ void feats_reduce_kernel(
const T *feats, const int32_t *coors_map,
T *reduced_feats, // shall be 0 at initialization
const int num_input, const int num_feats, const reduce_t reduce_type) {
MUSA_1D_KERNEL_LOOP(x, num_input) {
int32_t reduce_to = coors_map[x];
if (reduce_to == -1) continue;
const T *feats_offset = feats + x * num_feats;
T *reduced_feats_offset = reduced_feats + reduce_to * num_feats;
if (reduce_type == reduce_t::MAX) {
for (int i = 0; i < num_feats; i++) {
reduceMax(&reduced_feats_offset[i], feats_offset[i]);
}
} else {
for (int i = 0; i < num_feats; i++) {
reduceAdd(&reduced_feats_offset[i], feats_offset[i]);
}
}
}
}
template <typename T>
__global__ void add_reduce_traceback_grad_kernel(
T *grad_feats, const T *grad_reduced_feats, const int32_t *coors_map,
const int32_t *reduce_count, const int num_input, const int num_feats,
const reduce_t reduce_type) {
MUSA_1D_KERNEL_LOOP(x, num_input) {
int32_t reduce_to = coors_map[x];
if (reduce_to == -1) {
continue;
}
const int input_offset = x * num_feats;
T *grad_feats_offset = grad_feats + input_offset;
const int reduced_offset = reduce_to * num_feats;
const T *grad_reduced_feats_offset = grad_reduced_feats + reduced_offset;
if (reduce_type == reduce_t::SUM) {
for (int i = 0; i < num_feats; i++) {
grad_feats_offset[i] = grad_reduced_feats_offset[i];
}
} else if (reduce_type == reduce_t::MEAN) {
for (int i = 0; i < num_feats; i++) {
grad_feats_offset[i] = grad_reduced_feats_offset[i] /
static_cast<T>(reduce_count[reduce_to]);
}
}
}
}
template <typename T>
__global__ void max_reduce_traceback_scatter_idx_kernel(
const T *feats, const T *reduced_feats, int32_t *reduce_from,
const int32_t *coors_map, const int num_input, const int num_feats) {
MUSA_1D_KERNEL_LOOP(x, num_input) {
int32_t reduce_to = coors_map[x];
const int input_offset = x * num_feats;
const T *feats_offset = feats + input_offset;
if (reduce_to == -1) {
continue;
}
const int reduced_offset = reduce_to * num_feats;
const T *reduced_feats_offset = reduced_feats + reduced_offset;
int32_t *reduce_from_offset = reduce_from + reduced_offset;
for (int i = 0; i < num_feats; i++) {
if (feats_offset[i] == reduced_feats_offset[i]) {
atomicMin(&reduce_from_offset[i], static_cast<int32_t>(x));
}
}
}
}
template <typename T>
__global__ void max_reduce_scatter_grad_kernel(T *grad_feats,
const T *grad_reduced_feats,
const int32_t *reduce_from,
const int num_reduced,
const int num_feats) {
MUSA_1D_KERNEL_LOOP(x, num_reduced) {
const int reduced_offset = x * num_feats;
const int32_t *scatter_to_offset = reduce_from + reduced_offset;
const T *grad_reduced_feats_offset = grad_reduced_feats + reduced_offset;
for (int i = 0; i < num_feats; i++) {
grad_feats[scatter_to_offset[i] * num_feats + i] =
grad_reduced_feats_offset[i];
}
}
}
#endif // SCATTER_POINTS_MUSA_KERNEL_MUH

View File

@ -0,0 +1,327 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef SYNCBN_MUSA_KERNEL_MUH
#define SYNCBN_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
template <typename T>
__global__ void sync_bn_forward_mean_musa_kernel(const T *input, float *mean,
int num, int channels,
int spatial) {
__shared__ float buffer[THREADS_PER_BLOCK];
int tid = threadIdx.x;
int c = blockIdx.x;
buffer[tid] = 0;
for (int i = tid; i < num * spatial; i += blockDim.x) {
int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
buffer[tid] += input[index];
}
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
buffer[tid] += buffer[tid + s];
}
__syncthreads();
}
int total = num * spatial;
if (tid == 0) {
mean[c] = buffer[0] / total;
}
}
template <>
__global__ void sync_bn_forward_mean_musa_kernel(const phalf *input,
float *mean, int num,
int channels, int spatial) {
__shared__ float buffer[THREADS_PER_BLOCK];
int tid = threadIdx.x;
int c = blockIdx.x;
buffer[tid] = 0;
for (int i = tid; i < num * spatial; i += blockDim.x) {
int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
buffer[tid] += static_cast<float>(input[index]);
}
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
buffer[tid] += buffer[tid + s];
}
__syncthreads();
}
int total = num * spatial;
if (tid == 0) {
mean[c] = buffer[0] / total;
}
}
template <typename T>
__global__ void sync_bn_forward_var_musa_kernel(const T *input,
const float *mean, float *var,
int num, int channels,
int spatial) {
__shared__ float buffer[THREADS_PER_BLOCK];
int tid = threadIdx.x;
int c = blockIdx.x;
buffer[tid] = 0;
for (int i = tid; i < num * spatial; i += blockDim.x) {
int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
float td = input[index] - mean[c];
buffer[tid] += td * td;
}
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
buffer[tid] += buffer[tid + s];
}
__syncthreads();
}
int total = num * spatial;
if (tid == 0) {
var[c] = buffer[0] / total;
}
}
template <>
__global__ void sync_bn_forward_var_musa_kernel(const phalf *input,
const float *mean, float *var,
int num, int channels,
int spatial) {
__shared__ float buffer[THREADS_PER_BLOCK];
int tid = threadIdx.x;
int c = blockIdx.x;
buffer[tid] = 0;
for (int i = tid; i < num * spatial; i += blockDim.x) {
int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
float td = static_cast<float>(input[index]) - mean[c];
buffer[tid] += td * td;
}
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
buffer[tid] += buffer[tid + s];
}
__syncthreads();
}
int total = num * spatial;
if (tid == 0) {
var[c] = buffer[0] / total;
}
}
template <typename T>
__global__ void sync_bn_forward_output_musa_kernel(
const T *input, const float *mean, const float *var, float *running_mean,
float *running_var, const float *weight, const float *bias, float *norm,
float *std, T *output, int num, int channels, int spatial, float eps,
float momentum, int group_size) {
int tid = threadIdx.x;
int c = blockIdx.x;
float mean_value = mean[c];
float std_value = sqrt(var[c] + eps);
if (weight != nullptr) {
float weight_value = weight[c];
float bias_value = bias[c];
if (norm != nullptr) {
for (int i = tid; i < num * spatial; i += blockDim.x) {
int index =
(i / spatial) * channels * spatial + c * spatial + i % spatial;
norm[index] = (input[index] - mean_value) / std_value;
output[index] = norm[index] * weight_value + bias_value;
}
} else {
for (int i = tid; i < num * spatial; i += blockDim.x) {
int index =
(i / spatial) * channels * spatial + c * spatial + i % spatial;
output[index] =
(input[index] - mean_value) / std_value * weight_value + bias_value;
}
}
} else {
if (norm != nullptr) {
for (int i = tid; i < num * spatial; i += blockDim.x) {
int index =
(i / spatial) * channels * spatial + c * spatial + i % spatial;
output[index] = norm[index] = (input[index] - mean_value) / std_value;
}
} else {
for (int i = tid; i < num * spatial; i += blockDim.x) {
int index =
(i / spatial) * channels * spatial + c * spatial + i % spatial;
output[index] = (input[index] - mean_value) / std_value;
}
}
}
if (tid == 0) {
if (std != nullptr) std[c] = std_value;
if (running_mean != nullptr) {
running_mean[c] =
momentum * mean_value + (1 - momentum) * running_mean[c];
int count = num * spatial * group_size;
float var_unbias = count > 1 ? var[c] * count / (count - 1) : var[c];
running_var[c] = momentum * var_unbias + (1 - momentum) * running_var[c];
}
}
}
template <>
__global__ void sync_bn_forward_output_musa_kernel(
const phalf *input, const float *mean, const float *var,
float *running_mean, float *running_var, const float *weight,
const float *bias, float *norm, float *std, phalf *output, int num,
int channels, int spatial, float eps, float momentum, int group_size) {
int tid = threadIdx.x;
int c = blockIdx.x;
float mean_value = mean[c];
float std_value = sqrt(var[c] + eps);
if (weight != nullptr) {
float weight_value = weight[c];
float bias_value = bias[c];
if (norm != nullptr) {
for (int i = tid; i < num * spatial; i += blockDim.x) {
int index =
(i / spatial) * channels * spatial + c * spatial + i % spatial;
norm[index] =
(static_cast<float>(input[index]) - mean_value) / std_value;
output[index] =
static_cast<phalf>(norm[index] * weight_value + bias_value);
}
} else {
for (int i = tid; i < num * spatial; i += blockDim.x) {
int index =
(i / spatial) * channels * spatial + c * spatial + i % spatial;
output[index] =
static_cast<phalf>((static_cast<float>(input[index]) - mean_value) /
std_value * weight_value +
bias_value);
}
}
} else {
if (norm != nullptr) {
for (int i = tid; i < num * spatial; i += blockDim.x) {
int index =
(i / spatial) * channels * spatial + c * spatial + i % spatial;
norm[index] =
(static_cast<float>(input[index]) - mean_value) / std_value;
output[index] = static_cast<phalf>(norm[index]);
}
} else {
for (int i = tid; i < num * spatial; i += blockDim.x) {
int index =
(i / spatial) * channels * spatial + c * spatial + i % spatial;
output[index] = static_cast<phalf>(
(static_cast<float>(input[index]) - mean_value) / std_value);
}
}
}
if (tid == 0) {
if (std != nullptr) std[c] = std_value;
if (running_mean != nullptr) {
running_mean[c] =
momentum * mean_value + (1 - momentum) * running_mean[c];
int count = num * spatial * group_size;
float var_unbias = count > 1 ? var[c] * count / (count - 1) : var[c];
running_var[c] = momentum * var_unbias + (1 - momentum) * running_var[c];
}
}
}
template <typename T>
__global__ void sync_bn_backward_param_musa_kernel(const T *grad_output,
const float *norm,
float *grad_weight,
float *grad_bias, int num,
int channels, int spatial) {
__shared__ float buffer1[THREADS_PER_BLOCK];
__shared__ float buffer2[THREADS_PER_BLOCK];
int tid = threadIdx.x;
int c = blockIdx.x;
buffer1[tid] = buffer2[tid] = 0;
for (int i = tid; i < num * spatial; i += blockDim.x) {
int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
buffer1[tid] += grad_output[index] * norm[index];
buffer2[tid] += grad_output[index];
}
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
buffer1[tid] += buffer1[tid + s];
buffer2[tid] += buffer2[tid + s];
}
__syncthreads();
}
if (tid == 0) {
grad_weight[c] = buffer1[0];
grad_bias[c] = buffer2[0];
}
}
template <>
__global__ void sync_bn_backward_param_musa_kernel(const phalf *grad_output,
const float *norm,
float *grad_weight,
float *grad_bias, int num,
int channels, int spatial) {
__shared__ float buffer1[THREADS_PER_BLOCK];
__shared__ float buffer2[THREADS_PER_BLOCK];
int tid = threadIdx.x;
int c = blockIdx.x;
buffer1[tid] = buffer2[tid] = 0;
for (int i = tid; i < num * spatial; i += blockDim.x) {
int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
buffer1[tid] += static_cast<float>(grad_output[index]) * norm[index];
buffer2[tid] += static_cast<float>(grad_output[index]);
}
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
buffer1[tid] += buffer1[tid + s];
buffer2[tid] += buffer2[tid + s];
}
__syncthreads();
}
if (tid == 0) {
grad_weight[c] = buffer1[0];
grad_bias[c] = buffer2[0];
}
}
template <typename T>
__global__ void sync_bn_backward_data_musa_kernel(
int output_size, const T *grad_output, const float *weight,
const float *grad_weight, const float *grad_bias, const float *norm,
const float *std, T *grad_input, int num, int channels, int spatial) {
int factor = num * spatial;
MUSA_1D_KERNEL_LOOP(index, output_size) {
int c = (index / spatial) % channels;
grad_input[index] =
weight[c] *
(grad_output[index] -
(grad_weight[c] * norm[index] + grad_bias[c]) / factor) /
std[c];
}
}
template <>
__global__ void sync_bn_backward_data_musa_kernel(
int output_size, const phalf *grad_output, const float *weight,
const float *grad_weight, const float *grad_bias, const float *norm,
const float *std, phalf *grad_input, int num, int channels, int spatial) {
int factor = num * spatial;
MUSA_1D_KERNEL_LOOP(index, output_size) {
int c = (index / spatial) % channels;
grad_input[index] = static_cast<phalf>(
weight[c] *
(static_cast<float>(grad_output[index]) -
(grad_weight[c] * norm[index] + grad_bias[c]) / factor) /
std[c]);
}
}
#endif // SYNCBN_MUSA_KERNEL_MUH

View File

@ -0,0 +1,57 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef THREE_INTERPOLATE_MUSA_KERNEL_MUH
#define THREE_INTERPOLATE_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
template <typename T>
__global__ void three_interpolate_forward_musa_kernel(
int b, int c, int m, int n, const T *points, const int *__restrict__ idx,
const T *weight, T *out) {
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
MUSA_1D_KERNEL_LOOP(pt_idx, n) {
if (bs_idx >= b || c_idx >= c) return;
weight += bs_idx * n * 3 + pt_idx * 3;
points += bs_idx * c * m + c_idx * m;
idx += bs_idx * n * 3 + pt_idx * 3;
out += bs_idx * c * n + c_idx * n;
out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] +
weight[2] * points[idx[2]];
}
}
template <typename T>
__global__ void three_interpolate_backward_musa_kernel(
int b, int c, int n, int m, const T *grad_out, const int *__restrict__ idx,
const T *weight, T *grad_points) {
// grad_out: (B, C, N)
// weight: (B, N, 3)
// output:
// grad_points: (B, C, M)
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
MUSA_1D_KERNEL_LOOP(pt_idx, n) {
if (bs_idx >= b || c_idx >= c) return;
grad_out += bs_idx * c * n + c_idx * n + pt_idx;
weight += bs_idx * n * 3 + pt_idx * 3;
grad_points += bs_idx * c * m + c_idx * m;
idx += bs_idx * n * 3 + pt_idx * 3;
atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
}
}
#endif // THREE_INTERPOLATE_MUSA_KERNEL_MUH

View File

@ -0,0 +1,63 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef THREE_NN_MUSA_KERNEL_MUH
#define THREE_NN_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
template <typename T>
__global__ void three_nn_forward_musa_kernel(int b, int n, int m,
const T *unknown, const T *known,
T *dist2, int *__restrict__ idx) {
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
int bs_idx = blockIdx.y;
MUSA_1D_KERNEL_LOOP(pt_idx, n) {
if (bs_idx >= b) return;
unknown += bs_idx * n * 3 + pt_idx * 3;
known += bs_idx * m * 3;
dist2 += bs_idx * n * 3 + pt_idx * 3;
idx += bs_idx * n * 3 + pt_idx * 3;
T ux = unknown[0];
T uy = unknown[1];
T uz = unknown[2];
double best1 = 1e40, best2 = 1e40, best3 = 1e40;
int besti1 = 0, besti2 = 0, besti3 = 0;
for (int k = 0; k < m; ++k) {
T x = known[k * 3 + 0];
T y = known[k * 3 + 1];
T z = known[k * 3 + 2];
T d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
if (d < best1) {
best3 = best2;
besti3 = besti2;
best2 = best1;
besti2 = besti1;
best1 = d;
besti1 = k;
} else if (d < best2) {
best3 = best2;
besti3 = besti2;
best2 = d;
besti2 = k;
} else if (d < best3) {
best3 = d;
besti3 = k;
}
}
dist2[0] = best1;
dist2[1] = best2;
dist2[2] = best3;
idx[0] = besti1;
idx[1] = besti2;
idx[2] = besti3;
}
}
#endif // THREE_NN_MUSA_KERNEL_MUH

View File

@ -0,0 +1,57 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef TIN_SHIFT_MUSA_KERNEL_MUH
#define TIN_SHIFT_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
template <typename T>
__global__ void tin_shift_forward_musa_kernel(
const int nthreads, const T* input, const int* shift, T* output,
const int batch_size, const int channels, const int t_size,
const int hw_size, const int group_size, const int group_channel) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
const int hw_index = index % hw_size;
const int j = (index / hw_size) % channels;
const int n_index = (index / hw_size / channels) % batch_size;
int group_id = j / group_channel;
int t_shift = shift[n_index * group_size + group_id];
int offset = n_index * t_size * hw_size * channels + hw_size * j + hw_index;
for (int i = 0; i < t_size; i++) {
int now_t = i + t_shift;
int data_id = i * hw_size * channels + offset;
if (now_t < 0 || now_t >= t_size) {
continue;
}
int out_id = now_t * hw_size * channels + offset;
output[out_id] = input[data_id];
}
}
}
template <typename T>
__global__ void tin_shift_backward_musa_kernel(
const int nthreads, const T* input, const int* shift, T* output,
const int batch_size, const int channels, const int t_size,
const int hw_size, const int group_size, const int group_channel) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
const int hw_index = index % hw_size;
const int j = (index / hw_size) % channels;
const int n_index = (index / hw_size / channels) % batch_size;
int group_id = j / group_channel;
int t_shift = shift[n_index * group_size + group_id];
int offset = n_index * t_size * hw_size * channels + hw_size * j + hw_index;
for (int i = 0; i < t_size; i++) {
int now_t = i + t_shift;
int data_id = i * hw_size * channels + offset;
if (now_t < 0 || now_t >= t_size) {
continue;
}
int out_id = now_t * hw_size * channels + offset;
output[out_id] = input[data_id];
}
}
}
#endif // TIN_SHIFT_MUSA_KERNEL_MUH

View File

@ -0,0 +1,212 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef VOXELIZATION_MUSA_KERNEL_MUH
#define VOXELIZATION_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t;
template <typename T, typename T_int>
__global__ void dynamic_voxelize_kernel(
const T* points, T_int* coors, const float voxel_x, const float voxel_y,
const float voxel_z, const float coors_x_min, const float coors_y_min,
const float coors_z_min, const float coors_x_max, const float coors_y_max,
const float coors_z_max, const int grid_x, const int grid_y,
const int grid_z, const int num_points, const int num_features,
const int NDim) {
// const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
MUSA_1D_KERNEL_LOOP(index, num_points) {
// To save some computation
auto points_offset = points + index * num_features;
auto coors_offset = coors + index * NDim;
int c_x = floorf((points_offset[0] - coors_x_min) / voxel_x);
if (c_x < 0 || c_x >= grid_x) {
coors_offset[0] = -1;
continue;
}
int c_y = floorf((points_offset[1] - coors_y_min) / voxel_y);
if (c_y < 0 || c_y >= grid_y) {
coors_offset[0] = -1;
coors_offset[1] = -1;
continue;
}
int c_z = floorf((points_offset[2] - coors_z_min) / voxel_z);
if (c_z < 0 || c_z >= grid_z) {
coors_offset[0] = -1;
coors_offset[1] = -1;
coors_offset[2] = -1;
} else {
coors_offset[0] = c_z;
coors_offset[1] = c_y;
coors_offset[2] = c_x;
}
}
}
template <typename T, typename T_int>
__global__ void assign_point_to_voxel(const int nthreads, const T* points,
T_int* point_to_voxelidx,
T_int* coor_to_voxelidx, T* voxels,
const int max_points,
const int num_features,
const int num_points, const int NDim) {
MUSA_1D_KERNEL_LOOP(thread_idx, nthreads) {
// const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
int index = thread_idx / num_features;
int num = point_to_voxelidx[index];
int voxelidx = coor_to_voxelidx[index];
if (num > -1 && voxelidx > -1) {
auto voxels_offset =
voxels + voxelidx * max_points * num_features + num * num_features;
int k = thread_idx % num_features;
voxels_offset[k] = points[thread_idx];
}
}
}
template <typename T, typename T_int>
__global__ void assign_voxel_coors(const int nthreads, T_int* coor,
T_int* point_to_voxelidx,
T_int* coor_to_voxelidx, T_int* voxel_coors,
const int num_points, const int NDim) {
MUSA_1D_KERNEL_LOOP(thread_idx, nthreads) {
// const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
// if (index >= num_points) return;
int index = thread_idx / NDim;
int num = point_to_voxelidx[index];
int voxelidx = coor_to_voxelidx[index];
if (num == 0 && voxelidx > -1) {
auto coors_offset = voxel_coors + voxelidx * NDim;
int k = thread_idx % NDim;
coors_offset[k] = coor[thread_idx];
}
}
}
template <typename T_int>
__global__ void point_to_voxelidx_kernel(const T_int* coor,
T_int* point_to_voxelidx,
T_int* point_to_pointidx,
const int max_points,
const int max_voxels,
const int num_points, const int NDim) {
MUSA_1D_KERNEL_LOOP(index, num_points) {
auto coor_offset = coor + index * NDim;
// skip invalid points
if (coor_offset[0] == -1) continue;
int num = 0;
int coor_x = coor_offset[0];
int coor_y = coor_offset[1];
int coor_z = coor_offset[2];
// only calculate the coors before this coor[index]
for (int i = 0; i < index; ++i) {
auto prev_coor = coor + i * NDim;
if (prev_coor[0] == -1) continue;
// Find all previous points that have the same coors
// if find the same coor, record it
if ((prev_coor[0] == coor_x) && (prev_coor[1] == coor_y) &&
(prev_coor[2] == coor_z)) {
num++;
if (num == 1) {
// point to the same coor that first show up
point_to_pointidx[index] = i;
} else if (num >= max_points) {
// out of boundary
break;
}
}
}
if (num == 0) {
point_to_pointidx[index] = index;
}
if (num < max_points) {
point_to_voxelidx[index] = num;
}
}
}
template <typename T_int>
__global__ void determin_voxel_num(
// const T_int* coor,
T_int* num_points_per_voxel, T_int* point_to_voxelidx,
T_int* point_to_pointidx, T_int* coor_to_voxelidx, T_int* voxel_num,
const int max_points, const int max_voxels, const int num_points) {
// only calculate the coors before this coor[index]
for (int i = 0; i < num_points; ++i) {
int point_pos_in_voxel = point_to_voxelidx[i];
// record voxel
if (point_pos_in_voxel == -1) {
// out of max_points or invalid point
continue;
} else if (point_pos_in_voxel == 0) {
// record new voxel
int voxelidx = voxel_num[0];
if (voxel_num[0] >= max_voxels) continue;
voxel_num[0] += 1;
coor_to_voxelidx[i] = voxelidx;
num_points_per_voxel[voxelidx] = 1;
} else {
int point_idx = point_to_pointidx[i];
int voxelidx = coor_to_voxelidx[point_idx];
if (voxelidx != -1) {
coor_to_voxelidx[i] = voxelidx;
num_points_per_voxel[voxelidx] += 1;
}
}
}
}
__global__ void nondeterministic_get_assign_pos(
const int nthreads, const int32_t* coors_map, int32_t* pts_id,
int32_t* coors_count, int32_t* reduce_count, int32_t* coors_order) {
MUSA_1D_KERNEL_LOOP(thread_idx, nthreads) {
int coors_idx = coors_map[thread_idx];
if (coors_idx > -1) {
int32_t coors_pts_pos = atomicAdd(&reduce_count[coors_idx], 1);
pts_id[thread_idx] = coors_pts_pos;
if (coors_pts_pos == 0) {
coors_order[coors_idx] = atomicAdd(coors_count, 1);
}
}
}
}
template <typename T>
__global__ void nondeterministic_assign_point_voxel(
const int nthreads, const T* points, const int32_t* coors_map,
const int32_t* pts_id, const int32_t* coors_in, const int32_t* reduce_count,
const int32_t* coors_order, T* voxels, int32_t* coors, int32_t* pts_count,
const int max_voxels, const int max_points, const int num_features,
const int NDim) {
MUSA_1D_KERNEL_LOOP(thread_idx, nthreads) {
int coors_idx = coors_map[thread_idx];
int coors_pts_pos = pts_id[thread_idx];
if (coors_idx > -1 && coors_pts_pos < max_points) {
int coors_pos = coors_order[coors_idx];
if (coors_pos < max_voxels) {
auto voxels_offset =
voxels + (coors_pos * max_points + coors_pts_pos) * num_features;
auto points_offset = points + thread_idx * num_features;
for (int k = 0; k < num_features; k++) {
voxels_offset[k] = points_offset[k];
}
if (coors_pts_pos == 0) {
pts_count[coors_pos] = min(reduce_count[coors_idx], max_points);
auto coors_offset = coors + coors_pos * NDim;
auto coors_in_offset = coors_in + coors_idx * NDim;
for (int k = 0; k < NDim; k++) {
coors_offset[k] = coors_in_offset[k];
}
}
}
}
}
}
#endif // VOXELIZATION_MUSA_KERNEL_MUH

File diff suppressed because it is too large Load Diff

View File

@ -540,6 +540,17 @@ torch::Tensor bias_act_op(const torch::Tensor &input, const torch::Tensor &bias,
REGISTER_DEVICE_IMPL(bias_act_op_impl, MUSA, bias_act_op);
torch::Tensor filtered_lrelu_act_op_impl(torch::Tensor x, torch::Tensor si,
int sx, int sy, float gain,
float slope, float clamp,
bool writeSigns);
torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx,
int sy, float gain, float slope,
float clamp, bool writeSigns);
REGISTER_DEVICE_IMPL(filtered_lrelu_act_op_impl, MUSA, filtered_lrelu_act_op);
void GatherPointsForwardMUSAKernelLauncher(int b, int c, int n, int npoints,
const Tensor points,
const Tensor idx, Tensor out);
@ -869,6 +880,854 @@ Tensor nms_musa(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
Tensor nms_impl(Tensor boxes, Tensor scores, float iou_threshold, int offset);
REGISTER_DEVICE_IMPL(nms_impl, MUSA, nms_musa);
void PointsInBoxesPartForwardMUSAKernelLauncher(int batch_size, int boxes_num,
int pts_num, const Tensor boxes,
const Tensor pts,
Tensor box_idx_of_points);
void PointsInBoxesAllForwardMUSAKernelLauncher(int batch_size, int boxes_num,
int pts_num, const Tensor boxes,
const Tensor pts,
Tensor box_idx_of_points);
void points_in_boxes_part_forward_musa(int batch_size, int boxes_num,
int pts_num, const Tensor boxes,
const Tensor pts,
Tensor box_idx_of_points) {
PointsInBoxesPartForwardMUSAKernelLauncher(batch_size, boxes_num, pts_num,
boxes, pts, box_idx_of_points);
};
void points_in_boxes_all_forward_musa(int batch_size, int boxes_num,
int pts_num, const Tensor boxes,
const Tensor pts,
Tensor box_idx_of_points) {
PointsInBoxesAllForwardMUSAKernelLauncher(batch_size, boxes_num, pts_num,
boxes, pts, box_idx_of_points);
};
void points_in_boxes_part_forward_impl(int batch_size, int boxes_num,
int pts_num, const Tensor boxes,
const Tensor pts,
Tensor box_idx_of_points);
void points_in_boxes_all_forward_impl(int batch_size, int boxes_num,
int pts_num, const Tensor boxes,
const Tensor pts,
Tensor box_idx_of_points);
REGISTER_DEVICE_IMPL(points_in_boxes_part_forward_impl, MUSA,
points_in_boxes_part_forward_musa);
REGISTER_DEVICE_IMPL(points_in_boxes_all_forward_impl, MUSA,
points_in_boxes_all_forward_musa);
void PSAMaskForwardMUSAKernelLauncher(const int psa_type, const Tensor input,
Tensor output, const int num_,
const int h_feature, const int w_feature,
const int h_mask, const int w_mask,
const int half_h_mask,
const int half_w_mask);
void PSAMaskBackwardMUSAKernelLauncher(
const int psa_type, const Tensor grad_output, Tensor grad_input,
const int num_, const int h_feature, const int w_feature, const int h_mask,
const int w_mask, const int half_h_mask, const int half_w_mask);
void psamask_forward_musa(const int psa_type, const Tensor input, Tensor output,
const int num_, const int h_feature,
const int w_feature, const int h_mask,
const int w_mask, const int half_h_mask,
const int half_w_mask) {
PSAMaskForwardMUSAKernelLauncher(psa_type, input, output, num_, h_feature,
w_feature, h_mask, w_mask, half_h_mask,
half_w_mask);
}
void psamask_backward_musa(const int psa_type, const Tensor grad_output,
Tensor grad_input, const int num_,
const int h_feature, const int w_feature,
const int h_mask, const int w_mask,
const int half_h_mask, const int half_w_mask) {
PSAMaskBackwardMUSAKernelLauncher(psa_type, grad_output, grad_input, num_,
h_feature, w_feature, h_mask, w_mask,
half_h_mask, half_w_mask);
}
void psamask_forward_impl(const int psa_type, const Tensor input, Tensor output,
const int num_, const int h_feature,
const int w_feature, const int h_mask,
const int w_mask, const int half_h_mask,
const int half_w_mask);
void psamask_backward_impl(const int psa_type, const Tensor grad_output,
Tensor grad_input, const int num_,
const int h_feature, const int w_feature,
const int h_mask, const int w_mask,
const int half_h_mask, const int half_w_mask);
REGISTER_DEVICE_IMPL(psamask_forward_impl, MUSA, psamask_forward_musa);
REGISTER_DEVICE_IMPL(psamask_backward_impl, MUSA, psamask_backward_musa);
void ROIAlignForwardMUSAKernelLauncher(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned);
void ROIAlignBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois,
Tensor argmax_y, Tensor argmax_x,
Tensor grad_input, int aligned_height,
int aligned_width, float spatial_scale,
int sampling_ratio, int pool_mode,
bool aligned);
void roi_align_forward_musa(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned) {
ROIAlignForwardMUSAKernelLauncher(
input, rois, output, argmax_y, argmax_x, aligned_height, aligned_width,
spatial_scale, sampling_ratio, pool_mode, aligned);
}
void roi_align_backward_musa(Tensor grad_output, Tensor rois, Tensor argmax_y,
Tensor argmax_x, Tensor grad_input,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned) {
ROIAlignBackwardMUSAKernelLauncher(
grad_output, rois, argmax_y, argmax_x, grad_input, aligned_height,
aligned_width, spatial_scale, sampling_ratio, pool_mode, aligned);
}
void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned);
void roi_align_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax_y,
Tensor argmax_x, Tensor grad_input,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned);
REGISTER_DEVICE_IMPL(roi_align_forward_impl, MUSA, roi_align_forward_musa);
REGISTER_DEVICE_IMPL(roi_align_backward_impl, MUSA, roi_align_backward_musa);
void ROIAlignRotatedForwardMUSAKernelLauncher(
const at::Tensor input, const at::Tensor rois, const float spatial_scale,
const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor output);
void ROIAlignRotatedBackwardMUSAKernelLauncher(
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor bottom_grad);
void roi_align_rotated_forward_musa(Tensor input, Tensor rois, Tensor output,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise) {
// Number of ROIs
int num_rois = rois.size(0);
int size_rois = rois.size(1);
if (size_rois != 6) {
AT_ERROR("wrong roi size");
}
int num_channels = input.size(1);
int data_height = input.size(2);
int data_width = input.size(3);
ROIAlignRotatedForwardMUSAKernelLauncher(
input, rois, spatial_scale, sampling_ratio, aligned, clockwise,
num_channels, data_height, data_width, num_rois, aligned_height,
aligned_width, output);
}
void roi_align_rotated_backward_musa(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int aligned_height,
int aligned_width, float spatial_scale,
int sampling_ratio, bool aligned,
bool clockwise) {
// Number of ROIs
int num_rois = rois.size(0);
int size_rois = rois.size(1);
if (size_rois != 6) {
AT_ERROR("wrong roi size");
}
int num_channels = bottom_grad.size(1);
int data_height = bottom_grad.size(2);
int data_width = bottom_grad.size(3);
ROIAlignRotatedBackwardMUSAKernelLauncher(
top_grad, rois, spatial_scale, sampling_ratio, aligned, clockwise,
num_channels, data_height, data_width, num_rois, aligned_height,
aligned_width, bottom_grad);
}
void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise);
void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int aligned_height,
int aligned_width, float spatial_scale,
int sampling_ratio, bool aligned,
bool clockwise);
REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, MUSA,
roi_align_rotated_forward_musa);
REGISTER_DEVICE_IMPL(roi_align_rotated_backward_impl, MUSA,
roi_align_rotated_backward_musa);
void RiROIAlignRotatedForwardMUSAKernelLauncher(
const at::Tensor features, const at::Tensor rois, const float spatial_scale,
const int num_samples, const bool clockwise, const int channels,
const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, const int num_orientations,
at::Tensor output);
void RiROIAlignRotatedBackwardMUSAKernelLauncher(
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
const int num_samples, const bool clockwise, const int channels,
const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, const int num_orientations,
at::Tensor bottom_grad);
void riroi_align_rotated_forward_musa(Tensor features, Tensor rois,
Tensor output, int pooled_height,
int pooled_width, float spatial_scale,
int num_samples, int num_orientations,
bool clockwise) {
// Number of ROIs
int num_rois = rois.size(0);
int size_rois = rois.size(1);
if (size_rois != 6) {
AT_ERROR("wrong roi size");
}
CHECK_CONTIGUOUS(features);
CHECK_CONTIGUOUS(rois);
int num_channels = features.size(1) / num_orientations;
int data_height = features.size(2);
int data_width = features.size(3);
RiROIAlignRotatedForwardMUSAKernelLauncher(
features, rois, spatial_scale, num_samples, clockwise, num_channels,
data_height, data_width, num_rois, pooled_height, pooled_width,
num_orientations, output);
}
void riroi_align_rotated_backward_musa(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int pooled_height,
int pooled_width, float spatial_scale,
int num_samples, int num_orientations,
bool clockwise) {
// Number of ROIs
int num_rois = rois.size(0);
int size_rois = rois.size(1);
if (size_rois != 6) {
AT_ERROR("wrong roi size");
}
CHECK_CONTIGUOUS(top_grad);
CHECK_CONTIGUOUS(rois);
int num_channels = bottom_grad.size(1) / num_orientations;
int data_height = bottom_grad.size(2);
int data_width = bottom_grad.size(3);
RiROIAlignRotatedBackwardMUSAKernelLauncher(
top_grad, rois, spatial_scale, num_samples, clockwise, num_channels,
data_height, data_width, num_rois, pooled_height, pooled_width,
num_orientations, bottom_grad);
}
void riroi_align_rotated_forward_impl(Tensor features, Tensor rois,
Tensor output, int pooled_height,
int pooled_width, float spatial_scale,
int num_samples, int num_orientations,
bool clockwise);
void riroi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int pooled_height,
int pooled_width, float spatial_scale,
int num_samples, int num_orientations,
bool clockwise);
REGISTER_DEVICE_IMPL(riroi_align_rotated_forward_impl, MUSA,
riroi_align_rotated_forward_musa);
REGISTER_DEVICE_IMPL(riroi_align_rotated_backward_impl, MUSA,
riroi_align_rotated_backward_musa);
void RoiawarePool3dForwardMUSAKernelLauncher(
int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x,
int out_y, int out_z, const Tensor rois, const Tensor pts,
const Tensor pts_feature, Tensor argmax, Tensor pts_idx_of_voxels,
Tensor pooled_features, int pool_method);
void RoiawarePool3dBackwardMUSAKernelLauncher(
int boxes_num, int out_x, int out_y, int out_z, int channels,
int max_pts_each_voxel, const Tensor pts_idx_of_voxels, const Tensor argmax,
const Tensor grad_out, Tensor grad_in, int pool_method);
void roiaware_pool3d_forward_musa(int boxes_num, int pts_num, int channels,
int max_pts_each_voxel, int out_x, int out_y,
int out_z, const Tensor rois,
const Tensor pts, const Tensor pts_feature,
Tensor argmax, Tensor pts_idx_of_voxels,
Tensor pooled_features, int pool_method) {
RoiawarePool3dForwardMUSAKernelLauncher(
boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z,
rois, pts, pts_feature, argmax, pts_idx_of_voxels, pooled_features,
pool_method);
};
void roiaware_pool3d_backward_musa(int boxes_num, int out_x, int out_y,
int out_z, int channels,
int max_pts_each_voxel,
const Tensor pts_idx_of_voxels,
const Tensor argmax, const Tensor grad_out,
Tensor grad_in, int pool_method) {
RoiawarePool3dBackwardMUSAKernelLauncher(
boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel,
pts_idx_of_voxels, argmax, grad_out, grad_in, pool_method);
};
void roiaware_pool3d_forward_impl(int boxes_num, int pts_num, int channels,
int max_pts_each_voxel, int out_x, int out_y,
int out_z, const Tensor rois,
const Tensor pts, const Tensor pts_feature,
Tensor argmax, Tensor pts_idx_of_voxels,
Tensor pooled_features, int pool_method);
void roiaware_pool3d_backward_impl(int boxes_num, int out_x, int out_y,
int out_z, int channels,
int max_pts_each_voxel,
const Tensor pts_idx_of_voxels,
const Tensor argmax, const Tensor grad_out,
Tensor grad_in, int pool_method);
REGISTER_DEVICE_IMPL(roiaware_pool3d_forward_impl, MUSA,
roiaware_pool3d_forward_musa);
REGISTER_DEVICE_IMPL(roiaware_pool3d_backward_impl, MUSA,
roiaware_pool3d_backward_musa);
void RoIPointPool3dForwardMUSAKernelLauncher(
int batch_size, int pts_num, int boxes_num, int feature_in_len,
int sampled_pts_num, const Tensor xyz, const Tensor boxes3d,
const Tensor pts_feature, Tensor pooled_features, Tensor pooled_empty_flag);
void roipoint_pool3d_forward_musa(int batch_size, int pts_num, int boxes_num,
int feature_in_len, int sampled_pts_num,
const Tensor xyz, const Tensor boxes3d,
const Tensor pts_feature,
Tensor pooled_features,
Tensor pooled_empty_flag) {
RoIPointPool3dForwardMUSAKernelLauncher(
batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, xyz,
boxes3d, pts_feature, pooled_features, pooled_empty_flag);
};
void roipoint_pool3d_forward_impl(int batch_size, int pts_num, int boxes_num,
int feature_in_len, int sampled_pts_num,
const Tensor xyz, const Tensor boxes3d,
const Tensor pts_feature,
Tensor pooled_features,
Tensor pooled_empty_flag);
REGISTER_DEVICE_IMPL(roipoint_pool3d_forward_impl, MUSA,
roipoint_pool3d_forward_musa);
void ROIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois, Tensor output,
Tensor argmax, int pooled_height,
int pooled_width, float spatial_scale);
void ROIPoolBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois,
Tensor argmax, Tensor grad_input,
int pooled_height, int pooled_width,
float spatial_scale);
void roi_pool_forward_musa(Tensor input, Tensor rois, Tensor output,
Tensor argmax, int pooled_height, int pooled_width,
float spatial_scale) {
ROIPoolForwardMUSAKernelLauncher(input, rois, output, argmax, pooled_height,
pooled_width, spatial_scale);
}
void roi_pool_backward_musa(Tensor grad_output, Tensor rois, Tensor argmax,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale) {
ROIPoolBackwardMUSAKernelLauncher(grad_output, rois, argmax, grad_input,
pooled_height, pooled_width, spatial_scale);
}
void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
Tensor argmax, int pooled_height, int pooled_width,
float spatial_scale);
void roi_pool_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale);
REGISTER_DEVICE_IMPL(roi_pool_forward_impl, MUSA, roi_pool_forward_musa);
REGISTER_DEVICE_IMPL(roi_pool_backward_impl, MUSA, roi_pool_backward_musa);
typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t;
std::vector<at::Tensor> DynamicPointToVoxelForwardMUSAKernelLauncher(
const at::Tensor &feats, const at::Tensor &coors,
const reduce_t reduce_type);
void DynamicPointToVoxelBackwardMUSAKernelLauncher(
at::Tensor &grad_feats, const at::Tensor &grad_reduced_feats,
const at::Tensor &feats, const at::Tensor &reduced_feats,
const at::Tensor &coors_map, const at::Tensor &reduce_count,
const reduce_t reduce_type);
std::vector<torch::Tensor> dynamic_point_to_voxel_forward_musa(
const torch::Tensor &feats, const torch::Tensor &coors,
const reduce_t reduce_type) {
return DynamicPointToVoxelForwardMUSAKernelLauncher(feats, coors,
reduce_type);
};
void dynamic_point_to_voxel_backward_musa(
torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats,
const torch::Tensor &feats, const torch::Tensor &reduced_feats,
const torch::Tensor &coors_idx, const torch::Tensor &reduce_count,
const reduce_t reduce_type) {
DynamicPointToVoxelBackwardMUSAKernelLauncher(grad_feats, grad_reduced_feats,
feats, reduced_feats, coors_idx,
reduce_count, reduce_type);
};
std::vector<torch::Tensor> dynamic_point_to_voxel_forward_impl(
const torch::Tensor &feats, const torch::Tensor &coors,
const reduce_t reduce_type);
void dynamic_point_to_voxel_backward_impl(
torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats,
const torch::Tensor &feats, const torch::Tensor &reduced_feats,
const torch::Tensor &coors_idx, const torch::Tensor &reduce_count,
const reduce_t reduce_type);
REGISTER_DEVICE_IMPL(dynamic_point_to_voxel_forward_impl, MUSA,
dynamic_point_to_voxel_forward_musa);
REGISTER_DEVICE_IMPL(dynamic_point_to_voxel_backward_impl, MUSA,
dynamic_point_to_voxel_backward_musa);
void SyncBNForwardMeanMUSAKernelLauncher(const Tensor input, Tensor mean);
void SyncBNForwardVarMUSAKernelLauncher(const Tensor input, const Tensor mean,
Tensor var);
void SyncBNForwardOutputMUSAKernelLauncher(
const Tensor input, const Tensor mean, const Tensor var,
Tensor running_mean, Tensor running_var, const Tensor weight,
const Tensor bias, Tensor norm, Tensor std, Tensor output, float eps,
float momentum, int group_size);
void SyncBNBackwardParamMUSAKernelLauncher(const Tensor grad_output,
const Tensor norm,
Tensor grad_weight,
Tensor grad_bias);
void SyncBNBackwardDataMUSAKernelLauncher(const Tensor grad_output,
const Tensor weight,
const Tensor grad_weight,
const Tensor grad_bias,
const Tensor norm, const Tensor std,
Tensor grad_input);
void sync_bn_forward_mean_musa(const Tensor input, Tensor mean) {
SyncBNForwardMeanMUSAKernelLauncher(input, mean);
}
void sync_bn_forward_var_musa(const Tensor input, const Tensor mean,
Tensor var) {
SyncBNForwardVarMUSAKernelLauncher(input, mean, var);
}
void sync_bn_forward_output_musa(const Tensor input, const Tensor mean,
const Tensor var, Tensor running_mean,
Tensor running_var, const Tensor weight,
const Tensor bias, Tensor norm, Tensor std,
Tensor output, float eps, float momentum,
int group_size) {
SyncBNForwardOutputMUSAKernelLauncher(input, mean, var, running_mean,
running_var, weight, bias, norm, std,
output, eps, momentum, group_size);
}
void sync_bn_backward_param_musa(const Tensor grad_output, const Tensor norm,
Tensor grad_weight, Tensor grad_bias) {
SyncBNBackwardParamMUSAKernelLauncher(grad_output, norm, grad_weight,
grad_bias);
}
void sync_bn_backward_data_musa(const Tensor grad_output, const Tensor weight,
const Tensor grad_weight,
const Tensor grad_bias, const Tensor norm,
const Tensor std, Tensor grad_input) {
SyncBNBackwardDataMUSAKernelLauncher(grad_output, weight, grad_weight,
grad_bias, norm, std, grad_input);
}
void sync_bn_forward_mean_impl(const Tensor input, Tensor mean);
void sync_bn_forward_var_impl(const Tensor input, const Tensor mean,
Tensor var);
void sync_bn_forward_output_impl(const Tensor input, const Tensor mean,
const Tensor var, Tensor running_mean,
Tensor running_var, const Tensor weight,
const Tensor bias, Tensor norm, Tensor std,
Tensor output, float eps, float momentum,
int group_size);
void sync_bn_backward_param_impl(const Tensor grad_output, const Tensor norm,
Tensor grad_weight, Tensor grad_bias);
void sync_bn_backward_data_impl(const Tensor grad_output, const Tensor weight,
const Tensor grad_weight,
const Tensor grad_bias, const Tensor norm,
const Tensor std, Tensor grad_input);
REGISTER_DEVICE_IMPL(sync_bn_forward_mean_impl, MUSA,
sync_bn_forward_mean_musa);
REGISTER_DEVICE_IMPL(sync_bn_forward_var_impl, MUSA, sync_bn_forward_var_musa);
REGISTER_DEVICE_IMPL(sync_bn_forward_output_impl, MUSA,
sync_bn_forward_output_musa);
REGISTER_DEVICE_IMPL(sync_bn_backward_param_impl, MUSA,
sync_bn_backward_param_musa);
REGISTER_DEVICE_IMPL(sync_bn_backward_data_impl, MUSA,
sync_bn_backward_data_musa);
void ThreeInterpolateForwardMUSAKernelLauncher(int b, int c, int m, int n,
const Tensor points,
const Tensor idx,
const Tensor weight, Tensor out);
void ThreeInterpolateBackwardMUSAKernelLauncher(int b, int c, int n, int m,
const Tensor grad_out,
const Tensor idx,
const Tensor weight,
Tensor grad_points);
void three_interpolate_forward_musa(int b, int c, int m, int n,
const Tensor points, const Tensor idx,
const Tensor weight, Tensor out) {
ThreeInterpolateForwardMUSAKernelLauncher(b, c, m, n, points, idx, weight,
out);
};
void three_interpolate_backward_musa(int b, int c, int n, int m,
const Tensor grad_out, const Tensor idx,
const Tensor weight, Tensor grad_points) {
ThreeInterpolateBackwardMUSAKernelLauncher(b, c, n, m, grad_out, idx, weight,
grad_points);
};
void three_interpolate_forward_impl(int b, int c, int m, int n,
const Tensor points, const Tensor idx,
const Tensor weight, Tensor out);
void three_interpolate_backward_impl(int b, int c, int n, int m,
const Tensor grad_out, const Tensor idx,
const Tensor weight, Tensor grad_points);
REGISTER_DEVICE_IMPL(three_interpolate_forward_impl, MUSA,
three_interpolate_forward_musa);
REGISTER_DEVICE_IMPL(three_interpolate_backward_impl, MUSA,
three_interpolate_backward_musa);
void ThreeNNForwardMUSAKernelLauncher(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2,
Tensor idx);
void three_nn_forward_musa(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2, Tensor idx) {
ThreeNNForwardMUSAKernelLauncher(b, n, m, unknown, known, dist2, idx);
};
void three_nn_forward_impl(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2, Tensor idx);
REGISTER_DEVICE_IMPL(three_nn_forward_impl, MUSA, three_nn_forward_musa);
void TINShiftForwardMUSAKernelLauncher(Tensor input, Tensor shift,
Tensor output);
void TINShiftBackwardMUSAKernelLauncher(Tensor grad_output, Tensor shift,
Tensor grad_input);
void tin_shift_forward_musa(Tensor input, Tensor shift, Tensor output) {
TINShiftForwardMUSAKernelLauncher(input, shift, output);
}
void tin_shift_backward_musa(Tensor grad_output, Tensor shift,
Tensor grad_input) {
TINShiftBackwardMUSAKernelLauncher(grad_output, shift, grad_input);
}
void tin_shift_forward_impl(Tensor input, Tensor shift, Tensor output);
void tin_shift_backward_impl(Tensor grad_output, Tensor shift,
Tensor grad_input);
REGISTER_DEVICE_IMPL(tin_shift_forward_impl, MUSA, tin_shift_forward_musa);
REGISTER_DEVICE_IMPL(tin_shift_backward_impl, MUSA, tin_shift_backward_musa);
#if ((!defined(MUSA_ARCH)) || (defined(MUSA_ARCH)) && (MUSA_ARCH > 21))
torch::Tensor upfirdn2d_op(torch::Tensor input, torch::Tensor filter, int upx,
int upy, int downx, int downy, int padx0, int padx1,
int pady0, int pady1, bool flip, float gain);
torch::Tensor upfirdn2d_op_impl(torch::Tensor input, torch::Tensor filter,
int upx, int upy, int downx, int downy,
int padx0, int padx1, int pady0, int pady1,
bool flip, float gain);
REGISTER_DEVICE_IMPL(upfirdn2d_op_impl, MUSA, upfirdn2d_op);
#endif
int HardVoxelizeForwardMUSAKernelLauncher(
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim = 3);
int NondeterministicHardVoxelizeForwardMUSAKernelLauncher(
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim = 3);
void DynamicVoxelizeForwardMUSAKernelLauncher(
const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size, const std::vector<float> coors_range,
const int NDim = 3);
int hard_voxelize_forward_musa(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors,
at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim) {
return HardVoxelizeForwardMUSAKernelLauncher(
points, voxels, coors, num_points_per_voxel, voxel_size, coors_range,
max_points, max_voxels, NDim);
};
int nondeterministic_hard_voxelize_forward_musa(
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim) {
return NondeterministicHardVoxelizeForwardMUSAKernelLauncher(
points, voxels, coors, num_points_per_voxel, voxel_size, coors_range,
max_points, max_voxels, NDim);
};
void dynamic_voxelize_forward_musa(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int NDim) {
DynamicVoxelizeForwardMUSAKernelLauncher(points, coors, voxel_size,
coors_range, NDim);
};
int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors,
at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim);
int nondeterministic_hard_voxelize_forward_impl(
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim);
void dynamic_voxelize_forward_impl(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int NDim);
REGISTER_DEVICE_IMPL(hard_voxelize_forward_impl, MUSA,
hard_voxelize_forward_musa);
REGISTER_DEVICE_IMPL(nondeterministic_hard_voxelize_forward_impl, MUSA,
nondeterministic_hard_voxelize_forward_musa);
REGISTER_DEVICE_IMPL(dynamic_voxelize_forward_impl, MUSA,
dynamic_voxelize_forward_musa);
void RotatedFeatureAlignForwardMUSAKernelLauncher(const Tensor features,
const Tensor best_bboxes,
const float spatial_scale,
const int points,
Tensor output);
void RotatedFeatureAlignBackwardMUSAKernelLauncher(const Tensor top_grad,
const Tensor best_bboxes,
const float spatial_scale,
const int points,
Tensor bottom_grad);
void rotated_feature_align_forward_musa(const Tensor features,
const Tensor best_bboxes,
const float spatial_scale,
const int points, Tensor output) {
RotatedFeatureAlignForwardMUSAKernelLauncher(features, best_bboxes,
spatial_scale, points, output);
};
void rotated_feature_align_backward_musa(const Tensor top_grad,
const Tensor best_bboxes,
const float spatial_scale,
const int points, Tensor bottom_grad) {
RotatedFeatureAlignBackwardMUSAKernelLauncher(
top_grad, best_bboxes, spatial_scale, points, bottom_grad);
};
void rotated_feature_align_forward_impl(const Tensor features,
const Tensor best_bboxes,
const float spatial_scale,
const int points, Tensor output);
void rotated_feature_align_backward_impl(const Tensor top_grad,
const Tensor best_bboxes,
const float spatial_scale,
const int points, Tensor bottom_grad);
REGISTER_DEVICE_IMPL(rotated_feature_align_forward_impl, MUSA,
rotated_feature_align_forward_musa);
REGISTER_DEVICE_IMPL(rotated_feature_align_backward_impl, MUSA,
rotated_feature_align_backward_musa);
void PointsInPolygonsForwardMUSAKernelLauncher(const at::Tensor points,
const at::Tensor polygons,
const int rows, const int cols,
at::Tensor output);
void points_in_polygons_forward_musa(const Tensor points, const Tensor polygons,
Tensor output, const int rows,
const int cols) {
PointsInPolygonsForwardMUSAKernelLauncher(points, polygons, rows, cols,
output);
};
void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons,
Tensor output, const int rows,
const int cols);
REGISTER_DEVICE_IMPL(points_in_polygons_forward_impl, MUSA,
points_in_polygons_forward_musa);
torch::Tensor IndiceMaxpoolForwardMUSAKernelLauncher(torch::Tensor features,
torch::Tensor indicePairs,
torch::Tensor indiceNum,
int64_t numAct);
torch::Tensor indice_maxpool_forward_musa(torch::Tensor features,
torch::Tensor indicePairs,
torch::Tensor indiceNum,
int64_t numAct) {
return IndiceMaxpoolForwardMUSAKernelLauncher(features, indicePairs,
indiceNum, numAct);
};
torch::Tensor indice_maxpool_forward_impl(torch::Tensor features,
torch::Tensor indicePairs,
torch::Tensor indiceNum,
int64_t numAct);
REGISTER_DEVICE_IMPL(indice_maxpool_forward_impl, MUSA,
indice_maxpool_forward_musa);
torch::Tensor IndiceMaxpoolBackwardMUSAKernelLauncher(torch::Tensor features,
torch::Tensor outFeatures,
torch::Tensor outGrad,
torch::Tensor indicePairs,
torch::Tensor indiceNum);
torch::Tensor indice_maxpool_backward_musa(torch::Tensor features,
torch::Tensor outFeatures,
torch::Tensor outGrad,
torch::Tensor indicePairs,
torch::Tensor indiceNum) {
return IndiceMaxpoolBackwardMUSAKernelLauncher(features, outFeatures, outGrad,
indicePairs, indiceNum);
};
torch::Tensor indice_maxpool_backward_impl(torch::Tensor features,
torch::Tensor outFeatures,
torch::Tensor outGrad,
torch::Tensor indicePairs,
torch::Tensor indiceNum);
REGISTER_DEVICE_IMPL(indice_maxpool_backward_impl, MUSA,
indice_maxpool_backward_musa)
torch::Tensor IndiceConvForwardMUSAKernelLauncher(
torch::Tensor features, torch::Tensor filters, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t numActOut, int64_t _inverse,
int64_t _subM);
torch::Tensor indice_conv_forward_musa(torch::Tensor features,
torch::Tensor filters,
torch::Tensor indicePairs,
torch::Tensor indiceNum,
int64_t numActOut, int64_t _inverse,
int64_t _subM) {
return IndiceConvForwardMUSAKernelLauncher(
features, filters, indicePairs, indiceNum, numActOut, _inverse, _subM);
};
torch::Tensor indice_conv_forward_impl(torch::Tensor features,
torch::Tensor filters,
torch::Tensor indicePairs,
torch::Tensor indiceNum,
int64_t numActOut, int64_t _inverse,
int64_t _subM);
REGISTER_DEVICE_IMPL(indice_conv_forward_impl, MUSA, indice_conv_forward_musa);
std::vector<torch::Tensor> IndiceConvBackwardMUSAKernelLauncher(
torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad,
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
int64_t _subM);
std::vector<torch::Tensor> indice_conv_backward_musa(
torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad,
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
int64_t _subM) {
return IndiceConvBackwardMUSAKernelLauncher(
features, filters, outGrad, indicePairs, indiceNum, _inverse, _subM);
};
std::vector<torch::Tensor> indice_conv_backward_impl(
torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad,
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
int64_t _subM);
REGISTER_DEVICE_IMPL(indice_conv_backward_impl, MUSA,
indice_conv_backward_musa);
torch::Tensor FusedIndiceConvBatchnormMUSAKernelLauncher(
torch::Tensor features, torch::Tensor filters, torch::Tensor bias,
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t numActOut,
int64_t _inverse, int64_t _subM);
torch::Tensor fused_indice_conv_batchnorm_forward_musa(
torch::Tensor features, torch::Tensor filters, torch::Tensor bias,
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t numActOut,
int64_t _inverse, int64_t _subM) {
return FusedIndiceConvBatchnormMUSAKernelLauncher(features, filters, bias,
indicePairs, indiceNum,
numActOut, _inverse, _subM);
};
torch::Tensor fused_indice_conv_batchnorm_forward_impl(
torch::Tensor features, torch::Tensor filters, torch::Tensor bias,
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t numActOut,
int64_t _inverse, int64_t _subM);
REGISTER_DEVICE_IMPL(fused_indice_conv_batchnorm_forward_impl, MUSA,
fused_indice_conv_batchnorm_forward_musa)
void MinAreaPolygonsMUSAKernelLauncher(const Tensor pointsets, Tensor polygons);
void min_area_polygons_musa(const Tensor pointsets, Tensor polygons) {
@ -990,6 +1849,57 @@ REGISTER_DEVICE_IMPL(chamfer_distance_backward_impl, MUSA,
chamfer_distance_backward_musa);
#endif
void PrROIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois,
Tensor output, int pooled_height,
int pooled_width, float spatial_scale);
void PrROIPoolBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale);
void PrROIPoolCoorBackwardMUSAKernelLauncher(
Tensor output, Tensor grad_output, Tensor input, Tensor rois,
Tensor grad_rois, int pooled_height, int pooled_width, float spatial_scale);
void prroi_pool_forward_musa(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale) {
PrROIPoolForwardMUSAKernelLauncher(input, rois, output, pooled_height,
pooled_width, spatial_scale);
}
void prroi_pool_backward_musa(Tensor grad_output, Tensor rois,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale) {
PrROIPoolBackwardMUSAKernelLauncher(grad_output, rois, grad_input,
pooled_height, pooled_width,
spatial_scale);
}
void prroi_pool_coor_backward_musa(Tensor output, Tensor grad_output,
Tensor input, Tensor rois, Tensor grad_rois,
int pooled_height, int pooled_width,
float spatial_scale) {
PrROIPoolCoorBackwardMUSAKernelLauncher(output, grad_output, input, rois,
grad_rois, pooled_height,
pooled_width, spatial_scale);
}
void prroi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale);
void prroi_pool_backward_impl(Tensor grad_output, Tensor rois,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale);
void prroi_pool_coor_backward_impl(Tensor output, Tensor grad_output,
Tensor input, Tensor rois, Tensor grad_rois,
int pooled_height, int pooled_width,
float spatial_scale);
REGISTER_DEVICE_IMPL(prroi_pool_forward_impl, MUSA, prroi_pool_forward_musa);
REGISTER_DEVICE_IMPL(prroi_pool_backward_impl, MUSA, prroi_pool_backward_musa);
REGISTER_DEVICE_IMPL(prroi_pool_coor_backward_impl, MUSA,
prroi_pool_coor_backward_musa);
void BezierAlignForwardMUSAKernelLauncher(Tensor input, Tensor rois,
Tensor output, int aligned_height,
int aligned_width,

View File

@ -0,0 +1,62 @@
// Modified from
// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
// Written by Shaoshuai Shi
// All Rights Reserved 2019.
#include <stdio.h>
#include "points_in_boxes_musa_kernel.muh"
#include "pytorch_musa_helper.hpp"
void PointsInBoxesPartForwardMUSAKernelLauncher(int batch_size, int boxes_num,
int pts_num, const Tensor boxes,
const Tensor pts,
Tensor box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate, z is
// the bottom center, each box DO NOT overlaps params pts: (B, npoints, 3) [x,
// y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default
// -1
c10::musa::MUSAGuard device_guard(boxes.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
dim3 blocks(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), batch_size);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES(
boxes.scalar_type(), "points_in_boxes_part_forward_musa_kernel", [&] {
points_in_boxes_part_forward_musa_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
batch_size, boxes_num, pts_num, boxes.data_ptr<scalar_t>(),
pts.data_ptr<scalar_t>(), box_idx_of_points.data_ptr<int>());
});
AT_MUSA_CHECK(musaGetLastError());
}
void PointsInBoxesAllForwardMUSAKernelLauncher(int batch_size, int boxes_num,
int pts_num, const Tensor boxes,
const Tensor pts,
Tensor box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate, z is the bottom center, each box params pts: (B, npoints, 3)
// [x, y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints),
// default -1
c10::musa::MUSAGuard device_guard(boxes.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
dim3 blocks(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), batch_size);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES(
boxes.scalar_type(), "points_in_boxes_all_forward_musa_kernel", [&] {
points_in_boxes_all_forward_musa_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
batch_size, boxes_num, pts_num, boxes.data_ptr<scalar_t>(),
pts.data_ptr<scalar_t>(), box_idx_of_points.data_ptr<int>());
});
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,28 @@
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/ming71/MUSA/blob/master/point_justify/points_justify_kernel.cu
#include <stdio.h>
#include "points_in_polygons_musa_kernel.muh"
#include "pytorch_musa_helper.hpp"
void PointsInPolygonsForwardMUSAKernelLauncher(const at::Tensor points,
const at::Tensor polygons,
const int rows, const int cols,
at::Tensor output) {
const int output_size = rows * cols;
c10::musa::MUSAGuard device_guard(points.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES(
points.scalar_type(), "points_in_polygons_forward_musa_kernel", ([&] {
const scalar_t *vertex1 = points.data_ptr<scalar_t>();
const scalar_t *vertex2 = polygons.data_ptr<scalar_t>();
scalar_t *inside_flag = output.data_ptr<scalar_t>();
points_in_polygons_forward_musa_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, vertex1, vertex2, rows, cols, inside_flag);
}));
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,65 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "prroi_pool_musa_kernel.muh"
#include "pytorch_musa_helper.hpp"
void PrROIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois,
Tensor output, int pooled_height,
int pooled_width, float spatial_scale) {
int output_size = output.numel();
int channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
c10::musa::MUSAGuard device_guard(input.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
prroi_pool_forward_musa_kernel<float>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input.data_ptr<float>(), rois.data_ptr<float>(),
output.data_ptr<float>(), pooled_height, pooled_width,
static_cast<float>(spatial_scale), channels, height, width);
AT_MUSA_CHECK(musaGetLastError());
}
void PrROIPoolBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois,
Tensor grad_input, int pooled_height,
int pooled_width,
float spatial_scale) {
int output_size = grad_output.numel();
int channels = grad_input.size(1);
int height = grad_input.size(2);
int width = grad_input.size(3);
c10::musa::MUSAGuard device_guard(grad_output.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
prroi_pool_backward_musa_kernel<float>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, grad_output.data_ptr<float>(), rois.data_ptr<float>(),
grad_input.data_ptr<float>(), pooled_height, pooled_width,
static_cast<float>(spatial_scale), channels, height, width);
AT_MUSA_CHECK(musaGetLastError());
}
void PrROIPoolCoorBackwardMUSAKernelLauncher(Tensor output, Tensor grad_output,
Tensor input, Tensor rois,
Tensor grad_rois,
int pooled_height,
int pooled_width,
float spatial_scale) {
int output_size = grad_output.numel();
int channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
c10::musa::MUSAGuard device_guard(grad_output.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
prroi_pool_coor_backward_musa_kernel<float>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, output.data_ptr<float>(), grad_output.data_ptr<float>(),
input.data_ptr<float>(), rois.data_ptr<float>(),
grad_rois.data_ptr<float>(), pooled_height, pooled_width,
static_cast<float>(spatial_scale), channels, height, width);
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,60 @@
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/hszhao/semseg/blob/master/lib/psa/src
#include <torch/serialize/tensor.h>
#include "psamask_musa_kernel.muh"
#include "pytorch_musa_helper.hpp"
void PSAMaskForwardMUSAKernelLauncher(const int psa_type, const Tensor input,
Tensor output, const int num_,
const int h_feature, const int w_feature,
const int h_mask, const int w_mask,
const int half_h_mask,
const int half_w_mask) {
int nthreads = num_ * h_feature * w_feature;
musaStream_t stream = c10::musa::getCurrentMUSAStream();
if (psa_type == 0)
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "psamask_collect_forward_musa", [&] {
psamask_collect_forward_musa<scalar_t><<<nthreads, 512, 0, stream>>>(
nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask,
half_w_mask, input.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>());
});
else
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "psamask_distribute_forward_musa", [&] {
psamask_distribute_forward_musa<scalar_t>
<<<nthreads, 512, 0, stream>>>(
nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask,
half_w_mask, input.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>());
});
}
void PSAMaskBackwardMUSAKernelLauncher(
const int psa_type, const Tensor grad_output, Tensor grad_input,
const int num_, const int h_feature, const int w_feature, const int h_mask,
const int w_mask, const int half_h_mask, const int half_w_mask) {
int nthreads = num_ * h_feature * w_feature;
musaStream_t stream = c10::musa::getCurrentMUSAStream();
if (psa_type == 0)
AT_DISPATCH_FLOATING_TYPES(
grad_input.scalar_type(), "psamask_collect_backward_musa", [&] {
psamask_collect_backward_musa<scalar_t><<<nthreads, 512, 0, stream>>>(
nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask,
half_w_mask, grad_output.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>());
});
else
AT_DISPATCH_FLOATING_TYPES(
grad_input.scalar_type(), "psamask_distribute_backward_musa", [&] {
psamask_distribute_backward_musa<scalar_t>
<<<nthreads, 512, 0, stream>>>(
nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask,
half_w_mask, grad_output.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>());
});
}

View File

@ -0,0 +1,53 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_musa_helper.hpp"
#include "riroi_align_rotated_musa_kernel.muh"
void RiROIAlignRotatedForwardMUSAKernelLauncher(
const at::Tensor features, const at::Tensor rois, const float spatial_scale,
const int num_samples, const bool clockwise, const int channels,
const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, const int num_orientations,
at::Tensor output) {
const int output_size =
num_rois * pooled_height * pooled_width * channels * num_orientations;
c10::musa::MUSAGuard device_guard(features.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES(
features.scalar_type(), "riroi_align_rotated_forward_musa_kernel", ([&] {
const scalar_t *bottom_data = features.data_ptr<scalar_t>();
const scalar_t *rois_data = rois.data_ptr<scalar_t>();
scalar_t *top_data = output.data_ptr<scalar_t>();
riroi_align_rotated_forward_musa_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, bottom_data, rois_data, scalar_t(spatial_scale),
num_samples, clockwise, channels, height, width, pooled_height,
pooled_width, num_orientations, top_data);
}));
AT_MUSA_CHECK(musaGetLastError());
}
void RiROIAlignRotatedBackwardMUSAKernelLauncher(
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
const int num_samples, const bool clockwise, const int channels,
const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, const int num_orientations,
at::Tensor bottom_grad) {
const int output_size =
num_rois * pooled_height * pooled_width * channels * num_orientations;
c10::musa::MUSAGuard device_guard(top_grad.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES(
top_grad.scalar_type(), "riroi_align_rotated_backward_musa_kernel", ([&] {
const scalar_t *top_diff = top_grad.data_ptr<scalar_t>();
const scalar_t *rois_data = rois.data_ptr<scalar_t>();
scalar_t *bottom_diff = bottom_grad.data_ptr<scalar_t>();
riroi_align_rotated_backward_musa_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, top_diff, rois_data, spatial_scale, num_samples,
clockwise, channels, height, width, pooled_height, pooled_width,
num_orientations, bottom_diff);
}));
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,58 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_musa_helper.hpp"
#include "roi_align_musa_kernel.muh"
void ROIAlignForwardMUSAKernelLauncher(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned) {
int output_size = output.numel();
int channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
c10::musa::MUSAGuard device_guard(input.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "roi_align_forward_musa_kernel", [&] {
roi_align_forward_musa_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input.data_ptr<scalar_t>(),
rois.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
argmax_y.data_ptr<scalar_t>(), argmax_x.data_ptr<scalar_t>(),
aligned_height, aligned_width,
static_cast<scalar_t>(spatial_scale), sampling_ratio, pool_mode,
aligned, channels, height, width);
});
AT_MUSA_CHECK(musaGetLastError());
}
void ROIAlignBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois,
Tensor argmax_y, Tensor argmax_x,
Tensor grad_input, int aligned_height,
int aligned_width, float spatial_scale,
int sampling_ratio, int pool_mode,
bool aligned) {
int output_size = grad_output.numel();
int channels = grad_input.size(1);
int height = grad_input.size(2);
int width = grad_input.size(3);
c10::musa::MUSAGuard device_guard(grad_output.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "roi_align_backward_musa_kernel", [&] {
roi_align_backward_musa_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, grad_output.data_ptr<scalar_t>(),
rois.data_ptr<scalar_t>(), argmax_y.data_ptr<scalar_t>(),
argmax_x.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
aligned_height, aligned_width,
static_cast<scalar_t>(spatial_scale), sampling_ratio, pool_mode,
aligned, channels, height, width);
});
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,45 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_musa_helper.hpp"
#include "roi_align_rotated_musa_kernel.muh"
void ROIAlignRotatedForwardMUSAKernelLauncher(
const at::Tensor input, const at::Tensor rois, const float spatial_scale,
const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor output) {
const int output_size = num_rois * pooled_height * pooled_width * channels;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "ROIAlignRotatedLaucherForward", ([&] {
const scalar_t *bottom_data = input.data_ptr<scalar_t>();
const scalar_t *rois_data = rois.data_ptr<scalar_t>();
scalar_t *top_data = output.data_ptr<scalar_t>();
roi_align_rotated_forward_musa_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, bottom_data, rois_data, scalar_t(spatial_scale),
sampling_ratio, aligned, clockwise, channels, height, width,
pooled_height, pooled_width, top_data);
}));
AT_MUSA_CHECK(musaGetLastError());
}
void ROIAlignRotatedBackwardMUSAKernelLauncher(
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor bottom_grad) {
const int output_size = num_rois * pooled_height * pooled_width * channels;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.scalar_type(), "ROIAlignLaucherBackward", ([&] {
const scalar_t *top_diff = top_grad.data_ptr<scalar_t>();
const scalar_t *rois_data = rois.data_ptr<scalar_t>();
scalar_t *bottom_diff = bottom_grad.data_ptr<scalar_t>();
roi_align_rotated_backward_musa_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, top_diff, rois_data, spatial_scale, sampling_ratio,
aligned, clockwise, channels, height, width, pooled_height,
pooled_width, bottom_diff);
}));
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,50 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_musa_helper.hpp"
#include "roi_pool_musa_kernel.muh"
void ROIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois, Tensor output,
Tensor argmax, int pooled_height,
int pooled_width, float spatial_scale) {
int output_size = output.numel();
int channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
c10::musa::MUSAGuard device_guard(input.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "roi_pool_forward_musa_kernel", [&] {
roi_pool_forward_musa_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input.data_ptr<scalar_t>(),
rois.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
argmax.data_ptr<int>(), pooled_height, pooled_width,
static_cast<scalar_t>(spatial_scale), channels, height, width);
});
AT_MUSA_CHECK(musaGetLastError());
}
void ROIPoolBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois,
Tensor argmax, Tensor grad_input,
int pooled_height, int pooled_width,
float spatial_scale) {
int output_size = grad_output.numel();
int channels = grad_input.size(1);
int height = grad_input.size(2);
int width = grad_input.size(3);
c10::musa::MUSAGuard device_guard(grad_output.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "roi_pool_backward_musa_kernel", [&] {
roi_pool_backward_musa_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, grad_output.data_ptr<scalar_t>(),
rois.data_ptr<scalar_t>(), argmax.data_ptr<int>(),
grad_input.data_ptr<scalar_t>(), pooled_height, pooled_width,
channels, height, width);
});
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,118 @@
// Modified from
// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
// Written by Shaoshuai Shi
// All Rights Reserved 2019.
#include <stdio.h>
#include "pytorch_musa_helper.hpp"
#include "roiaware_pool3d_musa_kernel.muh"
void RoiawarePool3dForwardMUSAKernelLauncher(
int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x,
int out_y, int out_z, const Tensor rois, const Tensor pts,
const Tensor pts_feature, Tensor argmax, Tensor pts_idx_of_voxels,
Tensor pooled_features, int pool_method) {
// params rois: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate params pts: (npoints, 3) [x, y, z] in LiDAR coordinate params
// pts_feature: (npoints, C) params argmax: (N, out_x, out_y, out_z, C) params
// pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) params
// pooled_features: (N, out_x, out_y, out_z, C) params pool_method: 0:
// max_pool 1: avg_pool
c10::musa::MUSAGuard device_guard(pts_feature.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
Tensor pts_mask =
-at::ones({boxes_num, pts_num}, pts_feature.options().dtype(at::kInt));
dim3 blocks_mask(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), boxes_num);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
rois.scalar_type(), "generate_pts_mask_for_box3d", [&] {
generate_pts_mask_for_box3d<scalar_t>
<<<blocks_mask, threads, 0, stream>>>(
boxes_num, pts_num, out_x, out_y, out_z,
rois.data_ptr<scalar_t>(), pts.data_ptr<scalar_t>(),
pts_mask.data_ptr<int>());
});
AT_MUSA_CHECK(musaGetLastError());
// TODO: Merge the collect and pool functions, SS
dim3 blocks_collect(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK));
AT_DISPATCH_INTEGRAL_TYPES(
pts_idx_of_voxels.scalar_type(), "collect_inside_pts_for_box3d", [&] {
collect_inside_pts_for_box3d<scalar_t>
<<<blocks_collect, threads, 0, stream>>>(
boxes_num, pts_num, max_pts_each_voxel, out_x, out_y, out_z,
pts_mask.data_ptr<int>(),
pts_idx_of_voxels.data_ptr<scalar_t>());
});
AT_MUSA_CHECK(musaGetLastError());
dim3 blocks_pool(GET_BLOCKS(out_x * out_y * out_z, THREADS_PER_BLOCK),
channels, boxes_num);
if (pool_method == 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
pts_feature.scalar_type(), "roiaware_maxpool3d", [&] {
roiaware_maxpool3d<scalar_t><<<blocks_pool, threads, 0, stream>>>(
boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y,
out_z, pts_feature.data_ptr<scalar_t>(),
pts_idx_of_voxels.data_ptr<int>(),
pooled_features.data_ptr<scalar_t>(), argmax.data_ptr<int>());
});
} else if (pool_method == 1) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
pts_feature.scalar_type(), "roiaware_avgpool3d", [&] {
roiaware_avgpool3d<scalar_t><<<blocks_pool, threads, 0, stream>>>(
boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y,
out_z, pts_feature.data_ptr<scalar_t>(),
pts_idx_of_voxels.data_ptr<int>(),
pooled_features.data_ptr<scalar_t>());
});
}
AT_MUSA_CHECK(musaGetLastError());
}
void RoiawarePool3dBackwardMUSAKernelLauncher(
int boxes_num, int out_x, int out_y, int out_z, int channels,
int max_pts_each_voxel, const Tensor pts_idx_of_voxels, const Tensor argmax,
const Tensor grad_out, Tensor grad_in, int pool_method) {
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
// params argmax: (N, out_x, out_y, out_z, C)
// params grad_out: (N, out_x, out_y, out_z, C)
// params grad_in: (npoints, C), return value
// params pool_method: 0: max_pool, 1: avg_pool
c10::musa::MUSAGuard device_guard(grad_out.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
dim3 blocks(GET_BLOCKS(out_x * out_y * out_z, THREADS_PER_BLOCK), channels,
boxes_num);
dim3 threads(THREADS_PER_BLOCK);
if (pool_method == 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_in.scalar_type(), "roiaware_maxpool3d_backward", [&] {
roiaware_maxpool3d_backward<scalar_t><<<blocks, threads, 0, stream>>>(
boxes_num, channels, out_x, out_y, out_z, argmax.data_ptr<int>(),
grad_out.data_ptr<scalar_t>(), grad_in.data_ptr<scalar_t>());
});
} else if (pool_method == 1) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_in.scalar_type(), "roiaware_avgpool3d_backward", [&] {
roiaware_avgpool3d_backward<scalar_t><<<blocks, threads, 0, stream>>>(
boxes_num, channels, out_x, out_y, out_z, max_pts_each_voxel,
pts_idx_of_voxels.data_ptr<int>(), grad_out.data_ptr<scalar_t>(),
grad_in.data_ptr<scalar_t>());
});
}
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,60 @@
/*
Modified from
https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/roipoint_pool3d/src/roipoint_pool3d_kernel.cu
Point cloud feature pooling
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/
#include <math.h>
#include <stdio.h>
#include "pytorch_musa_helper.hpp"
#include "roipoint_pool3d_musa_kernel.muh"
void RoIPointPool3dForwardMUSAKernelLauncher(
int batch_size, int pts_num, int boxes_num, int feature_in_len,
int sampled_pts_num, const Tensor xyz, const Tensor boxes3d,
const Tensor pts_feature, Tensor pooled_features,
Tensor pooled_empty_flag) {
Tensor pts_assign = at::empty({batch_size, pts_num, boxes_num},
boxes3d.options().dtype(at::kInt));
c10::musa::MUSAGuard device_guard(xyz.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), boxes_num, batch_size);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
xyz.scalar_type(), "assign_pts_to_box3d", [&] {
assign_pts_to_box3d<scalar_t><<<blocks, threads, 0, stream>>>(
batch_size, pts_num, boxes_num, xyz.data_ptr<scalar_t>(),
boxes3d.data_ptr<scalar_t>(), pts_assign.data_ptr<int>());
});
Tensor pts_idx = at::empty({batch_size, boxes_num, sampled_pts_num},
boxes3d.options().dtype(at::kInt));
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks2(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK), batch_size);
get_pooled_idx<<<blocks2, threads, 0, stream>>>(
batch_size, pts_num, boxes_num, sampled_pts_num,
pts_assign.data_ptr<int>(), pts_idx.data_ptr<int>(),
pooled_empty_flag.data_ptr<int>());
dim3 blocks_pool(GET_BLOCKS(sampled_pts_num, THREADS_PER_BLOCK), boxes_num,
batch_size);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
xyz.scalar_type(), "roipoint_pool3d_forward", [&] {
roipoint_pool3d_forward<scalar_t><<<blocks_pool, threads, 0, stream>>>(
batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num,
xyz.data_ptr<scalar_t>(), pts_idx.data_ptr<int>(),
pts_feature.data_ptr<scalar_t>(),
pooled_features.data_ptr<scalar_t>(),
pooled_empty_flag.data_ptr<int>());
});
}

View File

@ -0,0 +1,53 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/SJTU-Thinklab-Det/r3det-on-mmdetection/blob/master/mmdet/ops/fr/src/feature_refine_kernel.cu
#include "pytorch_musa_helper.hpp"
#include "rotated_feature_align_musa_kernel.muh"
void RotatedFeatureAlignForwardMUSAKernelLauncher(const Tensor features,
const Tensor best_bboxes,
const float spatial_scale,
const int points,
Tensor output) {
c10::musa::MUSAGuard device_guard(features.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
const int output_size = features.numel();
AT_DISPATCH_FLOATING_TYPES(
features.scalar_type(), "rotated_feature_align_forward_musa_kernel",
([&] {
const scalar_t* bottom_data = features.data_ptr<scalar_t>();
const scalar_t* bboxes_data = best_bboxes.data_ptr<scalar_t>();
scalar_t* top_data = output.data_ptr<scalar_t>();
rotated_feature_align_forward_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, points, bottom_data, bboxes_data,
scalar_t(spatial_scale), features.size(1), features.size(2),
features.size(3), top_data);
}));
AT_MUSA_CHECK(musaGetLastError());
}
void RotatedFeatureAlignBackwardMUSAKernelLauncher(const Tensor top_grad,
const Tensor best_bboxes,
const float spatial_scale,
const int points,
Tensor bottom_grad) {
c10::musa::MUSAGuard device_guard(top_grad.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
const int output_size = top_grad.numel();
AT_DISPATCH_FLOATING_TYPES(
top_grad.scalar_type(), "rotated_feature_align_backward_musa_kernel",
([&] {
const scalar_t* top_diff = top_grad.data_ptr<scalar_t>();
const scalar_t* bboxes_data = best_bboxes.data_ptr<scalar_t>();
scalar_t* bottom_diff = bottom_grad.data_ptr<scalar_t>();
rotated_feature_align_backward_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, points, top_diff, bboxes_data,
scalar_t(spatial_scale), top_grad.size(1), top_grad.size(2),
top_grad.size(3), bottom_diff);
}));
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,132 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include <stdio.h>
#include <stdlib.h>
#include <torch/types.h>
#include "pytorch_musa_helper.hpp"
#include "scatter_points_musa_kernel.muh"
std::vector<at::Tensor> DynamicPointToVoxelForwardMUSAKernelLauncher(
const at::Tensor &feats, const at::Tensor &coors,
const reduce_t reduce_type) {
const int num_input = feats.size(0);
const int num_feats = feats.size(1);
if (num_input == 0)
return {feats.clone().detach(), coors.clone().detach(),
coors.new_empty({0}, torch::kInt32),
coors.new_empty({0}, torch::kInt32)};
at::Tensor out_coors;
at::Tensor coors_map;
at::Tensor reduce_count;
auto coors_clean = coors.masked_fill(coors.lt(0).any(-1, true), -1);
std::tie(out_coors, coors_map, reduce_count) =
at::unique_dim(coors_clean, 0, true, true, true);
if (out_coors[0][0].lt(0).item<bool>()) {
// the first element of out_coors (-1,-1,-1) and should be removed
out_coors = out_coors.slice(0, 1);
reduce_count = reduce_count.slice(0, 1);
coors_map = coors_map - 1;
}
coors_map = coors_map.to(torch::kInt32);
reduce_count = reduce_count.to(torch::kInt32);
auto reduced_feats =
at::empty({out_coors.size(0), num_feats}, feats.options());
c10::musa::MUSAGuard device_guard(feats.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES(
feats.scalar_type(), "feats_reduce_kernel", ([&] {
if (reduce_type == reduce_t::MAX)
reduced_feats.fill_(-std::numeric_limits<scalar_t>::infinity());
else
reduced_feats.fill_(static_cast<scalar_t>(0));
dim3 blocks(std::min(
at::musa::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim));
dim3 threads(THREADS_PER_BLOCK);
feats_reduce_kernel<<<blocks, threads, 0, stream>>>(
feats.data_ptr<scalar_t>(), coors_map.data_ptr<int32_t>(),
reduced_feats.data_ptr<scalar_t>(), num_input, num_feats,
reduce_type);
if (reduce_type == reduce_t::MEAN)
reduced_feats /= reduce_count.unsqueeze(-1).to(reduced_feats.dtype());
}));
AT_MUSA_CHECK(musaGetLastError());
return {reduced_feats, out_coors, coors_map, reduce_count};
}
void DynamicPointToVoxelBackwardMUSAKernelLauncher(
at::Tensor &grad_feats, const at::Tensor &grad_reduced_feats,
const at::Tensor &feats, const at::Tensor &reduced_feats,
const at::Tensor &coors_map, const at::Tensor &reduce_count,
const reduce_t reduce_type) {
const int num_input = feats.size(0);
const int num_reduced = reduced_feats.size(0);
const int num_feats = feats.size(1);
grad_feats.fill_(0);
// copy voxel grad to points
if (num_input == 0 || num_reduced == 0) return;
c10::musa::MUSAGuard device_guard(feats.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
if (reduce_type == reduce_t::MEAN || reduce_type == reduce_t::SUM) {
AT_DISPATCH_FLOATING_TYPES(
grad_reduced_feats.scalar_type(), "add_reduce_traceback_grad_kernel",
([&] {
dim3 blocks(std::min(
at::musa::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim));
dim3 threads(THREADS_PER_BLOCK);
add_reduce_traceback_grad_kernel<<<blocks, threads, 0, stream>>>(
grad_feats.data_ptr<scalar_t>(),
grad_reduced_feats.data_ptr<scalar_t>(),
coors_map.data_ptr<int32_t>(), reduce_count.data_ptr<int32_t>(),
num_input, num_feats, reduce_type);
}));
AT_MUSA_CHECK(musaGetLastError());
} else {
auto reduce_from = at::full({num_reduced, num_feats}, num_input,
coors_map.options().dtype(torch::kInt32));
AT_DISPATCH_FLOATING_TYPES(
grad_reduced_feats.scalar_type(),
"max_reduce_traceback_scatter_idx_kernel", ([&] {
dim3 blocks(std::min(
at::musa::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim));
dim3 threads(THREADS_PER_BLOCK);
max_reduce_traceback_scatter_idx_kernel<<<blocks, threads, 0,
stream>>>(
feats.data_ptr<scalar_t>(), reduced_feats.data_ptr<scalar_t>(),
reduce_from.data_ptr<int32_t>(), coors_map.data_ptr<int32_t>(),
num_input, num_feats);
}));
AT_MUSA_CHECK(musaGetLastError());
AT_DISPATCH_FLOATING_TYPES(
grad_reduced_feats.scalar_type(),
"max_reduce_traceback_scatter_idx_kernel", ([&] {
dim3 blocks(
std::min(at::musa::ATenCeilDiv(num_reduced, THREADS_PER_BLOCK),
maxGridDim));
dim3 threads(THREADS_PER_BLOCK);
max_reduce_scatter_grad_kernel<<<blocks, threads, 0, stream>>>(
grad_feats.data_ptr<scalar_t>(),
grad_reduced_feats.data_ptr<scalar_t>(),
reduce_from.data_ptr<int32_t>(), num_reduced, num_feats);
}));
AT_MUSA_CHECK(musaGetLastError());
}
}

View File

@ -0,0 +1,486 @@
// Copyright 2019 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <ATen/ATen.h>
// clang-format off
// TODO: make spconv_utils.h order agnostic
#include "../spconv_utils.h"
// clang-format on
#include <utils/spconv/spconv/maxpool.h>
#include <utils/spconv/spconv/mp_helper.h>
#include <utils/spconv/tensorview/helper_launch.h>
#include <utils/spconv/tensorview/tensorview.h>
#include <chrono>
#include <limits>
#include <type_traits>
#include <utils/spconv/tensorview/helper_kernel.muh>
#include "pytorch_musa_helper.hpp"
template <typename scalar_t, typename Index, int NumTLP, int NumILP>
__global__ void maxPoolFwdBlockKernel(scalar_t *outFeatures,
const scalar_t *inFeatures,
const Index *indicesIn,
const Index *indicesOut, int numHot,
int numPlanes) {
scalar_t in, out;
int ILPStrideY[NumILP];
Index idxo, idxi;
#pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++)
ILPStrideY[ilp] = threadIdx.y + ilp * blockDim.y;
outFeatures += blockIdx.y * NumTLP;
inFeatures += blockIdx.y * NumTLP;
for (int ix = blockIdx.x * blockDim.x; ix < numHot;
ix += blockDim.x * gridDim.x) {
{
#pragma unroll
for (int ilp = 0; ilp < NumILP; ++ilp) {
idxi = indicesIn[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
idxo = indicesOut[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
in = inFeatures[idxi];
out = outFeatures[idxo];
if (in > out) {
outFeatures[idxo] = in;
}
}
}
}
}
template <typename scalar_t, typename Index, int NumTLP, int NumILP>
__global__ void maxPoolFwdGenericBlockKernel(scalar_t *outFeatures,
const scalar_t *inFeatures,
const Index *indicesIn,
const Index *indicesOut,
int numHot, int numPlanes) {
int ILPStrideX[NumILP];
Index RI[NumILP];
Index RO[NumILP];
scalar_t in, out;
#pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++)
ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
for (int ix : tv::KernelLoopX<int, NumILP>(numHot)) {
#pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++) {
RI[ilp] = indicesIn[ix + ILPStrideX[ilp]] * numPlanes;
RO[ilp] = indicesOut[ix + ILPStrideX[ilp]] * numPlanes;
}
for (int iy : tv::KernelLoopY<int>(numPlanes)) {
#pragma unroll
for (int ilp = 0; ilp < NumILP; ++ilp) {
in = inFeatures[RI[ilp] + iy];
out = outFeatures[RO[ilp] + iy];
if (in > out) {
outFeatures[RO[ilp] + iy] = in;
}
}
}
}
}
template <typename scalar_t, typename Index, int NumTLP, int NumILP,
typename VecType>
__global__ void maxPoolFwdVecBlockKernel(scalar_t *outFeatures,
const scalar_t *inFeatures,
const Index *indicesIn,
const Index *indicesOut, int numHot,
int numPlanes) {
int ILPStrideY[NumILP];
constexpr int vecloadFactor = sizeof(VecType) / sizeof(scalar_t);
scalar_t bufi[vecloadFactor];
scalar_t bufo[vecloadFactor];
Index idxi, idxo;
#pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++)
ILPStrideY[ilp] = threadIdx.y + ilp * blockDim.y;
outFeatures += blockIdx.y * NumTLP;
inFeatures += blockIdx.y * NumTLP;
for (int ix = blockIdx.x * blockDim.x * vecloadFactor; ix < numHot;
ix += blockDim.x * gridDim.x * vecloadFactor) {
#pragma unroll
for (int ilp = 0; ilp < NumILP; ++ilp) {
idxi = indicesIn[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
idxo = indicesOut[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
reinterpret_cast<VecType *>(bufo)[0] =
reinterpret_cast<VecType *>(outFeatures)[idxo];
reinterpret_cast<VecType *>(bufi)[0] =
reinterpret_cast<const VecType *>(inFeatures)[idxi];
#pragma unroll
for (int i = 0; i < vecloadFactor; i++) {
if (bufi[i] > bufo[i]) {
bufo[i] = bufi[i];
}
}
reinterpret_cast<VecType *>(outFeatures)[idxo] =
reinterpret_cast<VecType *>(bufo)[0];
}
}
}
template <typename scalar_t, typename Index, int NumTLP, int NumILP>
__global__ void maxPoolFwdGenericKernel(scalar_t *outFeatures,
const scalar_t *inFeatures,
const Index *indicesIn,
const Index *indicesOut, int numHot,
int numPlanes) {
int ILPStrideX[NumILP];
Index RI[NumILP];
Index RO[NumILP];
scalar_t in, out;
#pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++)
ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
for (int ix : tv::KernelLoopX<int, NumILP>(numHot)) {
#pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++) {
if (ix + ILPStrideX[ilp] < numHot) {
RI[ilp] = indicesIn[ix + ILPStrideX[ilp]] * numPlanes;
RO[ilp] = indicesOut[ix + ILPStrideX[ilp]] * numPlanes;
}
}
for (int iy : tv::KernelLoopY<int>(numPlanes)) {
#pragma unroll
for (int ilp = 0; ilp < NumILP; ++ilp) {
if (ix + ILPStrideX[ilp] < numHot) {
in = inFeatures[RI[ilp] + iy];
out = outFeatures[RO[ilp] + iy];
if (in > out) {
outFeatures[RO[ilp] + iy] = in;
}
}
}
}
}
}
template <typename scalar_t, typename Index, int NumTLP, int NumILP>
__global__ void maxPoolBwdBlockKernel(const scalar_t *outFeatures,
const scalar_t *inFeatures,
const scalar_t *fout, scalar_t *fin,
const Index *indicesIn,
const Index *indicesOut, int numHot,
int numPlanes) {
scalar_t in, out;
Index idxo, idxi;
int ILPStrideY[NumILP];
#pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++)
ILPStrideY[ilp] = threadIdx.y + ilp * blockDim.y;
outFeatures += blockIdx.y * NumTLP;
inFeatures += blockIdx.y * NumTLP;
fout += blockIdx.y * NumTLP;
fin += blockIdx.y * NumTLP;
for (int ix = blockIdx.x * blockDim.x; ix < numHot;
ix += blockDim.x * gridDim.x) {
{
#pragma unroll
for (int ilp = 0; ilp < NumILP; ++ilp) {
idxi = indicesIn[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
idxo = indicesOut[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
in = inFeatures[idxi];
out = outFeatures[idxo];
if (in == out) {
fin[idxi] += fout[idxo];
}
}
}
}
}
template <typename scalar_t, typename Index, int NumTLP, int NumILP>
__global__ void maxPoolBwdGenericBlockKernel(
const scalar_t *outFeatures, const scalar_t *inFeatures,
const scalar_t *fout, scalar_t *fin, const Index *indicesIn,
const Index *indicesOut, int numHot, int numPlanes) {
int ILPStrideX[NumILP];
Index RI[NumILP];
Index RO[NumILP];
scalar_t in, out;
#pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++)
ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
for (int ix : tv::KernelLoopX<int, NumILP>(numHot)) {
#pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++) {
RI[ilp] = indicesIn[ix + ILPStrideX[ilp]] * numPlanes;
RO[ilp] = indicesOut[ix + ILPStrideX[ilp]] * numPlanes;
}
for (int iy : tv::KernelLoopY<int>(numPlanes)) {
#pragma unroll
for (int ilp = 0; ilp < NumILP; ++ilp) {
in = inFeatures[RI[ilp] + iy];
out = outFeatures[RO[ilp] + iy];
if (in == out) {
fin[RI[ilp] + iy] += fout[RO[ilp] + iy];
}
}
}
}
}
template <typename scalar_t, typename Index, int NumTLP, int NumILP,
typename VecType>
__global__ void maxPoolBwdVecBlockKernel(const scalar_t *outFeatures,
const scalar_t *inFeatures,
const scalar_t *fout, scalar_t *fin,
const Index *indicesIn,
const Index *indicesOut, int numHot,
int numPlanes) {
int ILPStrideY[NumILP];
constexpr int vecloadFactor = sizeof(VecType) / sizeof(scalar_t);
scalar_t bufi[vecloadFactor];
scalar_t bufo[vecloadFactor];
scalar_t bufdi[vecloadFactor];
scalar_t bufdo[vecloadFactor];
Index idxi, idxo;
#pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++)
ILPStrideY[ilp] = threadIdx.y + ilp * blockDim.y;
outFeatures += blockIdx.y * NumTLP;
inFeatures += blockIdx.y * NumTLP;
for (int ix = blockIdx.x * blockDim.x * vecloadFactor; ix < numHot;
ix += blockDim.x * gridDim.x * vecloadFactor) {
#pragma unroll
for (int ilp = 0; ilp < NumILP; ++ilp) {
idxi = indicesIn[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
idxo = indicesOut[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
reinterpret_cast<VecType *>(bufo)[0] =
reinterpret_cast<const VecType *>(outFeatures)[idxo];
reinterpret_cast<VecType *>(bufi)[0] =
reinterpret_cast<const VecType *>(inFeatures)[idxi];
reinterpret_cast<VecType *>(bufdo)[0] =
reinterpret_cast<const VecType *>(fout)[idxo];
reinterpret_cast<VecType *>(bufdi)[0] =
reinterpret_cast<VecType *>(fin)[idxi];
#pragma unroll
for (int i = 0; i < vecloadFactor; i++) {
if (bufi[i] == bufo[i]) {
bufdi[i] += bufdo[i];
}
}
reinterpret_cast<VecType *>(fin)[idxi] =
reinterpret_cast<VecType *>(bufdi)[0];
}
}
}
template <typename scalar_t, typename Index, int NumTLP, int NumILP>
__global__ void maxPoolBwdGenericKernel(const scalar_t *outFeatures,
const scalar_t *inFeatures,
const scalar_t *fout, scalar_t *fin,
const Index *indicesIn,
const Index *indicesOut, int numHot,
int numPlanes) {
int ILPStrideX[NumILP];
Index RI[NumILP];
Index RO[NumILP];
scalar_t in, out;
#pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++)
ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
for (int ix : tv::KernelLoopX<int, NumILP>(numHot)) {
#pragma unroll
for (int ilp = 0; ilp < NumILP; ilp++) {
if (ix + ILPStrideX[ilp] < numHot) {
RI[ilp] = indicesIn[ix + ILPStrideX[ilp]] * numPlanes;
RO[ilp] = indicesOut[ix + ILPStrideX[ilp]] * numPlanes;
}
}
for (int iy : tv::KernelLoopY<int>(numPlanes)) {
#pragma unroll
for (int ilp = 0; ilp < NumILP; ++ilp) {
if (ix + ILPStrideX[ilp] < numHot) {
in = inFeatures[RI[ilp] + iy];
out = outFeatures[RO[ilp] + iy];
if (in == out) {
fin[RI[ilp] + iy] += fout[RO[ilp] + iy];
}
}
}
}
}
}
namespace functor {
template <typename scalar_t, typename Index>
struct SparseMaxPoolForwardFunctor<tv::TorchGPU, scalar_t, Index> {
using vecload_type_t =
std::conditional_t<std::is_same<scalar_t, at::Half>::value, int2, int4>;
using kernel_block_t = mp_list_c<int, 64, 32, 16>;
void operator()(const tv::TorchGPU &d, tv::TensorView<scalar_t> outFeatures,
tv::TensorView<const scalar_t> inFeatures,
tv::TensorView<const Index> indices, int size) {
if (size <= 0) return;
int numPlanes = inFeatures.dim(1);
bool notFound = true;
constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(scalar_t);
mp_for_each<kernel_block_t>([=, &outFeatures, &inFeatures, &indices,
&notFound](auto NumTLP) {
constexpr int NumILP = NumTLP / 4;
int numHotBlock = (size / NumTLP) * NumTLP;
if (notFound) {
if (numPlanes % NumTLP == 0) {
if (numHotBlock >= NumTLP) {
maxPoolFwdVecBlockKernel<scalar_t, Index, int(NumTLP), NumILP,
vecload_type_t>
<<<dim3(std::min(size / NumTLP, 512), numPlanes / NumTLP),
dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0,
d.getStream()>>>(outFeatures.data(), inFeatures.data(),
indices.subview(0).data(),
indices.subview(1).data(), numHotBlock,
numPlanes / vecloadFactor);
TV_CHECK_MUSA_ERR();
}
if (size > numHotBlock) {
maxPoolFwdGenericKernel<scalar_t, Index, int(NumTLP), NumILP>
<<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP),
0, d.getStream()>>>(outFeatures.data(), inFeatures.data(),
indices.subview(0).data() + numHotBlock,
indices.subview(1).data() + numHotBlock,
size - numHotBlock, numPlanes);
TV_CHECK_MUSA_ERR();
}
notFound = false;
}
}
});
if (notFound) {
constexpr int NumTLP = 64;
constexpr int NumILP = NumTLP / 4;
int numHotBlock = (size / NumTLP) * NumTLP;
if (numHotBlock >= NumTLP) {
maxPoolFwdGenericBlockKernel<scalar_t, Index, NumTLP, NumILP>
<<<dim3(size / NumTLP, tv::launch::DivUp(numPlanes, NumTLP)),
dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>(
outFeatures.data(), inFeatures.data(),
indices.subview(0).data(), indices.subview(1).data(),
numHotBlock, numPlanes);
TV_CHECK_MUSA_ERR();
}
if (size > numHotBlock) {
maxPoolFwdGenericKernel<scalar_t, Index, NumTLP, NumILP>
<<<dim3(1, tv::launch::DivUp(numPlanes, NumTLP)),
dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>(
outFeatures.data(), inFeatures.data(),
indices.subview(0).data() + numHotBlock,
indices.subview(1).data() + numHotBlock, size - numHotBlock,
numPlanes);
TV_CHECK_MUSA_ERR();
}
}
}
};
template <typename scalar_t, typename Index>
struct SparseMaxPoolBackwardFunctor<tv::TorchGPU, scalar_t, Index> {
using vecload_type_t =
std::conditional_t<std::is_same<scalar_t, at::Half>::value, int2, int4>;
using kernel_block_t = mp_list_c<int, 64, 32, 16>;
void operator()(const tv::TorchGPU &d,
tv::TensorView<const scalar_t> outFeatures,
tv::TensorView<const scalar_t> inFeatures,
tv::TensorView<const scalar_t> fout,
tv::TensorView<scalar_t> fin,
tv::TensorView<const Index> indices, int size) {
if (size <= 0) return;
int numPlanes = inFeatures.dim(1);
bool notFound = true;
constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(scalar_t);
mp_for_each<kernel_block_t>([=, &outFeatures, &inFeatures, &fout, &fin,
&indices, &notFound](auto NumTLP) {
constexpr int NumILP = NumTLP / 4;
int numHotBlock = (size / NumTLP) * NumTLP;
if (notFound) {
if (numPlanes % NumTLP == 0) {
if (numHotBlock >= NumTLP) {
maxPoolBwdVecBlockKernel<scalar_t, Index, int(NumTLP), NumILP,
vecload_type_t>
<<<dim3(std::min(size / NumTLP, 512), numPlanes / NumTLP),
dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0,
d.getStream()>>>(outFeatures.data(), inFeatures.data(),
fout.data(), fin.data(),
indices.subview(0).data(),
indices.subview(1).data(), numHotBlock,
numPlanes / vecloadFactor);
TV_CHECK_MUSA_ERR();
}
if (size > numHotBlock) {
maxPoolBwdGenericKernel<scalar_t, Index, int(NumTLP), NumILP>
<<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP),
0, d.getStream()>>>(outFeatures.data(), inFeatures.data(),
fout.data(), fin.data(),
indices.subview(0).data() + numHotBlock,
indices.subview(1).data() + numHotBlock,
size - numHotBlock, numPlanes);
TV_CHECK_MUSA_ERR();
}
notFound = false;
}
}
});
if (notFound) {
constexpr int NumTLP = 64;
constexpr int NumILP = NumTLP / 4;
int numHotBlock = (size / NumTLP) * NumTLP;
if (numHotBlock >= NumTLP) {
maxPoolBwdGenericBlockKernel<scalar_t, Index, NumTLP, NumILP>
<<<dim3(size / NumTLP, tv::launch::DivUp(numPlanes, NumTLP)),
dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>(
outFeatures.data(), inFeatures.data(), fout.data(), fin.data(),
indices.subview(0).data(), indices.subview(1).data(),
numHotBlock, numPlanes);
TV_CHECK_MUSA_ERR();
}
if (size > numHotBlock) {
maxPoolBwdGenericKernel<scalar_t, Index, NumTLP, NumILP>
<<<dim3(1, tv::launch::DivUp(numPlanes, NumTLP)),
dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>(
outFeatures.data(), inFeatures.data(), fout.data(), fin.data(),
indices.subview(0).data() + numHotBlock,
indices.subview(1).data() + numHotBlock, size - numHotBlock,
numPlanes);
TV_CHECK_MUSA_ERR();
}
}
}
};
} // namespace functor
#define DECLARE_GPU_SPECS_T_INDEX(scalar_t, Index) \
template struct functor::SparseMaxPoolForwardFunctor<tv::TorchGPU, scalar_t, \
Index>; \
template struct functor::SparseMaxPoolBackwardFunctor<tv::TorchGPU, \
scalar_t, Index>;
#define DECLARE_GPU_SPECS(scalar_t) DECLARE_GPU_SPECS_T_INDEX(scalar_t, int);
DECLARE_GPU_SPECS(float);
DECLARE_GPU_SPECS(double);
DECLARE_GPU_SPECS(at::Half);
#undef DECLARE_GPU_SPECS
#undef DECLARE_GPU_SPECS_T_INDEX

View File

@ -0,0 +1,91 @@
#include <musa_runtime_api.h>
#include <torch/script.h>
// clang-format off
// TODO: make spconv_utils.h order agnostic
#include "../spconv_utils.h"
// clang-format on
#include <utils/spconv/spconv/maxpool.h>
#include "pytorch_musa_helper.hpp"
torch::Tensor IndiceMaxpoolForwardMUSAKernelLauncher(torch::Tensor features,
torch::Tensor indicePairs,
torch::Tensor indiceNum,
int64_t numAct) {
c10::musa::MUSAGuard device_guard(features.device());
auto device = features.device().type();
auto kernelVolume = indicePairs.size(0);
auto numInPlanes = features.size(1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto options =
torch::TensorOptions().dtype(features.dtype()).device(features.device());
torch::Tensor output = torch::zeros({numAct, numInPlanes}, options);
for (int i = 0; i < kernelVolume; ++i) {
auto nHot = indicePairNumCpu.data_ptr<int>()[i];
if (nHot <= 0) {
continue;
}
AT_DISPATCH_FLOATING_TYPES(
features.scalar_type(), "IndiceMaxpoolForwardKernel", [&] {
if (device == torch::kCPU) {
functor::SparseMaxPoolForwardFunctor<tv::CPU, scalar_t, int>
forwardFtor;
forwardFtor(tv::CPU(), tv::torch2tv<scalar_t>(output),
tv::torch2tv<const scalar_t>(features),
tv::torch2tv<const int>(indicePairs).subview(i), nHot);
} else {
functor::SparseMaxPoolForwardFunctor<tv::TorchGPU, scalar_t, int>
forwardFtor;
forwardFtor(tv::TorchGPU(), tv::torch2tv<scalar_t>(output),
tv::torch2tv<const scalar_t>(features),
tv::torch2tv<const int>(indicePairs).subview(i), nHot);
TV_CHECK_MUSA_ERR();
}
});
}
return output;
}
torch::Tensor IndiceMaxpoolBackwardMUSAKernelLauncher(torch::Tensor features,
torch::Tensor outFeatures,
torch::Tensor outGrad,
torch::Tensor indicePairs,
torch::Tensor indiceNum) {
c10::musa::MUSAGuard device_guard(features.device());
auto device = features.device().type();
auto numInPlanes = features.size(1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto options =
torch::TensorOptions().dtype(features.dtype()).device(features.device());
torch::Tensor inputGrad = torch::zeros(features.sizes(), options);
auto kernelVolume = indicePairs.size(0);
for (int i = 0; i < kernelVolume; ++i) {
auto nHot = indicePairNumCpu.data_ptr<int>()[i];
if (nHot <= 0) {
continue;
}
AT_DISPATCH_FLOATING_TYPES(
features.scalar_type(), "IndiceMaxpoolBackwardKernel", [&] {
if (device == torch::kCPU) {
functor::SparseMaxPoolBackwardFunctor<tv::CPU, scalar_t, int>
backwardFtor;
backwardFtor(tv::CPU(), tv::torch2tv<const scalar_t>(outFeatures),
tv::torch2tv<const scalar_t>(features),
tv::torch2tv<const scalar_t>(outGrad),
tv::torch2tv<scalar_t>(inputGrad),
tv::torch2tv<const int>(indicePairs).subview(i), nHot);
} else {
functor::SparseMaxPoolBackwardFunctor<tv::TorchGPU, scalar_t, int>
backwardFtor;
backwardFtor(tv::TorchGPU(),
tv::torch2tv<const scalar_t>(outFeatures),
tv::torch2tv<const scalar_t>(features),
tv::torch2tv<const scalar_t>(outGrad),
tv::torch2tv<scalar_t>(inputGrad),
tv::torch2tv<const int>(indicePairs).subview(i), nHot);
TV_CHECK_MUSA_ERR();
}
});
}
return inputGrad;
}

View File

@ -0,0 +1,110 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_musa_helper.hpp"
#include "sync_bn_musa_kernel.muh"
void SyncBNForwardMeanMUSAKernelLauncher(const Tensor input, Tensor mean) {
int num = input.size(0);
int channels = input.size(1);
int spatial = input.size(2);
c10::musa::MUSAGuard device_guard(input.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "sync_bn_forward_mean_musa_kernel", [&] {
sync_bn_forward_mean_musa_kernel<scalar_t>
<<<channels, THREADS_PER_BLOCK, 0, stream>>>(
input.data_ptr<scalar_t>(), mean.data_ptr<float>(), num,
channels, spatial);
});
AT_MUSA_CHECK(musaGetLastError());
}
void SyncBNForwardVarMUSAKernelLauncher(const Tensor input, const Tensor mean,
Tensor var) {
int num = input.size(0);
int channels = input.size(1);
int spatial = input.size(2);
c10::musa::MUSAGuard device_guard(input.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "sync_bn_forward_mean_musa_kernel", [&] {
sync_bn_forward_var_musa_kernel<scalar_t>
<<<channels, THREADS_PER_BLOCK, 0, stream>>>(
input.data_ptr<scalar_t>(), mean.data_ptr<float>(),
var.data_ptr<float>(), num, channels, spatial);
});
AT_MUSA_CHECK(musaGetLastError());
}
void SyncBNForwardOutputMUSAKernelLauncher(
const Tensor input, const Tensor mean, const Tensor var,
Tensor running_mean, Tensor running_var, const Tensor weight,
const Tensor bias, Tensor norm, Tensor std, Tensor output, float eps,
float momentum, int group_size) {
int num = input.size(0);
int channels = input.size(1);
int spatial = input.size(2);
c10::musa::MUSAGuard device_guard(input.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "sync_bn_forward_mean_musa_kernel", [&] {
sync_bn_forward_output_musa_kernel<scalar_t>
<<<channels, THREADS_PER_BLOCK, 0, stream>>>(
input.data_ptr<scalar_t>(), mean.data_ptr<float>(),
var.data_ptr<float>(), running_mean.data_ptr<float>(),
running_var.data_ptr<float>(), weight.data_ptr<float>(),
bias.data_ptr<float>(), norm.data_ptr<float>(),
std.data_ptr<float>(), output.data_ptr<scalar_t>(), num,
channels, spatial, eps, momentum, group_size);
});
AT_MUSA_CHECK(musaGetLastError());
}
void SyncBNBackwardParamMUSAKernelLauncher(const Tensor grad_output,
const Tensor norm,
Tensor grad_weight,
Tensor grad_bias) {
int num = grad_output.size(0);
int channels = grad_output.size(1);
int spatial = grad_output.size(2);
c10::musa::MUSAGuard device_guard(grad_output.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES(
grad_output.scalar_type(), "sync_bn_backward_param_musa_kernel", [&] {
sync_bn_backward_param_musa_kernel<scalar_t>
<<<channels, THREADS_PER_BLOCK, 0, stream>>>(
grad_output.data_ptr<scalar_t>(), norm.data_ptr<float>(),
grad_weight.data_ptr<float>(), grad_bias.data_ptr<float>(), num,
channels, spatial);
});
AT_MUSA_CHECK(musaGetLastError());
}
void SyncBNBackwardDataMUSAKernelLauncher(const Tensor grad_output,
const Tensor weight,
const Tensor grad_weight,
const Tensor grad_bias,
const Tensor norm, const Tensor std,
Tensor grad_input) {
int output_size = grad_input.numel();
int num = grad_input.size(0);
int channels = grad_input.size(1);
int spatial = grad_input.size(2);
c10::musa::MUSAGuard device_guard(grad_input.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES(
grad_output.scalar_type(), "sync_bn_backward_data_musa_kernel", [&] {
sync_bn_backward_data_musa_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, grad_output.data_ptr<scalar_t>(),
weight.data_ptr<float>(), grad_weight.data_ptr<float>(),
grad_bias.data_ptr<float>(), norm.data_ptr<float>(),
std.data_ptr<float>(), grad_input.data_ptr<scalar_t>(), num,
channels, spatial);
});
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,66 @@
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate_gpu.cu
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "pytorch_musa_helper.hpp"
#include "three_interpolate_musa_kernel.muh"
void ThreeInterpolateForwardMUSAKernelLauncher(int b, int c, int m, int n,
const Tensor points,
const Tensor idx,
const Tensor weight,
Tensor out) {
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)
c10::musa::MUSAGuard device_guard(points.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(GET_BLOCKS(n, THREADS_PER_BLOCK), c, b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "three_interpolate_forward_musa_kernel", [&] {
three_interpolate_forward_musa_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
b, c, m, n, points.data_ptr<scalar_t>(), idx.data_ptr<int>(),
weight.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
});
AT_MUSA_CHECK(musaGetLastError());
}
void ThreeInterpolateBackwardMUSAKernelLauncher(int b, int c, int n, int m,
const Tensor grad_out,
const Tensor idx,
const Tensor weight,
Tensor grad_points) {
// grad_out: (B, C, N)
// weight: (B, N, 3)
// output:
// grad_points: (B, C, M)
c10::musa::MUSAGuard device_guard(grad_out.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(GET_BLOCKS(n, THREADS_PER_BLOCK), c, b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_out.scalar_type(), "three_interpolate_backward_musa_kernel", [&] {
three_interpolate_backward_musa_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
b, c, n, m, grad_out.data_ptr<scalar_t>(), idx.data_ptr<int>(),
weight.data_ptr<scalar_t>(), grad_points.data_ptr<scalar_t>());
});
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,35 @@
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate_gpu.cu
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "pytorch_musa_helper.hpp"
#include "three_nn_musa_kernel.muh"
void ThreeNNForwardMUSAKernelLauncher(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2,
Tensor idx) {
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
c10::musa::MUSAGuard device_guard(unknown.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(GET_BLOCKS(n, THREADS_PER_BLOCK), b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES(
unknown.scalar_type(), "three_nn_forward_musa_kernel", [&] {
three_nn_forward_musa_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
b, n, m, unknown.data_ptr<scalar_t>(), known.data_ptr<scalar_t>(),
dist2.data_ptr<scalar_t>(), idx.data_ptr<int>());
});
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,55 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_musa_helper.hpp"
#include "pytorch_device_registry.hpp"
#include "tin_shift_musa_kernel.muh"
void TINShiftForwardMUSAKernelLauncher(Tensor input, Tensor shift,
Tensor output) {
int output_size = output.numel();
int batch_size = input.size(0);
int t_size = input.size(1);
int channels = input.size(2);
int hw_size = input.size(3);
int group_size = shift.size(1);
int group_channel = channels / group_size;
int num_kernels = batch_size * hw_size * channels;
c10::musa::MUSAGuard device_guard(input.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "tin_shift_forward_musa_kernel", [&] {
tin_shift_forward_musa_kernel<scalar_t>
<<<GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input.data_ptr<scalar_t>(), shift.data_ptr<int>(),
output.data_ptr<scalar_t>(), batch_size, channels, t_size,
hw_size, group_size, group_channel);
});
AT_MUSA_CHECK(musaGetLastError());
}
void TINShiftBackwardMUSAKernelLauncher(Tensor grad_output, Tensor shift,
Tensor grad_input) {
int output_size = grad_output.numel();
int batch_size = grad_output.size(0);
int t_size = grad_output.size(1);
int channels = grad_output.size(2);
int hw_size = grad_output.size(3);
int group_size = shift.size(1);
int group_channel = channels / group_size;
int num_kernels = batch_size * hw_size * channels;
c10::musa::MUSAGuard device_guard(grad_output.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "tin_shift_backward_musa_kernel", [&] {
tin_shift_backward_musa_kernel<scalar_t>
<<<GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, stream>>>(
output_size, grad_output.data_ptr<scalar_t>(),
shift.data_ptr<int>(), grad_input.data_ptr<scalar_t>(),
batch_size, channels, t_size, hw_size, group_size,
group_channel);
});
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,749 @@
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <c10/util/Half.h>
#include <torch/types.h>
#include "pytorch_musa_helper.hpp"
#if MUSA_ARCH > 21
struct upfirdn2d_kernel_params {
const void *x;
const float *f;
void *y;
int2 up;
int2 down;
int2 pad0;
int flip;
float gain;
int4 inSize; // [width, height, channel, batch]
int4 inStride;
int2 filterSize; // [width, height]
int2 filterStride;
int4 outSize; // [width, height, channel, batch]
int4 outStride;
int sizeMinor;
int sizeMajor;
int loopMinor;
int loopMajor;
int loopX;
int launchMinor;
int launchMajor;
};
//------------------------------------------------------------------------
// MUSA kernel specialization.
struct upfirdn2d_kernel_spec {
void *kernel;
int tileOutW;
int tileOutH;
int loopMinor;
int loopX;
};
//------------------------------------------------------------------------
// MUSA kernel selection.
template <class T>
upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params &p);
//------------------------------------------------------------------------
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//------------------------------------------------------------------------
// Helpers.
template <class T>
struct InternalType;
template <>
struct InternalType<double> {
typedef double scalar_t;
};
template <>
struct InternalType<float> {
typedef float scalar_t;
};
template <>
struct InternalType<c10::Half> {
typedef float scalar_t;
};
static __device__ __forceinline__ int floor_div(int a, int b) {
int t = 1 - a / b;
return (a + t * b) / b - t;
}
//------------------------------------------------------------------------
// Generic MUSA implementation for large filters.
template <class T>
static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) {
typedef typename InternalType<T>::scalar_t scalar_t;
// Calculate thread index.
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
int outY = minorBase / p.launchMinor;
minorBase -= outY * p.launchMinor;
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
int majorBase = blockIdx.z * p.loopMajor;
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
return;
// Setup Y receptive field.
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
int h =
min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
if (p.flip) filterY = p.filterSize.y - 1 - filterY;
// Loop over major, minor, and X.
for (int majorIdx = 0, major = majorBase;
majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
for (int minorIdx = 0, minor = minorBase;
minorIdx < p.loopMinor & minor < p.sizeMinor;
minorIdx++, minor += p.launchMinor) {
int nc = major * p.sizeMinor + minor;
int n = nc / p.inSize.z;
int c = nc - n * p.inSize.z;
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x;
loopX++, outX += blockDim.y) {
// Setup X receptive field.
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
int w =
min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) -
inX;
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
if (p.flip) filterX = p.filterSize.x - 1 - filterX;
// Initialize pointers.
const T *xp =
&((const T *)p.x)[inX * p.inStride.x + inY * p.inStride.y +
c * p.inStride.z + n * p.inStride.w];
const float *fp =
&p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
// Inner loop.
scalar_t v = 0;
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
v += (scalar_t)(*xp) * (scalar_t)(*fp);
xp += p.inStride.x;
fp += filterStepX;
}
xp += p.inStride.y - w * p.inStride.x;
fp += filterStepY - w * filterStepX;
}
// Store result.
v *= p.gain;
((T *)p.y)[outX * p.outStride.x + outY * p.outStride.y +
c * p.outStride.z + n * p.outStride.w] = (T)v;
}
}
}
//------------------------------------------------------------------------
// Specialized MUSA implementation for small filters.
template <class T, int upx, int upy, int downx, int downy, int filterW,
int filterH, int tileOutW, int tileOutH, int loopMinor>
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) {
typedef typename InternalType<T>::scalar_t scalar_t;
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
__shared__ volatile scalar_t sf[filterH][filterW];
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
// Calculate tile index.
int minorBase = blockIdx.x;
int tileOutY = minorBase / p.launchMinor;
minorBase -= tileOutY * p.launchMinor;
minorBase *= loopMinor;
tileOutY *= tileOutH;
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
int majorBase = blockIdx.z * p.loopMajor;
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y |
majorBase >= p.sizeMajor)
return;
// Load filter (flipped).
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW;
tapIdx += blockDim.x) {
int fy = tapIdx / filterW;
int fx = tapIdx - fy * filterW;
scalar_t v = 0;
if (fx < p.filterSize.x & fy < p.filterSize.y) {
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
}
sf[fy][fx] = v;
}
// Loop over major and X.
for (int majorIdx = 0, major = majorBase;
majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) {
int baseNC = major * p.sizeMinor + minorBase;
int n = baseNC / p.inSize.z;
int baseC = baseNC - n * p.inSize.z;
for (int loopX = 0, tileOutX = tileOutXBase;
loopX < p.loopX & tileOutX < p.outSize.x;
loopX++, tileOutX += tileOutW) {
// Load input pixels.
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
int tileInX = floor_div(tileMidX, upx);
int tileInY = floor_div(tileMidY, upy);
__syncthreads();
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor;
inIdx += blockDim.x) {
int relC = inIdx;
int relInX = relC / loopMinor;
int relInY = relInX / tileInW;
relC -= relInX * loopMinor;
relInX -= relInY * tileInW;
int c = baseC + relC;
int inX = tileInX + relInX;
int inY = tileInY + relInY;
scalar_t v = 0;
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y &
c < p.inSize.z)
v = (scalar_t)(
(const T *)p.x)[inX * p.inStride.x + inY * p.inStride.y +
c * p.inStride.z + n * p.inStride.w];
sx[relInY][relInX][relC] = v;
}
// Loop over output pixels.
__syncthreads();
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor;
outIdx += blockDim.x) {
int relC = outIdx;
int relOutX = relC / loopMinor;
int relOutY = relOutX / tileOutW;
relC -= relOutX * loopMinor;
relOutX -= relOutY * tileOutW;
int c = baseC + relC;
int outX = tileOutX + relOutX;
int outY = tileOutY + relOutY;
// Setup receptive field.
int midX = tileMidX + relOutX * downx;
int midY = tileMidY + relOutY * downy;
int inX = floor_div(midX, upx);
int inY = floor_div(midY, upy);
int relInX = inX - tileInX;
int relInY = inY - tileInY;
int filterX = (inX + 1) * upx - midX - 1; // flipped
int filterY = (inY + 1) * upy - midY - 1; // flipped
// Inner loop.
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) {
scalar_t v = 0;
#pragma unroll
for (int y = 0; y < filterH / upy; y++)
#pragma unroll
for (int x = 0; x < filterW / upx; x++)
v += sx[relInY + y][relInX + x][relC] *
sf[filterY + y * upy][filterX + x * upx];
v *= p.gain;
((T *)p.y)[outX * p.outStride.x + outY * p.outStride.y +
c * p.outStride.z + n * p.outStride.w] = (T)v;
}
}
}
}
}
//------------------------------------------------------------------------
// MUSA kernel selection.
template <class T>
upfirdn2d_kernel_spec choose_upfirdn2d_kernel(
const upfirdn2d_kernel_params &p) {
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
upfirdn2d_kernel_spec spec = {(void *)upfirdn2d_kernel_large<T>, -1, -1, 1,
4}; // contiguous
if (s == 1)
spec = {(void *)upfirdn2d_kernel_large<T>, -1, -1, 4, 1}; // channels_last
// No up/downsampling.
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) {
// contiguous
if (s != 1 && fx <= 24 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 24, 24, 64, 32, 1>,
64, 32, 1, 1};
if (s != 1 && fx <= 16 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 16, 16, 64, 32, 1>,
64, 32, 1, 1};
if (s != 1 && fx <= 7 && fy <= 7)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 7, 7, 64, 16, 1>,
64, 16, 1, 1};
if (s != 1 && fx <= 6 && fy <= 6)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 6, 6, 64, 16, 1>,
64, 16, 1, 1};
if (s != 1 && fx <= 5 && fy <= 5)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 5, 5, 64, 16, 1>,
64, 16, 1, 1};
if (s != 1 && fx <= 4 && fy <= 4)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 4, 4, 64, 16, 1>,
64, 16, 1, 1};
if (s != 1 && fx <= 3 && fy <= 3)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 3, 3, 64, 16, 1>,
64, 16, 1, 1};
if (s != 1 && fx <= 24 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 24, 1, 128, 8, 1>,
128, 8, 1, 1};
if (s != 1 && fx <= 16 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 16, 1, 128, 8, 1>,
128, 8, 1, 1};
if (s != 1 && fx <= 8 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 8, 1, 128, 8, 1>,
128, 8, 1, 1};
if (s != 1 && fx <= 1 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 24, 32, 32, 1>,
32, 32, 1, 1};
if (s != 1 && fx <= 1 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 16, 32, 32, 1>,
32, 32, 1, 1};
if (s != 1 && fx <= 1 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 8, 32, 32, 1>,
32, 32, 1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 24, 24, 32, 32, 1>,
32, 32, 1, 1};
if (s == 1 && fx <= 16 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 16, 16, 32, 32, 1>,
32, 32, 1, 1};
if (s == 1 && fx <= 7 && fy <= 7)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 7, 7, 16, 16, 8>,
16, 16, 8, 1};
if (s == 1 && fx <= 6 && fy <= 6)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 6, 6, 16, 16, 8>,
16, 16, 8, 1};
if (s == 1 && fx <= 5 && fy <= 5)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 5, 5, 16, 16, 8>,
16, 16, 8, 1};
if (s == 1 && fx <= 4 && fy <= 4)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 4, 4, 16, 16, 8>,
16, 16, 8, 1};
if (s == 1 && fx <= 3 && fy <= 3)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 3, 3, 16, 16, 8>,
16, 16, 8, 1};
if (s == 1 && fx <= 24 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 24, 1, 128, 1, 16>,
128, 1, 16, 1};
if (s == 1 && fx <= 16 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 16, 1, 128, 1, 16>,
128, 1, 16, 1};
if (s == 1 && fx <= 8 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 8, 1, 128, 1, 16>,
128, 1, 16, 1};
if (s == 1 && fx <= 1 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 24, 1, 128, 16>,
1, 128, 16, 1};
if (s == 1 && fx <= 1 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 16, 1, 128, 16>,
1, 128, 16, 1};
if (s == 1 && fx <= 1 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 8, 1, 128, 16>,
1, 128, 16, 1};
}
// 2x upsampling.
if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) {
// contiguous
if (s != 1 && fx <= 24 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 24, 24, 64, 32, 1>,
64, 32, 1, 1};
if (s != 1 && fx <= 16 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 16, 16, 64, 32, 1>,
64, 32, 1, 1};
if (s != 1 && fx <= 8 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 8, 8, 64, 16, 1>,
64, 16, 1, 1};
if (s != 1 && fx <= 6 && fy <= 6)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 6, 6, 64, 16, 1>,
64, 16, 1, 1};
if (s != 1 && fx <= 4 && fy <= 4)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 4, 4, 64, 16, 1>,
64, 16, 1, 1};
if (s != 1 && fx <= 2 && fy <= 2)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 2, 2, 64, 16, 1>,
64, 16, 1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 24, 24, 32, 32, 1>,
32, 32, 1, 1};
if (s == 1 && fx <= 16 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 16, 16, 32, 32, 1>,
32, 32, 1, 1};
if (s == 1 && fx <= 8 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 8, 8, 16, 16, 8>,
16, 16, 8, 1};
if (s == 1 && fx <= 6 && fy <= 6)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 6, 6, 16, 16, 8>,
16, 16, 8, 1};
if (s == 1 && fx <= 4 && fy <= 4)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 4, 4, 16, 16, 8>,
16, 16, 8, 1};
if (s == 1 && fx <= 2 && fy <= 2)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 2, 2, 16, 16, 8>,
16, 16, 8, 1};
}
if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) {
// contiguous
if (s != 1 && fx <= 24 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 24, 1, 128, 8, 1>,
128, 8, 1, 1};
if (s != 1 && fx <= 16 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 16, 1, 128, 8, 1>,
128, 8, 1, 1};
if (s != 1 && fx <= 8 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 8, 1, 128, 8, 1>,
128, 8, 1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 24, 1, 128, 1, 16>,
128, 1, 16, 1};
if (s == 1 && fx <= 16 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 16, 1, 128, 1, 16>,
128, 1, 16, 1};
if (s == 1 && fx <= 8 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 8, 1, 128, 1, 16>,
128, 1, 16, 1};
}
if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) {
// contiguous
if (s != 1 && fx <= 1 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 24, 32, 32, 1>,
32, 32, 1, 1};
if (s != 1 && fx <= 1 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 16, 32, 32, 1>,
32, 32, 1, 1};
if (s != 1 && fx <= 1 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 8, 32, 32, 1>,
32, 32, 1, 1};
// channels_last
if (s == 1 && fx <= 1 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 24, 1, 128, 16>,
1, 128, 16, 1};
if (s == 1 && fx <= 1 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 16, 1, 128, 16>,
1, 128, 16, 1};
if (s == 1 && fx <= 1 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 8, 1, 128, 16>,
1, 128, 16, 1};
}
// 2x downsampling.
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) {
// contiguous
if (s != 1 && fx <= 24 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 24, 24, 32, 16, 1>,
32, 16, 1, 1};
if (s != 1 && fx <= 16 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 16, 16, 32, 16, 1>,
32, 16, 1, 1};
if (s != 1 && fx <= 8 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 8, 8, 32, 8, 1>, 32,
8, 1, 1};
if (s != 1 && fx <= 6 && fy <= 6)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 6, 6, 32, 8, 1>, 32,
8, 1, 1};
if (s != 1 && fx <= 4 && fy <= 4)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 4, 4, 32, 8, 1>, 32,
8, 1, 1};
if (s != 1 && fx <= 2 && fy <= 2)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 2, 2, 32, 8, 1>, 32,
8, 1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 24, 24, 16, 16, 1>,
16, 16, 1, 1};
if (s == 1 && fx <= 16 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 16, 16, 16, 16, 1>,
16, 16, 1, 1};
if (s == 1 && fx <= 8 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 8, 8, 8, 8, 8>, 8,
8, 8, 1};
if (s == 1 && fx <= 6 && fy <= 6)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 6, 6, 8, 8, 8>, 8,
8, 8, 1};
if (s == 1 && fx <= 4 && fy <= 4)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 4, 4, 8, 8, 8>, 8,
8, 8, 1};
if (s == 1 && fx <= 2 && fy <= 2)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 2, 2, 8, 8, 8>, 8,
8, 8, 1};
}
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) {
// contiguous
if (s != 1 && fx <= 24 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 24, 1, 64, 8, 1>,
64, 8, 1, 1};
if (s != 1 && fx <= 16 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 16, 1, 64, 8, 1>,
64, 8, 1, 1};
if (s != 1 && fx <= 8 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 8, 1, 64, 8, 1>, 64,
8, 1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 24, 1, 64, 1, 8>,
64, 1, 8, 1};
if (s == 1 && fx <= 16 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 16, 1, 64, 1, 8>,
64, 1, 8, 1};
if (s == 1 && fx <= 8 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 8, 1, 64, 1, 8>, 64,
1, 8, 1};
}
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) {
// contiguous
if (s != 1 && fx <= 1 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 24, 32, 16, 1>,
32, 16, 1, 1};
if (s != 1 && fx <= 1 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 16, 32, 16, 1>,
32, 16, 1, 1};
if (s != 1 && fx <= 1 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 8, 32, 16, 1>,
32, 16, 1, 1};
// channels_last
if (s == 1 && fx <= 1 && fy <= 24)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 24, 1, 64, 8>, 1,
64, 8, 1};
if (s == 1 && fx <= 1 && fy <= 16)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 16, 1, 64, 8>, 1,
64, 8, 1};
if (s == 1 && fx <= 1 && fy <= 8)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 8, 1, 64, 8>, 1,
64, 8, 1};
}
// 4x upsampling.
if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) {
// contiguous
if (s != 1 && fx <= 48 && fy <= 48)
spec = {(void *)upfirdn2d_kernel_small<T, 4, 4, 1, 1, 48, 48, 64, 32, 1>,
64, 32, 1, 1};
if (s != 1 && fx <= 32 && fy <= 32)
spec = {(void *)upfirdn2d_kernel_small<T, 4, 4, 1, 1, 32, 32, 64, 32, 1>,
64, 32, 1, 1};
// channels_last
if (s == 1 && fx <= 48 && fy <= 48)
spec = {(void *)upfirdn2d_kernel_small<T, 4, 4, 1, 1, 48, 48, 32, 32, 1>,
32, 32, 1, 1};
if (s == 1 && fx <= 32 && fy <= 32)
spec = {(void *)upfirdn2d_kernel_small<T, 4, 4, 1, 1, 32, 32, 32, 32, 1>,
32, 32, 1, 1};
}
if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) {
// contiguous
if (s != 1 && fx <= 48 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 4, 1, 1, 1, 48, 1, 128, 8, 1>,
128, 8, 1, 1};
if (s != 1 && fx <= 32 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 4, 1, 1, 1, 32, 1, 128, 8, 1>,
128, 8, 1, 1};
// channels_last
if (s == 1 && fx <= 48 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 4, 1, 1, 1, 48, 1, 128, 1, 16>,
128, 1, 16, 1};
if (s == 1 && fx <= 32 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 4, 1, 1, 1, 32, 1, 128, 1, 16>,
128, 1, 16, 1};
}
if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) {
// contiguous
if (s != 1 && fx <= 1 && fy <= 48)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 4, 1, 1, 1, 48, 32, 32, 1>,
32, 32, 1, 1};
if (s != 1 && fx <= 1 && fy <= 32)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 4, 1, 1, 1, 32, 32, 32, 1>,
32, 32, 1, 1};
// channels_last
if (s == 1 && fx <= 1 && fy <= 48)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 4, 1, 1, 1, 48, 1, 128, 16>,
1, 128, 16, 1};
if (s == 1 && fx <= 1 && fy <= 32)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 4, 1, 1, 1, 32, 1, 128, 16>,
1, 128, 16, 1};
}
// 4x downsampling (inefficient).
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1) {
// contiguous
if (s != 1 && fx <= 48 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 4, 1, 48, 1, 32, 8, 1>,
32, 8, 1, 1};
if (s != 1 && fx <= 32 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 4, 1, 32, 1, 32, 8, 1>,
32, 8, 1, 1};
// channels_last
if (s == 1 && fx <= 48 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 4, 1, 48, 1, 32, 1, 8>,
32, 1, 8, 1};
if (s == 1 && fx <= 32 && fy <= 1)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 4, 1, 32, 1, 32, 1, 8>,
32, 1, 8, 1};
}
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4) {
// contiguous
if (s != 1 && fx <= 1 && fy <= 48)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 4, 1, 48, 32, 8, 1>,
32, 8, 1, 1};
if (s != 1 && fx <= 1 && fy <= 32)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 4, 1, 32, 32, 8, 1>,
32, 8, 1, 1};
// channels_last
if (s == 1 && fx <= 1 && fy <= 48)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 4, 1, 48, 1, 32, 8>, 1,
32, 8, 1};
if (s == 1 && fx <= 1 && fy <= 32)
spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 4, 1, 32, 1, 32, 8>, 1,
32, 8, 1};
}
return spec;
}
//------------------------------------------------------------------------
// Template specializations.
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double>(
const upfirdn2d_kernel_params &p);
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float>(
const upfirdn2d_kernel_params &p);
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(
const upfirdn2d_kernel_params &p);
//------------------------------------------------------------------------
//------------------------------------------------------------------------
torch::Tensor upfirdn2d_op(torch::Tensor x, torch::Tensor f, int upx, int upy,
int downx, int downy, int padx0, int padx1,
int pady0, int pady1, bool flip, float gain) {
// Validate arguments.
TORCH_CHECK(x.is_privateuseone(), "x must reside on MUSA device");
TORCH_CHECK(f.device() == x.device(),
"f must reside on the same device as x");
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
TORCH_CHECK(x.numel() > 0, "x has zero size");
TORCH_CHECK(f.numel() > 0, "f has zero size");
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
TORCH_CHECK((x.size(0) - 1) * x.stride(0) + (x.size(1) - 1) * x.stride(1) +
(x.size(2) - 1) * x.stride(2) +
(x.size(3) - 1) * x.stride(3) <=
INT_MAX,
"x memory footprint is too large");
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
TORCH_CHECK(downx >= 1 && downy >= 1,
"downsampling factor must be at least 1");
// Create output tensor.
const at::musa::OptionalMUSAGuard device_guard(device_of(x));
int outW =
((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
int outH =
((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW},
x.options(), x.suggest_memory_format());
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
TORCH_CHECK((y.size(0) - 1) * y.stride(0) + (y.size(1) - 1) * y.stride(1) +
(y.size(2) - 1) * y.stride(2) +
(y.size(3) - 1) * y.stride(3) <=
INT_MAX,
"output memory footprint is too large");
// Initialize MUSA kernel parameters.
upfirdn2d_kernel_params p;
p.x = x.data_ptr();
p.f = f.data_ptr<float>();
p.y = y.data_ptr();
p.up = make_int2(upx, upy);
p.down = make_int2(downx, downy);
p.pad0 = make_int2(padx0, pady0);
p.flip = (flip) ? 1 : 0;
p.gain = gain;
p.inSize =
make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1),
(int)x.stride(0));
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
p.outSize =
make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1),
(int)y.stride(0));
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
// Choose MUSA kernel.
upfirdn2d_kernel_spec spec;
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "upfirdn2d_musa", [&] {
spec = choose_upfirdn2d_kernel<scalar_t>(p);
});
// Set looping options.
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
p.loopMinor = spec.loopMinor;
p.loopX = spec.loopX;
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
// Compute grid size.
dim3 blockSize, gridSize;
if (spec.tileOutW < 0) // large
{
blockSize = dim3(4, 32, 1);
gridSize =
dim3(((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, p.launchMajor);
} else // small
{
blockSize = dim3(256, 1, 1);
gridSize =
dim3(((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, p.launchMajor);
}
// Launch MUSA kernel.
void *args[] = {&p};
#ifdef MMCV_WITH_HIP
AT_MUSA_CHECK(hipLaunchKernel(spec.kernel, gridSize, blockSize, args, 0,
c10::musa::getCurrentMUSAStream()));
#else
AT_MUSA_CHECK(musaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0,
c10::musa::getCurrentMUSAStream()));
#endif
return y;
}
#else
#warning "upfirdn2d is supported when MUSA_ARCH > 21"
#endif //MUSA_ARCH

View File

@ -0,0 +1,286 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include <stdio.h>
#include <stdlib.h>
#include "pytorch_musa_helper.hpp"
#include "voxelization_musa_kernel.muh"
int HardVoxelizeForwardMUSAKernelLauncher(
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim = 3) {
// current version tooks about 0.04s for one frame on cpu
// check device
c10::musa::MUSAGuard device_guard(points.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
const int num_points = points.size(0);
const int num_features = points.size(1);
const float voxel_x = voxel_size[0];
const float voxel_y = voxel_size[1];
const float voxel_z = voxel_size[2];
const float coors_x_min = coors_range[0];
const float coors_y_min = coors_range[1];
const float coors_z_min = coors_range[2];
const float coors_x_max = coors_range[3];
const float coors_y_max = coors_range[4];
const float coors_z_max = coors_range[5];
const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
// map points to voxel coors
at::Tensor temp_coors =
at::zeros({num_points, NDim}, points.options().dtype(at::kInt));
dim3 grid(std::min(at::musa::ATenCeilDiv(num_points, 512), 4096));
dim3 block(512);
// 1. link point to corresponding voxel coors
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "hard_voxelize_kernel", ([&] {
dynamic_voxelize_kernel<scalar_t, int><<<grid, block, 0, stream>>>(
points.contiguous().data_ptr<scalar_t>(),
temp_coors.contiguous().data_ptr<int>(), voxel_x, voxel_y, voxel_z,
coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max,
coors_z_max, grid_x, grid_y, grid_z, num_points, num_features,
NDim);
}));
AT_MUSA_CHECK(musaGetLastError());
// 2. map point to the idx of the corresponding voxel, find duplicate coor
// create some temporary variables
auto point_to_pointidx = -at::ones(
{
num_points,
},
points.options().dtype(at::kInt));
auto point_to_voxelidx = -at::ones(
{
num_points,
},
points.options().dtype(at::kInt));
dim3 map_grid(std::min(at::musa::ATenCeilDiv(num_points, 512), 4096));
dim3 map_block(512);
AT_DISPATCH_ALL_TYPES(
temp_coors.scalar_type(), "determin_duplicate", ([&] {
point_to_voxelidx_kernel<int><<<map_grid, map_block, 0, stream>>>(
temp_coors.contiguous().data_ptr<int>(),
point_to_voxelidx.contiguous().data_ptr<int>(),
point_to_pointidx.contiguous().data_ptr<int>(), max_points,
max_voxels, num_points, NDim);
}));
AT_MUSA_CHECK(musaGetLastError());
// 3. determine voxel num and voxel's coor index
// make the logic in the MUSA device could accelerate about 10 times
auto coor_to_voxelidx = -at::ones(
{
num_points,
},
points.options().dtype(at::kInt));
auto voxel_num = at::zeros(
{
1,
},
points.options().dtype(at::kInt)); // must be zero from the beginning
AT_DISPATCH_ALL_TYPES(temp_coors.scalar_type(), "determin_duplicate", ([&] {
determin_voxel_num<int><<<1, 1, 0, stream>>>(
num_points_per_voxel.contiguous().data_ptr<int>(),
point_to_voxelidx.contiguous().data_ptr<int>(),
point_to_pointidx.contiguous().data_ptr<int>(),
coor_to_voxelidx.contiguous().data_ptr<int>(),
voxel_num.contiguous().data_ptr<int>(),
max_points, max_voxels, num_points);
}));
AT_MUSA_CHECK(musaGetLastError());
// 4. copy point features to voxels
// Step 4 & 5 could be parallel
auto pts_output_size = num_points * num_features;
dim3 cp_grid(std::min(at::musa::ATenCeilDiv(pts_output_size, 512), 4096));
dim3 cp_block(512);
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "assign_point_to_voxel", ([&] {
assign_point_to_voxel<float, int><<<cp_grid, cp_block, 0, stream>>>(
pts_output_size, points.contiguous().data_ptr<float>(),
point_to_voxelidx.contiguous().data_ptr<int>(),
coor_to_voxelidx.contiguous().data_ptr<int>(),
voxels.contiguous().data_ptr<float>(), max_points, num_features,
num_points, NDim);
}));
// musaDeviceSynchronize();
// AT_MUSA_CHECK(musaGetLastError());
// 5. copy coors of each voxels
auto coors_output_size = num_points * NDim;
dim3 coors_cp_grid(
std::min(at::musa::ATenCeilDiv(coors_output_size, 512), 4096));
dim3 coors_cp_block(512);
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "assign_point_to_voxel", ([&] {
assign_voxel_coors<float, int>
<<<coors_cp_grid, coors_cp_block, 0, stream>>>(
coors_output_size, temp_coors.contiguous().data_ptr<int>(),
point_to_voxelidx.contiguous().data_ptr<int>(),
coor_to_voxelidx.contiguous().data_ptr<int>(),
coors.contiguous().data_ptr<int>(), num_points, NDim);
}));
AT_MUSA_CHECK(musaGetLastError());
auto voxel_num_cpu = voxel_num.to(at::kCPU);
int voxel_num_int = voxel_num_cpu.data_ptr<int>()[0];
return voxel_num_int;
}
int NondeterministicHardVoxelizeForwardMUSAKernelLauncher(
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim = 3) {
c10::musa::MUSAGuard device_guard(points.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
const int num_points = points.size(0);
const int num_features = points.size(1);
if (num_points == 0) return 0;
dim3 blocks(
std::min(at::musa::ATenCeilDiv(num_points, THREADS_PER_BLOCK), 4096));
dim3 threads(THREADS_PER_BLOCK);
const float voxel_x = voxel_size[0];
const float voxel_y = voxel_size[1];
const float voxel_z = voxel_size[2];
const float coors_x_min = coors_range[0];
const float coors_y_min = coors_range[1];
const float coors_z_min = coors_range[2];
const float coors_x_max = coors_range[3];
const float coors_y_max = coors_range[4];
const float coors_z_max = coors_range[5];
const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
// map points to voxel coors
at::Tensor temp_coors =
at::zeros({num_points, NDim}, points.options().dtype(at::kInt));
// 1. link point to corresponding voxel coors
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "hard_voxelize_kernel", ([&] {
dynamic_voxelize_kernel<scalar_t, int><<<blocks, threads, 0, stream>>>(
points.contiguous().data_ptr<scalar_t>(),
temp_coors.contiguous().data_ptr<int>(), voxel_x, voxel_y, voxel_z,
coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max,
coors_z_max, grid_x, grid_y, grid_z, num_points, num_features,
NDim);
}));
at::Tensor coors_map;
at::Tensor reduce_count;
auto coors_clean = temp_coors.masked_fill(temp_coors.lt(0).any(-1, true), -1);
std::tie(temp_coors, coors_map, reduce_count) =
at::unique_dim(coors_clean, 0, true, true, false);
if (temp_coors[0][0].lt(0).item<bool>()) {
// the first element of temp_coors is (-1,-1,-1) and should be removed
temp_coors = temp_coors.slice(0, 1);
coors_map = coors_map - 1;
}
int num_coors = temp_coors.size(0);
temp_coors = temp_coors.to(at::kInt);
coors_map = coors_map.to(at::kInt);
at::Tensor coors_count = at::zeros({1}, coors_map.options());
at::Tensor coors_order = at::empty({num_coors}, coors_map.options());
at::Tensor pts_id = at::zeros({num_points}, coors_map.options());
reduce_count = at::zeros({num_coors}, coors_map.options());
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "get_assign_pos", ([&] {
nondeterministic_get_assign_pos<<<blocks, threads, 0, stream>>>(
num_points, coors_map.contiguous().data_ptr<int32_t>(),
pts_id.contiguous().data_ptr<int32_t>(),
coors_count.contiguous().data_ptr<int32_t>(),
reduce_count.contiguous().data_ptr<int32_t>(),
coors_order.contiguous().data_ptr<int32_t>());
}));
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "assign_point_to_voxel", ([&] {
nondeterministic_assign_point_voxel<scalar_t>
<<<blocks, threads, 0, stream>>>(
num_points, points.contiguous().data_ptr<scalar_t>(),
coors_map.contiguous().data_ptr<int32_t>(),
pts_id.contiguous().data_ptr<int32_t>(),
temp_coors.contiguous().data_ptr<int32_t>(),
reduce_count.contiguous().data_ptr<int32_t>(),
coors_order.contiguous().data_ptr<int32_t>(),
voxels.contiguous().data_ptr<scalar_t>(),
coors.contiguous().data_ptr<int32_t>(),
num_points_per_voxel.contiguous().data_ptr<int32_t>(),
max_voxels, max_points, num_features, NDim);
}));
AT_MUSA_CHECK(musaGetLastError());
return max_voxels < num_coors ? max_voxels : num_coors;
}
void DynamicVoxelizeForwardMUSAKernelLauncher(
const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size, const std::vector<float> coors_range,
const int NDim = 3) {
// current version tooks about 0.04s for one frame on cpu
// check device
c10::musa::MUSAGuard device_guard(points.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
const int num_points = points.size(0);
const int num_features = points.size(1);
const float voxel_x = voxel_size[0];
const float voxel_y = voxel_size[1];
const float voxel_z = voxel_size[2];
const float coors_x_min = coors_range[0];
const float coors_y_min = coors_range[1];
const float coors_z_min = coors_range[2];
const float coors_x_max = coors_range[3];
const float coors_y_max = coors_range[4];
const float coors_z_max = coors_range[5];
const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
const int col_blocks = at::musa::ATenCeilDiv(num_points, THREADS_PER_BLOCK);
dim3 blocks(col_blocks);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_ALL_TYPES(points.scalar_type(), "dynamic_voxelize_kernel", [&] {
dynamic_voxelize_kernel<scalar_t, int><<<blocks, threads, 0, stream>>>(
points.contiguous().data_ptr<scalar_t>(),
coors.contiguous().data_ptr<int>(), voxel_x, voxel_y, voxel_z,
coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max,
coors_z_max, grid_x, grid_y, grid_z, num_points, num_features, NDim);
});
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -1,4 +1,5 @@
import torch
from mmengine.device import is_cuda_available, is_musa_available
from torch import Tensor
from ..utils import ext_loader
@ -10,7 +11,7 @@ ext_module = ext_loader.load_ext('_ext', [
def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor:
"""Find the box in which each point is (CUDA).
"""Find the box in which each point is (CUDA/MUSA).
Args:
points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate.
@ -38,7 +39,7 @@ def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor:
# If manually put the tensor 'points' or 'boxes' on a device
# which is not the current device, some temporary variables
# will be created on the current device in the cuda op,
# will be created on the current device in the cuda/musa op,
# and the output will be incorrect.
# Therefore, we force the current device to be the same
# as the device of the tensors if it was not.
@ -48,8 +49,12 @@ def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor:
assert points_device == boxes.get_device(), \
'Points and boxes should be put on the same device'
if points.device.type != 'npu':
if torch.cuda.current_device() != points_device:
torch.cuda.set_device(points_device)
if is_cuda_available():
if torch.cuda.current_device() != points_device:
torch.cuda.set_device(points_device)
elif is_musa_available():
if torch.musa.current_device() != points_device:
torch.musa.set_device(points_device)
else:
boxes[:, :, 2] += boxes[:, :, 5] / 2.0
@ -99,7 +104,7 @@ def points_in_boxes_cpu(points: Tensor, boxes: Tensor) -> Tensor:
def points_in_boxes_all(points: Tensor, boxes: Tensor) -> Tensor:
"""Find all boxes in which each point is (CUDA).
"""Find all boxes in which each point is (CUDA/MUSA).
Args:
points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate
@ -131,8 +136,12 @@ def points_in_boxes_all(points: Tensor, boxes: Tensor) -> Tensor:
assert points_device == boxes.get_device(), \
'Points and boxes should be put on the same device'
if points.device.type != 'npu':
if torch.cuda.current_device() != points_device:
torch.cuda.set_device(points_device)
if is_cuda_available():
if torch.cuda.current_device() != points_device:
torch.cuda.set_device(points_device)
elif is_musa_available():
if torch.musa.current_device() != points_device:
torch.musa.set_device(points_device)
ext_module.points_in_boxes_all_forward(boxes.contiguous(),
points.contiguous(),

View File

@ -4,6 +4,7 @@ from typing import Optional
import torch
import torch.distributed as dist
import torch.nn.functional as F
from mmengine.device import is_cuda_available, is_musa_available
from mmengine.registry import MODELS
from torch.autograd import Function
from torch.autograd.function import once_differentiable
@ -47,10 +48,20 @@ class SyncBatchNormFunction(Function):
self.group_size = group_size
self.stats_mode = stats_mode
assert isinstance(
input, (torch.HalfTensor, torch.FloatTensor,
torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \
f'only support Half or Float Tensor, but {input.type()}'
if is_cuda_available():
assert isinstance(
input, (torch.HalfTensor, torch.FloatTensor,
torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \
f'only support Half or Float Tensor, but {input.type()}'
elif is_musa_available():
assert isinstance(
input, (torch.HalfTensor, torch.FloatTensor,
torch.musa.HalfTensor, torch.musa.FloatTensor)), \
f'only support Half or Float Tensor, but {input.type()}'
else:
assert isinstance(
input, (torch.HalfTensor, torch.FloatTensor)), \
f'only support Half or Float Tensor, but {input.type()}'
output = torch.zeros_like(input)
input3d = input.flatten(start_dim=2)
output3d = output.view_as(input3d)

View File

@ -116,6 +116,13 @@ def upfirdn2d(input: torch.Tensor,
padding=padding,
flip_filter=flip_filter,
gain=gain).apply(input, filter)
elif use_custom_op and input.device.type == 'musa':
return _upfirdn2d_musa(
up=up,
down=down,
padding=padding,
flip_filter=flip_filter,
gain=gain).apply(input, filter)
return _upfirdn2d_ref(
input,
filter,
@ -303,6 +310,101 @@ def _upfirdn2d_cuda(up: int = 1,
return Upfirdn2dCuda
_upfirdn2d_musa_cache: Dict = dict()
def _upfirdn2d_musa(up: int = 1,
down: int = 1,
padding: Union[int, List[int]] = 0,
flip_filter: bool = False,
gain: Union[float, int] = 1):
"""Fast MUSA implementation of `upfirdn2d()` using custom ops.
Args:
up (int): Integer upsampling factor. Can be a single int or a
list/tuple `[x, y]`. Defaults to 1.
down (int): Integer downsampling factor. Can be a single int
or a list/tuple `[x, y]`. Defaults to 1.
padding (int | tuple[int]): Padding with respect to the upsampled
image. Can be a single number or a list/tuple `[x, y]` or
`[x_before, x_after, y_before, y_after]`. Defaults to 0.
flip_filter (bool): False = convolution, True = correlation.
Defaults to False.
gain (int): Overall scaling factor for signal magnitude.
Defaults to 1.
Returns:
torch.Tensor: Tensor of the shape `[batch_size, num_channels,
out_height, out_width]`
"""
# Parse arguments.
upx, upy = _parse_scaling(up)
downx, downy = _parse_scaling(down)
padx0, padx1, pady0, pady1 = _parse_padding(padding)
# Lookup from cache.
key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter,
gain)
if key in _upfirdn2d_musa_cache:
return _upfirdn2d_musa_cache[key]
# Forward op.
class Upfirdn2dMusa(torch.autograd.Function):
@staticmethod
def forward(ctx, x, f): # pylint: disable=arguments-differ
assert isinstance(x, torch.Tensor) and x.ndim == 4
if f is None:
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
if f.ndim == 1 and f.shape[0] == 1:
f = f.square().unsqueeze(
0) # Convert separable-1 into full-1x1.
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
y = x
if f.ndim == 2:
y = ext_module.upfirdn2d(y, f, upx, upy, downx, downy, padx0,
padx1, pady0, pady1, flip_filter,
gain)
else:
y = ext_module.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1,
padx0, padx1, 0, 0, flip_filter, 1.0)
y = ext_module.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy,
0, 0, pady0, pady1, flip_filter, gain)
ctx.save_for_backward(f)
ctx.x_shape = x.shape
return y
@staticmethod
def backward(ctx, dy): # pylint: disable=arguments-differ
f, = ctx.saved_tensors
_, _, ih, iw = ctx.x_shape
_, _, oh, ow = dy.shape
fw, fh = _get_filter_size(f)
p = [
fw - padx0 - 1,
iw * upx - ow * downx + padx0 - upx + 1,
fh - pady0 - 1,
ih * upy - oh * downy + pady0 - upy + 1,
]
dx = None
df = None
if ctx.needs_input_grad[0]:
dx = _upfirdn2d_musa(
up=down,
down=up,
padding=p,
flip_filter=(not flip_filter),
gain=gain).apply(dy, f)
assert not ctx.needs_input_grad[1]
return dx, df
# Add to cache.
_upfirdn2d_musa_cache[key] = Upfirdn2dMusa
return Upfirdn2dMusa
def filter2d(input: torch.Tensor,
filter: torch.Tensor,
padding: Union[int, List[int]] = 0,

View File

@ -4,7 +4,7 @@ import pytest
import torch
from mmcv.ops import points_in_polygons
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE, IS_NPU_AVAILABLE
@pytest.mark.parametrize('device', [
@ -15,7 +15,11 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
not IS_NPU_AVAILABLE, reason='requires NPU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
])
def test_points_in_polygons(device):
points = np.array([[300., 300.], [400., 400.], [100., 100], [300, 250],

View File

@ -3,7 +3,7 @@ import numpy as np
import pytest
import torch
from mmcv.utils import IS_CUDA_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE
_USING_PARROTS = True
try:
@ -41,7 +41,11 @@ class TestPrRoiPool:
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
def test_roipool_gradcheck(self, device):
from mmcv.ops import PrRoIPool
@ -92,7 +96,11 @@ class TestPrRoiPool:
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
def test_roipool_allclose_float(self, device):
self._test_roipool_allclose(device, dtype=torch.float)

View File

@ -4,7 +4,8 @@ import pytest
import torch
import torch.nn as nn
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
IS_NPU_AVAILABLE)
class Loss(nn.Module):
@ -32,7 +33,11 @@ class TestPSAMask:
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
not IS_NPU_AVAILABLE, reason='requires NPU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
def test_psa_mask_collect(self, device):
from mmcv.ops import PSAMask
@ -84,7 +89,11 @@ class TestPSAMask:
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
not IS_NPU_AVAILABLE, reason='requires NPU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
def test_psa_mask_distribute(self, device):
from mmcv.ops import PSAMask

View File

@ -3,7 +3,8 @@ import numpy as np
import pytest
import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
IS_NPU_AVAILABLE)
_USING_PARROTS = True
try:
@ -107,7 +108,11 @@ def _test_roialign_allclose(device, dtype):
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
not IS_NPU_AVAILABLE, reason='requires NPU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
])
def test_roialign_float(device, dtype):
_test_roialign_allclose(device=device, dtype=dtype)

View File

@ -3,7 +3,7 @@ import numpy as np
import pytest
import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE
_USING_PARROTS = True
try:
@ -132,15 +132,19 @@ def _test_roialign_rotated_allclose(device, dtype):
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
])
@pytest.mark.parametrize('dtype', [
torch.float,
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_MLU_AVAILABLE,
reason='MLU does not support for 64-bit floating point')),
IS_MLU_AVAILABLE or IS_MUSA_AVAILABLE,
reason='MLU, MUSA does not support for 64-bit floating point')),
torch.half
])
def test_roialign_rotated(device, dtype):

View File

@ -5,7 +5,8 @@ import numpy as np
import pytest
import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
IS_NPU_AVAILABLE)
_USING_PARROTS = True
try:
@ -89,16 +90,20 @@ class TestRoiPool:
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
not IS_NPU_AVAILABLE, reason='requires NPU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
])
@pytest.mark.parametrize('dtype', [
torch.float,
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_MLU_AVAILABLE or IS_NPU_AVAILABLE,
reason='MLU, NPU does not support for 64-bit floating point')),
torch.half
IS_MLU_AVAILABLE or IS_NPU_AVAILABLE or IS_MUSA_AVAILABLE,
reason='MLU, NPU, MUSA '
'does not support for 64-bit floating point')), torch.half
])
def test_roipool_allclose(self, device, dtype):
self._test_roipool_allclose(device, dtype)

View File

@ -5,7 +5,8 @@ import torch
from mmcv.ops import (RoIAwarePool3d, points_in_boxes_all, points_in_boxes_cpu,
points_in_boxes_part)
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
IS_NPU_AVAILABLE)
@pytest.mark.parametrize('dtype', [
@ -13,7 +14,8 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_MLU_AVAILABLE, reason='MLU does not support for double'))
IS_MLU_AVAILABLE or IS_MUSA_AVAILABLE,
reason='MLU, MUSA does not support for double'))
])
@pytest.mark.parametrize('device', [
pytest.param(
@ -23,7 +25,11 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
def test_RoIAwarePool3d(device, dtype):
roiaware_pool3d_max = RoIAwarePool3d(
@ -64,7 +70,11 @@ def test_RoIAwarePool3d(device, dtype):
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
not IS_NPU_AVAILABLE, reason='requires NPU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
def test_points_in_boxes_part(device):
boxes = torch.tensor(

View File

@ -3,7 +3,8 @@ import pytest
import torch
from mmcv.ops import RoIPointPool3d
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
IS_NPU_AVAILABLE)
@pytest.mark.parametrize('device', [
@ -18,15 +19,19 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
not IS_NPU_AVAILABLE, reason='requires NPU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
])
@pytest.mark.parametrize('dtype', [
torch.float, torch.half,
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_MLU_AVAILABLE or IS_NPU_AVAILABLE,
reason='MLU and NPU does not support for double'))
IS_MLU_AVAILABLE or IS_NPU_AVAILABLE or IS_MUSA_AVAILABLE,
reason='MLU, NPU, MUSA does not support for double'))
])
def test_roipoint(device, dtype):
points = torch.tensor(

View File

@ -3,7 +3,8 @@ import pytest
import torch
from mmcv.ops import rotated_feature_align
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
IS_NPU_AVAILABLE)
@pytest.mark.skipif(
@ -21,6 +22,10 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
pytest.param(
'cpu',
marks=pytest.mark.skipif(

View File

@ -4,7 +4,7 @@ import torch
from torch.autograd import gradcheck
from mmcv.ops import DynamicScatter
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE
if torch.__version__ == 'parrots':
pytest.skip('not supported in parrots now', allow_module_level=True)
@ -18,7 +18,11 @@ if torch.__version__ == 'parrots':
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
])
def test_dynamic_scatter(device):
dsmean = DynamicScatter([0.32, 0.32, 6],

View File

@ -10,7 +10,7 @@ from mmcv.ops import (SparseConvTensor, SparseInverseConv3d, SparseSequential,
if torch.__version__ == 'parrots':
pytest.skip('not supported in parrots now', allow_module_level=True)
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE
def make_sparse_convmodule(in_channels,
@ -86,10 +86,17 @@ def make_sparse_convmodule(in_channels,
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
])
def test_make_sparse_convmodule(device):
torch.cuda.empty_cache()
if IS_CUDA_AVAILABLE:
torch.cuda.empty_cache()
elif IS_MUSA_AVAILABLE:
torch.musa.empty_cache()
voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315],
[6.8162713, -2.480431, -1.3616394, 0.36],
[11.643568, -4.744306, -1.3580885, 0.16],

View File

@ -8,6 +8,8 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE
if platform.system() == 'Windows':
import regex as re
else:
@ -29,10 +31,24 @@ class TestSyncBN:
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(rank)
dist.init_process_group('nccl')
torch.cuda.set_device(local_rank)
if IS_CUDA_AVAILABLE:
dist.init_process_group('nccl')
torch.cuda.set_device(local_rank)
elif IS_MUSA_AVAILABLE:
dist.init_process_group('mccl')
torch.musa.set_device(local_rank)
def _test_syncbn_train(self, size=1, half=False):
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
def _test_syncbn_train(self, size=1, half=False, device='cuda'):
if 'SLURM_NTASKS' not in os.environ or int(
os.environ['SLURM_NTASKS']) != 4:
@ -49,10 +65,13 @@ class TestSyncBN:
rank = dist.get_rank()
torch.manual_seed(9)
torch.cuda.manual_seed(9)
if IS_CUDA_AVAILABLE:
torch.cuda.manual_seed(9)
elif IS_MUSA_AVAILABLE:
torch.musa.manual_seed(9)
self.x = torch.rand(16, 3, 2, 3).cuda()
self.y_bp = torch.rand(16, 3, 2, 3).cuda()
self.x = torch.rand(16, 3, 2, 3).to(device)
self.y_bp = torch.rand(16, 3, 2, 3).to(device)
if half:
self.x = self.x.half()
@ -60,7 +79,10 @@ class TestSyncBN:
dist.broadcast(self.x, src=0)
dist.broadcast(self.y_bp, src=0)
torch.cuda.synchronize()
if IS_CUDA_AVAILABLE:
torch.cuda.synchronize()
elif IS_MUSA_AVAILABLE:
torch.musa.synchronize()
if size == 1:
groups = [None, None, None, None]
groups[0] = dist.new_group([0])
@ -75,13 +97,13 @@ class TestSyncBN:
group = groups[rank]
elif size == 4:
group = dist.group.WORLD
syncbn = SyncBatchNorm(3, group=group).cuda()
syncbn = SyncBatchNorm(3, group=group).to(device)
syncbn.weight.data[0] = 0.2
syncbn.weight.data[1] = 0.5
syncbn.weight.data[2] = 0.7
syncbn.train()
bn = nn.BatchNorm2d(3).cuda()
bn = nn.BatchNorm2d(3).to(device)
bn.weight.data[0] = 0.2
bn.weight.data[1] = 0.5
bn.weight.data[2] = 0.7
@ -143,7 +165,17 @@ class TestSyncBN:
assert np.allclose(x_grad.data.cpu().numpy(),
sx_grad.data.cpu().numpy(), 1e-2)
def _test_syncbn_empty_train(self, size=1, half=False):
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
def _test_syncbn_empty_train(self, size=1, half=False, device='cuda'):
if 'SLURM_NTASKS' not in os.environ or int(
os.environ['SLURM_NTASKS']) != 4:
@ -160,10 +192,13 @@ class TestSyncBN:
rank = dist.get_rank()
torch.manual_seed(9)
torch.cuda.manual_seed(9)
if IS_CUDA_AVAILABLE:
torch.cuda.manual_seed(9)
elif IS_MUSA_AVAILABLE:
torch.musa.manual_seed(9)
self.x = torch.rand(0, 3, 2, 3).cuda()
self.y_bp = torch.rand(0, 3, 2, 3).cuda()
self.x = torch.rand(0, 3, 2, 3).to(device)
self.y_bp = torch.rand(0, 3, 2, 3).to(device)
if half:
self.x = self.x.half()
@ -171,7 +206,10 @@ class TestSyncBN:
dist.broadcast(self.x, src=0)
dist.broadcast(self.y_bp, src=0)
torch.cuda.synchronize()
if IS_CUDA_AVAILABLE:
torch.cuda.synchronize()
elif IS_MUSA_AVAILABLE:
torch.musa.synchronize()
if size == 1:
groups = [None, None, None, None]
groups[0] = dist.new_group([0])
@ -187,13 +225,13 @@ class TestSyncBN:
elif size == 4:
group = dist.group.WORLD
syncbn = SyncBatchNorm(3, group=group, stats_mode='N').cuda()
syncbn = SyncBatchNorm(3, group=group, stats_mode='N').to(device)
syncbn.weight.data[0] = 0.2
syncbn.weight.data[1] = 0.5
syncbn.weight.data[2] = 0.7
syncbn.train()
bn = nn.BatchNorm2d(3).cuda()
bn = nn.BatchNorm2d(3).to(device)
bn.weight.data[0] = 0.2
bn.weight.data[1] = 0.5
bn.weight.data[2] = 0.7

View File

@ -3,7 +3,7 @@ import pytest
import torch
from mmcv.ops import three_interpolate
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE, IS_NPU_AVAILABLE
@pytest.mark.parametrize('dtype', [
@ -11,8 +11,8 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_NPU_AVAILABLE,
reason='NPU does not support for 64-bit floating point'))
IS_NPU_AVAILABLE or IS_MUSA_AVAILABLE,
reason='NPU, MUSA does not support for 64-bit floating point'))
])
@pytest.mark.parametrize('device', [
pytest.param(
@ -22,9 +22,15 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
not IS_NPU_AVAILABLE, reason='requires NPU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
def test_three_interpolate(dtype, device):
if IS_MUSA_AVAILABLE:
torch.musa.empty_cache()
features = torch.tensor(
[[[2.4350, 4.7516, 4.4995, 2.4350, 2.4350, 2.4350],
[3.1236, 2.6278, 3.0447, 3.1236, 3.1236, 3.1236],

View File

@ -3,7 +3,7 @@ import pytest
import torch
from mmcv.ops import three_nn
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE
known = [[[-1.8373, 3.5605, -0.7867], [0.7615, 2.9420, 0.2314],
[-0.6503, 3.6637, -1.0622], [-1.8373, 3.5605, -0.7867],
@ -48,7 +48,11 @@ expected_idx = [[[0, 3, 4], [1, 2, 0], [2, 0, 3], [0, 3, 4], [2, 1, 0],
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
@pytest.mark.parametrize('dtype,rtol', [(torch.float, 1e-8),
(torch.half, 1e-3)])

View File

@ -5,7 +5,7 @@ import numpy as np
import pytest
import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE
_USING_PARROTS = True
try:
@ -209,15 +209,19 @@ def _test_tinshift_assert(device, dtype):
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
])
@pytest.mark.parametrize('dtype', [
torch.float,
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_MLU_AVAILABLE,
reason='MLU does not support for 64-bit floating point')),
IS_MLU_AVAILABLE or IS_MUSA_AVAILABLE,
reason='MLU, MUSA does not support for 64-bit floating point')),
torch.half
])
def test_tinshift(device, dtype):

View File

@ -4,7 +4,8 @@ import pytest
import torch
from mmcv.ops import Voxelization
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
IS_NPU_AVAILABLE)
def _get_voxel_points_indices(points, coors, voxel):
@ -215,3 +216,38 @@ def test_voxelization_npu(device_type):
assert np.all(coors == expected_coors)
assert np.all(voxels == expected_voxels)
assert np.all(num_points_per_voxel == expected_num_points_per_voxel)
@pytest.mark.parametrize('device_type', [
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
])
def test_voxelization_musa(device_type):
voxel_size = [0.5, 0.5, 0.5]
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
voxel_dict = np.load(
'tests/data/for_3d_ops/test_voxel.npy', allow_pickle=True).item()
expected_coors = voxel_dict['coors']
expected_voxels = voxel_dict['voxels']
expected_num_points_per_voxel = voxel_dict['num_points_per_voxel']
points = voxel_dict['points']
points = torch.tensor(points)
max_num_points = 1000
hard_voxelization = Voxelization(voxel_size, point_cloud_range,
max_num_points)
device = torch.device(device_type)
# test hard_voxelization on mlu
points = points.contiguous().to(device)
coors, voxels, num_points_per_voxel = hard_voxelization.forward(points)
coors = coors.cpu().detach().numpy()
voxels = voxels.cpu().detach().numpy()
num_points_per_voxel = num_points_per_voxel.cpu().detach().numpy()
assert np.all(coors == expected_coors)
assert np.all(voxels == expected_voxels)
assert np.all(num_points_per_voxel == expected_num_points_per_voxel)