[MUSA] Support MUSA PR2

* squash commit for support musa pr2

* fix conv2d_gradfix:Conv2d is_cuda_available judgment

* Revert "fix conv2d_gradfix:Conv2d is_cuda_available judgment"

This reverts commit 38faea16c1.

* fix conv2d_gradfix:Conv2d comment
pull/3258/head
sunyanguomt 2025-03-11 12:28:07 +08:00 committed by GitHub
parent c4ea0e6d97
commit c7d3326d8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
41 changed files with 4569 additions and 71 deletions

View File

@ -114,6 +114,81 @@ activation_funcs = {
has_2nd_grad=True),
}
activation_funcs_musa = {
'linear':
EasyDict(
func=lambda x, **_: x,
def_alpha=0,
def_gain=1,
musa_idx=1,
ref='',
has_2nd_grad=False),
'relu':
EasyDict(
func=lambda x, **_: torch.nn.functional.relu(x),
def_alpha=0,
def_gain=np.sqrt(2),
musa_idx=2,
ref='y',
has_2nd_grad=False),
'lrelu':
EasyDict(
func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
def_alpha=0.2,
def_gain=np.sqrt(2),
musa_idx=3,
ref='y',
has_2nd_grad=False),
'tanh':
EasyDict(
func=lambda x, **_: torch.tanh(x),
def_alpha=0,
def_gain=1,
musa_idx=4,
ref='y',
has_2nd_grad=True),
'sigmoid':
EasyDict(
func=lambda x, **_: torch.sigmoid(x),
def_alpha=0,
def_gain=1,
musa_idx=5,
ref='y',
has_2nd_grad=True),
'elu':
EasyDict(
func=lambda x, **_: torch.nn.functional.elu(x),
def_alpha=0,
def_gain=1,
musa_idx=6,
ref='y',
has_2nd_grad=True),
'selu':
EasyDict(
func=lambda x, **_: torch.nn.functional.selu(x),
def_alpha=0,
def_gain=1,
musa_idx=7,
ref='y',
has_2nd_grad=True),
'softplus':
EasyDict(
func=lambda x, **_: torch.nn.functional.softplus(x),
def_alpha=0,
def_gain=1,
musa_idx=8,
ref='y',
has_2nd_grad=True),
'swish':
EasyDict(
func=lambda x, **_: torch.sigmoid(x) * x,
def_alpha=0,
def_gain=np.sqrt(2),
musa_idx=9,
ref='x',
has_2nd_grad=True),
}
_null_tensor = torch.empty([0])
@ -167,6 +242,13 @@ def bias_act(input: torch.Tensor,
return _bias_act_cuda(
dim=dim, act=act, alpha=alpha, gain=gain,
clamp=clamp).apply(input, bias)
try:
if use_custom_op and input.is_musa:
return _bias_act_musa(
dim=dim, act=act, alpha=alpha, gain=gain,
clamp=clamp).apply(input, bias)
except AttributeError:
pass
return _bias_act_ref(
input=input,
bias=bias,
@ -373,3 +455,126 @@ def _bias_act_cuda(dim: int = 1,
# Add to cache.
_bias_act_cuda_cache[key] = BiasActCuda
return BiasActCuda
_bias_act_musa_cache: Dict = dict()
def _bias_act_musa(dim: int = 1,
act: str = 'linear',
alpha: Optional[Union[float, int]] = None,
gain: Optional[float] = None,
clamp: Optional[float] = None):
""""Fast MUSA implementation of `bias_act()` using custom ops.
Args:
dim (int): The dimension in `x` corresponding to the elements of `b`.
The value of `dim` is ignored if `b` is not specified.
Defaults to 1.
act (str): Name of the activation function to evaluate, or `"linear"`
to disable. Can be e.g. "relu", "lrelu", "tanh", "sigmoid",
"swish", etc. See `activation_funcs_musa` for a full list. `None`
is not allowed. Defaults to `linear`.
alpha (float | int): Shape parameter for the activation
function, or `None` to use the default. Defaults to None.
gain (float): Scaling factor for the output tensor, or `None`
to use default. See `activation_funcs_musa` for the default scaling
of each activation function. If unsure, consider specifying 1.
Defaults to None.
clamp (float): Clamp the output values to `[-clamp, +clamp]`,
or `None` to disable the clamping (default). Defaults to None.
Returns:
torch.Tensor: Tensor of the same shape and datatype as `x`.
"""
# Parse arguments.
assert clamp is None or clamp >= 0
spec = activation_funcs_musa[act]
alpha = float(alpha if alpha is not None else spec.def_alpha)
gain = float(gain if gain is not None else spec.def_gain)
clamp = float(clamp if clamp is not None else -1)
# Lookup from cache.
key = (dim, act, alpha, gain, clamp)
if key in _bias_act_musa_cache:
return _bias_act_musa_cache[key]
# Forward op.
class BiasActMusa(torch.autograd.Function):
@staticmethod
def forward(ctx, x, b): # pylint: disable=arguments-differ
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(
1) == 1 else torch.contiguous_format
x = x.contiguous(memory_format=ctx.memory_format)
b = b.contiguous() if b is not None else _null_tensor.to(x.device)
y = x
if act != 'linear' or gain != 1 or clamp >= 0 or (
b is not _null_tensor.to(x.device)):
y = ext_module.bias_act(x, b, _null_tensor.to(x.device),
_null_tensor.to(x.device),
_null_tensor.to(x.device), 0, dim,
spec.musa_idx, alpha, gain, clamp)
ctx.save_for_backward(
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor.to(
x.device), b if 'x' in spec.ref or spec.has_2nd_grad else
_null_tensor.to(x.device),
y if 'y' in spec.ref else _null_tensor.to(x.device))
return y
@staticmethod
def backward(ctx, dy): # pylint: disable=arguments-differ
dy = dy.contiguous(memory_format=ctx.memory_format)
x, b, y = ctx.saved_tensors
dx = None
db = None
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
dx = dy
if act != 'linear' or gain != 1 or clamp >= 0:
dx = BiasActMusaGrad.apply(dy, x, b, y)
if ctx.needs_input_grad[1]:
db = dx.sum([i for i in range(dx.ndim) if i != dim])
return dx, db
# Backward op.
class BiasActMusaGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
ctx.memory_format = torch.channels_last if dy.ndim > 2 and (
dy.stride(1) == 1) else torch.contiguous_format
dx = ext_module.bias_act(dy, b, x, y, _null_tensor.to(x.device), 1,
dim, spec.musa_idx, alpha, gain, clamp)
ctx.save_for_backward(
dy if spec.has_2nd_grad else _null_tensor.to(x.device), x, b,
y)
return dx
@staticmethod
def backward(ctx, d_dx): # pylint: disable=arguments-differ
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
dy, x, b, y = ctx.saved_tensors
d_dy = None
d_x = None
d_b = None
d_y = None
if ctx.needs_input_grad[0]:
d_dy = BiasActMusaGrad.apply(d_dx, x, b, y)
if spec.has_2nd_grad and (ctx.needs_input_grad[1]
or ctx.needs_input_grad[2]):
d_x = ext_module.bias_act(d_dx, b, x, y, dy, 2, dim,
spec.musa_idx, alpha, gain, clamp)
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
return d_dy, d_x, d_b, d_y
# Add to cache.
_bias_act_musa_cache[key] = BiasActMusa
return BiasActMusa

View File

@ -65,7 +65,7 @@ class CARAFENaiveFunction(Function):
def backward(
ctx,
grad_output: Tensor) -> Tuple[Tensor, Tensor, None, None, None]:
assert grad_output.is_cuda
assert grad_output.is_cuda or grad_output.is_musa
features, masks = ctx.saved_tensors
kernel_size = ctx.kernel_size

View File

@ -15,6 +15,7 @@ import warnings
from typing import Dict, Optional, Tuple, Union
import torch
from mmengine.device import is_musa_available
from mmengine.utils import digit_version
from mmengine.utils.dl_utils.parrots_wrapper import is_rocm_pytorch
@ -95,6 +96,8 @@ def conv_transpose2d(input: torch.Tensor,
def _should_use_custom_op(input):
assert isinstance(input, torch.Tensor)
if enabled and is_musa_available():
return True
if (not enabled) or (not torch.backends.cudnn.enabled):
return False
if input.device.type != 'cuda':
@ -177,7 +180,8 @@ def _conv2d_gradfix(
ctx.input_shape = input.shape
# Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
if weight_shape[2:] == stride == dilation == (
if (not is_musa_available()
) and weight_shape[2:] == stride == dilation == (
1, 1) and padding == (
0, 0) and torch.cuda.get_device_capability(
input.device) < (8, 0):

View File

@ -0,0 +1,192 @@
// Copyright (c) OpenMMLab. All rights reserved
// modified from
// https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/csrc/border_align/border_align_kernel.cu.
// the main difference: (1) use `argmax_idx` for fast computing of gradient
// during the backward. (2) `wh` is directly computed by `boxes`, rather than
// passing it as argument to forward or backward functions.
#ifndef BORDER_ALIGN_MUSA_KERNEL_MUH
#define BORDER_ALIGN_MUSA_KERNEL_MUH
#include <float.h>
#include "pytorch_musa_helper.hpp"
enum BorderMode { Top = 0, Left = 1, Bottom = 2, Right = 3 };
/*** Forward ***/
template <typename T>
__global__ void border_align_forward_musa_kernel(
const int nthreads, const T* input, const T* boxes, T* output,
int* argmax_idx, const int channels, const int box_size, const int height,
const int width, const int pool_size) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
// (batch_idx, c_idx, box_idx) is an element paralleled for computing
// output, and `extreme_idx` is in range [0,3]
int batch_idx, c_idx, box_idx, extreme_idx, maxidx, *offset_argmax_idx;
const T *offset_box, *offset_input, *offset_box_x;
T *offset_output, box_width, box_height, stride, x_stride, y_stride, x, y,
val, maxval;
extreme_idx = threadIdx.y;
// shape (N, C, box_size, 4) for output
batch_idx = index / channels / box_size;
// shape (N, box_size, 4) for boxes
box_idx = index % box_size + batch_idx * box_size;
c_idx = (index / box_size) % channels;
offset_box = boxes + box_idx * 4;
box_width = *(offset_box + 2) - *offset_box;
box_height = *(offset_box + 3) - *(offset_box + 1);
offset_output = output + index * 4 + extreme_idx;
offset_argmax_idx = argmax_idx + index * 4 + extreme_idx;
// shape (N, 4C, h, w) for input.
// [0,C) for top feature, [C,2C) for left feature,
// [2C,3C) for bottom feature, [3C,4C) for right feature
offset_input =
input + (batch_idx * channels * 4 + extreme_idx * channels + c_idx) *
height * width;
// extreme_idx in [0,1] -> offset_box_x indexed at x1
// extreme_idx in [2,3] -> offset_box_x indexed at x2
offset_box_x = offset_box + extreme_idx / 2 * 2;
// (x1,y1) or (x2,y2) for (x,y)
x = *offset_box_x;
y = *(offset_box_x + 1);
switch (extreme_idx) {
// top
case BorderMode::Top:
stride = box_width / pool_size;
x_stride = stride;
y_stride = 0;
break;
// left
case BorderMode::Left:
stride = box_height / pool_size;
x_stride = 0;
y_stride = stride;
break;
// bottom
case BorderMode::Bottom:
stride = box_width / pool_size;
x_stride = -stride;
y_stride = 0;
break;
// right
case BorderMode::Right:
stride = box_height / pool_size;
x_stride = 0;
y_stride = -stride;
break;
}
// initialize maxval and maxidx with the start position (e.g. (x1,y1) or
// (x2,y2))
maxval = bilinear_interpolate(offset_input, height, width, y, x, index);
maxidx = 0;
// do max_pool along the border
for (int i = 1; i <= pool_size; i++) {
x += x_stride;
y += y_stride;
val = bilinear_interpolate(offset_input, height, width, y, x, index);
if (val > maxval) {
maxval = val;
maxidx = i;
}
}
// update output and argmax_idx
*offset_output = maxval;
*offset_argmax_idx = maxidx;
}
}
/*** Backward ***/
template <typename T>
__global__ void border_align_backward_musa_kernel(
const int nthreads, const T* grad_output, const T* boxes,
const int* argmax_idx, T* grad_input, const int channels,
const int box_size, const int height, const int width,
const int pool_size) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
// (batch_idx, c_idx, box_idx) is an element paralleled for computing
// output, and `extreme_idx` is in range [0,3]
int batch_idx, c_idx, box_idx, extreme_idx;
const int* offset_argmax_idx;
const T *offset_grad_output, *offset_box, *offset_box_x;
T *offset_grad_input, box_width, box_height, stride, x_stride, y_stride, x,
y;
extreme_idx = threadIdx.y;
batch_idx = index / channels / box_size;
box_idx = index % box_size + batch_idx * box_size;
c_idx = (index / box_size) % channels;
offset_box = boxes + box_idx * 4;
box_width = *(offset_box + 2) - *offset_box;
box_height = *(offset_box + 3) - *(offset_box + 1);
offset_grad_output = grad_output + index * 4 + extreme_idx;
offset_argmax_idx = argmax_idx + index * 4 + extreme_idx;
// [0,C) for top feature grad, [C,2C) for left feature grad,
// [2C,3C) for bottom feature grad, [3C,4C) for right feature grad
offset_grad_input = grad_input + (batch_idx * channels * 4 +
extreme_idx * channels + c_idx) *
height * width;
// extreme_idx in [0,1] -> offset_box_x indexed at x1
// extreme_idx in [2,3] -> offset_box_x indexed at x2
offset_box_x = offset_box + extreme_idx / 2 * 2;
switch (extreme_idx) {
// top
case BorderMode::Top:
stride = box_width / pool_size;
x_stride = stride;
y_stride = 0;
break;
// left
case BorderMode::Left:
stride = box_height / pool_size;
x_stride = 0;
y_stride = stride;
break;
// bottom
case BorderMode::Bottom:
stride = box_width / pool_size;
x_stride = -stride;
y_stride = 0;
break;
// right
case BorderMode::Right:
stride = box_height / pool_size;
x_stride = 0;
y_stride = -stride;
break;
}
// get position (x,y) which has maximum value during forward
x = *offset_box_x;
y = *(offset_box_x + 1);
x += x_stride * (T)(*offset_argmax_idx);
y += y_stride * (T)(*offset_argmax_idx);
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4, x_low,
x_high, y_low, y_high, index);
// update grad_output
atomicAdd(offset_grad_input + y_low * width + x_low,
*offset_grad_output * w1);
atomicAdd(offset_grad_input + y_low * width + x_high,
*offset_grad_output * w2);
atomicAdd(offset_grad_input + y_high * width + x_low,
*offset_grad_output * w3);
atomicAdd(offset_grad_input + y_high * width + x_high,
*offset_grad_output * w4);
}
}
#endif // BORDER_ALIGN_MUSA_KERNEL_MUH

View File

@ -0,0 +1,88 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#ifndef BOX_IOU_QUADRI_MUSA_MUH
#define BOX_IOU_QUADRI_MUSA_MUH
#include "pytorch_musa_helper.hpp"
#include "box_iou_rotated_utils.hpp"
// 2D block with 32 * 16 = 512 threads per block
const int BLOCK_DIM_X = 32;
const int BLOCK_DIM_Y = 16;
inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); }
template <typename T>
__global__ void box_iou_quadri_musa_kernel(
const int n_boxes1, const int n_boxes2, const T* dev_boxes1,
const T* dev_boxes2, T* dev_ious, const int mode_flag, const bool aligned) {
if (aligned) {
MUSA_1D_KERNEL_LOOP(index, n_boxes1) {
int b1 = index;
int b2 = index;
int base1 = b1 * 8;
float block_boxes1[8];
float block_boxes2[8];
block_boxes1[0] = dev_boxes1[base1 + 0];
block_boxes1[1] = dev_boxes1[base1 + 1];
block_boxes1[2] = dev_boxes1[base1 + 2];
block_boxes1[3] = dev_boxes1[base1 + 3];
block_boxes1[4] = dev_boxes1[base1 + 4];
block_boxes1[5] = dev_boxes1[base1 + 5];
block_boxes1[6] = dev_boxes1[base1 + 6];
block_boxes1[7] = dev_boxes1[base1 + 7];
int base2 = b2 * 8;
block_boxes2[0] = dev_boxes2[base2 + 0];
block_boxes2[1] = dev_boxes2[base2 + 1];
block_boxes2[2] = dev_boxes2[base2 + 2];
block_boxes2[3] = dev_boxes2[base2 + 3];
block_boxes2[4] = dev_boxes2[base2 + 4];
block_boxes2[5] = dev_boxes2[base2 + 5];
block_boxes2[6] = dev_boxes2[base2 + 6];
block_boxes2[7] = dev_boxes2[base2 + 7];
dev_ious[index] =
single_box_iou_quadri<T>(block_boxes1, block_boxes2, mode_flag);
}
} else {
MUSA_1D_KERNEL_LOOP(index, n_boxes1 * n_boxes2) {
int b1 = index / n_boxes2;
int b2 = index % n_boxes2;
int base1 = b1 * 8;
float block_boxes1[8];
float block_boxes2[8];
block_boxes1[0] = dev_boxes1[base1 + 0];
block_boxes1[1] = dev_boxes1[base1 + 1];
block_boxes1[2] = dev_boxes1[base1 + 2];
block_boxes1[3] = dev_boxes1[base1 + 3];
block_boxes1[4] = dev_boxes1[base1 + 4];
block_boxes1[5] = dev_boxes1[base1 + 5];
block_boxes1[6] = dev_boxes1[base1 + 6];
block_boxes1[7] = dev_boxes1[base1 + 7];
int base2 = b2 * 8;
block_boxes2[0] = dev_boxes2[base2 + 0];
block_boxes2[1] = dev_boxes2[base2 + 1];
block_boxes2[2] = dev_boxes2[base2 + 2];
block_boxes2[3] = dev_boxes2[base2 + 3];
block_boxes2[4] = dev_boxes2[base2 + 4];
block_boxes2[5] = dev_boxes2[base2 + 5];
block_boxes2[6] = dev_boxes2[base2 + 6];
block_boxes2[7] = dev_boxes2[base2 + 7];
dev_ious[index] =
single_box_iou_quadri<T>(block_boxes1, block_boxes2, mode_flag);
}
}
}
#endif

View File

@ -0,0 +1,77 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu
#ifndef BOX_IOU_ROTATED_MUSA_MUH
#define BOX_IOU_ROTATED_MUSA_MUH
#include "pytorch_musa_helper.hpp"
#include "box_iou_rotated_utils.hpp"
// 2D block with 32 * 16 = 512 threads per block
const int BLOCK_DIM_X = 32;
const int BLOCK_DIM_Y = 16;
inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); }
template <typename T>
__global__ void box_iou_rotated_musa_kernel(
const int n_boxes1, const int n_boxes2, const T* dev_boxes1,
const T* dev_boxes2, T* dev_ious, const int mode_flag, const bool aligned) {
if (aligned) {
MUSA_1D_KERNEL_LOOP(index, n_boxes1) {
int b1 = index;
int b2 = index;
int base1 = b1 * 5;
float block_boxes1[5];
float block_boxes2[5];
block_boxes1[0] = dev_boxes1[base1 + 0];
block_boxes1[1] = dev_boxes1[base1 + 1];
block_boxes1[2] = dev_boxes1[base1 + 2];
block_boxes1[3] = dev_boxes1[base1 + 3];
block_boxes1[4] = dev_boxes1[base1 + 4];
int base2 = b2 * 5;
block_boxes2[0] = dev_boxes2[base2 + 0];
block_boxes2[1] = dev_boxes2[base2 + 1];
block_boxes2[2] = dev_boxes2[base2 + 2];
block_boxes2[3] = dev_boxes2[base2 + 3];
block_boxes2[4] = dev_boxes2[base2 + 4];
dev_ious[index] =
single_box_iou_rotated<T>(block_boxes1, block_boxes2, mode_flag);
}
} else {
MUSA_1D_KERNEL_LOOP(index, n_boxes1 * n_boxes2) {
int b1 = index / n_boxes2;
int b2 = index % n_boxes2;
int base1 = b1 * 5;
float block_boxes1[5];
float block_boxes2[5];
block_boxes1[0] = dev_boxes1[base1 + 0];
block_boxes1[1] = dev_boxes1[base1 + 1];
block_boxes1[2] = dev_boxes1[base1 + 2];
block_boxes1[3] = dev_boxes1[base1 + 3];
block_boxes1[4] = dev_boxes1[base1 + 4];
int base2 = b2 * 5;
block_boxes2[0] = dev_boxes2[base2 + 0];
block_boxes2[1] = dev_boxes2[base2 + 1];
block_boxes2[2] = dev_boxes2[base2 + 2];
block_boxes2[3] = dev_boxes2[base2 + 3];
block_boxes2[4] = dev_boxes2[base2 + 4];
dev_ious[index] =
single_box_iou_rotated<T>(block_boxes1, block_boxes2, mode_flag);
}
}
}
#endif

View File

@ -0,0 +1,332 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef CARAFE_MUSA_KERNEL_MUH
#define CARAFE_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
#ifdef MMCV_WITH_HIP
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
#define THREADS_PER_PIXEL 32
#define MAX_SHARED_MEMORY 49152
#define MAX_SHARED_SCALAR_T 6144 // 49152 / 8 = 6144
#define MAXIMIZE_KERNEL_SIZE true
#define kTileDim 32
#define kBlockRows 8
#define FULL_MASK 0xffffffff
inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); }
__device__ inline int Loc2Index(const int n, const int c, const int h,
const int w, const int channel_num,
const int height, const int width) {
int index = w + (h + (c + n * channel_num) * height) * width;
return index;
}
#ifndef MMCV_WITH_HIP
/* TODO: move this to a common place */
template <typename scalar_t>
__device__ inline scalar_t min(scalar_t a, scalar_t b) {
return a < b ? a : b;
}
template <typename scalar_t>
__device__ inline scalar_t max(scalar_t a, scalar_t b) {
return a > b ? a : b;
}
#endif
template <typename scalar_t>
__device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) {
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
#ifdef MMCV_WITH_HIP
val += __shfl_down(val, offset);
#else
val += __shfl_down_sync(FULL_MASK, val, offset);
#endif
return val;
}
template <>
__device__ __forceinline__ phalf warpReduceSum(phalf val) {
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
#ifdef MMCV_WITH_HIP
// Using PyTorch's macro for half support
__PHALF(val) += WARP_SHFL_DOWN(val, offset);
#else
__PHALF(val) +=
__shfl_down_sync(FULL_MASK, __PHALF(val).operator __half(), offset);
#endif
return val;
}
// Splits the original matrix into submatrices with size 32 * 32.
// Each block transposes one submatrix by loading it into shared memory.
// Reference https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/
template <typename scalar_t>
__global__ void BatchTranspose2DMUSAKernel(const int N, const int H,
const int W, const int dh,
const int dw,
const scalar_t *__restrict__ X,
scalar_t *__restrict__ Y) {
__shared__ scalar_t tile[kTileDim][kTileDim + 1];
const int n = blockIdx.x / (dh * dw);
const int k = blockIdx.x % (dh * dw);
const int r = k / dw;
const int c = k % dw;
const int offset = n * H * W;
int x = c * kTileDim + threadIdx.x;
int y = r * kTileDim + threadIdx.y;
if (x < W) {
for (int i = 0; threadIdx.y + i < kTileDim && y + i < H; i += kBlockRows) {
tile[threadIdx.y + i][threadIdx.x] = X[offset + (y + i) * W + x];
}
}
__syncthreads();
x = r * kTileDim + threadIdx.x;
y = c * kTileDim + threadIdx.y;
if (x < H) {
for (int i = 0; threadIdx.y + i < kTileDim && y + i < W; i += kBlockRows) {
Y[offset + (y + i) * H + x] = tile[threadIdx.x][threadIdx.y + i];
}
}
}
template <typename scalar_t>
__global__ void CARAFEForward(
const int num_kernels, const scalar_t *__restrict__ bottom_data,
const scalar_t *__restrict__ bottom_masks, const int kernel_size,
const int group_size, const int scale_factor, const int channels,
const int down_height, const int down_width, const int height,
const int width, const int mask_channels, scalar_t *__restrict__ top_data) {
#if MAXIMIZE_KERNEL_SIZE
__shared__ float shared_mask[MAX_SHARED_SCALAR_T * 2];
#else
__shared__ scalar_t shared_mask[MAX_SHARED_SCALAR_T];
#endif
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index > num_kernels - 1) {
return;
}
const int pixel_id = threadIdx.x / THREADS_PER_PIXEL;
const int split_id = threadIdx.x % THREADS_PER_PIXEL;
index = index / THREADS_PER_PIXEL;
const int pw = index % width;
const int ph = (index / width) % height;
const int n = index / width / height;
const int down_pw = pw / scale_factor;
const int down_ph = ph / scale_factor;
const int start_w = down_pw - (kernel_size - 1) / 2;
const int end_w = down_pw + (kernel_size - 1) / 2 + 1;
const int start_h = down_ph - (kernel_size - 1) / 2;
const int end_h = down_ph + (kernel_size - 1) / 2 + 1;
for (int c = split_id; c < mask_channels; c += THREADS_PER_PIXEL) {
int mask_index = Loc2Index(n, ph, pw, c, height, width, mask_channels);
shared_mask[c * WARP_SIZE + pixel_id] = bottom_masks[mask_index];
}
__syncthreads();
const int channels_per_group = ceilf(channels / (float)group_size);
#pragma unroll
for (int c = split_id; c < channels; c += THREADS_PER_PIXEL) {
int mask_group = c / channels_per_group;
scalar_t output_val = 0;
#pragma unroll
for (int iy = start_h; iy < end_h; iy++) {
#pragma unroll
for (int ix = start_w; ix < end_w; ix++) {
if (iy < 0 || iy > down_height - 1 || ix < 0 || ix > down_width - 1) {
continue;
}
int mask_iy = iy - down_ph + (kernel_size - 1) / 2;
int mask_ix = ix - down_pw + (kernel_size - 1) / 2;
int mask_c =
(mask_group * kernel_size + mask_iy) * kernel_size + mask_ix;
int feat_index =
Loc2Index(n, iy, ix, c, down_height, down_width, channels);
output_val += bottom_data[feat_index] *
shared_mask[mask_c * WARP_SIZE + pixel_id];
}
}
int top_index = Loc2Index(n, ph, pw, c, height, width, channels);
top_data[top_index] = output_val;
}
}
template <typename scalar_t>
__global__ void CARAFEBackward_Feature(
const int num_kernels, const scalar_t *__restrict__ top_diff,
const scalar_t *__restrict__ bottom_masks, const int kernel_size,
const int group_size, const int scale_factor, const int channels,
const int down_height, const int down_width, const int height,
const int width, const int mask_channels,
scalar_t *__restrict__ bottom_diff) {
#if MAXIMIZE_KERNEL_SIZE
__shared__ float shared_mask[MAX_SHARED_SCALAR_T * 2];
#else
__shared__ scalar_t shared_mask[MAX_SHARED_SCALAR_T];
#endif
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index > num_kernels - 1) {
return;
}
const int pixel_id = threadIdx.x / THREADS_PER_PIXEL;
const int split_id = threadIdx.x % THREADS_PER_PIXEL;
// (n, c, ph, pw) is an element in the bottom_data
index = index / THREADS_PER_PIXEL;
const int pw = index % width;
const int ph = (index / width) % height;
const int n = index / width / height;
const int start_w = pw - (kernel_size - 1) * scale_factor / 2;
const int end_w = pw + (kernel_size - 1) * scale_factor / 2 + 1;
const int start_h = ph - (kernel_size - 1) * scale_factor / 2;
const int end_h = ph + (kernel_size - 1) * scale_factor / 2 + 1;
for (int c = split_id; c < mask_channels; c += THREADS_PER_PIXEL) {
const int mask_w = (c % kernel_size) * scale_factor;
const int mask_h = (c / kernel_size % kernel_size) * scale_factor;
const int mask_x = start_w + mask_w;
const int mask_y = start_h + mask_h;
if (mask_y < 0 || mask_y > height - 1 || mask_x < 0 || mask_x > width - 1) {
shared_mask[c * WARP_SIZE + pixel_id] = 0;
continue;
}
const int mask_group = c / (kernel_size * kernel_size);
const int mask_c = (2 * mask_group + 1) * kernel_size * kernel_size - c - 1;
int mask_index =
Loc2Index(n, mask_c, mask_y, mask_x, mask_channels, height, width);
shared_mask[c * WARP_SIZE + pixel_id] = bottom_masks[mask_index];
}
__syncthreads();
const int channels_per_group = ceilf(channels / (float)group_size);
#pragma unroll
for (int c = split_id; c < channels; c += THREADS_PER_PIXEL) {
int mask_group = c / channels_per_group;
int top_index = Loc2Index(n, ph, pw, c, height, width, channels);
scalar_t output_val = 0;
#pragma unroll
for (int iy = start_h; iy < end_h; iy += scale_factor) {
#pragma unroll
for (int ix = start_w; ix < end_w; ix += scale_factor) {
if (iy < 0 || iy > height - 1 || ix < 0 || ix > width - 1) {
continue;
}
int mask_iy =
(iy - ph + (kernel_size - 1) * scale_factor / 2) / scale_factor;
int mask_ix =
(ix - pw + (kernel_size - 1) * scale_factor / 2) / scale_factor;
int mask_c =
(mask_group * kernel_size + mask_iy) * kernel_size + mask_ix;
int feat_index = Loc2Index(n, iy, ix, c, height, width, channels);
output_val +=
shared_mask[mask_c * WARP_SIZE + pixel_id] * top_diff[feat_index];
}
}
bottom_diff[top_index] = output_val;
}
}
template <typename scalar_t>
__global__ void FeatureSum(const int num_kernels,
const scalar_t *__restrict__ input_data,
const int scale_factor, const int channels,
const int height, const int width,
scalar_t *__restrict__ output_data) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index > num_kernels - 1) {
return;
}
const int split_id = threadIdx.x % THREADS_PER_PIXEL;
index = index / THREADS_PER_PIXEL;
const int pw = index % width;
const int ph = (index / width) % height;
const int n = index / width / height;
for (int c = split_id; c < channels; c += THREADS_PER_PIXEL) {
scalar_t output_val = 0;
for (int iy = ph * scale_factor; iy < (ph + 1) * scale_factor; iy++) {
for (int ix = pw * scale_factor; ix < (pw + 1) * scale_factor; ix++) {
int input_id = Loc2Index(n, iy, ix, c, height * scale_factor,
width * scale_factor, channels);
output_val += input_data[input_id];
}
}
const int output_id = Loc2Index(n, ph, pw, c, height, width, channels);
output_data[output_id] = output_val;
}
}
template <typename scalar_t>
__global__ void CARAFEBackward_Mask(const int num_kernels,
const scalar_t *__restrict__ top_diff,
const scalar_t *__restrict__ bottom_data,
const int kernel_size, const int group_size,
const int scale_factor, const int channels,
const int down_height, const int down_width,
const int height, const int width,
const int mask_channels,
scalar_t *__restrict__ mask_diff) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index > num_kernels - 1) {
return;
}
const int lane_id = index % WARP_SIZE;
index = index / WARP_SIZE;
const int mask_c = index % mask_channels;
// (n, c, ph, pw) is an element in the bottom_data
index = index / mask_channels;
const int pw = index % width;
const int ph = (index / width) % height;
const int n = index / width / height;
const int down_pw = pw / scale_factor;
const int down_ph = ph / scale_factor;
const int mask_group = mask_c / (kernel_size * kernel_size);
const int mask_loc = mask_c % (kernel_size * kernel_size);
const int offset_x = mask_loc % kernel_size - (kernel_size - 1) / 2;
const int offset_y =
mask_loc / kernel_size % kernel_size - (kernel_size - 1) / 2;
const int down_x = down_pw + offset_x;
const int down_y = down_ph + offset_y;
scalar_t output_val = 0;
if (down_y >= 0 && down_y <= down_height - 1 && down_x >= 0 &&
down_x <= down_width - 1) {
const int channels_per_mask = ceilf(channels / (float)group_size);
const int start = channels_per_mask * mask_group;
const int end = min(channels_per_mask * (mask_group + 1), channels);
for (int c = start + lane_id; c < end; c += WARP_SIZE) {
int bottom_id =
Loc2Index(n, down_y, down_x, c, down_height, down_width, channels);
int top_id = Loc2Index(n, ph, pw, c, height, width, channels);
output_val += top_diff[top_id] * bottom_data[bottom_id];
}
}
#ifdef MMCV_WITH_HIP
__syncthreads();
#else
__syncwarp();
#endif
output_val = warpReduceSum(output_val);
if (lane_id == 0) {
const int mask_id =
Loc2Index(n, ph, pw, mask_c, height, width, mask_channels);
mask_diff[mask_id] = output_val;
}
}
#endif // CARAFE_MUSA_KERNEL_MUH

View File

@ -0,0 +1,107 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef CARAFE_NAIVE_MUSA_KERNEL_MUH
#define CARAFE_NAIVE_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
__device__ inline int Loc2Index(const int n, const int c, const int h,
const int w, const int channel_num,
const int height, const int width) {
int index = w + (h + (c + n * channel_num) * height) * width;
return index;
}
template <typename scalar_t>
__global__ void carafe_naive_forward_musa_kernel(
const int nthreads, const scalar_t *bottom_data,
const scalar_t *bottom_masks, scalar_t *top_data, const int kernel_size,
const int group_size, const int scale_factor, const int channels,
const int height, const int width) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the bottom_data
int pw = index % width;
int ph = (index / width) % height;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
int mask_channels = kernel_size * kernel_size * group_size;
int mask_group = c / (channels / group_size);
int down_pw = pw / scale_factor;
int down_ph = ph / scale_factor;
int down_width = width / scale_factor;
int down_height = height / scale_factor;
int start_w = down_pw - (kernel_size - 1) / 2;
int end_w = down_pw + (kernel_size - 1) / 2 + 1;
int start_h = down_ph - (kernel_size - 1) / 2;
int end_h = down_ph + (kernel_size - 1) / 2 + 1;
scalar_t output_val = 0;
for (int iy = start_h; iy < end_h; iy++) {
for (int ix = start_w; ix < end_w; ix++) {
if (iy < 0 || iy > down_height - 1 || ix < 0 || ix > down_width - 1) {
continue;
}
int mask_iy = iy - down_ph + (kernel_size - 1) / 2;
int mask_ix = ix - down_pw + (kernel_size - 1) / 2;
int mask_c =
(mask_group * kernel_size + mask_iy) * kernel_size + mask_ix;
int feat_index =
Loc2Index(n, c, iy, ix, channels, down_height, down_width);
int mask_index =
Loc2Index(n, mask_c, ph, pw, mask_channels, height, width);
output_val += bottom_data[feat_index] * bottom_masks[mask_index];
}
}
top_data[index] = output_val;
}
}
template <typename scalar_t>
__global__ void carafe_naive_backward_musa_kernel(
const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_data,
const scalar_t *bottom_masks, scalar_t *bottom_diff, scalar_t *mask_diff,
const int kernel_size, const int group_size, const int scale_factor,
const int channels, const int height, const int width) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the bottom_data
int pw = index % width;
int ph = (index / width) % height;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
int mask_channels = kernel_size * kernel_size * group_size;
int mask_group = c / (channels / group_size);
int down_pw = pw / scale_factor;
int down_ph = ph / scale_factor;
int down_width = width / scale_factor;
int down_height = height / scale_factor;
int start_w = down_pw - (kernel_size - 1) / 2;
int end_w = down_pw + (kernel_size - 1) / 2 + 1;
int start_h = down_ph - (kernel_size - 1) / 2;
int end_h = down_ph + (kernel_size - 1) / 2 + 1;
for (int iy = start_h; iy < end_h; iy++) {
for (int ix = start_w; ix < end_w; ix++) {
if (iy < 0 || iy > down_height - 1 || ix < 0 || ix > down_width - 1) {
continue;
}
int mask_iy = iy - down_ph + (kernel_size - 1) / 2;
int mask_ix = ix - down_pw + (kernel_size - 1) / 2;
int mask_c =
(mask_group * kernel_size + mask_iy) * kernel_size + mask_ix;
int feat_index =
Loc2Index(n, c, iy, ix, channels, down_height, down_width);
int mask_index =
Loc2Index(n, mask_c, ph, pw, mask_channels, height, width);
atomicAdd(bottom_diff + feat_index,
bottom_masks[mask_index] * top_diff[index]);
atomicAdd(mask_diff + mask_index,
bottom_data[feat_index] * top_diff[index]);
}
}
}
}
#endif // CARAFE_NAIVE_MUSA_KERNEL_MUH

View File

@ -0,0 +1,101 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cu
#ifndef CHAMFER_DISTANCE_MUSA_KERNEL_MUH
#define CHAMFER_DISTANCE_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
#define MAX_SHARED_SCALAR_T 6144 // 49152 / 8 = 6144
#if MUSA_ARCH > 21
template <typename scalar_t>
__global__ void chamfer_distance_forward_musa_kernel(int b, int n,
const scalar_t* xyz, int m,
const scalar_t* xyz2,
scalar_t* result,
int* result_i) {
__shared__ scalar_t buf[MAX_SHARED_SCALAR_T];
for (int i = blockIdx.x; i < b; i += gridDim.x) {
for (int k2 = 0; k2 < m; k2 += THREADS_PER_BLOCK) {
int end_k = min(m, k2 + THREADS_PER_BLOCK) - k2;
for (int j = threadIdx.x; j < end_k * 2; j += blockDim.x) {
buf[j] = xyz2[(i * m + k2) * 2 + j];
}
__syncthreads();
for (int j = threadIdx.x; j < n; j += blockDim.x * gridDim.y) {
scalar_t x1 = xyz[(i * n + j) * 2 + 0];
scalar_t y1 = xyz[(i * n + j) * 2 + 1];
int best_i = 0;
scalar_t best = 1e10;
int end_ka = end_k & (~3);
if (end_ka == THREADS_PER_BLOCK) {
for (int k = 0; k < THREADS_PER_BLOCK; k += 4) {
#pragma unroll
for (int j = 0; j < 4; ++j) {
scalar_t x2 = buf[(k + j) * 2] - x1;
scalar_t y2 = buf[(k + j) * 2 + 1] - y1;
scalar_t d = x2 * x2 + y2 * y2;
if (d < best) {
best = d;
best_i = k + k2 + j;
}
}
}
} else {
for (int k = 0; k < end_ka; k += 4) {
#pragma unroll
for (int j = 0; j < 4; ++j) {
scalar_t x2 = buf[(k + j) * 2] - x1;
scalar_t y2 = buf[(k + j) * 2 + 1] - y1;
scalar_t d = x2 * x2 + y2 * y2;
if (d < best) {
best = d;
best_i = k + k2 + j;
}
}
}
}
for (int k = end_ka; k < end_k; k++) {
scalar_t x2 = buf[k * 2 + 0] - x1;
scalar_t y2 = buf[k * 2 + 1] - y1;
scalar_t d = x2 * x2 + y2 * y2;
if (k == 0 || d < best) {
best = d;
best_i = k + k2;
}
}
if (k2 == 0 || result[(i * n + j)] > best) {
result[(i * n + j)] = best;
result_i[(i * n + j)] = best_i;
}
}
__syncthreads();
}
}
}
template <typename scalar_t>
__global__ void chamfer_distance_backward_musa_kernel(
int b, int n, const scalar_t* xyz1, int m, const scalar_t* xyz2,
const scalar_t* grad_dist1, const int* idx1, scalar_t* grad_xyz1,
scalar_t* grad_xyz2) {
for (int i = blockIdx.x; i < b; i += gridDim.x) {
for (int j = threadIdx.x; j < n; j += blockDim.x * gridDim.y) {
scalar_t x1 = xyz1[(i * n + j) * 2 + 0];
scalar_t y1 = xyz1[(i * n + j) * 2 + 1];
int j2 = idx1[i * n + j];
scalar_t x2 = xyz2[(i * m + j2) * 2 + 0];
scalar_t y2 = xyz2[(i * m + j2) * 2 + 1];
scalar_t g = grad_dist1[i * n + j] * 2;
atomicAdd(&(grad_xyz1[(i * n + j) * 2 + 0]), g * (x1 - x2));
atomicAdd(&(grad_xyz1[(i * n + j) * 2 + 1]), g * (y1 - y2));
atomicAdd(&(grad_xyz2[(i * m + j2) * 2 + 0]), -(g * (x1 - x2)));
atomicAdd(&(grad_xyz2[(i * m + j2) * 2 + 1]), -(g * (y1 - y2)));
}
}
}
#else
#warning "chamfer_distance is supported when MUSA_ARCH > 21"
#endif //MUSA_ARCH
#endif // CHAMFER_DISTANCE_MUSA_KERNEL_MUH

View File

@ -0,0 +1,827 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef CONVEX_IOU_MUSA_KERNEL_MUH
#define CONVEX_IOU_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
#define MAXN 100
#define NMAX 512
__device__ const double EPS = 1E-8;
__device__ inline int sig(double d) { return (d > EPS) - (d < -EPS); }
struct Point {
double x, y;
__device__ Point() {}
__device__ Point(double x, double y) : x(x), y(y) {}
};
__device__ inline bool point_same(Point& a, Point& b) {
return sig(a.x - b.x) == 0 && sig(a.y - b.y) == 0;
}
__device__ inline void swap1(Point* a, Point* b) {
Point temp;
temp.x = a->x;
temp.y = a->y;
a->x = b->x;
a->y = b->y;
b->x = temp.x;
b->y = temp.y;
}
__device__ inline void reverse1(Point* a, const int n) {
for (int i = 0; i < (n - 1) / 2.0; i++) {
Point* j = &(a[i]);
Point* k = &(a[n - 1 - i]);
swap1(j, k);
}
}
__device__ inline double cross(Point o, Point a, Point b) {
return (a.x - o.x) * (b.y - o.y) - (b.x - o.x) * (a.y - o.y);
}
__device__ inline double dis(Point a, Point b) {
return (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y);
}
__device__ inline double area(Point* ps, int n) {
ps[n] = ps[0];
double res = 0;
for (int i = 0; i < n; i++) {
res += ps[i].x * ps[i + 1].y - ps[i].y * ps[i + 1].x;
}
return res / 2.0;
}
__device__ inline double polygon_area_grad(Point* ps, int n,
int* polygon_to_pred_index,
int n_pred, double* grad_C) {
ps[n] = ps[0];
double partion_grad[4 * 30 + 2];
double res = 0;
for (int i = 0; i < n; i++) {
res += ps[i].x * ps[i + 1].y - ps[i].y * ps[i + 1].x;
partion_grad[i * 4 + 2] = ps[i + 1].y;
partion_grad[i * 4 + 3] = -ps[i + 1].x;
if (i != n - 1) {
partion_grad[i * 4 + 4] = -ps[i].y;
partion_grad[i * 4 + 5] = ps[i].x;
} else {
partion_grad[0] = -ps[i].y;
partion_grad[1] = ps[i].x;
}
}
for (int i = 0; i < n; i++) {
for (int j = 0; j < n_pred; j++) {
if (i == polygon_to_pred_index[j]) {
grad_C[2 * polygon_to_pred_index[j + n_pred]] =
(partion_grad[i * 4] + partion_grad[i * 4 + 2]) / 2;
break;
}
}
for (int j = 0; j < n_pred; j++) {
if (i == polygon_to_pred_index[j]) {
grad_C[2 * polygon_to_pred_index[j + n_pred] + 1] =
(partion_grad[i * 4 + 1] + partion_grad[i * 4 + 1 + 2]) / 2;
break;
}
}
}
return res / 2.0;
}
__device__ inline int lineCross(Point a, Point b, Point c, Point d, Point& p,
double* cut_grad, int m, int n, int i) {
double s1, s2;
double s2_s1_2;
double ds1_dxc, ds1_dyc, ds2_dxd, ds2_dyd;
double dxp_dxc, dxp_dyc, dxp_dxd, dxp_dyd, dyp_dxc, dyp_dyc, dyp_dxd, dyp_dyd;
s1 = cross(a, b, c);
s2 = cross(a, b, d);
ds1_dxc = -(b.y - a.y);
ds1_dyc = b.x - a.x;
ds2_dxd = ds1_dxc;
ds2_dyd = ds1_dyc;
s2_s1_2 = (s2 - s1) * (s2 - s1);
if (sig(s1) == 0 && sig(s2) == 0) return 2;
if (sig(s2 - s1) == 0) return 0;
dxp_dxc =
((s2 - d.x * ds1_dxc) * (s2 - s1) - (c.x * s2 - d.x * s1) * (-ds1_dxc)) /
(s2_s1_2);
dxp_dyc =
((0 - d.x * ds1_dyc) * (s2 - s1) - (c.x * s2 - d.x * s1) * (-ds1_dyc)) /
(s2_s1_2);
dxp_dxd =
((c.x * ds2_dxd - s1) * (s2 - s1) - (c.x * s2 - d.x * s1) * (ds2_dxd)) /
(s2_s1_2);
dxp_dyd =
((c.x * ds2_dyd - 0) * (s2 - s1) - (c.x * s2 - d.x * s1) * (ds2_dyd)) /
(s2_s1_2);
dyp_dxc =
((0 - d.y * ds1_dxc) * (s2 - s1) - (c.y * s2 - d.y * s1) * (-ds1_dxc)) /
(s2_s1_2);
dyp_dyc =
((s2 - d.y * ds1_dyc) * (s2 - s1) - (c.y * s2 - d.y * s1) * (-ds1_dyc)) /
(s2_s1_2);
dyp_dxd =
((c.y * ds2_dxd - 0) * (s2 - s1) - (c.y * s2 - d.y * s1) * (ds2_dxd)) /
(s2_s1_2);
dyp_dyd =
((c.y * ds2_dyd - s1) * (s2 - s1) - (c.y * s2 - d.y * s1) * (ds2_dyd)) /
(s2_s1_2);
p.x = (c.x * s2 - d.x * s1) / (s2 - s1);
p.y = (c.y * s2 - d.y * s1) / (s2 - s1);
if (i == n - 1) {
cut_grad[4 * n * m + 4 * i] = dxp_dxc; // + dyp_dxc;
cut_grad[4 * n * m + 4 * i + 1] = dyp_dxc;
cut_grad[4 * n * m + 4 * i + 2] = dxp_dyc; // + dyp_dyc;
cut_grad[4 * n * m + 4 * i + 3] = dyp_dyc;
cut_grad[4 * n * m + 0] = dxp_dxd; // + dyp_dxd;
cut_grad[4 * n * m + 1] = dyp_dxd;
cut_grad[4 * n * m + 2] = dxp_dyd; // + dyp_dyd;
cut_grad[4 * n * m + 3] = dyp_dyd;
} else {
cut_grad[4 * n * m + 4 * i] = dxp_dxc; // + dyp_dxc;
cut_grad[4 * n * m + 4 * i + 1] = dyp_dxc;
cut_grad[4 * n * m + 4 * i + 2] = dxp_dyc; // + dyp_dyc;
cut_grad[4 * n * m + 4 * i + 3] = dyp_dyc;
cut_grad[4 * n * m + 4 * (i + 1)] = dxp_dxd; // + dyp_dxd;
cut_grad[4 * n * m + 4 * (i + 1) + 1] = dyp_dxd;
cut_grad[4 * n * m + 4 * (i + 1) + 2] = dxp_dyd; // + dyp_dyd;
cut_grad[4 * n * m + 4 * (i + 1) + 3] = dyp_dyd;
}
return 1;
}
__device__ inline void polygon_cut(Point* p, int& n, Point a, Point b,
double* cut_grad) {
Point pp[MAXN];
double ccur_grad[MAXN] = {};
int m = 0;
p[n] = p[0];
int k = n;
for (int i = 0; i < n; i++) {
if (sig(cross(a, b, p[i])) > 0) {
pp[m] = p[i];
ccur_grad[4 * n * m + 4 * i] = 1.0;
ccur_grad[4 * n * m + 4 * i + 3] = 1.0;
m++;
}
if (sig(cross(a, b, p[i])) != sig(cross(a, b, p[i + 1]))) {
lineCross(a, b, p[i], p[i + 1], pp[m], ccur_grad, m, n, i);
m++;
}
}
n = 0;
for (int i = 0; i < m; i++) {
if (!i || !(point_same(pp[i], pp[i - 1]))) {
p[n] = pp[i];
for (int j = 0; j < 4 * k; j++) {
cut_grad[4 * k * n + j] = ccur_grad[4 * k * i + j];
}
n++;
}
}
while (n > 1 && point_same(p[n - 1], p[0])) n--;
}
__device__ inline double intersectArea(Point a, Point b, Point c, Point d,
double* grad_AB, int order,
int convex_n) {
Point o(0, 0);
int res_flag = 0;
int s1 = sig(cross(o, a, b));
int s2 = sig(cross(o, c, d));
if (s1 == 0 || s2 == 0) return 0.0;
if (s1 == -1) {
Point* i = &a;
Point* j = &b;
swap1(i, j);
res_flag = 1;
}
if (s2 == -1) {
Point* i = &c;
Point* j = &d;
swap1(i, j);
}
Point p[10] = {o, a, b};
int n = 3, n0 = 3, n1, n2, n3;
double cut_grad1[MAXN] = {};
double cut_grad2[MAXN] = {};
double cut_grad3[MAXN] = {};
double p1_p_grad[10][10] = {};
double p2_p1_grad[10][10] = {};
double p3_p2_grad[10][10] = {};
double p3_p1_grad[10][10] = {};
double p3_p_grad[10][10] = {};
// 1
polygon_cut(p, n, o, c, cut_grad1);
n1 = n;
for (int i = 0; i < n; i++) {
for (int j = 0; j < 4 * n0; j++) {
if (!(j % 2)) {
p1_p_grad[2 * i][j / 2] = cut_grad1[4 * n0 * i + j];
} else {
p1_p_grad[2 * i + 1][j / 2] = cut_grad1[4 * n0 * i + j];
}
}
}
// 2
polygon_cut(p, n, c, d, cut_grad2);
n2 = n;
for (int i = 0; i < n; i++) {
for (int j = 0; j < 4 * n1; j++) {
if (!(j % 2)) {
p2_p1_grad[2 * i][j / 2] = cut_grad2[4 * n1 * i + j];
} else {
p2_p1_grad[2 * i + 1][j / 2] = cut_grad2[4 * n1 * i + j];
}
}
}
// 3
polygon_cut(p, n, d, o, cut_grad3);
n3 = n;
for (int i = 0; i < n; i++) {
for (int j = 0; j < 4 * n2; j++) {
if (!(j % 2)) {
p3_p2_grad[2 * i][j / 2] = cut_grad3[4 * n2 * i + j];
} else {
p3_p2_grad[2 * i + 1][j / 2] = cut_grad3[4 * n2 * i + j];
}
}
}
// mul
// p3_p2(n3 * n2) * p2_p1(n2 * n1) = p3_p1 (n3 * n1)
for (int i = 0; i < 2 * n3; i++) {
for (int j = 0; j < 2 * n1; j++) {
double sum = 0.0;
for (int m = 0; m < 2 * n2; m++) {
sum = sum + p3_p2_grad[i][m] * p2_p1_grad[m][j];
}
p3_p1_grad[i][j] = sum;
}
}
// p3_p1 (n3 * n1) * p1_p (n1 * n0) = p3_p (n3 * n0)
for (int i = 0; i < 2 * n3; i++) {
for (int j = 0; j < 2 * n0; j++) {
double sum = 0.0;
for (int m = 0; m < 2 * n1; m++) {
sum = sum + p3_p1_grad[i][m] * p1_p_grad[m][j];
}
p3_p_grad[i][j] = sum;
}
}
// calculate S_grad
int polygon_index_box_index[20];
double grad_polygon[20];
double S_grad[6];
for (int i = 0; i < n3; i++) {
polygon_index_box_index[i] = i;
polygon_index_box_index[i + n3] = i;
}
double res =
polygon_area_grad(p, n3, polygon_index_box_index, n3, grad_polygon);
if (s1 * s2 == -1) {
for (int j = 0; j < 2 * 3; j++) {
double sum = 0.0;
for (int m = 0; m < 2 * n3; m++) {
sum = sum - grad_polygon[m] * p3_p_grad[m][j];
}
S_grad[j] = sum;
}
if (order != convex_n - 1) {
if (res_flag) {
grad_AB[2 * order] += S_grad[4];
grad_AB[2 * order + 1] += S_grad[5];
grad_AB[2 * order + 2] += S_grad[2];
grad_AB[2 * order + 3] += S_grad[3];
} else {
grad_AB[2 * order] += S_grad[2];
grad_AB[2 * order + 1] += S_grad[3];
grad_AB[2 * order + 2] += S_grad[4];
grad_AB[2 * order + 3] += S_grad[5];
}
} else {
if (res_flag) {
grad_AB[2 * order] += S_grad[4];
grad_AB[2 * order + 1] += S_grad[5];
grad_AB[0] += S_grad[2];
grad_AB[1] += S_grad[3];
} else {
grad_AB[2 * order] += S_grad[2];
grad_AB[2 * order + 1] += S_grad[3];
grad_AB[0] += S_grad[4];
grad_AB[1] += S_grad[5];
}
}
res = -res;
} else {
for (int j = 0; j < 2 * 3; j++) {
double sum = 0.0;
for (int m = 0; m < 2 * n3; m++) {
sum = sum + grad_polygon[m] * p3_p_grad[m][j];
}
S_grad[j] = sum;
}
if (order != convex_n - 1) {
if (res_flag) {
grad_AB[2 * order] += S_grad[4];
grad_AB[2 * order + 1] += S_grad[5];
grad_AB[2 * order + 2] += S_grad[2];
grad_AB[2 * order + 3] += S_grad[3];
} else {
grad_AB[2 * order] += S_grad[2];
grad_AB[2 * order + 1] += S_grad[3];
grad_AB[2 * order + 2] += S_grad[4];
grad_AB[2 * order + 3] += S_grad[5];
}
} else {
if (res_flag) {
grad_AB[2 * order] += S_grad[4];
grad_AB[2 * order + 1] += S_grad[5];
grad_AB[0] += S_grad[2];
grad_AB[1] += S_grad[3];
} else {
grad_AB[2 * order] += S_grad[2];
grad_AB[2 * order + 1] += S_grad[3];
grad_AB[0] += S_grad[4];
grad_AB[1] += S_grad[5];
}
}
}
return res;
}
__device__ inline double intersectAreaO(Point* ps1, int n1, Point* ps2, int n2,
double* grad_AB) {
if (area(ps1, n1) < 0) reverse1(ps1, n1);
if (area(ps2, n2) < 0) reverse1(ps2, n2);
ps1[n1] = ps1[0];
ps2[n2] = ps2[0];
double res = 0;
for (int i = 0; i < n1; i++) {
for (int j = 0; j < n2; j++) {
res +=
intersectArea(ps1[i], ps1[i + 1], ps2[j], ps2[j + 1], grad_AB, i, n1);
}
}
return res;
}
__device__ inline void Jarvis(Point* in_poly, int& n_poly) {
Point p_max, p_k;
int max_index, k_index;
int Stack[NMAX] = {}, top1, top2;
double sign;
Point right_point[10], left_point[10];
for (int i = 0; i < n_poly; i++) {
if (in_poly[i].y < in_poly[0].y ||
in_poly[i].y == in_poly[0].y && in_poly[i].x < in_poly[0].x) {
Point* j = &(in_poly[0]);
Point* k = &(in_poly[i]);
swap1(j, k);
}
if (i == 0) {
p_max = in_poly[0];
max_index = 0;
}
if (in_poly[i].y > p_max.y ||
in_poly[i].y == p_max.y && in_poly[i].x > p_max.x) {
p_max = in_poly[i];
max_index = i;
}
}
if (max_index == 0) {
max_index = 1;
p_max = in_poly[max_index];
}
k_index = 0, Stack[0] = 0, top1 = 0;
while (k_index != max_index) {
p_k = p_max;
k_index = max_index;
for (int i = 1; i < n_poly; i++) {
sign = cross(in_poly[Stack[top1]], in_poly[i], p_k);
if ((sign > 0) || ((sign == 0) && (dis(in_poly[Stack[top1]], in_poly[i]) >
dis(in_poly[Stack[top1]], p_k)))) {
p_k = in_poly[i];
k_index = i;
}
}
top1++;
Stack[top1] = k_index;
}
for (int i = 0; i <= top1; i++) right_point[i] = in_poly[Stack[i]];
k_index = 0, Stack[0] = 0, top2 = 0;
while (k_index != max_index) {
p_k = p_max;
k_index = max_index;
for (int i = 1; i < n_poly; i++) {
sign = cross(in_poly[Stack[top2]], in_poly[i], p_k);
if ((sign < 0) || (sign == 0) && (dis(in_poly[Stack[top2]], in_poly[i]) >
dis(in_poly[Stack[top2]], p_k))) {
p_k = in_poly[i];
k_index = i;
}
}
top2++;
Stack[top2] = k_index;
}
for (int i = top2 - 1; i >= 0; i--) left_point[i] = in_poly[Stack[i]];
for (int i = 0; i < top1 + top2; i++) {
if (i <= top1) {
in_poly[i] = right_point[i];
} else {
in_poly[i] = left_point[top2 - (i - top1)];
}
}
n_poly = top1 + top2;
}
__device__ inline double intersectAreaPoly(Point* ps1, int n1, Point* ps2,
int n2, double* grad_C) {
Point polygon[MAXN];
int n = n1 + n2, n_poly = 0;
for (int i = 0; i < n1; i++) {
for (int j = 0; j < n - n1; j++) {
if (point_same(ps1[i], ps2[j])) {
for (int k = j; k < n - n1 - 1; k++) {
ps2[k] = ps2[k + 1];
}
n2--;
break;
}
}
}
n_poly = n1 + n2;
for (int i = 0; i < n_poly; i++) {
if (i < n1) {
polygon[i] = ps1[i];
} else {
polygon[i] = ps2[i - n1];
}
}
Jarvis(polygon, n_poly);
int polygon_to_pred_index[18] = {-1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1};
int n_pred = 0;
for (int i = 0; i < n_poly; i++) {
for (int j = 0; j < n1; j++) {
if (polygon[i].x == ps1[j].x && polygon[i].y == ps1[j].y) {
polygon_to_pred_index[n_pred] = i;
polygon_to_pred_index[n_pred + n1] = j;
n_pred += 1;
break;
}
}
}
if (n_pred == 0) {
double polygon_area = fabs(area(polygon, n_poly));
for (int i = 0; i < 18; i++) {
grad_C[i] = 0.0;
}
return polygon_area;
} else {
double polygon_area =
polygon_area_grad(polygon, n_poly, polygon_to_pred_index, n1, grad_C);
if (polygon_area < 0) {
for (int i = 0; i < 18; i++) {
grad_C[i] = -grad_C[i];
}
}
return fabs(polygon_area);
}
}
// convex_find and get the polygon_index_box_index
__device__ inline void Jarvis_and_index(Point* in_poly, int& n_poly,
int* points_to_convex_ind) {
int n_input = n_poly;
Point input_poly[20];
for (int i = 0; i < n_input; i++) {
input_poly[i].x = in_poly[i].x;
input_poly[i].y = in_poly[i].y;
}
Point p_max, p_k;
int max_index, k_index;
int Stack[20], top1, top2;
double sign;
Point right_point[10], left_point[10];
for (int i = 0; i < n_poly; i++) {
if (in_poly[i].y < in_poly[0].y ||
in_poly[i].y == in_poly[0].y && in_poly[i].x < in_poly[0].x) {
Point* j = &(in_poly[0]);
Point* k = &(in_poly[i]);
swap1(j, k);
}
if (i == 0) {
p_max = in_poly[0];
max_index = 0;
}
if (in_poly[i].y > p_max.y ||
in_poly[i].y == p_max.y && in_poly[i].x > p_max.x) {
p_max = in_poly[i];
max_index = i;
}
}
if (max_index == 0) {
max_index = 1;
p_max = in_poly[max_index];
}
k_index = 0, Stack[0] = 0, top1 = 0;
while (k_index != max_index) {
p_k = p_max;
k_index = max_index;
for (int i = 1; i < n_poly; i++) {
sign = cross(in_poly[Stack[top1]], in_poly[i], p_k);
if ((sign > 0) || ((sign == 0) && (dis(in_poly[Stack[top1]], in_poly[i]) >
dis(in_poly[Stack[top1]], p_k)))) {
p_k = in_poly[i];
k_index = i;
}
}
top1++;
Stack[top1] = k_index;
}
for (int i = 0; i <= top1; i++) {
right_point[i] = in_poly[Stack[i]];
}
k_index = 0, Stack[0] = 0, top2 = 0;
while (k_index != max_index) {
p_k = p_max;
k_index = max_index;
for (int i = 1; i < n_poly; i++) {
sign = cross(in_poly[Stack[top2]], in_poly[i], p_k);
if ((sign < 0) || (sign == 0) && (dis(in_poly[Stack[top2]], in_poly[i]) >
dis(in_poly[Stack[top2]], p_k))) {
p_k = in_poly[i];
k_index = i;
}
}
top2++;
Stack[top2] = k_index;
}
for (int i = top2 - 1; i >= 0; i--) {
left_point[i] = in_poly[Stack[i]];
}
for (int i = 0; i < top1 + top2; i++) {
if (i <= top1) {
in_poly[i] = right_point[i];
} else {
in_poly[i] = left_point[top2 - (i - top1)];
}
}
n_poly = top1 + top2;
for (int i = 0; i < n_poly; i++) {
for (int j = 0; j < n_input; j++) {
if (point_same(in_poly[i], input_poly[j])) {
points_to_convex_ind[i] = j;
break;
}
}
}
}
template <typename T>
__device__ inline float devrIoU(T const* const p, T const* const q,
T* point_grad, const int idx) {
Point ps1[MAXN], ps2[MAXN];
Point convex[MAXN];
for (int i = 0; i < 9; i++) {
convex[i].x = (double)p[i * 2];
convex[i].y = (double)p[i * 2 + 1];
}
int n_convex = 9;
int points_to_convex_ind[9] = {-1, -1, -1, -1, -1, -1, -1, -1, -1};
Jarvis_and_index(convex, n_convex, points_to_convex_ind);
int n1 = n_convex;
int n2 = 4;
for (int i = 0; i < n1; i++) {
ps1[i].x = (double)convex[i].x;
ps1[i].y = (double)convex[i].y;
}
for (int i = 0; i < n2; i++) {
ps2[i].x = (double)q[i * 2];
ps2[i].y = (double)q[i * 2 + 1];
}
int polygon_index_box_index[18];
for (int i = 0; i < n1; i++) {
polygon_index_box_index[i] = i;
polygon_index_box_index[i + n1] = i;
}
double grad_A[18] = {};
double grad_AB[18] = {};
double grad_C[18] = {};
double inter_area = intersectAreaO(ps1, n1, ps2, n2, grad_AB);
double S_pred =
polygon_area_grad(ps1, n1, polygon_index_box_index, n1, grad_A);
if (S_pred < 0) {
for (int i = 0; i < n_convex * 2; i++) {
grad_A[i] = -grad_A[i];
}
}
double union_area = fabs(S_pred) + fabs(area(ps2, n2)) - inter_area;
double iou = inter_area / union_area;
double polygon_area = intersectAreaPoly(ps1, n1, ps2, n2, grad_C);
// printf("%d:live\n", idx);
double rot_giou = iou - (polygon_area - union_area) / polygon_area;
float grad_point_temp[18] = {};
for (int i = 0; i < n_convex; i++) {
int grad_point = points_to_convex_ind[i];
grad_point_temp[2 * grad_point] =
(float)((union_area + inter_area) / (union_area * union_area) *
grad_AB[2 * i] -
iou / union_area * grad_A[2 * i] -
1 / polygon_area * (grad_AB[2 * i] - grad_A[2 * i]) -
(union_area) / polygon_area / polygon_area * grad_C[2 * i]);
grad_point_temp[2 * grad_point + 1] =
(float)((union_area + inter_area) / (union_area * union_area) *
grad_AB[2 * i + 1] -
iou / union_area * grad_A[2 * i + 1] -
1 / polygon_area * (grad_AB[2 * i + 1] - grad_A[2 * i + 1]) -
(union_area) / polygon_area / polygon_area * grad_C[2 * i + 1]);
}
for (int i = 0; i < 9; i++) {
point_grad[2 * i] = grad_point_temp[2 * i];
point_grad[2 * i + 1] = grad_point_temp[2 * i + 1];
}
return (float)rot_giou;
}
template <typename T>
__global__ void convex_giou_musa_kernel(const int ex_n_boxes,
const int gt_n_boxes, const T* ex_boxes,
const T* gt_boxes, T* point_grad) {
MUSA_1D_KERNEL_LOOP(index, ex_n_boxes) {
const T* cur_box = ex_boxes + index * 18;
const T* cur_gt_box = gt_boxes + index * 8;
T* cur_grad = point_grad + index * 19;
T giou = devrIoU(cur_box, cur_gt_box, cur_grad, threadIdx.x);
cur_grad[18] = giou;
}
}
__device__ inline int lineCross(Point a, Point b, Point c, Point d, Point& p) {
double s1, s2;
s1 = cross(a, b, c);
s2 = cross(a, b, d);
if (sig(s1) == 0 && sig(s2) == 0) return 2;
if (sig(s2 - s1) == 0) return 0;
p.x = (c.x * s2 - d.x * s1) / (s2 - s1);
p.y = (c.y * s2 - d.y * s1) / (s2 - s1);
return 1;
}
__device__ inline void polygon_cut(Point* p, int& n, Point a, Point b) {
Point pp[MAXN];
int m = 0;
p[n] = p[0];
for (int i = 0; i < n; i++) {
if (sig(cross(a, b, p[i])) > 0) {
pp[m] = p[i];
m++;
}
if (sig(cross(a, b, p[i])) != sig(cross(a, b, p[i + 1]))) {
lineCross(a, b, p[i], p[i + 1], pp[m]);
m++;
}
}
n = 0;
for (int i = 0; i < m; i++) {
if (!i || !(point_same(pp[i], pp[i - 1]))) {
p[n] = pp[i];
n++;
}
}
while (n > 1 && point_same(p[n - 1], p[0])) n--;
}
__device__ inline double intersectArea(Point a, Point b, Point c, Point d) {
Point o(0, 0);
int s1 = sig(cross(o, a, b));
int s2 = sig(cross(o, c, d));
if (s1 == 0 || s2 == 0) return 0.0;
if (s1 == -1) {
Point* i = &a;
Point* j = &b;
swap1(i, j);
}
if (s2 == -1) {
Point* i = &c;
Point* j = &d;
swap1(i, j);
}
Point p[10] = {o, a, b};
int n = 3;
polygon_cut(p, n, o, c);
polygon_cut(p, n, c, d);
polygon_cut(p, n, d, o);
double res = area(p, n);
if (s1 * s2 == -1) res = -res;
return res;
}
__device__ inline double intersectAreaO(Point* ps1, int n1, Point* ps2,
int n2) {
if (area(ps1, n1) < 0) reverse1(ps1, n1);
if (area(ps2, n2) < 0) reverse1(ps2, n2);
ps1[n1] = ps1[0];
ps2[n2] = ps2[0];
double res = 0;
for (int i = 0; i < n1; i++) {
for (int j = 0; j < n2; j++) {
res += intersectArea(ps1[i], ps1[i + 1], ps2[j], ps2[j + 1]);
}
}
return res;
}
template <typename T>
__device__ inline float devrIoU(T const* const p, T const* const q) {
Point ps1[MAXN], ps2[MAXN];
Point convex[MAXN];
for (int i = 0; i < 9; i++) {
convex[i].x = (double)p[i * 2];
convex[i].y = (double)p[i * 2 + 1];
}
int n_convex = 9;
int points_to_convex_ind[9] = {-1, -1, -1, -1, -1, -1, -1, -1, -1};
Jarvis_and_index(convex, n_convex, points_to_convex_ind);
int n1 = n_convex;
for (int i = 0; i < n1; i++) {
ps1[i].x = (double)convex[i].x;
ps1[i].y = (double)convex[i].y;
}
int n2 = 4;
for (int i = 0; i < n2; i++) {
ps2[i].x = (double)q[i * 2];
ps2[i].y = (double)q[i * 2 + 1];
}
double inter_area = intersectAreaO(ps1, n1, ps2, n2);
double S_pred = area(ps1, n1);
double union_area = fabs(S_pred) + fabs(area(ps2, n2)) - inter_area;
double iou = inter_area / union_area;
return (float)iou;
}
template <typename T>
__global__ void convex_iou_musa_kernel(const int ex_n_boxes,
const int gt_n_boxes, const T* ex_boxes,
const T* gt_boxes, T* iou) {
MUSA_1D_KERNEL_LOOP(index, ex_n_boxes) {
const T* cur_box = ex_boxes + index * 18;
for (int i = 0; i < gt_n_boxes; i++) {
iou[index * gt_n_boxes + i] = devrIoU(cur_box, gt_boxes + i * 8);
}
}
}
#endif // CONVEX_IOU_MUSA_KERNEL_MUH

View File

@ -0,0 +1,227 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/ClementPinard/Pytorch-Correlation-extension/blob/master/Correlation_Module/correlation_cuda_kernel.cu
// Original licence: Under MIT License
#ifndef CORRELATION_MUSA
#define CORRELATION_MUSA
#include "pytorch_musa_helper.hpp"
#include <musa.h>
#include <musa_runtime.h>
// Using <torch/extension.h> is recommended in the official documentation in
// https://pytorch.org/tutorials/advanced/cpp_extension.html#writing-the-c-op.
// However, we use <torch/types.h> for compatibility with MUSA 9.0
// Read https://github.com/pytorch/extension-cpp/issues/35 for more details.
#include <torch/types.h>
#include <iostream>
#include <vector>
using namespace torch;
#define TensorAcc4R PackedTensorAccessor32<scalar_t, 4, RestrictPtrTraits>
#define TensorAcc5R PackedTensorAccessor32<scalar_t, 5, RestrictPtrTraits>
#define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W)
#define WARP_SIZE 32
#define FULL_MASK 0xffffffff
template <typename scalar_t>
__global__ void correlation_forward_musa_kernel(
const TensorAcc4R rInput1, const TensorAcc4R rInput2, TensorAcc5R output,
int kH, int kW, int patchH, int patchW, int padH, int padW, int dilationH,
int dilationW, int dilation_patchH, int dilation_patchW, int dH, int dW,
int oH, int oW) {
const int iH = rInput1.size(1);
const int iW = rInput1.size(2);
const int C = rInput1.size(3);
const int n = blockIdx.x;
const int h = blockIdx.y * blockDim.y + threadIdx.y;
const int w = blockIdx.z * blockDim.z + threadIdx.z;
if (h >= oH || w >= oW) return;
const int thread = threadIdx.x;
const int start_i = -padH + h * dH;
const int start_j = -padW + w * dW;
const int patchRadH = dilation_patchH * (patchH - 1) / 2;
const int patchRadW = dilation_patchW * (patchW - 1) / 2;
for (int ph = 0; ph < patchH; ++ph) {
int ph_dilated = ph * dilation_patchH - patchRadH;
for (int pw = 0; pw < patchW; ++pw) {
int pw_dilated = pw * dilation_patchW - patchRadW;
scalar_t prod_sum = 0.0f;
for (int i = 0; i < kH; ++i) {
int i1 = start_i + i * dilationH;
int i2 = i1 + ph_dilated;
if (WITHIN_BOUNDS(i1, i2, iH, iH)) {
for (int j = 0; j < kW; ++j) {
int j1 = start_j + j * dilationW;
int j2 = j1 + pw_dilated;
if (WITHIN_BOUNDS(j1, j2, iW, iW)) {
for (int c = thread; c < C; c += WARP_SIZE) {
scalar_t v1 = rInput1[n][i1][j1][c];
scalar_t v2 = rInput2[n][i2][j2][c];
prod_sum += v1 * v2;
}
}
}
}
}
// accumulate
for (int offset = 16; offset > 0; offset /= 2)
#ifdef MMCV_WITH_HIP
prod_sum += __shfl_down(float(prod_sum), offset);
#else
prod_sum += __shfl_down_sync(FULL_MASK, float(prod_sum), offset);
#endif
if (thread == 0) {
output[n][ph][pw][h][w] = prod_sum;
}
}
}
}
template <typename scalar_t>
__global__ void correlation_backward_musa_kernel_input1(
const TensorAcc5R grad_output, const TensorAcc4R input2,
TensorAcc4R grad_input1, const int kH, const int kW, const int patchH,
const int patchW, const int padH, const int padW, const int dilationH,
const int dilationW, const int dilation_patchH, const int dilation_patchW,
const int dH, const int dW) {
const int iH = input2.size(1);
const int iW = input2.size(2);
const int C = input2.size(3);
const int H = grad_output.size(3);
const int W = grad_output.size(4);
const int patchRadH = (patchH - 1) / 2;
const int patchRadW = (patchW - 1) / 2;
const int n = blockIdx.x;
const int h = blockIdx.y;
const int w = blockIdx.z;
const int h_2 = h + padH;
const int w_2 = w + padW;
const int min_h = h_2 - kH * dilationH;
const int min_w = w_2 - kW * dilationW;
extern __shared__ __align__(sizeof(4)) unsigned char grad_cache_char[];
scalar_t *grad_cache = reinterpret_cast<scalar_t *>(grad_cache_char);
for (int i = threadIdx.x; i < patchH * patchW; i += blockDim.x) {
const int ph = i / patchW;
const int pw = i % patchW;
int i1 = h + dilation_patchH * (ph - patchRadH);
int j1 = w + dilation_patchW * (pw - patchRadW);
if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
scalar_t grad_val = 0.0f;
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
int i2 = (h_3) / dH;
if (i2 * dH != h_3) continue;
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
int j2 = (w_3) / dW;
if (j2 * dW != w_3) continue;
if (WITHIN_BOUNDS(i2, j2, H, W)) {
grad_val += grad_output[n][ph][pw][i2][j2];
}
}
}
grad_cache[i] = grad_val;
}
}
__syncthreads();
for (int c = threadIdx.x; c < C; c += blockDim.x) {
scalar_t grad_input_val = 0.0f;
for (int ph = 0; ph < patchH; ++ph) {
int i1 = h + dilation_patchH * (ph - patchRadH);
for (int pw = 0; pw < patchW; ++pw) {
int j1 = w + dilation_patchW * (pw - patchRadW);
if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
grad_input_val += input2[n][i1][j1][c] * grad_cache[ph * patchW + pw];
}
}
}
grad_input1[n][c][h][w] = grad_input_val;
}
}
template <typename scalar_t>
__global__ void correlation_backward_musa_kernel_input2(
const TensorAcc5R grad_output, const TensorAcc4R input1,
TensorAcc4R grad_input2, int kH, int kW, int patchH, int patchW, int padH,
int padW, int dilationH, int dilationW, int dilation_patchH,
int dilation_patchW, int dH, int dW) {
const int iH = input1.size(1);
const int iW = input1.size(2);
const int C = input1.size(3);
const int patchRadH = (patchH - 1) / 2;
const int patchRadW = (patchW - 1) / 2;
const int H = grad_output.size(3);
const int W = grad_output.size(4);
const int dilatedKH = kH * dilationH;
const int dilatedKW = kW * dilationW;
const int n = blockIdx.x;
const int h = blockIdx.y;
const int w = blockIdx.z;
extern __shared__ __align__(sizeof(4)) unsigned char grad_cache_char[];
scalar_t *grad_cache = reinterpret_cast<scalar_t *>(grad_cache_char);
for (int i = threadIdx.x; i < patchH * patchW; i += blockDim.x) {
const int ph = i / patchW;
const int pw = i % patchW;
int i1 = h - dilation_patchH * (ph - patchRadH);
int j1 = w - dilation_patchW * (pw - patchRadW);
if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
scalar_t grad_val = 0.0f;
const int h_2 = i1 + padH;
const int w_2 = j1 + padW;
const int min_h = h_2 - dilatedKH;
const int min_w = w_2 - dilatedKW;
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
int i2 = (h_3) / dH;
if (i2 * dH != h_3) continue;
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
int j2 = (w_3) / dW;
if (j2 * dW != w_3) continue;
if (WITHIN_BOUNDS(i2, j2, H, W)) {
grad_val += grad_output[n][ph][pw][i2][j2];
}
}
}
grad_cache[i] = grad_val;
}
}
__syncthreads();
for (int c = threadIdx.x; c < C; c += blockDim.x) {
scalar_t grad_input_val = 0.0f;
for (int ph = 0; ph < patchH; ++ph) {
int i1 = h - dilation_patchH * (ph - patchRadH);
for (int pw = 0; pw < patchW; ++pw) {
int j1 = w - dilation_patchW * (pw - patchRadW);
if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
grad_input_val += input1[n][i1][j1][c] * grad_cache[ph * patchW + pw];
}
}
}
grad_input2[n][c][h][w] = grad_input_val;
}
}
#endif

View File

@ -0,0 +1,360 @@
/*!
******************* BEGIN Caffe Copyright Notice and Disclaimer
*****************
*
* COPYRIGHT
*
* All contributions by the University of California:
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
* All rights reserved.
*
* All other contributions:
* Copyright (c) 2014-2017, the respective contributors
* All rights reserved.
*
* Caffe uses a shared copyright model: each contributor holds copyright over
* their contributions to Caffe. The project versioning records all such
* contribution and copyright details. If a contributor wants to further mark
* their specific copyright on a particular contribution, they should indicate
* their copyright solely in the commit message of the change when it is
* committed.
*
* LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
*FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
*DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
*SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
*CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
*OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
*OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* CONTRIBUTION AGREEMENT
*
* By contributing to the BVLC/caffe repository through pull-request, comment,
* or otherwise, the contributor releases their content to the
* license and copyright terms herein.
*
***************** END Caffe Copyright Notice and Disclaimer
*********************
*
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file modulated_deformable_im2col.muh
* \brief Function definitions of converting an image to
* column matrix based on kernel, padding, dilation, and offset.
* These functions are mainly used in deformable convolution operators.
* \ref: https://arxiv.org/abs/1703.06211
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
*/
// modified from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
#ifndef DEFORM_CONV_MUSA_KERNEL_MUH
#define DEFORM_CONV_MUSA_KERNEL_MUH
#include <float.h>
#include "pytorch_musa_helper.hpp"
template <typename T>
__device__ T deformable_im2col_bilinear(const T *input, const int data_width,
const int height, const int width, T h,
T w) {
if (h <= -1 || height <= h || w <= -1 || width <= w) {
return 0;
}
int h_low = floorf(h);
int w_low = floorf(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
T lh = h - h_low;
T lw = w - w_low;
T hh = 1 - lh, hw = 1 - lw;
T v1 = 0;
if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low];
T v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = input[h_low * data_width + w_high];
T v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = input[h_high * data_width + w_low];
T v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = input[h_high * data_width + w_high];
T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename T>
__device__ T get_gradient_weight(T argmax_h, T argmax_w, const int h,
const int w, const int height,
const int width) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
// empty
return 0;
}
int argmax_h_low = floorf(argmax_h);
int argmax_w_low = floorf(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
T weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
template <typename T>
__device__ T get_coordinate_weight(T argmax_h, T argmax_w, const int height,
const int width, const T *im_data,
const int data_width, const int bp_dir) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
// empty
return 0;
}
int argmax_h_low = floorf(argmax_h);
int argmax_w_low = floorf(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
T weight = 0;
if (bp_dir == 0) {
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) *
im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) *
im_data[argmax_h_high * data_width + argmax_w_high];
} else if (bp_dir == 1) {
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
template <typename T>
__global__ void deformable_im2col_gpu_kernel(
const int n, const T *data_im, const T *data_offset, const int height,
const int width, const int kernel_h, const int kernel_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group, const int batch_size,
const int num_channels, const int deformable_group, const int height_col,
const int width_col, T *data_col) {
MUSA_1D_KERNEL_LOOP(index, n) {
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int b_col = (index / width_col / height_col) % batch_size;
const int c_im = (index / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
T *data_col_ptr =
data_col +
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
const T *data_im_ptr =
data_im + (b_col * num_channels + c_im) * height * width;
const T *data_offset_ptr =
data_offset + (b_col * deformable_group + deformable_group_index) * 2 *
kernel_h * kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
T val = static_cast<T>(0);
const T h_im = h_in + i * dilation_h + offset_h;
const T w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
val = deformable_im2col_bilinear(data_im_ptr, width, height, width,
h_im, w_im);
*data_col_ptr = val;
data_col_ptr += batch_size * height_col * width_col;
}
}
}
}
template <typename T>
__global__ void deformable_col2im_gpu_kernel(
const int n, const T *data_col, const T *data_offset, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group, const int batch_size,
const int deformable_group, const int height_col, const int width_col,
T *grad_im) {
MUSA_1D_KERNEL_LOOP(index, n) {
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i =
(index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c =
index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const T *data_offset_ptr =
data_offset + (b * deformable_group + deformable_group_index) * 2 *
kernel_h * kernel_w * height_col * width_col;
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
const T cur_inv_h_data = h_in + i * dilation_h + offset_h;
const T cur_inv_w_data = w_in + j * dilation_w + offset_w;
const T cur_top_grad = data_col[index];
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++) {
for (int dx = -2; dx <= 2; dx++) {
if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1) {
int cur_bottom_grad_pos =
((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
T weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data,
cur_h + dy, cur_w + dx, height, width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
}
}
}
}
}
template <typename T>
__global__ void deformable_col2im_coord_gpu_kernel(
const int n, const T *data_col, const T *data_im, const T *data_offset,
const int channels, const int height, const int width, const int kernel_h,
const int kernel_w, const int pad_h, const int pad_w, const int stride_h,
const int stride_w, const int dilation_h, const int dilation_w,
const int channel_per_deformable_group, const int batch_size,
const int offset_channels, const int deformable_group, const int height_col,
const int width_col, T *grad_offset) {
MUSA_1D_KERNEL_LOOP(index, n) {
T val = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int c = (index / width_col / height_col) % offset_channels;
int b = (index / width_col / height_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const T *data_col_ptr = data_col + deformable_group_index *
channel_per_deformable_group *
batch_size * width_col * height_col;
const T *data_im_ptr =
data_im + (b * deformable_group + deformable_group_index) *
channel_per_deformable_group / kernel_h / kernel_w *
height * width;
const T *data_offset_ptr =
data_offset + (b * deformable_group + deformable_group_index) * 2 *
kernel_h * kernel_w * height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group;
col_c += col_step) {
const int col_pos =
(((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i =
(col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr =
(((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr =
(((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
w_out);
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
T inv_h = h_in + i * dilation_h + offset_h;
T inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
inv_h = inv_w = -2;
const T weight = get_coordinate_weight(inv_h, inv_w, height, width,
data_im_ptr + cnt * height * width,
width, bp_dir);
val += weight * data_col_ptr[col_pos];
cnt += 1;
}
grad_offset[index] = val;
}
}
#endif // DEFORM_CONV_MUSA_KERNEL_MUH

View File

@ -0,0 +1,181 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef DEFORM_ROI_POOL_MUSA_KERNEL_MUH
#define DEFORM_ROI_POOL_MUSA_KERNEL_MUH
#include "pytorch_musa_helper.hpp"
template <typename T>
__global__ void deform_roi_pool_forward_musa_kernel(
const int nthreads, const T* input, const T* rois, const T* offset,
T* output, const int pooled_height, const int pooled_width,
const T spatial_scale, const int sampling_ratio, const T gamma,
const int channels, const int height, const int width) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale - 0.5;
T roi_start_h = offset_rois[2] * spatial_scale - 0.5;
T roi_end_w = offset_rois[3] * spatial_scale - 0.5;
T roi_end_h = offset_rois[4] * spatial_scale - 0.5;
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
const T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_height / pooled_height));
int roi_bin_grid_w =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_width / pooled_width));
// Compute roi offset
if (offset != NULL) {
const T* offset_cur_w = offset + n * pooled_width * pooled_height * 2 +
ph * pooled_width + pw;
T offset_roi_w = gamma * roi_width * offset_cur_w[0];
T offset_roi_h =
gamma * roi_height * offset_cur_w[pooled_width * pooled_height];
roi_start_w += offset_roi_w;
roi_start_h += offset_roi_h;
}
// We do average pooling inside a bin
const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1);
T output_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T val = bilinear_interpolate(offset_input, height, width, y, x, index);
output_val += val;
}
}
output[index] = output_val / count;
}
}
template <typename T>
__global__ void deform_roi_pool_backward_musa_kernel(
const int nthreads, const T* grad_output, const T* input, const T* rois,
const T* offset, T* grad_input, T* grad_offset, const int pooled_height,
const int pooled_width, const T spatial_scale, const int sampling_ratio,
const T gamma, const int channels, const int height, const int width) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
const T* offset_input =
input + ((roi_batch_ind * channels + c) * height * width);
T* offset_grad_input =
grad_input + ((roi_batch_ind * channels + c) * height * width);
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale - 0.5;
T roi_start_h = offset_rois[2] * spatial_scale - 0.5;
T roi_end_w = offset_rois[3] * spatial_scale - 0.5;
T roi_end_h = offset_rois[4] * spatial_scale - 0.5;
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_height / pooled_height));
int roi_bin_grid_w =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_width / pooled_width));
// Compute roi offset
if (offset != NULL) {
const T* offset_cur_w = offset + n * pooled_width * pooled_height * 2 +
ph * pooled_width + pw;
T offset_roi_w = gamma * roi_width * offset_cur_w[0];
T offset_roi_h =
gamma * roi_height * offset_cur_w[pooled_width * pooled_height];
roi_start_w += offset_roi_w;
roi_start_h += offset_roi_h;
}
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
const T grad_output_this_bin = grad_output[index] / count;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
x_low, x_high, y_low, y_high, index);
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomicAdd(offset_grad_input + y_low * width + x_low,
grad_output_this_bin * w1);
atomicAdd(offset_grad_input + y_low * width + x_high,
grad_output_this_bin * w2);
atomicAdd(offset_grad_input + y_high * width + x_low,
grad_output_this_bin * w3);
atomicAdd(offset_grad_input + y_high * width + x_high,
grad_output_this_bin * w4);
if (offset != NULL) {
T input_00 = offset_input[y_low * width + x_low];
T input_10 = offset_input[y_low * width + x_high];
T input_01 = offset_input[y_high * width + x_low];
T input_11 = offset_input[y_high * width + x_high];
T ogx = gamma * roi_width * grad_output_this_bin *
(input_11 * (y - y_low) + input_10 * (y_high - y) +
input_01 * (y_low - y) + input_00 * (y - y_high));
T ogy = gamma * roi_height * grad_output_this_bin *
(input_11 * (x - x_low) + input_01 * (x_high - x) +
input_10 * (x_low - x) + input_00 * (x - x_high));
atomicAdd(grad_offset + n * pooled_width * pooled_height * 2 +
ph * pooled_width + pw,
ogx);
atomicAdd(grad_offset + n * pooled_width * pooled_height * 2 +
pooled_width * pooled_height + ph * pooled_width + pw,
ogy);
}
}
}
}
}
}
#endif // DEFORM_ROI_POOL_MUSA_KERNEL_MUH

View File

@ -0,0 +1,133 @@
// Copyright (c) OpenMMLab. All rights reserved
// Adapted from
// https://github.com/lilanxiao/Rotated_IoU/cuda_op/sort_vert_kernel.cu # noqa
#include "pytorch_musa_helper.hpp"
#define MAX_NUM_VERT_IDX 9
#define INTERSECTION_OFFSET 8
#define EPSILON 1e-8
inline int opt_n_thread(int work_size) {
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
return max(min(1 << pow_2, THREADS_PER_BLOCK), 1);
}
/*
compare normalized vertices (vertices around (0,0))
if vertex1 < vertex2 return true.
order: minimum at x-aixs, become larger in anti-clockwise direction
*/
__device__ bool compare_vertices(float x1, float y1, float x2, float y2) {
if (fabs(x1 - x2) < EPSILON && fabs(y2 - y1) < EPSILON)
return false; // if equal, return false
if (y1 > 0 && y2 < 0) return true;
if (y1 < 0 && y2 > 0) return false;
float n1 = x1 * x1 + y1 * y1 + EPSILON;
float n2 = x2 * x2 + y2 * y2 + EPSILON;
float diff = fabs(x1) * x1 / n1 - fabs(x2) * x2 / n2;
if (y1 > 0 && y2 > 0) {
if (diff > EPSILON)
return true;
else
return false;
}
if (y1 < 0 && y2 < 0) {
if (diff < EPSILON)
return true;
else
return false;
}
return false;
}
__global__ void diff_iou_rotated_sort_vertices_forward_musa_kernel(
int b, int n, int m, const float *__restrict__ vertices,
const bool *__restrict__ mask, const int *__restrict__ num_valid,
int *__restrict__ idx) {
int batch_idx = blockIdx.x;
vertices += batch_idx * n * m * 2;
mask += batch_idx * n * m;
num_valid += batch_idx * n;
idx += batch_idx * n * MAX_NUM_VERT_IDX;
int index = threadIdx.x; // index of polygon
int stride = blockDim.x;
for (int i = index; i < n; i += stride) {
int pad; // index of arbitrary invalid intersection point (not box corner!)
for (int j = INTERSECTION_OFFSET; j < m; ++j) {
if (!mask[i * m + j]) {
pad = j;
break;
}
}
if (num_valid[i] < 3) {
// not enough vertices, take an invalid intersection point
// (zero padding)
for (int j = 0; j < MAX_NUM_VERT_IDX; ++j) {
idx[i * MAX_NUM_VERT_IDX + j] = pad;
}
} else {
// sort the valid vertices
// note the number of valid vertices is known
// note: check that num_valid[i] < MAX_NUM_VERT_IDX
for (int j = 0; j < num_valid[i]; ++j) {
// initialize with a "big" value
float x_min = 1;
float y_min = -EPSILON;
int i_take = 0;
int i2;
float x2, y2;
if (j != 0) {
i2 = idx[i * MAX_NUM_VERT_IDX + j - 1];
x2 = vertices[i * m * 2 + i2 * 2 + 0];
y2 = vertices[i * m * 2 + i2 * 2 + 1];
}
for (int k = 0; k < m; ++k) {
float x = vertices[i * m * 2 + k * 2 + 0];
float y = vertices[i * m * 2 + k * 2 + 1];
if (mask[i * m + k] && compare_vertices(x, y, x_min, y_min)) {
if ((j == 0) || (j != 0 && compare_vertices(x2, y2, x, y))) {
x_min = x;
y_min = y;
i_take = k;
}
}
}
idx[i * MAX_NUM_VERT_IDX + j] = i_take;
}
// duplicate the first idx
idx[i * MAX_NUM_VERT_IDX + num_valid[i]] = idx[i * MAX_NUM_VERT_IDX + 0];
// pad zeros
for (int j = num_valid[i] + 1; j < MAX_NUM_VERT_IDX; ++j) {
idx[i * MAX_NUM_VERT_IDX + j] = pad;
}
// for corner case: the two boxes are exactly the same.
// in this case, idx would have duplicate elements, which makes the
// shoelace formula broken because of the definition, the duplicate
// elements only appear in the first 8 positions (they are "corners in
// box", not "intersection of edges")
if (num_valid[i] == 8) {
int counter = 0;
for (int j = 0; j < 4; ++j) {
int check = idx[i * MAX_NUM_VERT_IDX + j];
for (int k = 4; k < INTERSECTION_OFFSET; ++k) {
if (idx[i * MAX_NUM_VERT_IDX + k] == check) counter++;
}
}
if (counter == 4) {
idx[i * MAX_NUM_VERT_IDX + 4] = idx[i * MAX_NUM_VERT_IDX + 0];
for (int j = 5; j < MAX_NUM_VERT_IDX; ++j) {
idx[i * MAX_NUM_VERT_IDX + j] = pad;
}
}
}
// TODO: still might need to cover some other corner cases :(
}
}
}

View File

@ -153,7 +153,9 @@ void deform_conv_forward(Tensor input, Tensor weight, Tensor offset,
#else
AT_ERROR("DeformConv is not compiled with GPU support");
#endif
} else {
}
#ifndef MMCV_WITH_MUSA
else {
CHECK_CPU_INPUT(input);
CHECK_CPU_INPUT(offset);
CHECK_CPU_INPUT(weight);
@ -161,7 +163,7 @@ void deform_conv_forward(Tensor input, Tensor weight, Tensor offset,
CHECK_CPU_INPUT(columns);
CHECK_CPU_INPUT(ones);
}
#endif
deform_conv_shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH,
padW, dilationH, dilationW, group, deformable_group);
at::DeviceGuard guard(input.device());
@ -275,7 +277,9 @@ void deform_conv_backward_input(Tensor input, Tensor offset, Tensor gradOutput,
#else
AT_ERROR("DeformConv is not compiled with GPU support");
#endif
} else {
}
#ifndef MMCV_WITH_MUSA
else {
CHECK_CPU_INPUT(input);
CHECK_CPU_INPUT(offset);
CHECK_CPU_INPUT(gradOutput);
@ -284,6 +288,7 @@ void deform_conv_backward_input(Tensor input, Tensor offset, Tensor gradOutput,
CHECK_CPU_INPUT(weight);
CHECK_CPU_INPUT(columns);
}
#endif
deform_conv_shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW,
padH, padW, dilationH, dilationW, group,
deformable_group);
@ -407,7 +412,9 @@ void deform_conv_backward_parameters(Tensor input, Tensor offset,
#else
AT_ERROR("DeformConv is not compiled with GPU support");
#endif
} else {
}
#ifndef MMCV_WITH_MUSA
else {
CHECK_CPU_INPUT(input);
CHECK_CPU_INPUT(offset);
CHECK_CPU_INPUT(gradOutput);
@ -415,6 +422,7 @@ void deform_conv_backward_parameters(Tensor input, Tensor offset,
CHECK_CPU_INPUT(columns);
CHECK_CPU_INPUT(ones);
}
#endif
deform_conv_shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH,
dW, padH, padW, dilationH, dilationW, group,

View File

@ -0,0 +1,301 @@
// Modified from
// https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/bias_act.cpp
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <c10/util/Half.h>
#include <musa_runtime.h>
#include <torch/types.h>
#include "pytorch_musa_helper.hpp"
struct bias_act_kernel_params {
const void *x; // [sizeX]
const void *b; // [sizeB] or NULL
const void *xref; // [sizeX] or NULL
const void *yref; // [sizeX] or NULL
const void *dy; // [sizeX] or NULL
void *y; // [sizeX]
int grad;
int act;
float alpha;
float gain;
float clamp;
int sizeX;
int sizeB;
int stepB;
int loopX;
};
// MUSA kernel selection.
template <class T>
void *choose_bias_act_kernel(const bias_act_kernel_params &p);
//------------------------------------------------------------------------
// Helpers.
template <class T>
struct InternalType;
template <>
struct InternalType<double> {
typedef double scalar_t;
};
template <>
struct InternalType<float> {
typedef float scalar_t;
};
template <>
struct InternalType<c10::Half> {
typedef float scalar_t;
};
//------------------------------------------------------------------------
// MUSA kernel.
template <class T, int A>
__global__ void bias_act_kernel(bias_act_kernel_params p) {
typedef typename InternalType<T>::scalar_t scalar_t;
int G = p.grad;
scalar_t alpha = (scalar_t)p.alpha;
scalar_t gain = (scalar_t)p.gain;
scalar_t clamp = (scalar_t)p.clamp;
scalar_t one = (scalar_t)1;
scalar_t two = (scalar_t)2;
scalar_t expRange = (scalar_t)80;
scalar_t halfExpRange = (scalar_t)40;
scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
// Loop over elements.
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX;
loopIdx++, xi += blockDim.x) {
// Load.
scalar_t x = (scalar_t)((const T *)p.x)[xi];
scalar_t b =
(p.b) ? (scalar_t)((const T *)p.b)[(xi / p.stepB) % p.sizeB] : 0;
scalar_t xref = (p.xref) ? (scalar_t)((const T *)p.xref)[xi] : 0;
scalar_t yref = (p.yref) ? (scalar_t)((const T *)p.yref)[xi] : 0;
scalar_t dy = (p.dy) ? (scalar_t)((const T *)p.dy)[xi] : one;
scalar_t yy = (gain != 0) ? yref / gain : 0;
scalar_t y = 0;
// Apply bias.
((G == 0) ? x : xref) += b;
// linear
if (A == 1) {
if (G == 0) y = x;
if (G == 1) y = x;
}
// relu
if (A == 2) {
if (G == 0) y = (x > 0) ? x : 0;
if (G == 1) y = (yy > 0) ? x : 0;
}
// lrelu
if (A == 3) {
if (G == 0) y = (x > 0) ? x : x * alpha;
if (G == 1) y = (yy > 0) ? x : x * alpha;
}
// tanh
if (A == 4) {
if (G == 0) {
scalar_t c = exp(x);
scalar_t d = one / c;
y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d);
}
if (G == 1) y = x * (one - yy * yy);
if (G == 2) y = x * (one - yy * yy) * (-two * yy);
}
// sigmoid
if (A == 5) {
if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
if (G == 1) y = x * yy * (one - yy);
if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
}
// elu
if (A == 6) {
if (G == 0) y = (x >= 0) ? x : exp(x) - one;
if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
}
// selu
if (A == 7) {
if (G == 0)
y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
if (G == 1)
y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
}
// softplus
if (A == 8) {
if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
if (G == 1) y = x * (one - exp(-yy));
if (G == 2) {
scalar_t c = exp(-yy);
y = x * c * (one - c);
}
}
// swish
if (A == 9) {
if (G == 0)
y = (x < -expRange) ? 0 : x / (exp(-x) + one);
else {
scalar_t c = exp(xref);
scalar_t d = c + one;
if (G == 1)
y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
else
y = (xref > halfExpRange)
? 0
: x * c * (xref * (two - d) + two * d) / (d * d * d);
yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
}
}
// Apply gain.
y *= gain * dy;
// Clamp.
if (clamp >= 0) {
if (G == 0)
y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
else
y = (yref > -clamp & yref < clamp) ? y : 0;
}
// Store.
((T *)p.y)[xi] = (T)y;
}
}
//------------------------------------------------------------------------
// MUSA kernel selection.
template <class T>
void *choose_bias_act_kernel(const bias_act_kernel_params &p) {
if (p.act == 1) return (void *)bias_act_kernel<T, 1>;
if (p.act == 2) return (void *)bias_act_kernel<T, 2>;
if (p.act == 3) return (void *)bias_act_kernel<T, 3>;
if (p.act == 4) return (void *)bias_act_kernel<T, 4>;
if (p.act == 5) return (void *)bias_act_kernel<T, 5>;
if (p.act == 6) return (void *)bias_act_kernel<T, 6>;
if (p.act == 7) return (void *)bias_act_kernel<T, 7>;
if (p.act == 8) return (void *)bias_act_kernel<T, 8>;
if (p.act == 9) return (void *)bias_act_kernel<T, 9>;
return NULL;
}
//------------------------------------------------------------------------
static bool has_same_layout(torch::Tensor x, torch::Tensor y) {
if (x.dim() != y.dim()) return false;
for (int64_t i = 0; i < x.dim(); i++) {
if (x.size(i) != y.size(i)) return false;
if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) return false;
}
return true;
}
//------------------------------------------------------------------------
torch::Tensor bias_act_op(const torch::Tensor &x, const torch::Tensor &b,
const torch::Tensor &xref, const torch::Tensor &yref,
const torch::Tensor &dy, int grad, int dim, int act,
float alpha, float gain, float clamp) {
// Validate arguments.
TORCH_CHECK(x.is_privateuseone(), "x must reside on MUSA device");
TORCH_CHECK(
b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()),
"b must have the same dtype and device as x");
TORCH_CHECK(xref.numel() == 0 ||
(xref.sizes() == x.sizes() && xref.dtype() == x.dtype() &&
xref.device() == x.device()),
"xref must have the same shape, dtype, and device as x");
TORCH_CHECK(yref.numel() == 0 ||
(yref.sizes() == x.sizes() && yref.dtype() == x.dtype() &&
yref.device() == x.device()),
"yref must have the same shape, dtype, and device as x");
TORCH_CHECK(
dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() &&
dy.device() == x.device()),
"dy must have the same dtype and device as x");
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()),
"dim is out of bounds");
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim),
"b has wrong number of elements");
TORCH_CHECK(grad >= 0, "grad must be non-negative");
// Validate layout.
TORCH_CHECK(x.is_non_overlapping_and_dense(),
"x must be non-overlapping and dense");
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x),
"xref must have the same layout as x");
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x),
"yref must have the same layout as x");
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x),
"dy must have the same layout as x");
// Create output tensor.
const at::musa::OptionalMUSAGuard device_guard(device_of(x));
torch::Tensor y = torch::empty_like(x);
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
// Initialize MUSA kernel parameters.
bias_act_kernel_params p;
p.x = x.data_ptr();
p.b = (b.numel()) ? b.data_ptr() : NULL;
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
p.y = y.data_ptr();
p.grad = grad;
p.act = act;
p.alpha = alpha;
p.gain = gain;
p.clamp = clamp;
p.sizeX = (int)x.numel();
p.sizeB = (int)b.numel();
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
// Choose MUSA kernel.
void *kernel;
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "upfirdn2d_musa", [&] {
kernel = choose_bias_act_kernel<scalar_t>(p);
});
TORCH_CHECK(kernel, "no MUSA kernel found for the specified activation func");
// Launch MUSA kernel.
p.loopX = 4;
int blockSize = 4 * 32;
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
void *args[] = {&p};
#ifdef MMCV_WITH_HIP
AT_MUSA_CHECK(hipLaunchKernel(kernel, gridSize, blockSize, args, 0,
c10::musa::getCurrentMUSAStream()));
#else
AT_MUSA_CHECK(musaLaunchKernel(kernel, gridSize, blockSize, args, 0,
c10::musa::getCurrentMUSAStream()));
#endif
return y;
}

View File

@ -0,0 +1,68 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "border_align_musa_kernel.muh"
#include "pytorch_musa_helper.hpp"
void BorderAlignForwardMUSAKernelLauncher(const Tensor &input,
const Tensor &boxes, Tensor output,
Tensor argmax_idx,
const int pool_size) {
// shape assertion
AT_ASSERTM(input.ndimension() == 4,
"non-empty 4D(batch mode) tensor expected for input feature");
AT_ASSERTM(boxes.ndimension() == 3,
"boxes must be 3D tensor with size of [B, H*W, 4]");
int batch_size = input.size(0);
int feat_channels = input.size(1);
int channels = feat_channels / 4;
int height = input.size(2);
int width = input.size(3);
// shape [N, box_size, 4] for boxes. (x1, y1, x2, y2) format
int box_size = boxes.size(1);
// shape [N, channels, box_size, 4] for output
int nthreads = batch_size * channels * box_size;
c10::musa::MUSAGuard device_guard(input.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
dim3 block(128, 4);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "border_align_forward_musa_kernel", [&] {
border_align_forward_musa_kernel<scalar_t>
<<<GET_BLOCKS(nthreads), block, 0, stream>>>(
nthreads, input.data_ptr<scalar_t>(),
boxes.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
argmax_idx.data_ptr<int>(), channels, box_size, height, width,
pool_size);
});
AT_MUSA_CHECK(musaGetLastError());
}
void BorderAlignBackwardMUSAKernelLauncher(const Tensor &grad_output,
const Tensor &boxes,
const Tensor &argmax_idx,
Tensor grad_input,
const int pool_size) {
int batch_size = grad_input.size(0);
int feat_channels = grad_input.size(1);
int channels = feat_channels / 4;
int height = grad_input.size(2);
int width = grad_input.size(3);
int box_size = boxes.size(1);
int nthreads = batch_size * channels * box_size;
c10::musa::MUSAGuard device_guard(grad_output.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
dim3 block(128, 4);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "border_align_backward_musa_kernel", [&] {
border_align_backward_musa_kernel<scalar_t>
<<<GET_BLOCKS(nthreads), block, 0, stream>>>(
nthreads, grad_output.data_ptr<scalar_t>(),
boxes.data_ptr<scalar_t>(), argmax_idx.data_ptr<int>(),
grad_input.data_ptr<scalar_t>(), channels, box_size, height,
width, pool_size);
});
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,23 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include "box_iou_quadri_musa.muh"
#include "pytorch_musa_helper.hpp"
void box_iou_quadri_musa(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned) {
using scalar_t = float;
AT_ASSERTM(boxes1.is_privateuseone(), "boxes1 must be a MUSA tensor");
AT_ASSERTM(boxes2.is_privateuseone(), "boxes2 must be a MUSA tensor");
int output_size = ious.numel();
int num_boxes1 = boxes1.size(0);
int num_boxes2 = boxes2.size(0);
c10::musa::MUSAGuard device_guard(boxes1.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
box_iou_quadri_musa_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
num_boxes1, num_boxes2, boxes1.data_ptr<scalar_t>(),
boxes2.data_ptr<scalar_t>(), (scalar_t*)ious.data_ptr<scalar_t>(),
mode_flag, aligned);
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,25 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_musa.cu
#include "box_iou_rotated_musa.muh"
#include "pytorch_musa_helper.hpp"
void box_iou_rotated_musa(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned) {
using scalar_t = float;
AT_ASSERTM(boxes1.is_privateuseone(), "boxes1 must be a MUSA tensor");
AT_ASSERTM(boxes2.is_privateuseone(), "boxes2 must be a MUSA tensor");
int output_size = ious.numel();
int num_boxes1 = boxes1.size(0);
int num_boxes2 = boxes2.size(0);
c10::musa::MUSAGuard device_guard(boxes1.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
box_iou_rotated_musa_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
num_boxes1, num_boxes2, boxes1.data_ptr<scalar_t>(),
boxes2.data_ptr<scalar_t>(), (scalar_t*)ious.data_ptr<scalar_t>(),
mode_flag, aligned);
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,182 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "carafe_musa_kernel.muh"
#include "pytorch_musa_helper.hpp"
#include <iostream>
#if MUSA_ARCH > 21
void CARAFEForwardMUSAKernelLauncher(const Tensor features, const Tensor masks,
Tensor rfeatures, Tensor routput,
Tensor rmasks, Tensor output,
const int kernel_size,
const int group_size,
const int scale_factor) {
const int batch_size = output.size(0);
const int channels = output.size(1);
const int output_height = output.size(2);
const int output_width = output.size(3);
const int input_height = features.size(2);
const int input_width = features.size(3);
const int mask_channels = masks.size(1);
rfeatures.resize_({batch_size, input_height, input_width, channels});
routput.resize_({batch_size, output_height, output_width, channels});
rmasks.resize_({batch_size, output_height, output_width, mask_channels});
// one warp per pixel
c10::musa::MUSAGuard device_guard(features.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.scalar_type(), "NCHW2NHWC_Feature", ([&] {
const scalar_t *bottom_data = features.data_ptr<scalar_t>();
scalar_t *top_data = rfeatures.data_ptr<scalar_t>();
const int dh = divideUP(channels, kTileDim);
const int dw = divideUP(input_height * input_width, kTileDim);
BatchTranspose2DMUSAKernel<scalar_t>
<<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>(
batch_size, channels, input_height * input_width, dh, dw,
bottom_data, top_data);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.scalar_type(), "NCHW2NHWC_Masks", ([&] {
const scalar_t *bottom_data = masks.data_ptr<scalar_t>();
scalar_t *top_data = rmasks.data_ptr<scalar_t>();
const int dh = divideUP(mask_channels, kTileDim);
const int dw = divideUP(output_height * output_width, kTileDim);
BatchTranspose2DMUSAKernel<scalar_t>
<<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>(
batch_size, mask_channels, output_height * output_width, dh, dw,
bottom_data, top_data);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.scalar_type(), "CARAFELaucherForward", ([&] {
const int num_kernels =
batch_size * output_height * output_width * THREADS_PER_PIXEL;
const scalar_t *bottom_data = rfeatures.data_ptr<scalar_t>();
const scalar_t *bottom_masks = rmasks.data_ptr<scalar_t>();
scalar_t *top_data = routput.data_ptr<scalar_t>();
CARAFEForward<scalar_t><<<divideUP(num_kernels, THREADS_PER_BLOCK),
THREADS_PER_BLOCK, 0, stream>>>(
num_kernels, bottom_data, bottom_masks, kernel_size, group_size,
scale_factor, channels, input_height, input_width, output_height,
output_width, mask_channels, top_data);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.scalar_type(), "NHWC2NCHW", ([&] {
const scalar_t *bottom_data = routput.data_ptr<scalar_t>();
scalar_t *top_data = output.data_ptr<scalar_t>();
const int dh = divideUP(output_height * output_width, kTileDim);
const int dw = divideUP(channels, kTileDim);
BatchTranspose2DMUSAKernel<scalar_t>
<<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>(
batch_size, output_height * output_width, channels, dh, dw,
bottom_data, top_data);
}));
AT_MUSA_CHECK(musaGetLastError());
}
void CARAFEBackwardMUSAKernelLauncher(
const Tensor top_grad, const Tensor rfeatures, const Tensor masks,
Tensor rtop_grad, Tensor rbottom_grad_hs, Tensor rbottom_grad,
Tensor rmask_grad, Tensor bottom_grad, Tensor mask_grad,
const int kernel_size, const int group_size, const int scale_factor) {
const int batch_size = top_grad.size(0);
const int channels = top_grad.size(1);
const int output_height = top_grad.size(2);
const int output_width = top_grad.size(3);
const int input_height = bottom_grad.size(2);
const int input_width = bottom_grad.size(3);
const int mask_channels = masks.size(1);
rtop_grad.resize_({batch_size, output_height, output_width, channels});
rbottom_grad.resize_({batch_size, input_height, input_width, channels});
rbottom_grad_hs.resize_({batch_size, output_height, output_width, channels});
rmask_grad.resize_({batch_size, output_height, output_width, mask_channels});
c10::musa::MUSAGuard device_guard(top_grad.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.scalar_type(), "NCHW2NHWC_Top_Grad", ([&] {
const scalar_t *bottom_data = top_grad.data_ptr<scalar_t>();
scalar_t *top_data = rtop_grad.data_ptr<scalar_t>();
const int dh = divideUP(channels, kTileDim);
const int dw = divideUP(output_height * output_width, kTileDim);
BatchTranspose2DMUSAKernel<scalar_t>
<<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>(
batch_size, channels, output_height * output_width, dh, dw,
bottom_data, top_data);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.scalar_type(), "CARAFELaucherBackward_Feature", ([&] {
const int num_kernels =
batch_size * output_height * output_width * THREADS_PER_PIXEL;
const scalar_t *top_diff = rtop_grad.data_ptr<scalar_t>();
const scalar_t *bottom_masks = masks.data_ptr<scalar_t>();
scalar_t *bottom_diff = rbottom_grad_hs.data_ptr<scalar_t>();
CARAFEBackward_Feature<scalar_t>
<<<divideUP(num_kernels, THREADS_PER_BLOCK), THREADS_PER_BLOCK, 0,
stream>>>(num_kernels, top_diff, bottom_masks, kernel_size,
group_size, scale_factor, channels, input_height,
input_width, output_height, output_width,
mask_channels, bottom_diff);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.scalar_type(), "FeatureSum", ([&] {
const int num_kernels =
batch_size * input_height * input_width * THREADS_PER_PIXEL;
const scalar_t *bottom_diff_hs = rbottom_grad_hs.data_ptr<scalar_t>();
scalar_t *bottom_diff = rbottom_grad.data_ptr<scalar_t>();
FeatureSum<scalar_t>
<<<divideUP(num_kernels, THREADS_PER_BLOCK), THREADS_PER_BLOCK, 0,
stream>>>(num_kernels, bottom_diff_hs, scale_factor, channels,
input_height, input_width, bottom_diff);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.scalar_type(), "NHWC2NCHW_Bottom_Grad", ([&] {
const scalar_t *bottom_data = rbottom_grad.data_ptr<scalar_t>();
scalar_t *top_data = bottom_grad.data_ptr<scalar_t>();
const int dh = divideUP(input_height * input_width, kTileDim);
const int dw = divideUP(channels, kTileDim);
BatchTranspose2DMUSAKernel<scalar_t>
<<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>(
batch_size, input_height * input_width, channels, dh, dw,
bottom_data, top_data);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.scalar_type(), "CARAFELaucherBackward_Mask", ([&] {
const int num_kernels = batch_size * output_height * output_width *
mask_channels * WARP_SIZE;
const scalar_t *top_diff = rtop_grad.data_ptr<scalar_t>();
const scalar_t *bottom_data = rfeatures.data_ptr<scalar_t>();
scalar_t *mask_diff = rmask_grad.data_ptr<scalar_t>();
CARAFEBackward_Mask<scalar_t>
<<<divideUP(num_kernels, THREADS_PER_BLOCK), THREADS_PER_BLOCK, 0,
stream>>>(num_kernels, top_diff, bottom_data, kernel_size,
group_size, scale_factor, channels, input_height,
input_width, output_height, output_width,
mask_channels, mask_diff);
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.scalar_type(), "NHWC2NCHW_Mask_Grad", ([&] {
const scalar_t *bottom_data = rmask_grad.data_ptr<scalar_t>();
scalar_t *top_data = mask_grad.data_ptr<scalar_t>();
const int dh = divideUP(output_height * output_width, kTileDim);
const int dw = divideUP(mask_channels, kTileDim);
BatchTranspose2DMUSAKernel<scalar_t>
<<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>(
batch_size, output_height * output_width, mask_channels, dh, dw,
bottom_data, top_data);
}));
AT_MUSA_CHECK(musaGetLastError());
}
#endif //MUSA_ARCH

View File

@ -0,0 +1,52 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "carafe_naive_musa_kernel.muh"
#include "pytorch_musa_helper.hpp"
void CARAFENAIVEForwardMUSAKernelLauncher(const Tensor features,
const Tensor masks, Tensor output,
const int kernel_size,
const int group_size,
const int scale_factor) {
int output_size = output.numel();
int channels = output.size(1);
int height = output.size(2);
int width = output.size(3);
c10::musa::MUSAGuard device_guard(features.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES(
features.scalar_type(), "CARAFENAIVEForward", ([&] {
carafe_naive_forward_musa_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, features.data_ptr<scalar_t>(),
masks.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
kernel_size, group_size, scale_factor, channels, height, width);
}));
AT_MUSA_CHECK(musaGetLastError());
}
void CARAFENAIVEBackwardMUSAKernelLauncher(
const Tensor top_grad, const Tensor features, const Tensor masks,
Tensor bottom_grad, Tensor mask_grad, const int kernel_size,
const int group_size, const int scale_factor) {
int output_size = top_grad.numel();
int channels = top_grad.size(1);
int height = top_grad.size(2);
int width = top_grad.size(3);
c10::musa::MUSAGuard device_guard(top_grad.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES(
top_grad.scalar_type(), "CARAFENAIVEBackward", ([&] {
carafe_naive_backward_musa_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, top_grad.data_ptr<scalar_t>(),
features.data_ptr<scalar_t>(), masks.data_ptr<scalar_t>(),
bottom_grad.data_ptr<scalar_t>(),
mask_grad.data_ptr<scalar_t>(), kernel_size, group_size,
scale_factor, channels, height, width);
}));
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,66 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cpp
#include "chamfer_distance_musa_kernel.muh"
#include "pytorch_musa_helper.hpp"
#if MUSA_ARCH > 21
void ChamferDistanceForwardMUSAKernelLauncher(
const Tensor xyz1, const Tensor xyz2, const Tensor dist1,
const Tensor dist2, const Tensor idx1, const Tensor idx2) {
int batch_size = xyz1.size(0);
int n = xyz1.size(1);
int m = xyz2.size(1);
c10::musa::MUSAGuard device_guard(xyz1.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
xyz1.scalar_type(), "chamfer_distance_forward_musa_kernel", [&] {
chamfer_distance_forward_musa_kernel<scalar_t>
<<<GET_BLOCKS(batch_size * n), THREADS_PER_BLOCK, 0, stream>>>(
batch_size, n, xyz1.data_ptr<scalar_t>(), m,
xyz2.data_ptr<scalar_t>(), dist1.data_ptr<scalar_t>(),
idx1.data_ptr<int>());
});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
xyz1.scalar_type(), "chamfer_distance_forward_musa_kernel", [&] {
chamfer_distance_forward_musa_kernel<scalar_t>
<<<GET_BLOCKS(batch_size * m), THREADS_PER_BLOCK, 0, stream>>>(
batch_size, m, xyz2.data_ptr<scalar_t>(), n,
xyz1.data_ptr<scalar_t>(), dist2.data_ptr<scalar_t>(),
idx2.data_ptr<int>());
});
AT_MUSA_CHECK(musaGetLastError());
}
void ChamferDistanceBackwardMUSAKernelLauncher(
const Tensor xyz1, const Tensor xyz2, Tensor idx1, Tensor idx2,
Tensor grad_dist1, Tensor grad_dist2, Tensor grad_xyz1, Tensor grad_xyz2) {
int batch_size = xyz1.size(0);
int n = xyz1.size(1);
int m = xyz2.size(1);
c10::musa::MUSAGuard device_guard(xyz1.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES(
xyz1.scalar_type(), "chamfer_distance_backward_musa_kernel", [&] {
chamfer_distance_backward_musa_kernel<scalar_t>
<<<GET_BLOCKS(batch_size * n), THREADS_PER_BLOCK / 2, 0, stream>>>(
batch_size, m, xyz1.data_ptr<scalar_t>(), n,
xyz2.data_ptr<scalar_t>(), grad_dist1.data_ptr<scalar_t>(),
idx1.data_ptr<int>(), grad_xyz1.data_ptr<scalar_t>(),
grad_xyz2.data_ptr<scalar_t>());
});
AT_DISPATCH_FLOATING_TYPES(
xyz1.scalar_type(), "chamfer_distance_backward_musa_kernel", [&] {
chamfer_distance_backward_musa_kernel<scalar_t>
<<<GET_BLOCKS(batch_size * m), THREADS_PER_BLOCK / 2, 0, stream>>>(
batch_size, n, xyz2.data_ptr<scalar_t>(), m,
xyz1.data_ptr<scalar_t>(), grad_dist2.data_ptr<scalar_t>(),
idx2.data_ptr<int>(), grad_xyz2.data_ptr<scalar_t>(),
grad_xyz1.data_ptr<scalar_t>());
});
AT_MUSA_CHECK(musaGetLastError());
}
#else
#warning "chamfer_distance is supported when MUSA_ARCH > 21"
#endif //MUSA_ARCH

View File

@ -0,0 +1,41 @@
// Copyright (c) OpenMMLab. All rights reserved
// modified from
// https://github.com/SDL-GuoZonghao/BeyondBoundingBox/blob/main/mmdet/ops/iou/src/convex_iou_kernel.cu
#include "convex_iou_musa_kernel.muh"
#include "pytorch_musa_helper.hpp"
void ConvexIoUMUSAKernelLauncher(const Tensor pointsets, const Tensor polygons,
Tensor ious) {
int output_size = ious.numel();
int num_pointsets = pointsets.size(0);
int num_polygons = polygons.size(0);
c10::musa::MUSAGuard device_guard(pointsets.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES(
pointsets.scalar_type(), "convex_iou_musa_kernel", ([&] {
convex_iou_musa_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK / 2, 0, stream>>>(
num_pointsets, num_polygons, pointsets.data_ptr<scalar_t>(),
polygons.data_ptr<scalar_t>(), ious.data_ptr<scalar_t>());
}));
AT_MUSA_CHECK(musaGetLastError());
}
void ConvexGIoUMUSAKernelLauncher(const Tensor pointsets, const Tensor polygons,
Tensor output) {
int output_size = output.numel();
int num_pointsets = pointsets.size(0);
int num_polygons = polygons.size(0);
c10::musa::MUSAGuard device_guard(pointsets.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES(
pointsets.scalar_type(), "convex_giou_musa_kernel", ([&] {
convex_giou_musa_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK / 2, 0, stream>>>(
num_pointsets, num_polygons, pointsets.data_ptr<scalar_t>(),
polygons.data_ptr<scalar_t>(), output.data_ptr<scalar_t>());
}));
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,94 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/ClementPinard/Pytorch-Correlation-extension/blob/master/Correlation_Module/correlation_musa_kernel.cu
// Original licence: Under MIT License
#include "correlation_musa.muh"
#include "pytorch_musa_helper.hpp"
void CorrelationForwardMUSAKernelLauncher(Tensor input1, Tensor input2,
Tensor output, int kH, int kW,
int patchH, int patchW, int padH,
int padW, int dilationH,
int dilationW, int dilation_patchH,
int dilation_patchW, int dH, int dW) {
const int batch_size = input1.size(0);
const int iH = input1.size(2);
const int iW = input1.size(3);
const int dilatedKH = (kH - 1) * dilationH + 1;
const int dilatedKW = (kW - 1) * dilationW + 1;
const auto oH = (iH + 2 * padH - dilatedKH) / dH + 1;
const auto oW = (iW + 2 * padW - dilatedKW) / dW + 1;
auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous();
auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous();
const dim3 threads(WARP_SIZE, 4, 4);
const dim3 blocks(batch_size, (oH + 3) >> 2, (oW + 3) >> 2);
c10::musa::MUSAGuard device_guard(input1.device());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input1.scalar_type(), "correlation_forward_musa", ([&] {
TensorAcc4R trInput1_acc =
trInput1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc4R trInput2_acc =
trInput2.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc5R output_acc =
output.packed_accessor32<scalar_t, 5, RestrictPtrTraits>();
correlation_forward_musa_kernel<scalar_t>
<<<blocks, threads, 0, c10::musa::getCurrentMUSAStream()>>>(
trInput1_acc, trInput2_acc, output_acc, kH, kW, patchH, patchW,
padH, padW, dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW, oH, oW);
}));
}
void CorrelationBackwardMUSAKernelLauncher(
Tensor grad_output, Tensor input1, Tensor input2, Tensor grad_input1,
Tensor grad_input2, int kH, int kW, int patchH, int patchW, int padH,
int padW, int dilationH, int dilationW, int dilation_patchH,
int dilation_patchW, int dH, int dW) {
const int batch_size = input1.size(0);
const int iH = input1.size(2);
const int iW = input1.size(3);
const int C = input1.size(1);
auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous();
auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous();
const dim3 blocks(batch_size, iH, iW);
const dim3 threads(THREADS_PER_BLOCK);
c10::musa::MUSAGuard device_guard(input1.device());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input1.scalar_type(), "correlation_backward_musa", ([&] {
const int grad_cache_size = patchH * patchW * sizeof(scalar_t);
TensorAcc4R input1_acc =
trInput1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc4R input2_acc =
trInput2.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc4R grad_input1_acc =
grad_input1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc4R grad_input2_acc =
grad_input2.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc5R grad_output_acc =
grad_output.packed_accessor32<scalar_t, 5, RestrictPtrTraits>();
correlation_backward_musa_kernel_input1<scalar_t>
<<<blocks, threads, grad_cache_size,
c10::musa::getCurrentMUSAStream()>>>(
grad_output_acc, input2_acc, grad_input1_acc, kH, kW, patchH,
patchW, padH, padW, dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW);
correlation_backward_musa_kernel_input2<scalar_t>
<<<blocks, threads, grad_cache_size,
c10::musa::getCurrentMUSAStream()>>>(
grad_output_acc, input1_acc, grad_input2_acc, kH, kW, patchH,
patchW, padH, padW, dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW);
}));
}

View File

@ -0,0 +1,105 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "deform_conv_musa_kernel.muh"
#include "pytorch_musa_helper.hpp"
void deformable_im2col_musa(Tensor data_im, Tensor data_offset,
const int channels, const int height,
const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
Tensor data_col) {
// num_axes should be smaller than block size
// todo: check parallel_imgs is correctly passed in
int height_col =
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col =
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * height_col * width_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels),
THREADS_PER_BLOCK, 0,
c10::musa::getCurrentMUSAStream()>>>(
num_kernels, data_im_, data_offset_, height, width, ksize_h,
ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
channel_per_deformable_group, parallel_imgs, channels,
deformable_group, height_col, width_col, data_col_);
}));
AT_MUSA_CHECK(musaGetLastError());
}
void deformable_col2im_musa(Tensor data_col, Tensor data_offset,
const int channels, const int height,
const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
Tensor grad_im) {
// todo: make sure parallel_imgs is passed in correctly
int height_col =
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col =
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels =
channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels),
THREADS_PER_BLOCK, 0,
c10::musa::getCurrentMUSAStream()>>>(
num_kernels, data_col_, data_offset_, channels, height, width,
ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h,
dilation_w, channel_per_deformable_group, parallel_imgs,
deformable_group, height_col, width_col, grad_im_);
}));
AT_MUSA_CHECK(musaGetLastError());
}
void deformable_col2im_coord_musa(
Tensor data_col, Tensor data_im, Tensor data_offset, const int channels,
const int height, const int width, const int ksize_h, const int ksize_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int parallel_imgs,
const int deformable_group, Tensor grad_offset) {
int height_col =
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col =
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w *
deformable_group * parallel_imgs;
int channel_per_deformable_group =
channels * ksize_h * ksize_w / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
deformable_col2im_coord_gpu_kernel<<<
GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0,
c10::musa::getCurrentMUSAStream()>>>(
num_kernels, data_col_, data_im_, data_offset_, channels, height,
width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group, parallel_imgs,
2 * ksize_h * ksize_w * deformable_group, deformable_group,
height_col, width_col, grad_offset_);
}));
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,55 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "deform_roi_pool_musa_kernel.muh"
#include "pytorch_musa_helper.hpp"
void DeformRoIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois,
Tensor offset, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale,
int sampling_ratio, float gamma) {
int output_size = output.numel();
int channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
c10::musa::MUSAGuard device_guard(input.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "deform_roi_pool_forward_musa_kernel", [&] {
deform_roi_pool_forward_musa_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input.data_ptr<scalar_t>(),
rois.data_ptr<scalar_t>(), offset.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(), pooled_height, pooled_width,
static_cast<scalar_t>(spatial_scale), sampling_ratio,
static_cast<scalar_t>(gamma), channels, height, width);
});
AT_MUSA_CHECK(musaGetLastError());
}
void DeformRoIPoolBackwardMUSAKernelLauncher(
Tensor grad_output, Tensor input, Tensor rois, Tensor offset,
Tensor grad_input, Tensor grad_offset, int pooled_height, int pooled_width,
float spatial_scale, int sampling_ratio, float gamma) {
int output_size = grad_output.numel();
int channels = grad_input.size(1);
int height = grad_input.size(2);
int width = grad_input.size(3);
c10::musa::MUSAGuard device_guard(grad_output.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
AT_DISPATCH_FLOATING_TYPES(
grad_output.scalar_type(), "deform_roi_pool_backward_musa_kernel", [&] {
deform_roi_pool_backward_musa_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, grad_output.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(), rois.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
grad_offset.data_ptr<scalar_t>(), pooled_height, pooled_width,
static_cast<scalar_t>(spatial_scale), sampling_ratio,
static_cast<scalar_t>(gamma), channels, height, width);
});
AT_MUSA_CHECK(musaGetLastError());
}

View File

@ -0,0 +1,35 @@
// Copyright (c) OpenMMLab. All rights reserved
// Adapted from
// https://github.com/lilanxiao/Rotated_IoU/musa_op/sort_vert_kernel.cu # noqa
#include "diff_iou_rotated_musa_kernel.muh"
#include "pytorch_cpp_helper.hpp"
#include "pytorch_musa_helper.hpp"
at::Tensor DiffIoURotatedSortVerticesMUSAKernelLauncher(at::Tensor vertices,
at::Tensor mask,
at::Tensor num_valid) {
c10::musa::MUSAGuard device_guard(vertices.device());
musaStream_t stream = c10::musa::getCurrentMUSAStream();
CHECK_CONTIGUOUS(vertices);
CHECK_CONTIGUOUS(mask);
CHECK_CONTIGUOUS(num_valid);
CHECK_MUSA(vertices);
CHECK_MUSA(mask);
CHECK_MUSA(num_valid);
int b = vertices.size(0);
int n = vertices.size(1);
int m = vertices.size(2);
at::Tensor idx =
torch::zeros({b, n, MAX_NUM_VERT_IDX},
at::device(vertices.device()).dtype(at::ScalarType::Int));
diff_iou_rotated_sort_vertices_forward_musa_kernel<<<b, opt_n_thread(n), 0,
stream>>>(
b, n, m, vertices.data_ptr<float>(), mask.data_ptr<bool>(),
num_valid.data_ptr<int>(), idx.data_ptr<int>());
AT_MUSA_CHECK(musaGetLastError());
return idx;
}

View File

@ -103,6 +103,303 @@ void bbox_overlaps_impl(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
const int mode, const bool aligned, const int offset);
REGISTER_DEVICE_IMPL(bbox_overlaps_impl, MUSA, bbox_overlaps_musa);
void BorderAlignForwardMUSAKernelLauncher(const Tensor &input,
const Tensor &boxes, Tensor output,
Tensor argmax_idx,
const int pool_size);
void BorderAlignBackwardMUSAKernelLauncher(const Tensor &grad_output,
const Tensor &boxes,
const Tensor &argmax_idx,
Tensor grad_input,
const int pool_size);
void border_align_forward_musa(const Tensor &input, const Tensor &boxes,
Tensor output, Tensor argmax_idx,
const int pool_size) {
BorderAlignForwardMUSAKernelLauncher(input, boxes, output, argmax_idx,
pool_size);
}
void border_align_backward_musa(const Tensor &grad_output, const Tensor &boxes,
const Tensor &argmax_idx, Tensor grad_input,
const int pool_size) {
BorderAlignBackwardMUSAKernelLauncher(grad_output, boxes, argmax_idx,
grad_input, pool_size);
}
void border_align_forward_impl(const Tensor &input, const Tensor &boxes,
Tensor output, Tensor argmax_idx,
const int pool_size);
void border_align_backward_impl(const Tensor &grad_output, const Tensor &boxes,
const Tensor &argmax_idx, Tensor grad_input,
const int pool_size);
REGISTER_DEVICE_IMPL(border_align_forward_impl, MUSA,
border_align_forward_musa);
REGISTER_DEVICE_IMPL(border_align_backward_impl, MUSA,
border_align_backward_musa);
void box_iou_rotated_musa(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);
void box_iou_rotated_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);
REGISTER_DEVICE_IMPL(box_iou_rotated_impl, MUSA, box_iou_rotated_musa);
void box_iou_quadri_musa(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);
void box_iou_quadri_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);
REGISTER_DEVICE_IMPL(box_iou_quadri_impl, MUSA, box_iou_quadri_musa);
#if ((!defined(MUSA_ARCH)) || (defined(MUSA_ARCH)) && (MUSA_ARCH > 21))
void CARAFEForwardMUSAKernelLauncher(const Tensor features, const Tensor masks,
Tensor rfeatures, Tensor routput,
Tensor rmasks, Tensor output,
const int kernel_size,
const int group_size,
const int scale_factor);
void CARAFEBackwardMUSAKernelLauncher(
const Tensor top_grad, const Tensor rfeatures, const Tensor masks,
Tensor rtop_grad, Tensor rbottom_grad_hs, Tensor rbottom_grad,
Tensor rmask_grad, Tensor bottom_grad, Tensor mask_grad,
const int kernel_size, const int group_size, const int scale_factor);
void carafe_forward_musa(Tensor features, Tensor masks, Tensor rfeatures,
Tensor routput, Tensor rmasks, Tensor output,
int kernel_size, int group_size, int scale_factor) {
CARAFEForwardMUSAKernelLauncher(features, masks, rfeatures, routput, rmasks,
output, kernel_size, group_size,
scale_factor);
}
void carafe_backward_musa(Tensor top_grad, Tensor rfeatures, Tensor masks,
Tensor rtop_grad, Tensor rbottom_grad_hs,
Tensor rbottom_grad, Tensor rmask_grad,
Tensor bottom_grad, Tensor mask_grad, int kernel_size,
int group_size, int scale_factor) {
CARAFEBackwardMUSAKernelLauncher(top_grad, rfeatures, masks, rtop_grad,
rbottom_grad_hs, rbottom_grad, rmask_grad,
bottom_grad, mask_grad, kernel_size,
group_size, scale_factor);
}
void carafe_forward_impl(Tensor features, Tensor masks, Tensor rfeatures,
Tensor routput, Tensor rmasks, Tensor output,
int kernel_size, int group_size, int scale_factor);
void carafe_backward_impl(Tensor top_grad, Tensor rfeatures, Tensor masks,
Tensor rtop_grad, Tensor rbottom_grad_hs,
Tensor rbottom_grad, Tensor rmask_grad,
Tensor bottom_grad, Tensor mask_grad, int kernel_size,
int group_size, int scale_factor);
REGISTER_DEVICE_IMPL(carafe_forward_impl, MUSA, carafe_forward_musa);
REGISTER_DEVICE_IMPL(carafe_backward_impl, MUSA, carafe_backward_musa);
#endif
void CARAFENAIVEForwardMUSAKernelLauncher(const Tensor features,
const Tensor masks, Tensor output,
const int kernel_size,
const int group_size,
const int scale_factor);
void CARAFENAIVEBackwardMUSAKernelLauncher(
const Tensor top_grad, const Tensor features, const Tensor masks,
Tensor bottom_grad, Tensor mask_grad, const int kernel_size,
const int group_size, const int scale_factor);
void carafe_naive_forward_musa(Tensor features, Tensor masks, Tensor output,
int kernel_size, int group_size,
int scale_factor) {
CARAFENAIVEForwardMUSAKernelLauncher(features, masks, output, kernel_size,
group_size, scale_factor);
}
void carafe_naive_backward_musa(Tensor top_grad, Tensor features, Tensor masks,
Tensor bottom_grad, Tensor mask_grad,
int kernel_size, int group_size,
int scale_factor) {
CARAFENAIVEBackwardMUSAKernelLauncher(top_grad, features, masks, bottom_grad,
mask_grad, kernel_size, group_size,
scale_factor);
}
void carafe_naive_forward_impl(Tensor features, Tensor masks, Tensor output,
int kernel_size, int group_size,
int scale_factor);
void carafe_naive_backward_impl(Tensor top_grad, Tensor features, Tensor masks,
Tensor bottom_grad, Tensor mask_grad,
int kernel_size, int group_size,
int scale_factor);
REGISTER_DEVICE_IMPL(carafe_naive_forward_impl, MUSA,
carafe_naive_forward_musa);
REGISTER_DEVICE_IMPL(carafe_naive_backward_impl, MUSA,
carafe_naive_backward_musa);
void CorrelationForwardMUSAKernelLauncher(Tensor input1, Tensor input2,
Tensor output, int kH, int kW,
int patchH, int patchW, int padH,
int padW, int dilationH,
int dilationW, int dilation_patchH,
int dilation_patchW, int dH, int dW);
void CorrelationBackwardMUSAKernelLauncher(Tensor grad_output, Tensor input1,
Tensor input2, Tensor grad_input1,
Tensor grad_input2, int kH, int kW,
int patchH, int patchW, int padH,
int padW, int dilationH,
int dilationW, int dilation_patchH,
int dilation_patchW, int dH, int dW);
void correlation_forward_musa(Tensor input1, Tensor input2, Tensor output,
int kH, int kW, int patchH, int patchW, int padH,
int padW, int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW, int dH,
int dW) {
CorrelationForwardMUSAKernelLauncher(
input1, input2, output, kH, kW, patchH, patchW, padH, padW, dilationH,
dilationW, dilation_patchH, dilation_patchW, dH, dW);
}
void correlation_backward_musa(Tensor grad_output, Tensor input1, Tensor input2,
Tensor grad_input1, Tensor grad_input2, int kH,
int kW, int patchH, int patchW, int padH,
int padW, int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW, int dH,
int dW) {
CorrelationBackwardMUSAKernelLauncher(
grad_output, input1, input2, grad_input1, grad_input2, kH, kW, patchH,
patchW, padH, padW, dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW);
}
void correlation_forward_impl(Tensor input1, Tensor input2, Tensor output,
int kH, int kW, int patchH, int patchW, int padH,
int padW, int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW, int dH,
int dW);
void correlation_backward_impl(Tensor grad_output, Tensor input1, Tensor input2,
Tensor grad_input1, Tensor grad_input2, int kH,
int kW, int patchH, int patchW, int padH,
int padW, int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW, int dH,
int dW);
REGISTER_DEVICE_IMPL(correlation_forward_impl, MUSA, correlation_forward_musa);
REGISTER_DEVICE_IMPL(correlation_backward_impl, MUSA,
correlation_backward_musa);
void deformable_im2col_musa(Tensor data_im, Tensor data_offset,
const int channels, const int height,
const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
Tensor data_col);
void deformable_col2im_musa(Tensor data_col, Tensor data_offset,
const int channels, const int height,
const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
Tensor grad_im);
void deformable_col2im_coord_musa(
Tensor data_col, Tensor data_im, Tensor data_offset, const int channels,
const int height, const int width, const int ksize_h, const int ksize_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int parallel_imgs,
const int deformable_group, Tensor grad_offset);
void deformable_im2col_impl(Tensor data_im, Tensor data_offset,
const int channels, const int height,
const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
Tensor data_col);
void deformable_col2im_impl(Tensor data_col, Tensor data_offset,
const int channels, const int height,
const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
Tensor grad_im);
void deformable_col2im_coord_impl(
Tensor data_col, Tensor data_im, Tensor data_offset, const int channels,
const int height, const int width, const int ksize_h, const int ksize_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int parallel_imgs,
const int deformable_group, Tensor grad_offset);
REGISTER_DEVICE_IMPL(deformable_im2col_impl, MUSA, deformable_im2col_musa);
REGISTER_DEVICE_IMPL(deformable_col2im_impl, MUSA, deformable_col2im_musa);
REGISTER_DEVICE_IMPL(deformable_col2im_coord_impl, MUSA,
deformable_col2im_coord_musa);
void DeformRoIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois,
Tensor offset, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale,
int sampling_ratio, float gamma);
void DeformRoIPoolBackwardMUSAKernelLauncher(
Tensor grad_output, Tensor input, Tensor rois, Tensor offset,
Tensor grad_input, Tensor grad_offset, int pooled_height, int pooled_width,
float spatial_scale, int sampling_ratio, float gamma);
void deform_roi_pool_forward_musa(Tensor input, Tensor rois, Tensor offset,
Tensor output, int pooled_height,
int pooled_width, float spatial_scale,
int sampling_ratio, float gamma) {
DeformRoIPoolForwardMUSAKernelLauncher(input, rois, offset, output,
pooled_height, pooled_width,
spatial_scale, sampling_ratio, gamma);
}
void deform_roi_pool_backward_musa(Tensor grad_output, Tensor input,
Tensor rois, Tensor offset,
Tensor grad_input, Tensor grad_offset,
int pooled_height, int pooled_width,
float spatial_scale, int sampling_ratio,
float gamma) {
DeformRoIPoolBackwardMUSAKernelLauncher(
grad_output, input, rois, offset, grad_input, grad_offset, pooled_height,
pooled_width, spatial_scale, sampling_ratio, gamma);
}
void deform_roi_pool_forward_impl(Tensor input, Tensor rois, Tensor offset,
Tensor output, int pooled_height,
int pooled_width, float spatial_scale,
int sampling_ratio, float gamma);
void deform_roi_pool_backward_impl(Tensor grad_output, Tensor input,
Tensor rois, Tensor offset,
Tensor grad_input, Tensor grad_offset,
int pooled_height, int pooled_width,
float spatial_scale, int sampling_ratio,
float gamma);
REGISTER_DEVICE_IMPL(deform_roi_pool_forward_impl, MUSA,
deform_roi_pool_forward_musa);
REGISTER_DEVICE_IMPL(deform_roi_pool_backward_impl, MUSA,
deform_roi_pool_backward_musa);
void ActiveRotatedFilterForwardMUSAKernelLauncher(const Tensor input,
const Tensor indices,
Tensor output);
@ -132,6 +429,88 @@ REGISTER_DEVICE_IMPL(active_rotated_filter_forward_impl, MUSA,
REGISTER_DEVICE_IMPL(active_rotated_filter_backward_impl, MUSA,
active_rotated_filter_backward_musa);
void ConvexIoUMUSAKernelLauncher(const Tensor pointsets, const Tensor polygons,
Tensor ious);
void ConvexGIoUMUSAKernelLauncher(const Tensor pointsets, const Tensor polygons,
Tensor output);
void convex_iou_musa(const Tensor pointsets, const Tensor polygons,
Tensor ious) {
ConvexIoUMUSAKernelLauncher(pointsets, polygons, ious);
}
void convex_giou_musa(const Tensor pointsets, const Tensor polygons,
Tensor output) {
ConvexGIoUMUSAKernelLauncher(pointsets, polygons, output);
}
void convex_iou_impl(const Tensor pointsets, const Tensor polygons,
Tensor ious);
void convex_giou_impl(const Tensor pointsets, const Tensor polygons,
Tensor output);
REGISTER_DEVICE_IMPL(convex_iou_impl, MUSA, convex_iou_musa);
REGISTER_DEVICE_IMPL(convex_giou_impl, MUSA, convex_giou_musa);
Tensor DiffIoURotatedSortVerticesMUSAKernelLauncher(Tensor vertices,
Tensor mask,
Tensor num_valid);
Tensor diff_iou_rotated_sort_vertices_forward_musa(Tensor vertices, Tensor mask,
Tensor num_valid) {
return DiffIoURotatedSortVerticesMUSAKernelLauncher(vertices, mask,
num_valid);
}
Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,
Tensor num_valid);
REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, MUSA,
diff_iou_rotated_sort_vertices_forward_musa);
#if ((!defined(MUSA_ARCH)) || (defined(MUSA_ARCH)) && (MUSA_ARCH > 21))
void ChamferDistanceForwardMUSAKernelLauncher(
const Tensor xyz1, const Tensor xyz2, const Tensor dist1,
const Tensor dist2, const Tensor idx1, const Tensor idx2);
#endif
void ChamferDistanceBackwardMUSAKernelLauncher(
const Tensor xyz1, const Tensor xyz2, Tensor idx1, Tensor idx2,
Tensor grad_dist1, Tensor grad_dist2, Tensor grad_xyz1, Tensor grad_xyz2);
#if ((!defined(MUSA_ARCH)) || (defined(MUSA_ARCH)) && (MUSA_ARCH > 21))
void chamfer_distance_forward_musa(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx2) {
ChamferDistanceForwardMUSAKernelLauncher(xyz1, xyz2, dist1, dist2, idx1,
idx2);
};
void chamfer_distance_backward_musa(const Tensor xyz1, const Tensor xyz2,
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2) {
ChamferDistanceBackwardMUSAKernelLauncher(xyz1, xyz2, idx1, idx2, graddist1,
graddist2, gradxyz1, gradxyz2);
};
void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx2);
void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2,
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2);
REGISTER_DEVICE_IMPL(chamfer_distance_forward_impl, MUSA,
chamfer_distance_forward_musa);
REGISTER_DEVICE_IMPL(chamfer_distance_backward_impl, MUSA,
chamfer_distance_backward_musa);
#endif
void BezierAlignForwardMUSAKernelLauncher(Tensor input, Tensor rois,
Tensor output, int aligned_height,
int aligned_width,

View File

@ -4,6 +4,7 @@ import torch
from mmcv.ops import bias_act
from mmcv.ops.bias_act import EasyDict
from mmcv.utils import IS_MUSA_AVAILABLE
_USING_PARROTS = True
try:
@ -131,6 +132,74 @@ class TestBiasAct:
assert out1.max() <= 0.5
assert out2.max() <= 0.5
@pytest.mark.skipif(not IS_MUSA_AVAILABLE, reason='requires musa')
def test_bias_act_musa(self):
if _USING_PARROTS:
gradcheck(
bias_act, (self.input_tensor.musa(), self.bias.musa()),
delta=1e-4,
pt_atol=1e-3)
else:
gradcheck(
bias_act, (self.input_tensor.musa(), self.bias.musa()),
eps=1e-4,
atol=1e-3)
gradgradcheck(
bias_act, (self.input_tensor.musa(), self.bias.musa()),
eps=1e-4,
atol=1e-3)
out = bias_act(self.input_tensor.musa(), self.bias.musa())
assert out.shape == (1, 3)
# test with different dim
input_tensor = torch.randn((1, 1, 3), requires_grad=True).musa()
bias = torch.randn(3, requires_grad=True).musa()
out = bias_act(input_tensor, bias, dim=2)
assert out.shape == (1, 1, 3)
# test with different act
out = bias_act(self.input_tensor.musa(), self.bias.musa(), act='relu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.musa(), self.bias.musa(), act='lrelu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.musa(), self.bias.musa(), act='tanh')
assert out.shape == (1, 3)
out = bias_act(
self.input_tensor.musa(), self.bias.musa(), act='sigmoid')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.musa(), self.bias.musa(), act='elu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.musa(), self.bias.musa(), act='selu')
assert out.shape == (1, 3)
out = bias_act(
self.input_tensor.musa(), self.bias.musa(), act='softplus')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.musa(), self.bias.musa(), act='swish')
assert out.shape == (1, 3)
# test with different alpha
out = bias_act(
self.input_tensor.musa(), self.bias.musa(), act='lrelu', alpha=0.1)
assert out.shape == (1, 3)
# test with different gain
out1 = bias_act(
self.input_tensor.musa(), self.bias.musa(), act='lrelu', gain=0.2)
out2 = bias_act(
self.input_tensor.musa(), self.bias.musa(), act='lrelu', gain=0.1)
assert torch.allclose(out1, out2 * 2)
# test with different clamp
out1 = bias_act(
self.input_tensor.musa(), self.bias.musa(), act='lrelu', clamp=0.5)
out2 = bias_act(
self.input_tensor.musa(), self.bias.musa(), act='lrelu', clamp=0.2)
assert out1.max() <= 0.5
assert out2.max() <= 0.5
def test_easy_dict(self):
easy_dict = EasyDict(
func=lambda x, **_: x,
@ -142,3 +211,15 @@ class TestBiasAct:
_ = easy_dict.def_alpha
easy_dict.def_alpha = 1
del easy_dict.def_alpha
def test_easy_dict_musa(self):
easy_dict = EasyDict(
func=lambda x, **_: x,
def_alpha=0,
def_gain=1,
musa_idx=1,
ref='',
has_2nd_grad=False)
_ = easy_dict.def_alpha
easy_dict.def_alpha = 1
del easy_dict.def_alpha

View File

@ -5,6 +5,8 @@ import numpy as np
import pytest
import torch
from mmcv.utils import IS_MUSA_AVAILABLE
# [1,4c,h,w]
input_arr = [[[[1., 2., 3., 4.], [5., 6., 7., 8.], [9., 10., 11., 12.]],
[[6, 7, 5, 8], [2, 1, 3, 4], [12, 9, 11, 10]],
@ -51,6 +53,8 @@ input_grad_dict = {
def _test_border_align_allclose(device, dtype, pool_size):
if not torch.cuda.is_available() and device == 'cuda':
pytest.skip('test requires GPU')
elif not IS_MUSA_AVAILABLE and device == 'musa':
pytest.skip('test requires GPU')
try:
from mmcv.ops import BorderAlign, border_align
except ModuleNotFoundError:
@ -84,8 +88,16 @@ def _test_border_align_allclose(device, dtype, pool_size):
input.grad.data.type(dtype).cpu().numpy(), np_grad, atol=1e-5)
@pytest.mark.parametrize('device', ['cuda'])
@pytest.mark.parametrize('dtype', [torch.float, torch.half, torch.double])
@pytest.mark.parametrize('device', ['cuda', 'musa'])
@pytest.mark.parametrize('dtype', [
torch.float,
torch.half,
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_MUSA_AVAILABLE,
reason='MUSA does not support for 64-bit floating point')),
])
@pytest.mark.parametrize('pool_size', [1, 2])
def test_border_align(device, dtype, pool_size):
_test_border_align_allclose(device, dtype, pool_size)

View File

@ -3,7 +3,7 @@ import numpy as np
import pytest
import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE, IS_NPU_AVAILABLE
class TestBoxIoUQuadri:
@ -17,7 +17,11 @@ class TestBoxIoUQuadri:
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
not IS_NPU_AVAILABLE, reason='requires NPU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
def test_box_iou_quadri_cuda(self, device):
from mmcv.ops import box_iou_quadri
@ -55,7 +59,11 @@ class TestBoxIoUQuadri:
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
not IS_NPU_AVAILABLE, reason='requires NPU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
def test_box_iou_quadri_iof_cuda(self, device):
from mmcv.ops import box_iou_quadri

View File

@ -4,7 +4,8 @@ import pytest
import torch
from mmcv.ops import box_iou_rotated
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
IS_NPU_AVAILABLE)
class TestBoxIoURotated:
@ -58,7 +59,11 @@ class TestBoxIoURotated:
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
not IS_NPU_AVAILABLE, reason='requires NPU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
def test_box_iou_rotated(self, device):
np_boxes1 = np.asarray(
@ -145,7 +150,11 @@ class TestBoxIoURotated:
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
not IS_NPU_AVAILABLE, reason='requires NPU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
])
def test_box_iou_rotated_iof(self, device):
np_boxes1 = np.asarray(

View File

@ -4,32 +4,34 @@ import pytest
import torch
from torch.autograd import gradcheck
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE
class TestCarafe:
def test_carafe_naive_gradcheck(self):
if not torch.cuda.is_available():
if (not torch.cuda.is_available()) and (not IS_MUSA_AVAILABLE):
return
from mmcv.ops import CARAFENaive
feat = torch.randn(
2, 64, 3, 3, requires_grad=True, device='cuda').double()
mask = torch.randn(
2, 100, 6, 6, requires_grad=True,
device='cuda').sigmoid().double()
gradcheck(CARAFENaive(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4)
if IS_CUDA_AVAILABLE:
feat = torch.randn(
2, 64, 3, 3, requires_grad=True, device='cuda').double()
mask = torch.randn(
2, 100, 6, 6, requires_grad=True,
device='cuda').sigmoid().double()
gradcheck(CARAFENaive(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4)
def test_carafe_gradcheck(self):
if not torch.cuda.is_available():
if (not torch.cuda.is_available()) and (not IS_MUSA_AVAILABLE):
return
from mmcv.ops import CARAFE
feat = torch.randn(
2, 64, 3, 3, requires_grad=True, device='cuda').double()
mask = torch.randn(
2, 100, 6, 6, requires_grad=True,
device='cuda').sigmoid().double()
gradcheck(CARAFE(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4)
if IS_CUDA_AVAILABLE:
feat = torch.randn(
2, 64, 3, 3, requires_grad=True, device='cuda').double()
mask = torch.randn(
2, 100, 6, 6, requires_grad=True,
device='cuda').sigmoid().double()
gradcheck(CARAFE(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4)
@pytest.mark.parametrize('device', [
pytest.param(
@ -39,7 +41,11 @@ class TestCarafe:
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
def test_carafe_allclose(self, device):
try:
@ -64,14 +70,43 @@ class TestCarafe:
np_feat_grad = np_feat_grad.reshape((2, 64, 3, 3))
np_mask_grad = np_mask_grad.reshape((2, 100, 6, 6))
feat = torch.tensor(
np_feat, dtype=torch.float, device=device, requires_grad=True)
mask = torch.tensor(
np_mask, dtype=torch.float, device=device, requires_grad=True)
# feat = torch.tensor(
# np_feat, dtype=torch.float, device=device, requires_grad=True)
# mask = torch.tensor(
# np_mask, dtype=torch.float, device=device, requires_grad=True)
# feat = torch.tensor(
# np_feat, dtype=torch.float, requires_grad=True).to(device)
# mask = torch.tensor(
# np_mask, dtype=torch.float, requires_grad=True).to(device)
# feat = torch.tensor(
# np_feat, dtype=torch.float).to(device)
# mask = torch.tensor(
# np_mask, dtype=torch.float).to(device)
# feat_cpu = torch.from_numpy(np_feat).to(torch.float)
# mask_cpu = torch.from_numpy(np_mask).to(torch.float)
# if device == 'musa':
# feat =feat_cpu.musa()
# mask =mask_cpu.musa()
# else:
# feat =feat_cpu.to(device)
# mask =mask_cpu.to(device)
# feat.requires_grad = True
# mask.requires_grad = True
feat_cpu = torch.FloatTensor(np_feat)
mask_cpu = torch.FloatTensor(np_mask)
feat = feat_cpu.to(device)
mask = mask_cpu.to(device)
feat.requires_grad = True
mask.requires_grad = True
# pytest.set_trace()
carafe = CARAFE(5, 4, 2)
carafe.to(device)
carafe.train()
output = carafe(feat, mask)
output.backward(torch.ones_like(output))
assert np.allclose(
output.data.type(torch.float).cpu().numpy(), np_output, atol=1e-3)

View File

@ -3,6 +3,8 @@ import numpy as np
import torch
import torch.nn as nn
from mmcv.utils import IS_MUSA_AVAILABLE
class Loss(nn.Module):
@ -20,6 +22,9 @@ class TestCrissCrossAttention:
def test_cc_attention(self):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if IS_MUSA_AVAILABLE:
device = torch.device('musa:0')
from mmcv.ops import CrissCrossAttention
loss_func = Loss()

View File

@ -4,7 +4,7 @@ import pytest
import torch
from mmcv.ops import chamfer_distance
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE, IS_NPU_AVAILABLE
def chamfer_distance_forward_groundtruth(xyz1, xyz2, dtype):
@ -51,7 +51,11 @@ def torch_to_np_type(dtype):
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
not IS_NPU_AVAILABLE, reason='requires NPU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
@pytest.mark.parametrize('dtype', [torch.half, torch.float32])
@pytest.mark.parametrize('shape', [(2, 600, 2), (1, 1, 2), (7, 7, 2)])

View File

@ -5,6 +5,7 @@ import torch.nn as nn
from torch.autograd import gradcheck, gradgradcheck
from mmcv.ops import conv2d, conv_transpose2d
from mmcv.utils import IS_MUSA_AVAILABLE
class TestCond2d:
@ -23,6 +24,15 @@ class TestCond2d:
gradcheck(conv2d, (x, weight, None, 1, 1), eps=1e-2, atol=0.1)
gradgradcheck(conv2d, (x, weight, None, 1, 1), eps=1e-2, atol=0.1)
@pytest.mark.skipif(not IS_MUSA_AVAILABLE, reason='requires musa')
def test_conv2d_musa(self):
x = self.input.musa()
weight = self.weight.musa()
res = conv2d(x, weight, None, 1, 1)
assert res.shape == (1, 1, 32, 32)
gradcheck(conv2d, (x, weight, None, 1, 1), eps=1e-2, atol=0.1)
gradgradcheck(conv2d, (x, weight, None, 1, 1), eps=1e-2, atol=0.1)
class TestCond2dTansposed:
@ -41,3 +51,14 @@ class TestCond2dTansposed:
conv_transpose2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2)
gradgradcheck(
conv_transpose2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2)
@pytest.mark.skipif(not IS_MUSA_AVAILABLE, reason='requires musa')
def test_conv2d_transposed_musa(self):
x = self.input.musa()
weight = self.weight.musa()
res = conv_transpose2d(x, weight, None, 1, 1)
assert res.shape == (1, 1, 32, 32)
gradcheck(
conv_transpose2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2)
gradgradcheck(
conv_transpose2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2)

View File

@ -4,6 +4,7 @@ import pytest
import torch
from mmcv.ops import convex_giou, convex_iou
from mmcv.utils import IS_MUSA_AVAILABLE
np_pointsets = np.asarray([[
1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0, 1.0, 1.0, 3.0, 3.0, 1.0, 2.0, 3.0, 3.0,
@ -54,3 +55,12 @@ def test_convex_giou():
giou, grad = convex_giou(pointsets, polygons)
assert torch.allclose(giou, expected_giou, atol=1e-3)
assert torch.allclose(grad, expected_grad, atol=1e-3)
@pytest.mark.skipif(not IS_MUSA_AVAILABLE, reason='requires musa')
def test_convex_miou():
pointsets = torch.from_numpy(np_pointsets).musa().float()
polygons = torch.from_numpy(np_polygons).musa().float()
expected_iou = torch.from_numpy(np_expected_iou).musa().float()
assert torch.allclose(
convex_iou(pointsets, polygons), expected_iou, atol=1e-3)

View File

@ -3,6 +3,7 @@ import pytest
import torch
from mmcv.ops import Correlation
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE
_input1 = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]]
_input2 = [[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]]]]
@ -23,24 +24,33 @@ class TestCorrelation:
layer = Correlation(max_displacement=0)
input1 = torch.tensor(_input1, dtype=dtype).cuda()
input2 = torch.tensor(_input2, dtype=dtype).cuda()
if IS_CUDA_AVAILABLE:
input1 = torch.tensor(_input1, dtype=dtype).cuda()
input2 = torch.tensor(_input2, dtype=dtype).cuda()
elif IS_MUSA_AVAILABLE:
input1 = torch.tensor(_input1, dtype=dtype).musa()
input2 = torch.tensor(_input2, dtype=dtype).musa()
input1.requires_grad = True
input2.requires_grad = True
out = layer(input1, input2)
out.backward(torch.ones_like(out))
# `eq_cpu` is not implemented for 'Half' in torch1.5.0,
# so we need to make a comparison for cuda tensor
# so we need to make a comparison for cuda/musa tensor
# rather than cpu tensor
gt_out = torch.tensor(_gt_out, dtype=dtype).cuda()
if IS_CUDA_AVAILABLE:
gt_out = torch.tensor(_gt_out, dtype=dtype).cuda()
elif IS_MUSA_AVAILABLE:
gt_out = torch.tensor(_gt_out, dtype=dtype).musa()
assert_equal_tensor(out, gt_out)
assert_equal_tensor(input1.grad.detach(), input2)
assert_equal_tensor(input2.grad.detach(), input1)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
(not torch.cuda.is_available()) and (not IS_MUSA_AVAILABLE),
reason='requires CUDA/MUSA support')
def test_correlation(self):
self._test_correlation(torch.float)
self._test_correlation(torch.double)
if IS_CUDA_AVAILABLE:
self._test_correlation(torch.double)
self._test_correlation(torch.half)

View File

@ -5,17 +5,24 @@ import torch
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE
if IS_MLU_AVAILABLE:
torch.backends.cnnl.allow_tf32 = False
try:
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
# would be imported and used; we should test if our modules support it.
from torch.cuda.amp import autocast
except ImportError:
pass
if IS_MUSA_AVAILABLE:
try:
from torch_musa.core.amp import autocast
except ImportError:
pass
else:
try:
# If PyTorch version >= 1.6.0 and fp16 is enabled
# torch.cuda.amp.autocast would be imported and used
# we should test if our modules support it.
from torch.cuda.amp import autocast
except ImportError:
pass
input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]]
offset_weight = [[[0.1, 0.4, 0.6, 0.1]], [[0.3, 0.2, 0.1, 0.3]],
@ -79,6 +86,8 @@ class TestDeformconv:
model.cuda()
elif device == 'mlu':
model.mlu()
elif device == 'musa':
model.musa()
model.type(dtype)
out = model(x)
@ -161,6 +170,8 @@ class TestDeformconv:
model.cuda()
elif device == 'mlu':
model.mlu()
elif device == 'musa':
model.musa()
out = model(x)
out.backward(torch.ones_like(out))
@ -194,19 +205,24 @@ class TestDeformconv:
with pytest.raises(AssertionError):
model = DeformConv2d(3, 4, 3, groups=3)
@pytest.mark.parametrize('device, threshold', [
('cpu', 1e-1),
pytest.param(
'cuda',
1e-3,
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
1e-3,
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support')),
])
@pytest.mark.parametrize(
'device, threshold',
[('cpu', 1e-1),
pytest.param(
'cuda',
1e-3,
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
1e-3,
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'musa',
1e-3,
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))])
def test_deformconv_float(self, device, threshold):
self._test_deformconv(torch.float, device=device, threshold=threshold)
# test batch_size < im2col_step
@ -244,6 +260,11 @@ class TestDeformconv:
1e-1,
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'musa',
1e-1,
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
def test_deformconv_half(self, device, threshold):
self._test_deformconv(torch.half, device=device, threshold=threshold)

View File

@ -5,7 +5,8 @@ import numpy as np
import pytest
import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE,
IS_NPU_AVAILABLE)
_USING_PARROTS = True
try:
@ -137,16 +138,20 @@ class TestDeformRoIPool:
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support')),
])
@pytest.mark.parametrize('dtype', [
torch.float,
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_MLU_AVAILABLE,
reason='MLU does not support for 64-bit floating point')),
torch.half
IS_MLU_AVAILABLE or IS_MUSA_AVAILABLE,
reason='MLU, MUSA does not support for 64-bit floating point'),
), torch.half
])
def test_deform_roi_pool_allclose(self, device, dtype):
self._test_deform_roi_pool_allclose(device, dtype)

View File

@ -4,12 +4,13 @@ import pytest
import torch
from mmcv.ops import diff_iou_rotated_2d, diff_iou_rotated_3d
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE
if IS_MLU_AVAILABLE:
torch.backends.mlu.matmul.allow_tf32 = False
# TODO @MTAI there are some bugs for musa!
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
@ -18,7 +19,11 @@ if IS_MLU_AVAILABLE:
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
def test_diff_iou_rotated_2d(device):
np_boxes1 = np.asarray([[[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0],
@ -47,7 +52,11 @@ def test_diff_iou_rotated_2d(device):
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'musa',
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
def test_diff_iou_rotated_3d(device):
np_boxes1 = np.asarray(