From 4b38ffcf45a6a42e3cb34efd6fa38b9c8fa915b7 Mon Sep 17 00:00:00 2001
From: sunyanguomt <yanguo.sun@mthreads.com>
Date: Thu, 20 Mar 2025 16:34:23 +0800
Subject: [PATCH] [MUSA] mmcv support musa, split pr 4 (#3260)

* mmcv support musa, split pr 4

* fix lint

* fix lint
---
 .../musa/points_in_boxes_musa_kernel.muh      |   91 +
 .../musa/points_in_polygons_musa_kernel.muh   |   75 +
 .../common/musa/prroi_pool_musa_kernel.muh    |  377 +++
 .../csrc/common/musa/psamask_musa_kernel.muh  |  137 ++
 .../musa/riroi_align_rotated_musa_kernel.muh  |  238 ++
 .../common/musa/roi_align_musa_kernel.muh     |  205 ++
 .../musa/roi_align_rotated_musa_kernel.muh    |  194 ++
 .../csrc/common/musa/roi_pool_musa_kernel.muh |   89 +
 .../musa/roiaware_pool3d_musa_kernel.muh      |  256 ++
 .../musa/roipoint_pool3d_musa_kernel.muh      |  130 ++
 .../rotated_feature_align_musa_kernel.muh     |  125 +
 .../musa/scatter_points_musa_kernel.muh       |  137 ++
 .../csrc/common/musa/sync_bn_musa_kernel.muh  |  327 +++
 .../musa/three_interpolate_musa_kernel.muh    |   57 +
 .../csrc/common/musa/three_nn_musa_kernel.muh |   63 +
 .../common/musa/tin_shift_musa_kernel.muh     |   57 +
 .../common/musa/voxelization_musa_kernel.muh  |  212 ++
 mmcv/ops/csrc/pytorch/musa/filtered_lrelu.mu  | 2052 +++++++++++++++++
 mmcv/ops/csrc/pytorch/musa/musabind.cpp       |  910 ++++++++
 .../csrc/pytorch/musa/points_in_boxes_musa.mu |   62 +
 .../pytorch/musa/points_in_polygons_musa.mu   |   28 +
 mmcv/ops/csrc/pytorch/musa/prroi_pool_musa.mu |   65 +
 mmcv/ops/csrc/pytorch/musa/psamask_musa.mu    |   60 +
 .../pytorch/musa/riroi_align_rotated_musa.mu  |   53 +
 mmcv/ops/csrc/pytorch/musa/roi_align_musa.mu  |   58 +
 .../pytorch/musa/roi_align_rotated_musa.mu    |   45 +
 mmcv/ops/csrc/pytorch/musa/roi_pool_musa.mu   |   50 +
 .../csrc/pytorch/musa/roiaware_pool3d_musa.mu |  118 +
 .../csrc/pytorch/musa/roipoint_pool3d_musa.mu |   60 +
 .../musa/rotated_feature_align_musa.mu        |   53 +
 .../csrc/pytorch/musa/scatter_points_musa.mu  |  132 ++
 mmcv/ops/csrc/pytorch/musa/sparse_maxpool.mu  |  486 ++++
 .../csrc/pytorch/musa/sparse_pool_ops_musa.mu |   91 +
 mmcv/ops/csrc/pytorch/musa/sync_bn_musa.mu    |  110 +
 .../pytorch/musa/three_interpolate_musa.mu    |   66 +
 mmcv/ops/csrc/pytorch/musa/three_nn_musa.mu   |   35 +
 mmcv/ops/csrc/pytorch/musa/tin_shift_musa.mu  |   55 +
 .../ops/csrc/pytorch/musa/upfirdn2d_kernel.mu |  749 ++++++
 .../csrc/pytorch/musa/voxelization_musa.mu    |  286 +++
 mmcv/ops/points_in_boxes.py                   |   23 +-
 mmcv/ops/sync_bn.py                           |   19 +-
 mmcv/ops/upfirdn2d.py                         |  102 +
 tests/test_ops/test_points_in_polygons.py     |    8 +-
 tests/test_ops/test_prroi_pool.py             |   14 +-
 tests/test_ops/test_psa_mask.py               |   15 +-
 tests/test_ops/test_roi_align.py              |    9 +-
 tests/test_ops/test_roi_align_rotated.py      |   12 +-
 tests/test_ops/test_roi_pool.py               |   15 +-
 tests/test_ops/test_roiaware_pool3d.py        |   18 +-
 tests/test_ops/test_roipoint_pool3d.py        |   13 +-
 tests/test_ops/test_rotated_feature_align.py  |    7 +-
 tests/test_ops/test_scatter_points.py         |    8 +-
 tests/test_ops/test_spconv.py                 |   13 +-
 tests/test_ops/test_syncbn.py                 |   70 +-
 tests/test_ops/test_three_interpolate.py      |   14 +-
 tests/test_ops/test_three_nn.py               |    8 +-
 tests/test_ops/test_tin_shift.py              |   12 +-
 tests/test_ops/test_voxelization.py           |   38 +-
 58 files changed, 8741 insertions(+), 71 deletions(-)
 create mode 100644 mmcv/ops/csrc/common/musa/points_in_boxes_musa_kernel.muh
 create mode 100644 mmcv/ops/csrc/common/musa/points_in_polygons_musa_kernel.muh
 create mode 100644 mmcv/ops/csrc/common/musa/prroi_pool_musa_kernel.muh
 create mode 100644 mmcv/ops/csrc/common/musa/psamask_musa_kernel.muh
 create mode 100644 mmcv/ops/csrc/common/musa/riroi_align_rotated_musa_kernel.muh
 create mode 100644 mmcv/ops/csrc/common/musa/roi_align_musa_kernel.muh
 create mode 100644 mmcv/ops/csrc/common/musa/roi_align_rotated_musa_kernel.muh
 create mode 100644 mmcv/ops/csrc/common/musa/roi_pool_musa_kernel.muh
 create mode 100644 mmcv/ops/csrc/common/musa/roiaware_pool3d_musa_kernel.muh
 create mode 100644 mmcv/ops/csrc/common/musa/roipoint_pool3d_musa_kernel.muh
 create mode 100644 mmcv/ops/csrc/common/musa/rotated_feature_align_musa_kernel.muh
 create mode 100644 mmcv/ops/csrc/common/musa/scatter_points_musa_kernel.muh
 create mode 100644 mmcv/ops/csrc/common/musa/sync_bn_musa_kernel.muh
 create mode 100644 mmcv/ops/csrc/common/musa/three_interpolate_musa_kernel.muh
 create mode 100644 mmcv/ops/csrc/common/musa/three_nn_musa_kernel.muh
 create mode 100644 mmcv/ops/csrc/common/musa/tin_shift_musa_kernel.muh
 create mode 100644 mmcv/ops/csrc/common/musa/voxelization_musa_kernel.muh
 create mode 100644 mmcv/ops/csrc/pytorch/musa/filtered_lrelu.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/points_in_boxes_musa.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/points_in_polygons_musa.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/prroi_pool_musa.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/psamask_musa.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/riroi_align_rotated_musa.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/roi_align_musa.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/roi_align_rotated_musa.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/roi_pool_musa.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/roiaware_pool3d_musa.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/roipoint_pool3d_musa.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/rotated_feature_align_musa.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/scatter_points_musa.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/sparse_maxpool.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/sparse_pool_ops_musa.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/sync_bn_musa.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/three_interpolate_musa.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/three_nn_musa.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/tin_shift_musa.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/upfirdn2d_kernel.mu
 create mode 100644 mmcv/ops/csrc/pytorch/musa/voxelization_musa.mu

diff --git a/mmcv/ops/csrc/common/musa/points_in_boxes_musa_kernel.muh b/mmcv/ops/csrc/common/musa/points_in_boxes_musa_kernel.muh
new file mode 100644
index 000000000..e20ac68c7
--- /dev/null
+++ b/mmcv/ops/csrc/common/musa/points_in_boxes_musa_kernel.muh
@@ -0,0 +1,91 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef POINT_IN_BOXES_MUSA_KERNEL_MUH
+#define POINT_IN_BOXES_MUSA_KERNEL_MUH
+
+#include "pytorch_musa_helper.hpp"
+
+template <typename T>
+__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz,
+                                             T &local_x, T &local_y) {
+  T cosa = cos(-rz), sina = sin(-rz);
+  local_x = shift_x * cosa + shift_y * (-sina);
+  local_y = shift_x * sina + shift_y * cosa;
+}
+
+template <typename T>
+__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x,
+                                        T &local_y) {
+  // param pt: (x, y, z)
+  // param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate,
+  // cz in the bottom center
+  T x = pt[0], y = pt[1], z = pt[2];
+  T cx = box3d[0], cy = box3d[1], cz = box3d[2];
+  T x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6];
+  cz += z_size /
+        2.0;  // shift to the center since cz in box3d is the bottom center
+
+  if (fabsf(z - cz) > z_size / 2.0) return 0;
+  lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
+  float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) &
+                  (local_y > -y_size / 2.0) & (local_y < y_size / 2.0);
+  return in_flag;
+}
+
+template <typename T>
+__global__ void points_in_boxes_part_forward_musa_kernel(
+    int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts,
+    int *box_idx_of_points) {
+  // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
+  // coordinate, z is the bottom center, each box DO NOT overlaps params pts:
+  // (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points:
+  // (B, npoints), default -1
+
+  int bs_idx = blockIdx.y;
+  MUSA_1D_KERNEL_LOOP(pt_idx, pts_num) {
+    if (bs_idx >= batch_size) return;
+
+    boxes += bs_idx * boxes_num * 7;
+    pts += bs_idx * pts_num * 3 + pt_idx * 3;
+    box_idx_of_points += bs_idx * pts_num + pt_idx;
+
+    T local_x = 0, local_y = 0;
+    int cur_in_flag = 0;
+    for (int k = 0; k < boxes_num; k++) {
+      cur_in_flag = check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y);
+      if (cur_in_flag) {
+        box_idx_of_points[0] = k;
+        break;
+      }
+    }
+  }
+}
+
+template <typename T>
+__global__ void points_in_boxes_all_forward_musa_kernel(
+    int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts,
+    int *box_idx_of_points) {
+  // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
+  // coordinate, z is the bottom center, each box DO NOT overlaps params pts:
+  // (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points:
+  // (B, npoints), default -1
+
+  int bs_idx = blockIdx.y;
+  MUSA_1D_KERNEL_LOOP(pt_idx, pts_num) {
+    if (bs_idx >= batch_size) return;
+
+    boxes += bs_idx * boxes_num * 7;
+    pts += bs_idx * pts_num * 3 + pt_idx * 3;
+    box_idx_of_points += bs_idx * pts_num * boxes_num + pt_idx * boxes_num;
+
+    T local_x = 0, local_y = 0;
+    for (int k = 0; k < boxes_num; k++) {
+      const int cur_in_flag =
+          check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y);
+      if (cur_in_flag) {
+        box_idx_of_points[k] = 1;
+      }
+    }
+  }
+}
+
+#endif  // POINT_IN_BOXES_MUSA_KERNEL_MUH
diff --git a/mmcv/ops/csrc/common/musa/points_in_polygons_musa_kernel.muh b/mmcv/ops/csrc/common/musa/points_in_polygons_musa_kernel.muh
new file mode 100644
index 000000000..714c889bd
--- /dev/null
+++ b/mmcv/ops/csrc/common/musa/points_in_polygons_musa_kernel.muh
@@ -0,0 +1,75 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef POINTS_IN_POLYGONS_MUSA_KERNEL_MUH
+#define POINTS_IN_POLYGONS_MUSA_KERNEL_MUH
+
+#include "pytorch_musa_helper.hpp"
+
+struct point {
+  float x, y;
+};
+
+template <typename scalar_t>
+__global__ void points_in_polygons_forward_musa_kernel(
+    const int nthreads, const scalar_t *vertex1, const scalar_t *vertex2,
+    const int rows, const int cols, scalar_t *inside_flag) {
+  MUSA_1D_KERNEL_LOOP(index, nthreads) {
+    int row = index / cols;
+    int col = index % cols;
+
+    const scalar_t *offset_vertex1 = vertex1 + row * 2;
+    const scalar_t *offset_vertex2 = vertex2 + col * 8;
+
+    point point_[1];
+    point polygon[4];
+
+    point_[0].x = offset_vertex1[0];
+    point_[0].y = offset_vertex1[1];
+
+    polygon[0].x = offset_vertex2[0];
+    polygon[0].y = offset_vertex2[1];
+    polygon[1].x = offset_vertex2[2];
+    polygon[1].y = offset_vertex2[3];
+    polygon[2].x = offset_vertex2[4];
+    polygon[2].y = offset_vertex2[5];
+    polygon[3].x = offset_vertex2[6];
+    polygon[3].y = offset_vertex2[7];
+
+    int nCross = 0;
+    int i, j;
+    float sx, sy, tx, ty, px, py, x;
+    for (i = 0, j = 3; i < 4; j = i, i++) {
+      sx = polygon[i].x;
+      sy = polygon[i].y;
+      tx = polygon[j].x;
+      ty = polygon[j].y;
+
+      px = point_[0].x;
+      py = point_[0].y;
+
+      if (py < min(sy, ty)) continue;
+      if (py > max(sy, ty)) continue;
+
+      if ((sx == px && sy == py) || (tx == px && ty == py)) {
+        break;
+      } else {
+        if ((sy < py && ty >= py) || (sy >= py && ty < py)) {
+          x = sx + (py - sy) * (tx - sx) / (ty - sy);
+          if (x == px) {
+            break;
+          }
+          if (x > px) {
+            nCross++;
+          }
+        }
+      }
+    }
+    if (nCross % 2 == 1) {
+      inside_flag[index] = 1.0;
+    } else {
+      inside_flag[index] = 0.0;
+    }
+    return;
+  }
+}
+
+#endif  // POINTS_IN_POLYGONS_MUSA_KERNEL_MUH
diff --git a/mmcv/ops/csrc/common/musa/prroi_pool_musa_kernel.muh b/mmcv/ops/csrc/common/musa/prroi_pool_musa_kernel.muh
new file mode 100644
index 000000000..9394b7e89
--- /dev/null
+++ b/mmcv/ops/csrc/common/musa/prroi_pool_musa_kernel.muh
@@ -0,0 +1,377 @@
+// Copyright (c) OpenMMLab. All rights reserved
+// Modified from
+// https://github.com/vacancy/PreciseRoIPooling/blob/master/src/prroi_pooling_gpu_impl.cu
+// Distributed under terms of the MIT license.
+#ifndef PRROI_POOL_MUSA_KERNEL_MUH
+#define PRROI_POOL_MUSA_KERNEL_MUH
+
+#include "pytorch_musa_helper.hpp"
+
+template <typename T>
+__device__ static __forceinline__ T PrRoIPoolingGetData(const T *data,
+                                                        const int h,
+                                                        const int w,
+                                                        const int height,
+                                                        const int width) {
+  bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);
+  T retVal = overflow ? 0.0f : data[h * width + w];
+  return retVal;
+}
+
+template <typename T>
+__device__ static __forceinline__ T PrRoIPoolingGetCoeff(T dh, T dw) {
+  return (1.0f - abs(dh)) * (1.0f - abs(dw));
+}
+
+template <typename T>
+__device__ static __forceinline__ T PrRoIPoolingSingleCoorIntegral(T s, T t,
+                                                                   T c1, T c2) {
+  return 0.5 * (t * t - s * s) * (c2 - c1) + (t - s) * c1;
+}
+
+template <typename T>
+__device__ static T PrRoIPoolingInterpolation(const T *data, const T h,
+                                              const T w, const int height,
+                                              const int width) {
+  T retVal = 0.0f;
+  int h1 = floorf(h);
+  int w1 = floorf(w);
+  retVal += PrRoIPoolingGetData(data, h1, w1, height, width) *
+            PrRoIPoolingGetCoeff(h - T(h1), w - T(w1));
+  h1 = floorf(h) + 1;
+  w1 = floorf(w);
+  retVal += PrRoIPoolingGetData(data, h1, w1, height, width) *
+            PrRoIPoolingGetCoeff(h - T(h1), w - T(w1));
+  h1 = floorf(h);
+  w1 = floorf(w) + 1;
+  retVal += PrRoIPoolingGetData(data, h1, w1, height, width) *
+            PrRoIPoolingGetCoeff(h - T(h1), w - T(w1));
+  h1 = floorf(h) + 1;
+  w1 = floorf(w) + 1;
+  retVal += PrRoIPoolingGetData(data, h1, w1, height, width) *
+            PrRoIPoolingGetCoeff(h - T(h1), w - T(w1));
+  return retVal;
+}
+
+template <typename T>
+__device__ static T PrRoIPoolingMatCalculation(const T *this_data,
+                                               const int s_h, const int s_w,
+                                               const int e_h, const int e_w,
+                                               const T y0, const T x0,
+                                               const T y1, const T x1,
+                                               const int h0, const int w0) {
+  T alpha, beta, lim_alpha, lim_beta, tmp;
+  T sum_out = 0;
+
+  alpha = x0 - T(s_w);
+  beta = y0 - T(s_h);
+  lim_alpha = x1 - T(s_w);
+  lim_beta = y1 - T(s_h);
+  tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
+         0.5f * alpha * alpha) *
+        (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
+  sum_out += PrRoIPoolingGetData(this_data, s_h, s_w, h0, w0) * tmp;
+
+  alpha = T(e_w) - x1;
+  lim_alpha = T(e_w) - x0;
+  tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
+         0.5f * alpha * alpha) *
+        (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
+  sum_out += PrRoIPoolingGetData(this_data, s_h, e_w, h0, w0) * tmp;
+
+  alpha = x0 - T(s_w);
+  beta = T(e_h) - y1;
+  lim_alpha = x1 - T(s_w);
+  lim_beta = T(e_h) - y0;
+  tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
+         0.5f * alpha * alpha) *
+        (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
+  sum_out += PrRoIPoolingGetData(this_data, e_h, s_w, h0, w0) * tmp;
+
+  alpha = T(e_w) - x1;
+  lim_alpha = T(e_w) - x0;
+  tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
+         0.5f * alpha * alpha) *
+        (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
+  sum_out += PrRoIPoolingGetData(this_data, e_h, e_w, h0, w0) * tmp;
+
+  return sum_out;
+}
+
+template <typename T>
+__device__ static void PrRoIPoolingDistributeDiff(T *diff, const T top_diff,
+                                                  const int h, const int w,
+                                                  const int height,
+                                                  const int width,
+                                                  const T coeff) {
+  bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);
+  if (!overflow) atomicAdd(diff + h * width + w, top_diff * coeff);
+}
+
+template <typename T>
+__device__ static void PrRoIPoolingMatDistributeDiff(
+    T *diff, const T top_diff, const int s_h, const int s_w, const int e_h,
+    const int e_w, const T y0, const T x0, const T y1, const T x1, const int h0,
+    const int w0) {
+  T alpha, beta, lim_alpha, lim_beta, tmp;
+
+  alpha = x0 - T(s_w);
+  beta = y0 - T(s_h);
+  lim_alpha = x1 - T(s_w);
+  lim_beta = y1 - T(s_h);
+  tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
+         0.5f * alpha * alpha) *
+        (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
+  PrRoIPoolingDistributeDiff(diff, top_diff, s_h, s_w, h0, w0, tmp);
+
+  alpha = T(e_w) - x1;
+  lim_alpha = T(e_w) - x0;
+  tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
+         0.5f * alpha * alpha) *
+        (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
+  PrRoIPoolingDistributeDiff(diff, top_diff, s_h, e_w, h0, w0, tmp);
+
+  alpha = x0 - T(s_w);
+  beta = T(e_h) - y1;
+  lim_alpha = x1 - T(s_w);
+  lim_beta = T(e_h) - y0;
+  tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
+         0.5f * alpha * alpha) *
+        (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
+  PrRoIPoolingDistributeDiff(diff, top_diff, e_h, s_w, h0, w0, tmp);
+
+  alpha = T(e_w) - x1;
+  lim_alpha = T(e_w) - x0;
+  tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
+         0.5f * alpha * alpha) *
+        (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
+  PrRoIPoolingDistributeDiff(diff, top_diff, e_h, e_w, h0, w0, tmp);
+}
+
+template <typename T>
+__global__ void prroi_pool_forward_musa_kernel(
+    const int nthreads, const T *input, const T *rois, T *output,
+    const int pooled_height, const int pooled_width, const T spatial_scale,
+    const int channels, const int height, const int width) {
+  MUSA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+
+    const T *offset_rois = rois + n * 5;
+    int roi_batch_ind = offset_rois[0];
+
+    T roi_x1 = offset_rois[1] * spatial_scale;
+    T roi_y1 = offset_rois[2] * spatial_scale;
+    T roi_x2 = offset_rois[3] * spatial_scale;
+    T roi_y2 = offset_rois[4] * spatial_scale;
+
+    T roi_width = max(roi_x2 - roi_x1, ((T)0.0));
+    T roi_height = max(roi_y2 - roi_y1, ((T)0.0));
+    T bin_size_h = roi_height / static_cast<T>(pooled_height);
+    T bin_size_w = roi_width / static_cast<T>(pooled_width);
+
+    const T *this_data =
+        input + (roi_batch_ind * channels + c) * height * width;
+    T *this_out = output + index;
+
+    T bin_x1 = roi_x1 + bin_size_w * pw;
+    T bin_y1 = roi_y1 + bin_size_h * ph;
+    T bin_x2 = bin_x1 + bin_size_w;
+    T bin_y2 = bin_y1 + bin_size_h;
+
+    T bin_size = max(T(0.0), bin_size_w * bin_size_h);
+    if (bin_size == 0) {
+      *this_out = 0;
+      continue;
+    }
+
+    T sum_out = 0;
+
+    int start_x, start_y, end_x, end_y;
+
+    start_x = floorf(bin_x1);
+    end_x = ceilf(bin_x2);
+    start_y = floorf(bin_y1);
+    end_y = ceilf(bin_y2);
+
+    for (int bin_x = start_x; bin_x < end_x; ++bin_x)
+      for (int bin_y = start_y; bin_y < end_y; ++bin_y)
+        sum_out += PrRoIPoolingMatCalculation(
+            this_data, bin_y, bin_x, bin_y + 1, bin_x + 1,
+            max(bin_y1, T(bin_y)), max(bin_x1, T(bin_x)),
+            min(bin_y2, T(bin_y) + 1.0f), min(bin_x2, T(bin_x + 1.0f)), height,
+            width);
+    *this_out = sum_out / bin_size;
+  }
+}
+
+template <typename T>
+__global__ void prroi_pool_backward_musa_kernel(
+    const int nthreads, const T *grad_output, const T *rois, T *grad_input,
+    const int pooled_height, const int pooled_width, const T spatial_scale,
+    const int channels, const int height, const int width) {
+  MUSA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+    auto rois_cur = rois + n * 5;
+
+    int roi_batch_ind = rois_cur[0];
+    T roi_x1 = rois_cur[1] * spatial_scale;
+    T roi_y1 = rois_cur[2] * spatial_scale;
+    T roi_x2 = rois_cur[3] * spatial_scale;
+    T roi_y2 = rois_cur[4] * spatial_scale;
+
+    T roi_width = max(roi_x2 - roi_x1, (T)0);
+    T roi_height = max(roi_y2 - roi_y1, (T)0);
+    T bin_size_h = roi_height / static_cast<T>(pooled_height);
+    T bin_size_w = roi_width / static_cast<T>(pooled_width);
+
+    const T *this_out_grad = grad_output + index;
+    T *this_data_grad =
+        grad_input + (roi_batch_ind * channels + c) * height * width;
+
+    T bin_x1 = roi_x1 + bin_size_w * pw;
+    T bin_y1 = roi_y1 + bin_size_h * ph;
+    T bin_x2 = bin_x1 + bin_size_w;
+    T bin_y2 = bin_y1 + bin_size_h;
+
+    T bin_size = max(T(0.0), bin_size_w * bin_size_h);
+
+    T sum_out = bin_size == T(0) ? T(0) : *this_out_grad / bin_size;
+
+    int start_x, start_y, end_x, end_y;
+
+    start_x = floorf(bin_x1);
+    end_x = ceilf(bin_x2);
+    start_y = floorf(bin_y1);
+    end_y = ceilf(bin_y2);
+
+    for (int bin_x = start_x; bin_x < end_x; ++bin_x)
+      for (int bin_y = start_y; bin_y < end_y; ++bin_y)
+        PrRoIPoolingMatDistributeDiff(
+            this_data_grad, sum_out, bin_y, bin_x, bin_y + 1, bin_x + 1,
+            max(bin_y1, T(bin_y)), max(bin_x1, T(bin_x)),
+            min(bin_y2, T(bin_y) + 1.0f), min(bin_x2, T(bin_x + 1.0f)), height,
+            width);
+  }
+}
+
+template <typename T>
+__global__ void prroi_pool_coor_backward_musa_kernel(
+    const int nthreads, const T *output, const T *grad_output, const T *input,
+    const T *rois, T *grad_rois, const int pooled_height,
+    const int pooled_width, const T spatial_scale, const int channels,
+    const int height, const int width) {
+  MUSA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+    auto rois_cur = rois + n * 5;
+
+    int roi_batch_ind = rois_cur[0];
+    T roi_x1 = rois_cur[1] * spatial_scale;
+    T roi_y1 = rois_cur[2] * spatial_scale;
+    T roi_x2 = rois_cur[3] * spatial_scale;
+    T roi_y2 = rois_cur[4] * spatial_scale;
+
+    T roi_width = max(roi_x2 - roi_x1, (T)0);
+    T roi_height = max(roi_y2 - roi_y1, (T)0);
+    T bin_size_h = roi_height / static_cast<T>(pooled_height);
+    T bin_size_w = roi_width / static_cast<T>(pooled_width);
+
+    const T output_grad_val = grad_output[index];
+    const T *this_input_data =
+        input + (roi_batch_ind * channels + c) * height * width;
+    const T output_val = output[index];
+    T *this_rois_grad = grad_rois + n * 5;
+
+    T bin_x1 = roi_x1 + bin_size_w * pw;
+    T bin_y1 = roi_y1 + bin_size_h * ph;
+    T bin_x2 = bin_x1 + bin_size_w;
+    T bin_y2 = bin_y1 + bin_size_h;
+
+    T bin_size = max(T(0.0), bin_size_w * bin_size_h);
+
+    T sum_out = bin_size == T(0) ? T(0) : output_grad_val / bin_size;
+
+    // WARNING: to be discussed
+    if (sum_out == 0) continue;
+
+    int start_x, start_y, end_x, end_y;
+
+    start_x = floorf(bin_x1);
+    end_x = ceilf(bin_x2);
+    start_y = floorf(bin_y1);
+    end_y = ceilf(bin_y2);
+
+    T grad_x1_y = 0, grad_x2_y = 0, grad_x_y1 = 0, grad_x_y2 = 0;
+    for (int bin_y = start_y; bin_y < end_y; ++bin_y) {
+      grad_x1_y += PrRoIPoolingSingleCoorIntegral(
+          max(bin_y1, T(bin_y)) - bin_y, min(bin_y2, T(bin_y + 1)) - bin_y,
+          PrRoIPoolingInterpolation(this_input_data, float(bin_y), bin_x1,
+                                    height, width),
+          PrRoIPoolingInterpolation(this_input_data, float(bin_y + 1), bin_x1,
+                                    height, width));
+
+      grad_x2_y += PrRoIPoolingSingleCoorIntegral(
+          max(bin_y1, T(bin_y)) - bin_y, min(bin_y2, T(bin_y + 1)) - bin_y,
+          PrRoIPoolingInterpolation(this_input_data, float(bin_y), bin_x2,
+                                    height, width),
+          PrRoIPoolingInterpolation(this_input_data, float(bin_y + 1), bin_x2,
+                                    height, width));
+    }
+
+    for (int bin_x = start_x; bin_x < end_x; ++bin_x) {
+      grad_x_y1 += PrRoIPoolingSingleCoorIntegral(
+          max(bin_x1, T(bin_x)) - bin_x, min(bin_x2, T(bin_x + 1)) - bin_x,
+          PrRoIPoolingInterpolation(this_input_data, bin_y1, float(bin_x),
+                                    height, width),
+          PrRoIPoolingInterpolation(this_input_data, bin_y1, float(bin_x + 1),
+                                    height, width));
+
+      grad_x_y2 += PrRoIPoolingSingleCoorIntegral(
+          max(bin_x1, T(bin_x)) - bin_x, min(bin_x2, T(bin_x + 1)) - bin_x,
+          PrRoIPoolingInterpolation(this_input_data, bin_y2, float(bin_x),
+                                    height, width),
+          PrRoIPoolingInterpolation(this_input_data, bin_y2, float(bin_x + 1),
+                                    height, width));
+    }
+
+    T partial_x1 = -grad_x1_y + (bin_y2 - bin_y1) * output_val;
+    T partial_y1 = -grad_x_y1 + (bin_x2 - bin_x1) * output_val;
+    T partial_x2 = grad_x2_y - (bin_y2 - bin_y1) * output_val;
+    T partial_y2 = grad_x_y2 - (bin_x2 - bin_x1) * output_val;
+
+    partial_x1 = partial_x1 / bin_size * spatial_scale;
+    partial_x2 = partial_x2 / bin_size * spatial_scale;
+    partial_y1 = partial_y1 / bin_size * spatial_scale;
+    partial_y2 = partial_y2 / bin_size * spatial_scale;
+
+    // (index, x1, y1, x2, y2)
+    this_rois_grad[0] = 0;
+    atomicAdd(this_rois_grad + 1,
+              (partial_x1 * (1.0f - T(pw) / pooled_width) +
+               partial_x2 * (1.0f - T(pw + 1) / pooled_width)) *
+                  output_grad_val);
+    atomicAdd(this_rois_grad + 2,
+              (partial_y1 * (1.0f - T(ph) / pooled_height) +
+               partial_y2 * (1.0f - T(ph + 1) / pooled_height)) *
+                  output_grad_val);
+    atomicAdd(this_rois_grad + 3, (partial_x2 * T(pw + 1) / pooled_width +
+                                   partial_x1 * T(pw) / pooled_width) *
+                                      output_grad_val);
+    atomicAdd(this_rois_grad + 4, (partial_y2 * T(ph + 1) / pooled_height +
+                                   partial_y1 * T(ph) / pooled_height) *
+                                      output_grad_val);
+  }
+}
+
+#endif  // ROI_POOL_MUSA_KERNEL_MUH
diff --git a/mmcv/ops/csrc/common/musa/psamask_musa_kernel.muh b/mmcv/ops/csrc/common/musa/psamask_musa_kernel.muh
new file mode 100644
index 000000000..75091ea4b
--- /dev/null
+++ b/mmcv/ops/csrc/common/musa/psamask_musa_kernel.muh
@@ -0,0 +1,137 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef PSAMASK_MUSA_KERNEL_MUH
+#define PSAMASK_MUSA_KERNEL_MUH
+
+#include "pytorch_musa_helper.hpp"
+
+// MUSA: grid stride looping
+#ifndef MUSA_KERNEL_LOOP
+#define MUSA_KERNEL_LOOP(i, n)                                 \
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+       i += blockDim.x * gridDim.x)
+#endif
+
+template <typename T>
+__global__ void psamask_collect_forward_musa(
+    const int nthreads, const int h_feature, const int w_feature,
+    const int h_mask, const int w_mask, const int half_h_mask,
+    const int half_w_mask, const T* mask_data, T* buffer_data) {
+  MUSA_KERNEL_LOOP(index, nthreads) {
+    const int w = index % w_feature;
+    const int h = (index / w_feature) % h_feature;
+    const int n = index / w_feature / h_feature;
+    // effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed
+    const int hstart = max(0, half_h_mask - h);
+    const int hend = min(h_mask, h_feature + half_h_mask - h);
+    const int wstart = max(0, half_w_mask - w);
+    const int wend = min(w_mask, w_feature + half_w_mask - w);
+    // (hidx,                    widx                   ) with mask-indexed
+    // (hidx + h - half_h_mask, widx + w - half_w_mask) with feature-indexed
+    for (int hidx = hstart; hidx < hend; hidx++) {
+      for (int widx = wstart; widx < wend; widx++) {
+        buffer_data[(n * h_feature * w_feature +
+                     (hidx + h - half_h_mask) * w_feature +
+                     (widx + w - half_w_mask)) *
+                        h_feature * w_feature +
+                    h * w_feature + w] = mask_data
+            [((n * h_mask * w_mask + hidx * w_mask + widx) * h_feature + h) *
+                 w_feature +
+             w];
+      }
+    }
+  }
+}
+
+template <typename T>
+__global__ void psamask_distribute_forward_musa(
+    const int nthreads, const int h_feature, const int w_feature,
+    const int h_mask, const int w_mask, const int half_h_mask,
+    const int half_w_mask, const T* mask_data, T* buffer_data) {
+  MUSA_KERNEL_LOOP(index, nthreads) {
+    const int w = index % w_feature;
+    const int h = (index / w_feature) % h_feature;
+    const int n = index / w_feature / h_feature;
+    // effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed
+    const int hstart = max(0, half_h_mask - h);
+    const int hend = min(h_mask, h_feature + half_h_mask - h);
+    const int wstart = max(0, half_w_mask - w);
+    const int wend = min(w_mask, w_feature + half_w_mask - w);
+    // (hidx,                    widx                   ) with mask-indexed
+    // (hidx + h - half_h_mask, widx + w - half_w_mask) with feature-indexed
+    for (int hidx = hstart; hidx < hend; hidx++) {
+      for (int widx = wstart; widx < wend; widx++) {
+        buffer_data[(n * h_feature * w_feature + h * w_feature + w) *
+                        h_feature * w_feature +
+                    (hidx + h - half_h_mask) * w_feature +
+                    (widx + w - half_w_mask)] = mask_data
+            [((n * h_mask * w_mask + hidx * w_mask + widx) * h_feature + h) *
+                 w_feature +
+             w];
+      }
+    }
+  }
+}
+
+template <typename T>
+__global__ void psamask_collect_backward_musa(
+    const int nthreads, const int h_feature, const int w_feature,
+    const int h_mask, const int w_mask, const int half_h_mask,
+    const int half_w_mask, const T* buffer_diff, T* mask_diff) {
+  MUSA_KERNEL_LOOP(index, nthreads) {
+    const int w = index % w_feature;
+    const int h = (index / w_feature) % h_feature;
+    const int n = index / w_feature / h_feature;
+    // effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed
+    const int hstart = max(0, half_h_mask - h);
+    const int hend = min(h_mask, h_feature + half_h_mask - h);
+    const int wstart = max(0, half_w_mask - w);
+    const int wend = min(w_mask, w_feature + half_w_mask - w);
+    // (hidx,                    widx                   ) with mask-indexed
+    // (hidx + h - half_h_mask, widx + w - half_w_mask) with feature-indexed
+    for (int hidx = hstart; hidx < hend; hidx++) {
+      for (int widx = wstart; widx < wend; widx++) {
+        mask_diff[((n * h_mask * w_mask + hidx * w_mask + widx) * h_feature +
+                   h) *
+                      w_feature +
+                  w] = buffer_diff[(n * h_feature * w_feature +
+                                    (hidx + h - half_h_mask) * w_feature +
+                                    (widx + w - half_w_mask)) *
+                                       h_feature * w_feature +
+                                   h * w_feature + w];
+      }
+    }
+  }
+}
+
+template <typename T>
+__global__ void psamask_distribute_backward_musa(
+    const int nthreads, const int h_feature, const int w_feature,
+    const int h_mask, const int w_mask, const int half_h_mask,
+    const int half_w_mask, const T* buffer_diff, T* mask_diff) {
+  MUSA_KERNEL_LOOP(index, nthreads) {
+    const int w = index % w_feature;
+    const int h = (index / w_feature) % h_feature;
+    const int n = index / w_feature / h_feature;
+    // effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed
+    const int hstart = max(0, half_h_mask - h);
+    const int hend = min(h_mask, h_feature + half_h_mask - h);
+    const int wstart = max(0, half_w_mask - w);
+    const int wend = min(w_mask, w_feature + half_w_mask - w);
+    // (hidx,                    widx                   ) with mask-indexed
+    // (hidx + h - half_h_mask, widx + w - half_w_mask) with feature-indexed
+    for (int hidx = hstart; hidx < hend; hidx++) {
+      for (int widx = wstart; widx < wend; widx++) {
+        mask_diff[((n * h_mask * w_mask + hidx * w_mask + widx) * h_feature +
+                   h) *
+                      w_feature +
+                  w] =
+            buffer_diff[(n * h_feature * w_feature + h * w_feature + w) *
+                            h_feature * w_feature +
+                        (hidx + h - half_h_mask) * w_feature +
+                        (widx + w - half_w_mask)];
+      }
+    }
+  }
+}
+
+#endif  // PSAMASK_MUSA_KERNEL_MUH
diff --git a/mmcv/ops/csrc/common/musa/riroi_align_rotated_musa_kernel.muh b/mmcv/ops/csrc/common/musa/riroi_align_rotated_musa_kernel.muh
new file mode 100644
index 000000000..b5124798b
--- /dev/null
+++ b/mmcv/ops/csrc/common/musa/riroi_align_rotated_musa_kernel.muh
@@ -0,0 +1,238 @@
+// Modified from
+// https://github.com/csuhan/ReDet/blob/master/mmdet/ops/riroi_align/src/riroi_align_kernel.cu
+#ifndef RIROI_ALIGN_ROTATED_MUSA_KERNEL_MUH
+#define RIROI_ALIGN_ROTATED_MUSA_KERNEL_MUH
+
+#include <float.h>
+#include "pytorch_musa_helper.hpp"
+
+/*** Forward ***/
+template <typename scalar_t>
+__global__ void riroi_align_rotated_forward_musa_kernel(
+    const int nthreads, const scalar_t *bottom_data,
+    const scalar_t *bottom_rois, const scalar_t spatial_scale,
+    const int num_samples, const bool clockwise, const int channels,
+    const int height, const int width, const int pooled_height,
+    const int pooled_width, const int num_orientations, scalar_t *top_data) {
+  MUSA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int o = (index / pooled_width / pooled_height) % num_orientations;
+    int c =
+        (index / pooled_width / pooled_height / num_orientations) % channels;
+    int n = index / pooled_width / pooled_height / num_orientations / channels;
+
+    const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
+    int roi_batch_ind = offset_bottom_rois[0];
+
+    // Do not using rounding; this implementation detail is critical
+    scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale;
+    scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale;
+    scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
+    scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
+    // scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
+    scalar_t theta = offset_bottom_rois[5];
+    // Force malformed ROIs to be 1x1
+    roi_width = max(roi_width, (scalar_t)1.);
+    roi_height = max(roi_height, (scalar_t)1.);
+    scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
+                          static_cast<scalar_t>(pooled_height);
+    scalar_t bin_size_w =
+        static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
+
+    // find aligned index
+    scalar_t ind_float = theta * num_orientations / (2 * M_PI);
+    int ind = floorf(ind_float);
+    scalar_t l_var = ind_float - (scalar_t)ind;
+    scalar_t r_var = 1.0 - l_var;
+    // correct start channel
+    ind = (ind + num_orientations) % num_orientations;
+    // rotated channel
+    int ind_rot = (o - ind + num_orientations) % num_orientations;
+    int ind_rot_plus = (ind_rot + 1 + num_orientations) % num_orientations;
+    const scalar_t *offset_bottom_data =
+        bottom_data + (roi_batch_ind * channels * num_orientations +
+                       c * num_orientations + ind_rot) *
+                          height * width;
+
+    const scalar_t *offset_bottom_data_plus =
+        bottom_data + (roi_batch_ind * channels * num_orientations +
+                       c * num_orientations + ind_rot_plus) *
+                          height * width;
+    // We use roi_bin_grid to sample the grid and mimic integral
+    int roi_bin_grid_h = (num_samples > 0)
+                             ? num_samples
+                             : ceilf(roi_height / pooled_height);  // e.g., = 2
+    int roi_bin_grid_w =
+        (num_samples > 0) ? num_samples : ceilf(roi_width / pooled_width);
+
+    // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
+    // Appropriate translation needs to be applied after.
+    if (clockwise) {
+      theta = -theta;  // If clockwise, the angle needs to be reversed.
+    }
+    scalar_t roi_start_h = -roi_height / 2.0;
+    scalar_t roi_start_w = -roi_width / 2.0;
+    scalar_t cosscalar_theta = cos(theta);
+    scalar_t sinscalar_theta = sin(theta);
+
+    // We do average (integral) pooling inside a bin
+    const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1);  // e.g. = 4
+
+    scalar_t output_val = 0.;
+    for (int iy = 0; iy < roi_bin_grid_h; iy++) {  // e.g., iy = 0, 1
+      const scalar_t yy =
+          roi_start_h + ph * bin_size_h +
+          static_cast<scalar_t>(iy + .5f) * bin_size_h /
+              static_cast<scalar_t>(roi_bin_grid_h);  // e.g., 0.5, 1.5
+      for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+        const scalar_t xx = roi_start_w + pw * bin_size_w +
+                            static_cast<scalar_t>(ix + .5f) * bin_size_w /
+                                static_cast<scalar_t>(roi_bin_grid_w);
+
+        // Rotate by theta (counterclockwise) around the center and translate
+        scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h;
+        scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w;
+
+        scalar_t val = bilinear_interpolate<scalar_t>(
+            offset_bottom_data, height, width, y, x, index);
+        scalar_t val_plus = bilinear_interpolate<scalar_t>(
+            offset_bottom_data_plus, height, width, y, x, index);
+        output_val += r_var * val + l_var * val_plus;
+      }
+    }
+    output_val /= count;
+
+    top_data[index] = output_val;
+  }
+}
+
+/*** Backward ***/
+template <typename scalar_t>
+__global__ void riroi_align_rotated_backward_musa_kernel(
+    const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois,
+    const scalar_t spatial_scale, const int num_samples, const bool clockwise,
+    const int channels, const int height, const int width,
+    const int pooled_height, const int pooled_width, const int num_orientations,
+    scalar_t *bottom_diff) {
+  MUSA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int o = (index / pooled_width / pooled_height) % num_orientations;
+    int c =
+        (index / pooled_width / pooled_height / num_orientations) % channels;
+    int n = index / pooled_width / pooled_height / num_orientations / channels;
+
+    const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
+    int roi_batch_ind = offset_bottom_rois[0];
+
+    // Do not round
+    scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale;
+    scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale;
+    scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
+    scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
+    // scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
+    scalar_t theta = offset_bottom_rois[5];
+    // Force malformed ROIs to be 1x1
+    roi_width = max(roi_width, (scalar_t)1.);
+    roi_height = max(roi_height, (scalar_t)1.);
+
+    scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
+                          static_cast<scalar_t>(pooled_height);
+    scalar_t bin_size_w =
+        static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
+
+    // find aligned index
+    scalar_t ind_float = theta * num_orientations / (2 * M_PI);
+    int ind = floorf(ind_float);
+    scalar_t l_var = ind_float - (scalar_t)ind;
+    scalar_t r_var = 1.0 - l_var;
+    // correct start channel
+    ind = (ind + num_orientations) % num_orientations;
+    // rotated channel
+    int ind_rot = (o - ind + num_orientations) % num_orientations;
+    int ind_rot_plus = (ind_rot + 1 + num_orientations) % num_orientations;
+    scalar_t *offset_bottom_diff =
+        bottom_diff + (roi_batch_ind * channels * num_orientations +
+                       c * num_orientations + ind_rot) *
+                          height * width;
+    scalar_t *offset_bottom_diff_plus =
+        bottom_diff + (roi_batch_ind * channels * num_orientations +
+                       c * num_orientations + ind_rot_plus) *
+                          height * width;
+    int top_offset =
+        (n * channels * num_orientations + c * num_orientations + o) *
+        pooled_height * pooled_width;
+    const scalar_t *offset_top_diff = top_diff + top_offset;
+    const scalar_t top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
+
+    // We use roi_bin_grid to sample the grid and mimic integral
+    int roi_bin_grid_h = (num_samples > 0)
+                             ? num_samples
+                             : ceilf(roi_height / pooled_height);  // e.g., = 2
+    int roi_bin_grid_w =
+        (num_samples > 0) ? num_samples : ceilf(roi_width / pooled_width);
+
+    // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
+    // Appropriate translation needs to be applied after.
+    if (clockwise) {
+      theta = -theta;  // If clockwise, the angle needs to be reversed.
+    }
+    scalar_t roi_start_h = -roi_height / 2.0;
+    scalar_t roi_start_w = -roi_width / 2.0;
+    scalar_t cosTheta = cos(theta);
+    scalar_t sinTheta = sin(theta);
+
+    // We do average (integral) pooling inside a bin
+    const scalar_t count = roi_bin_grid_h * roi_bin_grid_w;  // e.g. = 4
+
+    for (int iy = 0; iy < roi_bin_grid_h; iy++) {  // e.g., iy = 0, 1
+      const scalar_t yy =
+          roi_start_h + ph * bin_size_h +
+          static_cast<scalar_t>(iy + .5f) * bin_size_h /
+              static_cast<scalar_t>(roi_bin_grid_h);  // e.g., 0.5, 1.5
+      for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+        const scalar_t xx = roi_start_w + pw * bin_size_w +
+                            static_cast<scalar_t>(ix + .5f) * bin_size_w /
+                                static_cast<scalar_t>(roi_bin_grid_w);
+
+        // Rotate by theta around the center and translate
+        scalar_t y = yy * cosTheta - xx * sinTheta + roi_center_h;
+        scalar_t x = yy * sinTheta + xx * cosTheta + roi_center_w;
+
+        scalar_t w1, w2, w3, w4;
+        int x_low, x_high, y_low, y_high;
+
+        bilinear_interpolate_gradient<scalar_t>(height, width, y, x, w1, w2, w3,
+                                                w4, x_low, x_high, y_low,
+                                                y_high, index);
+
+        scalar_t g1 = top_diff_this_bin * w1 / count;
+        scalar_t g2 = top_diff_this_bin * w2 / count;
+        scalar_t g3 = top_diff_this_bin * w3 / count;
+        scalar_t g4 = top_diff_this_bin * w4 / count;
+
+        if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
+          atomicAdd(offset_bottom_diff + y_low * width + x_low, g1 * r_var);
+          atomicAdd(offset_bottom_diff + y_low * width + x_high, g2 * r_var);
+          atomicAdd(offset_bottom_diff + y_high * width + x_low, g3 * r_var);
+          atomicAdd(offset_bottom_diff + y_high * width + x_high, g4 * r_var);
+
+          atomicAdd(offset_bottom_diff_plus + y_low * width + x_low,
+                    g1 * l_var);
+          atomicAdd(offset_bottom_diff_plus + y_low * width + x_high,
+                    g2 * l_var);
+          atomicAdd(offset_bottom_diff_plus + y_high * width + x_low,
+                    g3 * l_var);
+          atomicAdd(offset_bottom_diff_plus + y_high * width + x_high,
+                    g4 * l_var);
+
+        }  // if
+      }    // ix
+    }      // iy
+  }        // MUSA_1D_KERNEL_LOOP
+}  // RiRoIAlignBackward
+
+#endif  // RIROI_ALIGN_ROTATED_MUSA_KERNEL_MUH
diff --git a/mmcv/ops/csrc/common/musa/roi_align_musa_kernel.muh b/mmcv/ops/csrc/common/musa/roi_align_musa_kernel.muh
new file mode 100644
index 000000000..afbc1de68
--- /dev/null
+++ b/mmcv/ops/csrc/common/musa/roi_align_musa_kernel.muh
@@ -0,0 +1,205 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef ROI_ALIGN_MUSA_KERNEL_MUH
+#define ROI_ALIGN_MUSA_KERNEL_MUH
+
+#include <float.h>
+#include "pytorch_musa_helper.hpp"
+
+
+/*** Forward ***/
+template <typename T>
+__global__ void roi_align_forward_musa_kernel(
+    const int nthreads, const T* input, const T* rois, T* output, T* argmax_y,
+    T* argmax_x, const int pooled_height, const int pooled_width,
+    const T spatial_scale, const int sampling_ratio,
+    const int pool_mode,  // 0 - max pool, 1 - avg pool
+    const bool aligned, const int channels, const int height, const int width) {
+  MUSA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+
+    const T* offset_rois = rois + n * 5;
+    int roi_batch_ind = offset_rois[0];
+
+    // Do not using rounding; this implementation detail is critical
+    T offset = aligned ? (T)0.5 : (T)0.0;
+    T roi_start_w = offset_rois[1] * spatial_scale - offset;
+    T roi_start_h = offset_rois[2] * spatial_scale - offset;
+    T roi_end_w = offset_rois[3] * spatial_scale - offset;
+    T roi_end_h = offset_rois[4] * spatial_scale - offset;
+
+    T roi_width = roi_end_w - roi_start_w;
+    T roi_height = roi_end_h - roi_start_h;
+    if (!aligned) {  // for backward-compatibility only
+      roi_width = max(roi_width, (T)1.);
+      roi_height = max(roi_height, (T)1.);
+    }
+
+    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+
+    const T* offset_input =
+        input + (roi_batch_ind * channels + c) * height * width;
+
+    // We use roi_bin_grid to sample the grid and mimic integral
+    int roi_bin_grid_h =
+        (sampling_ratio > 0)
+            ? sampling_ratio
+            : static_cast<int>(ceilf(roi_height / pooled_height));
+    int roi_bin_grid_w =
+        (sampling_ratio > 0)
+            ? sampling_ratio
+            : static_cast<int>(ceilf(roi_width / pooled_width));
+
+    if (pool_mode == 0) {
+      // We do max pooling inside a bin
+      T maxval = -FLT_MAX;
+      T maxidx_y = -1.f, maxidx_x = -1.f;
+      for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+        const T y = roi_start_h + ph * bin_size_h +
+                    static_cast<T>(iy + .5f) * bin_size_h /
+                        static_cast<T>(roi_bin_grid_h);
+        for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+          const T x = roi_start_w + pw * bin_size_w +
+                      static_cast<T>(ix + .5f) * bin_size_w /
+                          static_cast<T>(roi_bin_grid_w);
+          T val =
+              bilinear_interpolate(offset_input, height, width, y, x, index);
+          if (val > maxval) {
+            maxval = val;
+            maxidx_y = y;
+            maxidx_x = x;
+          }
+        }
+      }
+      output[index] = maxval;
+      argmax_y[index] = maxidx_y;
+      argmax_x[index] = maxidx_x;
+    } else if (pool_mode == 1) {
+      // We do average pooling inside a bin
+      const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1);
+      T output_val = 0.;
+      for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+        const T y = roi_start_h + ph * bin_size_h +
+                    static_cast<T>(iy + .5f) * bin_size_h /
+                        static_cast<T>(roi_bin_grid_h);
+        for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+          const T x = roi_start_w + pw * bin_size_w +
+                      static_cast<T>(ix + .5f) * bin_size_w /
+                          static_cast<T>(roi_bin_grid_w);
+          T val =
+              bilinear_interpolate(offset_input, height, width, y, x, index);
+          output_val += val;
+        }
+      }
+      output[index] = output_val / count;
+    }
+  }
+}
+
+/*** Backward ***/
+template <typename T>
+__global__ void roi_align_backward_musa_kernel(
+    const int nthreads, const T* grad_output, const T* rois, const T* argmax_y,
+    const T* argmax_x, T* grad_input, const int pooled_height,
+    const int pooled_width, const T spatial_scale, const int sampling_ratio,
+    const int pool_mode,  // 0 - max pool, 1 - avg pool
+    const bool aligned, const int channels, const int height, const int width) {
+  MUSA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+
+    const T grad_output_this_bin = grad_output[index];
+
+    const T* offset_rois = rois + n * 5;
+    int roi_batch_ind = offset_rois[0];
+    T* offset_grad_input =
+        grad_input + ((roi_batch_ind * channels + c) * height * width);
+
+    if (pool_mode == 0) {
+      T y = argmax_y[index], x = argmax_x[index];
+      if (y != -1.f) {
+        T w1, w2, w3, w4;
+        int x_low, x_high, y_low, y_high;
+        bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
+                                      x_low, x_high, y_low, y_high, index);
+
+        if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
+          atomicAdd(offset_grad_input + y_low * width + x_low,
+                    grad_output_this_bin * w1);
+          atomicAdd(offset_grad_input + y_low * width + x_high,
+                    grad_output_this_bin * w2);
+          atomicAdd(offset_grad_input + y_high * width + x_low,
+                    grad_output_this_bin * w3);
+          atomicAdd(offset_grad_input + y_high * width + x_high,
+                    grad_output_this_bin * w4);
+        }
+      }
+    } else if (pool_mode == 1) {
+      // Do not using rounding; this implementation detail is critical
+      T offset = aligned ? (T)0.5 : (T)0.0;
+      T roi_start_w = offset_rois[1] * spatial_scale - offset;
+      T roi_start_h = offset_rois[2] * spatial_scale - offset;
+      T roi_end_w = offset_rois[3] * spatial_scale - offset;
+      T roi_end_h = offset_rois[4] * spatial_scale - offset;
+
+      T roi_width = roi_end_w - roi_start_w;
+      T roi_height = roi_end_h - roi_start_h;
+      if (!aligned) {  // for backward-compatibility only
+        roi_width = max(roi_width, (T)1.);
+        roi_height = max(roi_height, (T)1.);
+      }
+
+      T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+      T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+
+      // We use roi_bin_grid to sample the grid and mimic integral
+      int roi_bin_grid_h =
+          (sampling_ratio > 0)
+              ? sampling_ratio
+              : static_cast<int>(ceilf(roi_height / pooled_height));
+      int roi_bin_grid_w =
+          (sampling_ratio > 0)
+              ? sampling_ratio
+              : static_cast<int>(ceilf(roi_width / pooled_width));
+
+      // We do average (integral) pooling inside a bin
+      const T count = roi_bin_grid_h * roi_bin_grid_w;  // e.g. = 4
+
+      for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+        const T y = roi_start_h + ph * bin_size_h +
+                    static_cast<T>(iy + .5f) * bin_size_h /
+                        static_cast<T>(roi_bin_grid_h);
+        for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+          const T x = roi_start_w + pw * bin_size_w +
+                      static_cast<T>(ix + .5f) * bin_size_w /
+                          static_cast<T>(roi_bin_grid_w);
+
+          T w1, w2, w3, w4;
+          int x_low, x_high, y_low, y_high;
+          bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
+                                        x_low, x_high, y_low, y_high, index);
+
+          if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
+            atomicAdd(offset_grad_input + y_low * width + x_low,
+                      grad_output_this_bin * w1 / count);
+            atomicAdd(offset_grad_input + y_low * width + x_high,
+                      grad_output_this_bin * w2 / count);
+            atomicAdd(offset_grad_input + y_high * width + x_low,
+                      grad_output_this_bin * w3 / count);
+            atomicAdd(offset_grad_input + y_high * width + x_high,
+                      grad_output_this_bin * w4 / count);
+          }
+        }
+      }
+    }
+  }
+}
+
+#endif  // ROI_ALIGN_MUSA_KERNEL_MUH
diff --git a/mmcv/ops/csrc/common/musa/roi_align_rotated_musa_kernel.muh b/mmcv/ops/csrc/common/musa/roi_align_rotated_musa_kernel.muh
new file mode 100644
index 000000000..76249a122
--- /dev/null
+++ b/mmcv/ops/csrc/common/musa/roi_align_rotated_musa_kernel.muh
@@ -0,0 +1,194 @@
+// Modified from
+// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/ROIAlignRotated
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+#ifndef ROI_ALIGN_ROTATED_MUSA_KERNEL_MUH
+#define ROI_ALIGN_ROTATED_MUSA_KERNEL_MUH
+
+#include <float.h>
+#include "pytorch_musa_helper.hpp"
+
+/*** Forward ***/
+template <typename scalar_t>
+__global__ void roi_align_rotated_forward_musa_kernel(
+    const int nthreads, const scalar_t *bottom_data,
+    const scalar_t *bottom_rois, const scalar_t spatial_scale,
+    const int sampling_ratio, const bool aligned, const bool clockwise,
+    const int channels, const int height, const int width,
+    const int pooled_height, const int pooled_width, scalar_t *top_data) {
+  MUSA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+
+    const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
+    int roi_batch_ind = offset_bottom_rois[0];
+
+    // Do not using rounding; this implementation detail is critical
+    scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0;
+    scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale - offset;
+    scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale - offset;
+    scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
+    scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
+    // scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
+    scalar_t theta = offset_bottom_rois[5];
+    if (clockwise) {
+      theta = -theta;  // If clockwise, the angle needs to be reversed.
+    }
+    if (!aligned) {  // for backward-compatibility only
+      // Force malformed ROIs to be 1x1
+      roi_width = max(roi_width, (scalar_t)1.);
+      roi_height = max(roi_height, (scalar_t)1.);
+    }
+    scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
+                          static_cast<scalar_t>(pooled_height);
+    scalar_t bin_size_w =
+        static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
+
+    const scalar_t *offset_bottom_data =
+        bottom_data + (roi_batch_ind * channels + c) * height * width;
+
+    // We use roi_bin_grid to sample the grid and mimic integral
+    int roi_bin_grid_h = (sampling_ratio > 0)
+                             ? sampling_ratio
+                             : ceilf(roi_height / pooled_height);  // e.g., = 2
+    int roi_bin_grid_w =
+        (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
+
+    // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
+    // Appropriate translation needs to be applied after.
+    scalar_t roi_start_h = -roi_height / 2.0;
+    scalar_t roi_start_w = -roi_width / 2.0;
+    scalar_t cosscalar_theta = cos(theta);
+    scalar_t sinscalar_theta = sin(theta);
+
+    // We do average (integral) pooling inside a bin
+    const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1);  // e.g. = 4
+
+    scalar_t output_val = 0.;
+    for (int iy = 0; iy < roi_bin_grid_h; iy++) {  // e.g., iy = 0, 1
+      const scalar_t yy =
+          roi_start_h + ph * bin_size_h +
+          static_cast<scalar_t>(iy + .5f) * bin_size_h /
+              static_cast<scalar_t>(roi_bin_grid_h);  // e.g., 0.5, 1.5
+      for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+        const scalar_t xx = roi_start_w + pw * bin_size_w +
+                            static_cast<scalar_t>(ix + .5f) * bin_size_w /
+                                static_cast<scalar_t>(roi_bin_grid_w);
+
+        // Rotate by theta (counterclockwise) around the center and translate
+        scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h;
+        scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w;
+
+        scalar_t val = bilinear_interpolate<scalar_t>(
+            offset_bottom_data, height, width, y, x, index);
+        output_val += val;
+      }
+    }
+    output_val /= count;
+
+    top_data[index] = output_val;
+  }
+}
+
+/*** Backward ***/
+template <typename scalar_t>
+__global__ void roi_align_rotated_backward_musa_kernel(
+    const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois,
+    const scalar_t spatial_scale, const int sampling_ratio, const bool aligned,
+    const bool clockwise, const int channels, const int height, const int width,
+    const int pooled_height, const int pooled_width, scalar_t *bottom_diff) {
+  MUSA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+
+    const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
+    int roi_batch_ind = offset_bottom_rois[0];
+
+    // Do not round
+    scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0;
+    scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale - offset;
+    scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale - offset;
+    scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
+    scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
+    // scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
+    scalar_t theta = offset_bottom_rois[5];
+    if (clockwise) {
+      theta = -theta;  // If clockwise, the angle needs to be reversed.
+    }
+    if (!aligned) {  // for backward-compatibility only
+      // Force malformed ROIs to be 1x1
+      roi_width = max(roi_width, (scalar_t)1.);
+      roi_height = max(roi_height, (scalar_t)1.);
+    }
+    scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
+                          static_cast<scalar_t>(pooled_height);
+    scalar_t bin_size_w =
+        static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
+
+    scalar_t *offset_bottom_diff =
+        bottom_diff + (roi_batch_ind * channels + c) * height * width;
+
+    int top_offset = (n * channels + c) * pooled_height * pooled_width;
+    const scalar_t *offset_top_diff = top_diff + top_offset;
+    const scalar_t top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
+
+    // We use roi_bin_grid to sample the grid and mimic integral
+    int roi_bin_grid_h = (sampling_ratio > 0)
+                             ? sampling_ratio
+                             : ceilf(roi_height / pooled_height);  // e.g., = 2
+    int roi_bin_grid_w =
+        (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
+
+    // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
+    // Appropriate translation needs to be applied after.
+    scalar_t roi_start_h = -roi_height / 2.0;
+    scalar_t roi_start_w = -roi_width / 2.0;
+    scalar_t cosTheta = cos(theta);
+    scalar_t sinTheta = sin(theta);
+
+    // We do average (integral) pooling inside a bin
+    const scalar_t count = roi_bin_grid_h * roi_bin_grid_w;  // e.g. = 4
+
+    for (int iy = 0; iy < roi_bin_grid_h; iy++) {  // e.g., iy = 0, 1
+      const scalar_t yy =
+          roi_start_h + ph * bin_size_h +
+          static_cast<scalar_t>(iy + .5f) * bin_size_h /
+              static_cast<scalar_t>(roi_bin_grid_h);  // e.g., 0.5, 1.5
+      for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+        const scalar_t xx = roi_start_w + pw * bin_size_w +
+                            static_cast<scalar_t>(ix + .5f) * bin_size_w /
+                                static_cast<scalar_t>(roi_bin_grid_w);
+
+        // Rotate by theta around the center and translate
+        scalar_t y = yy * cosTheta - xx * sinTheta + roi_center_h;
+        scalar_t x = yy * sinTheta + xx * cosTheta + roi_center_w;
+
+        scalar_t w1, w2, w3, w4;
+        int x_low, x_high, y_low, y_high;
+
+        bilinear_interpolate_gradient<scalar_t>(height, width, y, x, w1, w2, w3,
+                                                w4, x_low, x_high, y_low,
+                                                y_high, index);
+
+        scalar_t g1 = top_diff_this_bin * w1 / count;
+        scalar_t g2 = top_diff_this_bin * w2 / count;
+        scalar_t g3 = top_diff_this_bin * w3 / count;
+        scalar_t g4 = top_diff_this_bin * w4 / count;
+
+        if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
+          atomicAdd(offset_bottom_diff + y_low * width + x_low, g1);
+          atomicAdd(offset_bottom_diff + y_low * width + x_high, g2);
+          atomicAdd(offset_bottom_diff + y_high * width + x_low, g3);
+          atomicAdd(offset_bottom_diff + y_high * width + x_high, g4);
+        }  // if
+      }    // ix
+    }      // iy
+  }        // MUSA_1D_KERNEL_LOOP
+}  // RoIAlignBackward
+
+#endif  // ROI_ALIGN_ROTATED_MUSA_KERNEL_MUH
diff --git a/mmcv/ops/csrc/common/musa/roi_pool_musa_kernel.muh b/mmcv/ops/csrc/common/musa/roi_pool_musa_kernel.muh
new file mode 100644
index 000000000..ec7738d2c
--- /dev/null
+++ b/mmcv/ops/csrc/common/musa/roi_pool_musa_kernel.muh
@@ -0,0 +1,89 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef ROI_POOL_MUSA_KERNEL_MUH
+#define ROI_POOL_MUSA_KERNEL_MUH
+
+#include "pytorch_musa_helper.hpp"
+
+template <typename T>
+__global__ void roi_pool_forward_musa_kernel(
+    const int nthreads, const T* input, const T* rois, T* output, int* argmax,
+    const int pooled_height, const int pooled_width, const T spatial_scale,
+    const int channels, const int height, const int width) {
+  MUSA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+
+    const T* offset_rois = rois + n * 5;
+    int roi_batch_ind = offset_rois[0];
+    // calculate the roi region on feature maps
+    T roi_x1 = offset_rois[1] * spatial_scale;
+    T roi_y1 = offset_rois[2] * spatial_scale;
+    T roi_x2 = (offset_rois[3] + 1) * spatial_scale;
+    T roi_y2 = (offset_rois[4] + 1) * spatial_scale;
+
+    // force malformed rois to be 1x1
+    T roi_w = roi_x2 - roi_x1;
+    T roi_h = roi_y2 - roi_y1;
+    if (roi_w <= 0 || roi_h <= 0) continue;
+
+    T bin_size_w = roi_w / static_cast<T>(pooled_width);
+    T bin_size_h = roi_h / static_cast<T>(pooled_height);
+
+    // the corresponding bin region
+    int bin_x1 = floorf(static_cast<T>(pw) * bin_size_w + roi_x1);
+    int bin_y1 = floorf(static_cast<T>(ph) * bin_size_h + roi_y1);
+    int bin_x2 = ceilf(static_cast<T>(pw + 1) * bin_size_w + roi_x1);
+    int bin_y2 = ceilf(static_cast<T>(ph + 1) * bin_size_h + roi_y1);
+
+    // add roi offsets and clip to input boundaries
+    bin_x1 = min(max(bin_x1, 0), width);
+    bin_y1 = min(max(bin_y1, 0), height);
+    bin_x2 = min(max(bin_x2, 0), width);
+    bin_y2 = min(max(bin_y2, 0), height);
+    bool is_empty = (bin_y2 <= bin_y1) || (bin_x2 <= bin_x1);
+
+    const T* offset_input =
+        input + (roi_batch_ind * channels + c) * height * width;
+    // Define an empty pooling region to be zero
+    // If nothing is pooled, argmax = -1 causes nothing to be backprop'd
+    T max_val = is_empty ? 0 : -FLT_MAX;
+    int max_idx = -1;
+    for (int h = bin_y1; h < bin_y2; ++h) {
+      for (int w = bin_x1; w < bin_x2; ++w) {
+        int offset = h * width + w;
+        if (offset_input[offset] > max_val) {
+          max_val = offset_input[offset];
+          max_idx = offset;
+        }
+      }
+    }
+    output[index] = max_val;
+    if (argmax != NULL) argmax[index] = max_idx;
+  }
+}
+
+template <typename T>
+__global__ void roi_pool_backward_musa_kernel(
+    const int nthreads, const T* grad_output, const T* rois, const int* argmax,
+    T* grad_input, const int pooled_height, const int pooled_width,
+    const int channels, const int height, const int width) {
+  MUSA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c) is an element in the pooled output
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+
+    int roi_batch_ind = rois[n * 5];
+    T* grad_input_offset =
+        grad_input + ((roi_batch_ind * channels + c) * height * width);
+    int argmax_index = argmax[index];
+
+    if (argmax_index != -1) {
+      atomicAdd(grad_input_offset + argmax_index, grad_output[index]);
+    }
+  }
+}
+
+#endif  // ROI_POOL_MUSA_KERNEL_MUH
diff --git a/mmcv/ops/csrc/common/musa/roiaware_pool3d_musa_kernel.muh b/mmcv/ops/csrc/common/musa/roiaware_pool3d_musa_kernel.muh
new file mode 100644
index 000000000..d6de6a01c
--- /dev/null
+++ b/mmcv/ops/csrc/common/musa/roiaware_pool3d_musa_kernel.muh
@@ -0,0 +1,256 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef ROIAWARE_POOL3D_MUSA_KERNEL_MUH
+#define ROIAWARE_POOL3D_MUSA_KERNEL_MUH
+
+#include "pytorch_musa_helper.hpp"
+
+template <typename T>
+__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz,
+                                             T &local_x, T &local_y) {
+  T cosa = cos(-rz), sina = sin(-rz);
+  local_x = shift_x * cosa + shift_y * (-sina);
+  local_y = shift_x * sina + shift_y * cosa;
+}
+
+template <typename T>
+__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x,
+                                        T &local_y) {
+  // param pt: (x, y, z)
+  // param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate,
+  // cz in the bottom center
+  T x = pt[0], y = pt[1], z = pt[2];
+  T cx = box3d[0], cy = box3d[1], cz = box3d[2];
+  T x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6];
+  cz += z_size /
+        2.0;  // shift to the center since cz in box3d is the bottom center
+
+  if (fabsf(z - cz) > z_size / 2.0) return 0;
+  lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
+  float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) &
+                  (local_y > -y_size / 2.0) & (local_y < y_size / 2.0);
+  return in_flag;
+}
+
+template <typename T>
+__global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num,
+                                            int out_x, int out_y, int out_z,
+                                            const T *rois, const T *pts,
+                                            int *pts_mask) {
+  // params rois: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
+  // coordinate params pts: (npoints, 3) [x, y, z] params pts_mask: (N,
+  // npoints): -1 means point does not in this box, otherwise: encode (x_idxs,
+  // y_idxs, z_idxs) by binary bit
+  int box_idx = blockIdx.y;
+  MUSA_1D_KERNEL_LOOP(pt_idx, pts_num) {
+    if (box_idx >= boxes_num) return;
+
+    pts += pt_idx * 3;
+    rois += box_idx * 7;
+    pts_mask += box_idx * pts_num + pt_idx;
+
+    T local_x = 0, local_y = 0;
+    int cur_in_flag = check_pt_in_box3d(pts, rois, local_x, local_y);
+
+    pts_mask[0] = -1;
+    if (cur_in_flag > 0) {
+      T local_z = pts[2] - rois[2];
+      T x_size = rois[3], y_size = rois[4], z_size = rois[5];
+
+      T x_res = x_size / out_x;
+      T y_res = y_size / out_y;
+      T z_res = z_size / out_z;
+
+      unsigned int x_idx = int((local_x + x_size / 2) / x_res);
+      unsigned int y_idx = int((local_y + y_size / 2) / y_res);
+      unsigned int z_idx = int(local_z / z_res);
+
+      x_idx = min(max(x_idx, 0), out_x - 1);
+      y_idx = min(max(y_idx, 0), out_y - 1);
+      z_idx = min(max(z_idx, 0), out_z - 1);
+
+      unsigned int idx_encoding = (x_idx << 16) + (y_idx << 8) + z_idx;
+
+      pts_mask[0] = idx_encoding;
+    }
+  }
+}
+
+template <typename T>
+__global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num,
+                                             int max_pts_each_voxel, int out_x,
+                                             int out_y, int out_z,
+                                             const int *pts_mask,
+                                             T *pts_idx_of_voxels) {
+  // params pts_mask: (N, npoints)  0 or 1
+  // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
+  MUSA_1D_KERNEL_LOOP(box_idx, boxes_num) {
+    int max_num_pts = max_pts_each_voxel - 1;  // index 0 is the counter
+    pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel;
+
+    for (int k = 0; k < pts_num; k++) {
+      if (pts_mask[box_idx * pts_num + k] != -1) {
+        unsigned int idx_encoding = pts_mask[box_idx * pts_num + k];
+        unsigned int x_idx = (idx_encoding >> 16) & 0xFF;
+        unsigned int y_idx = (idx_encoding >> 8) & 0xFF;
+        unsigned int z_idx = idx_encoding & 0xFF;
+        unsigned int base_offset = x_idx * out_y * out_z * max_pts_each_voxel +
+                                   y_idx * out_z * max_pts_each_voxel +
+                                   z_idx * max_pts_each_voxel;
+        unsigned int cnt = pts_idx_of_voxels[base_offset];
+        if (cnt < max_num_pts) {
+          pts_idx_of_voxels[base_offset + cnt + 1] = k;
+          pts_idx_of_voxels[base_offset]++;
+        }
+      }
+    }
+  }
+}
+
+template <typename T>
+__global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels,
+                                   int max_pts_each_voxel, int out_x, int out_y,
+                                   int out_z, const T *pts_feature,
+                                   const int *pts_idx_of_voxels,
+                                   T *pooled_features, int *argmax) {
+  // params pts_feature: (npoints, C)
+  // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel),
+  // index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C)
+  // params argmax: (N, out_x, out_y, out_z, C)
+
+  int box_idx = blockIdx.z;
+  int channel_idx = blockIdx.y;
+  MUSA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
+    int x_idx = voxel_idx_flat / (out_y * out_z);
+    int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
+    int z_idx = voxel_idx_flat % out_z;
+    if (box_idx >= boxes_num || channel_idx >= channels) return;
+
+    int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
+    pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
+                         offset_base * max_pts_each_voxel;
+    pooled_features += box_idx * out_x * out_y * out_z * channels +
+                       offset_base * channels + channel_idx;
+    argmax += box_idx * out_x * out_y * out_z * channels +
+              offset_base * channels + channel_idx;
+
+    int argmax_idx = -1;
+    float max_val = -1e50;
+
+    int total_pts = pts_idx_of_voxels[0];
+
+    for (int k = 1; k <= total_pts; k++) {
+      if (pts_feature[pts_idx_of_voxels[k] * channels + channel_idx] >
+          max_val) {
+        max_val = pts_feature[pts_idx_of_voxels[k] * channels + channel_idx];
+        argmax_idx = pts_idx_of_voxels[k];
+      }
+    }
+
+    if (argmax_idx != -1) {
+      pooled_features[0] = max_val;
+    }
+    argmax[0] = argmax_idx;
+  }
+}
+
+template <typename T>
+__global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels,
+                                   int max_pts_each_voxel, int out_x, int out_y,
+                                   int out_z, const T *pts_feature,
+                                   const int *pts_idx_of_voxels,
+                                   T *pooled_features) {
+  // params pts_feature: (npoints, C)
+  // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel),
+  // index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C)
+  // params argmax: (N, out_x, out_y, out_z, C)
+
+  int box_idx = blockIdx.z;
+  int channel_idx = blockIdx.y;
+  MUSA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
+    int x_idx = voxel_idx_flat / (out_y * out_z);
+    int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
+    int z_idx = voxel_idx_flat % out_z;
+    if (box_idx >= boxes_num || channel_idx >= channels) return;
+
+    int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
+    pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
+                         offset_base * max_pts_each_voxel;
+    pooled_features += box_idx * out_x * out_y * out_z * channels +
+                       offset_base * channels + channel_idx;
+
+    float sum_val = 0;
+    int total_pts = pts_idx_of_voxels[0];
+
+    for (int k = 1; k <= total_pts; k++) {
+      sum_val += pts_feature[pts_idx_of_voxels[k] * channels + channel_idx];
+    }
+
+    if (total_pts > 0) {
+      pooled_features[0] = sum_val / total_pts;
+    }
+  }
+}
+
+template <typename T>
+__global__ void roiaware_maxpool3d_backward(int boxes_num, int channels,
+                                            int out_x, int out_y, int out_z,
+                                            const int *argmax,
+                                            const T *grad_out, T *grad_in) {
+  // params argmax: (N, out_x, out_y, out_z, C)
+  // params grad_out: (N, out_x, out_y, out_z, C)
+  // params grad_in: (npoints, C), return value
+
+  int box_idx = blockIdx.z;
+  int channel_idx = blockIdx.y;
+  MUSA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
+    int x_idx = voxel_idx_flat / (out_y * out_z);
+    int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
+    int z_idx = voxel_idx_flat % out_z;
+    if (box_idx >= boxes_num || channel_idx >= channels) return;
+
+    int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
+    argmax += box_idx * out_x * out_y * out_z * channels +
+              offset_base * channels + channel_idx;
+    grad_out += box_idx * out_x * out_y * out_z * channels +
+                offset_base * channels + channel_idx;
+
+    if (argmax[0] == -1) return;
+
+    atomicAdd(grad_in + argmax[0] * channels + channel_idx, grad_out[0] * 1);
+  }
+}
+
+template <typename T>
+__global__ void roiaware_avgpool3d_backward(int boxes_num, int channels,
+                                            int out_x, int out_y, int out_z,
+                                            int max_pts_each_voxel,
+                                            const int *pts_idx_of_voxels,
+                                            const T *grad_out, T *grad_in) {
+  // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
+  // params grad_out: (N, out_x, out_y, out_z, C)
+  // params grad_in: (npoints, C), return value
+
+  int box_idx = blockIdx.z;
+  int channel_idx = blockIdx.y;
+  MUSA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
+    int x_idx = voxel_idx_flat / (out_y * out_z);
+    int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
+    int z_idx = voxel_idx_flat % out_z;
+    if (box_idx >= boxes_num || channel_idx >= channels) return;
+
+    int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
+    pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
+                         offset_base * max_pts_each_voxel;
+    grad_out += box_idx * out_x * out_y * out_z * channels +
+                offset_base * channels + channel_idx;
+
+    int total_pts = pts_idx_of_voxels[0];
+    float cur_grad = 1 / fmaxf(float(total_pts), 1.0);
+    for (int k = 1; k <= total_pts; k++) {
+      atomicAdd(grad_in + pts_idx_of_voxels[k] * channels + channel_idx,
+                grad_out[0] * cur_grad);
+    }
+  }
+}
+
+#endif  // ROIAWARE_POOL3D_MUSA_KERNEL_MUH
diff --git a/mmcv/ops/csrc/common/musa/roipoint_pool3d_musa_kernel.muh b/mmcv/ops/csrc/common/musa/roipoint_pool3d_musa_kernel.muh
new file mode 100644
index 000000000..0a8d1ba69
--- /dev/null
+++ b/mmcv/ops/csrc/common/musa/roipoint_pool3d_musa_kernel.muh
@@ -0,0 +1,130 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef ROIPOINT_POOL3D_MUSA_KERNEL_MUH
+#define ROIPOINT_POOL3D_MUSA_KERNEL_MUH
+
+#include "pytorch_musa_helper.hpp"
+
+template <typename T>
+__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz,
+                                             T &local_x, T &local_y) {
+  T cosa = cos(-rz), sina = sin(-rz);
+  local_x = shift_x * cosa + shift_y * (-sina);
+  local_y = shift_x * sina + shift_y * cosa;
+}
+
+template <typename T>
+__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x,
+                                        T &local_y) {
+  // param pt: (x, y, z)
+  // param box3d: (cx, cy, cz, dx, dy, dz, rz) in LiDAR coordinate, cz in the
+  // bottom center
+  T x = pt[0], y = pt[1], z = pt[2];
+  T cx = box3d[0], cy = box3d[1], cz = box3d[2];
+  T dx = box3d[3], dy = box3d[4], dz = box3d[5], rz = box3d[6];
+  cz += dz / 2.0;  // shift to the center since cz in box3d is the bottom center
+
+  if (fabsf(z - cz) > dz / 2.0) return 0;
+  lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
+  T in_flag = (local_x > -dx / 2.0) & (local_x < dx / 2.0) &
+              (local_y > -dy / 2.0) & (local_y < dy / 2.0);
+  return in_flag;
+}
+
+template <typename T>
+__global__ void assign_pts_to_box3d(int batch_size, int pts_num, int boxes_num,
+                                    const T *xyz, const T *boxes3d,
+                                    int *pts_assign) {
+  // params xyz: (B, N, 3)
+  // params boxes3d: (B, M, 7)
+  // params pts_assign: (B, N, M): idx of the corresponding box3d, -1 means
+  // background points
+  int box_idx = blockIdx.y;
+  int bs_idx = blockIdx.z;
+  MUSA_1D_KERNEL_LOOP(pt_idx, pts_num) {
+    if (box_idx >= boxes_num || bs_idx >= batch_size) return;
+
+    int assign_idx =
+        bs_idx * pts_num * boxes_num + pt_idx * boxes_num + box_idx;
+    pts_assign[assign_idx] = 0;
+
+    int box_offset = bs_idx * boxes_num * 7 + box_idx * 7;
+    int pt_offset = bs_idx * pts_num * 3 + pt_idx * 3;
+
+    T local_x = 0, local_y = 0;
+    int cur_in_flag = check_pt_in_box3d(xyz + pt_offset, boxes3d + box_offset,
+                                        local_x, local_y);
+    pts_assign[assign_idx] = cur_in_flag;
+  }
+}
+
+__global__ void get_pooled_idx(int batch_size, int pts_num, int boxes_num,
+                               int sampled_pts_num, const int *pts_assign,
+                               int *pts_idx, int *pooled_empty_flag) {
+  // params xyz: (B, N, 3)
+  // params pts_feature: (B, N, C)
+  // params pts_assign: (B, N)
+  // params pts_idx: (B, M, 512)
+  // params pooled_empty_flag: (B, M)
+  MUSA_1D_KERNEL_LOOP(boxes_idx, boxes_num) {
+    int bs_idx = blockIdx.y;
+
+    int cnt = 0;
+    for (int k = 0; k < pts_num; k++) {
+      if (pts_assign[bs_idx * pts_num * boxes_num + k * boxes_num +
+                     boxes_idx]) {
+        if (cnt < sampled_pts_num) {
+          pts_idx[bs_idx * boxes_num * sampled_pts_num +
+                  boxes_idx * sampled_pts_num + cnt] = k;
+          cnt++;
+        } else
+          break;
+      }
+    }
+
+    if (cnt == 0) {
+      pooled_empty_flag[bs_idx * boxes_num + boxes_idx] = 1;
+    } else if (cnt < sampled_pts_num) {
+      // duplicate same points for sampling
+      for (int k = cnt; k < sampled_pts_num; k++) {
+        int duplicate_idx = k % cnt;
+        int base_offset =
+            bs_idx * boxes_num * sampled_pts_num + boxes_idx * sampled_pts_num;
+        pts_idx[base_offset + k] = pts_idx[base_offset + duplicate_idx];
+      }
+    }
+  }
+}
+
+template <typename T>
+__global__ void roipoint_pool3d_forward(
+    int batch_size, int pts_num, int boxes_num, int feature_in_len,
+    int sampled_pts_num, const T *xyz, const int *pts_idx, const T *pts_feature,
+    T *pooled_features, int *pooled_empty_flag) {
+  // params xyz: (B, N, 3)
+  // params pts_idx: (B, M, 512)
+  // params pts_feature: (B, N, C)
+  // params pooled_features: (B, M, 512, 3+C)
+  // params pooled_empty_flag: (B, M)
+  int box_idx = blockIdx.y;
+  int bs_idx = blockIdx.z;
+  MUSA_1D_KERNEL_LOOP(sample_pt_idx, sampled_pts_num) {
+    if (box_idx >= boxes_num || bs_idx >= batch_size) return;
+    if (pooled_empty_flag[bs_idx * boxes_num + box_idx]) return;
+
+    int temp_idx = bs_idx * boxes_num * sampled_pts_num +
+                   box_idx * sampled_pts_num + sample_pt_idx;
+    int src_pt_idx = pts_idx[temp_idx];
+    int dst_feature_offset = temp_idx * (3 + feature_in_len);
+
+    for (int j = 0; j < 3; j++)
+      pooled_features[dst_feature_offset + j] =
+          xyz[bs_idx * pts_num * 3 + src_pt_idx * 3 + j];
+
+    int src_feature_offset =
+        bs_idx * pts_num * feature_in_len + src_pt_idx * feature_in_len;
+    memcpy(pooled_features + dst_feature_offset + 3,
+           pts_feature + src_feature_offset, feature_in_len * sizeof(T));
+  }
+}
+
+#endif  // ROIPOINT_POOL3D_MUSA_KERNEL_MUH
diff --git a/mmcv/ops/csrc/common/musa/rotated_feature_align_musa_kernel.muh b/mmcv/ops/csrc/common/musa/rotated_feature_align_musa_kernel.muh
new file mode 100644
index 000000000..b1d8785ea
--- /dev/null
+++ b/mmcv/ops/csrc/common/musa/rotated_feature_align_musa_kernel.muh
@@ -0,0 +1,125 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+// Modified from
+// https://github.com/SJTU-Thinklab-Det/r3det-on-mmdetection/blob/master/mmdet/ops/fr/src/feature_refine_kernel.cu
+#ifndef ROTATED_FEATURE_ALIGN_MUSA_KERNEL_MUH
+#define ROTATED_FEATURE_ALIGN_MUSA_KERNEL_MUH
+
+#include "pytorch_musa_helper.hpp"
+
+template <typename scalar_t>
+__global__ void rotated_feature_align_forward_kernel(
+    const int nthreads, const int points, const scalar_t* bottom_data,
+    const scalar_t* best_bboxes, const scalar_t spatial_scale,
+    const int channels, const int height, const int width, scalar_t* top_data) {
+  MUSA_1D_KERNEL_LOOP(index, nthreads) {
+    int w = index % width;
+    int h = (index / width) % height;
+    int c = (index / width / height) % channels;
+    int n = index / width / height / channels;
+
+    const scalar_t* bbox_offset =
+        best_bboxes + ((n * height + h) * width + w) * 5;
+    scalar_t roi_y = bbox_offset[0] * spatial_scale;
+    scalar_t roi_x = bbox_offset[1] * spatial_scale;
+
+    scalar_t px[5] = {roi_x, 0, 0, 0, 0};
+    scalar_t py[5] = {roi_y, 0, 0, 0, 0};
+
+    if (points > 1) {
+      scalar_t roi_w = bbox_offset[2] * spatial_scale;
+      scalar_t roi_h = bbox_offset[3] * spatial_scale;
+      scalar_t roi_a = bbox_offset[4];
+
+      scalar_t w_2 = roi_w / 2, h_2 = roi_h / 2;
+      scalar_t cosa = cosf(roi_a), sina = sinf(roi_a);
+      scalar_t wx = cosa * w_2, wy = sina * w_2;
+      scalar_t hx = -sina * h_2, hy = cosa * h_2;
+
+      px[1] = roi_x + wx + hx;
+      py[1] = roi_y + wy + hy;
+      px[2] = roi_x - wx + hx;
+      py[2] = roi_y - wy + hy;
+      px[3] = roi_x - wx - hx;
+      py[3] = roi_y - wy - hy;
+      px[4] = roi_x + wx - hx;
+      py[4] = roi_y + wy - hy;
+    }
+
+    const scalar_t* offset_bottom_data =
+        bottom_data + (n * channels + c) * height * width;
+
+    scalar_t output_val = bottom_data[index];
+    for (int i = 0; i < points; i++) {
+      output_val += bilinear_interpolate<scalar_t>(offset_bottom_data, height,
+                                                   width, py[i], px[i], i);
+    }
+    top_data[index] = output_val;
+  }
+}
+
+template <typename scalar_t>
+__global__ void rotated_feature_align_backward_kernel(
+    const int nthreads, const int points, const scalar_t* top_diff,
+    const scalar_t* best_bboxes, const scalar_t spatial_scale,
+    const int channels, const int height, const int width,
+    scalar_t* bottom_diff) {
+  MUSA_1D_KERNEL_LOOP(index, nthreads) {
+    int w = index % width;
+    int h = (index / width) % height;
+    int c = (index / width / height) % channels;
+    int n = index / width / height / channels;
+
+    const scalar_t* bbox_offset =
+        best_bboxes + ((n * height + h) * width + w) * 5;
+    scalar_t roi_y = bbox_offset[0] * spatial_scale;
+    scalar_t roi_x = bbox_offset[1] * spatial_scale;
+
+    scalar_t px[5] = {roi_x, 0, 0, 0, 0};
+    scalar_t py[5] = {roi_y, 0, 0, 0, 0};
+
+    if (points > 1) {
+      scalar_t roi_w = bbox_offset[2] * spatial_scale;
+      scalar_t roi_h = bbox_offset[3] * spatial_scale;
+      scalar_t roi_a = bbox_offset[4];
+
+      scalar_t w_2 = roi_w / 2, h_2 = roi_h / 2;
+      scalar_t cosa = cosf(roi_a), sina = sinf(roi_a);
+      scalar_t wx = cosa * w_2, wy = sina * w_2;
+      scalar_t hx = -sina * h_2, hy = cosa * h_2;
+
+      px[1] = roi_x + wx + hx;
+      py[1] = roi_y + wy + hy;
+      px[2] = roi_x - wx + hx;
+      py[2] = roi_y - wy + hy;
+      px[3] = roi_x - wx - hx;
+      py[3] = roi_y - wy - hy;
+      px[4] = roi_x + wx - hx;
+      py[4] = roi_y + wy - hy;
+    }
+
+    scalar_t* offset_bottom_diff =
+        bottom_diff + (n * channels + c) * height * width;
+    scalar_t value_top_diff = top_diff[index];
+
+    atomicAdd(bottom_diff + index, value_top_diff);
+    for (int i = 0; i < points; i++) {
+      scalar_t w1, w2, w3, w4;
+      int x_low, x_high, y_low, y_high;
+
+      bilinear_interpolate_gradient<scalar_t>(height, width, py[i], px[i], w1,
+                                              w2, w3, w4, x_low, x_high, y_low,
+                                              y_high, i);
+      scalar_t g1 = value_top_diff * w1;
+      scalar_t g2 = value_top_diff * w2;
+      scalar_t g3 = value_top_diff * w3;
+      scalar_t g4 = value_top_diff * w4;
+      if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
+        atomicAdd(offset_bottom_diff + y_low * width + x_low, g1);
+        atomicAdd(offset_bottom_diff + y_low * width + x_high, g2);
+        atomicAdd(offset_bottom_diff + y_high * width + x_low, g3);
+        atomicAdd(offset_bottom_diff + y_high * width + x_high, g4);
+      }
+    }
+  }
+}
+#endif  // ROTATED_FEATURE_ALIGN_MUSA_KERNEL_MUH
diff --git a/mmcv/ops/csrc/common/musa/scatter_points_musa_kernel.muh b/mmcv/ops/csrc/common/musa/scatter_points_musa_kernel.muh
new file mode 100644
index 000000000..ba418eceb
--- /dev/null
+++ b/mmcv/ops/csrc/common/musa/scatter_points_musa_kernel.muh
@@ -0,0 +1,137 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef SCATTER_POINTS_MUSA_KERNEL_MUH
+#define SCATTER_POINTS_MUSA_KERNEL_MUH
+
+#include "pytorch_musa_helper.hpp"
+
+typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t;
+int const maxGridDim = 50000;
+
+__device__ __forceinline__ static void reduceMax(float *address, float val) {
+  int *address_as_i = reinterpret_cast<int *>(address);
+  int old = *address_as_i, assumed;
+  do {
+    assumed = old;
+    old = atomicCAS(address_as_i, assumed,
+                    __float_as_int(fmaxf(val, __int_as_float(assumed))));
+  } while (assumed != old || __int_as_float(old) < val);
+}
+
+__device__ __forceinline__ static void reduceMax(double *address, double val) {
+  unsigned long long *address_as_ull =
+      reinterpret_cast<unsigned long long *>(address);
+  unsigned long long old = *address_as_ull, assumed;
+  do {
+    assumed = old;
+    old = atomicCAS(
+        address_as_ull, assumed,
+        __double_as_longlong(fmax(val, __longlong_as_double(assumed))));
+  } while (assumed != old || __longlong_as_double(old) < val);
+}
+
+__device__ __forceinline__ static void reduceAdd(float *address, float val) {
+  atomicAdd(address, val);
+}
+
+__device__ __forceinline__ static void reduceAdd(double *address, double val) {
+  atomicAdd(address, val);
+
+}
+
+template <typename T>
+__global__ void feats_reduce_kernel(
+    const T *feats, const int32_t *coors_map,
+    T *reduced_feats,  // shall be 0 at initialization
+    const int num_input, const int num_feats, const reduce_t reduce_type) {
+  MUSA_1D_KERNEL_LOOP(x, num_input) {
+    int32_t reduce_to = coors_map[x];
+    if (reduce_to == -1) continue;
+
+    const T *feats_offset = feats + x * num_feats;
+    T *reduced_feats_offset = reduced_feats + reduce_to * num_feats;
+    if (reduce_type == reduce_t::MAX) {
+      for (int i = 0; i < num_feats; i++) {
+        reduceMax(&reduced_feats_offset[i], feats_offset[i]);
+      }
+    } else {
+      for (int i = 0; i < num_feats; i++) {
+        reduceAdd(&reduced_feats_offset[i], feats_offset[i]);
+      }
+    }
+  }
+}
+
+template <typename T>
+__global__ void add_reduce_traceback_grad_kernel(
+    T *grad_feats, const T *grad_reduced_feats, const int32_t *coors_map,
+    const int32_t *reduce_count, const int num_input, const int num_feats,
+    const reduce_t reduce_type) {
+  MUSA_1D_KERNEL_LOOP(x, num_input) {
+    int32_t reduce_to = coors_map[x];
+    if (reduce_to == -1) {
+      continue;
+    }
+
+    const int input_offset = x * num_feats;
+    T *grad_feats_offset = grad_feats + input_offset;
+    const int reduced_offset = reduce_to * num_feats;
+    const T *grad_reduced_feats_offset = grad_reduced_feats + reduced_offset;
+
+    if (reduce_type == reduce_t::SUM) {
+      for (int i = 0; i < num_feats; i++) {
+        grad_feats_offset[i] = grad_reduced_feats_offset[i];
+      }
+    } else if (reduce_type == reduce_t::MEAN) {
+      for (int i = 0; i < num_feats; i++) {
+        grad_feats_offset[i] = grad_reduced_feats_offset[i] /
+                               static_cast<T>(reduce_count[reduce_to]);
+      }
+    }
+  }
+}
+
+template <typename T>
+__global__ void max_reduce_traceback_scatter_idx_kernel(
+    const T *feats, const T *reduced_feats, int32_t *reduce_from,
+    const int32_t *coors_map, const int num_input, const int num_feats) {
+  MUSA_1D_KERNEL_LOOP(x, num_input) {
+    int32_t reduce_to = coors_map[x];
+
+    const int input_offset = x * num_feats;
+    const T *feats_offset = feats + input_offset;
+
+    if (reduce_to == -1) {
+      continue;
+    }
+
+    const int reduced_offset = reduce_to * num_feats;
+    const T *reduced_feats_offset = reduced_feats + reduced_offset;
+    int32_t *reduce_from_offset = reduce_from + reduced_offset;
+
+    for (int i = 0; i < num_feats; i++) {
+      if (feats_offset[i] == reduced_feats_offset[i]) {
+        atomicMin(&reduce_from_offset[i], static_cast<int32_t>(x));
+      }
+    }
+  }
+}
+
+template <typename T>
+__global__ void max_reduce_scatter_grad_kernel(T *grad_feats,
+                                               const T *grad_reduced_feats,
+                                               const int32_t *reduce_from,
+                                               const int num_reduced,
+                                               const int num_feats) {
+  MUSA_1D_KERNEL_LOOP(x, num_reduced) {
+    const int reduced_offset = x * num_feats;
+    const int32_t *scatter_to_offset = reduce_from + reduced_offset;
+    const T *grad_reduced_feats_offset = grad_reduced_feats + reduced_offset;
+
+    for (int i = 0; i < num_feats; i++) {
+      grad_feats[scatter_to_offset[i] * num_feats + i] =
+          grad_reduced_feats_offset[i];
+    }
+  }
+}
+
+#endif  // SCATTER_POINTS_MUSA_KERNEL_MUH
diff --git a/mmcv/ops/csrc/common/musa/sync_bn_musa_kernel.muh b/mmcv/ops/csrc/common/musa/sync_bn_musa_kernel.muh
new file mode 100644
index 000000000..7eb5e0382
--- /dev/null
+++ b/mmcv/ops/csrc/common/musa/sync_bn_musa_kernel.muh
@@ -0,0 +1,327 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef SYNCBN_MUSA_KERNEL_MUH
+#define SYNCBN_MUSA_KERNEL_MUH
+
+#include "pytorch_musa_helper.hpp"
+
+template <typename T>
+__global__ void sync_bn_forward_mean_musa_kernel(const T *input, float *mean,
+                                                 int num, int channels,
+                                                 int spatial) {
+  __shared__ float buffer[THREADS_PER_BLOCK];
+  int tid = threadIdx.x;
+  int c = blockIdx.x;
+  buffer[tid] = 0;
+  for (int i = tid; i < num * spatial; i += blockDim.x) {
+    int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
+    buffer[tid] += input[index];
+  }
+  __syncthreads();
+
+  for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+    if (tid < s) {
+      buffer[tid] += buffer[tid + s];
+    }
+    __syncthreads();
+  }
+  int total = num * spatial;
+  if (tid == 0) {
+    mean[c] = buffer[0] / total;
+  }
+}
+
+template <>
+__global__ void sync_bn_forward_mean_musa_kernel(const phalf *input,
+                                                 float *mean, int num,
+                                                 int channels, int spatial) {
+  __shared__ float buffer[THREADS_PER_BLOCK];
+  int tid = threadIdx.x;
+  int c = blockIdx.x;
+  buffer[tid] = 0;
+  for (int i = tid; i < num * spatial; i += blockDim.x) {
+    int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
+    buffer[tid] += static_cast<float>(input[index]);
+  }
+  __syncthreads();
+
+  for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+    if (tid < s) {
+      buffer[tid] += buffer[tid + s];
+    }
+    __syncthreads();
+  }
+  int total = num * spatial;
+  if (tid == 0) {
+    mean[c] = buffer[0] / total;
+  }
+}
+
+template <typename T>
+__global__ void sync_bn_forward_var_musa_kernel(const T *input,
+                                                const float *mean, float *var,
+                                                int num, int channels,
+                                                int spatial) {
+  __shared__ float buffer[THREADS_PER_BLOCK];
+  int tid = threadIdx.x;
+  int c = blockIdx.x;
+  buffer[tid] = 0;
+  for (int i = tid; i < num * spatial; i += blockDim.x) {
+    int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
+    float td = input[index] - mean[c];
+    buffer[tid] += td * td;
+  }
+  __syncthreads();
+  for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+    if (tid < s) {
+      buffer[tid] += buffer[tid + s];
+    }
+    __syncthreads();
+  }
+  int total = num * spatial;
+  if (tid == 0) {
+    var[c] = buffer[0] / total;
+  }
+}
+
+template <>
+__global__ void sync_bn_forward_var_musa_kernel(const phalf *input,
+                                                const float *mean, float *var,
+                                                int num, int channels,
+                                                int spatial) {
+  __shared__ float buffer[THREADS_PER_BLOCK];
+  int tid = threadIdx.x;
+  int c = blockIdx.x;
+  buffer[tid] = 0;
+  for (int i = tid; i < num * spatial; i += blockDim.x) {
+    int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
+    float td = static_cast<float>(input[index]) - mean[c];
+    buffer[tid] += td * td;
+  }
+  __syncthreads();
+  for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+    if (tid < s) {
+      buffer[tid] += buffer[tid + s];
+    }
+    __syncthreads();
+  }
+  int total = num * spatial;
+  if (tid == 0) {
+    var[c] = buffer[0] / total;
+  }
+}
+
+template <typename T>
+__global__ void sync_bn_forward_output_musa_kernel(
+    const T *input, const float *mean, const float *var, float *running_mean,
+    float *running_var, const float *weight, const float *bias, float *norm,
+    float *std, T *output, int num, int channels, int spatial, float eps,
+    float momentum, int group_size) {
+  int tid = threadIdx.x;
+  int c = blockIdx.x;
+  float mean_value = mean[c];
+  float std_value = sqrt(var[c] + eps);
+
+  if (weight != nullptr) {
+    float weight_value = weight[c];
+    float bias_value = bias[c];
+    if (norm != nullptr) {
+      for (int i = tid; i < num * spatial; i += blockDim.x) {
+        int index =
+            (i / spatial) * channels * spatial + c * spatial + i % spatial;
+        norm[index] = (input[index] - mean_value) / std_value;
+        output[index] = norm[index] * weight_value + bias_value;
+      }
+    } else {
+      for (int i = tid; i < num * spatial; i += blockDim.x) {
+        int index =
+            (i / spatial) * channels * spatial + c * spatial + i % spatial;
+        output[index] =
+            (input[index] - mean_value) / std_value * weight_value + bias_value;
+      }
+    }
+  } else {
+    if (norm != nullptr) {
+      for (int i = tid; i < num * spatial; i += blockDim.x) {
+        int index =
+            (i / spatial) * channels * spatial + c * spatial + i % spatial;
+        output[index] = norm[index] = (input[index] - mean_value) / std_value;
+      }
+    } else {
+      for (int i = tid; i < num * spatial; i += blockDim.x) {
+        int index =
+            (i / spatial) * channels * spatial + c * spatial + i % spatial;
+        output[index] = (input[index] - mean_value) / std_value;
+      }
+    }
+  }
+  if (tid == 0) {
+    if (std != nullptr) std[c] = std_value;
+    if (running_mean != nullptr) {
+      running_mean[c] =
+          momentum * mean_value + (1 - momentum) * running_mean[c];
+      int count = num * spatial * group_size;
+      float var_unbias = count > 1 ? var[c] * count / (count - 1) : var[c];
+      running_var[c] = momentum * var_unbias + (1 - momentum) * running_var[c];
+    }
+  }
+}
+
+template <>
+__global__ void sync_bn_forward_output_musa_kernel(
+    const phalf *input, const float *mean, const float *var,
+    float *running_mean, float *running_var, const float *weight,
+    const float *bias, float *norm, float *std, phalf *output, int num,
+    int channels, int spatial, float eps, float momentum, int group_size) {
+  int tid = threadIdx.x;
+  int c = blockIdx.x;
+  float mean_value = mean[c];
+  float std_value = sqrt(var[c] + eps);
+  if (weight != nullptr) {
+    float weight_value = weight[c];
+    float bias_value = bias[c];
+    if (norm != nullptr) {
+      for (int i = tid; i < num * spatial; i += blockDim.x) {
+        int index =
+            (i / spatial) * channels * spatial + c * spatial + i % spatial;
+        norm[index] =
+            (static_cast<float>(input[index]) - mean_value) / std_value;
+        output[index] =
+            static_cast<phalf>(norm[index] * weight_value + bias_value);
+      }
+    } else {
+      for (int i = tid; i < num * spatial; i += blockDim.x) {
+        int index =
+            (i / spatial) * channels * spatial + c * spatial + i % spatial;
+        output[index] =
+            static_cast<phalf>((static_cast<float>(input[index]) - mean_value) /
+                                   std_value * weight_value +
+                               bias_value);
+      }
+    }
+  } else {
+    if (norm != nullptr) {
+      for (int i = tid; i < num * spatial; i += blockDim.x) {
+        int index =
+            (i / spatial) * channels * spatial + c * spatial + i % spatial;
+        norm[index] =
+            (static_cast<float>(input[index]) - mean_value) / std_value;
+        output[index] = static_cast<phalf>(norm[index]);
+      }
+    } else {
+      for (int i = tid; i < num * spatial; i += blockDim.x) {
+        int index =
+            (i / spatial) * channels * spatial + c * spatial + i % spatial;
+        output[index] = static_cast<phalf>(
+            (static_cast<float>(input[index]) - mean_value) / std_value);
+      }
+    }
+  }
+  if (tid == 0) {
+    if (std != nullptr) std[c] = std_value;
+    if (running_mean != nullptr) {
+      running_mean[c] =
+          momentum * mean_value + (1 - momentum) * running_mean[c];
+      int count = num * spatial * group_size;
+      float var_unbias = count > 1 ? var[c] * count / (count - 1) : var[c];
+      running_var[c] = momentum * var_unbias + (1 - momentum) * running_var[c];
+    }
+  }
+}
+
+template <typename T>
+__global__ void sync_bn_backward_param_musa_kernel(const T *grad_output,
+                                                   const float *norm,
+                                                   float *grad_weight,
+                                                   float *grad_bias, int num,
+                                                   int channels, int spatial) {
+  __shared__ float buffer1[THREADS_PER_BLOCK];
+  __shared__ float buffer2[THREADS_PER_BLOCK];
+
+  int tid = threadIdx.x;
+  int c = blockIdx.x;
+  buffer1[tid] = buffer2[tid] = 0;
+  for (int i = tid; i < num * spatial; i += blockDim.x) {
+    int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
+    buffer1[tid] += grad_output[index] * norm[index];
+    buffer2[tid] += grad_output[index];
+  }
+  __syncthreads();
+
+  for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+    if (tid < s) {
+      buffer1[tid] += buffer1[tid + s];
+      buffer2[tid] += buffer2[tid + s];
+    }
+    __syncthreads();
+  }
+  if (tid == 0) {
+    grad_weight[c] = buffer1[0];
+    grad_bias[c] = buffer2[0];
+  }
+}
+
+template <>
+__global__ void sync_bn_backward_param_musa_kernel(const phalf *grad_output,
+                                                   const float *norm,
+                                                   float *grad_weight,
+                                                   float *grad_bias, int num,
+                                                   int channels, int spatial) {
+  __shared__ float buffer1[THREADS_PER_BLOCK];
+  __shared__ float buffer2[THREADS_PER_BLOCK];
+
+  int tid = threadIdx.x;
+  int c = blockIdx.x;
+  buffer1[tid] = buffer2[tid] = 0;
+  for (int i = tid; i < num * spatial; i += blockDim.x) {
+    int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
+    buffer1[tid] += static_cast<float>(grad_output[index]) * norm[index];
+    buffer2[tid] += static_cast<float>(grad_output[index]);
+  }
+  __syncthreads();
+
+  for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+    if (tid < s) {
+      buffer1[tid] += buffer1[tid + s];
+      buffer2[tid] += buffer2[tid + s];
+    }
+    __syncthreads();
+  }
+  if (tid == 0) {
+    grad_weight[c] = buffer1[0];
+    grad_bias[c] = buffer2[0];
+  }
+}
+
+template <typename T>
+__global__ void sync_bn_backward_data_musa_kernel(
+    int output_size, const T *grad_output, const float *weight,
+    const float *grad_weight, const float *grad_bias, const float *norm,
+    const float *std, T *grad_input, int num, int channels, int spatial) {
+  int factor = num * spatial;
+  MUSA_1D_KERNEL_LOOP(index, output_size) {
+    int c = (index / spatial) % channels;
+    grad_input[index] =
+        weight[c] *
+        (grad_output[index] -
+         (grad_weight[c] * norm[index] + grad_bias[c]) / factor) /
+        std[c];
+  }
+}
+
+template <>
+__global__ void sync_bn_backward_data_musa_kernel(
+    int output_size, const phalf *grad_output, const float *weight,
+    const float *grad_weight, const float *grad_bias, const float *norm,
+    const float *std, phalf *grad_input, int num, int channels, int spatial) {
+  int factor = num * spatial;
+  MUSA_1D_KERNEL_LOOP(index, output_size) {
+    int c = (index / spatial) % channels;
+    grad_input[index] = static_cast<phalf>(
+        weight[c] *
+        (static_cast<float>(grad_output[index]) -
+         (grad_weight[c] * norm[index] + grad_bias[c]) / factor) /
+        std[c]);
+  }
+}
+
+#endif  // SYNCBN_MUSA_KERNEL_MUH
diff --git a/mmcv/ops/csrc/common/musa/three_interpolate_musa_kernel.muh b/mmcv/ops/csrc/common/musa/three_interpolate_musa_kernel.muh
new file mode 100644
index 000000000..4d5086ffd
--- /dev/null
+++ b/mmcv/ops/csrc/common/musa/three_interpolate_musa_kernel.muh
@@ -0,0 +1,57 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef THREE_INTERPOLATE_MUSA_KERNEL_MUH
+#define THREE_INTERPOLATE_MUSA_KERNEL_MUH
+
+#include "pytorch_musa_helper.hpp"
+
+template <typename T>
+__global__ void three_interpolate_forward_musa_kernel(
+    int b, int c, int m, int n, const T *points, const int *__restrict__ idx,
+    const T *weight, T *out) {
+  // points: (B, C, M)
+  // idx: (B, N, 3)
+  // weight: (B, N, 3)
+  // output:
+  //      out: (B, C, N)
+
+  int bs_idx = blockIdx.z;
+  int c_idx = blockIdx.y;
+  MUSA_1D_KERNEL_LOOP(pt_idx, n) {
+    if (bs_idx >= b || c_idx >= c) return;
+
+    weight += bs_idx * n * 3 + pt_idx * 3;
+    points += bs_idx * c * m + c_idx * m;
+    idx += bs_idx * n * 3 + pt_idx * 3;
+    out += bs_idx * c * n + c_idx * n;
+
+    out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] +
+                  weight[2] * points[idx[2]];
+  }
+}
+
+template <typename T>
+__global__ void three_interpolate_backward_musa_kernel(
+    int b, int c, int n, int m, const T *grad_out, const int *__restrict__ idx,
+    const T *weight, T *grad_points) {
+  // grad_out: (B, C, N)
+  // weight: (B, N, 3)
+  // output:
+  //      grad_points: (B, C, M)
+
+  int bs_idx = blockIdx.z;
+  int c_idx = blockIdx.y;
+  MUSA_1D_KERNEL_LOOP(pt_idx, n) {
+    if (bs_idx >= b || c_idx >= c) return;
+
+    grad_out += bs_idx * c * n + c_idx * n + pt_idx;
+    weight += bs_idx * n * 3 + pt_idx * 3;
+    grad_points += bs_idx * c * m + c_idx * m;
+    idx += bs_idx * n * 3 + pt_idx * 3;
+
+    atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
+    atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
+    atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
+  }
+}
+
+#endif  // THREE_INTERPOLATE_MUSA_KERNEL_MUH
diff --git a/mmcv/ops/csrc/common/musa/three_nn_musa_kernel.muh b/mmcv/ops/csrc/common/musa/three_nn_musa_kernel.muh
new file mode 100644
index 000000000..c25af0623
--- /dev/null
+++ b/mmcv/ops/csrc/common/musa/three_nn_musa_kernel.muh
@@ -0,0 +1,63 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef THREE_NN_MUSA_KERNEL_MUH
+#define THREE_NN_MUSA_KERNEL_MUH
+
+
+#include "pytorch_musa_helper.hpp"
+template <typename T>
+__global__ void three_nn_forward_musa_kernel(int b, int n, int m,
+                                             const T *unknown, const T *known,
+                                             T *dist2, int *__restrict__ idx) {
+  // unknown: (B, N, 3)
+  // known: (B, M, 3)
+  // output:
+  //      dist2: (B, N, 3)
+  //      idx: (B, N, 3)
+
+  int bs_idx = blockIdx.y;
+  MUSA_1D_KERNEL_LOOP(pt_idx, n) {
+    if (bs_idx >= b) return;
+
+    unknown += bs_idx * n * 3 + pt_idx * 3;
+    known += bs_idx * m * 3;
+    dist2 += bs_idx * n * 3 + pt_idx * 3;
+    idx += bs_idx * n * 3 + pt_idx * 3;
+
+    T ux = unknown[0];
+    T uy = unknown[1];
+    T uz = unknown[2];
+
+    double best1 = 1e40, best2 = 1e40, best3 = 1e40;
+    int besti1 = 0, besti2 = 0, besti3 = 0;
+    for (int k = 0; k < m; ++k) {
+      T x = known[k * 3 + 0];
+      T y = known[k * 3 + 1];
+      T z = known[k * 3 + 2];
+      T d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
+      if (d < best1) {
+        best3 = best2;
+        besti3 = besti2;
+        best2 = best1;
+        besti2 = besti1;
+        best1 = d;
+        besti1 = k;
+      } else if (d < best2) {
+        best3 = best2;
+        besti3 = besti2;
+        best2 = d;
+        besti2 = k;
+      } else if (d < best3) {
+        best3 = d;
+        besti3 = k;
+      }
+    }
+    dist2[0] = best1;
+    dist2[1] = best2;
+    dist2[2] = best3;
+    idx[0] = besti1;
+    idx[1] = besti2;
+    idx[2] = besti3;
+  }
+}
+
+#endif  // THREE_NN_MUSA_KERNEL_MUH
diff --git a/mmcv/ops/csrc/common/musa/tin_shift_musa_kernel.muh b/mmcv/ops/csrc/common/musa/tin_shift_musa_kernel.muh
new file mode 100644
index 000000000..ba460883c
--- /dev/null
+++ b/mmcv/ops/csrc/common/musa/tin_shift_musa_kernel.muh
@@ -0,0 +1,57 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef TIN_SHIFT_MUSA_KERNEL_MUH
+#define TIN_SHIFT_MUSA_KERNEL_MUH
+
+#include "pytorch_musa_helper.hpp"
+
+template <typename T>
+__global__ void tin_shift_forward_musa_kernel(
+    const int nthreads, const T* input, const int* shift, T* output,
+    const int batch_size, const int channels, const int t_size,
+    const int hw_size, const int group_size, const int group_channel) {
+  MUSA_1D_KERNEL_LOOP(index, nthreads) {
+    const int hw_index = index % hw_size;
+    const int j = (index / hw_size) % channels;
+
+    const int n_index = (index / hw_size / channels) % batch_size;
+    int group_id = j / group_channel;
+    int t_shift = shift[n_index * group_size + group_id];
+    int offset = n_index * t_size * hw_size * channels + hw_size * j + hw_index;
+    for (int i = 0; i < t_size; i++) {
+      int now_t = i + t_shift;
+      int data_id = i * hw_size * channels + offset;
+      if (now_t < 0 || now_t >= t_size) {
+        continue;
+      }
+      int out_id = now_t * hw_size * channels + offset;
+      output[out_id] = input[data_id];
+    }
+  }
+}
+
+template <typename T>
+__global__ void tin_shift_backward_musa_kernel(
+    const int nthreads, const T* input, const int* shift, T* output,
+    const int batch_size, const int channels, const int t_size,
+    const int hw_size, const int group_size, const int group_channel) {
+  MUSA_1D_KERNEL_LOOP(index, nthreads) {
+    const int hw_index = index % hw_size;
+    const int j = (index / hw_size) % channels;
+
+    const int n_index = (index / hw_size / channels) % batch_size;
+    int group_id = j / group_channel;
+    int t_shift = shift[n_index * group_size + group_id];
+    int offset = n_index * t_size * hw_size * channels + hw_size * j + hw_index;
+    for (int i = 0; i < t_size; i++) {
+      int now_t = i + t_shift;
+      int data_id = i * hw_size * channels + offset;
+      if (now_t < 0 || now_t >= t_size) {
+        continue;
+      }
+      int out_id = now_t * hw_size * channels + offset;
+      output[out_id] = input[data_id];
+    }
+  }
+}
+
+#endif  // TIN_SHIFT_MUSA_KERNEL_MUH
diff --git a/mmcv/ops/csrc/common/musa/voxelization_musa_kernel.muh b/mmcv/ops/csrc/common/musa/voxelization_musa_kernel.muh
new file mode 100644
index 000000000..24bc770f5
--- /dev/null
+++ b/mmcv/ops/csrc/common/musa/voxelization_musa_kernel.muh
@@ -0,0 +1,212 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+#ifndef VOXELIZATION_MUSA_KERNEL_MUH
+#define VOXELIZATION_MUSA_KERNEL_MUH
+
+#include "pytorch_musa_helper.hpp"
+
+typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t;
+
+template <typename T, typename T_int>
+__global__ void dynamic_voxelize_kernel(
+    const T* points, T_int* coors, const float voxel_x, const float voxel_y,
+    const float voxel_z, const float coors_x_min, const float coors_y_min,
+    const float coors_z_min, const float coors_x_max, const float coors_y_max,
+    const float coors_z_max, const int grid_x, const int grid_y,
+    const int grid_z, const int num_points, const int num_features,
+    const int NDim) {
+  //   const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
+  MUSA_1D_KERNEL_LOOP(index, num_points) {
+    // To save some computation
+    auto points_offset = points + index * num_features;
+    auto coors_offset = coors + index * NDim;
+    int c_x = floorf((points_offset[0] - coors_x_min) / voxel_x);
+    if (c_x < 0 || c_x >= grid_x) {
+      coors_offset[0] = -1;
+      continue;
+    }
+
+    int c_y = floorf((points_offset[1] - coors_y_min) / voxel_y);
+    if (c_y < 0 || c_y >= grid_y) {
+      coors_offset[0] = -1;
+      coors_offset[1] = -1;
+      continue;
+    }
+
+    int c_z = floorf((points_offset[2] - coors_z_min) / voxel_z);
+    if (c_z < 0 || c_z >= grid_z) {
+      coors_offset[0] = -1;
+      coors_offset[1] = -1;
+      coors_offset[2] = -1;
+    } else {
+      coors_offset[0] = c_z;
+      coors_offset[1] = c_y;
+      coors_offset[2] = c_x;
+    }
+  }
+}
+
+template <typename T, typename T_int>
+__global__ void assign_point_to_voxel(const int nthreads, const T* points,
+                                      T_int* point_to_voxelidx,
+                                      T_int* coor_to_voxelidx, T* voxels,
+                                      const int max_points,
+                                      const int num_features,
+                                      const int num_points, const int NDim) {
+  MUSA_1D_KERNEL_LOOP(thread_idx, nthreads) {
+    // const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
+    int index = thread_idx / num_features;
+
+    int num = point_to_voxelidx[index];
+    int voxelidx = coor_to_voxelidx[index];
+    if (num > -1 && voxelidx > -1) {
+      auto voxels_offset =
+          voxels + voxelidx * max_points * num_features + num * num_features;
+
+      int k = thread_idx % num_features;
+      voxels_offset[k] = points[thread_idx];
+    }
+  }
+}
+
+template <typename T, typename T_int>
+__global__ void assign_voxel_coors(const int nthreads, T_int* coor,
+                                   T_int* point_to_voxelidx,
+                                   T_int* coor_to_voxelidx, T_int* voxel_coors,
+                                   const int num_points, const int NDim) {
+  MUSA_1D_KERNEL_LOOP(thread_idx, nthreads) {
+    // const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
+    // if (index >= num_points) return;
+    int index = thread_idx / NDim;
+    int num = point_to_voxelidx[index];
+    int voxelidx = coor_to_voxelidx[index];
+    if (num == 0 && voxelidx > -1) {
+      auto coors_offset = voxel_coors + voxelidx * NDim;
+      int k = thread_idx % NDim;
+      coors_offset[k] = coor[thread_idx];
+    }
+  }
+}
+
+template <typename T_int>
+__global__ void point_to_voxelidx_kernel(const T_int* coor,
+                                         T_int* point_to_voxelidx,
+                                         T_int* point_to_pointidx,
+                                         const int max_points,
+                                         const int max_voxels,
+                                         const int num_points, const int NDim) {
+  MUSA_1D_KERNEL_LOOP(index, num_points) {
+    auto coor_offset = coor + index * NDim;
+    // skip invalid points
+    if (coor_offset[0] == -1) continue;
+
+    int num = 0;
+    int coor_x = coor_offset[0];
+    int coor_y = coor_offset[1];
+    int coor_z = coor_offset[2];
+    // only calculate the coors before this coor[index]
+    for (int i = 0; i < index; ++i) {
+      auto prev_coor = coor + i * NDim;
+      if (prev_coor[0] == -1) continue;
+
+      // Find all previous points that have the same coors
+      // if find the same coor, record it
+      if ((prev_coor[0] == coor_x) && (prev_coor[1] == coor_y) &&
+          (prev_coor[2] == coor_z)) {
+        num++;
+        if (num == 1) {
+          // point to the same coor that first show up
+          point_to_pointidx[index] = i;
+        } else if (num >= max_points) {
+          // out of boundary
+          break;
+        }
+      }
+    }
+    if (num == 0) {
+      point_to_pointidx[index] = index;
+    }
+    if (num < max_points) {
+      point_to_voxelidx[index] = num;
+    }
+  }
+}
+
+template <typename T_int>
+__global__ void determin_voxel_num(
+    // const T_int* coor,
+    T_int* num_points_per_voxel, T_int* point_to_voxelidx,
+    T_int* point_to_pointidx, T_int* coor_to_voxelidx, T_int* voxel_num,
+    const int max_points, const int max_voxels, const int num_points) {
+  // only calculate the coors before this coor[index]
+  for (int i = 0; i < num_points; ++i) {
+    int point_pos_in_voxel = point_to_voxelidx[i];
+    // record voxel
+    if (point_pos_in_voxel == -1) {
+      // out of max_points or invalid point
+      continue;
+    } else if (point_pos_in_voxel == 0) {
+      // record new voxel
+      int voxelidx = voxel_num[0];
+      if (voxel_num[0] >= max_voxels) continue;
+      voxel_num[0] += 1;
+      coor_to_voxelidx[i] = voxelidx;
+      num_points_per_voxel[voxelidx] = 1;
+    } else {
+      int point_idx = point_to_pointidx[i];
+      int voxelidx = coor_to_voxelidx[point_idx];
+      if (voxelidx != -1) {
+        coor_to_voxelidx[i] = voxelidx;
+        num_points_per_voxel[voxelidx] += 1;
+      }
+    }
+  }
+}
+
+__global__ void nondeterministic_get_assign_pos(
+    const int nthreads, const int32_t* coors_map, int32_t* pts_id,
+    int32_t* coors_count, int32_t* reduce_count, int32_t* coors_order) {
+  MUSA_1D_KERNEL_LOOP(thread_idx, nthreads) {
+    int coors_idx = coors_map[thread_idx];
+    if (coors_idx > -1) {
+      int32_t coors_pts_pos = atomicAdd(&reduce_count[coors_idx], 1);
+      pts_id[thread_idx] = coors_pts_pos;
+      if (coors_pts_pos == 0) {
+        coors_order[coors_idx] = atomicAdd(coors_count, 1);
+      }
+    }
+  }
+}
+
+template <typename T>
+__global__ void nondeterministic_assign_point_voxel(
+    const int nthreads, const T* points, const int32_t* coors_map,
+    const int32_t* pts_id, const int32_t* coors_in, const int32_t* reduce_count,
+    const int32_t* coors_order, T* voxels, int32_t* coors, int32_t* pts_count,
+    const int max_voxels, const int max_points, const int num_features,
+    const int NDim) {
+  MUSA_1D_KERNEL_LOOP(thread_idx, nthreads) {
+    int coors_idx = coors_map[thread_idx];
+    int coors_pts_pos = pts_id[thread_idx];
+    if (coors_idx > -1 && coors_pts_pos < max_points) {
+      int coors_pos = coors_order[coors_idx];
+      if (coors_pos < max_voxels) {
+        auto voxels_offset =
+            voxels + (coors_pos * max_points + coors_pts_pos) * num_features;
+        auto points_offset = points + thread_idx * num_features;
+        for (int k = 0; k < num_features; k++) {
+          voxels_offset[k] = points_offset[k];
+        }
+        if (coors_pts_pos == 0) {
+          pts_count[coors_pos] = min(reduce_count[coors_idx], max_points);
+          auto coors_offset = coors + coors_pos * NDim;
+          auto coors_in_offset = coors_in + coors_idx * NDim;
+          for (int k = 0; k < NDim; k++) {
+            coors_offset[k] = coors_in_offset[k];
+          }
+        }
+      }
+    }
+  }
+}
+
+#endif  // VOXELIZATION_MUSA_KERNEL_MUH
diff --git a/mmcv/ops/csrc/pytorch/musa/filtered_lrelu.mu b/mmcv/ops/csrc/pytorch/musa/filtered_lrelu.mu
new file mode 100644
index 000000000..b8994a768
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/filtered_lrelu.mu
@@ -0,0 +1,2052 @@
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto.  Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+#include <c10/util/Half.h>
+#include <musa_runtime.h>
+#include <torch/types.h>
+
+#include <cstdint>
+
+#include "pytorch_musa_helper.hpp"
+#include "pytorch_device_registry.hpp"
+
+//------------------------------------------------------------------------
+// MUSA kernel parameters.
+
+struct filtered_lrelu_kernel_params {
+  // These parameters decide which kernel to use.
+  int up;        // upsampling ratio (1, 2, 4)
+  int down;      // downsampling ratio (1, 2, 4)
+  int2 fuShape;  // [size, 1] | [size, size]
+  int2 fdShape;  // [size, 1] | [size, size]
+
+  int _dummy;  // Alignment.
+
+  // Rest of the parameters.
+  const void *x;     // Input tensor.
+  void *y;           // Output tensor.
+  const void *b;     // Bias tensor.
+  unsigned char *s;  // Sign tensor in/out. NULL if unused.
+  const float *fu;   // Upsampling filter.
+  const float *fd;   // Downsampling filter.
+
+  int2 pad0;    // Left/top padding.
+  float gain;   // Additional gain factor.
+  float slope;  // Leaky ReLU slope on negative side.
+  float clamp;  // Clamp after nonlinearity.
+  int flip;     // Filter kernel flip for gradient computation.
+
+  int tilesXdim;  // Original number of horizontal output tiles.
+  int tilesXrep;  // Number of horizontal tiles per CTA.
+  int blockZofs;  // Block z offset to support large minibatch, channel
+                  // dimensions.
+
+  int4 xShape;  // [width, height, channel, batch]
+  int4 yShape;  // [width, height, channel, batch]
+  int2 sShape;  // [width, height] - width is in bytes. Contiguous. Zeros if
+                // unused.
+  int2 sOfs;  // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
+  int swLimit;  // Active width of sign tensor in bytes.
+
+  longlong4 xStride;   // Strides of all tensors except signs, same component
+                       // order as shapes.
+  longlong4 yStride;   //
+  int64_t bStride;     //
+  longlong3 fuStride;  //
+  longlong3 fdStride;  //
+};
+
+struct filtered_lrelu_act_kernel_params {
+  void *x;           // Input/output, modified in-place.
+  unsigned char *s;  // Sign tensor in/out. NULL if unused.
+
+  float gain;   // Additional gain factor.
+  float slope;  // Leaky ReLU slope on negative side.
+  float clamp;  // Clamp after nonlinearity.
+
+  int4 xShape;        // [width, height, channel, batch]
+  longlong4 xStride;  // Input/output tensor strides, same order as in shape.
+  int2 sShape;  // [width, height] - width is in elements. Contiguous. Zeros if
+                // unused.
+  int2 sOfs;  // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
+};
+
+//------------------------------------------------------------------------
+// MUSA kernel specialization.
+
+struct filtered_lrelu_kernel_spec {
+  void *setup;   // Function for filter kernel setup.
+  void *exec;    // Function for main operation.
+  int2 tileOut;  // Width/height of launch tile.
+  int numWarps;  // Number of warps per thread block, determines launch block
+                 // size.
+  int xrep;      // For processing multiple horizontal tiles per thread block.
+  int dynamicSharedKB;  // How much dynamic shared memory the exec kernel wants.
+};
+
+//------------------------------------------------------------------------
+// MUSA kernel selection.
+
+template <class T, class index_t, bool signWrite, bool signRead>
+filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(
+    const filtered_lrelu_kernel_params &p, int sharedKB);
+template <class T, bool signWrite, bool signRead>
+void *choose_filtered_lrelu_act_kernel(void);
+
+//------------------------------------------------------------------------
+// Helpers.
+
+enum              // Filter modes.
+{ MODE_SUSD = 0,  // Separable upsampling, separable downsampling.
+  MODE_FUSD = 1,  // Full upsampling, separable downsampling.
+  MODE_SUFD = 2,  // Separable upsampling, full downsampling.
+  MODE_FUFD = 3,  // Full upsampling, full downsampling.
+};
+
+template <class T>
+struct InternalType;
+template <>
+struct InternalType<double> {
+  typedef double scalar_t;
+  typedef double2 vec2_t;
+  typedef double4 vec4_t;
+  __device__ __forceinline__ static vec2_t zero_vec2(void) {
+    return make_double2(0, 0);
+  }
+  __device__ __forceinline__ static vec4_t zero_vec4(void) {
+    return make_double4(0, 0, 0, 0);
+  }
+  __device__ __forceinline__ static double clamp(double x, double c) {
+    return fmin(fmax(x, -c), c);
+  }
+};
+template <>
+struct InternalType<float> {
+  typedef float scalar_t;
+  typedef float2 vec2_t;
+  typedef float4 vec4_t;
+  __device__ __forceinline__ static vec2_t zero_vec2(void) {
+    return make_float2(0, 0);
+  }
+  __device__ __forceinline__ static vec4_t zero_vec4(void) {
+    return make_float4(0, 0, 0, 0);
+  }
+  __device__ __forceinline__ static float clamp(float x, float c) {
+    return fminf(fmaxf(x, -c), c);
+  }
+};
+template <>
+struct InternalType<c10::Half> {
+  typedef float scalar_t;
+  typedef float2 vec2_t;
+  typedef float4 vec4_t;
+  __device__ __forceinline__ static vec2_t zero_vec2(void) {
+    return make_float2(0, 0);
+  }
+  __device__ __forceinline__ static vec4_t zero_vec4(void) {
+    return make_float4(0, 0, 0, 0);
+  }
+  __device__ __forceinline__ static float clamp(float x, float c) {
+    return fminf(fmaxf(x, -c), c);
+  }
+};
+
+#define MIN(A, B) ((A) < (B) ? (A) : (B))
+#define MAX(A, B) ((A) > (B) ? (A) : (B))
+#define CEIL_DIV(A, B)                                   \
+  (((B) == 1)                                            \
+       ? (A)                                             \
+       : ((B) == 2) ? ((int)((A) + 1) >> 1)              \
+                    : ((B) == 4) ? ((int)((A) + 3) >> 2) \
+                                 : (((A) + ((A) > 0 ? (B)-1 : 0)) / (B)))
+
+// This works only up to blocks of size 256 x 256 and for all N that are powers
+// of two.
+template <int N>
+__device__ __forceinline__ void fast_div_mod(int &x, int &y, unsigned int i) {
+  if ((N & (N - 1)) && N <= 256)
+    y = (i * ((1 << 24) / N + 1)) >> 24;  // Assumes N <= 256, i < N*256.
+  else
+    y = i / N;
+
+  x = i - y * N;
+}
+
+// Type cast stride before reading it.
+template <class T>
+__device__ __forceinline__ T get_stride(const int64_t &x) {
+  return *reinterpret_cast<const T *>(&x);
+}
+
+//------------------------------------------------------------------------
+// Filters, setup kernel, copying function.
+
+#define MAX_FILTER_SIZE 32
+
+// Combined up/down filter buffers so that transfer can be done with one copy.
+__device__ float
+    g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE];  // Filters in global memory,
+                                                    // written by setup kernel.
+__device__ __constant__ float
+    c_fbuf[2 * MAX_FILTER_SIZE *
+           MAX_FILTER_SIZE];  // Filters in constant memory, read by main
+                              // kernel.
+
+// Accessors to combined buffers to index up/down filters individually.
+#define c_fu (c_fbuf)
+#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
+#define g_fu (g_fbuf)
+#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
+
+// Set up filters into global memory buffer.
+static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p) {
+  for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE;
+       idx += blockDim.x) {
+    int x, y;
+    fast_div_mod<MAX_FILTER_SIZE>(x, y, idx);
+
+    int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
+    int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
+    if (p.fuShape.y > 0)
+      g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y)
+                      ? 0.0f
+                      : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
+    else
+      g_fu[idx] =
+          (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
+
+    int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
+    int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
+    if (p.fdShape.y > 0)
+      g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y)
+                      ? 0.0f
+                      : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
+    else
+      g_fd[idx] =
+          (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
+  }
+}
+
+// Host function to copy filters written by setup kernel into constant buffer
+// for main kernel.
+static musaError_t copy_filters(musaStream_t stream) {
+  void *src = 0;
+  musaError_t err = musaGetSymbolAddress(&src, g_fbuf);
+  if (err) return err;
+  return musaMemcpyToSymbolAsync(
+      c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0,
+      musaMemcpyDeviceToDevice, stream);
+}
+
+//------------------------------------------------------------------------
+// Coordinate spaces:
+// - Relative to input tensor:      inX, inY, tileInX, tileInY
+// - Relative to input tile:        relInX, relInY, tileInW, tileInH
+// - Relative to upsampled tile:    relUpX, relUpY, tileUpW, tileUpH
+// - Relative to output tile:       relOutX, relOutY, tileOutW, tileOutH
+// - Relative to output tensor:     outX, outY, tileOutX, tileOutY
+//
+// Relationships between coordinate spaces:
+// - inX = tileInX + relInX
+// - inY = tileInY + relInY
+// - relUpX = relInX * up + phaseInX
+// - relUpY = relInY * up + phaseInY
+// - relUpX = relOutX * down
+// - relUpY = relOutY * down
+// - outX = tileOutX + relOutX
+// - outY = tileOutY + relOutY
+
+extern __shared__ char
+    s_buf_raw[];  // When sharedKB <= 48, allocate shared memory statically
+                  // inside the kernel, otherwise use the externally allocated
+                  // shared memory buffer.
+
+template <class T, class index_t, int sharedKB, bool signWrite, bool signRead,
+          int filterMode, int up, int fuSize, int down, int fdSize,
+          int tileOutW, int tileOutH, int threadsPerBlock, bool enableXrep,
+          bool enableWriteSkip>
+static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
+  // Check that we don't try to support non-existing filter modes.
+  static_assert(up == 1 || up == 2 || up == 4,
+                "only up=1, up=2, up=4 scales supported");
+  static_assert(down == 1 || down == 2 || down == 4,
+                "only down=1, down=2, down=4 scales supported");
+  static_assert(fuSize >= up,
+                "upsampling filter size must be at least upsampling factor");
+  static_assert(
+      fdSize >= down,
+      "downsampling filter size must be at least downsampling factor");
+  static_assert(
+      fuSize % up == 0,
+      "upsampling filter size must be divisible with upsampling factor");
+  static_assert(
+      fdSize % down == 0,
+      "downsampling filter size must be divisible with downsampling factor");
+  static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE,
+                "filter size greater than MAX_FILTER_SIZE");
+  static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD ||
+                                            filterMode == MODE_FUSD)),
+                "up=1 supported only for 1x1 full filters");
+  static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD ||
+                                              filterMode == MODE_SUFD)),
+                "down=1 supported only for 1x1 full filters");
+  static_assert(
+      !(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)),
+      "full filters not supported for up=4");
+  static_assert(
+      !(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)),
+      "full filters not supported for down=4");
+
+  // Static definitions.
+  typedef typename InternalType<T>::scalar_t scalar_t;
+  typedef typename InternalType<T>::vec2_t vec2_t;
+  typedef typename InternalType<T>::vec4_t vec4_t;
+  const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) &
+                      ~3;  // Upsampled tile width, rounded up to multiple of 4.
+  const int tileUpH =
+      tileOutH * down + (fdSize - 1) - (down - 1);  // Upsampled tile height.
+  const int tileInW =
+      CEIL_DIV(tileUpW + (fuSize - 1), up);  // Input tile width.
+  const int tileInH =
+      CEIL_DIV(tileUpH + (fuSize - 1), up);  // Input tile height.
+  const int tileUpH_up =
+      CEIL_DIV(tileUpH, up) *
+      up;  // Upsampled tile height rounded up to a multiple of up.
+  const int tileInH_up =
+      CEIL_DIV(tileUpH_up + (fuSize - 1),
+               up);  // For allocations only, to avoid shared memory read
+                     // overruns with up=2 and up=4.
+
+  // Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
+  const bool downInline =
+      (down == 1) && ((up == 1 && filterMode == MODE_FUFD) ||
+                      (up == 2 && filterMode == MODE_SUFD));
+
+  // Sizes of logical buffers.
+  const int szIn = tileInH_up * tileInW;
+  const int szUpX = tileInH_up * tileUpW;
+  const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
+  const int szDownX = tileUpH * tileOutW;
+
+  // Sizes for shared memory arrays.
+  const int s_buf0_size_base =
+      (filterMode == MODE_SUSD)
+          ? MAX(szIn, szUpXY)
+          : (filterMode == MODE_FUSD)
+                ? MAX(szIn, szDownX)
+                : (filterMode == MODE_SUFD)
+                      ? MAX(szIn, szUpXY)
+                      : (filterMode == MODE_FUFD) ? szIn : -1;
+  const int s_buf1_size_base =
+      (filterMode == MODE_SUSD)
+          ? MAX(szUpX, szDownX)
+          : (filterMode == MODE_FUSD)
+                ? szUpXY
+                : (filterMode == MODE_SUFD)
+                      ? szUpX
+                      : (filterMode == MODE_FUFD) ? szUpXY : -1;
+
+  // Ensure U128 alignment.
+  const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
+  const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
+
+  // Check at compile time that we don't use too much shared memory.
+  static_assert(
+      (s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10),
+      "shared memory overflow");
+
+  // Declare shared memory arrays.
+  scalar_t *s_buf0;
+  scalar_t *s_buf1;
+  if (sharedKB <= 48) {
+    // Allocate shared memory arrays here.
+    __shared__ scalar_t
+        s_buf0_st[(sharedKB > 48)
+                      ? (1 << 24)
+                      : (s_buf0_size +
+                         s_buf1_size)];  // Prevent launching if this isn't
+                                         // optimized away when unused.
+    s_buf0 = s_buf0_st;
+    s_buf1 = s_buf0 + s_buf0_size;
+  } else {
+    // Use the dynamically allocated shared memory array.
+    s_buf0 = (scalar_t *)s_buf_raw;
+    s_buf1 = s_buf0 + s_buf0_size;
+  }
+
+  // Pointers to the buffers.
+  scalar_t *
+      s_tileIn;  // Input tile:                      [relInX * tileInH + relInY]
+  scalar_t *s_tileUpX;   // After horizontal upsampling:     [relInY * tileUpW +
+                         // relUpX]
+  scalar_t *s_tileUpXY;  // After upsampling:                [relUpY * tileUpW +
+                         // relUpX]
+  scalar_t *s_tileDownX;  // After horizontal downsampling:   [relUpY * tileOutW
+                          // + relOutX]
+  if (filterMode == MODE_SUSD) {
+    s_tileIn = s_buf0;
+    s_tileUpX = s_buf1;
+    s_tileUpXY = s_buf0;
+    s_tileDownX = s_buf1;
+  } else if (filterMode == MODE_FUSD) {
+    s_tileIn = s_buf0;
+    s_tileUpXY = s_buf1;
+    s_tileDownX = s_buf0;
+  } else if (filterMode == MODE_SUFD) {
+    s_tileIn = s_buf0;
+    s_tileUpX = s_buf1;
+    s_tileUpXY = s_buf0;
+  } else if (filterMode == MODE_FUFD) {
+    s_tileIn = s_buf0;
+    s_tileUpXY = s_buf1;
+  }
+
+  // Allow large grids in z direction via per-launch offset.
+  int channelIdx = blockIdx.z + p.blockZofs;
+  int batchIdx = channelIdx / p.yShape.z;
+  channelIdx -= batchIdx * p.yShape.z;
+
+  // Offset to output feature map. In bytes.
+  index_t mapOfsOut = channelIdx * get_stride<index_t>(p.yStride.z) +
+                      batchIdx * get_stride<index_t>(p.yStride.w);
+
+  // Sign shift amount.
+  uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
+
+// Inner tile loop.
+#pragma unroll 1
+  for (int tileIdx = 0;
+       !enableXrep ||
+       (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y));
+       tileIdx++) {
+    // Locate output tile.
+    int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
+    int tileOutX = tileX * tileOutW;
+    int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
+
+    // Locate input tile.
+    int tmpX = tileOutX * down - p.pad0.x;
+    int tmpY = tileOutY * down - p.pad0.y;
+    int tileInX = CEIL_DIV(tmpX, up);
+    int tileInY = CEIL_DIV(tmpY, up);
+    const int phaseInX = tileInX * up - tmpX;
+    const int phaseInY = tileInY * up - tmpY;
+
+    // Extra sync if input and output buffers are the same and we are not on
+    // first tile.
+    if (enableXrep && tileIdx > 0 &&
+        (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) ||
+         (filterMode == MODE_FUFD && downInline)))
+      __syncthreads();
+
+    // Load input tile & apply bias. Unrolled.
+    scalar_t b =
+        (scalar_t) * (const T *)((const char *)p.b +
+                                 (channelIdx * get_stride<index_t>(p.bStride)));
+    index_t mapOfsIn = channelIdx * get_stride<index_t>(p.xStride.z) +
+                       batchIdx * get_stride<index_t>(p.xStride.w);
+    int idx = threadIdx.x;
+    const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
+#pragma unroll
+    for (int loop = 0; loop < loopCountIN; loop++) {
+      int relInX, relInY;
+      fast_div_mod<tileInW>(relInX, relInY, idx);
+      int inX = tileInX + relInX;
+      int inY = tileInY + relInY;
+      scalar_t v = 0;
+
+      if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
+        v = (scalar_t) * ((const T *)((const char *)p.x +
+                                      (inX * get_stride<index_t>(p.xStride.x) +
+                                       inY * get_stride<index_t>(p.xStride.y) +
+                                       mapOfsIn))) +
+            b;
+
+      bool skip = (loop == loopCountIN - 1) && (idx >= tileInW * tileInH);
+      if (!skip) s_tileIn[idx] = v;
+
+      idx += threadsPerBlock;
+    }
+
+    if (filterMode == MODE_SUSD ||
+        filterMode == MODE_SUFD)  // Separable upsampling filter.
+    {
+      // Horizontal upsampling.
+      __syncthreads();
+      if (up == 4) {
+        for (int idx = threadIdx.x * up; idx < tileUpW * tileInH;
+             idx += blockDim.x * up) {
+          int relUpX0, relInY;
+          fast_div_mod<tileUpW>(relUpX0, relInY, idx);
+          int relInX0 = relUpX0 / up;
+          int src0 = relInX0 + tileInW * relInY;
+          int dst = relInY * tileUpW + relUpX0;
+          vec4_t v = InternalType<T>::zero_vec4();
+          scalar_t a = s_tileIn[src0];
+          if (phaseInX == 0) {
+#pragma unroll
+            for (int step = 0; step < fuSize / up; step++) {
+              v.x += a * (scalar_t)c_fu[step * up + 0];
+              a = s_tileIn[src0 + step + 1];
+              v.y += a * (scalar_t)c_fu[step * up + 3];
+              v.z += a * (scalar_t)c_fu[step * up + 2];
+              v.w += a * (scalar_t)c_fu[step * up + 1];
+            }
+          } else if (phaseInX == 1) {
+#pragma unroll
+            for (int step = 0; step < fuSize / up; step++) {
+              v.x += a * (scalar_t)c_fu[step * up + 1];
+              v.y += a * (scalar_t)c_fu[step * up + 0];
+              a = s_tileIn[src0 + step + 1];
+              v.z += a * (scalar_t)c_fu[step * up + 3];
+              v.w += a * (scalar_t)c_fu[step * up + 2];
+            }
+          } else if (phaseInX == 2) {
+#pragma unroll
+            for (int step = 0; step < fuSize / up; step++) {
+              v.x += a * (scalar_t)c_fu[step * up + 2];
+              v.y += a * (scalar_t)c_fu[step * up + 1];
+              v.z += a * (scalar_t)c_fu[step * up + 0];
+              a = s_tileIn[src0 + step + 1];
+              v.w += a * (scalar_t)c_fu[step * up + 3];
+            }
+          } else  // (phaseInX == 3)
+          {
+#pragma unroll
+            for (int step = 0; step < fuSize / up; step++) {
+              v.x += a * (scalar_t)c_fu[step * up + 3];
+              v.y += a * (scalar_t)c_fu[step * up + 2];
+              v.z += a * (scalar_t)c_fu[step * up + 1];
+              v.w += a * (scalar_t)c_fu[step * up + 0];
+              a = s_tileIn[src0 + step + 1];
+            }
+          }
+          s_tileUpX[dst + 0] = v.x;
+          s_tileUpX[dst + 1] = v.y;
+          s_tileUpX[dst + 2] = v.z;
+          s_tileUpX[dst + 3] = v.w;
+        }
+      } else if (up == 2) {
+        bool p0 = (phaseInX == 0);
+        for (int idx = threadIdx.x * up; idx < tileUpW * tileInH;
+             idx += blockDim.x * up) {
+          int relUpX0, relInY;
+          fast_div_mod<tileUpW>(relUpX0, relInY, idx);
+          int relInX0 = relUpX0 / up;
+          int src0 = relInX0 + tileInW * relInY;
+          int dst = relInY * tileUpW + relUpX0;
+          vec2_t v = InternalType<T>::zero_vec2();
+          scalar_t a = s_tileIn[src0];
+          if (p0)  // (phaseInX == 0)
+          {
+#pragma unroll
+            for (int step = 0; step < fuSize / up; step++) {
+              v.x += a * (scalar_t)c_fu[step * up + 0];
+              a = s_tileIn[src0 + step + 1];
+              v.y += a * (scalar_t)c_fu[step * up + 1];
+            }
+          } else  // (phaseInX == 1)
+          {
+#pragma unroll
+            for (int step = 0; step < fuSize / up; step++) {
+              v.x += a * (scalar_t)c_fu[step * up + 1];
+              v.y += a * (scalar_t)c_fu[step * up + 0];
+              a = s_tileIn[src0 + step + 1];
+            }
+          }
+          s_tileUpX[dst + 0] = v.x;
+          s_tileUpX[dst + 1] = v.y;
+        }
+      }
+
+      // Vertical upsampling & nonlinearity.
+
+      __syncthreads();
+      int groupMask = 15 << ((threadIdx.x & 31) & ~3);
+      int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH
+                          : 0;  // Skip already written signs.
+      int sShapeMaxY =
+          MIN(p.sShape.y,
+              tileOutY * down + tileUpH);  // Avoid out-of-tile sign writes.
+      if (up == 4) {
+        minY -= 3;  // Adjust according to block height.
+        for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up;
+             idx += blockDim.x) {
+          int relUpX, relInY0;
+          fast_div_mod<tileUpW>(relUpX, relInY0, idx);
+          int relUpY0 = relInY0 * up;
+          int src0 = relInY0 * tileUpW + relUpX;
+          int dst = relUpY0 * tileUpW + relUpX;
+          vec4_t v = InternalType<T>::zero_vec4();
+
+          scalar_t a = s_tileUpX[src0];
+          if (phaseInY == 0) {
+#pragma unroll
+            for (int step = 0; step < fuSize / up; step++) {
+              v.x += a * (scalar_t)c_fu[step * up + 0];
+              a = s_tileUpX[src0 + (step + 1) * tileUpW];
+              v.y += a * (scalar_t)c_fu[step * up + 3];
+              v.z += a * (scalar_t)c_fu[step * up + 2];
+              v.w += a * (scalar_t)c_fu[step * up + 1];
+            }
+          } else if (phaseInY == 1) {
+#pragma unroll
+            for (int step = 0; step < fuSize / up; step++) {
+              v.x += a * (scalar_t)c_fu[step * up + 1];
+              v.y += a * (scalar_t)c_fu[step * up + 0];
+              a = s_tileUpX[src0 + (step + 1) * tileUpW];
+              v.z += a * (scalar_t)c_fu[step * up + 3];
+              v.w += a * (scalar_t)c_fu[step * up + 2];
+            }
+          } else if (phaseInY == 2) {
+#pragma unroll
+            for (int step = 0; step < fuSize / up; step++) {
+              v.x += a * (scalar_t)c_fu[step * up + 2];
+              v.y += a * (scalar_t)c_fu[step * up + 1];
+              v.z += a * (scalar_t)c_fu[step * up + 0];
+              a = s_tileUpX[src0 + (step + 1) * tileUpW];
+              v.w += a * (scalar_t)c_fu[step * up + 3];
+            }
+          } else  // (phaseInY == 3)
+          {
+#pragma unroll
+            for (int step = 0; step < fuSize / up; step++) {
+              v.x += a * (scalar_t)c_fu[step * up + 3];
+              v.y += a * (scalar_t)c_fu[step * up + 2];
+              v.z += a * (scalar_t)c_fu[step * up + 1];
+              v.w += a * (scalar_t)c_fu[step * up + 0];
+              a = s_tileUpX[src0 + (step + 1) * tileUpW];
+            }
+          }
+
+          int x = tileOutX * down + relUpX;
+          int y = tileOutY * down + relUpY0;
+          int signX = x + p.sOfs.x;
+          int signY = y + p.sOfs.y;
+          int signZ = blockIdx.z + p.blockZofs;
+          int signXb = signX >> 2;
+          index_t si0 =
+              signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+          index_t si1 = si0 + p.sShape.x;
+          index_t si2 = si0 + p.sShape.x * 2;
+          index_t si3 = si0 + p.sShape.x * 3;
+
+          v.x *= (scalar_t)((float)up * (float)up * p.gain);
+          v.y *= (scalar_t)((float)up * (float)up * p.gain);
+          v.z *= (scalar_t)((float)up * (float)up * p.gain);
+          v.w *= (scalar_t)((float)up * (float)up * p.gain);
+
+          if (signWrite) {
+            if (!enableWriteSkip) {
+              // Determine and write signs.
+              int sx = __float_as_uint(v.x) >> 31 << 0;
+              int sy = __float_as_uint(v.y) >> 31 << 8;
+              int sz = __float_as_uint(v.z) >> 31 << 16;
+              int sw = __float_as_uint(v.w) >> 31 << 24;
+              if (sx) v.x *= p.slope;
+              if (sy) v.y *= p.slope;
+              if (sz) v.z *= p.slope;
+              if (sw) v.w *= p.slope;
+              if (fabsf(v.x) > p.clamp) {
+                sx = 2 << 0;
+                v.x = InternalType<T>::clamp(v.x, p.clamp);
+              }
+              if (fabsf(v.y) > p.clamp) {
+                sy = 2 << 8;
+                v.y = InternalType<T>::clamp(v.y, p.clamp);
+              }
+              if (fabsf(v.z) > p.clamp) {
+                sz = 2 << 16;
+                v.z = InternalType<T>::clamp(v.z, p.clamp);
+              }
+              if (fabsf(v.w) > p.clamp) {
+                sw = 2 << 24;
+                v.w = InternalType<T>::clamp(v.w, p.clamp);
+              }
+
+              if ((uint32_t)signXb < p.swLimit && signY >= minY) {
+                // Combine signs.
+                uint32_t s = sx + sy + sw + sz;
+                s <<= (signX & 3) << 1;
+#ifdef MMCV_WITH_HIP
+                s |= __shfl_xor(s, 1);
+                s |= __shfl_xor(s, 2);
+#else
+                s |= __shfl_xor_sync(groupMask, s, 1);
+                s |= __shfl_xor_sync(groupMask, s, 2);
+#endif
+
+                // Write signs.
+                if ((uint32_t)(signY + 0) < sShapeMaxY) {
+                  p.s[si0] = (unsigned char)(s >> 0);
+                }
+                if ((uint32_t)(signY + 1) < sShapeMaxY) {
+                  p.s[si1] = (unsigned char)(s >> 8);
+                }
+                if ((uint32_t)(signY + 2) < sShapeMaxY) {
+                  p.s[si2] = (unsigned char)(s >> 16);
+                }
+                if ((uint32_t)(signY + 3) < sShapeMaxY) {
+                  p.s[si3] = (unsigned char)(s >> 24);
+                }
+              }
+            } else {
+              // Determine and write signs.
+              if ((uint32_t)signXb < p.swLimit && signY >= minY) {
+                int sx = __float_as_uint(v.x) >> 31 << 0;
+                int sy = __float_as_uint(v.y) >> 31 << 8;
+                int sz = __float_as_uint(v.z) >> 31 << 16;
+                int sw = __float_as_uint(v.w) >> 31 << 24;
+                if (sx) v.x *= p.slope;
+                if (sy) v.y *= p.slope;
+                if (sz) v.z *= p.slope;
+                if (sw) v.w *= p.slope;
+                if (fabsf(v.x) > p.clamp) {
+                  sx = 2 << 0;
+                  v.x = InternalType<T>::clamp(v.x, p.clamp);
+                }
+                if (fabsf(v.y) > p.clamp) {
+                  sy = 2 << 8;
+                  v.y = InternalType<T>::clamp(v.y, p.clamp);
+                }
+                if (fabsf(v.z) > p.clamp) {
+                  sz = 2 << 16;
+                  v.z = InternalType<T>::clamp(v.z, p.clamp);
+                }
+                if (fabsf(v.w) > p.clamp) {
+                  sw = 2 << 24;
+                  v.w = InternalType<T>::clamp(v.w, p.clamp);
+                }
+
+                // Combine signs.
+                uint32_t s = sx + sy + sw + sz;
+                s <<= (signX & 3) << 1;
+#ifdef MMCV_WITH_HIP
+                s |= __shfl_xor(s, 1);
+                s |= __shfl_xor(s, 2);
+#else
+                s |= __shfl_xor_sync(groupMask, s, 1);
+                s |= __shfl_xor_sync(groupMask, s, 2);
+#endif
+
+                // Write signs.
+                if ((uint32_t)(signY + 0) < sShapeMaxY) {
+                  p.s[si0] = (unsigned char)(s >> 0);
+                }
+                if ((uint32_t)(signY + 1) < sShapeMaxY) {
+                  p.s[si1] = (unsigned char)(s >> 8);
+                }
+                if ((uint32_t)(signY + 2) < sShapeMaxY) {
+                  p.s[si2] = (unsigned char)(s >> 16);
+                }
+                if ((uint32_t)(signY + 3) < sShapeMaxY) {
+                  p.s[si3] = (unsigned char)(s >> 24);
+                }
+              } else {
+                // Just compute the values.
+                if (v.x < 0.f) v.x *= p.slope;
+                v.x = InternalType<T>::clamp(v.x, p.clamp);
+                if (v.y < 0.f) v.y *= p.slope;
+                v.y = InternalType<T>::clamp(v.y, p.clamp);
+                if (v.z < 0.f) v.z *= p.slope;
+                v.z = InternalType<T>::clamp(v.z, p.clamp);
+                if (v.w < 0.f) v.w *= p.slope;
+                v.w = InternalType<T>::clamp(v.w, p.clamp);
+              }
+            }
+          } else if (signRead)  // Read signs and apply.
+          {
+            if ((uint32_t)signXb < p.swLimit) {
+              int ss = (signX & 3) << 1;
+              if ((uint32_t)(signY + 0) < p.sShape.y) {
+                int s = p.s[si0] >> ss;
+                if (s & 1) v.x *= p.slope;
+                if (s & 2) v.x = 0.f;
+              }
+              if ((uint32_t)(signY + 1) < p.sShape.y) {
+                int s = p.s[si1] >> ss;
+                if (s & 1) v.y *= p.slope;
+                if (s & 2) v.y = 0.f;
+              }
+              if ((uint32_t)(signY + 2) < p.sShape.y) {
+                int s = p.s[si2] >> ss;
+                if (s & 1) v.z *= p.slope;
+                if (s & 2) v.z = 0.f;
+              }
+              if ((uint32_t)(signY + 3) < p.sShape.y) {
+                int s = p.s[si3] >> ss;
+                if (s & 1) v.w *= p.slope;
+                if (s & 2) v.w = 0.f;
+              }
+            }
+          } else  // Forward pass with no sign write.
+          {
+            if (v.x < 0.f) v.x *= p.slope;
+            v.x = InternalType<T>::clamp(v.x, p.clamp);
+            if (v.y < 0.f) v.y *= p.slope;
+            v.y = InternalType<T>::clamp(v.y, p.clamp);
+            if (v.z < 0.f) v.z *= p.slope;
+            v.z = InternalType<T>::clamp(v.z, p.clamp);
+            if (v.w < 0.f) v.w *= p.slope;
+            v.w = InternalType<T>::clamp(v.w, p.clamp);
+          }
+
+          s_tileUpXY[dst + 0 * tileUpW] = v.x;
+          if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
+          if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
+          if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
+        }
+      } else if (up == 2) {
+        minY -= 1;  // Adjust according to block height.
+        for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up;
+             idx += blockDim.x) {
+          int relUpX, relInY0;
+          fast_div_mod<tileUpW>(relUpX, relInY0, idx);
+          int relUpY0 = relInY0 * up;
+          int src0 = relInY0 * tileUpW + relUpX;
+          int dst = relUpY0 * tileUpW + relUpX;
+          vec2_t v = InternalType<T>::zero_vec2();
+
+          scalar_t a = s_tileUpX[src0];
+          if (phaseInY == 0) {
+#pragma unroll
+            for (int step = 0; step < fuSize / up; step++) {
+              v.x += a * (scalar_t)c_fu[step * up + 0];
+              a = s_tileUpX[src0 + (step + 1) * tileUpW];
+              v.y += a * (scalar_t)c_fu[step * up + 1];
+            }
+          } else  // (phaseInY == 1)
+          {
+#pragma unroll
+            for (int step = 0; step < fuSize / up; step++) {
+              v.x += a * (scalar_t)c_fu[step * up + 1];
+              v.y += a * (scalar_t)c_fu[step * up + 0];
+              a = s_tileUpX[src0 + (step + 1) * tileUpW];
+            }
+          }
+
+          int x = tileOutX * down + relUpX;
+          int y = tileOutY * down + relUpY0;
+          int signX = x + p.sOfs.x;
+          int signY = y + p.sOfs.y;
+          int signZ = blockIdx.z + p.blockZofs;
+          int signXb = signX >> 2;
+          index_t si0 =
+              signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+          index_t si1 = si0 + p.sShape.x;
+
+          v.x *= (scalar_t)((float)up * (float)up * p.gain);
+          v.y *= (scalar_t)((float)up * (float)up * p.gain);
+
+          if (signWrite) {
+            if (!enableWriteSkip) {
+              // Determine and write signs.
+              int sx = __float_as_uint(v.x) >> 31 << 0;
+              int sy = __float_as_uint(v.y) >> 31 << 8;
+              if (sx) v.x *= p.slope;
+              if (sy) v.y *= p.slope;
+              if (fabsf(v.x) > p.clamp) {
+                sx = 2 << 0;
+                v.x = InternalType<T>::clamp(v.x, p.clamp);
+              }
+              if (fabsf(v.y) > p.clamp) {
+                sy = 2 << 8;
+                v.y = InternalType<T>::clamp(v.y, p.clamp);
+              }
+
+              if ((uint32_t)signXb < p.swLimit && signY >= minY) {
+                // Combine signs.
+                int s = sx + sy;
+                s <<= signXo;
+#ifdef MMCV_WITH_HIP
+                s |= __shfl_xor(s, 1);
+                s |= __shfl_xor(s, 2);
+#else
+                s |= __shfl_xor_sync(groupMask, s, 1);
+                s |= __shfl_xor_sync(groupMask, s, 2);
+#endif
+
+                // Write signs.
+                if ((uint32_t)(signY + 0) < sShapeMaxY) {
+                  p.s[si0] = (unsigned char)(s >> 0);
+                }
+                if ((uint32_t)(signY + 1) < sShapeMaxY) {
+                  p.s[si1] = (unsigned char)(s >> 8);
+                }
+              }
+            } else {
+              // Determine and write signs.
+              if ((uint32_t)signXb < p.swLimit && signY >= minY) {
+                int sx = __float_as_uint(v.x) >> 31 << 0;
+                int sy = __float_as_uint(v.y) >> 31 << 8;
+                if (sx) v.x *= p.slope;
+                if (sy) v.y *= p.slope;
+                if (fabsf(v.x) > p.clamp) {
+                  sx = 2 << 0;
+                  v.x = InternalType<T>::clamp(v.x, p.clamp);
+                }
+                if (fabsf(v.y) > p.clamp) {
+                  sy = 2 << 8;
+                  v.y = InternalType<T>::clamp(v.y, p.clamp);
+                }
+
+                // Combine signs.
+                int s = sx + sy;
+                s <<= signXo;
+#ifdef MMCV_WITH_HIP
+                s |= __shfl_xor(s, 1);
+                s |= __shfl_xor(s, 2);
+#else
+                s |= __shfl_xor_sync(groupMask, s, 1);
+                s |= __shfl_xor_sync(groupMask, s, 2);
+#endif
+
+                // Write signs.
+                if ((uint32_t)(signY + 0) < sShapeMaxY) {
+                  p.s[si0] = (unsigned char)(s >> 0);
+                }
+                if ((uint32_t)(signY + 1) < sShapeMaxY) {
+                  p.s[si1] = (unsigned char)(s >> 8);
+                }
+              } else {
+                // Just compute the values.
+                if (v.x < 0.f) v.x *= p.slope;
+                v.x = InternalType<T>::clamp(v.x, p.clamp);
+                if (v.y < 0.f) v.y *= p.slope;
+                v.y = InternalType<T>::clamp(v.y, p.clamp);
+              }
+            }
+          } else if (signRead)  // Read signs and apply.
+          {
+            if ((uint32_t)signXb < p.swLimit) {
+              if ((uint32_t)(signY + 0) < p.sShape.y) {
+                int s = p.s[si0] >> signXo;
+                if (s & 1) v.x *= p.slope;
+                if (s & 2) v.x = 0.f;
+              }
+              if ((uint32_t)(signY + 1) < p.sShape.y) {
+                int s = p.s[si1] >> signXo;
+                if (s & 1) v.y *= p.slope;
+                if (s & 2) v.y = 0.f;
+              }
+            }
+          } else  // Forward pass with no sign write.
+          {
+            if (v.x < 0.f) v.x *= p.slope;
+            v.x = InternalType<T>::clamp(v.x, p.clamp);
+            if (v.y < 0.f) v.y *= p.slope;
+            v.y = InternalType<T>::clamp(v.y, p.clamp);
+          }
+
+          if (!downInline) {
+            // Write into temporary buffer.
+            s_tileUpXY[dst] = v.x;
+            if (relUpY0 < tileUpH - 1) s_tileUpXY[dst + tileUpW] = v.y;
+          } else {
+            // Write directly into output buffer.
+            if ((uint32_t)x < p.yShape.x) {
+              int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);
+              index_t ofs = x * get_stride<index_t>(p.yStride.x) +
+                            y * get_stride<index_t>(p.yStride.y) + mapOfsOut;
+              if ((uint32_t)y + 0 < p.yShape.y)
+                *((T *)((char *)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
+              if ((uint32_t)y + 1 < ymax)
+                *((T *)((char *)p.y + ofs + get_stride<index_t>(p.yStride.y))) =
+                    (T)(v.y * (scalar_t)c_fd[0]);
+            }
+          }
+        }
+      }
+    } else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD) {
+      // Full upsampling filter.
+
+      if (up == 2) {
+        // 2 x 2-wide.
+        __syncthreads();
+        int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y
+                            : 0;  // Skip already written signs.
+        for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH;
+             idx += blockDim.x * 4) {
+          int relUpX0, relUpY0;
+          fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
+          int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);
+          int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);
+          int src0 = relInX0 + tileInW * relInY0;
+          int tap0y = (relInY0 * up + phaseInY - relUpY0);
+
+#define X_LOOP(TAPY, PX)                                             \
+  for (int sx = 0; sx < fuSize / up; sx++) {                         \
+    v.x += a * (scalar_t)c_fu[(sx * up + (((PX)-0) & (up - 1))) +    \
+                              (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
+    v.z += b * (scalar_t)c_fu[(sx * up + (((PX)-0) & (up - 1))) +    \
+                              (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
+    if ((PX) == 0) {                                                 \
+      a = b;                                                         \
+      b = s_tileIn[src0 + 2 + sx + sy * tileInW];                    \
+    }                                                                \
+    v.y += a * (scalar_t)c_fu[(sx * up + (((PX)-1) & (up - 1))) +    \
+                              (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
+    v.w += b * (scalar_t)c_fu[(sx * up + (((PX)-1) & (up - 1))) +    \
+                              (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
+    if ((PX) == 1) {                                                 \
+      a = b;                                                         \
+      b = s_tileIn[src0 + 2 + sx + sy * tileInW];                    \
+    }                                                                \
+  }
+
+          vec4_t v = InternalType<T>::zero_vec4();
+          if (tap0y == 0 && phaseInX == 0)
+#pragma unroll
+            for (int sy = 0; sy < fuSize / up; sy++) {
+              scalar_t a = s_tileIn[src0 + sy * tileInW];
+              scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+#pragma unroll
+              X_LOOP(0, 0)
+            }
+          if (tap0y == 0 && phaseInX == 1)
+#pragma unroll
+            for (int sy = 0; sy < fuSize / up; sy++) {
+              scalar_t a = s_tileIn[src0 + sy * tileInW];
+              scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+#pragma unroll
+              X_LOOP(0, 1)
+            }
+          if (tap0y == 1 && phaseInX == 0)
+#pragma unroll
+            for (int sy = 0; sy < fuSize / up; sy++) {
+              scalar_t a = s_tileIn[src0 + sy * tileInW];
+              scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+#pragma unroll
+              X_LOOP(1, 0)
+            }
+          if (tap0y == 1 && phaseInX == 1)
+#pragma unroll
+            for (int sy = 0; sy < fuSize / up; sy++) {
+              scalar_t a = s_tileIn[src0 + sy * tileInW];
+              scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+#pragma unroll
+              X_LOOP(1, 1)
+            }
+
+#undef X_LOOP
+
+          int x = tileOutX * down + relUpX0;
+          int y = tileOutY * down + relUpY0;
+          int signX = x + p.sOfs.x;
+          int signY = y + p.sOfs.y;
+          int signZ = blockIdx.z + p.blockZofs;
+          int signXb = signX >> 2;
+          index_t si =
+              signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+
+          v.x *= (scalar_t)((float)up * (float)up * p.gain);
+          v.y *= (scalar_t)((float)up * (float)up * p.gain);
+          v.z *= (scalar_t)((float)up * (float)up * p.gain);
+          v.w *= (scalar_t)((float)up * (float)up * p.gain);
+
+          if (signWrite) {
+            if (!enableWriteSkip) {
+              // Determine and write signs.
+              int sx = __float_as_uint(v.x) >> 31;
+              int sy = __float_as_uint(v.y) >> 31;
+              int sz = __float_as_uint(v.z) >> 31;
+              int sw = __float_as_uint(v.w) >> 31;
+              if (sx) v.x *= p.slope;
+              if (fabsf(v.x) > p.clamp) {
+                sx = 2;
+                v.x = InternalType<T>::clamp(v.x, p.clamp);
+              }
+              if (sy) v.y *= p.slope;
+              if (fabsf(v.y) > p.clamp) {
+                sy = 2;
+                v.y = InternalType<T>::clamp(v.y, p.clamp);
+              }
+              if (sz) v.z *= p.slope;
+              if (fabsf(v.z) > p.clamp) {
+                sz = 2;
+                v.z = InternalType<T>::clamp(v.z, p.clamp);
+              }
+              if (sw) v.w *= p.slope;
+              if (fabsf(v.w) > p.clamp) {
+                sw = 2;
+                v.w = InternalType<T>::clamp(v.w, p.clamp);
+              }
+
+              if ((uint32_t)signXb < p.swLimit &&
+                  (uint32_t)signY < p.sShape.y && signY >= minY) {
+                p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
+              }
+            } else {
+              // Determine and write signs.
+              if ((uint32_t)signXb < p.swLimit &&
+                  (uint32_t)signY < p.sShape.y && signY >= minY) {
+                int sx = __float_as_uint(v.x) >> 31;
+                int sy = __float_as_uint(v.y) >> 31;
+                int sz = __float_as_uint(v.z) >> 31;
+                int sw = __float_as_uint(v.w) >> 31;
+                if (sx) v.x *= p.slope;
+                if (fabsf(v.x) > p.clamp) {
+                  sx = 2;
+                  v.x = InternalType<T>::clamp(v.x, p.clamp);
+                }
+                if (sy) v.y *= p.slope;
+                if (fabsf(v.y) > p.clamp) {
+                  sy = 2;
+                  v.y = InternalType<T>::clamp(v.y, p.clamp);
+                }
+                if (sz) v.z *= p.slope;
+                if (fabsf(v.z) > p.clamp) {
+                  sz = 2;
+                  v.z = InternalType<T>::clamp(v.z, p.clamp);
+                }
+                if (sw) v.w *= p.slope;
+                if (fabsf(v.w) > p.clamp) {
+                  sw = 2;
+                  v.w = InternalType<T>::clamp(v.w, p.clamp);
+                }
+
+                p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
+              } else {
+                // Just compute the values.
+                if (v.x < 0.f) v.x *= p.slope;
+                v.x = InternalType<T>::clamp(v.x, p.clamp);
+                if (v.y < 0.f) v.y *= p.slope;
+                v.y = InternalType<T>::clamp(v.y, p.clamp);
+                if (v.z < 0.f) v.z *= p.slope;
+                v.z = InternalType<T>::clamp(v.z, p.clamp);
+                if (v.w < 0.f) v.w *= p.slope;
+                v.w = InternalType<T>::clamp(v.w, p.clamp);
+              }
+            }
+          } else if (signRead)  // Read sign and apply.
+          {
+            if ((uint32_t)signY < p.sShape.y) {
+              int s = 0;
+              if ((uint32_t)signXb < p.swLimit) s = p.s[si];
+              if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;
+              s >>= (signX & 3) << 1;
+              if (s & 0x01) v.x *= p.slope;
+              if (s & 0x02) v.x = 0.f;
+              if (s & 0x04) v.y *= p.slope;
+              if (s & 0x08) v.y = 0.f;
+              if (s & 0x10) v.z *= p.slope;
+              if (s & 0x20) v.z = 0.f;
+              if (s & 0x40) v.w *= p.slope;
+              if (s & 0x80) v.w = 0.f;
+            }
+          } else  // Forward pass with no sign write.
+          {
+            if (v.x < 0.f) v.x *= p.slope;
+            v.x = InternalType<T>::clamp(v.x, p.clamp);
+            if (v.y < 0.f) v.y *= p.slope;
+            v.y = InternalType<T>::clamp(v.y, p.clamp);
+            if (v.z < 0.f) v.z *= p.slope;
+            v.z = InternalType<T>::clamp(v.z, p.clamp);
+            if (v.w < 0.f) v.w *= p.slope;
+            v.w = InternalType<T>::clamp(v.w, p.clamp);
+          }
+
+          s_tileUpXY[idx + 0] = v.x;
+          s_tileUpXY[idx + 1] = v.y;
+          s_tileUpXY[idx + 2] = v.z;
+          s_tileUpXY[idx + 3] = v.w;
+        }
+      } else if (up == 1) {
+        __syncthreads();
+        uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);
+        int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH
+                            : 0;  // Skip already written signs.
+        for (int idx = threadIdx.x; idx < tileUpW * tileUpH;
+             idx += blockDim.x) {
+          int relUpX0, relUpY0;
+          fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
+          scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0];  // 1x1 filter.
+
+          int x = tileOutX * down + relUpX0;
+          int y = tileOutY * down + relUpY0;
+          int signX = x + p.sOfs.x;
+          int signY = y + p.sOfs.y;
+          int signZ = blockIdx.z + p.blockZofs;
+          int signXb = signX >> 2;
+          index_t si =
+              signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+          v *= (scalar_t)((float)up * (float)up * p.gain);
+
+          if (signWrite) {
+            if (!enableWriteSkip) {
+              // Determine and write sign.
+              uint32_t s = 0;
+              uint32_t signXbit = (1u << signXo);
+              if (v < 0.f) {
+                s = signXbit;
+                v *= p.slope;
+              }
+              if (fabsf(v) > p.clamp) {
+                s = signXbit * 2;
+                v = InternalType<T>::clamp(v, p.clamp);
+              }
+              if ((uint32_t)signXb < p.swLimit &&
+                  (uint32_t)signY < p.sShape.y && signY >= minY) {
+#ifdef MMCV_WITH_HIP
+                s += __shfl_xor(s, 1);  // Coalesce.
+                s += __shfl_xor(s, 2);  // Coalesce.
+#else
+                s += __shfl_xor_sync(groupMask, s, 1);  // Coalesce.
+                s += __shfl_xor_sync(groupMask, s, 2);  // Coalesce.
+#endif
+                p.s[si] = s;  // Write.
+              }
+            } else {
+              // Determine and write sign.
+              if ((uint32_t)signXb < p.swLimit &&
+                  (uint32_t)signY < p.sShape.y && signY >= minY) {
+                uint32_t s = 0;
+                uint32_t signXbit = (1u << signXo);
+                if (v < 0.f) {
+                  s = signXbit;
+                  v *= p.slope;
+                }
+                if (fabsf(v) > p.clamp) {
+                  s = signXbit * 2;
+                  v = InternalType<T>::clamp(v, p.clamp);
+                }
+#ifdef MMCV_WITH_HIP
+                s += __shfl_xor(s, 1);  // Coalesce.
+                s += __shfl_xor(s, 2);  // Coalesce.
+#else
+                s += __shfl_xor_sync(groupMask, s, 1);  // Coalesce.
+                s += __shfl_xor_sync(groupMask, s, 2);  // Coalesce.
+#endif
+                p.s[si] = s;  // Write.
+              } else {
+                // Just compute the value.
+                if (v < 0.f) v *= p.slope;
+                v = InternalType<T>::clamp(v, p.clamp);
+              }
+            }
+          } else if (signRead) {
+            // Read sign and apply if within sign tensor bounds.
+            if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y) {
+              int s = p.s[si];
+              s >>= signXo;
+              if (s & 1) v *= p.slope;
+              if (s & 2) v = 0.f;
+            }
+          } else  // Forward pass with no sign write.
+          {
+            if (v < 0.f) v *= p.slope;
+            v = InternalType<T>::clamp(v, p.clamp);
+          }
+
+          if (!downInline)  // Write into temporary buffer.
+            s_tileUpXY[idx] = v;
+          else if ((uint32_t)x < p.yShape.x &&
+                   (uint32_t)y <
+                       p.yShape.y)  // Write directly into output buffer
+            *((T *)((char *)p.y + (x * get_stride<index_t>(p.yStride.x) +
+                                   y * get_stride<index_t>(p.yStride.y) +
+                                   mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
+        }
+      }
+    }
+
+    // Downsampling.
+    if (filterMode == MODE_SUSD || filterMode == MODE_FUSD) {
+      // Horizontal downsampling.
+      __syncthreads();
+      if (down == 4 && tileOutW % 4 == 0) {
+        // Calculate 4 pixels at a time.
+        for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH;
+             idx += blockDim.x * 4) {
+          int relOutX0, relUpY;
+          fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
+          int relUpX0 = relOutX0 * down;
+          int src0 = relUpY * tileUpW + relUpX0;
+          vec4_t v = InternalType<T>::zero_vec4();
+#pragma unroll
+          for (int step = 0; step < fdSize; step++) {
+            v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
+            v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step];
+            v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step];
+            v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];
+          }
+          s_tileDownX[idx + 0] = v.x;
+          s_tileDownX[idx + 1] = v.y;
+          s_tileDownX[idx + 2] = v.z;
+          s_tileDownX[idx + 3] = v.w;
+        }
+      } else if ((down == 2 || down == 4) && (tileOutW % 2 == 0)) {
+        // Calculate 2 pixels at a time.
+        for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH;
+             idx += blockDim.x * 2) {
+          int relOutX0, relUpY;
+          fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
+          int relUpX0 = relOutX0 * down;
+          int src0 = relUpY * tileUpW + relUpX0;
+          vec2_t v = InternalType<T>::zero_vec2();
+#pragma unroll
+          for (int step = 0; step < fdSize; step++) {
+            v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
+            v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];
+          }
+          s_tileDownX[idx + 0] = v.x;
+          s_tileDownX[idx + 1] = v.y;
+        }
+      } else {
+        // Calculate 1 pixel at a time.
+        for (int idx = threadIdx.x; idx < tileOutW * tileUpH;
+             idx += blockDim.x) {
+          int relOutX0, relUpY;
+          fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
+          int relUpX0 = relOutX0 * down;
+          int src = relUpY * tileUpW + relUpX0;
+          scalar_t v = 0.f;
+#pragma unroll
+          for (int step = 0; step < fdSize; step++)
+            v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];
+          s_tileDownX[idx] = v;
+        }
+      }
+
+      // Vertical downsampling & store output tile.
+      __syncthreads();
+      for (int idx = threadIdx.x; idx < tileOutW * tileOutH;
+           idx += blockDim.x) {
+        int relOutX, relOutY0;
+        fast_div_mod<tileOutW>(relOutX, relOutY0, idx);
+        int relUpY0 = relOutY0 * down;
+        int src0 = relUpY0 * tileOutW + relOutX;
+        scalar_t v = 0;
+#pragma unroll
+        for (int step = 0; step < fdSize; step++)
+          v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];
+
+        int outX = tileOutX + relOutX;
+        int outY = tileOutY + relOutY0;
+
+        if (outX < p.yShape.x & outY < p.yShape.y)
+          *((T *)((char *)p.y + (outX * get_stride<index_t>(p.yStride.x) +
+                                 outY * get_stride<index_t>(p.yStride.y) +
+                                 mapOfsOut))) = (T)v;
+      }
+    } else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD) {
+      // Full downsampling filter.
+      if (down == 2) {
+        // 2-wide.
+        __syncthreads();
+        for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH;
+             idx += blockDim.x * 2) {
+          int relOutX0, relOutY0;
+          fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
+          int relUpX0 = relOutX0 * down;
+          int relUpY0 = relOutY0 * down;
+          int src0 = relUpY0 * tileUpW + relUpX0;
+          vec2_t v = InternalType<T>::zero_vec2();
+#pragma unroll
+          for (int sy = 0; sy < fdSize; sy++)
+#pragma unroll
+            for (int sx = 0; sx < fdSize; sx++) {
+              v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] *
+                     (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
+              v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] *
+                     (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
+            }
+
+          int outX = tileOutX + relOutX0;
+          int outY = tileOutY + relOutY0;
+          if ((uint32_t)outY < p.yShape.y) {
+            index_t ofs = outX * get_stride<index_t>(p.yStride.x) +
+                          outY * get_stride<index_t>(p.yStride.y) + mapOfsOut;
+            if (outX + 0 < p.yShape.x) *((T *)((char *)p.y + ofs)) = (T)v.x;
+            if (outX + 1 < p.yShape.x)
+              *((T *)((char *)p.y + ofs + get_stride<index_t>(p.yStride.x))) =
+                  (T)v.y;
+          }
+        }
+      } else if (down == 1 && !downInline) {
+        // Thread per pixel.
+        __syncthreads();
+        for (int idx = threadIdx.x; idx < tileOutW * tileOutH;
+             idx += blockDim.x) {
+          int relOutX0, relOutY0;
+          fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
+          scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0];  // 1x1 filter.
+
+          int outX = tileOutX + relOutX0;
+          int outY = tileOutY + relOutY0;
+          if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
+            *((T *)((char *)p.y + (outX * get_stride<index_t>(p.yStride.x) +
+                                   outY * get_stride<index_t>(p.yStride.y) +
+                                   mapOfsOut))) = (T)v;
+        }
+      }
+    }
+
+    if (!enableXrep) break;
+  }
+}
+
+//------------------------------------------------------------------------
+// Compute activation function and signs for upsampled data tensor, modifying
+// data tensor in-place. Used for accelerating the generic variant. Sign tensor
+// is known to be contiguous, and p.x and p.s have the same z, w dimensions.
+// 64-bit indexing is always used.
+
+template <class T, bool signWrite, bool signRead>
+static __global__ void filtered_lrelu_act_kernel(
+    filtered_lrelu_act_kernel_params p) {
+  typedef typename InternalType<T>::scalar_t scalar_t;
+
+  // Indexing.
+  int32_t x = threadIdx.x + blockIdx.x * blockDim.x;
+  int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;
+  int32_t qmax =
+      p.xShape.z * p.xShape.w;  // Combined minibatch*channel maximum index.
+
+  // Loop to accommodate oversized tensors.
+  for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)
+    for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y) {
+      // Extract z and w (channel, minibatch index).
+      int32_t w = q / p.xShape.z;
+      int32_t z = q - w * p.xShape.z;
+
+      // Choose behavior based on sign read/write mode.
+      if (signWrite) {
+        // Process value if in p.x.
+        uint32_t s = 0;
+        if (x < p.xShape.x && y < p.xShape.y) {
+          int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z +
+                       w * p.xStride.w;
+          T *pv = ((T *)p.x) + ix;
+          scalar_t v = (scalar_t)(*pv);
+
+          // Gain, LReLU, clamp.
+          v *= p.gain;
+          if (v < 0.f) {
+            v *= p.slope;
+            s = 1;  // Sign.
+          }
+          if (fabsf(v) > p.clamp) {
+            v = InternalType<T>::clamp(v, p.clamp);
+            s = 2;  // Clamp.
+          }
+
+          *pv = (T)v;  // Write value.
+        }
+
+        // Coalesce into threads 0 and 16 of warp.
+        uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
+        s <<= ((threadIdx.x & 15) << 1);  // Shift into place.
+#ifdef MMCV_WITH_HIP
+        s |= __shfl_xor(s, 1);  // Distribute.
+        s |= __shfl_xor(s, 2);
+        s |= __shfl_xor(s, 4);
+        s |= __shfl_xor(s, 8);
+#else
+        s |= __shfl_xor_sync(m, s, 1);                  // Distribute.
+        s |= __shfl_xor_sync(m, s, 2);
+        s |= __shfl_xor_sync(m, s, 4);
+        s |= __shfl_xor_sync(m, s, 8);
+#endif
+
+        // Write signs if leader and in p.s.
+        if (!(threadIdx.x & 15) && x < p.sShape.x)  // y is always in.
+        {
+          uint64_t is =
+              x + p.sShape.x * (y + (int64_t)p.sShape.y * q);  // Contiguous.
+          ((uint32_t *)p.s)[is >> 4] = s;
+        }
+      } else if (signRead) {
+        // Process value if in p.x.
+        if (x < p.xShape.x)  // y is always in.
+        {
+          int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z +
+                       w * p.xStride.w;
+          T *pv = ((T *)p.x) + ix;
+          scalar_t v = (scalar_t)(*pv);
+          v *= p.gain;
+
+          // Apply sign buffer offset.
+          uint32_t sx = x + p.sOfs.x;
+          uint32_t sy = y + p.sOfs.y;
+
+          // Read and apply signs if we land inside valid region of sign buffer.
+          if (sx < p.sShape.x && sy < p.sShape.y) {
+            uint64_t is =
+                (sx >> 2) + (p.sShape.x >> 2) *
+                                (sy + (uint64_t)p.sShape.y * q);  // Contiguous.
+            unsigned char s = p.s[is];
+            s >>= (sx & 3) << 1;  // Shift into place.
+            if (s & 1)            // Sign?
+              v *= p.slope;
+            if (s & 2)  // Clamp?
+              v = 0.f;
+          }
+
+          *pv = (T)v;  // Write value.
+        }
+      } else {
+        // Forward pass with no sign write. Process value if in p.x.
+        if (x < p.xShape.x)  // y is always in.
+        {
+          int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z +
+                       w * p.xStride.w;
+          T *pv = ((T *)p.x) + ix;
+          scalar_t v = (scalar_t)(*pv);
+          v *= p.gain;
+          if (v < 0.f) v *= p.slope;
+          if (fabsf(v) > p.clamp) v = InternalType<T>::clamp(v, p.clamp);
+          *pv = (T)v;  // Write value.
+        }
+      }
+    }
+}
+
+template <class T, bool signWrite, bool signRead>
+void *choose_filtered_lrelu_act_kernel(void) {
+  return (void *)filtered_lrelu_act_kernel<T, signWrite, signRead>;
+}
+
+//------------------------------------------------------------------------
+// MUSA kernel selection.
+
+template <class T, class index_t, bool signWrite, bool signRead>
+filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(
+    const filtered_lrelu_kernel_params &p, int sharedKB) {
+  filtered_lrelu_kernel_spec s = {0};
+
+  // Return the first matching kernel.
+#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS)                        \
+  if (sharedKB >= SH)                                                          \
+    if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) ||      \
+        (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD)))         \
+      if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) ||    \
+          (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD)))       \
+        if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU &&             \
+            p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) {           \
+          static_assert((D * TW % 4) == 0,                                     \
+                        "down * tileWidth must be divisible by 4");            \
+          static_assert(                                                       \
+              FU % U == 0,                                                     \
+              "upscaling filter size must be multiple of upscaling factor");   \
+          static_assert(FD % D == 0,                                           \
+                        "downscaling filter size must be multiple of "         \
+                        "downscaling factor");                                 \
+          s.setup = (void *)setup_filters_kernel;                              \
+          s.exec = (void *)                                                    \
+              filtered_lrelu_kernel<T, index_t, SH, signWrite, signRead, MODE, \
+                                    U, FU, D, FD, TW, TH, W * 32, !!XR, !!WS>; \
+          s.tileOut = make_int2(TW, TH);                                       \
+          s.numWarps = W;                                                      \
+          s.xrep = XR;                                                         \
+          s.dynamicSharedKB = (SH == 48) ? 0 : SH;                             \
+          return s;                                                            \
+        }
+
+  // Launch parameters for various kernel specializations.
+  // Small filters must be listed before large filters, otherwise the kernel for
+  // larger filter will always match first. Kernels that use more shared memory
+  // must be listed before those that use less, for the same reason.
+
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 1, 1, /*down,fd*/ 1, 1, /*mode*/ MODE_FUFD,
+       /*tw,th,warps,xrep,wskip*/ 64, 178, 32, 0, 0)  // 1t-upf1-downf1
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 1, 1, /*mode*/ MODE_SUFD,
+       /*tw,th,warps,xrep,wskip*/ 152, 95, 16, 0, 0)  // 4t-ups2-downf1
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 1, 1, /*down,fd*/ 2, 8, /*mode*/ MODE_FUSD,
+       /*tw,th,warps,xrep,wskip*/ 56, 22, 16, 0, 0)  // 4t-upf1-downs2
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 2, 8, /*mode*/ MODE_SUSD,
+       /*tw,th,warps,xrep,wskip*/ 56, 29, 16, 11, 0)  // 4t-ups2-downs2
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 2, 8, /*mode*/ MODE_FUSD,
+       /*tw,th,warps,xrep,wskip*/ 60, 28, 16, 0, 0)  // 4t-upf2-downs2
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 2, 8, /*mode*/ MODE_SUFD,
+       /*tw,th,warps,xrep,wskip*/ 56, 28, 16, 0, 0)  // 4t-ups2-downf2
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 16, /*down,fd*/ 2, 8, /*mode*/ MODE_SUSD,
+       /*tw,th,warps,xrep,wskip*/ 56, 31, 16, 11, 0)  // 4t-ups4-downs2
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 16, /*down,fd*/ 2, 8, /*mode*/ MODE_SUFD,
+       /*tw,th,warps,xrep,wskip*/ 56, 36, 16, 0, 0)  // 4t-ups4-downf2
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 4, 16, /*mode*/ MODE_SUSD,
+       /*tw,th,warps,xrep,wskip*/ 16, 22, 16, 12, 0)  // 4t-ups2-downs4
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 4, 16, /*mode*/ MODE_FUSD,
+       /*tw,th,warps,xrep,wskip*/ 29, 15, 16, 0, 0)  // 4t-upf2-downs4
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 1, 1, /*mode*/ MODE_SUFD,
+       /*tw,th,warps,xrep,wskip*/ 96, 150, 28, 0, 0)  // 6t-ups2-downf1
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 1, 1, /*down,fd*/ 2, 12, /*mode*/ MODE_FUSD,
+       /*tw,th,warps,xrep,wskip*/ 32, 35, 24, 0, 0)  // 6t-upf1-downs2
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 2, 12, /*mode*/ MODE_SUSD,
+       /*tw,th,warps,xrep,wskip*/ 32, 46, 16, 10, 0)  // 6t-ups2-downs2
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 2, 12, /*mode*/ MODE_FUSD,
+       /*tw,th,warps,xrep,wskip*/ 58, 28, 24, 8, 0)  // 6t-upf2-downs2
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 2, 12, /*mode*/ MODE_SUFD,
+       /*tw,th,warps,xrep,wskip*/ 52, 28, 16, 0, 0)  // 6t-ups2-downf2
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 24, /*down,fd*/ 2, 12, /*mode*/ MODE_SUSD,
+       /*tw,th,warps,xrep,wskip*/ 32, 51, 16, 5, 0)  // 6t-ups4-downs2
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 24, /*down,fd*/ 2, 12, /*mode*/ MODE_SUFD,
+       /*tw,th,warps,xrep,wskip*/ 32, 56, 16, 6, 0)  // 6t-ups4-downf2
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 4, 24, /*mode*/ MODE_SUSD,
+       /*tw,th,warps,xrep,wskip*/ 16, 18, 16, 12, 0)  // 6t-ups2-downs4
+  CASE(/*sharedKB*/ 96, /*up,fu*/ 2, 12, /*down,fd*/ 4, 24, /*mode*/ MODE_FUSD,
+       /*tw,th,warps,xrep,wskip*/ 27, 31, 32, 6, 0)  // 6t-upf2-downs4 96kB
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 4, 24, /*mode*/ MODE_FUSD,
+       /*tw,th,warps,xrep,wskip*/ 27, 13, 24, 0, 0)  // 6t-upf2-downs4
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 1, 1, /*mode*/ MODE_SUFD,
+       /*tw,th,warps,xrep,wskip*/ 148, 89, 24, 0, 0)  // 8t-ups2-downf1
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 1, 1, /*down,fd*/ 2, 16, /*mode*/ MODE_FUSD,
+       /*tw,th,warps,xrep,wskip*/ 32, 31, 16, 5, 0)  // 8t-upf1-downs2
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 2, 16, /*mode*/ MODE_SUSD,
+       /*tw,th,warps,xrep,wskip*/ 32, 41, 16, 9, 0)  // 8t-ups2-downs2
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 2, 16, /*mode*/ MODE_FUSD,
+       /*tw,th,warps,xrep,wskip*/ 56, 26, 24, 0, 0)  // 8t-upf2-downs2
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 2, 16, /*mode*/ MODE_SUFD,
+       /*tw,th,warps,xrep,wskip*/ 32, 40, 16, 0, 0)  // 8t-ups2-downf2
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 32, /*down,fd*/ 2, 16, /*mode*/ MODE_SUSD,
+       /*tw,th,warps,xrep,wskip*/ 32, 46, 24, 5, 0)  // 8t-ups4-downs2
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 32, /*down,fd*/ 2, 16, /*mode*/ MODE_SUFD,
+       /*tw,th,warps,xrep,wskip*/ 32, 50, 16, 0, 0)  // 8t-ups4-downf2
+  CASE(/*sharedKB*/ 96, /*up,fu*/ 2, 16, /*down,fd*/ 4, 32, /*mode*/ MODE_SUSD,
+       /*tw,th,warps,xrep,wskip*/ 24, 24, 32, 12, 1)  // 8t-ups2-downs4 96kB
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 4, 32, /*mode*/ MODE_SUSD,
+       /*tw,th,warps,xrep,wskip*/ 16, 13, 16, 10, 1)  // 8t-ups2-downs4
+  CASE(/*sharedKB*/ 96, /*up,fu*/ 2, 16, /*down,fd*/ 4, 32, /*mode*/ MODE_FUSD,
+       /*tw,th,warps,xrep,wskip*/ 25, 28, 28, 4, 0)  // 8t-upf2-downs4 96kB
+  CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 4, 32, /*mode*/ MODE_FUSD,
+       /*tw,th,warps,xrep,wskip*/ 25, 10, 24, 0, 0)  // 8t-upf2-downs4
+
+#undef CASE
+  return s;  // No kernel found.
+}
+
+//------------------------------------------------------------------------
+
+#define BUILD_FILTERED_LRELU_OP 1
+
+#ifndef MMCV_WITH_HIP
+#ifdef __GNUC__
+#if __GNUC__ < 6
+#undef BUILD_FILTERED_LRELU_OP
+#define BUILD_FILTERED_LRELU_OP 0
+#endif
+#endif
+
+std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
+    torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b,
+    torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1,
+    int sx, int sy, float gain, float slope, float clamp, bool flip_filters,
+    bool writeSigns) {
+  // Set MUSA device.
+  TORCH_CHECK(x.is_privateuseone(), "x must reside on MUSA device");
+  const at::musa::OptionalMUSAGuard device_guard(device_of(x));
+
+  // Validate arguments.
+  TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() &&
+                  b.device() == x.device(),
+              "all input tensors must reside on the same device");
+  TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat,
+              "fu and fd must be float32");
+  TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
+  TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat,
+              "x and b must be float16 or float32");
+  TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+  TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX &&
+                  x.size(3) <= INT_MAX,
+              "x is too large");
+  TORCH_CHECK(x.numel() > 0, "x is empty");
+  TORCH_CHECK(
+      (fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2),
+      "fu and fd must be rank 1 or 2");
+  TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX,
+              "fu is too large");
+  TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX,
+              "fd is too large");
+  TORCH_CHECK(fu.numel() > 0, "fu is empty");
+  TORCH_CHECK(fd.numel() > 0, "fd is empty");
+  TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1),
+              "b must be a vector with the same number of channels as x");
+  TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
+
+  // Figure out how much shared memory is available on the device.
+  int maxSharedBytes = 0;
+#ifdef MMCV_WITH_HIP
+  musaDeviceGetAttribute(&maxSharedBytes,
+                         hipDeviceAttributeSharedMemPerBlockOptin,
+                         x.device().index());
+#else
+  AT_MUSA_CHECK(musaDeviceGetAttribute(&maxSharedBytes,
+                                       musaDevAttrMaxSharedMemoryPerBlockOptin,
+                                       x.device().index()));
+#endif
+  int sharedKB = maxSharedBytes >> 10;
+
+  // Populate enough launch parameters to check if a MUSA kernel exists.
+  filtered_lrelu_kernel_params p;
+  p.up = up;
+  p.down = down;
+  p.fuShape =
+      make_int2((int)fu.size(-1),
+                fu.dim() == 2 ? (int)fu.size(0)
+                              : 0);  // shape [n, 0] indicates separable filter.
+  p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
+  filtered_lrelu_kernel_spec test_spec =
+      choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB);
+  if (!test_spec.exec) {
+    // No kernel found - return empty tensors and indicate missing kernel with
+    // return code of -1.
+    return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
+  }
+
+  // Input/output element size.
+  int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
+
+  // Input sizes.
+  int64_t xw = (int)x.size(3);
+  int64_t xh = (int)x.size(2);
+  int64_t fut_w = (int)fu.size(-1) - 1;
+  int64_t fut_h = (int)fu.size(0) - 1;
+  int64_t fdt_w = (int)fd.size(-1) - 1;
+  int64_t fdt_h = (int)fd.size(0) - 1;
+
+  // Logical size of upsampled buffer.
+  int64_t cw = xw * up + (px0 + px1) - fut_w;
+  int64_t ch = xh * up + (py0 + py1) - fut_h;
+  TORCH_CHECK(
+      cw > fdt_w && ch > fdt_h,
+      "upsampled buffer must be at least the size of downsampling filter");
+  TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
+
+  // Compute output size and allocate.
+  int64_t yw = (cw - fdt_w + (down - 1)) / down;
+  int64_t yh = (ch - fdt_h + (down - 1)) / down;
+  TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
+  TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
+  torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(),
+                                 x.suggest_memory_format());
+
+  // Allocate sign tensor.
+  torch::Tensor so;
+  torch::Tensor s = si;
+  bool readSigns = !!s.numel();
+  int64_t sw_active = 0;  // Active width of sign tensor.
+  if (writeSigns) {
+    sw_active = yw * down - (down - 1) + fdt_w;   // Active width in elements.
+    int64_t sh = yh * down - (down - 1) + fdt_h;  // Height = active height.
+    int64_t sw = (sw_active + 15) & ~15;  // Width  = active width in elements,
+                                          // rounded up to multiple of 16.
+    TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
+    s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2},
+                          x.options().dtype(torch::kUInt8),
+                          at::MemoryFormat::Contiguous);
+  } else if (readSigns)
+    sw_active = s.size(3) << 2;
+
+  // Validate sign tensor if in use.
+  if (readSigns || writeSigns) {
+    TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
+    TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
+    TORCH_CHECK(s.device() == x.device(),
+                "signs must reside on the same device as x");
+    TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
+    TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1),
+                "signs must have same batch & channels as x");
+    TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX,
+                "signs is too large");
+  }
+
+  // Populate rest of MUSA kernel parameters.
+  p.x = x.data_ptr();
+  p.y = y.data_ptr();
+  p.b = b.data_ptr();
+  p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
+  p.fu = fu.data_ptr<float>();
+  p.fd = fd.data_ptr<float>();
+  p.pad0 = make_int2(px0, py0);
+  p.gain = gain;
+  p.slope = slope;
+  p.clamp = clamp;
+  p.flip = (flip_filters) ? 1 : 0;
+  p.xShape =
+      make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+  p.yShape =
+      make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
+  p.sShape = (readSigns || writeSigns)
+                 ? make_int2((int)s.size(3), (int)s.size(2))
+                 : make_int2(0, 0);  // Width is in bytes. Contiguous.
+  p.sOfs = make_int2(sx, sy);
+  p.swLimit = (sw_active + 3) >> 2;  // Rounded up to bytes.
+
+  // x, y, b strides are in bytes.
+  p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2),
+                             sz * x.stride(1), sz * x.stride(0));
+  p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2),
+                             sz * y.stride(1), sz * y.stride(0));
+  p.bStride = sz * b.stride(0);
+
+  // fu, fd strides are in elements.
+  p.fuStride =
+      make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
+  p.fdStride =
+      make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
+
+  // Determine if indices don't fit in int32. Support negative strides although
+  // Torch currently never produces those.
+  bool index64b = false;
+  if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
+  if (std::min(x.size(0) * p.xStride.w, 0ll) +
+          std::min(x.size(1) * p.xStride.z, 0ll) +
+          std::min(x.size(2) * p.xStride.y, 0ll) +
+          std::min(x.size(3) * p.xStride.x, 0ll) <
+      -INT_MAX)
+    index64b = true;
+  if (std::max(x.size(0) * p.xStride.w, 0ll) +
+          std::max(x.size(1) * p.xStride.z, 0ll) +
+          std::max(x.size(2) * p.xStride.y, 0ll) +
+          std::max(x.size(3) * p.xStride.x, 0ll) >
+      INT_MAX)
+    index64b = true;
+  if (std::min(y.size(0) * p.yStride.w, 0ll) +
+          std::min(y.size(1) * p.yStride.z, 0ll) +
+          std::min(y.size(2) * p.yStride.y, 0ll) +
+          std::min(y.size(3) * p.yStride.x, 0ll) <
+      -INT_MAX)
+    index64b = true;
+  if (std::max(y.size(0) * p.yStride.w, 0ll) +
+          std::max(y.size(1) * p.yStride.z, 0ll) +
+          std::max(y.size(2) * p.yStride.y, 0ll) +
+          std::max(y.size(3) * p.yStride.x, 0ll) >
+      INT_MAX)
+    index64b = true;
+  if (s.numel() > INT_MAX) index64b = true;
+
+  // Choose MUSA kernel.
+  filtered_lrelu_kernel_spec spec = {0};
+  AT_DISPATCH_FLOATING_TYPES(
+      x.scalar_type(), "filtered_lrelu_musa", [&] {
+        if constexpr (sizeof(scalar_t) <=
+                      4)  // Exclude doubles. constexpr
+                          // prevents template instantiation.
+        {
+          // Choose kernel based on index type, datatype and sign read/write
+          // modes.
+          if (!index64b && writeSigns && !readSigns)
+            spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true, false>(
+                p, sharedKB);
+          else if (!index64b && !writeSigns && readSigns)
+            spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true>(
+                p, sharedKB);
+          else if (!index64b && !writeSigns && !readSigns)
+            spec =
+                choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(
+                    p, sharedKB);
+          else if (index64b && writeSigns && !readSigns)
+            spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true, false>(
+                p, sharedKB);
+          else if (index64b && !writeSigns && readSigns)
+            spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true>(
+                p, sharedKB);
+          else if (index64b && !writeSigns && !readSigns)
+            spec =
+                choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(
+                    p, sharedKB);
+        }
+      });
+  TORCH_CHECK(
+      spec.exec,
+      "internal error - MUSA kernel not found")  // This should not happen
+                                                 // because we tested earlier
+                                                 // that kernel exists.
+
+  // Launch MUSA kernel.
+  void *args[] = {&p};
+  int bx = spec.numWarps * 32;
+  int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
+  int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
+  int gz = p.yShape.z * p.yShape.w;
+
+  // Repeat multiple horizontal tiles in a CTA?
+  if (spec.xrep) {
+    p.tilesXrep = spec.xrep;
+    p.tilesXdim = gx;
+
+    gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
+    std::swap(gx, gy);
+  } else {
+    p.tilesXrep = 0;
+    p.tilesXdim = 0;
+  }
+#ifdef MMCV_WITH_HIP
+  AT_MUSA_CHECK(hipLaunchKernel(spec.setup, 1, 1024, args, 0,
+                                c10::musa::getCurrentMUSAStream()));
+#else
+  // Launch filter setup kernel.
+  AT_MUSA_CHECK(musaLaunchKernel(spec.setup, 1, 1024, args, 0,
+                                 c10::musa::getCurrentMUSAStream()));
+#endif
+
+  // Copy kernels to constant memory.
+  if (writeSigns && !readSigns)
+    AT_MUSA_CHECK((copy_filters(c10::musa::getCurrentMUSAStream())));
+  else if (!writeSigns && readSigns)
+    AT_MUSA_CHECK((copy_filters(c10::musa::getCurrentMUSAStream())));
+  else if (!writeSigns && !readSigns)
+    AT_MUSA_CHECK((copy_filters(c10::musa::getCurrentMUSAStream())));
+
+  // Set cache and shared memory configurations for main kernel.
+  // FIXME:TODO FIX BUG
+  AT_MUSA_CHECK(musaFuncSetCacheConfig(spec.exec, musaFuncCachePreferShared));
+  if (spec.dynamicSharedKB)  // Need dynamically allocated shared memory?
+#ifdef MMCV_WITH_HIP
+    AT_MUSA_CHECK(hipFuncSetAttribute(
+        spec.exec, hipFuncAttributeMaxDynamicSharedMemorySize,
+        spec.dynamicSharedKB << 10));
+#else
+    AT_MUSA_CHECK(musaFuncSetAttribute(
+        spec.exec, musaFuncAttributeMaxDynamicSharedMemorySize,
+        spec.dynamicSharedKB << 10));
+#endif
+  // FIXME:TODO FIX BUG
+  AT_MUSA_CHECK(
+      musaFuncSetSharedMemConfig(spec.exec, musaSharedMemBankSizeFourByte));
+
+  // Launch main kernel.
+  const int maxSubGz = 65535;  // MUSA maximum for block z dimension.
+  for (int zofs = 0; zofs < gz;
+       zofs += maxSubGz)  // Do multiple launches if gz is too big.
+  {
+    p.blockZofs = zofs;
+    int subGz = std::min(maxSubGz, gz - zofs);
+// FIXME:TODO FIX BUG
+#ifdef MMCV_WITH_HIP
+    AT_MUSA_CHECK(hipLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args,
+                                  spec.dynamicSharedKB << 10,
+                                  c10::musa::getCurrentMUSAStream()));
+#else
+    AT_MUSA_CHECK(musaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args,
+                                   spec.dynamicSharedKB << 10,
+                                   c10::musa::getCurrentMUSAStream()));
+#endif
+  }
+
+  // Done.
+  return std::make_tuple(y, so, 0);
+}
+
+std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op_impl(
+    torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b,
+    torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1,
+    int sx, int sy, float gain, float slope, float clamp, bool flip_filters,
+    bool writeSigns);
+
+REGISTER_DEVICE_IMPL(filtered_lrelu_op_impl, MUSA, filtered_lrelu_op);
+
+#else
+
+#pragma message(                           \
+    "filtered_lrelu_op is not available. " \
+    "Please update your compiler and musa version.")
+
+#endif
+#undef BUILD_FILTERED_LRELU_OP
+
+//------------------------------------------------------------------------
+
+torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx,
+                                    int sy, float gain, float slope,
+                                    float clamp, bool writeSigns) {
+  // Set MUSA device.
+  TORCH_CHECK(x.is_privateuseone(), "x must reside on MUSA device");
+  const at::musa::OptionalMUSAGuard device_guard(device_of(x));
+
+  // Validate arguments.
+  TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+  TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX &&
+                  x.size(3) <= INT_MAX,
+              "x is too large");
+  TORCH_CHECK(x.numel() > 0, "x is empty");
+  TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat ||
+                  x.dtype() == torch::kDouble,
+              "x must be float16, float32 or float64");
+
+  // Output signs if we don't have sign input.
+  torch::Tensor so;
+  torch::Tensor s = si;
+  bool readSigns = !!s.numel();
+  if (writeSigns) {
+    int64_t sw = x.size(3);
+    sw = (sw + 15) & ~15;  // Round to a multiple of 16 for coalescing.
+    s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2},
+                          x.options().dtype(torch::kUInt8),
+                          at::MemoryFormat::Contiguous);
+  }
+
+  // Validate sign tensor if in use.
+  if (readSigns || writeSigns) {
+    TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
+    TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
+    TORCH_CHECK(s.device() == x.device(),
+                "signs must reside on the same device as x");
+    TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
+    TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1),
+                "signs must have same batch & channels as x");
+    TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX,
+                "signs tensor is too large");
+  }
+
+  // Initialize MUSA kernel parameters.
+  filtered_lrelu_act_kernel_params p;
+  p.x = x.data_ptr();
+  p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
+  p.gain = gain;
+  p.slope = slope;
+  p.clamp = clamp;
+  p.xShape =
+      make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+  p.xStride =
+      make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
+  p.sShape = (readSigns || writeSigns)
+                 ? make_int2((int)s.size(3) << 2, (int)s.size(2))
+                 : make_int2(0, 0);  // Width is in elements. Contiguous.
+  p.sOfs = make_int2(sx, sy);
+
+  // Choose MUSA kernel.
+  void *func = 0;
+  AT_DISPATCH_FLOATING_TYPES(
+      x.scalar_type(), "filtered_lrelu_act_musa", [&] {
+        if (writeSigns)
+          func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>();
+        else if (readSigns)
+          func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>();
+        else
+          func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>();
+      });
+  TORCH_CHECK(func, "internal error - MUSA kernel not found");
+
+  // Launch MUSA kernel.
+  void *args[] = {&p};
+  int bx = 128;  // 4 warps per block.
+
+  // Logical size of launch = writeSigns ? p.s : p.x
+  uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
+  uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
+  uint32_t gz =
+      p.xShape.z * p.xShape.w;  // Same as in p.sShape if signs are in use.
+  gx = (gx - 1) / bx + 1;
+
+  // Make sure grid y and z dimensions are within MUSA launch limits. Kernel
+  // loops internally to do the rest.
+  const uint32_t gmax = 65535;
+  gy = std::min(gy, gmax);
+  gz = std::min(gz, gmax);
+
+  // Launch.
+#ifdef MMCV_WITH_HIP
+  AT_MUSA_CHECK(hipLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0,
+                                c10::musa::getCurrentMUSAStream()));
+#else
+  AT_MUSA_CHECK(musaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0,
+                                 c10::musa::getCurrentMUSAStream()));
+#endif
+
+  return so;
+}
diff --git a/mmcv/ops/csrc/pytorch/musa/musabind.cpp b/mmcv/ops/csrc/pytorch/musa/musabind.cpp
index ebbb69275..889574f95 100644
--- a/mmcv/ops/csrc/pytorch/musa/musabind.cpp
+++ b/mmcv/ops/csrc/pytorch/musa/musabind.cpp
@@ -540,6 +540,17 @@ torch::Tensor bias_act_op(const torch::Tensor &input, const torch::Tensor &bias,
 
 REGISTER_DEVICE_IMPL(bias_act_op_impl, MUSA, bias_act_op);
 
+torch::Tensor filtered_lrelu_act_op_impl(torch::Tensor x, torch::Tensor si,
+                                         int sx, int sy, float gain,
+                                         float slope, float clamp,
+                                         bool writeSigns);
+
+torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx,
+                                    int sy, float gain, float slope,
+                                    float clamp, bool writeSigns);
+
+REGISTER_DEVICE_IMPL(filtered_lrelu_act_op_impl, MUSA, filtered_lrelu_act_op);
+
 void GatherPointsForwardMUSAKernelLauncher(int b, int c, int n, int npoints,
                                            const Tensor points,
                                            const Tensor idx, Tensor out);
@@ -869,6 +880,854 @@ Tensor nms_musa(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
 Tensor nms_impl(Tensor boxes, Tensor scores, float iou_threshold, int offset);
 REGISTER_DEVICE_IMPL(nms_impl, MUSA, nms_musa);
 
+void PointsInBoxesPartForwardMUSAKernelLauncher(int batch_size, int boxes_num,
+                                                int pts_num, const Tensor boxes,
+                                                const Tensor pts,
+                                                Tensor box_idx_of_points);
+
+void PointsInBoxesAllForwardMUSAKernelLauncher(int batch_size, int boxes_num,
+                                               int pts_num, const Tensor boxes,
+                                               const Tensor pts,
+                                               Tensor box_idx_of_points);
+
+void points_in_boxes_part_forward_musa(int batch_size, int boxes_num,
+                                       int pts_num, const Tensor boxes,
+                                       const Tensor pts,
+                                       Tensor box_idx_of_points) {
+  PointsInBoxesPartForwardMUSAKernelLauncher(batch_size, boxes_num, pts_num,
+                                             boxes, pts, box_idx_of_points);
+};
+
+void points_in_boxes_all_forward_musa(int batch_size, int boxes_num,
+                                      int pts_num, const Tensor boxes,
+                                      const Tensor pts,
+                                      Tensor box_idx_of_points) {
+  PointsInBoxesAllForwardMUSAKernelLauncher(batch_size, boxes_num, pts_num,
+                                            boxes, pts, box_idx_of_points);
+};
+
+void points_in_boxes_part_forward_impl(int batch_size, int boxes_num,
+                                       int pts_num, const Tensor boxes,
+                                       const Tensor pts,
+                                       Tensor box_idx_of_points);
+
+void points_in_boxes_all_forward_impl(int batch_size, int boxes_num,
+                                      int pts_num, const Tensor boxes,
+                                      const Tensor pts,
+                                      Tensor box_idx_of_points);
+REGISTER_DEVICE_IMPL(points_in_boxes_part_forward_impl, MUSA,
+                     points_in_boxes_part_forward_musa);
+REGISTER_DEVICE_IMPL(points_in_boxes_all_forward_impl, MUSA,
+                     points_in_boxes_all_forward_musa);
+
+void PSAMaskForwardMUSAKernelLauncher(const int psa_type, const Tensor input,
+                                      Tensor output, const int num_,
+                                      const int h_feature, const int w_feature,
+                                      const int h_mask, const int w_mask,
+                                      const int half_h_mask,
+                                      const int half_w_mask);
+
+void PSAMaskBackwardMUSAKernelLauncher(
+    const int psa_type, const Tensor grad_output, Tensor grad_input,
+    const int num_, const int h_feature, const int w_feature, const int h_mask,
+    const int w_mask, const int half_h_mask, const int half_w_mask);
+
+void psamask_forward_musa(const int psa_type, const Tensor input, Tensor output,
+                          const int num_, const int h_feature,
+                          const int w_feature, const int h_mask,
+                          const int w_mask, const int half_h_mask,
+                          const int half_w_mask) {
+  PSAMaskForwardMUSAKernelLauncher(psa_type, input, output, num_, h_feature,
+                                   w_feature, h_mask, w_mask, half_h_mask,
+                                   half_w_mask);
+}
+
+void psamask_backward_musa(const int psa_type, const Tensor grad_output,
+                           Tensor grad_input, const int num_,
+                           const int h_feature, const int w_feature,
+                           const int h_mask, const int w_mask,
+                           const int half_h_mask, const int half_w_mask) {
+  PSAMaskBackwardMUSAKernelLauncher(psa_type, grad_output, grad_input, num_,
+                                    h_feature, w_feature, h_mask, w_mask,
+                                    half_h_mask, half_w_mask);
+}
+
+void psamask_forward_impl(const int psa_type, const Tensor input, Tensor output,
+                          const int num_, const int h_feature,
+                          const int w_feature, const int h_mask,
+                          const int w_mask, const int half_h_mask,
+                          const int half_w_mask);
+
+void psamask_backward_impl(const int psa_type, const Tensor grad_output,
+                           Tensor grad_input, const int num_,
+                           const int h_feature, const int w_feature,
+                           const int h_mask, const int w_mask,
+                           const int half_h_mask, const int half_w_mask);
+REGISTER_DEVICE_IMPL(psamask_forward_impl, MUSA, psamask_forward_musa);
+REGISTER_DEVICE_IMPL(psamask_backward_impl, MUSA, psamask_backward_musa);
+
+void ROIAlignForwardMUSAKernelLauncher(Tensor input, Tensor rois, Tensor output,
+                                       Tensor argmax_y, Tensor argmax_x,
+                                       int aligned_height, int aligned_width,
+                                       float spatial_scale, int sampling_ratio,
+                                       int pool_mode, bool aligned);
+
+void ROIAlignBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois,
+                                        Tensor argmax_y, Tensor argmax_x,
+                                        Tensor grad_input, int aligned_height,
+                                        int aligned_width, float spatial_scale,
+                                        int sampling_ratio, int pool_mode,
+                                        bool aligned);
+
+void roi_align_forward_musa(Tensor input, Tensor rois, Tensor output,
+                            Tensor argmax_y, Tensor argmax_x,
+                            int aligned_height, int aligned_width,
+                            float spatial_scale, int sampling_ratio,
+                            int pool_mode, bool aligned) {
+  ROIAlignForwardMUSAKernelLauncher(
+      input, rois, output, argmax_y, argmax_x, aligned_height, aligned_width,
+      spatial_scale, sampling_ratio, pool_mode, aligned);
+}
+
+void roi_align_backward_musa(Tensor grad_output, Tensor rois, Tensor argmax_y,
+                             Tensor argmax_x, Tensor grad_input,
+                             int aligned_height, int aligned_width,
+                             float spatial_scale, int sampling_ratio,
+                             int pool_mode, bool aligned) {
+  ROIAlignBackwardMUSAKernelLauncher(
+      grad_output, rois, argmax_y, argmax_x, grad_input, aligned_height,
+      aligned_width, spatial_scale, sampling_ratio, pool_mode, aligned);
+}
+
+void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output,
+                            Tensor argmax_y, Tensor argmax_x,
+                            int aligned_height, int aligned_width,
+                            float spatial_scale, int sampling_ratio,
+                            int pool_mode, bool aligned);
+
+void roi_align_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax_y,
+                             Tensor argmax_x, Tensor grad_input,
+                             int aligned_height, int aligned_width,
+                             float spatial_scale, int sampling_ratio,
+                             int pool_mode, bool aligned);
+
+REGISTER_DEVICE_IMPL(roi_align_forward_impl, MUSA, roi_align_forward_musa);
+REGISTER_DEVICE_IMPL(roi_align_backward_impl, MUSA, roi_align_backward_musa);
+
+void ROIAlignRotatedForwardMUSAKernelLauncher(
+    const at::Tensor input, const at::Tensor rois, const float spatial_scale,
+    const int sampling_ratio, const bool aligned, const bool clockwise,
+    const int channels, const int height, const int width, const int num_rois,
+    const int pooled_height, const int pooled_width, at::Tensor output);
+
+void ROIAlignRotatedBackwardMUSAKernelLauncher(
+    const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
+    const int sampling_ratio, const bool aligned, const bool clockwise,
+    const int channels, const int height, const int width, const int num_rois,
+    const int pooled_height, const int pooled_width, at::Tensor bottom_grad);
+
+void roi_align_rotated_forward_musa(Tensor input, Tensor rois, Tensor output,
+                                    int aligned_height, int aligned_width,
+                                    float spatial_scale, int sampling_ratio,
+                                    bool aligned, bool clockwise) {
+  // Number of ROIs
+  int num_rois = rois.size(0);
+  int size_rois = rois.size(1);
+
+  if (size_rois != 6) {
+    AT_ERROR("wrong roi size");
+  }
+
+  int num_channels = input.size(1);
+  int data_height = input.size(2);
+  int data_width = input.size(3);
+  ROIAlignRotatedForwardMUSAKernelLauncher(
+      input, rois, spatial_scale, sampling_ratio, aligned, clockwise,
+      num_channels, data_height, data_width, num_rois, aligned_height,
+      aligned_width, output);
+}
+
+void roi_align_rotated_backward_musa(Tensor top_grad, Tensor rois,
+                                     Tensor bottom_grad, int aligned_height,
+                                     int aligned_width, float spatial_scale,
+                                     int sampling_ratio, bool aligned,
+                                     bool clockwise) {
+  // Number of ROIs
+  int num_rois = rois.size(0);
+  int size_rois = rois.size(1);
+  if (size_rois != 6) {
+    AT_ERROR("wrong roi size");
+  }
+
+  int num_channels = bottom_grad.size(1);
+  int data_height = bottom_grad.size(2);
+  int data_width = bottom_grad.size(3);
+  ROIAlignRotatedBackwardMUSAKernelLauncher(
+      top_grad, rois, spatial_scale, sampling_ratio, aligned, clockwise,
+      num_channels, data_height, data_width, num_rois, aligned_height,
+      aligned_width, bottom_grad);
+}
+
+void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output,
+                                    int aligned_height, int aligned_width,
+                                    float spatial_scale, int sampling_ratio,
+                                    bool aligned, bool clockwise);
+
+void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
+                                     Tensor bottom_grad, int aligned_height,
+                                     int aligned_width, float spatial_scale,
+                                     int sampling_ratio, bool aligned,
+                                     bool clockwise);
+REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, MUSA,
+                     roi_align_rotated_forward_musa);
+REGISTER_DEVICE_IMPL(roi_align_rotated_backward_impl, MUSA,
+                     roi_align_rotated_backward_musa);
+
+void RiROIAlignRotatedForwardMUSAKernelLauncher(
+    const at::Tensor features, const at::Tensor rois, const float spatial_scale,
+    const int num_samples, const bool clockwise, const int channels,
+    const int height, const int width, const int num_rois,
+    const int pooled_height, const int pooled_width, const int num_orientations,
+    at::Tensor output);
+
+void RiROIAlignRotatedBackwardMUSAKernelLauncher(
+    const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
+    const int num_samples, const bool clockwise, const int channels,
+    const int height, const int width, const int num_rois,
+    const int pooled_height, const int pooled_width, const int num_orientations,
+    at::Tensor bottom_grad);
+
+void riroi_align_rotated_forward_musa(Tensor features, Tensor rois,
+                                      Tensor output, int pooled_height,
+                                      int pooled_width, float spatial_scale,
+                                      int num_samples, int num_orientations,
+                                      bool clockwise) {
+  // Number of ROIs
+  int num_rois = rois.size(0);
+  int size_rois = rois.size(1);
+  if (size_rois != 6) {
+    AT_ERROR("wrong roi size");
+  }
+  CHECK_CONTIGUOUS(features);
+  CHECK_CONTIGUOUS(rois);
+  int num_channels = features.size(1) / num_orientations;
+  int data_height = features.size(2);
+  int data_width = features.size(3);
+  RiROIAlignRotatedForwardMUSAKernelLauncher(
+      features, rois, spatial_scale, num_samples, clockwise, num_channels,
+      data_height, data_width, num_rois, pooled_height, pooled_width,
+      num_orientations, output);
+}
+
+void riroi_align_rotated_backward_musa(Tensor top_grad, Tensor rois,
+                                       Tensor bottom_grad, int pooled_height,
+                                       int pooled_width, float spatial_scale,
+                                       int num_samples, int num_orientations,
+                                       bool clockwise) {
+  // Number of ROIs
+  int num_rois = rois.size(0);
+  int size_rois = rois.size(1);
+  if (size_rois != 6) {
+    AT_ERROR("wrong roi size");
+  }
+  CHECK_CONTIGUOUS(top_grad);
+  CHECK_CONTIGUOUS(rois);
+  int num_channels = bottom_grad.size(1) / num_orientations;
+  int data_height = bottom_grad.size(2);
+  int data_width = bottom_grad.size(3);
+  RiROIAlignRotatedBackwardMUSAKernelLauncher(
+      top_grad, rois, spatial_scale, num_samples, clockwise, num_channels,
+      data_height, data_width, num_rois, pooled_height, pooled_width,
+      num_orientations, bottom_grad);
+}
+
+void riroi_align_rotated_forward_impl(Tensor features, Tensor rois,
+                                      Tensor output, int pooled_height,
+                                      int pooled_width, float spatial_scale,
+                                      int num_samples, int num_orientations,
+                                      bool clockwise);
+
+void riroi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
+                                       Tensor bottom_grad, int pooled_height,
+                                       int pooled_width, float spatial_scale,
+                                       int num_samples, int num_orientations,
+                                       bool clockwise);
+
+REGISTER_DEVICE_IMPL(riroi_align_rotated_forward_impl, MUSA,
+                     riroi_align_rotated_forward_musa);
+REGISTER_DEVICE_IMPL(riroi_align_rotated_backward_impl, MUSA,
+                     riroi_align_rotated_backward_musa);
+
+void RoiawarePool3dForwardMUSAKernelLauncher(
+    int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x,
+    int out_y, int out_z, const Tensor rois, const Tensor pts,
+    const Tensor pts_feature, Tensor argmax, Tensor pts_idx_of_voxels,
+    Tensor pooled_features, int pool_method);
+
+void RoiawarePool3dBackwardMUSAKernelLauncher(
+    int boxes_num, int out_x, int out_y, int out_z, int channels,
+    int max_pts_each_voxel, const Tensor pts_idx_of_voxels, const Tensor argmax,
+    const Tensor grad_out, Tensor grad_in, int pool_method);
+
+void roiaware_pool3d_forward_musa(int boxes_num, int pts_num, int channels,
+                                  int max_pts_each_voxel, int out_x, int out_y,
+                                  int out_z, const Tensor rois,
+                                  const Tensor pts, const Tensor pts_feature,
+                                  Tensor argmax, Tensor pts_idx_of_voxels,
+                                  Tensor pooled_features, int pool_method) {
+  RoiawarePool3dForwardMUSAKernelLauncher(
+      boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z,
+      rois, pts, pts_feature, argmax, pts_idx_of_voxels, pooled_features,
+      pool_method);
+};
+
+void roiaware_pool3d_backward_musa(int boxes_num, int out_x, int out_y,
+                                   int out_z, int channels,
+                                   int max_pts_each_voxel,
+                                   const Tensor pts_idx_of_voxels,
+                                   const Tensor argmax, const Tensor grad_out,
+                                   Tensor grad_in, int pool_method) {
+  RoiawarePool3dBackwardMUSAKernelLauncher(
+      boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel,
+      pts_idx_of_voxels, argmax, grad_out, grad_in, pool_method);
+};
+
+void roiaware_pool3d_forward_impl(int boxes_num, int pts_num, int channels,
+                                  int max_pts_each_voxel, int out_x, int out_y,
+                                  int out_z, const Tensor rois,
+                                  const Tensor pts, const Tensor pts_feature,
+                                  Tensor argmax, Tensor pts_idx_of_voxels,
+                                  Tensor pooled_features, int pool_method);
+
+void roiaware_pool3d_backward_impl(int boxes_num, int out_x, int out_y,
+                                   int out_z, int channels,
+                                   int max_pts_each_voxel,
+                                   const Tensor pts_idx_of_voxels,
+                                   const Tensor argmax, const Tensor grad_out,
+                                   Tensor grad_in, int pool_method);
+
+REGISTER_DEVICE_IMPL(roiaware_pool3d_forward_impl, MUSA,
+                     roiaware_pool3d_forward_musa);
+REGISTER_DEVICE_IMPL(roiaware_pool3d_backward_impl, MUSA,
+                     roiaware_pool3d_backward_musa);
+
+void RoIPointPool3dForwardMUSAKernelLauncher(
+    int batch_size, int pts_num, int boxes_num, int feature_in_len,
+    int sampled_pts_num, const Tensor xyz, const Tensor boxes3d,
+    const Tensor pts_feature, Tensor pooled_features, Tensor pooled_empty_flag);
+
+void roipoint_pool3d_forward_musa(int batch_size, int pts_num, int boxes_num,
+                                  int feature_in_len, int sampled_pts_num,
+                                  const Tensor xyz, const Tensor boxes3d,
+                                  const Tensor pts_feature,
+                                  Tensor pooled_features,
+                                  Tensor pooled_empty_flag) {
+  RoIPointPool3dForwardMUSAKernelLauncher(
+      batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, xyz,
+      boxes3d, pts_feature, pooled_features, pooled_empty_flag);
+};
+
+void roipoint_pool3d_forward_impl(int batch_size, int pts_num, int boxes_num,
+                                  int feature_in_len, int sampled_pts_num,
+                                  const Tensor xyz, const Tensor boxes3d,
+                                  const Tensor pts_feature,
+                                  Tensor pooled_features,
+                                  Tensor pooled_empty_flag);
+REGISTER_DEVICE_IMPL(roipoint_pool3d_forward_impl, MUSA,
+                     roipoint_pool3d_forward_musa);
+
+void ROIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois, Tensor output,
+                                      Tensor argmax, int pooled_height,
+                                      int pooled_width, float spatial_scale);
+
+void ROIPoolBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois,
+                                       Tensor argmax, Tensor grad_input,
+                                       int pooled_height, int pooled_width,
+                                       float spatial_scale);
+
+void roi_pool_forward_musa(Tensor input, Tensor rois, Tensor output,
+                           Tensor argmax, int pooled_height, int pooled_width,
+                           float spatial_scale) {
+  ROIPoolForwardMUSAKernelLauncher(input, rois, output, argmax, pooled_height,
+                                   pooled_width, spatial_scale);
+}
+
+void roi_pool_backward_musa(Tensor grad_output, Tensor rois, Tensor argmax,
+                            Tensor grad_input, int pooled_height,
+                            int pooled_width, float spatial_scale) {
+  ROIPoolBackwardMUSAKernelLauncher(grad_output, rois, argmax, grad_input,
+                                    pooled_height, pooled_width, spatial_scale);
+}
+
+void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
+                           Tensor argmax, int pooled_height, int pooled_width,
+                           float spatial_scale);
+void roi_pool_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax,
+                            Tensor grad_input, int pooled_height,
+                            int pooled_width, float spatial_scale);
+REGISTER_DEVICE_IMPL(roi_pool_forward_impl, MUSA, roi_pool_forward_musa);
+REGISTER_DEVICE_IMPL(roi_pool_backward_impl, MUSA, roi_pool_backward_musa);
+
+typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t;
+
+std::vector<at::Tensor> DynamicPointToVoxelForwardMUSAKernelLauncher(
+    const at::Tensor &feats, const at::Tensor &coors,
+    const reduce_t reduce_type);
+
+void DynamicPointToVoxelBackwardMUSAKernelLauncher(
+    at::Tensor &grad_feats, const at::Tensor &grad_reduced_feats,
+    const at::Tensor &feats, const at::Tensor &reduced_feats,
+    const at::Tensor &coors_map, const at::Tensor &reduce_count,
+    const reduce_t reduce_type);
+
+std::vector<torch::Tensor> dynamic_point_to_voxel_forward_musa(
+    const torch::Tensor &feats, const torch::Tensor &coors,
+    const reduce_t reduce_type) {
+  return DynamicPointToVoxelForwardMUSAKernelLauncher(feats, coors,
+                                                      reduce_type);
+};
+
+void dynamic_point_to_voxel_backward_musa(
+    torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats,
+    const torch::Tensor &feats, const torch::Tensor &reduced_feats,
+    const torch::Tensor &coors_idx, const torch::Tensor &reduce_count,
+    const reduce_t reduce_type) {
+  DynamicPointToVoxelBackwardMUSAKernelLauncher(grad_feats, grad_reduced_feats,
+                                                feats, reduced_feats, coors_idx,
+                                                reduce_count, reduce_type);
+};
+
+std::vector<torch::Tensor> dynamic_point_to_voxel_forward_impl(
+    const torch::Tensor &feats, const torch::Tensor &coors,
+    const reduce_t reduce_type);
+
+void dynamic_point_to_voxel_backward_impl(
+    torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats,
+    const torch::Tensor &feats, const torch::Tensor &reduced_feats,
+    const torch::Tensor &coors_idx, const torch::Tensor &reduce_count,
+    const reduce_t reduce_type);
+
+REGISTER_DEVICE_IMPL(dynamic_point_to_voxel_forward_impl, MUSA,
+                     dynamic_point_to_voxel_forward_musa);
+REGISTER_DEVICE_IMPL(dynamic_point_to_voxel_backward_impl, MUSA,
+                     dynamic_point_to_voxel_backward_musa);
+
+void SyncBNForwardMeanMUSAKernelLauncher(const Tensor input, Tensor mean);
+
+void SyncBNForwardVarMUSAKernelLauncher(const Tensor input, const Tensor mean,
+                                        Tensor var);
+
+void SyncBNForwardOutputMUSAKernelLauncher(
+    const Tensor input, const Tensor mean, const Tensor var,
+    Tensor running_mean, Tensor running_var, const Tensor weight,
+    const Tensor bias, Tensor norm, Tensor std, Tensor output, float eps,
+    float momentum, int group_size);
+
+void SyncBNBackwardParamMUSAKernelLauncher(const Tensor grad_output,
+                                           const Tensor norm,
+                                           Tensor grad_weight,
+                                           Tensor grad_bias);
+
+void SyncBNBackwardDataMUSAKernelLauncher(const Tensor grad_output,
+                                          const Tensor weight,
+                                          const Tensor grad_weight,
+                                          const Tensor grad_bias,
+                                          const Tensor norm, const Tensor std,
+                                          Tensor grad_input);
+
+void sync_bn_forward_mean_musa(const Tensor input, Tensor mean) {
+  SyncBNForwardMeanMUSAKernelLauncher(input, mean);
+}
+
+void sync_bn_forward_var_musa(const Tensor input, const Tensor mean,
+                              Tensor var) {
+  SyncBNForwardVarMUSAKernelLauncher(input, mean, var);
+}
+
+void sync_bn_forward_output_musa(const Tensor input, const Tensor mean,
+                                 const Tensor var, Tensor running_mean,
+                                 Tensor running_var, const Tensor weight,
+                                 const Tensor bias, Tensor norm, Tensor std,
+                                 Tensor output, float eps, float momentum,
+                                 int group_size) {
+  SyncBNForwardOutputMUSAKernelLauncher(input, mean, var, running_mean,
+                                        running_var, weight, bias, norm, std,
+                                        output, eps, momentum, group_size);
+}
+
+void sync_bn_backward_param_musa(const Tensor grad_output, const Tensor norm,
+                                 Tensor grad_weight, Tensor grad_bias) {
+  SyncBNBackwardParamMUSAKernelLauncher(grad_output, norm, grad_weight,
+                                        grad_bias);
+}
+
+void sync_bn_backward_data_musa(const Tensor grad_output, const Tensor weight,
+                                const Tensor grad_weight,
+                                const Tensor grad_bias, const Tensor norm,
+                                const Tensor std, Tensor grad_input) {
+  SyncBNBackwardDataMUSAKernelLauncher(grad_output, weight, grad_weight,
+                                       grad_bias, norm, std, grad_input);
+}
+
+void sync_bn_forward_mean_impl(const Tensor input, Tensor mean);
+
+void sync_bn_forward_var_impl(const Tensor input, const Tensor mean,
+                              Tensor var);
+
+void sync_bn_forward_output_impl(const Tensor input, const Tensor mean,
+                                 const Tensor var, Tensor running_mean,
+                                 Tensor running_var, const Tensor weight,
+                                 const Tensor bias, Tensor norm, Tensor std,
+                                 Tensor output, float eps, float momentum,
+                                 int group_size);
+
+void sync_bn_backward_param_impl(const Tensor grad_output, const Tensor norm,
+                                 Tensor grad_weight, Tensor grad_bias);
+
+void sync_bn_backward_data_impl(const Tensor grad_output, const Tensor weight,
+                                const Tensor grad_weight,
+                                const Tensor grad_bias, const Tensor norm,
+                                const Tensor std, Tensor grad_input);
+
+REGISTER_DEVICE_IMPL(sync_bn_forward_mean_impl, MUSA,
+                     sync_bn_forward_mean_musa);
+REGISTER_DEVICE_IMPL(sync_bn_forward_var_impl, MUSA, sync_bn_forward_var_musa);
+REGISTER_DEVICE_IMPL(sync_bn_forward_output_impl, MUSA,
+                     sync_bn_forward_output_musa);
+REGISTER_DEVICE_IMPL(sync_bn_backward_param_impl, MUSA,
+                     sync_bn_backward_param_musa);
+REGISTER_DEVICE_IMPL(sync_bn_backward_data_impl, MUSA,
+                     sync_bn_backward_data_musa);
+
+void ThreeInterpolateForwardMUSAKernelLauncher(int b, int c, int m, int n,
+                                               const Tensor points,
+                                               const Tensor idx,
+                                               const Tensor weight, Tensor out);
+
+void ThreeInterpolateBackwardMUSAKernelLauncher(int b, int c, int n, int m,
+                                                const Tensor grad_out,
+                                                const Tensor idx,
+                                                const Tensor weight,
+                                                Tensor grad_points);
+
+void three_interpolate_forward_musa(int b, int c, int m, int n,
+                                    const Tensor points, const Tensor idx,
+                                    const Tensor weight, Tensor out) {
+  ThreeInterpolateForwardMUSAKernelLauncher(b, c, m, n, points, idx, weight,
+                                            out);
+};
+
+void three_interpolate_backward_musa(int b, int c, int n, int m,
+                                     const Tensor grad_out, const Tensor idx,
+                                     const Tensor weight, Tensor grad_points) {
+  ThreeInterpolateBackwardMUSAKernelLauncher(b, c, n, m, grad_out, idx, weight,
+                                             grad_points);
+};
+
+void three_interpolate_forward_impl(int b, int c, int m, int n,
+                                    const Tensor points, const Tensor idx,
+                                    const Tensor weight, Tensor out);
+
+void three_interpolate_backward_impl(int b, int c, int n, int m,
+                                     const Tensor grad_out, const Tensor idx,
+                                     const Tensor weight, Tensor grad_points);
+REGISTER_DEVICE_IMPL(three_interpolate_forward_impl, MUSA,
+                     three_interpolate_forward_musa);
+REGISTER_DEVICE_IMPL(three_interpolate_backward_impl, MUSA,
+                     three_interpolate_backward_musa);
+
+void ThreeNNForwardMUSAKernelLauncher(int b, int n, int m, const Tensor unknown,
+                                      const Tensor known, Tensor dist2,
+                                      Tensor idx);
+
+void three_nn_forward_musa(int b, int n, int m, const Tensor unknown,
+                           const Tensor known, Tensor dist2, Tensor idx) {
+  ThreeNNForwardMUSAKernelLauncher(b, n, m, unknown, known, dist2, idx);
+};
+
+void three_nn_forward_impl(int b, int n, int m, const Tensor unknown,
+                           const Tensor known, Tensor dist2, Tensor idx);
+REGISTER_DEVICE_IMPL(three_nn_forward_impl, MUSA, three_nn_forward_musa);
+
+void TINShiftForwardMUSAKernelLauncher(Tensor input, Tensor shift,
+                                       Tensor output);
+
+void TINShiftBackwardMUSAKernelLauncher(Tensor grad_output, Tensor shift,
+                                        Tensor grad_input);
+
+void tin_shift_forward_musa(Tensor input, Tensor shift, Tensor output) {
+  TINShiftForwardMUSAKernelLauncher(input, shift, output);
+}
+
+void tin_shift_backward_musa(Tensor grad_output, Tensor shift,
+                             Tensor grad_input) {
+  TINShiftBackwardMUSAKernelLauncher(grad_output, shift, grad_input);
+}
+
+void tin_shift_forward_impl(Tensor input, Tensor shift, Tensor output);
+void tin_shift_backward_impl(Tensor grad_output, Tensor shift,
+                             Tensor grad_input);
+REGISTER_DEVICE_IMPL(tin_shift_forward_impl, MUSA, tin_shift_forward_musa);
+REGISTER_DEVICE_IMPL(tin_shift_backward_impl, MUSA, tin_shift_backward_musa);
+
+#if ((!defined(MUSA_ARCH)) || (defined(MUSA_ARCH)) && (MUSA_ARCH > 21))
+torch::Tensor upfirdn2d_op(torch::Tensor input, torch::Tensor filter, int upx,
+                           int upy, int downx, int downy, int padx0, int padx1,
+                           int pady0, int pady1, bool flip, float gain);
+
+torch::Tensor upfirdn2d_op_impl(torch::Tensor input, torch::Tensor filter,
+                                int upx, int upy, int downx, int downy,
+                                int padx0, int padx1, int pady0, int pady1,
+                                bool flip, float gain);
+REGISTER_DEVICE_IMPL(upfirdn2d_op_impl, MUSA, upfirdn2d_op);
+#endif
+
+int HardVoxelizeForwardMUSAKernelLauncher(
+    const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
+    at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
+    const std::vector<float> coors_range, const int max_points,
+    const int max_voxels, const int NDim = 3);
+
+int NondeterministicHardVoxelizeForwardMUSAKernelLauncher(
+    const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
+    at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
+    const std::vector<float> coors_range, const int max_points,
+    const int max_voxels, const int NDim = 3);
+
+void DynamicVoxelizeForwardMUSAKernelLauncher(
+    const at::Tensor &points, at::Tensor &coors,
+    const std::vector<float> voxel_size, const std::vector<float> coors_range,
+    const int NDim = 3);
+
+int hard_voxelize_forward_musa(const at::Tensor &points, at::Tensor &voxels,
+                               at::Tensor &coors,
+                               at::Tensor &num_points_per_voxel,
+                               const std::vector<float> voxel_size,
+                               const std::vector<float> coors_range,
+                               const int max_points, const int max_voxels,
+                               const int NDim) {
+  return HardVoxelizeForwardMUSAKernelLauncher(
+      points, voxels, coors, num_points_per_voxel, voxel_size, coors_range,
+      max_points, max_voxels, NDim);
+};
+
+int nondeterministic_hard_voxelize_forward_musa(
+    const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
+    at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
+    const std::vector<float> coors_range, const int max_points,
+    const int max_voxels, const int NDim) {
+  return NondeterministicHardVoxelizeForwardMUSAKernelLauncher(
+      points, voxels, coors, num_points_per_voxel, voxel_size, coors_range,
+      max_points, max_voxels, NDim);
+};
+
+void dynamic_voxelize_forward_musa(const at::Tensor &points, at::Tensor &coors,
+                                   const std::vector<float> voxel_size,
+                                   const std::vector<float> coors_range,
+                                   const int NDim) {
+  DynamicVoxelizeForwardMUSAKernelLauncher(points, coors, voxel_size,
+                                           coors_range, NDim);
+};
+
+int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels,
+                               at::Tensor &coors,
+                               at::Tensor &num_points_per_voxel,
+                               const std::vector<float> voxel_size,
+                               const std::vector<float> coors_range,
+                               const int max_points, const int max_voxels,
+                               const int NDim);
+
+int nondeterministic_hard_voxelize_forward_impl(
+    const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
+    at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
+    const std::vector<float> coors_range, const int max_points,
+    const int max_voxels, const int NDim);
+
+void dynamic_voxelize_forward_impl(const at::Tensor &points, at::Tensor &coors,
+                                   const std::vector<float> voxel_size,
+                                   const std::vector<float> coors_range,
+                                   const int NDim);
+
+REGISTER_DEVICE_IMPL(hard_voxelize_forward_impl, MUSA,
+                     hard_voxelize_forward_musa);
+REGISTER_DEVICE_IMPL(nondeterministic_hard_voxelize_forward_impl, MUSA,
+                     nondeterministic_hard_voxelize_forward_musa);
+REGISTER_DEVICE_IMPL(dynamic_voxelize_forward_impl, MUSA,
+                     dynamic_voxelize_forward_musa);
+
+void RotatedFeatureAlignForwardMUSAKernelLauncher(const Tensor features,
+                                                  const Tensor best_bboxes,
+                                                  const float spatial_scale,
+                                                  const int points,
+                                                  Tensor output);
+
+void RotatedFeatureAlignBackwardMUSAKernelLauncher(const Tensor top_grad,
+                                                   const Tensor best_bboxes,
+                                                   const float spatial_scale,
+                                                   const int points,
+                                                   Tensor bottom_grad);
+
+void rotated_feature_align_forward_musa(const Tensor features,
+                                        const Tensor best_bboxes,
+                                        const float spatial_scale,
+                                        const int points, Tensor output) {
+  RotatedFeatureAlignForwardMUSAKernelLauncher(features, best_bboxes,
+                                               spatial_scale, points, output);
+};
+
+void rotated_feature_align_backward_musa(const Tensor top_grad,
+                                         const Tensor best_bboxes,
+                                         const float spatial_scale,
+                                         const int points, Tensor bottom_grad) {
+  RotatedFeatureAlignBackwardMUSAKernelLauncher(
+      top_grad, best_bboxes, spatial_scale, points, bottom_grad);
+};
+
+void rotated_feature_align_forward_impl(const Tensor features,
+                                        const Tensor best_bboxes,
+                                        const float spatial_scale,
+                                        const int points, Tensor output);
+
+void rotated_feature_align_backward_impl(const Tensor top_grad,
+                                         const Tensor best_bboxes,
+                                         const float spatial_scale,
+                                         const int points, Tensor bottom_grad);
+
+REGISTER_DEVICE_IMPL(rotated_feature_align_forward_impl, MUSA,
+                     rotated_feature_align_forward_musa);
+REGISTER_DEVICE_IMPL(rotated_feature_align_backward_impl, MUSA,
+                     rotated_feature_align_backward_musa);
+
+void PointsInPolygonsForwardMUSAKernelLauncher(const at::Tensor points,
+                                               const at::Tensor polygons,
+                                               const int rows, const int cols,
+                                               at::Tensor output);
+
+void points_in_polygons_forward_musa(const Tensor points, const Tensor polygons,
+                                     Tensor output, const int rows,
+                                     const int cols) {
+  PointsInPolygonsForwardMUSAKernelLauncher(points, polygons, rows, cols,
+                                            output);
+};
+
+void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons,
+                                     Tensor output, const int rows,
+                                     const int cols);
+
+REGISTER_DEVICE_IMPL(points_in_polygons_forward_impl, MUSA,
+                     points_in_polygons_forward_musa);
+
+torch::Tensor IndiceMaxpoolForwardMUSAKernelLauncher(torch::Tensor features,
+                                                     torch::Tensor indicePairs,
+                                                     torch::Tensor indiceNum,
+                                                     int64_t numAct);
+
+torch::Tensor indice_maxpool_forward_musa(torch::Tensor features,
+                                          torch::Tensor indicePairs,
+                                          torch::Tensor indiceNum,
+                                          int64_t numAct) {
+  return IndiceMaxpoolForwardMUSAKernelLauncher(features, indicePairs,
+                                                indiceNum, numAct);
+};
+
+torch::Tensor indice_maxpool_forward_impl(torch::Tensor features,
+                                          torch::Tensor indicePairs,
+                                          torch::Tensor indiceNum,
+                                          int64_t numAct);
+REGISTER_DEVICE_IMPL(indice_maxpool_forward_impl, MUSA,
+                     indice_maxpool_forward_musa);
+
+torch::Tensor IndiceMaxpoolBackwardMUSAKernelLauncher(torch::Tensor features,
+                                                      torch::Tensor outFeatures,
+                                                      torch::Tensor outGrad,
+                                                      torch::Tensor indicePairs,
+                                                      torch::Tensor indiceNum);
+
+torch::Tensor indice_maxpool_backward_musa(torch::Tensor features,
+                                           torch::Tensor outFeatures,
+                                           torch::Tensor outGrad,
+                                           torch::Tensor indicePairs,
+                                           torch::Tensor indiceNum) {
+  return IndiceMaxpoolBackwardMUSAKernelLauncher(features, outFeatures, outGrad,
+                                                 indicePairs, indiceNum);
+};
+
+torch::Tensor indice_maxpool_backward_impl(torch::Tensor features,
+                                           torch::Tensor outFeatures,
+                                           torch::Tensor outGrad,
+                                           torch::Tensor indicePairs,
+                                           torch::Tensor indiceNum);
+
+REGISTER_DEVICE_IMPL(indice_maxpool_backward_impl, MUSA,
+                     indice_maxpool_backward_musa)
+
+torch::Tensor IndiceConvForwardMUSAKernelLauncher(
+    torch::Tensor features, torch::Tensor filters, torch::Tensor indicePairs,
+    torch::Tensor indiceNum, int64_t numActOut, int64_t _inverse,
+    int64_t _subM);
+
+torch::Tensor indice_conv_forward_musa(torch::Tensor features,
+                                       torch::Tensor filters,
+                                       torch::Tensor indicePairs,
+                                       torch::Tensor indiceNum,
+                                       int64_t numActOut, int64_t _inverse,
+                                       int64_t _subM) {
+  return IndiceConvForwardMUSAKernelLauncher(
+      features, filters, indicePairs, indiceNum, numActOut, _inverse, _subM);
+};
+
+torch::Tensor indice_conv_forward_impl(torch::Tensor features,
+                                       torch::Tensor filters,
+                                       torch::Tensor indicePairs,
+                                       torch::Tensor indiceNum,
+                                       int64_t numActOut, int64_t _inverse,
+                                       int64_t _subM);
+
+REGISTER_DEVICE_IMPL(indice_conv_forward_impl, MUSA, indice_conv_forward_musa);
+
+std::vector<torch::Tensor> IndiceConvBackwardMUSAKernelLauncher(
+    torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad,
+    torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
+    int64_t _subM);
+
+std::vector<torch::Tensor> indice_conv_backward_musa(
+    torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad,
+    torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
+    int64_t _subM) {
+  return IndiceConvBackwardMUSAKernelLauncher(
+      features, filters, outGrad, indicePairs, indiceNum, _inverse, _subM);
+};
+
+std::vector<torch::Tensor> indice_conv_backward_impl(
+    torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad,
+    torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
+    int64_t _subM);
+
+REGISTER_DEVICE_IMPL(indice_conv_backward_impl, MUSA,
+                     indice_conv_backward_musa);
+
+torch::Tensor FusedIndiceConvBatchnormMUSAKernelLauncher(
+    torch::Tensor features, torch::Tensor filters, torch::Tensor bias,
+    torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t numActOut,
+    int64_t _inverse, int64_t _subM);
+
+torch::Tensor fused_indice_conv_batchnorm_forward_musa(
+    torch::Tensor features, torch::Tensor filters, torch::Tensor bias,
+    torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t numActOut,
+    int64_t _inverse, int64_t _subM) {
+  return FusedIndiceConvBatchnormMUSAKernelLauncher(features, filters, bias,
+                                                    indicePairs, indiceNum,
+                                                    numActOut, _inverse, _subM);
+};
+
+torch::Tensor fused_indice_conv_batchnorm_forward_impl(
+    torch::Tensor features, torch::Tensor filters, torch::Tensor bias,
+    torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t numActOut,
+    int64_t _inverse, int64_t _subM);
+
+REGISTER_DEVICE_IMPL(fused_indice_conv_batchnorm_forward_impl, MUSA,
+                     fused_indice_conv_batchnorm_forward_musa)
+
 void MinAreaPolygonsMUSAKernelLauncher(const Tensor pointsets, Tensor polygons);
 
 void min_area_polygons_musa(const Tensor pointsets, Tensor polygons) {
@@ -990,6 +1849,57 @@ REGISTER_DEVICE_IMPL(chamfer_distance_backward_impl, MUSA,
                      chamfer_distance_backward_musa);
 #endif
 
+void PrROIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois,
+                                        Tensor output, int pooled_height,
+                                        int pooled_width, float spatial_scale);
+
+void PrROIPoolBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois,
+                                         Tensor grad_input, int pooled_height,
+                                         int pooled_width, float spatial_scale);
+
+void PrROIPoolCoorBackwardMUSAKernelLauncher(
+    Tensor output, Tensor grad_output, Tensor input, Tensor rois,
+    Tensor grad_rois, int pooled_height, int pooled_width, float spatial_scale);
+
+void prroi_pool_forward_musa(Tensor input, Tensor rois, Tensor output,
+                             int pooled_height, int pooled_width,
+                             float spatial_scale) {
+  PrROIPoolForwardMUSAKernelLauncher(input, rois, output, pooled_height,
+                                     pooled_width, spatial_scale);
+}
+
+void prroi_pool_backward_musa(Tensor grad_output, Tensor rois,
+                              Tensor grad_input, int pooled_height,
+                              int pooled_width, float spatial_scale) {
+  PrROIPoolBackwardMUSAKernelLauncher(grad_output, rois, grad_input,
+                                      pooled_height, pooled_width,
+                                      spatial_scale);
+}
+
+void prroi_pool_coor_backward_musa(Tensor output, Tensor grad_output,
+                                   Tensor input, Tensor rois, Tensor grad_rois,
+                                   int pooled_height, int pooled_width,
+                                   float spatial_scale) {
+  PrROIPoolCoorBackwardMUSAKernelLauncher(output, grad_output, input, rois,
+                                          grad_rois, pooled_height,
+                                          pooled_width, spatial_scale);
+}
+
+void prroi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
+                             int pooled_height, int pooled_width,
+                             float spatial_scale);
+void prroi_pool_backward_impl(Tensor grad_output, Tensor rois,
+                              Tensor grad_input, int pooled_height,
+                              int pooled_width, float spatial_scale);
+void prroi_pool_coor_backward_impl(Tensor output, Tensor grad_output,
+                                   Tensor input, Tensor rois, Tensor grad_rois,
+                                   int pooled_height, int pooled_width,
+                                   float spatial_scale);
+REGISTER_DEVICE_IMPL(prroi_pool_forward_impl, MUSA, prroi_pool_forward_musa);
+REGISTER_DEVICE_IMPL(prroi_pool_backward_impl, MUSA, prroi_pool_backward_musa);
+REGISTER_DEVICE_IMPL(prroi_pool_coor_backward_impl, MUSA,
+                     prroi_pool_coor_backward_musa);
+
 void BezierAlignForwardMUSAKernelLauncher(Tensor input, Tensor rois,
                                           Tensor output, int aligned_height,
                                           int aligned_width,
diff --git a/mmcv/ops/csrc/pytorch/musa/points_in_boxes_musa.mu b/mmcv/ops/csrc/pytorch/musa/points_in_boxes_musa.mu
new file mode 100644
index 000000000..5330aeaac
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/points_in_boxes_musa.mu
@@ -0,0 +1,62 @@
+// Modified from
+// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
+// Written by Shaoshuai Shi
+// All Rights Reserved 2019.
+
+#include <stdio.h>
+
+#include "points_in_boxes_musa_kernel.muh"
+#include "pytorch_musa_helper.hpp"
+
+void PointsInBoxesPartForwardMUSAKernelLauncher(int batch_size, int boxes_num,
+                                                int pts_num, const Tensor boxes,
+                                                const Tensor pts,
+                                                Tensor box_idx_of_points) {
+  // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
+  // coordinate, z is
+  // the bottom center, each box DO NOT overlaps params pts: (B, npoints, 3) [x,
+  // y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default
+  // -1
+
+  c10::musa::MUSAGuard device_guard(boxes.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+
+  dim3 blocks(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), batch_size);
+  dim3 threads(THREADS_PER_BLOCK);
+
+  AT_DISPATCH_FLOATING_TYPES(
+      boxes.scalar_type(), "points_in_boxes_part_forward_musa_kernel", [&] {
+        points_in_boxes_part_forward_musa_kernel<scalar_t>
+            <<<blocks, threads, 0, stream>>>(
+                batch_size, boxes_num, pts_num, boxes.data_ptr<scalar_t>(),
+                pts.data_ptr<scalar_t>(), box_idx_of_points.data_ptr<int>());
+      });
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
+
+void PointsInBoxesAllForwardMUSAKernelLauncher(int batch_size, int boxes_num,
+                                               int pts_num, const Tensor boxes,
+                                               const Tensor pts,
+                                               Tensor box_idx_of_points) {
+  // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
+  // coordinate, z is the bottom center, each box params pts: (B, npoints, 3)
+  // [x, y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints),
+  // default -1
+
+  c10::musa::MUSAGuard device_guard(boxes.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+
+  dim3 blocks(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), batch_size);
+  dim3 threads(THREADS_PER_BLOCK);
+
+  AT_DISPATCH_FLOATING_TYPES(
+      boxes.scalar_type(), "points_in_boxes_all_forward_musa_kernel", [&] {
+        points_in_boxes_all_forward_musa_kernel<scalar_t>
+            <<<blocks, threads, 0, stream>>>(
+                batch_size, boxes_num, pts_num, boxes.data_ptr<scalar_t>(),
+                pts.data_ptr<scalar_t>(), box_idx_of_points.data_ptr<int>());
+      });
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
diff --git a/mmcv/ops/csrc/pytorch/musa/points_in_polygons_musa.mu b/mmcv/ops/csrc/pytorch/musa/points_in_polygons_musa.mu
new file mode 100644
index 000000000..307cb38ea
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/points_in_polygons_musa.mu
@@ -0,0 +1,28 @@
+// Copyright (c) OpenMMLab. All rights reserved
+// Modified from
+// https://github.com/ming71/MUSA/blob/master/point_justify/points_justify_kernel.cu
+
+#include <stdio.h>
+
+#include "points_in_polygons_musa_kernel.muh"
+#include "pytorch_musa_helper.hpp"
+
+void PointsInPolygonsForwardMUSAKernelLauncher(const at::Tensor points,
+                                               const at::Tensor polygons,
+                                               const int rows, const int cols,
+                                               at::Tensor output) {
+  const int output_size = rows * cols;
+  c10::musa::MUSAGuard device_guard(points.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  AT_DISPATCH_FLOATING_TYPES(
+      points.scalar_type(), "points_in_polygons_forward_musa_kernel", ([&] {
+        const scalar_t *vertex1 = points.data_ptr<scalar_t>();
+        const scalar_t *vertex2 = polygons.data_ptr<scalar_t>();
+        scalar_t *inside_flag = output.data_ptr<scalar_t>();
+
+        points_in_polygons_forward_musa_kernel<scalar_t>
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
+                output_size, vertex1, vertex2, rows, cols, inside_flag);
+      }));
+  AT_MUSA_CHECK(musaGetLastError());
+}
diff --git a/mmcv/ops/csrc/pytorch/musa/prroi_pool_musa.mu b/mmcv/ops/csrc/pytorch/musa/prroi_pool_musa.mu
new file mode 100644
index 000000000..fb7131776
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/prroi_pool_musa.mu
@@ -0,0 +1,65 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#include "prroi_pool_musa_kernel.muh"
+#include "pytorch_musa_helper.hpp"
+
+void PrROIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois,
+                                        Tensor output, int pooled_height,
+                                        int pooled_width, float spatial_scale) {
+  int output_size = output.numel();
+  int channels = input.size(1);
+  int height = input.size(2);
+  int width = input.size(3);
+
+  c10::musa::MUSAGuard device_guard(input.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  prroi_pool_forward_musa_kernel<float>
+      <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
+          output_size, input.data_ptr<float>(), rois.data_ptr<float>(),
+          output.data_ptr<float>(), pooled_height, pooled_width,
+          static_cast<float>(spatial_scale), channels, height, width);
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
+
+void PrROIPoolBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois,
+                                         Tensor grad_input, int pooled_height,
+                                         int pooled_width,
+                                         float spatial_scale) {
+  int output_size = grad_output.numel();
+  int channels = grad_input.size(1);
+  int height = grad_input.size(2);
+  int width = grad_input.size(3);
+
+  c10::musa::MUSAGuard device_guard(grad_output.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  prroi_pool_backward_musa_kernel<float>
+      <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
+          output_size, grad_output.data_ptr<float>(), rois.data_ptr<float>(),
+          grad_input.data_ptr<float>(), pooled_height, pooled_width,
+          static_cast<float>(spatial_scale), channels, height, width);
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
+
+void PrROIPoolCoorBackwardMUSAKernelLauncher(Tensor output, Tensor grad_output,
+                                             Tensor input, Tensor rois,
+                                             Tensor grad_rois,
+                                             int pooled_height,
+                                             int pooled_width,
+                                             float spatial_scale) {
+  int output_size = grad_output.numel();
+  int channels = input.size(1);
+  int height = input.size(2);
+  int width = input.size(3);
+
+  c10::musa::MUSAGuard device_guard(grad_output.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  prroi_pool_coor_backward_musa_kernel<float>
+      <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
+          output_size, output.data_ptr<float>(), grad_output.data_ptr<float>(),
+          input.data_ptr<float>(), rois.data_ptr<float>(),
+          grad_rois.data_ptr<float>(), pooled_height, pooled_width,
+          static_cast<float>(spatial_scale), channels, height, width);
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
diff --git a/mmcv/ops/csrc/pytorch/musa/psamask_musa.mu b/mmcv/ops/csrc/pytorch/musa/psamask_musa.mu
new file mode 100644
index 000000000..d432954fa
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/psamask_musa.mu
@@ -0,0 +1,60 @@
+// Copyright (c) OpenMMLab. All rights reserved
+// Modified from
+// https://github.com/hszhao/semseg/blob/master/lib/psa/src
+
+#include <torch/serialize/tensor.h>
+
+#include "psamask_musa_kernel.muh"
+#include "pytorch_musa_helper.hpp"
+
+void PSAMaskForwardMUSAKernelLauncher(const int psa_type, const Tensor input,
+                                      Tensor output, const int num_,
+                                      const int h_feature, const int w_feature,
+                                      const int h_mask, const int w_mask,
+                                      const int half_h_mask,
+                                      const int half_w_mask) {
+  int nthreads = num_ * h_feature * w_feature;
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  if (psa_type == 0)
+    AT_DISPATCH_FLOATING_TYPES(
+        input.scalar_type(), "psamask_collect_forward_musa", [&] {
+          psamask_collect_forward_musa<scalar_t><<<nthreads, 512, 0, stream>>>(
+              nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask,
+              half_w_mask, input.data_ptr<scalar_t>(),
+              output.data_ptr<scalar_t>());
+        });
+  else
+    AT_DISPATCH_FLOATING_TYPES(
+        input.scalar_type(), "psamask_distribute_forward_musa", [&] {
+          psamask_distribute_forward_musa<scalar_t>
+              <<<nthreads, 512, 0, stream>>>(
+                  nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask,
+                  half_w_mask, input.data_ptr<scalar_t>(),
+                  output.data_ptr<scalar_t>());
+        });
+}
+
+void PSAMaskBackwardMUSAKernelLauncher(
+    const int psa_type, const Tensor grad_output, Tensor grad_input,
+    const int num_, const int h_feature, const int w_feature, const int h_mask,
+    const int w_mask, const int half_h_mask, const int half_w_mask) {
+  int nthreads = num_ * h_feature * w_feature;
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  if (psa_type == 0)
+    AT_DISPATCH_FLOATING_TYPES(
+        grad_input.scalar_type(), "psamask_collect_backward_musa", [&] {
+          psamask_collect_backward_musa<scalar_t><<<nthreads, 512, 0, stream>>>(
+              nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask,
+              half_w_mask, grad_output.data_ptr<scalar_t>(),
+              grad_input.data_ptr<scalar_t>());
+        });
+  else
+    AT_DISPATCH_FLOATING_TYPES(
+        grad_input.scalar_type(), "psamask_distribute_backward_musa", [&] {
+          psamask_distribute_backward_musa<scalar_t>
+              <<<nthreads, 512, 0, stream>>>(
+                  nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask,
+                  half_w_mask, grad_output.data_ptr<scalar_t>(),
+                  grad_input.data_ptr<scalar_t>());
+        });
+}
diff --git a/mmcv/ops/csrc/pytorch/musa/riroi_align_rotated_musa.mu b/mmcv/ops/csrc/pytorch/musa/riroi_align_rotated_musa.mu
new file mode 100644
index 000000000..575071e33
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/riroi_align_rotated_musa.mu
@@ -0,0 +1,53 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#include "pytorch_musa_helper.hpp"
+#include "riroi_align_rotated_musa_kernel.muh"
+
+void RiROIAlignRotatedForwardMUSAKernelLauncher(
+    const at::Tensor features, const at::Tensor rois, const float spatial_scale,
+    const int num_samples, const bool clockwise, const int channels,
+    const int height, const int width, const int num_rois,
+    const int pooled_height, const int pooled_width, const int num_orientations,
+    at::Tensor output) {
+  const int output_size =
+      num_rois * pooled_height * pooled_width * channels * num_orientations;
+  c10::musa::MUSAGuard device_guard(features.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  AT_DISPATCH_FLOATING_TYPES(
+      features.scalar_type(), "riroi_align_rotated_forward_musa_kernel", ([&] {
+        const scalar_t *bottom_data = features.data_ptr<scalar_t>();
+        const scalar_t *rois_data = rois.data_ptr<scalar_t>();
+        scalar_t *top_data = output.data_ptr<scalar_t>();
+
+        riroi_align_rotated_forward_musa_kernel<scalar_t>
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
+                output_size, bottom_data, rois_data, scalar_t(spatial_scale),
+                num_samples, clockwise, channels, height, width, pooled_height,
+                pooled_width, num_orientations, top_data);
+      }));
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
+
+void RiROIAlignRotatedBackwardMUSAKernelLauncher(
+    const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
+    const int num_samples, const bool clockwise, const int channels,
+    const int height, const int width, const int num_rois,
+    const int pooled_height, const int pooled_width, const int num_orientations,
+    at::Tensor bottom_grad) {
+  const int output_size =
+      num_rois * pooled_height * pooled_width * channels * num_orientations;
+  c10::musa::MUSAGuard device_guard(top_grad.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  AT_DISPATCH_FLOATING_TYPES(
+      top_grad.scalar_type(), "riroi_align_rotated_backward_musa_kernel", ([&] {
+        const scalar_t *top_diff = top_grad.data_ptr<scalar_t>();
+        const scalar_t *rois_data = rois.data_ptr<scalar_t>();
+        scalar_t *bottom_diff = bottom_grad.data_ptr<scalar_t>();
+        riroi_align_rotated_backward_musa_kernel<scalar_t>
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
+                output_size, top_diff, rois_data, spatial_scale, num_samples,
+                clockwise, channels, height, width, pooled_height, pooled_width,
+                num_orientations, bottom_diff);
+      }));
+  AT_MUSA_CHECK(musaGetLastError());
+}
diff --git a/mmcv/ops/csrc/pytorch/musa/roi_align_musa.mu b/mmcv/ops/csrc/pytorch/musa/roi_align_musa.mu
new file mode 100644
index 000000000..f44d8e743
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/roi_align_musa.mu
@@ -0,0 +1,58 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#include "pytorch_musa_helper.hpp"
+#include "roi_align_musa_kernel.muh"
+
+void ROIAlignForwardMUSAKernelLauncher(Tensor input, Tensor rois, Tensor output,
+                                       Tensor argmax_y, Tensor argmax_x,
+                                       int aligned_height, int aligned_width,
+                                       float spatial_scale, int sampling_ratio,
+                                       int pool_mode, bool aligned) {
+  int output_size = output.numel();
+  int channels = input.size(1);
+  int height = input.size(2);
+  int width = input.size(3);
+
+  c10::musa::MUSAGuard device_guard(input.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      input.scalar_type(), "roi_align_forward_musa_kernel", [&] {
+        roi_align_forward_musa_kernel<scalar_t>
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
+                output_size, input.data_ptr<scalar_t>(),
+                rois.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
+                argmax_y.data_ptr<scalar_t>(), argmax_x.data_ptr<scalar_t>(),
+                aligned_height, aligned_width,
+                static_cast<scalar_t>(spatial_scale), sampling_ratio, pool_mode,
+                aligned, channels, height, width);
+      });
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
+
+void ROIAlignBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois,
+                                        Tensor argmax_y, Tensor argmax_x,
+                                        Tensor grad_input, int aligned_height,
+                                        int aligned_width, float spatial_scale,
+                                        int sampling_ratio, int pool_mode,
+                                        bool aligned) {
+  int output_size = grad_output.numel();
+  int channels = grad_input.size(1);
+  int height = grad_input.size(2);
+  int width = grad_input.size(3);
+
+  c10::musa::MUSAGuard device_guard(grad_output.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      grad_output.scalar_type(), "roi_align_backward_musa_kernel", [&] {
+        roi_align_backward_musa_kernel<scalar_t>
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
+                output_size, grad_output.data_ptr<scalar_t>(),
+                rois.data_ptr<scalar_t>(), argmax_y.data_ptr<scalar_t>(),
+                argmax_x.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
+                aligned_height, aligned_width,
+                static_cast<scalar_t>(spatial_scale), sampling_ratio, pool_mode,
+                aligned, channels, height, width);
+      });
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
diff --git a/mmcv/ops/csrc/pytorch/musa/roi_align_rotated_musa.mu b/mmcv/ops/csrc/pytorch/musa/roi_align_rotated_musa.mu
new file mode 100644
index 000000000..793744e24
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/roi_align_rotated_musa.mu
@@ -0,0 +1,45 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#include "pytorch_musa_helper.hpp"
+#include "roi_align_rotated_musa_kernel.muh"
+
+void ROIAlignRotatedForwardMUSAKernelLauncher(
+    const at::Tensor input, const at::Tensor rois, const float spatial_scale,
+    const int sampling_ratio, const bool aligned, const bool clockwise,
+    const int channels, const int height, const int width, const int num_rois,
+    const int pooled_height, const int pooled_width, at::Tensor output) {
+  const int output_size = num_rois * pooled_height * pooled_width * channels;
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      input.scalar_type(), "ROIAlignRotatedLaucherForward", ([&] {
+        const scalar_t *bottom_data = input.data_ptr<scalar_t>();
+        const scalar_t *rois_data = rois.data_ptr<scalar_t>();
+        scalar_t *top_data = output.data_ptr<scalar_t>();
+
+        roi_align_rotated_forward_musa_kernel<scalar_t>
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
+                output_size, bottom_data, rois_data, scalar_t(spatial_scale),
+                sampling_ratio, aligned, clockwise, channels, height, width,
+                pooled_height, pooled_width, top_data);
+      }));
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
+
+void ROIAlignRotatedBackwardMUSAKernelLauncher(
+    const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
+    const int sampling_ratio, const bool aligned, const bool clockwise,
+    const int channels, const int height, const int width, const int num_rois,
+    const int pooled_height, const int pooled_width, at::Tensor bottom_grad) {
+  const int output_size = num_rois * pooled_height * pooled_width * channels;
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      top_grad.scalar_type(), "ROIAlignLaucherBackward", ([&] {
+        const scalar_t *top_diff = top_grad.data_ptr<scalar_t>();
+        const scalar_t *rois_data = rois.data_ptr<scalar_t>();
+        scalar_t *bottom_diff = bottom_grad.data_ptr<scalar_t>();
+        roi_align_rotated_backward_musa_kernel<scalar_t>
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
+                output_size, top_diff, rois_data, spatial_scale, sampling_ratio,
+                aligned, clockwise, channels, height, width, pooled_height,
+                pooled_width, bottom_diff);
+      }));
+  AT_MUSA_CHECK(musaGetLastError());
+}
diff --git a/mmcv/ops/csrc/pytorch/musa/roi_pool_musa.mu b/mmcv/ops/csrc/pytorch/musa/roi_pool_musa.mu
new file mode 100644
index 000000000..f6cb3b699
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/roi_pool_musa.mu
@@ -0,0 +1,50 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#include "pytorch_musa_helper.hpp"
+#include "roi_pool_musa_kernel.muh"
+
+void ROIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois, Tensor output,
+                                      Tensor argmax, int pooled_height,
+                                      int pooled_width, float spatial_scale) {
+  int output_size = output.numel();
+  int channels = input.size(1);
+  int height = input.size(2);
+  int width = input.size(3);
+
+  c10::musa::MUSAGuard device_guard(input.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      input.scalar_type(), "roi_pool_forward_musa_kernel", [&] {
+        roi_pool_forward_musa_kernel<scalar_t>
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
+                output_size, input.data_ptr<scalar_t>(),
+                rois.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
+                argmax.data_ptr<int>(), pooled_height, pooled_width,
+                static_cast<scalar_t>(spatial_scale), channels, height, width);
+      });
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
+
+void ROIPoolBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois,
+                                       Tensor argmax, Tensor grad_input,
+                                       int pooled_height, int pooled_width,
+                                       float spatial_scale) {
+  int output_size = grad_output.numel();
+  int channels = grad_input.size(1);
+  int height = grad_input.size(2);
+  int width = grad_input.size(3);
+
+  c10::musa::MUSAGuard device_guard(grad_output.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      grad_output.scalar_type(), "roi_pool_backward_musa_kernel", [&] {
+        roi_pool_backward_musa_kernel<scalar_t>
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
+                output_size, grad_output.data_ptr<scalar_t>(),
+                rois.data_ptr<scalar_t>(), argmax.data_ptr<int>(),
+                grad_input.data_ptr<scalar_t>(), pooled_height, pooled_width,
+                channels, height, width);
+      });
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
diff --git a/mmcv/ops/csrc/pytorch/musa/roiaware_pool3d_musa.mu b/mmcv/ops/csrc/pytorch/musa/roiaware_pool3d_musa.mu
new file mode 100644
index 000000000..55e283da3
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/roiaware_pool3d_musa.mu
@@ -0,0 +1,118 @@
+// Modified from
+// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
+// Written by Shaoshuai Shi
+// All Rights Reserved 2019.
+
+#include <stdio.h>
+
+#include "pytorch_musa_helper.hpp"
+#include "roiaware_pool3d_musa_kernel.muh"
+
+void RoiawarePool3dForwardMUSAKernelLauncher(
+    int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x,
+    int out_y, int out_z, const Tensor rois, const Tensor pts,
+    const Tensor pts_feature, Tensor argmax, Tensor pts_idx_of_voxels,
+    Tensor pooled_features, int pool_method) {
+  // params rois: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
+  // coordinate params pts: (npoints, 3) [x, y, z] in LiDAR coordinate params
+  // pts_feature: (npoints, C) params argmax: (N, out_x, out_y, out_z, C) params
+  // pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) params
+  // pooled_features: (N, out_x, out_y, out_z, C) params pool_method: 0:
+  // max_pool 1: avg_pool
+
+  c10::musa::MUSAGuard device_guard(pts_feature.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+
+  Tensor pts_mask =
+      -at::ones({boxes_num, pts_num}, pts_feature.options().dtype(at::kInt));
+
+  dim3 blocks_mask(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), boxes_num);
+  dim3 threads(THREADS_PER_BLOCK);
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      rois.scalar_type(), "generate_pts_mask_for_box3d", [&] {
+        generate_pts_mask_for_box3d<scalar_t>
+            <<<blocks_mask, threads, 0, stream>>>(
+                boxes_num, pts_num, out_x, out_y, out_z,
+                rois.data_ptr<scalar_t>(), pts.data_ptr<scalar_t>(),
+                pts_mask.data_ptr<int>());
+      });
+
+  AT_MUSA_CHECK(musaGetLastError());
+
+  // TODO: Merge the collect and pool functions, SS
+
+  dim3 blocks_collect(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK));
+
+  AT_DISPATCH_INTEGRAL_TYPES(
+      pts_idx_of_voxels.scalar_type(), "collect_inside_pts_for_box3d", [&] {
+        collect_inside_pts_for_box3d<scalar_t>
+            <<<blocks_collect, threads, 0, stream>>>(
+                boxes_num, pts_num, max_pts_each_voxel, out_x, out_y, out_z,
+                pts_mask.data_ptr<int>(),
+                pts_idx_of_voxels.data_ptr<scalar_t>());
+      });
+
+  AT_MUSA_CHECK(musaGetLastError());
+
+  dim3 blocks_pool(GET_BLOCKS(out_x * out_y * out_z, THREADS_PER_BLOCK),
+                   channels, boxes_num);
+  if (pool_method == 0) {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+        pts_feature.scalar_type(), "roiaware_maxpool3d", [&] {
+          roiaware_maxpool3d<scalar_t><<<blocks_pool, threads, 0, stream>>>(
+              boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y,
+              out_z, pts_feature.data_ptr<scalar_t>(),
+              pts_idx_of_voxels.data_ptr<int>(),
+              pooled_features.data_ptr<scalar_t>(), argmax.data_ptr<int>());
+        });
+  } else if (pool_method == 1) {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+        pts_feature.scalar_type(), "roiaware_avgpool3d", [&] {
+          roiaware_avgpool3d<scalar_t><<<blocks_pool, threads, 0, stream>>>(
+              boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y,
+              out_z, pts_feature.data_ptr<scalar_t>(),
+              pts_idx_of_voxels.data_ptr<int>(),
+              pooled_features.data_ptr<scalar_t>());
+        });
+  }
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
+
+void RoiawarePool3dBackwardMUSAKernelLauncher(
+    int boxes_num, int out_x, int out_y, int out_z, int channels,
+    int max_pts_each_voxel, const Tensor pts_idx_of_voxels, const Tensor argmax,
+    const Tensor grad_out, Tensor grad_in, int pool_method) {
+  // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
+  // params argmax: (N, out_x, out_y, out_z, C)
+  // params grad_out: (N, out_x, out_y, out_z, C)
+  // params grad_in: (npoints, C), return value
+  // params pool_method: 0: max_pool, 1: avg_pool
+
+  c10::musa::MUSAGuard device_guard(grad_out.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+
+  dim3 blocks(GET_BLOCKS(out_x * out_y * out_z, THREADS_PER_BLOCK), channels,
+              boxes_num);
+  dim3 threads(THREADS_PER_BLOCK);
+
+  if (pool_method == 0) {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+        grad_in.scalar_type(), "roiaware_maxpool3d_backward", [&] {
+          roiaware_maxpool3d_backward<scalar_t><<<blocks, threads, 0, stream>>>(
+              boxes_num, channels, out_x, out_y, out_z, argmax.data_ptr<int>(),
+              grad_out.data_ptr<scalar_t>(), grad_in.data_ptr<scalar_t>());
+        });
+  } else if (pool_method == 1) {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+        grad_in.scalar_type(), "roiaware_avgpool3d_backward", [&] {
+          roiaware_avgpool3d_backward<scalar_t><<<blocks, threads, 0, stream>>>(
+              boxes_num, channels, out_x, out_y, out_z, max_pts_each_voxel,
+              pts_idx_of_voxels.data_ptr<int>(), grad_out.data_ptr<scalar_t>(),
+              grad_in.data_ptr<scalar_t>());
+        });
+  }
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
diff --git a/mmcv/ops/csrc/pytorch/musa/roipoint_pool3d_musa.mu b/mmcv/ops/csrc/pytorch/musa/roipoint_pool3d_musa.mu
new file mode 100644
index 000000000..a4c11e7e6
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/roipoint_pool3d_musa.mu
@@ -0,0 +1,60 @@
+/*
+Modified from
+https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/roipoint_pool3d/src/roipoint_pool3d_kernel.cu
+Point cloud feature pooling
+Written by Shaoshuai Shi
+All Rights Reserved 2018.
+*/
+
+#include <math.h>
+#include <stdio.h>
+
+#include "pytorch_musa_helper.hpp"
+#include "roipoint_pool3d_musa_kernel.muh"
+
+void RoIPointPool3dForwardMUSAKernelLauncher(
+    int batch_size, int pts_num, int boxes_num, int feature_in_len,
+    int sampled_pts_num, const Tensor xyz, const Tensor boxes3d,
+    const Tensor pts_feature, Tensor pooled_features,
+    Tensor pooled_empty_flag) {
+  Tensor pts_assign = at::empty({batch_size, pts_num, boxes_num},
+                                boxes3d.options().dtype(at::kInt));
+
+  c10::musa::MUSAGuard device_guard(xyz.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+
+  // blockIdx.x(col), blockIdx.y(row)
+  dim3 blocks(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), boxes_num, batch_size);
+  dim3 threads(THREADS_PER_BLOCK);
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      xyz.scalar_type(), "assign_pts_to_box3d", [&] {
+        assign_pts_to_box3d<scalar_t><<<blocks, threads, 0, stream>>>(
+            batch_size, pts_num, boxes_num, xyz.data_ptr<scalar_t>(),
+            boxes3d.data_ptr<scalar_t>(), pts_assign.data_ptr<int>());
+      });
+
+  Tensor pts_idx = at::empty({batch_size, boxes_num, sampled_pts_num},
+                             boxes3d.options().dtype(at::kInt));
+
+  // blockIdx.x(col), blockIdx.y(row)
+  dim3 blocks2(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK), batch_size);
+
+  get_pooled_idx<<<blocks2, threads, 0, stream>>>(
+      batch_size, pts_num, boxes_num, sampled_pts_num,
+      pts_assign.data_ptr<int>(), pts_idx.data_ptr<int>(),
+      pooled_empty_flag.data_ptr<int>());
+
+  dim3 blocks_pool(GET_BLOCKS(sampled_pts_num, THREADS_PER_BLOCK), boxes_num,
+                   batch_size);
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      xyz.scalar_type(), "roipoint_pool3d_forward", [&] {
+        roipoint_pool3d_forward<scalar_t><<<blocks_pool, threads, 0, stream>>>(
+            batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num,
+            xyz.data_ptr<scalar_t>(), pts_idx.data_ptr<int>(),
+            pts_feature.data_ptr<scalar_t>(),
+            pooled_features.data_ptr<scalar_t>(),
+            pooled_empty_flag.data_ptr<int>());
+      });
+}
diff --git a/mmcv/ops/csrc/pytorch/musa/rotated_feature_align_musa.mu b/mmcv/ops/csrc/pytorch/musa/rotated_feature_align_musa.mu
new file mode 100644
index 000000000..dd9ffe6c0
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/rotated_feature_align_musa.mu
@@ -0,0 +1,53 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+// Modified from
+// https://github.com/SJTU-Thinklab-Det/r3det-on-mmdetection/blob/master/mmdet/ops/fr/src/feature_refine_kernel.cu
+#include "pytorch_musa_helper.hpp"
+#include "rotated_feature_align_musa_kernel.muh"
+
+void RotatedFeatureAlignForwardMUSAKernelLauncher(const Tensor features,
+                                                  const Tensor best_bboxes,
+                                                  const float spatial_scale,
+                                                  const int points,
+                                                  Tensor output) {
+  c10::musa::MUSAGuard device_guard(features.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  const int output_size = features.numel();
+  AT_DISPATCH_FLOATING_TYPES(
+      features.scalar_type(), "rotated_feature_align_forward_musa_kernel",
+      ([&] {
+        const scalar_t* bottom_data = features.data_ptr<scalar_t>();
+        const scalar_t* bboxes_data = best_bboxes.data_ptr<scalar_t>();
+        scalar_t* top_data = output.data_ptr<scalar_t>();
+
+        rotated_feature_align_forward_kernel<scalar_t>
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
+                output_size, points, bottom_data, bboxes_data,
+                scalar_t(spatial_scale), features.size(1), features.size(2),
+                features.size(3), top_data);
+      }));
+  AT_MUSA_CHECK(musaGetLastError());
+}
+
+void RotatedFeatureAlignBackwardMUSAKernelLauncher(const Tensor top_grad,
+                                                   const Tensor best_bboxes,
+                                                   const float spatial_scale,
+                                                   const int points,
+                                                   Tensor bottom_grad) {
+  c10::musa::MUSAGuard device_guard(top_grad.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  const int output_size = top_grad.numel();
+  AT_DISPATCH_FLOATING_TYPES(
+      top_grad.scalar_type(), "rotated_feature_align_backward_musa_kernel",
+      ([&] {
+        const scalar_t* top_diff = top_grad.data_ptr<scalar_t>();
+        const scalar_t* bboxes_data = best_bboxes.data_ptr<scalar_t>();
+        scalar_t* bottom_diff = bottom_grad.data_ptr<scalar_t>();
+
+        rotated_feature_align_backward_kernel<scalar_t>
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
+                output_size, points, top_diff, bboxes_data,
+                scalar_t(spatial_scale), top_grad.size(1), top_grad.size(2),
+                top_grad.size(3), bottom_diff);
+      }));
+  AT_MUSA_CHECK(musaGetLastError());
+}
diff --git a/mmcv/ops/csrc/pytorch/musa/scatter_points_musa.mu b/mmcv/ops/csrc/pytorch/musa/scatter_points_musa.mu
new file mode 100644
index 000000000..1edca61a4
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/scatter_points_musa.mu
@@ -0,0 +1,132 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+#include <stdio.h>
+#include <stdlib.h>
+#include <torch/types.h>
+
+#include "pytorch_musa_helper.hpp"
+#include "scatter_points_musa_kernel.muh"
+
+std::vector<at::Tensor> DynamicPointToVoxelForwardMUSAKernelLauncher(
+    const at::Tensor &feats, const at::Tensor &coors,
+    const reduce_t reduce_type) {
+  const int num_input = feats.size(0);
+  const int num_feats = feats.size(1);
+
+  if (num_input == 0)
+    return {feats.clone().detach(), coors.clone().detach(),
+            coors.new_empty({0}, torch::kInt32),
+            coors.new_empty({0}, torch::kInt32)};
+
+  at::Tensor out_coors;
+  at::Tensor coors_map;
+  at::Tensor reduce_count;
+
+  auto coors_clean = coors.masked_fill(coors.lt(0).any(-1, true), -1);
+
+  std::tie(out_coors, coors_map, reduce_count) =
+      at::unique_dim(coors_clean, 0, true, true, true);
+
+  if (out_coors[0][0].lt(0).item<bool>()) {
+    // the first element of out_coors (-1,-1,-1) and should be removed
+    out_coors = out_coors.slice(0, 1);
+    reduce_count = reduce_count.slice(0, 1);
+    coors_map = coors_map - 1;
+  }
+
+  coors_map = coors_map.to(torch::kInt32);
+  reduce_count = reduce_count.to(torch::kInt32);
+
+  auto reduced_feats =
+      at::empty({out_coors.size(0), num_feats}, feats.options());
+
+  c10::musa::MUSAGuard device_guard(feats.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+
+  AT_DISPATCH_FLOATING_TYPES(
+      feats.scalar_type(), "feats_reduce_kernel", ([&] {
+        if (reduce_type == reduce_t::MAX)
+          reduced_feats.fill_(-std::numeric_limits<scalar_t>::infinity());
+        else
+          reduced_feats.fill_(static_cast<scalar_t>(0));
+
+        dim3 blocks(std::min(
+            at::musa::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim));
+        dim3 threads(THREADS_PER_BLOCK);
+        feats_reduce_kernel<<<blocks, threads, 0, stream>>>(
+            feats.data_ptr<scalar_t>(), coors_map.data_ptr<int32_t>(),
+            reduced_feats.data_ptr<scalar_t>(), num_input, num_feats,
+            reduce_type);
+        if (reduce_type == reduce_t::MEAN)
+          reduced_feats /= reduce_count.unsqueeze(-1).to(reduced_feats.dtype());
+      }));
+
+  AT_MUSA_CHECK(musaGetLastError());
+
+  return {reduced_feats, out_coors, coors_map, reduce_count};
+}
+
+void DynamicPointToVoxelBackwardMUSAKernelLauncher(
+    at::Tensor &grad_feats, const at::Tensor &grad_reduced_feats,
+    const at::Tensor &feats, const at::Tensor &reduced_feats,
+    const at::Tensor &coors_map, const at::Tensor &reduce_count,
+    const reduce_t reduce_type) {
+  const int num_input = feats.size(0);
+  const int num_reduced = reduced_feats.size(0);
+  const int num_feats = feats.size(1);
+
+  grad_feats.fill_(0);
+  // copy voxel grad to points
+
+  if (num_input == 0 || num_reduced == 0) return;
+  c10::musa::MUSAGuard device_guard(feats.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+
+  if (reduce_type == reduce_t::MEAN || reduce_type == reduce_t::SUM) {
+    AT_DISPATCH_FLOATING_TYPES(
+        grad_reduced_feats.scalar_type(), "add_reduce_traceback_grad_kernel",
+        ([&] {
+          dim3 blocks(std::min(
+              at::musa::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim));
+          dim3 threads(THREADS_PER_BLOCK);
+          add_reduce_traceback_grad_kernel<<<blocks, threads, 0, stream>>>(
+              grad_feats.data_ptr<scalar_t>(),
+              grad_reduced_feats.data_ptr<scalar_t>(),
+              coors_map.data_ptr<int32_t>(), reduce_count.data_ptr<int32_t>(),
+              num_input, num_feats, reduce_type);
+        }));
+
+    AT_MUSA_CHECK(musaGetLastError());
+  } else {
+    auto reduce_from = at::full({num_reduced, num_feats}, num_input,
+                                coors_map.options().dtype(torch::kInt32));
+    AT_DISPATCH_FLOATING_TYPES(
+        grad_reduced_feats.scalar_type(),
+        "max_reduce_traceback_scatter_idx_kernel", ([&] {
+          dim3 blocks(std::min(
+              at::musa::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim));
+          dim3 threads(THREADS_PER_BLOCK);
+          max_reduce_traceback_scatter_idx_kernel<<<blocks, threads, 0,
+                                                    stream>>>(
+              feats.data_ptr<scalar_t>(), reduced_feats.data_ptr<scalar_t>(),
+              reduce_from.data_ptr<int32_t>(), coors_map.data_ptr<int32_t>(),
+              num_input, num_feats);
+        }));
+
+    AT_MUSA_CHECK(musaGetLastError());
+
+    AT_DISPATCH_FLOATING_TYPES(
+        grad_reduced_feats.scalar_type(),
+        "max_reduce_traceback_scatter_idx_kernel", ([&] {
+          dim3 blocks(
+              std::min(at::musa::ATenCeilDiv(num_reduced, THREADS_PER_BLOCK),
+                       maxGridDim));
+          dim3 threads(THREADS_PER_BLOCK);
+          max_reduce_scatter_grad_kernel<<<blocks, threads, 0, stream>>>(
+              grad_feats.data_ptr<scalar_t>(),
+              grad_reduced_feats.data_ptr<scalar_t>(),
+              reduce_from.data_ptr<int32_t>(), num_reduced, num_feats);
+        }));
+
+    AT_MUSA_CHECK(musaGetLastError());
+  }
+}
diff --git a/mmcv/ops/csrc/pytorch/musa/sparse_maxpool.mu b/mmcv/ops/csrc/pytorch/musa/sparse_maxpool.mu
new file mode 100644
index 000000000..67a69c176
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/sparse_maxpool.mu
@@ -0,0 +1,486 @@
+// Copyright 2019 Yan Yan
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <ATen/ATen.h>
+// clang-format off
+// TODO: make spconv_utils.h order agnostic
+#include "../spconv_utils.h"
+// clang-format on
+#include <utils/spconv/spconv/maxpool.h>
+#include <utils/spconv/spconv/mp_helper.h>
+#include <utils/spconv/tensorview/helper_launch.h>
+#include <utils/spconv/tensorview/tensorview.h>
+
+#include <chrono>
+#include <limits>
+#include <type_traits>
+#include <utils/spconv/tensorview/helper_kernel.muh>
+
+#include "pytorch_musa_helper.hpp"
+
+template <typename scalar_t, typename Index, int NumTLP, int NumILP>
+__global__ void maxPoolFwdBlockKernel(scalar_t *outFeatures,
+                                      const scalar_t *inFeatures,
+                                      const Index *indicesIn,
+                                      const Index *indicesOut, int numHot,
+                                      int numPlanes) {
+  scalar_t in, out;
+  int ILPStrideY[NumILP];
+  Index idxo, idxi;
+#pragma unroll
+  for (int ilp = 0; ilp < NumILP; ilp++)
+    ILPStrideY[ilp] = threadIdx.y + ilp * blockDim.y;
+  outFeatures += blockIdx.y * NumTLP;
+  inFeatures += blockIdx.y * NumTLP;
+  for (int ix = blockIdx.x * blockDim.x; ix < numHot;
+       ix += blockDim.x * gridDim.x) {
+    {
+#pragma unroll
+      for (int ilp = 0; ilp < NumILP; ++ilp) {
+        idxi = indicesIn[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
+        idxo = indicesOut[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
+        in = inFeatures[idxi];
+        out = outFeatures[idxo];
+        if (in > out) {
+          outFeatures[idxo] = in;
+        }
+      }
+    }
+  }
+}
+
+template <typename scalar_t, typename Index, int NumTLP, int NumILP>
+__global__ void maxPoolFwdGenericBlockKernel(scalar_t *outFeatures,
+                                             const scalar_t *inFeatures,
+                                             const Index *indicesIn,
+                                             const Index *indicesOut,
+                                             int numHot, int numPlanes) {
+  int ILPStrideX[NumILP];
+  Index RI[NumILP];
+  Index RO[NumILP];
+  scalar_t in, out;
+#pragma unroll
+  for (int ilp = 0; ilp < NumILP; ilp++)
+    ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
+  for (int ix : tv::KernelLoopX<int, NumILP>(numHot)) {
+#pragma unroll
+    for (int ilp = 0; ilp < NumILP; ilp++) {
+      RI[ilp] = indicesIn[ix + ILPStrideX[ilp]] * numPlanes;
+      RO[ilp] = indicesOut[ix + ILPStrideX[ilp]] * numPlanes;
+    }
+    for (int iy : tv::KernelLoopY<int>(numPlanes)) {
+#pragma unroll
+      for (int ilp = 0; ilp < NumILP; ++ilp) {
+        in = inFeatures[RI[ilp] + iy];
+        out = outFeatures[RO[ilp] + iy];
+        if (in > out) {
+          outFeatures[RO[ilp] + iy] = in;
+        }
+      }
+    }
+  }
+}
+
+template <typename scalar_t, typename Index, int NumTLP, int NumILP,
+          typename VecType>
+__global__ void maxPoolFwdVecBlockKernel(scalar_t *outFeatures,
+                                         const scalar_t *inFeatures,
+                                         const Index *indicesIn,
+                                         const Index *indicesOut, int numHot,
+                                         int numPlanes) {
+  int ILPStrideY[NumILP];
+  constexpr int vecloadFactor = sizeof(VecType) / sizeof(scalar_t);
+  scalar_t bufi[vecloadFactor];
+  scalar_t bufo[vecloadFactor];
+  Index idxi, idxo;
+#pragma unroll
+  for (int ilp = 0; ilp < NumILP; ilp++)
+    ILPStrideY[ilp] = threadIdx.y + ilp * blockDim.y;
+  outFeatures += blockIdx.y * NumTLP;
+  inFeatures += blockIdx.y * NumTLP;
+  for (int ix = blockIdx.x * blockDim.x * vecloadFactor; ix < numHot;
+       ix += blockDim.x * gridDim.x * vecloadFactor) {
+#pragma unroll
+    for (int ilp = 0; ilp < NumILP; ++ilp) {
+      idxi = indicesIn[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
+      idxo = indicesOut[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
+      reinterpret_cast<VecType *>(bufo)[0] =
+          reinterpret_cast<VecType *>(outFeatures)[idxo];
+      reinterpret_cast<VecType *>(bufi)[0] =
+          reinterpret_cast<const VecType *>(inFeatures)[idxi];
+#pragma unroll
+      for (int i = 0; i < vecloadFactor; i++) {
+        if (bufi[i] > bufo[i]) {
+          bufo[i] = bufi[i];
+        }
+      }
+      reinterpret_cast<VecType *>(outFeatures)[idxo] =
+          reinterpret_cast<VecType *>(bufo)[0];
+    }
+  }
+}
+
+template <typename scalar_t, typename Index, int NumTLP, int NumILP>
+__global__ void maxPoolFwdGenericKernel(scalar_t *outFeatures,
+                                        const scalar_t *inFeatures,
+                                        const Index *indicesIn,
+                                        const Index *indicesOut, int numHot,
+                                        int numPlanes) {
+  int ILPStrideX[NumILP];
+  Index RI[NumILP];
+  Index RO[NumILP];
+  scalar_t in, out;
+#pragma unroll
+  for (int ilp = 0; ilp < NumILP; ilp++)
+    ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
+  for (int ix : tv::KernelLoopX<int, NumILP>(numHot)) {
+#pragma unroll
+    for (int ilp = 0; ilp < NumILP; ilp++) {
+      if (ix + ILPStrideX[ilp] < numHot) {
+        RI[ilp] = indicesIn[ix + ILPStrideX[ilp]] * numPlanes;
+        RO[ilp] = indicesOut[ix + ILPStrideX[ilp]] * numPlanes;
+      }
+    }
+    for (int iy : tv::KernelLoopY<int>(numPlanes)) {
+#pragma unroll
+      for (int ilp = 0; ilp < NumILP; ++ilp) {
+        if (ix + ILPStrideX[ilp] < numHot) {
+          in = inFeatures[RI[ilp] + iy];
+          out = outFeatures[RO[ilp] + iy];
+          if (in > out) {
+            outFeatures[RO[ilp] + iy] = in;
+          }
+        }
+      }
+    }
+  }
+}
+
+template <typename scalar_t, typename Index, int NumTLP, int NumILP>
+__global__ void maxPoolBwdBlockKernel(const scalar_t *outFeatures,
+                                      const scalar_t *inFeatures,
+                                      const scalar_t *fout, scalar_t *fin,
+                                      const Index *indicesIn,
+                                      const Index *indicesOut, int numHot,
+                                      int numPlanes) {
+  scalar_t in, out;
+  Index idxo, idxi;
+  int ILPStrideY[NumILP];
+#pragma unroll
+  for (int ilp = 0; ilp < NumILP; ilp++)
+    ILPStrideY[ilp] = threadIdx.y + ilp * blockDim.y;
+  outFeatures += blockIdx.y * NumTLP;
+  inFeatures += blockIdx.y * NumTLP;
+  fout += blockIdx.y * NumTLP;
+  fin += blockIdx.y * NumTLP;
+  for (int ix = blockIdx.x * blockDim.x; ix < numHot;
+       ix += blockDim.x * gridDim.x) {
+    {
+#pragma unroll
+      for (int ilp = 0; ilp < NumILP; ++ilp) {
+        idxi = indicesIn[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
+        idxo = indicesOut[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
+        in = inFeatures[idxi];
+        out = outFeatures[idxo];
+        if (in == out) {
+          fin[idxi] += fout[idxo];
+        }
+      }
+    }
+  }
+}
+
+template <typename scalar_t, typename Index, int NumTLP, int NumILP>
+__global__ void maxPoolBwdGenericBlockKernel(
+    const scalar_t *outFeatures, const scalar_t *inFeatures,
+    const scalar_t *fout, scalar_t *fin, const Index *indicesIn,
+    const Index *indicesOut, int numHot, int numPlanes) {
+  int ILPStrideX[NumILP];
+  Index RI[NumILP];
+  Index RO[NumILP];
+  scalar_t in, out;
+#pragma unroll
+  for (int ilp = 0; ilp < NumILP; ilp++)
+    ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
+  for (int ix : tv::KernelLoopX<int, NumILP>(numHot)) {
+#pragma unroll
+    for (int ilp = 0; ilp < NumILP; ilp++) {
+      RI[ilp] = indicesIn[ix + ILPStrideX[ilp]] * numPlanes;
+      RO[ilp] = indicesOut[ix + ILPStrideX[ilp]] * numPlanes;
+    }
+    for (int iy : tv::KernelLoopY<int>(numPlanes)) {
+#pragma unroll
+      for (int ilp = 0; ilp < NumILP; ++ilp) {
+        in = inFeatures[RI[ilp] + iy];
+        out = outFeatures[RO[ilp] + iy];
+        if (in == out) {
+          fin[RI[ilp] + iy] += fout[RO[ilp] + iy];
+        }
+      }
+    }
+  }
+}
+
+template <typename scalar_t, typename Index, int NumTLP, int NumILP,
+          typename VecType>
+__global__ void maxPoolBwdVecBlockKernel(const scalar_t *outFeatures,
+                                         const scalar_t *inFeatures,
+                                         const scalar_t *fout, scalar_t *fin,
+                                         const Index *indicesIn,
+                                         const Index *indicesOut, int numHot,
+                                         int numPlanes) {
+  int ILPStrideY[NumILP];
+  constexpr int vecloadFactor = sizeof(VecType) / sizeof(scalar_t);
+  scalar_t bufi[vecloadFactor];
+  scalar_t bufo[vecloadFactor];
+  scalar_t bufdi[vecloadFactor];
+  scalar_t bufdo[vecloadFactor];
+  Index idxi, idxo;
+#pragma unroll
+  for (int ilp = 0; ilp < NumILP; ilp++)
+    ILPStrideY[ilp] = threadIdx.y + ilp * blockDim.y;
+  outFeatures += blockIdx.y * NumTLP;
+  inFeatures += blockIdx.y * NumTLP;
+  for (int ix = blockIdx.x * blockDim.x * vecloadFactor; ix < numHot;
+       ix += blockDim.x * gridDim.x * vecloadFactor) {
+#pragma unroll
+    for (int ilp = 0; ilp < NumILP; ++ilp) {
+      idxi = indicesIn[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
+      idxo = indicesOut[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x;
+      reinterpret_cast<VecType *>(bufo)[0] =
+          reinterpret_cast<const VecType *>(outFeatures)[idxo];
+      reinterpret_cast<VecType *>(bufi)[0] =
+          reinterpret_cast<const VecType *>(inFeatures)[idxi];
+      reinterpret_cast<VecType *>(bufdo)[0] =
+          reinterpret_cast<const VecType *>(fout)[idxo];
+      reinterpret_cast<VecType *>(bufdi)[0] =
+          reinterpret_cast<VecType *>(fin)[idxi];
+
+#pragma unroll
+      for (int i = 0; i < vecloadFactor; i++) {
+        if (bufi[i] == bufo[i]) {
+          bufdi[i] += bufdo[i];
+        }
+      }
+      reinterpret_cast<VecType *>(fin)[idxi] =
+          reinterpret_cast<VecType *>(bufdi)[0];
+    }
+  }
+}
+
+template <typename scalar_t, typename Index, int NumTLP, int NumILP>
+__global__ void maxPoolBwdGenericKernel(const scalar_t *outFeatures,
+                                        const scalar_t *inFeatures,
+                                        const scalar_t *fout, scalar_t *fin,
+                                        const Index *indicesIn,
+                                        const Index *indicesOut, int numHot,
+                                        int numPlanes) {
+  int ILPStrideX[NumILP];
+  Index RI[NumILP];
+  Index RO[NumILP];
+  scalar_t in, out;
+#pragma unroll
+  for (int ilp = 0; ilp < NumILP; ilp++)
+    ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
+  for (int ix : tv::KernelLoopX<int, NumILP>(numHot)) {
+#pragma unroll
+    for (int ilp = 0; ilp < NumILP; ilp++) {
+      if (ix + ILPStrideX[ilp] < numHot) {
+        RI[ilp] = indicesIn[ix + ILPStrideX[ilp]] * numPlanes;
+        RO[ilp] = indicesOut[ix + ILPStrideX[ilp]] * numPlanes;
+      }
+    }
+    for (int iy : tv::KernelLoopY<int>(numPlanes)) {
+#pragma unroll
+      for (int ilp = 0; ilp < NumILP; ++ilp) {
+        if (ix + ILPStrideX[ilp] < numHot) {
+          in = inFeatures[RI[ilp] + iy];
+          out = outFeatures[RO[ilp] + iy];
+          if (in == out) {
+            fin[RI[ilp] + iy] += fout[RO[ilp] + iy];
+          }
+        }
+      }
+    }
+  }
+}
+
+namespace functor {
+template <typename scalar_t, typename Index>
+struct SparseMaxPoolForwardFunctor<tv::TorchGPU, scalar_t, Index> {
+  using vecload_type_t =
+      std::conditional_t<std::is_same<scalar_t, at::Half>::value, int2, int4>;
+  using kernel_block_t = mp_list_c<int, 64, 32, 16>;
+  void operator()(const tv::TorchGPU &d, tv::TensorView<scalar_t> outFeatures,
+                  tv::TensorView<const scalar_t> inFeatures,
+                  tv::TensorView<const Index> indices, int size) {
+    if (size <= 0) return;
+    int numPlanes = inFeatures.dim(1);
+    bool notFound = true;
+    constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(scalar_t);
+    mp_for_each<kernel_block_t>([=, &outFeatures, &inFeatures, &indices,
+                                 &notFound](auto NumTLP) {
+      constexpr int NumILP = NumTLP / 4;
+
+      int numHotBlock = (size / NumTLP) * NumTLP;
+      if (notFound) {
+        if (numPlanes % NumTLP == 0) {
+          if (numHotBlock >= NumTLP) {
+            maxPoolFwdVecBlockKernel<scalar_t, Index, int(NumTLP), NumILP,
+                                     vecload_type_t>
+                <<<dim3(std::min(size / NumTLP, 512), numPlanes / NumTLP),
+                   dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0,
+                   d.getStream()>>>(outFeatures.data(), inFeatures.data(),
+                                    indices.subview(0).data(),
+                                    indices.subview(1).data(), numHotBlock,
+                                    numPlanes / vecloadFactor);
+            TV_CHECK_MUSA_ERR();
+          }
+
+          if (size > numHotBlock) {
+            maxPoolFwdGenericKernel<scalar_t, Index, int(NumTLP), NumILP>
+                <<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP),
+                   0, d.getStream()>>>(outFeatures.data(), inFeatures.data(),
+                                       indices.subview(0).data() + numHotBlock,
+                                       indices.subview(1).data() + numHotBlock,
+                                       size - numHotBlock, numPlanes);
+            TV_CHECK_MUSA_ERR();
+          }
+          notFound = false;
+        }
+      }
+    });
+
+    if (notFound) {
+      constexpr int NumTLP = 64;
+      constexpr int NumILP = NumTLP / 4;
+      int numHotBlock = (size / NumTLP) * NumTLP;
+      if (numHotBlock >= NumTLP) {
+        maxPoolFwdGenericBlockKernel<scalar_t, Index, NumTLP, NumILP>
+            <<<dim3(size / NumTLP, tv::launch::DivUp(numPlanes, NumTLP)),
+               dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>(
+                outFeatures.data(), inFeatures.data(),
+                indices.subview(0).data(), indices.subview(1).data(),
+                numHotBlock, numPlanes);
+        TV_CHECK_MUSA_ERR();
+      }
+
+      if (size > numHotBlock) {
+        maxPoolFwdGenericKernel<scalar_t, Index, NumTLP, NumILP>
+            <<<dim3(1, tv::launch::DivUp(numPlanes, NumTLP)),
+               dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>(
+                outFeatures.data(), inFeatures.data(),
+                indices.subview(0).data() + numHotBlock,
+                indices.subview(1).data() + numHotBlock, size - numHotBlock,
+                numPlanes);
+        TV_CHECK_MUSA_ERR();
+      }
+    }
+  }
+};
+
+template <typename scalar_t, typename Index>
+struct SparseMaxPoolBackwardFunctor<tv::TorchGPU, scalar_t, Index> {
+  using vecload_type_t =
+      std::conditional_t<std::is_same<scalar_t, at::Half>::value, int2, int4>;
+  using kernel_block_t = mp_list_c<int, 64, 32, 16>;
+  void operator()(const tv::TorchGPU &d,
+                  tv::TensorView<const scalar_t> outFeatures,
+                  tv::TensorView<const scalar_t> inFeatures,
+                  tv::TensorView<const scalar_t> fout,
+                  tv::TensorView<scalar_t> fin,
+                  tv::TensorView<const Index> indices, int size) {
+    if (size <= 0) return;
+    int numPlanes = inFeatures.dim(1);
+    bool notFound = true;
+    constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(scalar_t);
+    mp_for_each<kernel_block_t>([=, &outFeatures, &inFeatures, &fout, &fin,
+                                 &indices, &notFound](auto NumTLP) {
+      constexpr int NumILP = NumTLP / 4;
+
+      int numHotBlock = (size / NumTLP) * NumTLP;
+      if (notFound) {
+        if (numPlanes % NumTLP == 0) {
+          if (numHotBlock >= NumTLP) {
+            maxPoolBwdVecBlockKernel<scalar_t, Index, int(NumTLP), NumILP,
+                                     vecload_type_t>
+                <<<dim3(std::min(size / NumTLP, 512), numPlanes / NumTLP),
+                   dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0,
+                   d.getStream()>>>(outFeatures.data(), inFeatures.data(),
+                                    fout.data(), fin.data(),
+                                    indices.subview(0).data(),
+                                    indices.subview(1).data(), numHotBlock,
+                                    numPlanes / vecloadFactor);
+            TV_CHECK_MUSA_ERR();
+          }
+
+          if (size > numHotBlock) {
+            maxPoolBwdGenericKernel<scalar_t, Index, int(NumTLP), NumILP>
+                <<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP),
+                   0, d.getStream()>>>(outFeatures.data(), inFeatures.data(),
+                                       fout.data(), fin.data(),
+                                       indices.subview(0).data() + numHotBlock,
+                                       indices.subview(1).data() + numHotBlock,
+                                       size - numHotBlock, numPlanes);
+            TV_CHECK_MUSA_ERR();
+          }
+          notFound = false;
+        }
+      }
+    });
+
+    if (notFound) {
+      constexpr int NumTLP = 64;
+      constexpr int NumILP = NumTLP / 4;
+      int numHotBlock = (size / NumTLP) * NumTLP;
+      if (numHotBlock >= NumTLP) {
+        maxPoolBwdGenericBlockKernel<scalar_t, Index, NumTLP, NumILP>
+            <<<dim3(size / NumTLP, tv::launch::DivUp(numPlanes, NumTLP)),
+               dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>(
+                outFeatures.data(), inFeatures.data(), fout.data(), fin.data(),
+                indices.subview(0).data(), indices.subview(1).data(),
+                numHotBlock, numPlanes);
+        TV_CHECK_MUSA_ERR();
+      }
+
+      if (size > numHotBlock) {
+        maxPoolBwdGenericKernel<scalar_t, Index, NumTLP, NumILP>
+            <<<dim3(1, tv::launch::DivUp(numPlanes, NumTLP)),
+               dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>(
+                outFeatures.data(), inFeatures.data(), fout.data(), fin.data(),
+                indices.subview(0).data() + numHotBlock,
+                indices.subview(1).data() + numHotBlock, size - numHotBlock,
+                numPlanes);
+        TV_CHECK_MUSA_ERR();
+      }
+    }
+  }
+};
+
+}  // namespace functor
+
+#define DECLARE_GPU_SPECS_T_INDEX(scalar_t, Index)                             \
+  template struct functor::SparseMaxPoolForwardFunctor<tv::TorchGPU, scalar_t, \
+                                                       Index>;                 \
+  template struct functor::SparseMaxPoolBackwardFunctor<tv::TorchGPU,          \
+                                                        scalar_t, Index>;
+
+#define DECLARE_GPU_SPECS(scalar_t) DECLARE_GPU_SPECS_T_INDEX(scalar_t, int);
+
+DECLARE_GPU_SPECS(float);
+DECLARE_GPU_SPECS(double);
+DECLARE_GPU_SPECS(at::Half);
+
+#undef DECLARE_GPU_SPECS
+#undef DECLARE_GPU_SPECS_T_INDEX
diff --git a/mmcv/ops/csrc/pytorch/musa/sparse_pool_ops_musa.mu b/mmcv/ops/csrc/pytorch/musa/sparse_pool_ops_musa.mu
new file mode 100644
index 000000000..a4ce9b2d5
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/sparse_pool_ops_musa.mu
@@ -0,0 +1,91 @@
+#include <musa_runtime_api.h>
+#include <torch/script.h>
+// clang-format off
+// TODO: make spconv_utils.h order agnostic
+#include "../spconv_utils.h"
+// clang-format on
+#include <utils/spconv/spconv/maxpool.h>
+
+#include "pytorch_musa_helper.hpp"
+
+torch::Tensor IndiceMaxpoolForwardMUSAKernelLauncher(torch::Tensor features,
+                                                     torch::Tensor indicePairs,
+                                                     torch::Tensor indiceNum,
+                                                     int64_t numAct) {
+  c10::musa::MUSAGuard device_guard(features.device());
+  auto device = features.device().type();
+  auto kernelVolume = indicePairs.size(0);
+  auto numInPlanes = features.size(1);
+  auto indicePairNumCpu = indiceNum.to({torch::kCPU});
+  auto options =
+      torch::TensorOptions().dtype(features.dtype()).device(features.device());
+  torch::Tensor output = torch::zeros({numAct, numInPlanes}, options);
+  for (int i = 0; i < kernelVolume; ++i) {
+    auto nHot = indicePairNumCpu.data_ptr<int>()[i];
+    if (nHot <= 0) {
+      continue;
+    }
+    AT_DISPATCH_FLOATING_TYPES(
+        features.scalar_type(), "IndiceMaxpoolForwardKernel", [&] {
+          if (device == torch::kCPU) {
+            functor::SparseMaxPoolForwardFunctor<tv::CPU, scalar_t, int>
+                forwardFtor;
+            forwardFtor(tv::CPU(), tv::torch2tv<scalar_t>(output),
+                        tv::torch2tv<const scalar_t>(features),
+                        tv::torch2tv<const int>(indicePairs).subview(i), nHot);
+          } else {
+            functor::SparseMaxPoolForwardFunctor<tv::TorchGPU, scalar_t, int>
+                forwardFtor;
+            forwardFtor(tv::TorchGPU(), tv::torch2tv<scalar_t>(output),
+                        tv::torch2tv<const scalar_t>(features),
+                        tv::torch2tv<const int>(indicePairs).subview(i), nHot);
+            TV_CHECK_MUSA_ERR();
+          }
+        });
+  }
+  return output;
+}
+
+torch::Tensor IndiceMaxpoolBackwardMUSAKernelLauncher(torch::Tensor features,
+                                                      torch::Tensor outFeatures,
+                                                      torch::Tensor outGrad,
+                                                      torch::Tensor indicePairs,
+                                                      torch::Tensor indiceNum) {
+  c10::musa::MUSAGuard device_guard(features.device());
+  auto device = features.device().type();
+  auto numInPlanes = features.size(1);
+  auto indicePairNumCpu = indiceNum.to({torch::kCPU});
+  auto options =
+      torch::TensorOptions().dtype(features.dtype()).device(features.device());
+  torch::Tensor inputGrad = torch::zeros(features.sizes(), options);
+  auto kernelVolume = indicePairs.size(0);
+  for (int i = 0; i < kernelVolume; ++i) {
+    auto nHot = indicePairNumCpu.data_ptr<int>()[i];
+    if (nHot <= 0) {
+      continue;
+    }
+    AT_DISPATCH_FLOATING_TYPES(
+        features.scalar_type(), "IndiceMaxpoolBackwardKernel", [&] {
+          if (device == torch::kCPU) {
+            functor::SparseMaxPoolBackwardFunctor<tv::CPU, scalar_t, int>
+                backwardFtor;
+            backwardFtor(tv::CPU(), tv::torch2tv<const scalar_t>(outFeatures),
+                         tv::torch2tv<const scalar_t>(features),
+                         tv::torch2tv<const scalar_t>(outGrad),
+                         tv::torch2tv<scalar_t>(inputGrad),
+                         tv::torch2tv<const int>(indicePairs).subview(i), nHot);
+          } else {
+            functor::SparseMaxPoolBackwardFunctor<tv::TorchGPU, scalar_t, int>
+                backwardFtor;
+            backwardFtor(tv::TorchGPU(),
+                         tv::torch2tv<const scalar_t>(outFeatures),
+                         tv::torch2tv<const scalar_t>(features),
+                         tv::torch2tv<const scalar_t>(outGrad),
+                         tv::torch2tv<scalar_t>(inputGrad),
+                         tv::torch2tv<const int>(indicePairs).subview(i), nHot);
+            TV_CHECK_MUSA_ERR();
+          }
+        });
+  }
+  return inputGrad;
+}
diff --git a/mmcv/ops/csrc/pytorch/musa/sync_bn_musa.mu b/mmcv/ops/csrc/pytorch/musa/sync_bn_musa.mu
new file mode 100644
index 000000000..56327f4ed
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/sync_bn_musa.mu
@@ -0,0 +1,110 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#include "pytorch_musa_helper.hpp"
+#include "sync_bn_musa_kernel.muh"
+
+void SyncBNForwardMeanMUSAKernelLauncher(const Tensor input, Tensor mean) {
+  int num = input.size(0);
+  int channels = input.size(1);
+  int spatial = input.size(2);
+
+  c10::musa::MUSAGuard device_guard(input.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  AT_DISPATCH_FLOATING_TYPES(
+      input.scalar_type(), "sync_bn_forward_mean_musa_kernel", [&] {
+        sync_bn_forward_mean_musa_kernel<scalar_t>
+            <<<channels, THREADS_PER_BLOCK, 0, stream>>>(
+                input.data_ptr<scalar_t>(), mean.data_ptr<float>(), num,
+                channels, spatial);
+      });
+  AT_MUSA_CHECK(musaGetLastError());
+}
+
+void SyncBNForwardVarMUSAKernelLauncher(const Tensor input, const Tensor mean,
+                                        Tensor var) {
+  int num = input.size(0);
+  int channels = input.size(1);
+  int spatial = input.size(2);
+
+  c10::musa::MUSAGuard device_guard(input.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  AT_DISPATCH_FLOATING_TYPES(
+      input.scalar_type(), "sync_bn_forward_mean_musa_kernel", [&] {
+        sync_bn_forward_var_musa_kernel<scalar_t>
+            <<<channels, THREADS_PER_BLOCK, 0, stream>>>(
+                input.data_ptr<scalar_t>(), mean.data_ptr<float>(),
+                var.data_ptr<float>(), num, channels, spatial);
+      });
+  AT_MUSA_CHECK(musaGetLastError());
+}
+
+void SyncBNForwardOutputMUSAKernelLauncher(
+    const Tensor input, const Tensor mean, const Tensor var,
+    Tensor running_mean, Tensor running_var, const Tensor weight,
+    const Tensor bias, Tensor norm, Tensor std, Tensor output, float eps,
+    float momentum, int group_size) {
+  int num = input.size(0);
+  int channels = input.size(1);
+  int spatial = input.size(2);
+
+  c10::musa::MUSAGuard device_guard(input.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  AT_DISPATCH_FLOATING_TYPES(
+      input.scalar_type(), "sync_bn_forward_mean_musa_kernel", [&] {
+        sync_bn_forward_output_musa_kernel<scalar_t>
+            <<<channels, THREADS_PER_BLOCK, 0, stream>>>(
+                input.data_ptr<scalar_t>(), mean.data_ptr<float>(),
+                var.data_ptr<float>(), running_mean.data_ptr<float>(),
+                running_var.data_ptr<float>(), weight.data_ptr<float>(),
+                bias.data_ptr<float>(), norm.data_ptr<float>(),
+                std.data_ptr<float>(), output.data_ptr<scalar_t>(), num,
+                channels, spatial, eps, momentum, group_size);
+      });
+  AT_MUSA_CHECK(musaGetLastError());
+}
+
+void SyncBNBackwardParamMUSAKernelLauncher(const Tensor grad_output,
+                                           const Tensor norm,
+                                           Tensor grad_weight,
+                                           Tensor grad_bias) {
+  int num = grad_output.size(0);
+  int channels = grad_output.size(1);
+  int spatial = grad_output.size(2);
+
+  c10::musa::MUSAGuard device_guard(grad_output.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  AT_DISPATCH_FLOATING_TYPES(
+      grad_output.scalar_type(), "sync_bn_backward_param_musa_kernel", [&] {
+        sync_bn_backward_param_musa_kernel<scalar_t>
+            <<<channels, THREADS_PER_BLOCK, 0, stream>>>(
+                grad_output.data_ptr<scalar_t>(), norm.data_ptr<float>(),
+                grad_weight.data_ptr<float>(), grad_bias.data_ptr<float>(), num,
+                channels, spatial);
+      });
+  AT_MUSA_CHECK(musaGetLastError());
+}
+
+void SyncBNBackwardDataMUSAKernelLauncher(const Tensor grad_output,
+                                          const Tensor weight,
+                                          const Tensor grad_weight,
+                                          const Tensor grad_bias,
+                                          const Tensor norm, const Tensor std,
+                                          Tensor grad_input) {
+  int output_size = grad_input.numel();
+  int num = grad_input.size(0);
+  int channels = grad_input.size(1);
+  int spatial = grad_input.size(2);
+
+  c10::musa::MUSAGuard device_guard(grad_input.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  AT_DISPATCH_FLOATING_TYPES(
+      grad_output.scalar_type(), "sync_bn_backward_data_musa_kernel", [&] {
+        sync_bn_backward_data_musa_kernel<scalar_t>
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
+                output_size, grad_output.data_ptr<scalar_t>(),
+                weight.data_ptr<float>(), grad_weight.data_ptr<float>(),
+                grad_bias.data_ptr<float>(), norm.data_ptr<float>(),
+                std.data_ptr<float>(), grad_input.data_ptr<scalar_t>(), num,
+                channels, spatial);
+      });
+  AT_MUSA_CHECK(musaGetLastError());
+}
diff --git a/mmcv/ops/csrc/pytorch/musa/three_interpolate_musa.mu b/mmcv/ops/csrc/pytorch/musa/three_interpolate_musa.mu
new file mode 100644
index 000000000..c48314cc3
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/three_interpolate_musa.mu
@@ -0,0 +1,66 @@
+// Modified from
+// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate_gpu.cu
+
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "pytorch_musa_helper.hpp"
+#include "three_interpolate_musa_kernel.muh"
+
+void ThreeInterpolateForwardMUSAKernelLauncher(int b, int c, int m, int n,
+                                               const Tensor points,
+                                               const Tensor idx,
+                                               const Tensor weight,
+                                               Tensor out) {
+  // points: (B, C, M)
+  // idx: (B, N, 3)
+  // weight: (B, N, 3)
+  // output:
+  //      out: (B, C, N)
+
+  c10::musa::MUSAGuard device_guard(points.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+
+  // blockIdx.x(col), blockIdx.y(row)
+  dim3 blocks(GET_BLOCKS(n, THREADS_PER_BLOCK), c, b);
+  dim3 threads(THREADS_PER_BLOCK);
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      points.scalar_type(), "three_interpolate_forward_musa_kernel", [&] {
+        three_interpolate_forward_musa_kernel<scalar_t>
+            <<<blocks, threads, 0, stream>>>(
+                b, c, m, n, points.data_ptr<scalar_t>(), idx.data_ptr<int>(),
+                weight.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
+      });
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
+
+void ThreeInterpolateBackwardMUSAKernelLauncher(int b, int c, int n, int m,
+                                                const Tensor grad_out,
+                                                const Tensor idx,
+                                                const Tensor weight,
+                                                Tensor grad_points) {
+  // grad_out: (B, C, N)
+  // weight: (B, N, 3)
+  // output:
+  //      grad_points: (B, C, M)
+
+  c10::musa::MUSAGuard device_guard(grad_out.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+
+  // blockIdx.x(col), blockIdx.y(row)
+  dim3 blocks(GET_BLOCKS(n, THREADS_PER_BLOCK), c, b);
+  dim3 threads(THREADS_PER_BLOCK);
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      grad_out.scalar_type(), "three_interpolate_backward_musa_kernel", [&] {
+        three_interpolate_backward_musa_kernel<scalar_t>
+            <<<blocks, threads, 0, stream>>>(
+                b, c, n, m, grad_out.data_ptr<scalar_t>(), idx.data_ptr<int>(),
+                weight.data_ptr<scalar_t>(), grad_points.data_ptr<scalar_t>());
+      });
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
diff --git a/mmcv/ops/csrc/pytorch/musa/three_nn_musa.mu b/mmcv/ops/csrc/pytorch/musa/three_nn_musa.mu
new file mode 100644
index 000000000..b69caa403
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/three_nn_musa.mu
@@ -0,0 +1,35 @@
+// Modified from
+// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate_gpu.cu
+
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "pytorch_musa_helper.hpp"
+#include "three_nn_musa_kernel.muh"
+
+void ThreeNNForwardMUSAKernelLauncher(int b, int n, int m, const Tensor unknown,
+                                      const Tensor known, Tensor dist2,
+                                      Tensor idx) {
+  // unknown: (B, N, 3)
+  // known: (B, M, 3)
+  // output:
+  //      dist2: (B, N, 3)
+  //      idx: (B, N, 3)
+
+  c10::musa::MUSAGuard device_guard(unknown.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+
+  // blockIdx.x(col), blockIdx.y(row)
+  dim3 blocks(GET_BLOCKS(n, THREADS_PER_BLOCK), b);
+  dim3 threads(THREADS_PER_BLOCK);
+
+  AT_DISPATCH_FLOATING_TYPES(
+      unknown.scalar_type(), "three_nn_forward_musa_kernel", [&] {
+        three_nn_forward_musa_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
+            b, n, m, unknown.data_ptr<scalar_t>(), known.data_ptr<scalar_t>(),
+            dist2.data_ptr<scalar_t>(), idx.data_ptr<int>());
+      });
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
diff --git a/mmcv/ops/csrc/pytorch/musa/tin_shift_musa.mu b/mmcv/ops/csrc/pytorch/musa/tin_shift_musa.mu
new file mode 100644
index 000000000..705c208d6
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/tin_shift_musa.mu
@@ -0,0 +1,55 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#include "pytorch_musa_helper.hpp"
+#include "pytorch_device_registry.hpp"
+#include "tin_shift_musa_kernel.muh"
+
+void TINShiftForwardMUSAKernelLauncher(Tensor input, Tensor shift,
+                                       Tensor output) {
+  int output_size = output.numel();
+  int batch_size = input.size(0);
+  int t_size = input.size(1);
+  int channels = input.size(2);
+  int hw_size = input.size(3);
+  int group_size = shift.size(1);
+  int group_channel = channels / group_size;
+  int num_kernels = batch_size * hw_size * channels;
+
+  c10::musa::MUSAGuard device_guard(input.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      input.scalar_type(), "tin_shift_forward_musa_kernel", [&] {
+        tin_shift_forward_musa_kernel<scalar_t>
+            <<<GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, stream>>>(
+                output_size, input.data_ptr<scalar_t>(), shift.data_ptr<int>(),
+                output.data_ptr<scalar_t>(), batch_size, channels, t_size,
+                hw_size, group_size, group_channel);
+      });
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
+
+void TINShiftBackwardMUSAKernelLauncher(Tensor grad_output, Tensor shift,
+                                        Tensor grad_input) {
+  int output_size = grad_output.numel();
+  int batch_size = grad_output.size(0);
+  int t_size = grad_output.size(1);
+  int channels = grad_output.size(2);
+  int hw_size = grad_output.size(3);
+  int group_size = shift.size(1);
+  int group_channel = channels / group_size;
+  int num_kernels = batch_size * hw_size * channels;
+
+  c10::musa::MUSAGuard device_guard(grad_output.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      grad_output.scalar_type(), "tin_shift_backward_musa_kernel", [&] {
+        tin_shift_backward_musa_kernel<scalar_t>
+            <<<GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, stream>>>(
+                output_size, grad_output.data_ptr<scalar_t>(),
+                shift.data_ptr<int>(), grad_input.data_ptr<scalar_t>(),
+                batch_size, channels, t_size, hw_size, group_size,
+                group_channel);
+      });
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
diff --git a/mmcv/ops/csrc/pytorch/musa/upfirdn2d_kernel.mu b/mmcv/ops/csrc/pytorch/musa/upfirdn2d_kernel.mu
new file mode 100644
index 000000000..9b9a2ffe8
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/upfirdn2d_kernel.mu
@@ -0,0 +1,749 @@
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto.  Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+#include <c10/util/Half.h>
+#include <torch/types.h>
+
+#include "pytorch_musa_helper.hpp"
+#if MUSA_ARCH > 21
+struct upfirdn2d_kernel_params {
+  const void *x;
+  const float *f;
+  void *y;
+
+  int2 up;
+  int2 down;
+  int2 pad0;
+  int flip;
+  float gain;
+
+  int4 inSize;  // [width, height, channel, batch]
+  int4 inStride;
+  int2 filterSize;  // [width, height]
+  int2 filterStride;
+  int4 outSize;  // [width, height, channel, batch]
+  int4 outStride;
+  int sizeMinor;
+  int sizeMajor;
+
+  int loopMinor;
+  int loopMajor;
+  int loopX;
+  int launchMinor;
+  int launchMajor;
+};
+
+//------------------------------------------------------------------------
+// MUSA kernel specialization.
+
+struct upfirdn2d_kernel_spec {
+  void *kernel;
+  int tileOutW;
+  int tileOutH;
+  int loopMinor;
+  int loopX;
+};
+
+//------------------------------------------------------------------------
+// MUSA kernel selection.
+
+template <class T>
+upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params &p);
+//------------------------------------------------------------------------
+
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto.  Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template <class T>
+struct InternalType;
+template <>
+struct InternalType<double> {
+  typedef double scalar_t;
+};
+template <>
+struct InternalType<float> {
+  typedef float scalar_t;
+};
+template <>
+struct InternalType<c10::Half> {
+  typedef float scalar_t;
+};
+
+static __device__ __forceinline__ int floor_div(int a, int b) {
+  int t = 1 - a / b;
+  return (a + t * b) / b - t;
+}
+
+//------------------------------------------------------------------------
+// Generic MUSA implementation for large filters.
+
+template <class T>
+static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) {
+  typedef typename InternalType<T>::scalar_t scalar_t;
+
+  // Calculate thread index.
+  int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
+  int outY = minorBase / p.launchMinor;
+  minorBase -= outY * p.launchMinor;
+  int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
+  int majorBase = blockIdx.z * p.loopMajor;
+  if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
+    return;
+
+  // Setup Y receptive field.
+  int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
+  int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
+  int h =
+      min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
+  int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
+  if (p.flip) filterY = p.filterSize.y - 1 - filterY;
+
+  // Loop over major, minor, and X.
+  for (int majorIdx = 0, major = majorBase;
+       majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+    for (int minorIdx = 0, minor = minorBase;
+         minorIdx < p.loopMinor & minor < p.sizeMinor;
+         minorIdx++, minor += p.launchMinor) {
+      int nc = major * p.sizeMinor + minor;
+      int n = nc / p.inSize.z;
+      int c = nc - n * p.inSize.z;
+      for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x;
+           loopX++, outX += blockDim.y) {
+        // Setup X receptive field.
+        int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
+        int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
+        int w =
+            min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) -
+            inX;
+        int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
+        if (p.flip) filterX = p.filterSize.x - 1 - filterX;
+
+        // Initialize pointers.
+        const T *xp =
+            &((const T *)p.x)[inX * p.inStride.x + inY * p.inStride.y +
+                              c * p.inStride.z + n * p.inStride.w];
+        const float *fp =
+            &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
+        int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
+        int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
+
+        // Inner loop.
+        scalar_t v = 0;
+        for (int y = 0; y < h; y++) {
+          for (int x = 0; x < w; x++) {
+            v += (scalar_t)(*xp) * (scalar_t)(*fp);
+            xp += p.inStride.x;
+            fp += filterStepX;
+          }
+          xp += p.inStride.y - w * p.inStride.x;
+          fp += filterStepY - w * filterStepX;
+        }
+
+        // Store result.
+        v *= p.gain;
+        ((T *)p.y)[outX * p.outStride.x + outY * p.outStride.y +
+                   c * p.outStride.z + n * p.outStride.w] = (T)v;
+      }
+    }
+}
+
+//------------------------------------------------------------------------
+// Specialized MUSA implementation for small filters.
+
+template <class T, int upx, int upy, int downx, int downy, int filterW,
+          int filterH, int tileOutW, int tileOutH, int loopMinor>
+static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) {
+  typedef typename InternalType<T>::scalar_t scalar_t;
+  const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
+  const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
+  __shared__ volatile scalar_t sf[filterH][filterW];
+  __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
+
+  // Calculate tile index.
+  int minorBase = blockIdx.x;
+  int tileOutY = minorBase / p.launchMinor;
+  minorBase -= tileOutY * p.launchMinor;
+  minorBase *= loopMinor;
+  tileOutY *= tileOutH;
+  int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
+  int majorBase = blockIdx.z * p.loopMajor;
+  if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y |
+      majorBase >= p.sizeMajor)
+    return;
+
+  // Load filter (flipped).
+  for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW;
+       tapIdx += blockDim.x) {
+    int fy = tapIdx / filterW;
+    int fx = tapIdx - fy * filterW;
+    scalar_t v = 0;
+    if (fx < p.filterSize.x & fy < p.filterSize.y) {
+      int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
+      int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
+      v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
+    }
+    sf[fy][fx] = v;
+  }
+
+  // Loop over major and X.
+  for (int majorIdx = 0, major = majorBase;
+       majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) {
+    int baseNC = major * p.sizeMinor + minorBase;
+    int n = baseNC / p.inSize.z;
+    int baseC = baseNC - n * p.inSize.z;
+    for (int loopX = 0, tileOutX = tileOutXBase;
+         loopX < p.loopX & tileOutX < p.outSize.x;
+         loopX++, tileOutX += tileOutW) {
+      // Load input pixels.
+      int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
+      int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
+      int tileInX = floor_div(tileMidX, upx);
+      int tileInY = floor_div(tileMidY, upy);
+      __syncthreads();
+      for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor;
+           inIdx += blockDim.x) {
+        int relC = inIdx;
+        int relInX = relC / loopMinor;
+        int relInY = relInX / tileInW;
+        relC -= relInX * loopMinor;
+        relInX -= relInY * tileInW;
+        int c = baseC + relC;
+        int inX = tileInX + relInX;
+        int inY = tileInY + relInY;
+        scalar_t v = 0;
+        if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y &
+            c < p.inSize.z)
+          v = (scalar_t)(
+              (const T *)p.x)[inX * p.inStride.x + inY * p.inStride.y +
+                              c * p.inStride.z + n * p.inStride.w];
+        sx[relInY][relInX][relC] = v;
+      }
+
+      // Loop over output pixels.
+      __syncthreads();
+      for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor;
+           outIdx += blockDim.x) {
+        int relC = outIdx;
+        int relOutX = relC / loopMinor;
+        int relOutY = relOutX / tileOutW;
+        relC -= relOutX * loopMinor;
+        relOutX -= relOutY * tileOutW;
+        int c = baseC + relC;
+        int outX = tileOutX + relOutX;
+        int outY = tileOutY + relOutY;
+
+        // Setup receptive field.
+        int midX = tileMidX + relOutX * downx;
+        int midY = tileMidY + relOutY * downy;
+        int inX = floor_div(midX, upx);
+        int inY = floor_div(midY, upy);
+        int relInX = inX - tileInX;
+        int relInY = inY - tileInY;
+        int filterX = (inX + 1) * upx - midX - 1;  // flipped
+        int filterY = (inY + 1) * upy - midY - 1;  // flipped
+
+        // Inner loop.
+        if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) {
+          scalar_t v = 0;
+#pragma unroll
+          for (int y = 0; y < filterH / upy; y++)
+#pragma unroll
+            for (int x = 0; x < filterW / upx; x++)
+              v += sx[relInY + y][relInX + x][relC] *
+                   sf[filterY + y * upy][filterX + x * upx];
+          v *= p.gain;
+          ((T *)p.y)[outX * p.outStride.x + outY * p.outStride.y +
+                     c * p.outStride.z + n * p.outStride.w] = (T)v;
+        }
+      }
+    }
+  }
+}
+
+//------------------------------------------------------------------------
+// MUSA kernel selection.
+
+template <class T>
+upfirdn2d_kernel_spec choose_upfirdn2d_kernel(
+    const upfirdn2d_kernel_params &p) {
+  int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
+  upfirdn2d_kernel_spec spec = {(void *)upfirdn2d_kernel_large<T>, -1, -1, 1,
+                                4};  // contiguous
+  if (s == 1)
+    spec = {(void *)upfirdn2d_kernel_large<T>, -1, -1, 4, 1};  // channels_last
+
+  // No up/downsampling.
+  if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) {
+    // contiguous
+    if (s != 1 && fx <= 24 && fy <= 24)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 24, 24, 64, 32, 1>,
+              64, 32, 1, 1};
+    if (s != 1 && fx <= 16 && fy <= 16)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 16, 16, 64, 32, 1>,
+              64, 32, 1, 1};
+    if (s != 1 && fx <= 7 && fy <= 7)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 7, 7, 64, 16, 1>,
+              64, 16, 1, 1};
+    if (s != 1 && fx <= 6 && fy <= 6)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 6, 6, 64, 16, 1>,
+              64, 16, 1, 1};
+    if (s != 1 && fx <= 5 && fy <= 5)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 5, 5, 64, 16, 1>,
+              64, 16, 1, 1};
+    if (s != 1 && fx <= 4 && fy <= 4)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 4, 4, 64, 16, 1>,
+              64, 16, 1, 1};
+    if (s != 1 && fx <= 3 && fy <= 3)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 3, 3, 64, 16, 1>,
+              64, 16, 1, 1};
+    if (s != 1 && fx <= 24 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 24, 1, 128, 8, 1>,
+              128, 8, 1, 1};
+    if (s != 1 && fx <= 16 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 16, 1, 128, 8, 1>,
+              128, 8, 1, 1};
+    if (s != 1 && fx <= 8 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 8, 1, 128, 8, 1>,
+              128, 8, 1, 1};
+    if (s != 1 && fx <= 1 && fy <= 24)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 24, 32, 32, 1>,
+              32, 32, 1, 1};
+    if (s != 1 && fx <= 1 && fy <= 16)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 16, 32, 32, 1>,
+              32, 32, 1, 1};
+    if (s != 1 && fx <= 1 && fy <= 8)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 8, 32, 32, 1>,
+              32, 32, 1, 1};
+    // channels_last
+    if (s == 1 && fx <= 24 && fy <= 24)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 24, 24, 32, 32, 1>,
+              32, 32, 1, 1};
+    if (s == 1 && fx <= 16 && fy <= 16)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 16, 16, 32, 32, 1>,
+              32, 32, 1, 1};
+    if (s == 1 && fx <= 7 && fy <= 7)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 7, 7, 16, 16, 8>,
+              16, 16, 8, 1};
+    if (s == 1 && fx <= 6 && fy <= 6)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 6, 6, 16, 16, 8>,
+              16, 16, 8, 1};
+    if (s == 1 && fx <= 5 && fy <= 5)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 5, 5, 16, 16, 8>,
+              16, 16, 8, 1};
+    if (s == 1 && fx <= 4 && fy <= 4)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 4, 4, 16, 16, 8>,
+              16, 16, 8, 1};
+    if (s == 1 && fx <= 3 && fy <= 3)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 3, 3, 16, 16, 8>,
+              16, 16, 8, 1};
+    if (s == 1 && fx <= 24 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 24, 1, 128, 1, 16>,
+              128, 1, 16, 1};
+    if (s == 1 && fx <= 16 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 16, 1, 128, 1, 16>,
+              128, 1, 16, 1};
+    if (s == 1 && fx <= 8 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 8, 1, 128, 1, 16>,
+              128, 1, 16, 1};
+    if (s == 1 && fx <= 1 && fy <= 24)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 24, 1, 128, 16>,
+              1, 128, 16, 1};
+    if (s == 1 && fx <= 1 && fy <= 16)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 16, 1, 128, 16>,
+              1, 128, 16, 1};
+    if (s == 1 && fx <= 1 && fy <= 8)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 1, 1, 8, 1, 128, 16>,
+              1, 128, 16, 1};
+  }
+
+  // 2x upsampling.
+  if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) {
+    // contiguous
+    if (s != 1 && fx <= 24 && fy <= 24)
+      spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 24, 24, 64, 32, 1>,
+              64, 32, 1, 1};
+    if (s != 1 && fx <= 16 && fy <= 16)
+      spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 16, 16, 64, 32, 1>,
+              64, 32, 1, 1};
+    if (s != 1 && fx <= 8 && fy <= 8)
+      spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 8, 8, 64, 16, 1>,
+              64, 16, 1, 1};
+    if (s != 1 && fx <= 6 && fy <= 6)
+      spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 6, 6, 64, 16, 1>,
+              64, 16, 1, 1};
+    if (s != 1 && fx <= 4 && fy <= 4)
+      spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 4, 4, 64, 16, 1>,
+              64, 16, 1, 1};
+    if (s != 1 && fx <= 2 && fy <= 2)
+      spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 2, 2, 64, 16, 1>,
+              64, 16, 1, 1};
+    // channels_last
+    if (s == 1 && fx <= 24 && fy <= 24)
+      spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 24, 24, 32, 32, 1>,
+              32, 32, 1, 1};
+    if (s == 1 && fx <= 16 && fy <= 16)
+      spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 16, 16, 32, 32, 1>,
+              32, 32, 1, 1};
+    if (s == 1 && fx <= 8 && fy <= 8)
+      spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 8, 8, 16, 16, 8>,
+              16, 16, 8, 1};
+    if (s == 1 && fx <= 6 && fy <= 6)
+      spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 6, 6, 16, 16, 8>,
+              16, 16, 8, 1};
+    if (s == 1 && fx <= 4 && fy <= 4)
+      spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 4, 4, 16, 16, 8>,
+              16, 16, 8, 1};
+    if (s == 1 && fx <= 2 && fy <= 2)
+      spec = {(void *)upfirdn2d_kernel_small<T, 2, 2, 1, 1, 2, 2, 16, 16, 8>,
+              16, 16, 8, 1};
+  }
+  if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) {
+    // contiguous
+    if (s != 1 && fx <= 24 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 24, 1, 128, 8, 1>,
+              128, 8, 1, 1};
+    if (s != 1 && fx <= 16 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 16, 1, 128, 8, 1>,
+              128, 8, 1, 1};
+    if (s != 1 && fx <= 8 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 8, 1, 128, 8, 1>,
+              128, 8, 1, 1};
+    // channels_last
+    if (s == 1 && fx <= 24 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 24, 1, 128, 1, 16>,
+              128, 1, 16, 1};
+    if (s == 1 && fx <= 16 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 16, 1, 128, 1, 16>,
+              128, 1, 16, 1};
+    if (s == 1 && fx <= 8 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 2, 1, 1, 1, 8, 1, 128, 1, 16>,
+              128, 1, 16, 1};
+  }
+  if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) {
+    // contiguous
+    if (s != 1 && fx <= 1 && fy <= 24)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 24, 32, 32, 1>,
+              32, 32, 1, 1};
+    if (s != 1 && fx <= 1 && fy <= 16)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 16, 32, 32, 1>,
+              32, 32, 1, 1};
+    if (s != 1 && fx <= 1 && fy <= 8)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 8, 32, 32, 1>,
+              32, 32, 1, 1};
+    // channels_last
+    if (s == 1 && fx <= 1 && fy <= 24)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 24, 1, 128, 16>,
+              1, 128, 16, 1};
+    if (s == 1 && fx <= 1 && fy <= 16)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 16, 1, 128, 16>,
+              1, 128, 16, 1};
+    if (s == 1 && fx <= 1 && fy <= 8)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 2, 1, 1, 1, 8, 1, 128, 16>,
+              1, 128, 16, 1};
+  }
+
+  // 2x downsampling.
+  if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) {
+    // contiguous
+    if (s != 1 && fx <= 24 && fy <= 24)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 24, 24, 32, 16, 1>,
+              32, 16, 1, 1};
+    if (s != 1 && fx <= 16 && fy <= 16)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 16, 16, 32, 16, 1>,
+              32, 16, 1, 1};
+    if (s != 1 && fx <= 8 && fy <= 8)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 8, 8, 32, 8, 1>, 32,
+              8, 1, 1};
+    if (s != 1 && fx <= 6 && fy <= 6)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 6, 6, 32, 8, 1>, 32,
+              8, 1, 1};
+    if (s != 1 && fx <= 4 && fy <= 4)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 4, 4, 32, 8, 1>, 32,
+              8, 1, 1};
+    if (s != 1 && fx <= 2 && fy <= 2)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 2, 2, 32, 8, 1>, 32,
+              8, 1, 1};
+    // channels_last
+    if (s == 1 && fx <= 24 && fy <= 24)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 24, 24, 16, 16, 1>,
+              16, 16, 1, 1};
+    if (s == 1 && fx <= 16 && fy <= 16)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 16, 16, 16, 16, 1>,
+              16, 16, 1, 1};
+    if (s == 1 && fx <= 8 && fy <= 8)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 8, 8, 8, 8, 8>, 8,
+              8, 8, 1};
+    if (s == 1 && fx <= 6 && fy <= 6)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 6, 6, 8, 8, 8>, 8,
+              8, 8, 1};
+    if (s == 1 && fx <= 4 && fy <= 4)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 4, 4, 8, 8, 8>, 8,
+              8, 8, 1};
+    if (s == 1 && fx <= 2 && fy <= 2)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 2, 2, 2, 8, 8, 8>, 8,
+              8, 8, 1};
+  }
+  if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) {
+    // contiguous
+    if (s != 1 && fx <= 24 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 24, 1, 64, 8, 1>,
+              64, 8, 1, 1};
+    if (s != 1 && fx <= 16 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 16, 1, 64, 8, 1>,
+              64, 8, 1, 1};
+    if (s != 1 && fx <= 8 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 8, 1, 64, 8, 1>, 64,
+              8, 1, 1};
+    // channels_last
+    if (s == 1 && fx <= 24 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 24, 1, 64, 1, 8>,
+              64, 1, 8, 1};
+    if (s == 1 && fx <= 16 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 16, 1, 64, 1, 8>,
+              64, 1, 8, 1};
+    if (s == 1 && fx <= 8 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 2, 1, 8, 1, 64, 1, 8>, 64,
+              1, 8, 1};
+  }
+  if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) {
+    // contiguous
+    if (s != 1 && fx <= 1 && fy <= 24)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 24, 32, 16, 1>,
+              32, 16, 1, 1};
+    if (s != 1 && fx <= 1 && fy <= 16)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 16, 32, 16, 1>,
+              32, 16, 1, 1};
+    if (s != 1 && fx <= 1 && fy <= 8)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 8, 32, 16, 1>,
+              32, 16, 1, 1};
+    // channels_last
+    if (s == 1 && fx <= 1 && fy <= 24)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 24, 1, 64, 8>, 1,
+              64, 8, 1};
+    if (s == 1 && fx <= 1 && fy <= 16)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 16, 1, 64, 8>, 1,
+              64, 8, 1};
+    if (s == 1 && fx <= 1 && fy <= 8)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 2, 1, 8, 1, 64, 8>, 1,
+              64, 8, 1};
+  }
+
+  // 4x upsampling.
+  if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) {
+    // contiguous
+    if (s != 1 && fx <= 48 && fy <= 48)
+      spec = {(void *)upfirdn2d_kernel_small<T, 4, 4, 1, 1, 48, 48, 64, 32, 1>,
+              64, 32, 1, 1};
+    if (s != 1 && fx <= 32 && fy <= 32)
+      spec = {(void *)upfirdn2d_kernel_small<T, 4, 4, 1, 1, 32, 32, 64, 32, 1>,
+              64, 32, 1, 1};
+    // channels_last
+    if (s == 1 && fx <= 48 && fy <= 48)
+      spec = {(void *)upfirdn2d_kernel_small<T, 4, 4, 1, 1, 48, 48, 32, 32, 1>,
+              32, 32, 1, 1};
+    if (s == 1 && fx <= 32 && fy <= 32)
+      spec = {(void *)upfirdn2d_kernel_small<T, 4, 4, 1, 1, 32, 32, 32, 32, 1>,
+              32, 32, 1, 1};
+  }
+  if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) {
+    // contiguous
+    if (s != 1 && fx <= 48 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 4, 1, 1, 1, 48, 1, 128, 8, 1>,
+              128, 8, 1, 1};
+    if (s != 1 && fx <= 32 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 4, 1, 1, 1, 32, 1, 128, 8, 1>,
+              128, 8, 1, 1};
+    // channels_last
+    if (s == 1 && fx <= 48 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 4, 1, 1, 1, 48, 1, 128, 1, 16>,
+              128, 1, 16, 1};
+    if (s == 1 && fx <= 32 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 4, 1, 1, 1, 32, 1, 128, 1, 16>,
+              128, 1, 16, 1};
+  }
+  if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) {
+    // contiguous
+    if (s != 1 && fx <= 1 && fy <= 48)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 4, 1, 1, 1, 48, 32, 32, 1>,
+              32, 32, 1, 1};
+    if (s != 1 && fx <= 1 && fy <= 32)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 4, 1, 1, 1, 32, 32, 32, 1>,
+              32, 32, 1, 1};
+    // channels_last
+    if (s == 1 && fx <= 1 && fy <= 48)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 4, 1, 1, 1, 48, 1, 128, 16>,
+              1, 128, 16, 1};
+    if (s == 1 && fx <= 1 && fy <= 32)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 4, 1, 1, 1, 32, 1, 128, 16>,
+              1, 128, 16, 1};
+  }
+
+  // 4x downsampling (inefficient).
+  if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1) {
+    // contiguous
+    if (s != 1 && fx <= 48 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 4, 1, 48, 1, 32, 8, 1>,
+              32, 8, 1, 1};
+    if (s != 1 && fx <= 32 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 4, 1, 32, 1, 32, 8, 1>,
+              32, 8, 1, 1};
+    // channels_last
+    if (s == 1 && fx <= 48 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 4, 1, 48, 1, 32, 1, 8>,
+              32, 1, 8, 1};
+    if (s == 1 && fx <= 32 && fy <= 1)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 4, 1, 32, 1, 32, 1, 8>,
+              32, 1, 8, 1};
+  }
+  if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4) {
+    // contiguous
+    if (s != 1 && fx <= 1 && fy <= 48)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 4, 1, 48, 32, 8, 1>,
+              32, 8, 1, 1};
+    if (s != 1 && fx <= 1 && fy <= 32)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 4, 1, 32, 32, 8, 1>,
+              32, 8, 1, 1};
+    // channels_last
+    if (s == 1 && fx <= 1 && fy <= 48)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 4, 1, 48, 1, 32, 8>, 1,
+              32, 8, 1};
+    if (s == 1 && fx <= 1 && fy <= 32)
+      spec = {(void *)upfirdn2d_kernel_small<T, 1, 1, 1, 4, 1, 32, 1, 32, 8>, 1,
+              32, 8, 1};
+  }
+  return spec;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double>(
+    const upfirdn2d_kernel_params &p);
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float>(
+    const upfirdn2d_kernel_params &p);
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(
+    const upfirdn2d_kernel_params &p);
+
+//------------------------------------------------------------------------
+
+//------------------------------------------------------------------------
+
+torch::Tensor upfirdn2d_op(torch::Tensor x, torch::Tensor f, int upx, int upy,
+                           int downx, int downy, int padx0, int padx1,
+                           int pady0, int pady1, bool flip, float gain) {
+  // Validate arguments.
+  TORCH_CHECK(x.is_privateuseone(), "x must reside on MUSA device");
+  TORCH_CHECK(f.device() == x.device(),
+              "f must reside on the same device as x");
+  TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
+  TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+  TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
+  TORCH_CHECK(x.numel() > 0, "x has zero size");
+  TORCH_CHECK(f.numel() > 0, "f has zero size");
+  TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+  TORCH_CHECK(f.dim() == 2, "f must be rank 2");
+  TORCH_CHECK((x.size(0) - 1) * x.stride(0) + (x.size(1) - 1) * x.stride(1) +
+                      (x.size(2) - 1) * x.stride(2) +
+                      (x.size(3) - 1) * x.stride(3) <=
+                  INT_MAX,
+              "x memory footprint is too large");
+  TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
+  TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
+  TORCH_CHECK(downx >= 1 && downy >= 1,
+              "downsampling factor must be at least 1");
+
+  // Create output tensor.
+  const at::musa::OptionalMUSAGuard device_guard(device_of(x));
+  int outW =
+      ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
+  int outH =
+      ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
+  TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
+  torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW},
+                                 x.options(), x.suggest_memory_format());
+  TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
+  TORCH_CHECK((y.size(0) - 1) * y.stride(0) + (y.size(1) - 1) * y.stride(1) +
+                      (y.size(2) - 1) * y.stride(2) +
+                      (y.size(3) - 1) * y.stride(3) <=
+                  INT_MAX,
+              "output memory footprint is too large");
+
+  // Initialize MUSA kernel parameters.
+  upfirdn2d_kernel_params p;
+  p.x = x.data_ptr();
+  p.f = f.data_ptr<float>();
+  p.y = y.data_ptr();
+  p.up = make_int2(upx, upy);
+  p.down = make_int2(downx, downy);
+  p.pad0 = make_int2(padx0, pady0);
+  p.flip = (flip) ? 1 : 0;
+  p.gain = gain;
+  p.inSize =
+      make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+  p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1),
+                         (int)x.stride(0));
+  p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
+  p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
+  p.outSize =
+      make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
+  p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1),
+                          (int)y.stride(0));
+  p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
+  p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
+
+  // Choose MUSA kernel.
+  upfirdn2d_kernel_spec spec;
+  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "upfirdn2d_musa", [&] {
+    spec = choose_upfirdn2d_kernel<scalar_t>(p);
+  });
+
+  // Set looping options.
+  p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
+  p.loopMinor = spec.loopMinor;
+  p.loopX = spec.loopX;
+  p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
+  p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
+
+  // Compute grid size.
+  dim3 blockSize, gridSize;
+  if (spec.tileOutW < 0)  // large
+  {
+    blockSize = dim3(4, 32, 1);
+    gridSize =
+        dim3(((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
+             (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, p.launchMajor);
+  } else  // small
+  {
+    blockSize = dim3(256, 1, 1);
+    gridSize =
+        dim3(((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
+             (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, p.launchMajor);
+  }
+
+  // Launch MUSA kernel.
+  void *args[] = {&p};
+#ifdef MMCV_WITH_HIP
+  AT_MUSA_CHECK(hipLaunchKernel(spec.kernel, gridSize, blockSize, args, 0,
+                                c10::musa::getCurrentMUSAStream()));
+#else
+  AT_MUSA_CHECK(musaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0,
+                                 c10::musa::getCurrentMUSAStream()));
+#endif
+
+  return y;
+}
+#else
+#warning "upfirdn2d is supported when MUSA_ARCH > 21"
+#endif  //MUSA_ARCH
diff --git a/mmcv/ops/csrc/pytorch/musa/voxelization_musa.mu b/mmcv/ops/csrc/pytorch/musa/voxelization_musa.mu
new file mode 100644
index 000000000..b243871ca
--- /dev/null
+++ b/mmcv/ops/csrc/pytorch/musa/voxelization_musa.mu
@@ -0,0 +1,286 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "pytorch_musa_helper.hpp"
+#include "voxelization_musa_kernel.muh"
+
+int HardVoxelizeForwardMUSAKernelLauncher(
+    const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
+    at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
+    const std::vector<float> coors_range, const int max_points,
+    const int max_voxels, const int NDim = 3) {
+  // current version tooks about 0.04s for one frame on cpu
+  // check device
+
+  c10::musa::MUSAGuard device_guard(points.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+
+  const int num_points = points.size(0);
+  const int num_features = points.size(1);
+
+  const float voxel_x = voxel_size[0];
+  const float voxel_y = voxel_size[1];
+  const float voxel_z = voxel_size[2];
+  const float coors_x_min = coors_range[0];
+  const float coors_y_min = coors_range[1];
+  const float coors_z_min = coors_range[2];
+  const float coors_x_max = coors_range[3];
+  const float coors_y_max = coors_range[4];
+  const float coors_z_max = coors_range[5];
+
+  const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
+  const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
+  const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
+
+  // map points to voxel coors
+  at::Tensor temp_coors =
+      at::zeros({num_points, NDim}, points.options().dtype(at::kInt));
+
+  dim3 grid(std::min(at::musa::ATenCeilDiv(num_points, 512), 4096));
+  dim3 block(512);
+
+  // 1. link point to corresponding voxel coors
+  AT_DISPATCH_ALL_TYPES(
+      points.scalar_type(), "hard_voxelize_kernel", ([&] {
+        dynamic_voxelize_kernel<scalar_t, int><<<grid, block, 0, stream>>>(
+            points.contiguous().data_ptr<scalar_t>(),
+            temp_coors.contiguous().data_ptr<int>(), voxel_x, voxel_y, voxel_z,
+            coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max,
+            coors_z_max, grid_x, grid_y, grid_z, num_points, num_features,
+            NDim);
+      }));
+
+  AT_MUSA_CHECK(musaGetLastError());
+
+  // 2. map point to the idx of the corresponding voxel, find duplicate coor
+  // create some temporary variables
+  auto point_to_pointidx = -at::ones(
+      {
+          num_points,
+      },
+      points.options().dtype(at::kInt));
+  auto point_to_voxelidx = -at::ones(
+      {
+          num_points,
+      },
+      points.options().dtype(at::kInt));
+
+  dim3 map_grid(std::min(at::musa::ATenCeilDiv(num_points, 512), 4096));
+  dim3 map_block(512);
+
+  AT_DISPATCH_ALL_TYPES(
+      temp_coors.scalar_type(), "determin_duplicate", ([&] {
+        point_to_voxelidx_kernel<int><<<map_grid, map_block, 0, stream>>>(
+            temp_coors.contiguous().data_ptr<int>(),
+            point_to_voxelidx.contiguous().data_ptr<int>(),
+            point_to_pointidx.contiguous().data_ptr<int>(), max_points,
+            max_voxels, num_points, NDim);
+      }));
+
+  AT_MUSA_CHECK(musaGetLastError());
+
+  // 3. determine voxel num and voxel's coor index
+  // make the logic in the MUSA device could accelerate about 10 times
+  auto coor_to_voxelidx = -at::ones(
+      {
+          num_points,
+      },
+      points.options().dtype(at::kInt));
+  auto voxel_num = at::zeros(
+      {
+          1,
+      },
+      points.options().dtype(at::kInt));  // must be zero from the beginning
+
+  AT_DISPATCH_ALL_TYPES(temp_coors.scalar_type(), "determin_duplicate", ([&] {
+                          determin_voxel_num<int><<<1, 1, 0, stream>>>(
+                              num_points_per_voxel.contiguous().data_ptr<int>(),
+                              point_to_voxelidx.contiguous().data_ptr<int>(),
+                              point_to_pointidx.contiguous().data_ptr<int>(),
+                              coor_to_voxelidx.contiguous().data_ptr<int>(),
+                              voxel_num.contiguous().data_ptr<int>(),
+                              max_points, max_voxels, num_points);
+                        }));
+
+  AT_MUSA_CHECK(musaGetLastError());
+
+  // 4. copy point features to voxels
+  // Step 4 & 5 could be parallel
+  auto pts_output_size = num_points * num_features;
+  dim3 cp_grid(std::min(at::musa::ATenCeilDiv(pts_output_size, 512), 4096));
+  dim3 cp_block(512);
+  AT_DISPATCH_ALL_TYPES(
+      points.scalar_type(), "assign_point_to_voxel", ([&] {
+        assign_point_to_voxel<float, int><<<cp_grid, cp_block, 0, stream>>>(
+            pts_output_size, points.contiguous().data_ptr<float>(),
+            point_to_voxelidx.contiguous().data_ptr<int>(),
+            coor_to_voxelidx.contiguous().data_ptr<int>(),
+            voxels.contiguous().data_ptr<float>(), max_points, num_features,
+            num_points, NDim);
+      }));
+  //   musaDeviceSynchronize();
+  //   AT_MUSA_CHECK(musaGetLastError());
+
+  // 5. copy coors of each voxels
+  auto coors_output_size = num_points * NDim;
+  dim3 coors_cp_grid(
+      std::min(at::musa::ATenCeilDiv(coors_output_size, 512), 4096));
+  dim3 coors_cp_block(512);
+  AT_DISPATCH_ALL_TYPES(
+      points.scalar_type(), "assign_point_to_voxel", ([&] {
+        assign_voxel_coors<float, int>
+            <<<coors_cp_grid, coors_cp_block, 0, stream>>>(
+                coors_output_size, temp_coors.contiguous().data_ptr<int>(),
+                point_to_voxelidx.contiguous().data_ptr<int>(),
+                coor_to_voxelidx.contiguous().data_ptr<int>(),
+                coors.contiguous().data_ptr<int>(), num_points, NDim);
+      }));
+
+  AT_MUSA_CHECK(musaGetLastError());
+
+  auto voxel_num_cpu = voxel_num.to(at::kCPU);
+  int voxel_num_int = voxel_num_cpu.data_ptr<int>()[0];
+
+  return voxel_num_int;
+}
+
+int NondeterministicHardVoxelizeForwardMUSAKernelLauncher(
+    const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
+    at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
+    const std::vector<float> coors_range, const int max_points,
+    const int max_voxels, const int NDim = 3) {
+  c10::musa::MUSAGuard device_guard(points.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+
+  const int num_points = points.size(0);
+  const int num_features = points.size(1);
+
+  if (num_points == 0) return 0;
+
+  dim3 blocks(
+      std::min(at::musa::ATenCeilDiv(num_points, THREADS_PER_BLOCK), 4096));
+  dim3 threads(THREADS_PER_BLOCK);
+
+  const float voxel_x = voxel_size[0];
+  const float voxel_y = voxel_size[1];
+  const float voxel_z = voxel_size[2];
+  const float coors_x_min = coors_range[0];
+  const float coors_y_min = coors_range[1];
+  const float coors_z_min = coors_range[2];
+  const float coors_x_max = coors_range[3];
+  const float coors_y_max = coors_range[4];
+  const float coors_z_max = coors_range[5];
+
+  const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
+  const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
+  const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
+
+  // map points to voxel coors
+  at::Tensor temp_coors =
+      at::zeros({num_points, NDim}, points.options().dtype(at::kInt));
+
+  // 1. link point to corresponding voxel coors
+  AT_DISPATCH_ALL_TYPES(
+      points.scalar_type(), "hard_voxelize_kernel", ([&] {
+        dynamic_voxelize_kernel<scalar_t, int><<<blocks, threads, 0, stream>>>(
+            points.contiguous().data_ptr<scalar_t>(),
+            temp_coors.contiguous().data_ptr<int>(), voxel_x, voxel_y, voxel_z,
+            coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max,
+            coors_z_max, grid_x, grid_y, grid_z, num_points, num_features,
+            NDim);
+      }));
+
+  at::Tensor coors_map;
+  at::Tensor reduce_count;
+
+  auto coors_clean = temp_coors.masked_fill(temp_coors.lt(0).any(-1, true), -1);
+
+  std::tie(temp_coors, coors_map, reduce_count) =
+      at::unique_dim(coors_clean, 0, true, true, false);
+
+  if (temp_coors[0][0].lt(0).item<bool>()) {
+    // the first element of temp_coors is (-1,-1,-1) and should be removed
+    temp_coors = temp_coors.slice(0, 1);
+    coors_map = coors_map - 1;
+  }
+
+  int num_coors = temp_coors.size(0);
+  temp_coors = temp_coors.to(at::kInt);
+  coors_map = coors_map.to(at::kInt);
+
+  at::Tensor coors_count = at::zeros({1}, coors_map.options());
+  at::Tensor coors_order = at::empty({num_coors}, coors_map.options());
+  at::Tensor pts_id = at::zeros({num_points}, coors_map.options());
+  reduce_count = at::zeros({num_coors}, coors_map.options());
+
+  AT_DISPATCH_ALL_TYPES(
+      points.scalar_type(), "get_assign_pos", ([&] {
+        nondeterministic_get_assign_pos<<<blocks, threads, 0, stream>>>(
+            num_points, coors_map.contiguous().data_ptr<int32_t>(),
+            pts_id.contiguous().data_ptr<int32_t>(),
+            coors_count.contiguous().data_ptr<int32_t>(),
+            reduce_count.contiguous().data_ptr<int32_t>(),
+            coors_order.contiguous().data_ptr<int32_t>());
+      }));
+
+  AT_DISPATCH_ALL_TYPES(
+      points.scalar_type(), "assign_point_to_voxel", ([&] {
+        nondeterministic_assign_point_voxel<scalar_t>
+            <<<blocks, threads, 0, stream>>>(
+                num_points, points.contiguous().data_ptr<scalar_t>(),
+                coors_map.contiguous().data_ptr<int32_t>(),
+                pts_id.contiguous().data_ptr<int32_t>(),
+                temp_coors.contiguous().data_ptr<int32_t>(),
+                reduce_count.contiguous().data_ptr<int32_t>(),
+                coors_order.contiguous().data_ptr<int32_t>(),
+                voxels.contiguous().data_ptr<scalar_t>(),
+                coors.contiguous().data_ptr<int32_t>(),
+                num_points_per_voxel.contiguous().data_ptr<int32_t>(),
+                max_voxels, max_points, num_features, NDim);
+      }));
+  AT_MUSA_CHECK(musaGetLastError());
+  return max_voxels < num_coors ? max_voxels : num_coors;
+}
+
+void DynamicVoxelizeForwardMUSAKernelLauncher(
+    const at::Tensor &points, at::Tensor &coors,
+    const std::vector<float> voxel_size, const std::vector<float> coors_range,
+    const int NDim = 3) {
+  // current version tooks about 0.04s for one frame on cpu
+  // check device
+
+  c10::musa::MUSAGuard device_guard(points.device());
+  musaStream_t stream = c10::musa::getCurrentMUSAStream();
+
+  const int num_points = points.size(0);
+  const int num_features = points.size(1);
+
+  const float voxel_x = voxel_size[0];
+  const float voxel_y = voxel_size[1];
+  const float voxel_z = voxel_size[2];
+  const float coors_x_min = coors_range[0];
+  const float coors_y_min = coors_range[1];
+  const float coors_z_min = coors_range[2];
+  const float coors_x_max = coors_range[3];
+  const float coors_y_max = coors_range[4];
+  const float coors_z_max = coors_range[5];
+
+  const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
+  const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
+  const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
+
+  const int col_blocks = at::musa::ATenCeilDiv(num_points, THREADS_PER_BLOCK);
+  dim3 blocks(col_blocks);
+  dim3 threads(THREADS_PER_BLOCK);
+
+  AT_DISPATCH_ALL_TYPES(points.scalar_type(), "dynamic_voxelize_kernel", [&] {
+    dynamic_voxelize_kernel<scalar_t, int><<<blocks, threads, 0, stream>>>(
+        points.contiguous().data_ptr<scalar_t>(),
+        coors.contiguous().data_ptr<int>(), voxel_x, voxel_y, voxel_z,
+        coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max,
+        coors_z_max, grid_x, grid_y, grid_z, num_points, num_features, NDim);
+  });
+
+  AT_MUSA_CHECK(musaGetLastError());
+}
diff --git a/mmcv/ops/points_in_boxes.py b/mmcv/ops/points_in_boxes.py
index e58a6e2a1..06be6c2a0 100644
--- a/mmcv/ops/points_in_boxes.py
+++ b/mmcv/ops/points_in_boxes.py
@@ -1,4 +1,5 @@
 import torch
+from mmengine.device import is_cuda_available, is_musa_available
 from torch import Tensor
 
 from ..utils import ext_loader
@@ -10,7 +11,7 @@ ext_module = ext_loader.load_ext('_ext', [
 
 
 def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor:
-    """Find the box in which each point is (CUDA).
+    """Find the box in which each point is (CUDA/MUSA).
 
     Args:
         points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate.
@@ -38,7 +39,7 @@ def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor:
 
     # If manually put the tensor 'points' or 'boxes' on a device
     # which is not the current device, some temporary variables
-    # will be created on the current device in the cuda op,
+    # will be created on the current device in the cuda/musa op,
     # and the output will be incorrect.
     # Therefore, we force the current device to be the same
     # as the device of the tensors if it was not.
@@ -48,8 +49,12 @@ def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor:
     assert points_device == boxes.get_device(), \
         'Points and boxes should be put on the same device'
     if points.device.type != 'npu':
-        if torch.cuda.current_device() != points_device:
-            torch.cuda.set_device(points_device)
+        if is_cuda_available():
+            if torch.cuda.current_device() != points_device:
+                torch.cuda.set_device(points_device)
+        elif is_musa_available():
+            if torch.musa.current_device() != points_device:
+                torch.musa.set_device(points_device)
     else:
         boxes[:, :, 2] += boxes[:, :, 5] / 2.0
 
@@ -99,7 +104,7 @@ def points_in_boxes_cpu(points: Tensor, boxes: Tensor) -> Tensor:
 
 
 def points_in_boxes_all(points: Tensor, boxes: Tensor) -> Tensor:
-    """Find all boxes in which each point is (CUDA).
+    """Find all boxes in which each point is (CUDA/MUSA).
 
     Args:
         points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate
@@ -131,8 +136,12 @@ def points_in_boxes_all(points: Tensor, boxes: Tensor) -> Tensor:
     assert points_device == boxes.get_device(), \
         'Points and boxes should be put on the same device'
     if points.device.type != 'npu':
-        if torch.cuda.current_device() != points_device:
-            torch.cuda.set_device(points_device)
+        if is_cuda_available():
+            if torch.cuda.current_device() != points_device:
+                torch.cuda.set_device(points_device)
+        elif is_musa_available():
+            if torch.musa.current_device() != points_device:
+                torch.musa.set_device(points_device)
 
     ext_module.points_in_boxes_all_forward(boxes.contiguous(),
                                            points.contiguous(),
diff --git a/mmcv/ops/sync_bn.py b/mmcv/ops/sync_bn.py
index 2b14d3037..ef3c742ae 100644
--- a/mmcv/ops/sync_bn.py
+++ b/mmcv/ops/sync_bn.py
@@ -4,6 +4,7 @@ from typing import Optional
 import torch
 import torch.distributed as dist
 import torch.nn.functional as F
+from mmengine.device import is_cuda_available, is_musa_available
 from mmengine.registry import MODELS
 from torch.autograd import Function
 from torch.autograd.function import once_differentiable
@@ -47,10 +48,20 @@ class SyncBatchNormFunction(Function):
         self.group_size = group_size
         self.stats_mode = stats_mode
 
-        assert isinstance(
-                   input, (torch.HalfTensor, torch.FloatTensor,
-                           torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \
-               f'only support Half or Float Tensor, but {input.type()}'
+        if is_cuda_available():
+            assert isinstance(
+                    input, (torch.HalfTensor, torch.FloatTensor,
+                            torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \
+                f'only support Half or Float Tensor, but {input.type()}'
+        elif is_musa_available():
+            assert isinstance(
+                    input, (torch.HalfTensor, torch.FloatTensor,
+                            torch.musa.HalfTensor, torch.musa.FloatTensor)), \
+                f'only support Half or Float Tensor, but {input.type()}'
+        else:
+            assert isinstance(
+                    input, (torch.HalfTensor, torch.FloatTensor)), \
+                f'only support Half or Float Tensor, but {input.type()}'
         output = torch.zeros_like(input)
         input3d = input.flatten(start_dim=2)
         output3d = output.view_as(input3d)
diff --git a/mmcv/ops/upfirdn2d.py b/mmcv/ops/upfirdn2d.py
index e01509503..6f251720b 100644
--- a/mmcv/ops/upfirdn2d.py
+++ b/mmcv/ops/upfirdn2d.py
@@ -116,6 +116,13 @@ def upfirdn2d(input: torch.Tensor,
             padding=padding,
             flip_filter=flip_filter,
             gain=gain).apply(input, filter)
+    elif use_custom_op and input.device.type == 'musa':
+        return _upfirdn2d_musa(
+            up=up,
+            down=down,
+            padding=padding,
+            flip_filter=flip_filter,
+            gain=gain).apply(input, filter)
     return _upfirdn2d_ref(
         input,
         filter,
@@ -303,6 +310,101 @@ def _upfirdn2d_cuda(up: int = 1,
     return Upfirdn2dCuda
 
 
+_upfirdn2d_musa_cache: Dict = dict()
+
+
+def _upfirdn2d_musa(up: int = 1,
+                    down: int = 1,
+                    padding: Union[int, List[int]] = 0,
+                    flip_filter: bool = False,
+                    gain: Union[float, int] = 1):
+    """Fast MUSA implementation of `upfirdn2d()` using custom ops.
+
+    Args:
+        up (int): Integer upsampling factor. Can be a single int or a
+            list/tuple `[x, y]`. Defaults to 1.
+        down (int): Integer downsampling factor. Can be a single int
+            or a list/tuple `[x, y]`. Defaults to 1.
+        padding (int | tuple[int]): Padding with respect to the upsampled
+            image. Can be a single number or a list/tuple `[x, y]` or
+            `[x_before, x_after, y_before, y_after]`. Defaults to 0.
+        flip_filter (bool): False = convolution, True = correlation.
+            Defaults to False.
+        gain (int): Overall scaling factor for signal magnitude.
+            Defaults to 1.
+
+    Returns:
+        torch.Tensor: Tensor of the shape `[batch_size, num_channels,
+        out_height, out_width]`
+    """
+    # Parse arguments.
+    upx, upy = _parse_scaling(up)
+    downx, downy = _parse_scaling(down)
+    padx0, padx1, pady0, pady1 = _parse_padding(padding)
+
+    # Lookup from cache.
+    key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter,
+           gain)
+    if key in _upfirdn2d_musa_cache:
+        return _upfirdn2d_musa_cache[key]
+
+    # Forward op.
+    class Upfirdn2dMusa(torch.autograd.Function):
+
+        @staticmethod
+        def forward(ctx, x, f):  # pylint: disable=arguments-differ
+            assert isinstance(x, torch.Tensor) and x.ndim == 4
+            if f is None:
+                f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+            if f.ndim == 1 and f.shape[0] == 1:
+                f = f.square().unsqueeze(
+                    0)  # Convert separable-1 into full-1x1.
+            assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
+            y = x
+            if f.ndim == 2:
+                y = ext_module.upfirdn2d(y, f, upx, upy, downx, downy, padx0,
+                                         padx1, pady0, pady1, flip_filter,
+                                         gain)
+            else:
+                y = ext_module.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1,
+                                         padx0, padx1, 0, 0, flip_filter, 1.0)
+                y = ext_module.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy,
+                                         0, 0, pady0, pady1, flip_filter, gain)
+            ctx.save_for_backward(f)
+            ctx.x_shape = x.shape
+            return y
+
+        @staticmethod
+        def backward(ctx, dy):  # pylint: disable=arguments-differ
+            f, = ctx.saved_tensors
+            _, _, ih, iw = ctx.x_shape
+            _, _, oh, ow = dy.shape
+            fw, fh = _get_filter_size(f)
+            p = [
+                fw - padx0 - 1,
+                iw * upx - ow * downx + padx0 - upx + 1,
+                fh - pady0 - 1,
+                ih * upy - oh * downy + pady0 - upy + 1,
+            ]
+            dx = None
+            df = None
+
+            if ctx.needs_input_grad[0]:
+                dx = _upfirdn2d_musa(
+                    up=down,
+                    down=up,
+                    padding=p,
+                    flip_filter=(not flip_filter),
+                    gain=gain).apply(dy, f)
+
+            assert not ctx.needs_input_grad[1]
+            return dx, df
+
+    # Add to cache.
+    _upfirdn2d_musa_cache[key] = Upfirdn2dMusa
+    return Upfirdn2dMusa
+
+
 def filter2d(input: torch.Tensor,
              filter: torch.Tensor,
              padding: Union[int, List[int]] = 0,
diff --git a/tests/test_ops/test_points_in_polygons.py b/tests/test_ops/test_points_in_polygons.py
index d224d1593..cdfbef579 100644
--- a/tests/test_ops/test_points_in_polygons.py
+++ b/tests/test_ops/test_points_in_polygons.py
@@ -4,7 +4,7 @@ import pytest
 import torch
 
 from mmcv.ops import points_in_polygons
-from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
+from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE, IS_NPU_AVAILABLE
 
 
 @pytest.mark.parametrize('device', [
@@ -15,7 +15,11 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
     pytest.param(
         'npu',
         marks=pytest.mark.skipif(
-            not IS_NPU_AVAILABLE, reason='requires NPU support'))
+            not IS_NPU_AVAILABLE, reason='requires NPU support')),
+    pytest.param(
+        'musa',
+        marks=pytest.mark.skipif(
+            not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
 ])
 def test_points_in_polygons(device):
     points = np.array([[300., 300.], [400., 400.], [100., 100], [300, 250],
diff --git a/tests/test_ops/test_prroi_pool.py b/tests/test_ops/test_prroi_pool.py
index 0535dfbe2..b7fd52f95 100644
--- a/tests/test_ops/test_prroi_pool.py
+++ b/tests/test_ops/test_prroi_pool.py
@@ -3,7 +3,7 @@ import numpy as np
 import pytest
 import torch
 
-from mmcv.utils import IS_CUDA_AVAILABLE
+from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE
 
 _USING_PARROTS = True
 try:
@@ -41,7 +41,11 @@ class TestPrRoiPool:
         pytest.param(
             'cuda',
             marks=pytest.mark.skipif(
-                not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
+                not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
+        pytest.param(
+            'musa',
+            marks=pytest.mark.skipif(
+                not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
     ])
     def test_roipool_gradcheck(self, device):
         from mmcv.ops import PrRoIPool
@@ -92,7 +96,11 @@ class TestPrRoiPool:
         pytest.param(
             'cuda',
             marks=pytest.mark.skipif(
-                not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
+                not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
+        pytest.param(
+            'musa',
+            marks=pytest.mark.skipif(
+                not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
     ])
     def test_roipool_allclose_float(self, device):
         self._test_roipool_allclose(device, dtype=torch.float)
diff --git a/tests/test_ops/test_psa_mask.py b/tests/test_ops/test_psa_mask.py
index b0fd86e8f..d692f4d78 100644
--- a/tests/test_ops/test_psa_mask.py
+++ b/tests/test_ops/test_psa_mask.py
@@ -4,7 +4,8 @@ import pytest
 import torch
 import torch.nn as nn
 
-from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
+from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
+                        IS_NPU_AVAILABLE)
 
 
 class Loss(nn.Module):
@@ -32,7 +33,11 @@ class TestPSAMask:
         pytest.param(
             'npu',
             marks=pytest.mark.skipif(
-                not IS_NPU_AVAILABLE, reason='requires NPU support'))
+                not IS_NPU_AVAILABLE, reason='requires NPU support')),
+        pytest.param(
+            'musa',
+            marks=pytest.mark.skipif(
+                not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
     ])
     def test_psa_mask_collect(self, device):
         from mmcv.ops import PSAMask
@@ -84,7 +89,11 @@ class TestPSAMask:
         pytest.param(
             'npu',
             marks=pytest.mark.skipif(
-                not IS_NPU_AVAILABLE, reason='requires NPU support'))
+                not IS_NPU_AVAILABLE, reason='requires NPU support')),
+        pytest.param(
+            'musa',
+            marks=pytest.mark.skipif(
+                not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
     ])
     def test_psa_mask_distribute(self, device):
         from mmcv.ops import PSAMask
diff --git a/tests/test_ops/test_roi_align.py b/tests/test_ops/test_roi_align.py
index dcd210346..f9a51eb28 100644
--- a/tests/test_ops/test_roi_align.py
+++ b/tests/test_ops/test_roi_align.py
@@ -3,7 +3,8 @@ import numpy as np
 import pytest
 import torch
 
-from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
+from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
+                        IS_NPU_AVAILABLE)
 
 _USING_PARROTS = True
 try:
@@ -107,7 +108,11 @@ def _test_roialign_allclose(device, dtype):
     pytest.param(
         'npu',
         marks=pytest.mark.skipif(
-            not IS_NPU_AVAILABLE, reason='requires NPU support'))
+            not IS_NPU_AVAILABLE, reason='requires NPU support')),
+    pytest.param(
+        'musa',
+        marks=pytest.mark.skipif(
+            not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
 ])
 def test_roialign_float(device, dtype):
     _test_roialign_allclose(device=device, dtype=dtype)
diff --git a/tests/test_ops/test_roi_align_rotated.py b/tests/test_ops/test_roi_align_rotated.py
index 0d5ca432d..ddfacea0b 100644
--- a/tests/test_ops/test_roi_align_rotated.py
+++ b/tests/test_ops/test_roi_align_rotated.py
@@ -3,7 +3,7 @@ import numpy as np
 import pytest
 import torch
 
-from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
+from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE
 
 _USING_PARROTS = True
 try:
@@ -132,15 +132,19 @@ def _test_roialign_rotated_allclose(device, dtype):
     pytest.param(
         'mlu',
         marks=pytest.mark.skipif(
-            not IS_MLU_AVAILABLE, reason='requires MLU support'))
+            not IS_MLU_AVAILABLE, reason='requires MLU support')),
+    pytest.param(
+        'musa',
+        marks=pytest.mark.skipif(
+            not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
 ])
 @pytest.mark.parametrize('dtype', [
     torch.float,
     pytest.param(
         torch.double,
         marks=pytest.mark.skipif(
-            IS_MLU_AVAILABLE,
-            reason='MLU does not support for 64-bit floating point')),
+            IS_MLU_AVAILABLE or IS_MUSA_AVAILABLE,
+            reason='MLU, MUSA does not support for 64-bit floating point')),
     torch.half
 ])
 def test_roialign_rotated(device, dtype):
diff --git a/tests/test_ops/test_roi_pool.py b/tests/test_ops/test_roi_pool.py
index 5ab04bce2..f6d38d5af 100644
--- a/tests/test_ops/test_roi_pool.py
+++ b/tests/test_ops/test_roi_pool.py
@@ -5,7 +5,8 @@ import numpy as np
 import pytest
 import torch
 
-from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
+from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
+                        IS_NPU_AVAILABLE)
 
 _USING_PARROTS = True
 try:
@@ -89,16 +90,20 @@ class TestRoiPool:
         pytest.param(
             'npu',
             marks=pytest.mark.skipif(
-                not IS_NPU_AVAILABLE, reason='requires NPU support'))
+                not IS_NPU_AVAILABLE, reason='requires NPU support')),
+        pytest.param(
+            'musa',
+            marks=pytest.mark.skipif(
+                not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
     ])
     @pytest.mark.parametrize('dtype', [
         torch.float,
         pytest.param(
             torch.double,
             marks=pytest.mark.skipif(
-                IS_MLU_AVAILABLE or IS_NPU_AVAILABLE,
-                reason='MLU, NPU does not support for 64-bit floating point')),
-        torch.half
+                IS_MLU_AVAILABLE or IS_NPU_AVAILABLE or IS_MUSA_AVAILABLE,
+                reason='MLU, NPU, MUSA '
+                'does not support for 64-bit floating point')), torch.half
     ])
     def test_roipool_allclose(self, device, dtype):
         self._test_roipool_allclose(device, dtype)
diff --git a/tests/test_ops/test_roiaware_pool3d.py b/tests/test_ops/test_roiaware_pool3d.py
index 189db33cc..338a1544c 100644
--- a/tests/test_ops/test_roiaware_pool3d.py
+++ b/tests/test_ops/test_roiaware_pool3d.py
@@ -5,7 +5,8 @@ import torch
 
 from mmcv.ops import (RoIAwarePool3d, points_in_boxes_all, points_in_boxes_cpu,
                       points_in_boxes_part)
-from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
+from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
+                        IS_NPU_AVAILABLE)
 
 
 @pytest.mark.parametrize('dtype', [
@@ -13,7 +14,8 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
     pytest.param(
         torch.double,
         marks=pytest.mark.skipif(
-            IS_MLU_AVAILABLE, reason='MLU does not support for double'))
+            IS_MLU_AVAILABLE or IS_MUSA_AVAILABLE,
+            reason='MLU, MUSA does not support for double'))
 ])
 @pytest.mark.parametrize('device', [
     pytest.param(
@@ -23,7 +25,11 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
     pytest.param(
         'mlu',
         marks=pytest.mark.skipif(
-            not IS_MLU_AVAILABLE, reason='requires MLU support'))
+            not IS_MLU_AVAILABLE, reason='requires MLU support')),
+    pytest.param(
+        'musa',
+        marks=pytest.mark.skipif(
+            not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
 ])
 def test_RoIAwarePool3d(device, dtype):
     roiaware_pool3d_max = RoIAwarePool3d(
@@ -64,7 +70,11 @@ def test_RoIAwarePool3d(device, dtype):
     pytest.param(
         'npu',
         marks=pytest.mark.skipif(
-            not IS_NPU_AVAILABLE, reason='requires NPU support'))
+            not IS_NPU_AVAILABLE, reason='requires NPU support')),
+    pytest.param(
+        'musa',
+        marks=pytest.mark.skipif(
+            not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
 ])
 def test_points_in_boxes_part(device):
     boxes = torch.tensor(
diff --git a/tests/test_ops/test_roipoint_pool3d.py b/tests/test_ops/test_roipoint_pool3d.py
index f0ad5586a..c3109a363 100644
--- a/tests/test_ops/test_roipoint_pool3d.py
+++ b/tests/test_ops/test_roipoint_pool3d.py
@@ -3,7 +3,8 @@ import pytest
 import torch
 
 from mmcv.ops import RoIPointPool3d
-from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
+from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
+                        IS_NPU_AVAILABLE)
 
 
 @pytest.mark.parametrize('device', [
@@ -18,15 +19,19 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
     pytest.param(
         'npu',
         marks=pytest.mark.skipif(
-            not IS_NPU_AVAILABLE, reason='requires NPU support'))
+            not IS_NPU_AVAILABLE, reason='requires NPU support')),
+    pytest.param(
+        'musa',
+        marks=pytest.mark.skipif(
+            not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
 ])
 @pytest.mark.parametrize('dtype', [
     torch.float, torch.half,
     pytest.param(
         torch.double,
         marks=pytest.mark.skipif(
-            IS_MLU_AVAILABLE or IS_NPU_AVAILABLE,
-            reason='MLU and NPU does not support for double'))
+            IS_MLU_AVAILABLE or IS_NPU_AVAILABLE or IS_MUSA_AVAILABLE,
+            reason='MLU, NPU, MUSA does not support for double'))
 ])
 def test_roipoint(device, dtype):
     points = torch.tensor(
diff --git a/tests/test_ops/test_rotated_feature_align.py b/tests/test_ops/test_rotated_feature_align.py
index 23de07e8e..6447f410a 100644
--- a/tests/test_ops/test_rotated_feature_align.py
+++ b/tests/test_ops/test_rotated_feature_align.py
@@ -3,7 +3,8 @@ import pytest
 import torch
 
 from mmcv.ops import rotated_feature_align
-from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
+from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
+                        IS_NPU_AVAILABLE)
 
 
 @pytest.mark.skipif(
@@ -21,6 +22,10 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
         'npu',
         marks=pytest.mark.skipif(
             not IS_NPU_AVAILABLE, reason='requires NPU support')),
+    pytest.param(
+        'musa',
+        marks=pytest.mark.skipif(
+            not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
     pytest.param(
         'cpu',
         marks=pytest.mark.skipif(
diff --git a/tests/test_ops/test_scatter_points.py b/tests/test_ops/test_scatter_points.py
index b8b569481..46ab4430e 100644
--- a/tests/test_ops/test_scatter_points.py
+++ b/tests/test_ops/test_scatter_points.py
@@ -4,7 +4,7 @@ import torch
 from torch.autograd import gradcheck
 
 from mmcv.ops import DynamicScatter
-from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
+from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE
 
 if torch.__version__ == 'parrots':
     pytest.skip('not supported in parrots now', allow_module_level=True)
@@ -18,7 +18,11 @@ if torch.__version__ == 'parrots':
     pytest.param(
         'mlu',
         marks=pytest.mark.skipif(
-            not IS_MLU_AVAILABLE, reason='requires MLU support'))
+            not IS_MLU_AVAILABLE, reason='requires MLU support')),
+    pytest.param(
+        'musa',
+        marks=pytest.mark.skipif(
+            not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
 ])
 def test_dynamic_scatter(device):
     dsmean = DynamicScatter([0.32, 0.32, 6],
diff --git a/tests/test_ops/test_spconv.py b/tests/test_ops/test_spconv.py
index 17ca5678e..6d958e021 100644
--- a/tests/test_ops/test_spconv.py
+++ b/tests/test_ops/test_spconv.py
@@ -10,7 +10,7 @@ from mmcv.ops import (SparseConvTensor, SparseInverseConv3d, SparseSequential,
 if torch.__version__ == 'parrots':
     pytest.skip('not supported in parrots now', allow_module_level=True)
 
-from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
+from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE
 
 
 def make_sparse_convmodule(in_channels,
@@ -86,10 +86,17 @@ def make_sparse_convmodule(in_channels,
     pytest.param(
         'mlu',
         marks=pytest.mark.skipif(
-            not IS_MLU_AVAILABLE, reason='requires MLU support'))
+            not IS_MLU_AVAILABLE, reason='requires MLU support')),
+    pytest.param(
+        'musa',
+        marks=pytest.mark.skipif(
+            not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
 ])
 def test_make_sparse_convmodule(device):
-    torch.cuda.empty_cache()
+    if IS_CUDA_AVAILABLE:
+        torch.cuda.empty_cache()
+    elif IS_MUSA_AVAILABLE:
+        torch.musa.empty_cache()
     voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315],
                                    [6.8162713, -2.480431, -1.3616394, 0.36],
                                    [11.643568, -4.744306, -1.3580885, 0.16],
diff --git a/tests/test_ops/test_syncbn.py b/tests/test_ops/test_syncbn.py
index d1c1605ad..dd046b9e2 100644
--- a/tests/test_ops/test_syncbn.py
+++ b/tests/test_ops/test_syncbn.py
@@ -8,6 +8,8 @@ import torch
 import torch.distributed as dist
 import torch.nn as nn
 
+from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE
+
 if platform.system() == 'Windows':
     import regex as re
 else:
@@ -29,10 +31,24 @@ class TestSyncBN:
         os.environ['WORLD_SIZE'] = str(world_size)
         os.environ['RANK'] = str(rank)
 
-        dist.init_process_group('nccl')
-        torch.cuda.set_device(local_rank)
+        if IS_CUDA_AVAILABLE:
+            dist.init_process_group('nccl')
+            torch.cuda.set_device(local_rank)
+        elif IS_MUSA_AVAILABLE:
+            dist.init_process_group('mccl')
+            torch.musa.set_device(local_rank)
 
-    def _test_syncbn_train(self, size=1, half=False):
+    @pytest.mark.parametrize('device', [
+        pytest.param(
+            'cuda',
+            marks=pytest.mark.skipif(
+                not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
+        pytest.param(
+            'musa',
+            marks=pytest.mark.skipif(
+                not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
+    ])
+    def _test_syncbn_train(self, size=1, half=False, device='cuda'):
 
         if 'SLURM_NTASKS' not in os.environ or int(
                 os.environ['SLURM_NTASKS']) != 4:
@@ -49,10 +65,13 @@ class TestSyncBN:
         rank = dist.get_rank()
 
         torch.manual_seed(9)
-        torch.cuda.manual_seed(9)
+        if IS_CUDA_AVAILABLE:
+            torch.cuda.manual_seed(9)
+        elif IS_MUSA_AVAILABLE:
+            torch.musa.manual_seed(9)
 
-        self.x = torch.rand(16, 3, 2, 3).cuda()
-        self.y_bp = torch.rand(16, 3, 2, 3).cuda()
+        self.x = torch.rand(16, 3, 2, 3).to(device)
+        self.y_bp = torch.rand(16, 3, 2, 3).to(device)
 
         if half:
             self.x = self.x.half()
@@ -60,7 +79,10 @@ class TestSyncBN:
         dist.broadcast(self.x, src=0)
         dist.broadcast(self.y_bp, src=0)
 
-        torch.cuda.synchronize()
+        if IS_CUDA_AVAILABLE:
+            torch.cuda.synchronize()
+        elif IS_MUSA_AVAILABLE:
+            torch.musa.synchronize()
         if size == 1:
             groups = [None, None, None, None]
             groups[0] = dist.new_group([0])
@@ -75,13 +97,13 @@ class TestSyncBN:
             group = groups[rank]
         elif size == 4:
             group = dist.group.WORLD
-        syncbn = SyncBatchNorm(3, group=group).cuda()
+        syncbn = SyncBatchNorm(3, group=group).to(device)
         syncbn.weight.data[0] = 0.2
         syncbn.weight.data[1] = 0.5
         syncbn.weight.data[2] = 0.7
         syncbn.train()
 
-        bn = nn.BatchNorm2d(3).cuda()
+        bn = nn.BatchNorm2d(3).to(device)
         bn.weight.data[0] = 0.2
         bn.weight.data[1] = 0.5
         bn.weight.data[2] = 0.7
@@ -143,7 +165,17 @@ class TestSyncBN:
         assert np.allclose(x_grad.data.cpu().numpy(),
                            sx_grad.data.cpu().numpy(), 1e-2)
 
-    def _test_syncbn_empty_train(self, size=1, half=False):
+    @pytest.mark.parametrize('device', [
+        pytest.param(
+            'cuda',
+            marks=pytest.mark.skipif(
+                not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
+        pytest.param(
+            'musa',
+            marks=pytest.mark.skipif(
+                not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
+    ])
+    def _test_syncbn_empty_train(self, size=1, half=False, device='cuda'):
 
         if 'SLURM_NTASKS' not in os.environ or int(
                 os.environ['SLURM_NTASKS']) != 4:
@@ -160,10 +192,13 @@ class TestSyncBN:
         rank = dist.get_rank()
 
         torch.manual_seed(9)
-        torch.cuda.manual_seed(9)
+        if IS_CUDA_AVAILABLE:
+            torch.cuda.manual_seed(9)
+        elif IS_MUSA_AVAILABLE:
+            torch.musa.manual_seed(9)
 
-        self.x = torch.rand(0, 3, 2, 3).cuda()
-        self.y_bp = torch.rand(0, 3, 2, 3).cuda()
+        self.x = torch.rand(0, 3, 2, 3).to(device)
+        self.y_bp = torch.rand(0, 3, 2, 3).to(device)
 
         if half:
             self.x = self.x.half()
@@ -171,7 +206,10 @@ class TestSyncBN:
         dist.broadcast(self.x, src=0)
         dist.broadcast(self.y_bp, src=0)
 
-        torch.cuda.synchronize()
+        if IS_CUDA_AVAILABLE:
+            torch.cuda.synchronize()
+        elif IS_MUSA_AVAILABLE:
+            torch.musa.synchronize()
         if size == 1:
             groups = [None, None, None, None]
             groups[0] = dist.new_group([0])
@@ -187,13 +225,13 @@ class TestSyncBN:
         elif size == 4:
             group = dist.group.WORLD
 
-        syncbn = SyncBatchNorm(3, group=group, stats_mode='N').cuda()
+        syncbn = SyncBatchNorm(3, group=group, stats_mode='N').to(device)
         syncbn.weight.data[0] = 0.2
         syncbn.weight.data[1] = 0.5
         syncbn.weight.data[2] = 0.7
         syncbn.train()
 
-        bn = nn.BatchNorm2d(3).cuda()
+        bn = nn.BatchNorm2d(3).to(device)
         bn.weight.data[0] = 0.2
         bn.weight.data[1] = 0.5
         bn.weight.data[2] = 0.7
diff --git a/tests/test_ops/test_three_interpolate.py b/tests/test_ops/test_three_interpolate.py
index d27a795ec..dd8e25c89 100644
--- a/tests/test_ops/test_three_interpolate.py
+++ b/tests/test_ops/test_three_interpolate.py
@@ -3,7 +3,7 @@ import pytest
 import torch
 
 from mmcv.ops import three_interpolate
-from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
+from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE, IS_NPU_AVAILABLE
 
 
 @pytest.mark.parametrize('dtype', [
@@ -11,8 +11,8 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
     pytest.param(
         torch.double,
         marks=pytest.mark.skipif(
-            IS_NPU_AVAILABLE,
-            reason='NPU does not support for 64-bit floating point'))
+            IS_NPU_AVAILABLE or IS_MUSA_AVAILABLE,
+            reason='NPU, MUSA does not support for 64-bit floating point'))
 ])
 @pytest.mark.parametrize('device', [
     pytest.param(
@@ -22,9 +22,15 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
     pytest.param(
         'npu',
         marks=pytest.mark.skipif(
-            not IS_NPU_AVAILABLE, reason='requires NPU support'))
+            not IS_NPU_AVAILABLE, reason='requires NPU support')),
+    pytest.param(
+        'musa',
+        marks=pytest.mark.skipif(
+            not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
 ])
 def test_three_interpolate(dtype, device):
+    if IS_MUSA_AVAILABLE:
+        torch.musa.empty_cache()
     features = torch.tensor(
         [[[2.4350, 4.7516, 4.4995, 2.4350, 2.4350, 2.4350],
           [3.1236, 2.6278, 3.0447, 3.1236, 3.1236, 3.1236],
diff --git a/tests/test_ops/test_three_nn.py b/tests/test_ops/test_three_nn.py
index 456188b91..9348dd0d5 100644
--- a/tests/test_ops/test_three_nn.py
+++ b/tests/test_ops/test_three_nn.py
@@ -3,7 +3,7 @@ import pytest
 import torch
 
 from mmcv.ops import three_nn
-from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
+from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE
 
 known = [[[-1.8373, 3.5605, -0.7867], [0.7615, 2.9420, 0.2314],
           [-0.6503, 3.6637, -1.0622], [-1.8373, 3.5605, -0.7867],
@@ -48,7 +48,11 @@ expected_idx = [[[0, 3, 4], [1, 2, 0], [2, 0, 3], [0, 3, 4], [2, 1, 0],
     pytest.param(
         'mlu',
         marks=pytest.mark.skipif(
-            not IS_MLU_AVAILABLE, reason='requires MLU support'))
+            not IS_MLU_AVAILABLE, reason='requires MLU support')),
+    pytest.param(
+        'musa',
+        marks=pytest.mark.skipif(
+            not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
 ])
 @pytest.mark.parametrize('dtype,rtol', [(torch.float, 1e-8),
                                         (torch.half, 1e-3)])
diff --git a/tests/test_ops/test_tin_shift.py b/tests/test_ops/test_tin_shift.py
index c8ce14465..3d82de73b 100755
--- a/tests/test_ops/test_tin_shift.py
+++ b/tests/test_ops/test_tin_shift.py
@@ -5,7 +5,7 @@ import numpy as np
 import pytest
 import torch
 
-from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
+from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE
 
 _USING_PARROTS = True
 try:
@@ -209,15 +209,19 @@ def _test_tinshift_assert(device, dtype):
     pytest.param(
         'mlu',
         marks=pytest.mark.skipif(
-            not IS_MLU_AVAILABLE, reason='requires MLU support'))
+            not IS_MLU_AVAILABLE, reason='requires MLU support')),
+    pytest.param(
+        'musa',
+        marks=pytest.mark.skipif(
+            not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
 ])
 @pytest.mark.parametrize('dtype', [
     torch.float,
     pytest.param(
         torch.double,
         marks=pytest.mark.skipif(
-            IS_MLU_AVAILABLE,
-            reason='MLU does not support for 64-bit floating point')),
+            IS_MLU_AVAILABLE or IS_MUSA_AVAILABLE,
+            reason='MLU, MUSA does not support for 64-bit floating point')),
     torch.half
 ])
 def test_tinshift(device, dtype):
diff --git a/tests/test_ops/test_voxelization.py b/tests/test_ops/test_voxelization.py
index 78282a8ad..06167c1d3 100644
--- a/tests/test_ops/test_voxelization.py
+++ b/tests/test_ops/test_voxelization.py
@@ -4,7 +4,8 @@ import pytest
 import torch
 
 from mmcv.ops import Voxelization
-from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
+from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
+                        IS_NPU_AVAILABLE)
 
 
 def _get_voxel_points_indices(points, coors, voxel):
@@ -215,3 +216,38 @@ def test_voxelization_npu(device_type):
     assert np.all(coors == expected_coors)
     assert np.all(voxels == expected_voxels)
     assert np.all(num_points_per_voxel == expected_num_points_per_voxel)
+
+
+@pytest.mark.parametrize('device_type', [
+    pytest.param(
+        'musa',
+        marks=pytest.mark.skipif(
+            not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
+])
+def test_voxelization_musa(device_type):
+    voxel_size = [0.5, 0.5, 0.5]
+    point_cloud_range = [0, -40, -3, 70.4, 40, 1]
+
+    voxel_dict = np.load(
+        'tests/data/for_3d_ops/test_voxel.npy', allow_pickle=True).item()
+    expected_coors = voxel_dict['coors']
+    expected_voxels = voxel_dict['voxels']
+    expected_num_points_per_voxel = voxel_dict['num_points_per_voxel']
+    points = voxel_dict['points']
+
+    points = torch.tensor(points)
+    max_num_points = 1000
+    hard_voxelization = Voxelization(voxel_size, point_cloud_range,
+                                     max_num_points)
+
+    device = torch.device(device_type)
+
+    # test hard_voxelization on mlu
+    points = points.contiguous().to(device)
+    coors, voxels, num_points_per_voxel = hard_voxelization.forward(points)
+    coors = coors.cpu().detach().numpy()
+    voxels = voxels.cpu().detach().numpy()
+    num_points_per_voxel = num_points_per_voxel.cpu().detach().numpy()
+    assert np.all(coors == expected_coors)
+    assert np.all(voxels == expected_voxels)
+    assert np.all(num_points_per_voxel == expected_num_points_per_voxel)