From c7d3326d8a8578a82b405b3dd719bf48c6a86abe Mon Sep 17 00:00:00 2001 From: sunyanguomt Date: Tue, 11 Mar 2025 12:28:07 +0800 Subject: [PATCH] [MUSA] Support MUSA PR2 * squash commit for support musa pr2 * fix conv2d_gradfix:Conv2d is_cuda_available judgment * Revert "fix conv2d_gradfix:Conv2d is_cuda_available judgment" This reverts commit 38faea16c126612bae0897f65cf722144b7cb24e. * fix conv2d_gradfix:Conv2d comment --- mmcv/ops/bias_act.py | 205 +++++ mmcv/ops/carafe.py | 2 +- mmcv/ops/conv2d_gradfix.py | 6 +- .../common/musa/border_align_musa_kernel.muh | 192 ++++ .../csrc/common/musa/box_iou_quadri_musa.muh | 88 ++ .../csrc/common/musa/box_iou_rotated_musa.muh | 77 ++ .../csrc/common/musa/carafe_musa_kernel.muh | 332 +++++++ .../common/musa/carafe_naive_musa_kernel.muh | 107 +++ .../musa/chamfer_distance_musa_kernel.muh | 101 +++ .../common/musa/convex_iou_musa_kernel.muh | 827 ++++++++++++++++++ .../ops/csrc/common/musa/correlation_musa.muh | 227 +++++ .../common/musa/deform_conv_musa_kernel.muh | 360 ++++++++ .../musa/deform_roi_pool_musa_kernel.muh | 181 ++++ .../musa/diff_iou_rotated_musa_kernel.muh | 133 +++ mmcv/ops/csrc/pytorch/deform_conv.cpp | 16 +- mmcv/ops/csrc/pytorch/musa/bias_act_musa.mu | 301 +++++++ .../csrc/pytorch/musa/border_align_musa.mu | 68 ++ .../csrc/pytorch/musa/box_iou_quadri_musa.mu | 23 + .../csrc/pytorch/musa/box_iou_rotated_musa.mu | 25 + mmcv/ops/csrc/pytorch/musa/carafe_musa.mu | 182 ++++ .../csrc/pytorch/musa/carafe_naive_musa.mu | 52 ++ .../pytorch/musa/chamfer_distance_musa.mu | 66 ++ mmcv/ops/csrc/pytorch/musa/convex_iou.mu | 41 + .../ops/csrc/pytorch/musa/correlation_musa.mu | 94 ++ .../ops/csrc/pytorch/musa/deform_conv_musa.mu | 105 +++ .../csrc/pytorch/musa/deform_roi_pool_musa.mu | 55 ++ .../pytorch/musa/diff_iou_rotated_musa.mu | 35 + mmcv/ops/csrc/pytorch/musa/musabind.cpp | 379 ++++++++ tests/test_ops/test_bias_act.py | 81 ++ tests/test_ops/test_border_align.py | 16 +- tests/test_ops/test_box_iou_quadri.py | 14 +- tests/test_ops/test_box_iou_rotated.py | 15 +- tests/test_ops/test_carafe.py | 77 +- tests/test_ops/test_cc_attention.py | 5 + tests/test_ops/test_chamfer_distance.py | 8 +- tests/test_ops/test_conv_gradfix.py | 21 + tests/test_ops/test_convex_iou.py | 10 + tests/test_ops/test_correlation.py | 22 +- tests/test_ops/test_deform_conv.py | 61 +- tests/test_ops/test_deform_roi_pool.py | 15 +- tests/test_ops/test_diff_iou_rotated.py | 15 +- 41 files changed, 4569 insertions(+), 71 deletions(-) create mode 100644 mmcv/ops/csrc/common/musa/border_align_musa_kernel.muh create mode 100644 mmcv/ops/csrc/common/musa/box_iou_quadri_musa.muh create mode 100644 mmcv/ops/csrc/common/musa/box_iou_rotated_musa.muh create mode 100644 mmcv/ops/csrc/common/musa/carafe_musa_kernel.muh create mode 100644 mmcv/ops/csrc/common/musa/carafe_naive_musa_kernel.muh create mode 100644 mmcv/ops/csrc/common/musa/chamfer_distance_musa_kernel.muh create mode 100644 mmcv/ops/csrc/common/musa/convex_iou_musa_kernel.muh create mode 100644 mmcv/ops/csrc/common/musa/correlation_musa.muh create mode 100644 mmcv/ops/csrc/common/musa/deform_conv_musa_kernel.muh create mode 100644 mmcv/ops/csrc/common/musa/deform_roi_pool_musa_kernel.muh create mode 100644 mmcv/ops/csrc/common/musa/diff_iou_rotated_musa_kernel.muh create mode 100644 mmcv/ops/csrc/pytorch/musa/bias_act_musa.mu create mode 100644 mmcv/ops/csrc/pytorch/musa/border_align_musa.mu create mode 100644 mmcv/ops/csrc/pytorch/musa/box_iou_quadri_musa.mu create mode 100644 mmcv/ops/csrc/pytorch/musa/box_iou_rotated_musa.mu create mode 100644 mmcv/ops/csrc/pytorch/musa/carafe_musa.mu create mode 100644 mmcv/ops/csrc/pytorch/musa/carafe_naive_musa.mu create mode 100644 mmcv/ops/csrc/pytorch/musa/chamfer_distance_musa.mu create mode 100644 mmcv/ops/csrc/pytorch/musa/convex_iou.mu create mode 100644 mmcv/ops/csrc/pytorch/musa/correlation_musa.mu create mode 100644 mmcv/ops/csrc/pytorch/musa/deform_conv_musa.mu create mode 100644 mmcv/ops/csrc/pytorch/musa/deform_roi_pool_musa.mu create mode 100644 mmcv/ops/csrc/pytorch/musa/diff_iou_rotated_musa.mu diff --git a/mmcv/ops/bias_act.py b/mmcv/ops/bias_act.py index 3dfa55743..5ee02f128 100644 --- a/mmcv/ops/bias_act.py +++ b/mmcv/ops/bias_act.py @@ -114,6 +114,81 @@ activation_funcs = { has_2nd_grad=True), } +activation_funcs_musa = { + 'linear': + EasyDict( + func=lambda x, **_: x, + def_alpha=0, + def_gain=1, + musa_idx=1, + ref='', + has_2nd_grad=False), + 'relu': + EasyDict( + func=lambda x, **_: torch.nn.functional.relu(x), + def_alpha=0, + def_gain=np.sqrt(2), + musa_idx=2, + ref='y', + has_2nd_grad=False), + 'lrelu': + EasyDict( + func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), + def_alpha=0.2, + def_gain=np.sqrt(2), + musa_idx=3, + ref='y', + has_2nd_grad=False), + 'tanh': + EasyDict( + func=lambda x, **_: torch.tanh(x), + def_alpha=0, + def_gain=1, + musa_idx=4, + ref='y', + has_2nd_grad=True), + 'sigmoid': + EasyDict( + func=lambda x, **_: torch.sigmoid(x), + def_alpha=0, + def_gain=1, + musa_idx=5, + ref='y', + has_2nd_grad=True), + 'elu': + EasyDict( + func=lambda x, **_: torch.nn.functional.elu(x), + def_alpha=0, + def_gain=1, + musa_idx=6, + ref='y', + has_2nd_grad=True), + 'selu': + EasyDict( + func=lambda x, **_: torch.nn.functional.selu(x), + def_alpha=0, + def_gain=1, + musa_idx=7, + ref='y', + has_2nd_grad=True), + 'softplus': + EasyDict( + func=lambda x, **_: torch.nn.functional.softplus(x), + def_alpha=0, + def_gain=1, + musa_idx=8, + ref='y', + has_2nd_grad=True), + 'swish': + EasyDict( + func=lambda x, **_: torch.sigmoid(x) * x, + def_alpha=0, + def_gain=np.sqrt(2), + musa_idx=9, + ref='x', + has_2nd_grad=True), +} + _null_tensor = torch.empty([0]) @@ -167,6 +242,13 @@ def bias_act(input: torch.Tensor, return _bias_act_cuda( dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(input, bias) + try: + if use_custom_op and input.is_musa: + return _bias_act_musa( + dim=dim, act=act, alpha=alpha, gain=gain, + clamp=clamp).apply(input, bias) + except AttributeError: + pass return _bias_act_ref( input=input, bias=bias, @@ -373,3 +455,126 @@ def _bias_act_cuda(dim: int = 1, # Add to cache. _bias_act_cuda_cache[key] = BiasActCuda return BiasActCuda + + +_bias_act_musa_cache: Dict = dict() + + +def _bias_act_musa(dim: int = 1, + act: str = 'linear', + alpha: Optional[Union[float, int]] = None, + gain: Optional[float] = None, + clamp: Optional[float] = None): + """"Fast MUSA implementation of `bias_act()` using custom ops. + + Args: + dim (int): The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + Defaults to 1. + act (str): Name of the activation function to evaluate, or `"linear"` + to disable. Can be e.g. "relu", "lrelu", "tanh", "sigmoid", + "swish", etc. See `activation_funcs_musa` for a full list. `None` + is not allowed. Defaults to `linear`. + alpha (float | int): Shape parameter for the activation + function, or `None` to use the default. Defaults to None. + gain (float): Scaling factor for the output tensor, or `None` + to use default. See `activation_funcs_musa` for the default scaling + of each activation function. If unsure, consider specifying 1. + Defaults to None. + clamp (float): Clamp the output values to `[-clamp, +clamp]`, + or `None` to disable the clamping (default). Defaults to None. + + Returns: + torch.Tensor: Tensor of the same shape and datatype as `x`. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs_musa[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_musa_cache: + return _bias_act_musa_cache[key] + + # Forward op. + class BiasActMusa(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, b): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride( + 1) == 1 else torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor.to(x.device) + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or ( + b is not _null_tensor.to(x.device)): + y = ext_module.bias_act(x, b, _null_tensor.to(x.device), + _null_tensor.to(x.device), + _null_tensor.to(x.device), 0, dim, + spec.musa_idx, alpha, gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor.to( + x.device), b if 'x' in spec.ref or spec.has_2nd_grad else + _null_tensor.to(x.device), + y if 'y' in spec.ref else _null_tensor.to(x.device)) + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActMusaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActMusaGrad(torch.autograd.Function): + + @staticmethod + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if dy.ndim > 2 and ( + dy.stride(1) == 1) else torch.contiguous_format + dx = ext_module.bias_act(dy, b, x, y, _null_tensor.to(x.device), 1, + dim, spec.musa_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor.to(x.device), x, b, + y) + return dx + + @staticmethod + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActMusaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and (ctx.needs_input_grad[1] + or ctx.needs_input_grad[2]): + d_x = ext_module.bias_act(d_dx, b, x, y, dy, 2, dim, + spec.musa_idx, alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_musa_cache[key] = BiasActMusa + return BiasActMusa diff --git a/mmcv/ops/carafe.py b/mmcv/ops/carafe.py index f7e79c275..30f3c38a0 100644 --- a/mmcv/ops/carafe.py +++ b/mmcv/ops/carafe.py @@ -65,7 +65,7 @@ class CARAFENaiveFunction(Function): def backward( ctx, grad_output: Tensor) -> Tuple[Tensor, Tensor, None, None, None]: - assert grad_output.is_cuda + assert grad_output.is_cuda or grad_output.is_musa features, masks = ctx.saved_tensors kernel_size = ctx.kernel_size diff --git a/mmcv/ops/conv2d_gradfix.py b/mmcv/ops/conv2d_gradfix.py index b93a76a84..070f2918f 100644 --- a/mmcv/ops/conv2d_gradfix.py +++ b/mmcv/ops/conv2d_gradfix.py @@ -15,6 +15,7 @@ import warnings from typing import Dict, Optional, Tuple, Union import torch +from mmengine.device import is_musa_available from mmengine.utils import digit_version from mmengine.utils.dl_utils.parrots_wrapper import is_rocm_pytorch @@ -95,6 +96,8 @@ def conv_transpose2d(input: torch.Tensor, def _should_use_custom_op(input): assert isinstance(input, torch.Tensor) + if enabled and is_musa_available(): + return True if (not enabled) or (not torch.backends.cudnn.enabled): return False if input.device.type != 'cuda': @@ -177,7 +180,8 @@ def _conv2d_gradfix( ctx.input_shape = input.shape # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). - if weight_shape[2:] == stride == dilation == ( + if (not is_musa_available() + ) and weight_shape[2:] == stride == dilation == ( 1, 1) and padding == ( 0, 0) and torch.cuda.get_device_capability( input.device) < (8, 0): diff --git a/mmcv/ops/csrc/common/musa/border_align_musa_kernel.muh b/mmcv/ops/csrc/common/musa/border_align_musa_kernel.muh new file mode 100644 index 000000000..553e6affa --- /dev/null +++ b/mmcv/ops/csrc/common/musa/border_align_musa_kernel.muh @@ -0,0 +1,192 @@ +// Copyright (c) OpenMMLab. All rights reserved +// modified from +// https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/csrc/border_align/border_align_kernel.cu. +// the main difference: (1) use `argmax_idx` for fast computing of gradient +// during the backward. (2) `wh` is directly computed by `boxes`, rather than +// passing it as argument to forward or backward functions. + +#ifndef BORDER_ALIGN_MUSA_KERNEL_MUH +#define BORDER_ALIGN_MUSA_KERNEL_MUH + +#include +#include "pytorch_musa_helper.hpp" + +enum BorderMode { Top = 0, Left = 1, Bottom = 2, Right = 3 }; + +/*** Forward ***/ +template +__global__ void border_align_forward_musa_kernel( + const int nthreads, const T* input, const T* boxes, T* output, + int* argmax_idx, const int channels, const int box_size, const int height, + const int width, const int pool_size) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + // (batch_idx, c_idx, box_idx) is an element paralleled for computing + // output, and `extreme_idx` is in range [0,3] + int batch_idx, c_idx, box_idx, extreme_idx, maxidx, *offset_argmax_idx; + const T *offset_box, *offset_input, *offset_box_x; + T *offset_output, box_width, box_height, stride, x_stride, y_stride, x, y, + val, maxval; + + extreme_idx = threadIdx.y; + // shape (N, C, box_size, 4) for output + batch_idx = index / channels / box_size; + // shape (N, box_size, 4) for boxes + box_idx = index % box_size + batch_idx * box_size; + c_idx = (index / box_size) % channels; + + offset_box = boxes + box_idx * 4; + box_width = *(offset_box + 2) - *offset_box; + box_height = *(offset_box + 3) - *(offset_box + 1); + offset_output = output + index * 4 + extreme_idx; + offset_argmax_idx = argmax_idx + index * 4 + extreme_idx; + // shape (N, 4C, h, w) for input. + // [0,C) for top feature, [C,2C) for left feature, + // [2C,3C) for bottom feature, [3C,4C) for right feature + offset_input = + input + (batch_idx * channels * 4 + extreme_idx * channels + c_idx) * + height * width; + + // extreme_idx in [0,1] -> offset_box_x indexed at x1 + // extreme_idx in [2,3] -> offset_box_x indexed at x2 + offset_box_x = offset_box + extreme_idx / 2 * 2; + + // (x1,y1) or (x2,y2) for (x,y) + x = *offset_box_x; + y = *(offset_box_x + 1); + + switch (extreme_idx) { + // top + case BorderMode::Top: + stride = box_width / pool_size; + x_stride = stride; + y_stride = 0; + break; + // left + case BorderMode::Left: + stride = box_height / pool_size; + x_stride = 0; + y_stride = stride; + break; + // bottom + case BorderMode::Bottom: + stride = box_width / pool_size; + x_stride = -stride; + y_stride = 0; + break; + // right + case BorderMode::Right: + stride = box_height / pool_size; + x_stride = 0; + y_stride = -stride; + break; + } + + // initialize maxval and maxidx with the start position (e.g. (x1,y1) or + // (x2,y2)) + maxval = bilinear_interpolate(offset_input, height, width, y, x, index); + maxidx = 0; + + // do max_pool along the border + for (int i = 1; i <= pool_size; i++) { + x += x_stride; + y += y_stride; + val = bilinear_interpolate(offset_input, height, width, y, x, index); + if (val > maxval) { + maxval = val; + maxidx = i; + } + } + + // update output and argmax_idx + *offset_output = maxval; + *offset_argmax_idx = maxidx; + } +} + +/*** Backward ***/ +template +__global__ void border_align_backward_musa_kernel( + const int nthreads, const T* grad_output, const T* boxes, + const int* argmax_idx, T* grad_input, const int channels, + const int box_size, const int height, const int width, + const int pool_size) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + // (batch_idx, c_idx, box_idx) is an element paralleled for computing + // output, and `extreme_idx` is in range [0,3] + int batch_idx, c_idx, box_idx, extreme_idx; + const int* offset_argmax_idx; + const T *offset_grad_output, *offset_box, *offset_box_x; + T *offset_grad_input, box_width, box_height, stride, x_stride, y_stride, x, + y; + + extreme_idx = threadIdx.y; + batch_idx = index / channels / box_size; + box_idx = index % box_size + batch_idx * box_size; + c_idx = (index / box_size) % channels; + + offset_box = boxes + box_idx * 4; + box_width = *(offset_box + 2) - *offset_box; + box_height = *(offset_box + 3) - *(offset_box + 1); + offset_grad_output = grad_output + index * 4 + extreme_idx; + offset_argmax_idx = argmax_idx + index * 4 + extreme_idx; + // [0,C) for top feature grad, [C,2C) for left feature grad, + // [2C,3C) for bottom feature grad, [3C,4C) for right feature grad + offset_grad_input = grad_input + (batch_idx * channels * 4 + + extreme_idx * channels + c_idx) * + height * width; + + // extreme_idx in [0,1] -> offset_box_x indexed at x1 + // extreme_idx in [2,3] -> offset_box_x indexed at x2 + offset_box_x = offset_box + extreme_idx / 2 * 2; + + switch (extreme_idx) { + // top + case BorderMode::Top: + stride = box_width / pool_size; + x_stride = stride; + y_stride = 0; + break; + // left + case BorderMode::Left: + stride = box_height / pool_size; + x_stride = 0; + y_stride = stride; + break; + // bottom + case BorderMode::Bottom: + stride = box_width / pool_size; + x_stride = -stride; + y_stride = 0; + break; + // right + case BorderMode::Right: + stride = box_height / pool_size; + x_stride = 0; + y_stride = -stride; + break; + } + + // get position (x,y) which has maximum value during forward + x = *offset_box_x; + y = *(offset_box_x + 1); + x += x_stride * (T)(*offset_argmax_idx); + y += y_stride * (T)(*offset_argmax_idx); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4, x_low, + x_high, y_low, y_high, index); + + // update grad_output + atomicAdd(offset_grad_input + y_low * width + x_low, + *offset_grad_output * w1); + atomicAdd(offset_grad_input + y_low * width + x_high, + *offset_grad_output * w2); + atomicAdd(offset_grad_input + y_high * width + x_low, + *offset_grad_output * w3); + atomicAdd(offset_grad_input + y_high * width + x_high, + *offset_grad_output * w4); + } +} + +#endif // BORDER_ALIGN_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/box_iou_quadri_musa.muh b/mmcv/ops/csrc/common/musa/box_iou_quadri_musa.muh new file mode 100644 index 000000000..2e5b1e167 --- /dev/null +++ b/mmcv/ops/csrc/common/musa/box_iou_quadri_musa.muh @@ -0,0 +1,88 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +#ifndef BOX_IOU_QUADRI_MUSA_MUH +#define BOX_IOU_QUADRI_MUSA_MUH + + +#include "pytorch_musa_helper.hpp" +#include "box_iou_rotated_utils.hpp" + +// 2D block with 32 * 16 = 512 threads per block +const int BLOCK_DIM_X = 32; +const int BLOCK_DIM_Y = 16; + +inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); } + +template +__global__ void box_iou_quadri_musa_kernel( + const int n_boxes1, const int n_boxes2, const T* dev_boxes1, + const T* dev_boxes2, T* dev_ious, const int mode_flag, const bool aligned) { + if (aligned) { + MUSA_1D_KERNEL_LOOP(index, n_boxes1) { + int b1 = index; + int b2 = index; + + int base1 = b1 * 8; + + float block_boxes1[8]; + float block_boxes2[8]; + + block_boxes1[0] = dev_boxes1[base1 + 0]; + block_boxes1[1] = dev_boxes1[base1 + 1]; + block_boxes1[2] = dev_boxes1[base1 + 2]; + block_boxes1[3] = dev_boxes1[base1 + 3]; + block_boxes1[4] = dev_boxes1[base1 + 4]; + block_boxes1[5] = dev_boxes1[base1 + 5]; + block_boxes1[6] = dev_boxes1[base1 + 6]; + block_boxes1[7] = dev_boxes1[base1 + 7]; + + int base2 = b2 * 8; + + block_boxes2[0] = dev_boxes2[base2 + 0]; + block_boxes2[1] = dev_boxes2[base2 + 1]; + block_boxes2[2] = dev_boxes2[base2 + 2]; + block_boxes2[3] = dev_boxes2[base2 + 3]; + block_boxes2[4] = dev_boxes2[base2 + 4]; + block_boxes2[5] = dev_boxes2[base2 + 5]; + block_boxes2[6] = dev_boxes2[base2 + 6]; + block_boxes2[7] = dev_boxes2[base2 + 7]; + + dev_ious[index] = + single_box_iou_quadri(block_boxes1, block_boxes2, mode_flag); + } + } else { + MUSA_1D_KERNEL_LOOP(index, n_boxes1 * n_boxes2) { + int b1 = index / n_boxes2; + int b2 = index % n_boxes2; + + int base1 = b1 * 8; + + float block_boxes1[8]; + float block_boxes2[8]; + + block_boxes1[0] = dev_boxes1[base1 + 0]; + block_boxes1[1] = dev_boxes1[base1 + 1]; + block_boxes1[2] = dev_boxes1[base1 + 2]; + block_boxes1[3] = dev_boxes1[base1 + 3]; + block_boxes1[4] = dev_boxes1[base1 + 4]; + block_boxes1[5] = dev_boxes1[base1 + 5]; + block_boxes1[6] = dev_boxes1[base1 + 6]; + block_boxes1[7] = dev_boxes1[base1 + 7]; + + int base2 = b2 * 8; + + block_boxes2[0] = dev_boxes2[base2 + 0]; + block_boxes2[1] = dev_boxes2[base2 + 1]; + block_boxes2[2] = dev_boxes2[base2 + 2]; + block_boxes2[3] = dev_boxes2[base2 + 3]; + block_boxes2[4] = dev_boxes2[base2 + 4]; + block_boxes2[5] = dev_boxes2[base2 + 5]; + block_boxes2[6] = dev_boxes2[base2 + 6]; + block_boxes2[7] = dev_boxes2[base2 + 7]; + + dev_ious[index] = + single_box_iou_quadri(block_boxes1, block_boxes2, mode_flag); + } + } +} + +#endif diff --git a/mmcv/ops/csrc/common/musa/box_iou_rotated_musa.muh b/mmcv/ops/csrc/common/musa/box_iou_rotated_musa.muh new file mode 100644 index 000000000..70802449a --- /dev/null +++ b/mmcv/ops/csrc/common/musa/box_iou_rotated_musa.muh @@ -0,0 +1,77 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +// modified from +// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu +#ifndef BOX_IOU_ROTATED_MUSA_MUH +#define BOX_IOU_ROTATED_MUSA_MUH + +#include "pytorch_musa_helper.hpp" +#include "box_iou_rotated_utils.hpp" + +// 2D block with 32 * 16 = 512 threads per block +const int BLOCK_DIM_X = 32; +const int BLOCK_DIM_Y = 16; + +inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); } + +template +__global__ void box_iou_rotated_musa_kernel( + const int n_boxes1, const int n_boxes2, const T* dev_boxes1, + const T* dev_boxes2, T* dev_ious, const int mode_flag, const bool aligned) { + if (aligned) { + MUSA_1D_KERNEL_LOOP(index, n_boxes1) { + int b1 = index; + int b2 = index; + + int base1 = b1 * 5; + + float block_boxes1[5]; + float block_boxes2[5]; + + block_boxes1[0] = dev_boxes1[base1 + 0]; + block_boxes1[1] = dev_boxes1[base1 + 1]; + block_boxes1[2] = dev_boxes1[base1 + 2]; + block_boxes1[3] = dev_boxes1[base1 + 3]; + block_boxes1[4] = dev_boxes1[base1 + 4]; + + int base2 = b2 * 5; + + block_boxes2[0] = dev_boxes2[base2 + 0]; + block_boxes2[1] = dev_boxes2[base2 + 1]; + block_boxes2[2] = dev_boxes2[base2 + 2]; + block_boxes2[3] = dev_boxes2[base2 + 3]; + block_boxes2[4] = dev_boxes2[base2 + 4]; + + dev_ious[index] = + single_box_iou_rotated(block_boxes1, block_boxes2, mode_flag); + } + } else { + MUSA_1D_KERNEL_LOOP(index, n_boxes1 * n_boxes2) { + int b1 = index / n_boxes2; + int b2 = index % n_boxes2; + + int base1 = b1 * 5; + + float block_boxes1[5]; + float block_boxes2[5]; + + block_boxes1[0] = dev_boxes1[base1 + 0]; + block_boxes1[1] = dev_boxes1[base1 + 1]; + block_boxes1[2] = dev_boxes1[base1 + 2]; + block_boxes1[3] = dev_boxes1[base1 + 3]; + block_boxes1[4] = dev_boxes1[base1 + 4]; + + int base2 = b2 * 5; + + block_boxes2[0] = dev_boxes2[base2 + 0]; + block_boxes2[1] = dev_boxes2[base2 + 1]; + block_boxes2[2] = dev_boxes2[base2 + 2]; + block_boxes2[3] = dev_boxes2[base2 + 3]; + block_boxes2[4] = dev_boxes2[base2 + 4]; + + dev_ious[index] = + single_box_iou_rotated(block_boxes1, block_boxes2, mode_flag); + } + } +} + +#endif diff --git a/mmcv/ops/csrc/common/musa/carafe_musa_kernel.muh b/mmcv/ops/csrc/common/musa/carafe_musa_kernel.muh new file mode 100644 index 000000000..5fa2398cc --- /dev/null +++ b/mmcv/ops/csrc/common/musa/carafe_musa_kernel.muh @@ -0,0 +1,332 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef CARAFE_MUSA_KERNEL_MUH +#define CARAFE_MUSA_KERNEL_MUH + +#include "pytorch_musa_helper.hpp" + +#ifdef MMCV_WITH_HIP +#define WARP_SIZE 64 +#else +#define WARP_SIZE 32 +#endif +#define THREADS_PER_PIXEL 32 +#define MAX_SHARED_MEMORY 49152 +#define MAX_SHARED_SCALAR_T 6144 // 49152 / 8 = 6144 +#define MAXIMIZE_KERNEL_SIZE true +#define kTileDim 32 +#define kBlockRows 8 +#define FULL_MASK 0xffffffff + +inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); } + +__device__ inline int Loc2Index(const int n, const int c, const int h, + const int w, const int channel_num, + const int height, const int width) { + int index = w + (h + (c + n * channel_num) * height) * width; + return index; +} +#ifndef MMCV_WITH_HIP +/* TODO: move this to a common place */ +template +__device__ inline scalar_t min(scalar_t a, scalar_t b) { + return a < b ? a : b; +} + +template +__device__ inline scalar_t max(scalar_t a, scalar_t b) { + return a > b ? a : b; +} +#endif +template +__device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) { + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) +#ifdef MMCV_WITH_HIP + val += __shfl_down(val, offset); +#else + val += __shfl_down_sync(FULL_MASK, val, offset); +#endif + return val; +} + +template <> +__device__ __forceinline__ phalf warpReduceSum(phalf val) { + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) +#ifdef MMCV_WITH_HIP + // Using PyTorch's macro for half support + __PHALF(val) += WARP_SHFL_DOWN(val, offset); +#else + __PHALF(val) += + __shfl_down_sync(FULL_MASK, __PHALF(val).operator __half(), offset); +#endif + return val; +} + +// Splits the original matrix into submatrices with size 32 * 32. +// Each block transposes one submatrix by loading it into shared memory. +// Reference https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/ +template +__global__ void BatchTranspose2DMUSAKernel(const int N, const int H, + const int W, const int dh, + const int dw, + const scalar_t *__restrict__ X, + scalar_t *__restrict__ Y) { + __shared__ scalar_t tile[kTileDim][kTileDim + 1]; + const int n = blockIdx.x / (dh * dw); + const int k = blockIdx.x % (dh * dw); + const int r = k / dw; + const int c = k % dw; + const int offset = n * H * W; + int x = c * kTileDim + threadIdx.x; + int y = r * kTileDim + threadIdx.y; + if (x < W) { + for (int i = 0; threadIdx.y + i < kTileDim && y + i < H; i += kBlockRows) { + tile[threadIdx.y + i][threadIdx.x] = X[offset + (y + i) * W + x]; + } + } + __syncthreads(); + x = r * kTileDim + threadIdx.x; + y = c * kTileDim + threadIdx.y; + if (x < H) { + for (int i = 0; threadIdx.y + i < kTileDim && y + i < W; i += kBlockRows) { + Y[offset + (y + i) * H + x] = tile[threadIdx.x][threadIdx.y + i]; + } + } +} + +template +__global__ void CARAFEForward( + const int num_kernels, const scalar_t *__restrict__ bottom_data, + const scalar_t *__restrict__ bottom_masks, const int kernel_size, + const int group_size, const int scale_factor, const int channels, + const int down_height, const int down_width, const int height, + const int width, const int mask_channels, scalar_t *__restrict__ top_data) { +#if MAXIMIZE_KERNEL_SIZE + __shared__ float shared_mask[MAX_SHARED_SCALAR_T * 2]; +#else + __shared__ scalar_t shared_mask[MAX_SHARED_SCALAR_T]; +#endif + + + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index > num_kernels - 1) { + return; + } + const int pixel_id = threadIdx.x / THREADS_PER_PIXEL; + const int split_id = threadIdx.x % THREADS_PER_PIXEL; + index = index / THREADS_PER_PIXEL; + const int pw = index % width; + const int ph = (index / width) % height; + const int n = index / width / height; + + const int down_pw = pw / scale_factor; + const int down_ph = ph / scale_factor; + + const int start_w = down_pw - (kernel_size - 1) / 2; + const int end_w = down_pw + (kernel_size - 1) / 2 + 1; + const int start_h = down_ph - (kernel_size - 1) / 2; + const int end_h = down_ph + (kernel_size - 1) / 2 + 1; + for (int c = split_id; c < mask_channels; c += THREADS_PER_PIXEL) { + int mask_index = Loc2Index(n, ph, pw, c, height, width, mask_channels); + shared_mask[c * WARP_SIZE + pixel_id] = bottom_masks[mask_index]; + } + __syncthreads(); + + + const int channels_per_group = ceilf(channels / (float)group_size); +#pragma unroll + for (int c = split_id; c < channels; c += THREADS_PER_PIXEL) { + int mask_group = c / channels_per_group; + scalar_t output_val = 0; +#pragma unroll + for (int iy = start_h; iy < end_h; iy++) { +#pragma unroll + for (int ix = start_w; ix < end_w; ix++) { + if (iy < 0 || iy > down_height - 1 || ix < 0 || ix > down_width - 1) { + continue; + } + int mask_iy = iy - down_ph + (kernel_size - 1) / 2; + int mask_ix = ix - down_pw + (kernel_size - 1) / 2; + int mask_c = + (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix; + int feat_index = + Loc2Index(n, iy, ix, c, down_height, down_width, channels); + + output_val += bottom_data[feat_index] * + shared_mask[mask_c * WARP_SIZE + pixel_id]; + } + } + + int top_index = Loc2Index(n, ph, pw, c, height, width, channels); + top_data[top_index] = output_val; + } +} + +template +__global__ void CARAFEBackward_Feature( + const int num_kernels, const scalar_t *__restrict__ top_diff, + const scalar_t *__restrict__ bottom_masks, const int kernel_size, + const int group_size, const int scale_factor, const int channels, + const int down_height, const int down_width, const int height, + const int width, const int mask_channels, + scalar_t *__restrict__ bottom_diff) { +#if MAXIMIZE_KERNEL_SIZE + __shared__ float shared_mask[MAX_SHARED_SCALAR_T * 2]; +#else + __shared__ scalar_t shared_mask[MAX_SHARED_SCALAR_T]; +#endif + + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index > num_kernels - 1) { + return; + } + + const int pixel_id = threadIdx.x / THREADS_PER_PIXEL; + const int split_id = threadIdx.x % THREADS_PER_PIXEL; + // (n, c, ph, pw) is an element in the bottom_data + index = index / THREADS_PER_PIXEL; + const int pw = index % width; + const int ph = (index / width) % height; + const int n = index / width / height; + + const int start_w = pw - (kernel_size - 1) * scale_factor / 2; + const int end_w = pw + (kernel_size - 1) * scale_factor / 2 + 1; + const int start_h = ph - (kernel_size - 1) * scale_factor / 2; + const int end_h = ph + (kernel_size - 1) * scale_factor / 2 + 1; + for (int c = split_id; c < mask_channels; c += THREADS_PER_PIXEL) { + const int mask_w = (c % kernel_size) * scale_factor; + const int mask_h = (c / kernel_size % kernel_size) * scale_factor; + const int mask_x = start_w + mask_w; + const int mask_y = start_h + mask_h; + if (mask_y < 0 || mask_y > height - 1 || mask_x < 0 || mask_x > width - 1) { + shared_mask[c * WARP_SIZE + pixel_id] = 0; + continue; + } + const int mask_group = c / (kernel_size * kernel_size); + const int mask_c = (2 * mask_group + 1) * kernel_size * kernel_size - c - 1; + int mask_index = + Loc2Index(n, mask_c, mask_y, mask_x, mask_channels, height, width); + shared_mask[c * WARP_SIZE + pixel_id] = bottom_masks[mask_index]; + } + __syncthreads(); + const int channels_per_group = ceilf(channels / (float)group_size); +#pragma unroll + for (int c = split_id; c < channels; c += THREADS_PER_PIXEL) { + int mask_group = c / channels_per_group; + int top_index = Loc2Index(n, ph, pw, c, height, width, channels); + scalar_t output_val = 0; +#pragma unroll + for (int iy = start_h; iy < end_h; iy += scale_factor) { +#pragma unroll + for (int ix = start_w; ix < end_w; ix += scale_factor) { + if (iy < 0 || iy > height - 1 || ix < 0 || ix > width - 1) { + continue; + } + int mask_iy = + (iy - ph + (kernel_size - 1) * scale_factor / 2) / scale_factor; + int mask_ix = + (ix - pw + (kernel_size - 1) * scale_factor / 2) / scale_factor; + int mask_c = + (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix; + int feat_index = Loc2Index(n, iy, ix, c, height, width, channels); + output_val += + shared_mask[mask_c * WARP_SIZE + pixel_id] * top_diff[feat_index]; + } + } + bottom_diff[top_index] = output_val; + } +} + +template +__global__ void FeatureSum(const int num_kernels, + const scalar_t *__restrict__ input_data, + const int scale_factor, const int channels, + const int height, const int width, + scalar_t *__restrict__ output_data) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index > num_kernels - 1) { + return; + } + const int split_id = threadIdx.x % THREADS_PER_PIXEL; + index = index / THREADS_PER_PIXEL; + const int pw = index % width; + const int ph = (index / width) % height; + const int n = index / width / height; + for (int c = split_id; c < channels; c += THREADS_PER_PIXEL) { + scalar_t output_val = 0; + for (int iy = ph * scale_factor; iy < (ph + 1) * scale_factor; iy++) { + for (int ix = pw * scale_factor; ix < (pw + 1) * scale_factor; ix++) { + int input_id = Loc2Index(n, iy, ix, c, height * scale_factor, + width * scale_factor, channels); + output_val += input_data[input_id]; + } + } + const int output_id = Loc2Index(n, ph, pw, c, height, width, channels); + output_data[output_id] = output_val; + } +} + +template +__global__ void CARAFEBackward_Mask(const int num_kernels, + const scalar_t *__restrict__ top_diff, + const scalar_t *__restrict__ bottom_data, + const int kernel_size, const int group_size, + const int scale_factor, const int channels, + const int down_height, const int down_width, + const int height, const int width, + const int mask_channels, + scalar_t *__restrict__ mask_diff) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index > num_kernels - 1) { + return; + } + + const int lane_id = index % WARP_SIZE; + index = index / WARP_SIZE; + const int mask_c = index % mask_channels; + // (n, c, ph, pw) is an element in the bottom_data + index = index / mask_channels; + const int pw = index % width; + const int ph = (index / width) % height; + const int n = index / width / height; + + const int down_pw = pw / scale_factor; + const int down_ph = ph / scale_factor; + + const int mask_group = mask_c / (kernel_size * kernel_size); + const int mask_loc = mask_c % (kernel_size * kernel_size); + + const int offset_x = mask_loc % kernel_size - (kernel_size - 1) / 2; + const int offset_y = + mask_loc / kernel_size % kernel_size - (kernel_size - 1) / 2; + + const int down_x = down_pw + offset_x; + const int down_y = down_ph + offset_y; + + scalar_t output_val = 0; + + if (down_y >= 0 && down_y <= down_height - 1 && down_x >= 0 && + down_x <= down_width - 1) { + const int channels_per_mask = ceilf(channels / (float)group_size); + const int start = channels_per_mask * mask_group; + const int end = min(channels_per_mask * (mask_group + 1), channels); + for (int c = start + lane_id; c < end; c += WARP_SIZE) { + int bottom_id = + Loc2Index(n, down_y, down_x, c, down_height, down_width, channels); + int top_id = Loc2Index(n, ph, pw, c, height, width, channels); + output_val += top_diff[top_id] * bottom_data[bottom_id]; + } + } +#ifdef MMCV_WITH_HIP + __syncthreads(); +#else + __syncwarp(); +#endif + output_val = warpReduceSum(output_val); + if (lane_id == 0) { + const int mask_id = + Loc2Index(n, ph, pw, mask_c, height, width, mask_channels); + mask_diff[mask_id] = output_val; + } +} + +#endif // CARAFE_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/carafe_naive_musa_kernel.muh b/mmcv/ops/csrc/common/musa/carafe_naive_musa_kernel.muh new file mode 100644 index 000000000..a05e3992d --- /dev/null +++ b/mmcv/ops/csrc/common/musa/carafe_naive_musa_kernel.muh @@ -0,0 +1,107 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef CARAFE_NAIVE_MUSA_KERNEL_MUH +#define CARAFE_NAIVE_MUSA_KERNEL_MUH + +#include "pytorch_musa_helper.hpp" + +__device__ inline int Loc2Index(const int n, const int c, const int h, + const int w, const int channel_num, + const int height, const int width) { + int index = w + (h + (c + n * channel_num) * height) * width; + return index; +} + +template +__global__ void carafe_naive_forward_musa_kernel( + const int nthreads, const scalar_t *bottom_data, + const scalar_t *bottom_masks, scalar_t *top_data, const int kernel_size, + const int group_size, const int scale_factor, const int channels, + const int height, const int width) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the bottom_data + int pw = index % width; + int ph = (index / width) % height; + int c = (index / width / height) % channels; + int n = index / width / height / channels; + + int mask_channels = kernel_size * kernel_size * group_size; + int mask_group = c / (channels / group_size); + + int down_pw = pw / scale_factor; + int down_ph = ph / scale_factor; + int down_width = width / scale_factor; + int down_height = height / scale_factor; + int start_w = down_pw - (kernel_size - 1) / 2; + int end_w = down_pw + (kernel_size - 1) / 2 + 1; + int start_h = down_ph - (kernel_size - 1) / 2; + int end_h = down_ph + (kernel_size - 1) / 2 + 1; + + scalar_t output_val = 0; + for (int iy = start_h; iy < end_h; iy++) { + for (int ix = start_w; ix < end_w; ix++) { + if (iy < 0 || iy > down_height - 1 || ix < 0 || ix > down_width - 1) { + continue; + } + int mask_iy = iy - down_ph + (kernel_size - 1) / 2; + int mask_ix = ix - down_pw + (kernel_size - 1) / 2; + int mask_c = + (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix; + int feat_index = + Loc2Index(n, c, iy, ix, channels, down_height, down_width); + int mask_index = + Loc2Index(n, mask_c, ph, pw, mask_channels, height, width); + output_val += bottom_data[feat_index] * bottom_masks[mask_index]; + } + } + top_data[index] = output_val; + } +} + +template +__global__ void carafe_naive_backward_musa_kernel( + const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_data, + const scalar_t *bottom_masks, scalar_t *bottom_diff, scalar_t *mask_diff, + const int kernel_size, const int group_size, const int scale_factor, + const int channels, const int height, const int width) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the bottom_data + int pw = index % width; + int ph = (index / width) % height; + int c = (index / width / height) % channels; + int n = index / width / height / channels; + + int mask_channels = kernel_size * kernel_size * group_size; + int mask_group = c / (channels / group_size); + + int down_pw = pw / scale_factor; + int down_ph = ph / scale_factor; + int down_width = width / scale_factor; + int down_height = height / scale_factor; + int start_w = down_pw - (kernel_size - 1) / 2; + int end_w = down_pw + (kernel_size - 1) / 2 + 1; + int start_h = down_ph - (kernel_size - 1) / 2; + int end_h = down_ph + (kernel_size - 1) / 2 + 1; + + for (int iy = start_h; iy < end_h; iy++) { + for (int ix = start_w; ix < end_w; ix++) { + if (iy < 0 || iy > down_height - 1 || ix < 0 || ix > down_width - 1) { + continue; + } + int mask_iy = iy - down_ph + (kernel_size - 1) / 2; + int mask_ix = ix - down_pw + (kernel_size - 1) / 2; + int mask_c = + (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix; + int feat_index = + Loc2Index(n, c, iy, ix, channels, down_height, down_width); + int mask_index = + Loc2Index(n, mask_c, ph, pw, mask_channels, height, width); + atomicAdd(bottom_diff + feat_index, + bottom_masks[mask_index] * top_diff[index]); + atomicAdd(mask_diff + mask_index, + bottom_data[feat_index] * top_diff[index]); + } + } + } +} + +#endif // CARAFE_NAIVE_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/chamfer_distance_musa_kernel.muh b/mmcv/ops/csrc/common/musa/chamfer_distance_musa_kernel.muh new file mode 100644 index 000000000..f59bdebc7 --- /dev/null +++ b/mmcv/ops/csrc/common/musa/chamfer_distance_musa_kernel.muh @@ -0,0 +1,101 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cu +#ifndef CHAMFER_DISTANCE_MUSA_KERNEL_MUH +#define CHAMFER_DISTANCE_MUSA_KERNEL_MUH + +#include "pytorch_musa_helper.hpp" +#define MAX_SHARED_SCALAR_T 6144 // 49152 / 8 = 6144 + +#if MUSA_ARCH > 21 +template +__global__ void chamfer_distance_forward_musa_kernel(int b, int n, + const scalar_t* xyz, int m, + const scalar_t* xyz2, + scalar_t* result, + int* result_i) { + __shared__ scalar_t buf[MAX_SHARED_SCALAR_T]; + for (int i = blockIdx.x; i < b; i += gridDim.x) { + for (int k2 = 0; k2 < m; k2 += THREADS_PER_BLOCK) { + int end_k = min(m, k2 + THREADS_PER_BLOCK) - k2; + for (int j = threadIdx.x; j < end_k * 2; j += blockDim.x) { + buf[j] = xyz2[(i * m + k2) * 2 + j]; + } + __syncthreads(); + for (int j = threadIdx.x; j < n; j += blockDim.x * gridDim.y) { + scalar_t x1 = xyz[(i * n + j) * 2 + 0]; + scalar_t y1 = xyz[(i * n + j) * 2 + 1]; + int best_i = 0; + scalar_t best = 1e10; + int end_ka = end_k & (~3); + if (end_ka == THREADS_PER_BLOCK) { + for (int k = 0; k < THREADS_PER_BLOCK; k += 4) { +#pragma unroll + for (int j = 0; j < 4; ++j) { + scalar_t x2 = buf[(k + j) * 2] - x1; + scalar_t y2 = buf[(k + j) * 2 + 1] - y1; + scalar_t d = x2 * x2 + y2 * y2; + if (d < best) { + best = d; + best_i = k + k2 + j; + } + } + } + } else { + for (int k = 0; k < end_ka; k += 4) { +#pragma unroll + for (int j = 0; j < 4; ++j) { + scalar_t x2 = buf[(k + j) * 2] - x1; + scalar_t y2 = buf[(k + j) * 2 + 1] - y1; + scalar_t d = x2 * x2 + y2 * y2; + if (d < best) { + best = d; + best_i = k + k2 + j; + } + } + } + } + for (int k = end_ka; k < end_k; k++) { + scalar_t x2 = buf[k * 2 + 0] - x1; + scalar_t y2 = buf[k * 2 + 1] - y1; + scalar_t d = x2 * x2 + y2 * y2; + if (k == 0 || d < best) { + best = d; + best_i = k + k2; + } + } + if (k2 == 0 || result[(i * n + j)] > best) { + result[(i * n + j)] = best; + result_i[(i * n + j)] = best_i; + } + } + __syncthreads(); + } + } +} + +template +__global__ void chamfer_distance_backward_musa_kernel( + int b, int n, const scalar_t* xyz1, int m, const scalar_t* xyz2, + const scalar_t* grad_dist1, const int* idx1, scalar_t* grad_xyz1, + scalar_t* grad_xyz2) { + for (int i = blockIdx.x; i < b; i += gridDim.x) { + for (int j = threadIdx.x; j < n; j += blockDim.x * gridDim.y) { + scalar_t x1 = xyz1[(i * n + j) * 2 + 0]; + scalar_t y1 = xyz1[(i * n + j) * 2 + 1]; + int j2 = idx1[i * n + j]; + scalar_t x2 = xyz2[(i * m + j2) * 2 + 0]; + scalar_t y2 = xyz2[(i * m + j2) * 2 + 1]; + scalar_t g = grad_dist1[i * n + j] * 2; + atomicAdd(&(grad_xyz1[(i * n + j) * 2 + 0]), g * (x1 - x2)); + atomicAdd(&(grad_xyz1[(i * n + j) * 2 + 1]), g * (y1 - y2)); + atomicAdd(&(grad_xyz2[(i * m + j2) * 2 + 0]), -(g * (x1 - x2))); + atomicAdd(&(grad_xyz2[(i * m + j2) * 2 + 1]), -(g * (y1 - y2))); + } + } +} +#else +#warning "chamfer_distance is supported when MUSA_ARCH > 21" +#endif //MUSA_ARCH + +#endif // CHAMFER_DISTANCE_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/convex_iou_musa_kernel.muh b/mmcv/ops/csrc/common/musa/convex_iou_musa_kernel.muh new file mode 100644 index 000000000..fd708bc10 --- /dev/null +++ b/mmcv/ops/csrc/common/musa/convex_iou_musa_kernel.muh @@ -0,0 +1,827 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef CONVEX_IOU_MUSA_KERNEL_MUH +#define CONVEX_IOU_MUSA_KERNEL_MUH + +#include "pytorch_musa_helper.hpp" + +#define MAXN 100 +#define NMAX 512 +__device__ const double EPS = 1E-8; + +__device__ inline int sig(double d) { return (d > EPS) - (d < -EPS); } + +struct Point { + double x, y; + __device__ Point() {} + __device__ Point(double x, double y) : x(x), y(y) {} +}; + +__device__ inline bool point_same(Point& a, Point& b) { + return sig(a.x - b.x) == 0 && sig(a.y - b.y) == 0; +} + +__device__ inline void swap1(Point* a, Point* b) { + Point temp; + temp.x = a->x; + temp.y = a->y; + + a->x = b->x; + a->y = b->y; + + b->x = temp.x; + b->y = temp.y; +} + +__device__ inline void reverse1(Point* a, const int n) { + for (int i = 0; i < (n - 1) / 2.0; i++) { + Point* j = &(a[i]); + Point* k = &(a[n - 1 - i]); + swap1(j, k); + } +} + +__device__ inline double cross(Point o, Point a, Point b) { + return (a.x - o.x) * (b.y - o.y) - (b.x - o.x) * (a.y - o.y); +} + +__device__ inline double dis(Point a, Point b) { + return (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y); +} +__device__ inline double area(Point* ps, int n) { + ps[n] = ps[0]; + double res = 0; + for (int i = 0; i < n; i++) { + res += ps[i].x * ps[i + 1].y - ps[i].y * ps[i + 1].x; + } + return res / 2.0; +} +__device__ inline double polygon_area_grad(Point* ps, int n, + int* polygon_to_pred_index, + int n_pred, double* grad_C) { + ps[n] = ps[0]; + double partion_grad[4 * 30 + 2]; + double res = 0; + for (int i = 0; i < n; i++) { + res += ps[i].x * ps[i + 1].y - ps[i].y * ps[i + 1].x; + partion_grad[i * 4 + 2] = ps[i + 1].y; + partion_grad[i * 4 + 3] = -ps[i + 1].x; + if (i != n - 1) { + partion_grad[i * 4 + 4] = -ps[i].y; + partion_grad[i * 4 + 5] = ps[i].x; + } else { + partion_grad[0] = -ps[i].y; + partion_grad[1] = ps[i].x; + } + } + for (int i = 0; i < n; i++) { + for (int j = 0; j < n_pred; j++) { + if (i == polygon_to_pred_index[j]) { + grad_C[2 * polygon_to_pred_index[j + n_pred]] = + (partion_grad[i * 4] + partion_grad[i * 4 + 2]) / 2; + break; + } + } + for (int j = 0; j < n_pred; j++) { + if (i == polygon_to_pred_index[j]) { + grad_C[2 * polygon_to_pred_index[j + n_pred] + 1] = + (partion_grad[i * 4 + 1] + partion_grad[i * 4 + 1 + 2]) / 2; + break; + } + } + } + + return res / 2.0; +} + +__device__ inline int lineCross(Point a, Point b, Point c, Point d, Point& p, + double* cut_grad, int m, int n, int i) { + double s1, s2; + double s2_s1_2; + double ds1_dxc, ds1_dyc, ds2_dxd, ds2_dyd; + double dxp_dxc, dxp_dyc, dxp_dxd, dxp_dyd, dyp_dxc, dyp_dyc, dyp_dxd, dyp_dyd; + s1 = cross(a, b, c); + s2 = cross(a, b, d); + + ds1_dxc = -(b.y - a.y); + ds1_dyc = b.x - a.x; + ds2_dxd = ds1_dxc; + ds2_dyd = ds1_dyc; + s2_s1_2 = (s2 - s1) * (s2 - s1); + + if (sig(s1) == 0 && sig(s2) == 0) return 2; + if (sig(s2 - s1) == 0) return 0; + + dxp_dxc = + ((s2 - d.x * ds1_dxc) * (s2 - s1) - (c.x * s2 - d.x * s1) * (-ds1_dxc)) / + (s2_s1_2); + dxp_dyc = + ((0 - d.x * ds1_dyc) * (s2 - s1) - (c.x * s2 - d.x * s1) * (-ds1_dyc)) / + (s2_s1_2); + dxp_dxd = + ((c.x * ds2_dxd - s1) * (s2 - s1) - (c.x * s2 - d.x * s1) * (ds2_dxd)) / + (s2_s1_2); + dxp_dyd = + ((c.x * ds2_dyd - 0) * (s2 - s1) - (c.x * s2 - d.x * s1) * (ds2_dyd)) / + (s2_s1_2); + + dyp_dxc = + ((0 - d.y * ds1_dxc) * (s2 - s1) - (c.y * s2 - d.y * s1) * (-ds1_dxc)) / + (s2_s1_2); + dyp_dyc = + ((s2 - d.y * ds1_dyc) * (s2 - s1) - (c.y * s2 - d.y * s1) * (-ds1_dyc)) / + (s2_s1_2); + dyp_dxd = + ((c.y * ds2_dxd - 0) * (s2 - s1) - (c.y * s2 - d.y * s1) * (ds2_dxd)) / + (s2_s1_2); + dyp_dyd = + ((c.y * ds2_dyd - s1) * (s2 - s1) - (c.y * s2 - d.y * s1) * (ds2_dyd)) / + (s2_s1_2); + + p.x = (c.x * s2 - d.x * s1) / (s2 - s1); + p.y = (c.y * s2 - d.y * s1) / (s2 - s1); + if (i == n - 1) { + cut_grad[4 * n * m + 4 * i] = dxp_dxc; // + dyp_dxc; + cut_grad[4 * n * m + 4 * i + 1] = dyp_dxc; + cut_grad[4 * n * m + 4 * i + 2] = dxp_dyc; // + dyp_dyc; + cut_grad[4 * n * m + 4 * i + 3] = dyp_dyc; + cut_grad[4 * n * m + 0] = dxp_dxd; // + dyp_dxd; + cut_grad[4 * n * m + 1] = dyp_dxd; + cut_grad[4 * n * m + 2] = dxp_dyd; // + dyp_dyd; + cut_grad[4 * n * m + 3] = dyp_dyd; + } else { + cut_grad[4 * n * m + 4 * i] = dxp_dxc; // + dyp_dxc; + cut_grad[4 * n * m + 4 * i + 1] = dyp_dxc; + cut_grad[4 * n * m + 4 * i + 2] = dxp_dyc; // + dyp_dyc; + cut_grad[4 * n * m + 4 * i + 3] = dyp_dyc; + cut_grad[4 * n * m + 4 * (i + 1)] = dxp_dxd; // + dyp_dxd; + cut_grad[4 * n * m + 4 * (i + 1) + 1] = dyp_dxd; + cut_grad[4 * n * m + 4 * (i + 1) + 2] = dxp_dyd; // + dyp_dyd; + cut_grad[4 * n * m + 4 * (i + 1) + 3] = dyp_dyd; + } + + return 1; +} +__device__ inline void polygon_cut(Point* p, int& n, Point a, Point b, + double* cut_grad) { + Point pp[MAXN]; + double ccur_grad[MAXN] = {}; + int m = 0; + p[n] = p[0]; + int k = n; + for (int i = 0; i < n; i++) { + if (sig(cross(a, b, p[i])) > 0) { + pp[m] = p[i]; + ccur_grad[4 * n * m + 4 * i] = 1.0; + ccur_grad[4 * n * m + 4 * i + 3] = 1.0; + m++; + } + if (sig(cross(a, b, p[i])) != sig(cross(a, b, p[i + 1]))) { + lineCross(a, b, p[i], p[i + 1], pp[m], ccur_grad, m, n, i); + m++; + } + } + + n = 0; + for (int i = 0; i < m; i++) { + if (!i || !(point_same(pp[i], pp[i - 1]))) { + p[n] = pp[i]; + for (int j = 0; j < 4 * k; j++) { + cut_grad[4 * k * n + j] = ccur_grad[4 * k * i + j]; + } + n++; + } + } + + while (n > 1 && point_same(p[n - 1], p[0])) n--; +} + +__device__ inline double intersectArea(Point a, Point b, Point c, Point d, + double* grad_AB, int order, + int convex_n) { + Point o(0, 0); + int res_flag = 0; + int s1 = sig(cross(o, a, b)); + int s2 = sig(cross(o, c, d)); + if (s1 == 0 || s2 == 0) return 0.0; + if (s1 == -1) { + Point* i = &a; + Point* j = &b; + swap1(i, j); + res_flag = 1; + } + if (s2 == -1) { + Point* i = &c; + Point* j = &d; + swap1(i, j); + } + Point p[10] = {o, a, b}; + int n = 3, n0 = 3, n1, n2, n3; + double cut_grad1[MAXN] = {}; + double cut_grad2[MAXN] = {}; + double cut_grad3[MAXN] = {}; + double p1_p_grad[10][10] = {}; + double p2_p1_grad[10][10] = {}; + double p3_p2_grad[10][10] = {}; + + double p3_p1_grad[10][10] = {}; + double p3_p_grad[10][10] = {}; + + // 1 + polygon_cut(p, n, o, c, cut_grad1); + n1 = n; + for (int i = 0; i < n; i++) { + for (int j = 0; j < 4 * n0; j++) { + if (!(j % 2)) { + p1_p_grad[2 * i][j / 2] = cut_grad1[4 * n0 * i + j]; + } else { + p1_p_grad[2 * i + 1][j / 2] = cut_grad1[4 * n0 * i + j]; + } + } + } + + // 2 + polygon_cut(p, n, c, d, cut_grad2); + n2 = n; + for (int i = 0; i < n; i++) { + for (int j = 0; j < 4 * n1; j++) { + if (!(j % 2)) { + p2_p1_grad[2 * i][j / 2] = cut_grad2[4 * n1 * i + j]; + } else { + p2_p1_grad[2 * i + 1][j / 2] = cut_grad2[4 * n1 * i + j]; + } + } + } + // 3 + polygon_cut(p, n, d, o, cut_grad3); + n3 = n; + for (int i = 0; i < n; i++) { + for (int j = 0; j < 4 * n2; j++) { + if (!(j % 2)) { + p3_p2_grad[2 * i][j / 2] = cut_grad3[4 * n2 * i + j]; + } else { + p3_p2_grad[2 * i + 1][j / 2] = cut_grad3[4 * n2 * i + j]; + } + } + } + + // mul + // p3_p2(n3 * n2) * p2_p1(n2 * n1) = p3_p1 (n3 * n1) + for (int i = 0; i < 2 * n3; i++) { + for (int j = 0; j < 2 * n1; j++) { + double sum = 0.0; + for (int m = 0; m < 2 * n2; m++) { + sum = sum + p3_p2_grad[i][m] * p2_p1_grad[m][j]; + } + p3_p1_grad[i][j] = sum; + } + } + + // p3_p1 (n3 * n1) * p1_p (n1 * n0) = p3_p (n3 * n0) + for (int i = 0; i < 2 * n3; i++) { + for (int j = 0; j < 2 * n0; j++) { + double sum = 0.0; + for (int m = 0; m < 2 * n1; m++) { + sum = sum + p3_p1_grad[i][m] * p1_p_grad[m][j]; + } + p3_p_grad[i][j] = sum; + } + } + + // calculate S_grad + int polygon_index_box_index[20]; + double grad_polygon[20]; + double S_grad[6]; + + for (int i = 0; i < n3; i++) { + polygon_index_box_index[i] = i; + polygon_index_box_index[i + n3] = i; + } + + double res = + polygon_area_grad(p, n3, polygon_index_box_index, n3, grad_polygon); + + if (s1 * s2 == -1) { + for (int j = 0; j < 2 * 3; j++) { + double sum = 0.0; + for (int m = 0; m < 2 * n3; m++) { + sum = sum - grad_polygon[m] * p3_p_grad[m][j]; + } + S_grad[j] = sum; + } + + if (order != convex_n - 1) { + if (res_flag) { + grad_AB[2 * order] += S_grad[4]; + grad_AB[2 * order + 1] += S_grad[5]; + grad_AB[2 * order + 2] += S_grad[2]; + grad_AB[2 * order + 3] += S_grad[3]; + + } else { + grad_AB[2 * order] += S_grad[2]; + grad_AB[2 * order + 1] += S_grad[3]; + grad_AB[2 * order + 2] += S_grad[4]; + grad_AB[2 * order + 3] += S_grad[5]; + } + } else { + if (res_flag) { + grad_AB[2 * order] += S_grad[4]; + grad_AB[2 * order + 1] += S_grad[5]; + grad_AB[0] += S_grad[2]; + grad_AB[1] += S_grad[3]; + + } else { + grad_AB[2 * order] += S_grad[2]; + grad_AB[2 * order + 1] += S_grad[3]; + grad_AB[0] += S_grad[4]; + grad_AB[1] += S_grad[5]; + } + } + res = -res; + } else { + for (int j = 0; j < 2 * 3; j++) { + double sum = 0.0; + for (int m = 0; m < 2 * n3; m++) { + sum = sum + grad_polygon[m] * p3_p_grad[m][j]; + } + S_grad[j] = sum; + } + + if (order != convex_n - 1) { + if (res_flag) { + grad_AB[2 * order] += S_grad[4]; + grad_AB[2 * order + 1] += S_grad[5]; + grad_AB[2 * order + 2] += S_grad[2]; + grad_AB[2 * order + 3] += S_grad[3]; + } else { + grad_AB[2 * order] += S_grad[2]; + grad_AB[2 * order + 1] += S_grad[3]; + grad_AB[2 * order + 2] += S_grad[4]; + grad_AB[2 * order + 3] += S_grad[5]; + } + } else { + if (res_flag) { + grad_AB[2 * order] += S_grad[4]; + grad_AB[2 * order + 1] += S_grad[5]; + grad_AB[0] += S_grad[2]; + grad_AB[1] += S_grad[3]; + } else { + grad_AB[2 * order] += S_grad[2]; + grad_AB[2 * order + 1] += S_grad[3]; + grad_AB[0] += S_grad[4]; + grad_AB[1] += S_grad[5]; + } + } + } + return res; +} + +__device__ inline double intersectAreaO(Point* ps1, int n1, Point* ps2, int n2, + double* grad_AB) { + if (area(ps1, n1) < 0) reverse1(ps1, n1); + if (area(ps2, n2) < 0) reverse1(ps2, n2); + ps1[n1] = ps1[0]; + ps2[n2] = ps2[0]; + double res = 0; + for (int i = 0; i < n1; i++) { + for (int j = 0; j < n2; j++) { + res += + intersectArea(ps1[i], ps1[i + 1], ps2[j], ps2[j + 1], grad_AB, i, n1); + } + } + return res; +} + +__device__ inline void Jarvis(Point* in_poly, int& n_poly) { + Point p_max, p_k; + int max_index, k_index; + int Stack[NMAX] = {}, top1, top2; + double sign; + Point right_point[10], left_point[10]; + + for (int i = 0; i < n_poly; i++) { + if (in_poly[i].y < in_poly[0].y || + in_poly[i].y == in_poly[0].y && in_poly[i].x < in_poly[0].x) { + Point* j = &(in_poly[0]); + Point* k = &(in_poly[i]); + swap1(j, k); + } + if (i == 0) { + p_max = in_poly[0]; + max_index = 0; + } + if (in_poly[i].y > p_max.y || + in_poly[i].y == p_max.y && in_poly[i].x > p_max.x) { + p_max = in_poly[i]; + max_index = i; + } + } + + if (max_index == 0) { + max_index = 1; + p_max = in_poly[max_index]; + } + + k_index = 0, Stack[0] = 0, top1 = 0; + while (k_index != max_index) { + p_k = p_max; + k_index = max_index; + for (int i = 1; i < n_poly; i++) { + sign = cross(in_poly[Stack[top1]], in_poly[i], p_k); + if ((sign > 0) || ((sign == 0) && (dis(in_poly[Stack[top1]], in_poly[i]) > + dis(in_poly[Stack[top1]], p_k)))) { + p_k = in_poly[i]; + k_index = i; + } + } + top1++; + Stack[top1] = k_index; + } + for (int i = 0; i <= top1; i++) right_point[i] = in_poly[Stack[i]]; + + k_index = 0, Stack[0] = 0, top2 = 0; + + while (k_index != max_index) { + p_k = p_max; + k_index = max_index; + for (int i = 1; i < n_poly; i++) { + sign = cross(in_poly[Stack[top2]], in_poly[i], p_k); + if ((sign < 0) || (sign == 0) && (dis(in_poly[Stack[top2]], in_poly[i]) > + dis(in_poly[Stack[top2]], p_k))) { + p_k = in_poly[i]; + k_index = i; + } + } + top2++; + Stack[top2] = k_index; + } + for (int i = top2 - 1; i >= 0; i--) left_point[i] = in_poly[Stack[i]]; + + for (int i = 0; i < top1 + top2; i++) { + if (i <= top1) { + in_poly[i] = right_point[i]; + } else { + in_poly[i] = left_point[top2 - (i - top1)]; + } + } + n_poly = top1 + top2; +} + +__device__ inline double intersectAreaPoly(Point* ps1, int n1, Point* ps2, + int n2, double* grad_C) { + Point polygon[MAXN]; + int n = n1 + n2, n_poly = 0; + for (int i = 0; i < n1; i++) { + for (int j = 0; j < n - n1; j++) { + if (point_same(ps1[i], ps2[j])) { + for (int k = j; k < n - n1 - 1; k++) { + ps2[k] = ps2[k + 1]; + } + n2--; + break; + } + } + } + n_poly = n1 + n2; + for (int i = 0; i < n_poly; i++) { + if (i < n1) { + polygon[i] = ps1[i]; + } else { + polygon[i] = ps2[i - n1]; + } + } + + Jarvis(polygon, n_poly); + + int polygon_to_pred_index[18] = {-1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1}; + int n_pred = 0; + for (int i = 0; i < n_poly; i++) { + for (int j = 0; j < n1; j++) { + if (polygon[i].x == ps1[j].x && polygon[i].y == ps1[j].y) { + polygon_to_pred_index[n_pred] = i; + polygon_to_pred_index[n_pred + n1] = j; + n_pred += 1; + break; + } + } + } + if (n_pred == 0) { + double polygon_area = fabs(area(polygon, n_poly)); + for (int i = 0; i < 18; i++) { + grad_C[i] = 0.0; + } + return polygon_area; + } else { + double polygon_area = + polygon_area_grad(polygon, n_poly, polygon_to_pred_index, n1, grad_C); + if (polygon_area < 0) { + for (int i = 0; i < 18; i++) { + grad_C[i] = -grad_C[i]; + } + } + return fabs(polygon_area); + } +} + +// convex_find and get the polygon_index_box_index +__device__ inline void Jarvis_and_index(Point* in_poly, int& n_poly, + int* points_to_convex_ind) { + int n_input = n_poly; + Point input_poly[20]; + for (int i = 0; i < n_input; i++) { + input_poly[i].x = in_poly[i].x; + input_poly[i].y = in_poly[i].y; + } + Point p_max, p_k; + int max_index, k_index; + int Stack[20], top1, top2; + double sign; + Point right_point[10], left_point[10]; + + for (int i = 0; i < n_poly; i++) { + if (in_poly[i].y < in_poly[0].y || + in_poly[i].y == in_poly[0].y && in_poly[i].x < in_poly[0].x) { + Point* j = &(in_poly[0]); + Point* k = &(in_poly[i]); + swap1(j, k); + } + if (i == 0) { + p_max = in_poly[0]; + max_index = 0; + } + if (in_poly[i].y > p_max.y || + in_poly[i].y == p_max.y && in_poly[i].x > p_max.x) { + p_max = in_poly[i]; + max_index = i; + } + } + if (max_index == 0) { + max_index = 1; + p_max = in_poly[max_index]; + } + + k_index = 0, Stack[0] = 0, top1 = 0; + while (k_index != max_index) { + p_k = p_max; + k_index = max_index; + for (int i = 1; i < n_poly; i++) { + sign = cross(in_poly[Stack[top1]], in_poly[i], p_k); + if ((sign > 0) || ((sign == 0) && (dis(in_poly[Stack[top1]], in_poly[i]) > + dis(in_poly[Stack[top1]], p_k)))) { + p_k = in_poly[i]; + k_index = i; + } + } + top1++; + Stack[top1] = k_index; + } + for (int i = 0; i <= top1; i++) { + right_point[i] = in_poly[Stack[i]]; + } + + k_index = 0, Stack[0] = 0, top2 = 0; + + while (k_index != max_index) { + p_k = p_max; + k_index = max_index; + for (int i = 1; i < n_poly; i++) { + sign = cross(in_poly[Stack[top2]], in_poly[i], p_k); + if ((sign < 0) || (sign == 0) && (dis(in_poly[Stack[top2]], in_poly[i]) > + dis(in_poly[Stack[top2]], p_k))) { + p_k = in_poly[i]; + k_index = i; + } + } + top2++; + Stack[top2] = k_index; + } + + for (int i = top2 - 1; i >= 0; i--) { + left_point[i] = in_poly[Stack[i]]; + } + + for (int i = 0; i < top1 + top2; i++) { + if (i <= top1) { + in_poly[i] = right_point[i]; + } else { + in_poly[i] = left_point[top2 - (i - top1)]; + } + } + n_poly = top1 + top2; + for (int i = 0; i < n_poly; i++) { + for (int j = 0; j < n_input; j++) { + if (point_same(in_poly[i], input_poly[j])) { + points_to_convex_ind[i] = j; + break; + } + } + } +} + +template +__device__ inline float devrIoU(T const* const p, T const* const q, + T* point_grad, const int idx) { + Point ps1[MAXN], ps2[MAXN]; + + Point convex[MAXN]; + for (int i = 0; i < 9; i++) { + convex[i].x = (double)p[i * 2]; + convex[i].y = (double)p[i * 2 + 1]; + } + int n_convex = 9; + int points_to_convex_ind[9] = {-1, -1, -1, -1, -1, -1, -1, -1, -1}; + Jarvis_and_index(convex, n_convex, points_to_convex_ind); + + int n1 = n_convex; + int n2 = 4; + + for (int i = 0; i < n1; i++) { + ps1[i].x = (double)convex[i].x; + ps1[i].y = (double)convex[i].y; + } + + for (int i = 0; i < n2; i++) { + ps2[i].x = (double)q[i * 2]; + ps2[i].y = (double)q[i * 2 + 1]; + } + + int polygon_index_box_index[18]; + for (int i = 0; i < n1; i++) { + polygon_index_box_index[i] = i; + polygon_index_box_index[i + n1] = i; + } + + double grad_A[18] = {}; + double grad_AB[18] = {}; + double grad_C[18] = {}; + + double inter_area = intersectAreaO(ps1, n1, ps2, n2, grad_AB); + double S_pred = + polygon_area_grad(ps1, n1, polygon_index_box_index, n1, grad_A); + if (S_pred < 0) { + for (int i = 0; i < n_convex * 2; i++) { + grad_A[i] = -grad_A[i]; + } + } + double union_area = fabs(S_pred) + fabs(area(ps2, n2)) - inter_area; + + double iou = inter_area / union_area; + double polygon_area = intersectAreaPoly(ps1, n1, ps2, n2, grad_C); + + // printf("%d:live\n", idx); + double rot_giou = iou - (polygon_area - union_area) / polygon_area; + + float grad_point_temp[18] = {}; + + for (int i = 0; i < n_convex; i++) { + int grad_point = points_to_convex_ind[i]; + grad_point_temp[2 * grad_point] = + (float)((union_area + inter_area) / (union_area * union_area) * + grad_AB[2 * i] - + iou / union_area * grad_A[2 * i] - + 1 / polygon_area * (grad_AB[2 * i] - grad_A[2 * i]) - + (union_area) / polygon_area / polygon_area * grad_C[2 * i]); + grad_point_temp[2 * grad_point + 1] = + (float)((union_area + inter_area) / (union_area * union_area) * + grad_AB[2 * i + 1] - + iou / union_area * grad_A[2 * i + 1] - + 1 / polygon_area * (grad_AB[2 * i + 1] - grad_A[2 * i + 1]) - + (union_area) / polygon_area / polygon_area * grad_C[2 * i + 1]); + } + + for (int i = 0; i < 9; i++) { + point_grad[2 * i] = grad_point_temp[2 * i]; + point_grad[2 * i + 1] = grad_point_temp[2 * i + 1]; + } + return (float)rot_giou; +} + +template +__global__ void convex_giou_musa_kernel(const int ex_n_boxes, + const int gt_n_boxes, const T* ex_boxes, + const T* gt_boxes, T* point_grad) { + MUSA_1D_KERNEL_LOOP(index, ex_n_boxes) { + const T* cur_box = ex_boxes + index * 18; + const T* cur_gt_box = gt_boxes + index * 8; + T* cur_grad = point_grad + index * 19; + T giou = devrIoU(cur_box, cur_gt_box, cur_grad, threadIdx.x); + cur_grad[18] = giou; + } +} + +__device__ inline int lineCross(Point a, Point b, Point c, Point d, Point& p) { + double s1, s2; + s1 = cross(a, b, c); + s2 = cross(a, b, d); + if (sig(s1) == 0 && sig(s2) == 0) return 2; + if (sig(s2 - s1) == 0) return 0; + p.x = (c.x * s2 - d.x * s1) / (s2 - s1); + p.y = (c.y * s2 - d.y * s1) / (s2 - s1); + return 1; +} + +__device__ inline void polygon_cut(Point* p, int& n, Point a, Point b) { + Point pp[MAXN]; + int m = 0; + p[n] = p[0]; + for (int i = 0; i < n; i++) { + if (sig(cross(a, b, p[i])) > 0) { + pp[m] = p[i]; + m++; + } + if (sig(cross(a, b, p[i])) != sig(cross(a, b, p[i + 1]))) { + lineCross(a, b, p[i], p[i + 1], pp[m]); + m++; + } + } + n = 0; + for (int i = 0; i < m; i++) { + if (!i || !(point_same(pp[i], pp[i - 1]))) { + p[n] = pp[i]; + n++; + } + } + + while (n > 1 && point_same(p[n - 1], p[0])) n--; +} + +__device__ inline double intersectArea(Point a, Point b, Point c, Point d) { + Point o(0, 0); + int s1 = sig(cross(o, a, b)); + int s2 = sig(cross(o, c, d)); + if (s1 == 0 || s2 == 0) return 0.0; + if (s1 == -1) { + Point* i = &a; + Point* j = &b; + swap1(i, j); + } + if (s2 == -1) { + Point* i = &c; + Point* j = &d; + swap1(i, j); + } + Point p[10] = {o, a, b}; + int n = 3; + + polygon_cut(p, n, o, c); + polygon_cut(p, n, c, d); + polygon_cut(p, n, d, o); + double res = area(p, n); + if (s1 * s2 == -1) res = -res; + return res; +} +__device__ inline double intersectAreaO(Point* ps1, int n1, Point* ps2, + int n2) { + if (area(ps1, n1) < 0) reverse1(ps1, n1); + if (area(ps2, n2) < 0) reverse1(ps2, n2); + ps1[n1] = ps1[0]; + ps2[n2] = ps2[0]; + double res = 0; + for (int i = 0; i < n1; i++) { + for (int j = 0; j < n2; j++) { + res += intersectArea(ps1[i], ps1[i + 1], ps2[j], ps2[j + 1]); + } + } + return res; +} + +template +__device__ inline float devrIoU(T const* const p, T const* const q) { + Point ps1[MAXN], ps2[MAXN]; + Point convex[MAXN]; + for (int i = 0; i < 9; i++) { + convex[i].x = (double)p[i * 2]; + convex[i].y = (double)p[i * 2 + 1]; + } + int n_convex = 9; + int points_to_convex_ind[9] = {-1, -1, -1, -1, -1, -1, -1, -1, -1}; + Jarvis_and_index(convex, n_convex, points_to_convex_ind); + int n1 = n_convex; + for (int i = 0; i < n1; i++) { + ps1[i].x = (double)convex[i].x; + ps1[i].y = (double)convex[i].y; + } + int n2 = 4; + for (int i = 0; i < n2; i++) { + ps2[i].x = (double)q[i * 2]; + ps2[i].y = (double)q[i * 2 + 1]; + } + double inter_area = intersectAreaO(ps1, n1, ps2, n2); + double S_pred = area(ps1, n1); + double union_area = fabs(S_pred) + fabs(area(ps2, n2)) - inter_area; + double iou = inter_area / union_area; + return (float)iou; +} + +template +__global__ void convex_iou_musa_kernel(const int ex_n_boxes, + const int gt_n_boxes, const T* ex_boxes, + const T* gt_boxes, T* iou) { + MUSA_1D_KERNEL_LOOP(index, ex_n_boxes) { + const T* cur_box = ex_boxes + index * 18; + for (int i = 0; i < gt_n_boxes; i++) { + iou[index * gt_n_boxes + i] = devrIoU(cur_box, gt_boxes + i * 8); + } + } +} +#endif // CONVEX_IOU_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/correlation_musa.muh b/mmcv/ops/csrc/common/musa/correlation_musa.muh new file mode 100644 index 000000000..f5714cbe6 --- /dev/null +++ b/mmcv/ops/csrc/common/musa/correlation_musa.muh @@ -0,0 +1,227 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/ClementPinard/Pytorch-Correlation-extension/blob/master/Correlation_Module/correlation_cuda_kernel.cu +// Original licence: Under MIT License + +#ifndef CORRELATION_MUSA +#define CORRELATION_MUSA + +#include "pytorch_musa_helper.hpp" + +#include +#include +// Using is recommended in the official documentation in +// https://pytorch.org/tutorials/advanced/cpp_extension.html#writing-the-c-op. +// However, we use for compatibility with MUSA 9.0 +// Read https://github.com/pytorch/extension-cpp/issues/35 for more details. +#include + +#include +#include + +using namespace torch; + +#define TensorAcc4R PackedTensorAccessor32 +#define TensorAcc5R PackedTensorAccessor32 +#define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W) + +#define WARP_SIZE 32 +#define FULL_MASK 0xffffffff + +template +__global__ void correlation_forward_musa_kernel( + const TensorAcc4R rInput1, const TensorAcc4R rInput2, TensorAcc5R output, + int kH, int kW, int patchH, int patchW, int padH, int padW, int dilationH, + int dilationW, int dilation_patchH, int dilation_patchW, int dH, int dW, + int oH, int oW) { + const int iH = rInput1.size(1); + const int iW = rInput1.size(2); + const int C = rInput1.size(3); + + const int n = blockIdx.x; + const int h = blockIdx.y * blockDim.y + threadIdx.y; + const int w = blockIdx.z * blockDim.z + threadIdx.z; + + if (h >= oH || w >= oW) return; + + const int thread = threadIdx.x; + + const int start_i = -padH + h * dH; + const int start_j = -padW + w * dW; + + const int patchRadH = dilation_patchH * (patchH - 1) / 2; + const int patchRadW = dilation_patchW * (patchW - 1) / 2; + + for (int ph = 0; ph < patchH; ++ph) { + int ph_dilated = ph * dilation_patchH - patchRadH; + for (int pw = 0; pw < patchW; ++pw) { + int pw_dilated = pw * dilation_patchW - patchRadW; + scalar_t prod_sum = 0.0f; + for (int i = 0; i < kH; ++i) { + int i1 = start_i + i * dilationH; + int i2 = i1 + ph_dilated; + if (WITHIN_BOUNDS(i1, i2, iH, iH)) { + for (int j = 0; j < kW; ++j) { + int j1 = start_j + j * dilationW; + int j2 = j1 + pw_dilated; + if (WITHIN_BOUNDS(j1, j2, iW, iW)) { + for (int c = thread; c < C; c += WARP_SIZE) { + scalar_t v1 = rInput1[n][i1][j1][c]; + scalar_t v2 = rInput2[n][i2][j2][c]; + prod_sum += v1 * v2; + } + } + } + } + } + // accumulate + for (int offset = 16; offset > 0; offset /= 2) +#ifdef MMCV_WITH_HIP + prod_sum += __shfl_down(float(prod_sum), offset); +#else + prod_sum += __shfl_down_sync(FULL_MASK, float(prod_sum), offset); +#endif + if (thread == 0) { + output[n][ph][pw][h][w] = prod_sum; + } + } + } +} + +template +__global__ void correlation_backward_musa_kernel_input1( + const TensorAcc5R grad_output, const TensorAcc4R input2, + TensorAcc4R grad_input1, const int kH, const int kW, const int patchH, + const int patchW, const int padH, const int padW, const int dilationH, + const int dilationW, const int dilation_patchH, const int dilation_patchW, + const int dH, const int dW) { + const int iH = input2.size(1); + const int iW = input2.size(2); + const int C = input2.size(3); + + const int H = grad_output.size(3); + const int W = grad_output.size(4); + + const int patchRadH = (patchH - 1) / 2; + const int patchRadW = (patchW - 1) / 2; + + const int n = blockIdx.x; + const int h = blockIdx.y; + const int w = blockIdx.z; + + const int h_2 = h + padH; + const int w_2 = w + padW; + const int min_h = h_2 - kH * dilationH; + const int min_w = w_2 - kW * dilationW; + + extern __shared__ __align__(sizeof(4)) unsigned char grad_cache_char[]; + scalar_t *grad_cache = reinterpret_cast(grad_cache_char); + for (int i = threadIdx.x; i < patchH * patchW; i += blockDim.x) { + const int ph = i / patchW; + const int pw = i % patchW; + int i1 = h + dilation_patchH * (ph - patchRadH); + int j1 = w + dilation_patchW * (pw - patchRadW); + + if (WITHIN_BOUNDS(i1, j1, iH, iW)) { + scalar_t grad_val = 0.0f; + for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) { + int i2 = (h_3) / dH; + if (i2 * dH != h_3) continue; + for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) { + int j2 = (w_3) / dW; + if (j2 * dW != w_3) continue; + if (WITHIN_BOUNDS(i2, j2, H, W)) { + grad_val += grad_output[n][ph][pw][i2][j2]; + } + } + } + grad_cache[i] = grad_val; + } + } + __syncthreads(); + + for (int c = threadIdx.x; c < C; c += blockDim.x) { + scalar_t grad_input_val = 0.0f; + for (int ph = 0; ph < patchH; ++ph) { + int i1 = h + dilation_patchH * (ph - patchRadH); + for (int pw = 0; pw < patchW; ++pw) { + int j1 = w + dilation_patchW * (pw - patchRadW); + if (WITHIN_BOUNDS(i1, j1, iH, iW)) { + grad_input_val += input2[n][i1][j1][c] * grad_cache[ph * patchW + pw]; + } + } + } + grad_input1[n][c][h][w] = grad_input_val; + } +} + +template +__global__ void correlation_backward_musa_kernel_input2( + const TensorAcc5R grad_output, const TensorAcc4R input1, + TensorAcc4R grad_input2, int kH, int kW, int patchH, int patchW, int padH, + int padW, int dilationH, int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW) { + const int iH = input1.size(1); + const int iW = input1.size(2); + const int C = input1.size(3); + + const int patchRadH = (patchH - 1) / 2; + const int patchRadW = (patchW - 1) / 2; + + const int H = grad_output.size(3); + const int W = grad_output.size(4); + + const int dilatedKH = kH * dilationH; + const int dilatedKW = kW * dilationW; + + const int n = blockIdx.x; + const int h = blockIdx.y; + const int w = blockIdx.z; + + extern __shared__ __align__(sizeof(4)) unsigned char grad_cache_char[]; + scalar_t *grad_cache = reinterpret_cast(grad_cache_char); + for (int i = threadIdx.x; i < patchH * patchW; i += blockDim.x) { + const int ph = i / patchW; + const int pw = i % patchW; + int i1 = h - dilation_patchH * (ph - patchRadH); + int j1 = w - dilation_patchW * (pw - patchRadW); + + if (WITHIN_BOUNDS(i1, j1, iH, iW)) { + scalar_t grad_val = 0.0f; + + const int h_2 = i1 + padH; + const int w_2 = j1 + padW; + const int min_h = h_2 - dilatedKH; + const int min_w = w_2 - dilatedKW; + + for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) { + int i2 = (h_3) / dH; + if (i2 * dH != h_3) continue; + for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) { + int j2 = (w_3) / dW; + if (j2 * dW != w_3) continue; + if (WITHIN_BOUNDS(i2, j2, H, W)) { + grad_val += grad_output[n][ph][pw][i2][j2]; + } + } + } + grad_cache[i] = grad_val; + } + } + __syncthreads(); + + for (int c = threadIdx.x; c < C; c += blockDim.x) { + scalar_t grad_input_val = 0.0f; + for (int ph = 0; ph < patchH; ++ph) { + int i1 = h - dilation_patchH * (ph - patchRadH); + for (int pw = 0; pw < patchW; ++pw) { + int j1 = w - dilation_patchW * (pw - patchRadW); + if (WITHIN_BOUNDS(i1, j1, iH, iW)) { + grad_input_val += input1[n][i1][j1][c] * grad_cache[ph * patchW + pw]; + } + } + } + grad_input2[n][c][h][w] = grad_input_val; + } +} +#endif diff --git a/mmcv/ops/csrc/common/musa/deform_conv_musa_kernel.muh b/mmcv/ops/csrc/common/musa/deform_conv_musa_kernel.muh new file mode 100644 index 000000000..c636eaaa7 --- /dev/null +++ b/mmcv/ops/csrc/common/musa/deform_conv_musa_kernel.muh @@ -0,0 +1,360 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer + ***************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE + *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer + ********************* + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.muh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +#ifndef DEFORM_CONV_MUSA_KERNEL_MUH +#define DEFORM_CONV_MUSA_KERNEL_MUH + +#include +#include "pytorch_musa_helper.hpp" + + +template +__device__ T deformable_im2col_bilinear(const T *input, const int data_width, + const int height, const int width, T h, + T w) { + if (h <= -1 || height <= h || w <= -1 || width <= w) { + return 0; + } + + int h_low = floorf(h); + int w_low = floorf(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + T lh = h - h_low; + T lw = w - w_low; + T hh = 1 - lh, hw = 1 - lw; + + T v1 = 0; + if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low]; + T v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = input[h_low * data_width + w_high]; + T v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = input[h_high * data_width + w_low]; + T v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = input[h_high * data_width + w_high]; + + T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ T get_gradient_weight(T argmax_h, T argmax_w, const int h, + const int w, const int height, + const int width) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + // empty + return 0; + } + + int argmax_h_low = floorf(argmax_h); + int argmax_w_low = floorf(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + T weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ T get_coordinate_weight(T argmax_h, T argmax_w, const int height, + const int width, const T *im_data, + const int data_width, const int bp_dir) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + // empty + return 0; + } + + int argmax_h_low = floorf(argmax_h); + int argmax_w_low = floorf(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + T weight = 0; + + if (bp_dir == 0) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } else if (bp_dir == 1) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel( + const int n, const T *data_im, const T *data_offset, const int height, + const int width, const int kernel_h, const int kernel_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int num_channels, const int deformable_group, const int height_col, + const int width_col, T *data_col) { + MUSA_1D_KERNEL_LOOP(index, n) { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + T *data_col_ptr = + data_col + + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + const T *data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const T *data_offset_ptr = + data_offset + (b_col * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + + w_col; + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + T val = static_cast(0); + const T h_im = h_in + i * dilation_h + offset_h; + const T w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + val = deformable_im2col_bilinear(data_im_ptr, width, height, width, + h_im, w_im); + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, const T *data_col, const T *data_offset, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int deformable_group, const int height_col, const int width_col, + T *grad_im) { + MUSA_1D_KERNEL_LOOP(index, n) { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = + (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = + index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const T *data_offset_ptr = + data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T cur_inv_h_data = h_in + i * dilation_h + offset_h; + const T cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const T cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) { + for (int dx = -2; dx <= 2; dx++) { + if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && + cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) { + int cur_bottom_grad_pos = + ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + T weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, + cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void deformable_col2im_coord_gpu_kernel( + const int n, const T *data_col, const T *data_im, const T *data_offset, + const int channels, const int height, const int width, const int kernel_h, + const int kernel_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int offset_channels, const int deformable_group, const int height_col, + const int width_col, T *grad_offset) { + MUSA_1D_KERNEL_LOOP(index, n) { + T val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const T *data_col_ptr = data_col + deformable_group_index * + channel_per_deformable_group * + batch_size * width_col * height_col; + const T *data_im_ptr = + data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * + height * width; + const T *data_offset_ptr = + data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; + col_c += col_step) { + const int col_pos = + (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = + (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = + (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = + (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + + w_out); + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + T inv_h = h_in + i * dilation_h + offset_h; + T inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + inv_h = inv_w = -2; + const T weight = get_coordinate_weight(inv_h, inv_w, height, width, + data_im_ptr + cnt * height * width, + width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + +#endif // DEFORM_CONV_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/deform_roi_pool_musa_kernel.muh b/mmcv/ops/csrc/common/musa/deform_roi_pool_musa_kernel.muh new file mode 100644 index 000000000..c206c729b --- /dev/null +++ b/mmcv/ops/csrc/common/musa/deform_roi_pool_musa_kernel.muh @@ -0,0 +1,181 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef DEFORM_ROI_POOL_MUSA_KERNEL_MUH +#define DEFORM_ROI_POOL_MUSA_KERNEL_MUH + +#include "pytorch_musa_helper.hpp" +template +__global__ void deform_roi_pool_forward_musa_kernel( + const int nthreads, const T* input, const T* rois, const T* offset, + T* output, const int pooled_height, const int pooled_width, + const T spatial_scale, const int sampling_ratio, const T gamma, + const int channels, const int height, const int width) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - 0.5; + T roi_start_h = offset_rois[2] * spatial_scale - 0.5; + T roi_end_w = offset_rois[3] * spatial_scale - 0.5; + T roi_end_h = offset_rois[4] * spatial_scale - 0.5; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_height / pooled_height)); + int roi_bin_grid_w = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_width / pooled_width)); + + // Compute roi offset + if (offset != NULL) { + const T* offset_cur_w = offset + n * pooled_width * pooled_height * 2 + + ph * pooled_width + pw; + T offset_roi_w = gamma * roi_width * offset_cur_w[0]; + T offset_roi_h = + gamma * roi_height * offset_cur_w[pooled_width * pooled_height]; + roi_start_w += offset_roi_w; + roi_start_h += offset_roi_h; + } + + // We do average pooling inside a bin + const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T val = bilinear_interpolate(offset_input, height, width, y, x, index); + output_val += val; + } + } + output[index] = output_val / count; + } +} + +template +__global__ void deform_roi_pool_backward_musa_kernel( + const int nthreads, const T* grad_output, const T* input, const T* rois, + const T* offset, T* grad_input, T* grad_offset, const int pooled_height, + const int pooled_width, const T spatial_scale, const int sampling_ratio, + const T gamma, const int channels, const int height, const int width) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + const T* offset_input = + input + ((roi_batch_ind * channels + c) * height * width); + T* offset_grad_input = + grad_input + ((roi_batch_ind * channels + c) * height * width); + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - 0.5; + T roi_start_h = offset_rois[2] * spatial_scale - 0.5; + T roi_end_w = offset_rois[3] * spatial_scale - 0.5; + T roi_end_h = offset_rois[4] * spatial_scale - 0.5; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_height / pooled_height)); + int roi_bin_grid_w = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_width / pooled_width)); + + // Compute roi offset + if (offset != NULL) { + const T* offset_cur_w = offset + n * pooled_width * pooled_height * 2 + + ph * pooled_width + pw; + T offset_roi_w = gamma * roi_width * offset_cur_w[0]; + T offset_roi_h = + gamma * roi_height * offset_cur_w[pooled_width * pooled_height]; + roi_start_w += offset_roi_w; + roi_start_h += offset_roi_h; + } + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + const T grad_output_this_bin = grad_output[index] / count; + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4, + x_low, x_high, y_low, y_high, index); + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomicAdd(offset_grad_input + y_low * width + x_low, + grad_output_this_bin * w1); + atomicAdd(offset_grad_input + y_low * width + x_high, + grad_output_this_bin * w2); + atomicAdd(offset_grad_input + y_high * width + x_low, + grad_output_this_bin * w3); + atomicAdd(offset_grad_input + y_high * width + x_high, + grad_output_this_bin * w4); + if (offset != NULL) { + T input_00 = offset_input[y_low * width + x_low]; + T input_10 = offset_input[y_low * width + x_high]; + T input_01 = offset_input[y_high * width + x_low]; + T input_11 = offset_input[y_high * width + x_high]; + T ogx = gamma * roi_width * grad_output_this_bin * + (input_11 * (y - y_low) + input_10 * (y_high - y) + + input_01 * (y_low - y) + input_00 * (y - y_high)); + T ogy = gamma * roi_height * grad_output_this_bin * + (input_11 * (x - x_low) + input_01 * (x_high - x) + + input_10 * (x_low - x) + input_00 * (x - x_high)); + atomicAdd(grad_offset + n * pooled_width * pooled_height * 2 + + ph * pooled_width + pw, + ogx); + atomicAdd(grad_offset + n * pooled_width * pooled_height * 2 + + pooled_width * pooled_height + ph * pooled_width + pw, + ogy); + } + } + } + } + } +} + +#endif // DEFORM_ROI_POOL_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/diff_iou_rotated_musa_kernel.muh b/mmcv/ops/csrc/common/musa/diff_iou_rotated_musa_kernel.muh new file mode 100644 index 000000000..3bb7c1c0c --- /dev/null +++ b/mmcv/ops/csrc/common/musa/diff_iou_rotated_musa_kernel.muh @@ -0,0 +1,133 @@ +// Copyright (c) OpenMMLab. All rights reserved +// Adapted from +// https://github.com/lilanxiao/Rotated_IoU/cuda_op/sort_vert_kernel.cu # noqa +#include "pytorch_musa_helper.hpp" + +#define MAX_NUM_VERT_IDX 9 +#define INTERSECTION_OFFSET 8 +#define EPSILON 1e-8 + +inline int opt_n_thread(int work_size) { + const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); + return max(min(1 << pow_2, THREADS_PER_BLOCK), 1); +} + +/* +compare normalized vertices (vertices around (0,0)) +if vertex1 < vertex2 return true. +order: minimum at x-aixs, become larger in anti-clockwise direction +*/ +__device__ bool compare_vertices(float x1, float y1, float x2, float y2) { + if (fabs(x1 - x2) < EPSILON && fabs(y2 - y1) < EPSILON) + return false; // if equal, return false + + if (y1 > 0 && y2 < 0) return true; + if (y1 < 0 && y2 > 0) return false; + + float n1 = x1 * x1 + y1 * y1 + EPSILON; + float n2 = x2 * x2 + y2 * y2 + EPSILON; + float diff = fabs(x1) * x1 / n1 - fabs(x2) * x2 / n2; + + if (y1 > 0 && y2 > 0) { + if (diff > EPSILON) + return true; + else + return false; + } + if (y1 < 0 && y2 < 0) { + if (diff < EPSILON) + return true; + else + return false; + } + return false; +} + +__global__ void diff_iou_rotated_sort_vertices_forward_musa_kernel( + int b, int n, int m, const float *__restrict__ vertices, + const bool *__restrict__ mask, const int *__restrict__ num_valid, + int *__restrict__ idx) { + int batch_idx = blockIdx.x; + vertices += batch_idx * n * m * 2; + mask += batch_idx * n * m; + num_valid += batch_idx * n; + idx += batch_idx * n * MAX_NUM_VERT_IDX; + + int index = threadIdx.x; // index of polygon + int stride = blockDim.x; + for (int i = index; i < n; i += stride) { + int pad; // index of arbitrary invalid intersection point (not box corner!) + for (int j = INTERSECTION_OFFSET; j < m; ++j) { + if (!mask[i * m + j]) { + pad = j; + break; + } + } + if (num_valid[i] < 3) { + // not enough vertices, take an invalid intersection point + // (zero padding) + for (int j = 0; j < MAX_NUM_VERT_IDX; ++j) { + idx[i * MAX_NUM_VERT_IDX + j] = pad; + } + } else { + // sort the valid vertices + // note the number of valid vertices is known + // note: check that num_valid[i] < MAX_NUM_VERT_IDX + for (int j = 0; j < num_valid[i]; ++j) { + // initialize with a "big" value + float x_min = 1; + float y_min = -EPSILON; + int i_take = 0; + int i2; + float x2, y2; + if (j != 0) { + i2 = idx[i * MAX_NUM_VERT_IDX + j - 1]; + x2 = vertices[i * m * 2 + i2 * 2 + 0]; + y2 = vertices[i * m * 2 + i2 * 2 + 1]; + } + for (int k = 0; k < m; ++k) { + float x = vertices[i * m * 2 + k * 2 + 0]; + float y = vertices[i * m * 2 + k * 2 + 1]; + if (mask[i * m + k] && compare_vertices(x, y, x_min, y_min)) { + if ((j == 0) || (j != 0 && compare_vertices(x2, y2, x, y))) { + x_min = x; + y_min = y; + i_take = k; + } + } + } + idx[i * MAX_NUM_VERT_IDX + j] = i_take; + } + // duplicate the first idx + idx[i * MAX_NUM_VERT_IDX + num_valid[i]] = idx[i * MAX_NUM_VERT_IDX + 0]; + + // pad zeros + for (int j = num_valid[i] + 1; j < MAX_NUM_VERT_IDX; ++j) { + idx[i * MAX_NUM_VERT_IDX + j] = pad; + } + + // for corner case: the two boxes are exactly the same. + // in this case, idx would have duplicate elements, which makes the + // shoelace formula broken because of the definition, the duplicate + // elements only appear in the first 8 positions (they are "corners in + // box", not "intersection of edges") + if (num_valid[i] == 8) { + int counter = 0; + for (int j = 0; j < 4; ++j) { + int check = idx[i * MAX_NUM_VERT_IDX + j]; + for (int k = 4; k < INTERSECTION_OFFSET; ++k) { + if (idx[i * MAX_NUM_VERT_IDX + k] == check) counter++; + } + } + if (counter == 4) { + idx[i * MAX_NUM_VERT_IDX + 4] = idx[i * MAX_NUM_VERT_IDX + 0]; + for (int j = 5; j < MAX_NUM_VERT_IDX; ++j) { + idx[i * MAX_NUM_VERT_IDX + j] = pad; + } + } + } + + // TODO: still might need to cover some other corner cases :( + } + } +} diff --git a/mmcv/ops/csrc/pytorch/deform_conv.cpp b/mmcv/ops/csrc/pytorch/deform_conv.cpp index 86690b939..0c8958826 100644 --- a/mmcv/ops/csrc/pytorch/deform_conv.cpp +++ b/mmcv/ops/csrc/pytorch/deform_conv.cpp @@ -153,7 +153,9 @@ void deform_conv_forward(Tensor input, Tensor weight, Tensor offset, #else AT_ERROR("DeformConv is not compiled with GPU support"); #endif - } else { + } +#ifndef MMCV_WITH_MUSA + else { CHECK_CPU_INPUT(input); CHECK_CPU_INPUT(offset); CHECK_CPU_INPUT(weight); @@ -161,7 +163,7 @@ void deform_conv_forward(Tensor input, Tensor weight, Tensor offset, CHECK_CPU_INPUT(columns); CHECK_CPU_INPUT(ones); } - +#endif deform_conv_shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, dilationH, dilationW, group, deformable_group); at::DeviceGuard guard(input.device()); @@ -275,7 +277,9 @@ void deform_conv_backward_input(Tensor input, Tensor offset, Tensor gradOutput, #else AT_ERROR("DeformConv is not compiled with GPU support"); #endif - } else { + } +#ifndef MMCV_WITH_MUSA + else { CHECK_CPU_INPUT(input); CHECK_CPU_INPUT(offset); CHECK_CPU_INPUT(gradOutput); @@ -284,6 +288,7 @@ void deform_conv_backward_input(Tensor input, Tensor offset, Tensor gradOutput, CHECK_CPU_INPUT(weight); CHECK_CPU_INPUT(columns); } +#endif deform_conv_shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, dilationH, dilationW, group, deformable_group); @@ -407,7 +412,9 @@ void deform_conv_backward_parameters(Tensor input, Tensor offset, #else AT_ERROR("DeformConv is not compiled with GPU support"); #endif - } else { + } +#ifndef MMCV_WITH_MUSA + else { CHECK_CPU_INPUT(input); CHECK_CPU_INPUT(offset); CHECK_CPU_INPUT(gradOutput); @@ -415,6 +422,7 @@ void deform_conv_backward_parameters(Tensor input, Tensor offset, CHECK_CPU_INPUT(columns); CHECK_CPU_INPUT(ones); } +#endif deform_conv_shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, padW, dilationH, dilationW, group, diff --git a/mmcv/ops/csrc/pytorch/musa/bias_act_musa.mu b/mmcv/ops/csrc/pytorch/musa/bias_act_musa.mu new file mode 100644 index 000000000..16ac7122c --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/bias_act_musa.mu @@ -0,0 +1,301 @@ +// Modified from +// https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/bias_act.cpp + +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include + +#include "pytorch_musa_helper.hpp" + +struct bias_act_kernel_params { + const void *x; // [sizeX] + const void *b; // [sizeB] or NULL + const void *xref; // [sizeX] or NULL + const void *yref; // [sizeX] or NULL + const void *dy; // [sizeX] or NULL + void *y; // [sizeX] + + int grad; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +// MUSA kernel selection. + +template +void *choose_bias_act_kernel(const bias_act_kernel_params &p); +//------------------------------------------------------------------------ +// Helpers. + +template +struct InternalType; +template <> +struct InternalType { + typedef double scalar_t; +}; +template <> +struct InternalType { + typedef float scalar_t; +}; +template <> +struct InternalType { + typedef float scalar_t; +}; + +//------------------------------------------------------------------------ +// MUSA kernel. + +template +__global__ void bias_act_kernel(bias_act_kernel_params p) { + typedef typename InternalType::scalar_t scalar_t; + int G = p.grad; + scalar_t alpha = (scalar_t)p.alpha; + scalar_t gain = (scalar_t)p.gain; + scalar_t clamp = (scalar_t)p.clamp; + scalar_t one = (scalar_t)1; + scalar_t two = (scalar_t)2; + scalar_t expRange = (scalar_t)80; + scalar_t halfExpRange = (scalar_t)40; + scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; + scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; + loopIdx++, xi += blockDim.x) { + // Load. + scalar_t x = (scalar_t)((const T *)p.x)[xi]; + scalar_t b = + (p.b) ? (scalar_t)((const T *)p.b)[(xi / p.stepB) % p.sizeB] : 0; + scalar_t xref = (p.xref) ? (scalar_t)((const T *)p.xref)[xi] : 0; + scalar_t yref = (p.yref) ? (scalar_t)((const T *)p.yref)[xi] : 0; + scalar_t dy = (p.dy) ? (scalar_t)((const T *)p.dy)[xi] : one; + scalar_t yy = (gain != 0) ? yref / gain : 0; + scalar_t y = 0; + + // Apply bias. + ((G == 0) ? x : xref) += b; + + // linear + if (A == 1) { + if (G == 0) y = x; + if (G == 1) y = x; + } + + // relu + if (A == 2) { + if (G == 0) y = (x > 0) ? x : 0; + if (G == 1) y = (yy > 0) ? x : 0; + } + + // lrelu + if (A == 3) { + if (G == 0) y = (x > 0) ? x : x * alpha; + if (G == 1) y = (yy > 0) ? x : x * alpha; + } + + // tanh + if (A == 4) { + if (G == 0) { + scalar_t c = exp(x); + scalar_t d = one / c; + y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); + } + if (G == 1) y = x * (one - yy * yy); + if (G == 2) y = x * (one - yy * yy) * (-two * yy); + } + + // sigmoid + if (A == 5) { + if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); + if (G == 1) y = x * yy * (one - yy); + if (G == 2) y = x * yy * (one - yy) * (one - two * yy); + } + + // elu + if (A == 6) { + if (G == 0) y = (x >= 0) ? x : exp(x) - one; + if (G == 1) y = (yy >= 0) ? x : x * (yy + one); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); + } + + // selu + if (A == 7) { + if (G == 0) + y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); + if (G == 1) + y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); + } + + // softplus + if (A == 8) { + if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); + if (G == 1) y = x * (one - exp(-yy)); + if (G == 2) { + scalar_t c = exp(-yy); + y = x * c * (one - c); + } + } + + // swish + if (A == 9) { + if (G == 0) + y = (x < -expRange) ? 0 : x / (exp(-x) + one); + else { + scalar_t c = exp(xref); + scalar_t d = c + one; + if (G == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) + ? 0 + : x * c * (xref * (two - d) + two * d) / (d * d * d); + yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; + } + } + + // Apply gain. + y *= gain * dy; + + // Clamp. + if (clamp >= 0) { + if (G == 0) + y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; + else + y = (yref > -clamp & yref < clamp) ? y : 0; + } + + // Store. + ((T *)p.y)[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// MUSA kernel selection. + +template +void *choose_bias_act_kernel(const bias_act_kernel_params &p) { + if (p.act == 1) return (void *)bias_act_kernel; + if (p.act == 2) return (void *)bias_act_kernel; + if (p.act == 3) return (void *)bias_act_kernel; + if (p.act == 4) return (void *)bias_act_kernel; + if (p.act == 5) return (void *)bias_act_kernel; + if (p.act == 6) return (void *)bias_act_kernel; + if (p.act == 7) return (void *)bias_act_kernel; + if (p.act == 8) return (void *)bias_act_kernel; + if (p.act == 9) return (void *)bias_act_kernel; + return NULL; +} + +//------------------------------------------------------------------------ + +static bool has_same_layout(torch::Tensor x, torch::Tensor y) { + if (x.dim() != y.dim()) return false; + for (int64_t i = 0; i < x.dim(); i++) { + if (x.size(i) != y.size(i)) return false; + if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) return false; + } + return true; +} + +//------------------------------------------------------------------------ +torch::Tensor bias_act_op(const torch::Tensor &x, const torch::Tensor &b, + const torch::Tensor &xref, const torch::Tensor &yref, + const torch::Tensor &dy, int grad, int dim, int act, + float alpha, float gain, float clamp) { + // Validate arguments. + TORCH_CHECK(x.is_privateuseone(), "x must reside on MUSA device"); + TORCH_CHECK( + b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), + "b must have the same dtype and device as x"); + TORCH_CHECK(xref.numel() == 0 || + (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && + xref.device() == x.device()), + "xref must have the same shape, dtype, and device as x"); + TORCH_CHECK(yref.numel() == 0 || + (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && + yref.device() == x.device()), + "yref must have the same shape, dtype, and device as x"); + TORCH_CHECK( + dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && + dy.device() == x.device()), + "dy must have the same dtype and device as x"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(b.dim() == 1, "b must have rank 1"); + TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), + "dim is out of bounds"); + TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), + "b has wrong number of elements"); + TORCH_CHECK(grad >= 0, "grad must be non-negative"); + + // Validate layout. + TORCH_CHECK(x.is_non_overlapping_and_dense(), + "x must be non-overlapping and dense"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); + TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), + "xref must have the same layout as x"); + TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), + "yref must have the same layout as x"); + TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), + "dy must have the same layout as x"); + + // Create output tensor. + const at::musa::OptionalMUSAGuard device_guard(device_of(x)); + torch::Tensor y = torch::empty_like(x); + TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); + + // Initialize MUSA kernel parameters. + bias_act_kernel_params p; + p.x = x.data_ptr(); + p.b = (b.numel()) ? b.data_ptr() : NULL; + p.xref = (xref.numel()) ? xref.data_ptr() : NULL; + p.yref = (yref.numel()) ? yref.data_ptr() : NULL; + p.dy = (dy.numel()) ? dy.data_ptr() : NULL; + p.y = y.data_ptr(); + p.grad = grad; + p.act = act; + p.alpha = alpha; + p.gain = gain; + p.clamp = clamp; + p.sizeX = (int)x.numel(); + p.sizeB = (int)b.numel(); + p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; + + // Choose MUSA kernel. + void *kernel; + AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "upfirdn2d_musa", [&] { + kernel = choose_bias_act_kernel(p); + }); + TORCH_CHECK(kernel, "no MUSA kernel found for the specified activation func"); + + // Launch MUSA kernel. + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void *args[] = {&p}; +#ifdef MMCV_WITH_HIP + AT_MUSA_CHECK(hipLaunchKernel(kernel, gridSize, blockSize, args, 0, + c10::musa::getCurrentMUSAStream())); +#else + AT_MUSA_CHECK(musaLaunchKernel(kernel, gridSize, blockSize, args, 0, + c10::musa::getCurrentMUSAStream())); +#endif + + return y; +} diff --git a/mmcv/ops/csrc/pytorch/musa/border_align_musa.mu b/mmcv/ops/csrc/pytorch/musa/border_align_musa.mu new file mode 100644 index 000000000..b4e25e15d --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/border_align_musa.mu @@ -0,0 +1,68 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "border_align_musa_kernel.muh" +#include "pytorch_musa_helper.hpp" + +void BorderAlignForwardMUSAKernelLauncher(const Tensor &input, + const Tensor &boxes, Tensor output, + Tensor argmax_idx, + const int pool_size) { + // shape assertion + AT_ASSERTM(input.ndimension() == 4, + "non-empty 4D(batch mode) tensor expected for input feature"); + AT_ASSERTM(boxes.ndimension() == 3, + "boxes must be 3D tensor with size of [B, H*W, 4]"); + + int batch_size = input.size(0); + int feat_channels = input.size(1); + int channels = feat_channels / 4; + int height = input.size(2); + int width = input.size(3); + // shape [N, box_size, 4] for boxes. (x1, y1, x2, y2) format + int box_size = boxes.size(1); + // shape [N, channels, box_size, 4] for output + int nthreads = batch_size * channels * box_size; + + c10::musa::MUSAGuard device_guard(input.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + dim3 block(128, 4); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "border_align_forward_musa_kernel", [&] { + border_align_forward_musa_kernel + <<>>( + nthreads, input.data_ptr(), + boxes.data_ptr(), output.data_ptr(), + argmax_idx.data_ptr(), channels, box_size, height, width, + pool_size); + }); + + AT_MUSA_CHECK(musaGetLastError()); +} + +void BorderAlignBackwardMUSAKernelLauncher(const Tensor &grad_output, + const Tensor &boxes, + const Tensor &argmax_idx, + Tensor grad_input, + const int pool_size) { + int batch_size = grad_input.size(0); + int feat_channels = grad_input.size(1); + int channels = feat_channels / 4; + int height = grad_input.size(2); + int width = grad_input.size(3); + int box_size = boxes.size(1); + int nthreads = batch_size * channels * box_size; + + c10::musa::MUSAGuard device_guard(grad_output.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + dim3 block(128, 4); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "border_align_backward_musa_kernel", [&] { + border_align_backward_musa_kernel + <<>>( + nthreads, grad_output.data_ptr(), + boxes.data_ptr(), argmax_idx.data_ptr(), + grad_input.data_ptr(), channels, box_size, height, + width, pool_size); + }); + + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/box_iou_quadri_musa.mu b/mmcv/ops/csrc/pytorch/musa/box_iou_quadri_musa.mu new file mode 100644 index 000000000..48c3570b9 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/box_iou_quadri_musa.mu @@ -0,0 +1,23 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +#include "box_iou_quadri_musa.muh" +#include "pytorch_musa_helper.hpp" + +void box_iou_quadri_musa(const Tensor boxes1, const Tensor boxes2, Tensor ious, + const int mode_flag, const bool aligned) { + using scalar_t = float; + AT_ASSERTM(boxes1.is_privateuseone(), "boxes1 must be a MUSA tensor"); + AT_ASSERTM(boxes2.is_privateuseone(), "boxes2 must be a MUSA tensor"); + + int output_size = ious.numel(); + int num_boxes1 = boxes1.size(0); + int num_boxes2 = boxes2.size(0); + + c10::musa::MUSAGuard device_guard(boxes1.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + box_iou_quadri_musa_kernel + <<>>( + num_boxes1, num_boxes2, boxes1.data_ptr(), + boxes2.data_ptr(), (scalar_t*)ious.data_ptr(), + mode_flag, aligned); + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/box_iou_rotated_musa.mu b/mmcv/ops/csrc/pytorch/musa/box_iou_rotated_musa.mu new file mode 100644 index 000000000..b52c10d46 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/box_iou_rotated_musa.mu @@ -0,0 +1,25 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +// modified from +// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_musa.cu +#include "box_iou_rotated_musa.muh" +#include "pytorch_musa_helper.hpp" + +void box_iou_rotated_musa(const Tensor boxes1, const Tensor boxes2, Tensor ious, + const int mode_flag, const bool aligned) { + using scalar_t = float; + AT_ASSERTM(boxes1.is_privateuseone(), "boxes1 must be a MUSA tensor"); + AT_ASSERTM(boxes2.is_privateuseone(), "boxes2 must be a MUSA tensor"); + + int output_size = ious.numel(); + int num_boxes1 = boxes1.size(0); + int num_boxes2 = boxes2.size(0); + + c10::musa::MUSAGuard device_guard(boxes1.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + box_iou_rotated_musa_kernel + <<>>( + num_boxes1, num_boxes2, boxes1.data_ptr(), + boxes2.data_ptr(), (scalar_t*)ious.data_ptr(), + mode_flag, aligned); + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/carafe_musa.mu b/mmcv/ops/csrc/pytorch/musa/carafe_musa.mu new file mode 100644 index 000000000..1de44d07c --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/carafe_musa.mu @@ -0,0 +1,182 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "carafe_musa_kernel.muh" +#include "pytorch_musa_helper.hpp" + +#include + +#if MUSA_ARCH > 21 +void CARAFEForwardMUSAKernelLauncher(const Tensor features, const Tensor masks, + Tensor rfeatures, Tensor routput, + Tensor rmasks, Tensor output, + const int kernel_size, + const int group_size, + const int scale_factor) { + const int batch_size = output.size(0); + const int channels = output.size(1); + const int output_height = output.size(2); + const int output_width = output.size(3); + + const int input_height = features.size(2); + const int input_width = features.size(3); + + const int mask_channels = masks.size(1); + + rfeatures.resize_({batch_size, input_height, input_width, channels}); + routput.resize_({batch_size, output_height, output_width, channels}); + rmasks.resize_({batch_size, output_height, output_width, mask_channels}); + + // one warp per pixel + c10::musa::MUSAGuard device_guard(features.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + features.scalar_type(), "NCHW2NHWC_Feature", ([&] { + const scalar_t *bottom_data = features.data_ptr(); + scalar_t *top_data = rfeatures.data_ptr(); + const int dh = divideUP(channels, kTileDim); + const int dw = divideUP(input_height * input_width, kTileDim); + BatchTranspose2DMUSAKernel + <<>>( + batch_size, channels, input_height * input_width, dh, dw, + bottom_data, top_data); + })); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + features.scalar_type(), "NCHW2NHWC_Masks", ([&] { + const scalar_t *bottom_data = masks.data_ptr(); + scalar_t *top_data = rmasks.data_ptr(); + const int dh = divideUP(mask_channels, kTileDim); + const int dw = divideUP(output_height * output_width, kTileDim); + BatchTranspose2DMUSAKernel + <<>>( + batch_size, mask_channels, output_height * output_width, dh, dw, + bottom_data, top_data); + })); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + features.scalar_type(), "CARAFELaucherForward", ([&] { + const int num_kernels = + batch_size * output_height * output_width * THREADS_PER_PIXEL; + const scalar_t *bottom_data = rfeatures.data_ptr(); + const scalar_t *bottom_masks = rmasks.data_ptr(); + scalar_t *top_data = routput.data_ptr(); + CARAFEForward<<>>( + num_kernels, bottom_data, bottom_masks, kernel_size, group_size, + scale_factor, channels, input_height, input_width, output_height, + output_width, mask_channels, top_data); + })); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + features.scalar_type(), "NHWC2NCHW", ([&] { + const scalar_t *bottom_data = routput.data_ptr(); + scalar_t *top_data = output.data_ptr(); + const int dh = divideUP(output_height * output_width, kTileDim); + const int dw = divideUP(channels, kTileDim); + BatchTranspose2DMUSAKernel + <<>>( + batch_size, output_height * output_width, channels, dh, dw, + bottom_data, top_data); + })); + AT_MUSA_CHECK(musaGetLastError()); +} + +void CARAFEBackwardMUSAKernelLauncher( + const Tensor top_grad, const Tensor rfeatures, const Tensor masks, + Tensor rtop_grad, Tensor rbottom_grad_hs, Tensor rbottom_grad, + Tensor rmask_grad, Tensor bottom_grad, Tensor mask_grad, + const int kernel_size, const int group_size, const int scale_factor) { + const int batch_size = top_grad.size(0); + const int channels = top_grad.size(1); + const int output_height = top_grad.size(2); + const int output_width = top_grad.size(3); + + const int input_height = bottom_grad.size(2); + const int input_width = bottom_grad.size(3); + + const int mask_channels = masks.size(1); + + rtop_grad.resize_({batch_size, output_height, output_width, channels}); + rbottom_grad.resize_({batch_size, input_height, input_width, channels}); + rbottom_grad_hs.resize_({batch_size, output_height, output_width, channels}); + rmask_grad.resize_({batch_size, output_height, output_width, mask_channels}); + + c10::musa::MUSAGuard device_guard(top_grad.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + top_grad.scalar_type(), "NCHW2NHWC_Top_Grad", ([&] { + const scalar_t *bottom_data = top_grad.data_ptr(); + scalar_t *top_data = rtop_grad.data_ptr(); + const int dh = divideUP(channels, kTileDim); + const int dw = divideUP(output_height * output_width, kTileDim); + BatchTranspose2DMUSAKernel + <<>>( + batch_size, channels, output_height * output_width, dh, dw, + bottom_data, top_data); + })); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + top_grad.scalar_type(), "CARAFELaucherBackward_Feature", ([&] { + const int num_kernels = + batch_size * output_height * output_width * THREADS_PER_PIXEL; + const scalar_t *top_diff = rtop_grad.data_ptr(); + const scalar_t *bottom_masks = masks.data_ptr(); + scalar_t *bottom_diff = rbottom_grad_hs.data_ptr(); + + CARAFEBackward_Feature + <<>>(num_kernels, top_diff, bottom_masks, kernel_size, + group_size, scale_factor, channels, input_height, + input_width, output_height, output_width, + mask_channels, bottom_diff); + })); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + top_grad.scalar_type(), "FeatureSum", ([&] { + const int num_kernels = + batch_size * input_height * input_width * THREADS_PER_PIXEL; + const scalar_t *bottom_diff_hs = rbottom_grad_hs.data_ptr(); + scalar_t *bottom_diff = rbottom_grad.data_ptr(); + + FeatureSum + <<>>(num_kernels, bottom_diff_hs, scale_factor, channels, + input_height, input_width, bottom_diff); + })); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + top_grad.scalar_type(), "NHWC2NCHW_Bottom_Grad", ([&] { + const scalar_t *bottom_data = rbottom_grad.data_ptr(); + scalar_t *top_data = bottom_grad.data_ptr(); + const int dh = divideUP(input_height * input_width, kTileDim); + const int dw = divideUP(channels, kTileDim); + BatchTranspose2DMUSAKernel + <<>>( + batch_size, input_height * input_width, channels, dh, dw, + bottom_data, top_data); + })); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + top_grad.scalar_type(), "CARAFELaucherBackward_Mask", ([&] { + const int num_kernels = batch_size * output_height * output_width * + mask_channels * WARP_SIZE; + const scalar_t *top_diff = rtop_grad.data_ptr(); + const scalar_t *bottom_data = rfeatures.data_ptr(); + scalar_t *mask_diff = rmask_grad.data_ptr(); + + CARAFEBackward_Mask + <<>>(num_kernels, top_diff, bottom_data, kernel_size, + group_size, scale_factor, channels, input_height, + input_width, output_height, output_width, + mask_channels, mask_diff); + })); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + top_grad.scalar_type(), "NHWC2NCHW_Mask_Grad", ([&] { + const scalar_t *bottom_data = rmask_grad.data_ptr(); + scalar_t *top_data = mask_grad.data_ptr(); + const int dh = divideUP(output_height * output_width, kTileDim); + const int dw = divideUP(mask_channels, kTileDim); + BatchTranspose2DMUSAKernel + <<>>( + batch_size, output_height * output_width, mask_channels, dh, dw, + bottom_data, top_data); + })); + + AT_MUSA_CHECK(musaGetLastError()); +} +#endif //MUSA_ARCH diff --git a/mmcv/ops/csrc/pytorch/musa/carafe_naive_musa.mu b/mmcv/ops/csrc/pytorch/musa/carafe_naive_musa.mu new file mode 100644 index 000000000..cf288a32b --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/carafe_naive_musa.mu @@ -0,0 +1,52 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "carafe_naive_musa_kernel.muh" +#include "pytorch_musa_helper.hpp" + +void CARAFENAIVEForwardMUSAKernelLauncher(const Tensor features, + const Tensor masks, Tensor output, + const int kernel_size, + const int group_size, + const int scale_factor) { + int output_size = output.numel(); + int channels = output.size(1); + int height = output.size(2); + int width = output.size(3); + + c10::musa::MUSAGuard device_guard(features.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES( + features.scalar_type(), "CARAFENAIVEForward", ([&] { + carafe_naive_forward_musa_kernel + <<>>( + output_size, features.data_ptr(), + masks.data_ptr(), output.data_ptr(), + kernel_size, group_size, scale_factor, channels, height, width); + })); + + AT_MUSA_CHECK(musaGetLastError()); +} + +void CARAFENAIVEBackwardMUSAKernelLauncher( + const Tensor top_grad, const Tensor features, const Tensor masks, + Tensor bottom_grad, Tensor mask_grad, const int kernel_size, + const int group_size, const int scale_factor) { + int output_size = top_grad.numel(); + int channels = top_grad.size(1); + int height = top_grad.size(2); + int width = top_grad.size(3); + + c10::musa::MUSAGuard device_guard(top_grad.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES( + top_grad.scalar_type(), "CARAFENAIVEBackward", ([&] { + carafe_naive_backward_musa_kernel + <<>>( + output_size, top_grad.data_ptr(), + features.data_ptr(), masks.data_ptr(), + bottom_grad.data_ptr(), + mask_grad.data_ptr(), kernel_size, group_size, + scale_factor, channels, height, width); + })); + + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/chamfer_distance_musa.mu b/mmcv/ops/csrc/pytorch/musa/chamfer_distance_musa.mu new file mode 100644 index 000000000..d318f8cf2 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/chamfer_distance_musa.mu @@ -0,0 +1,66 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cpp +#include "chamfer_distance_musa_kernel.muh" +#include "pytorch_musa_helper.hpp" +#if MUSA_ARCH > 21 +void ChamferDistanceForwardMUSAKernelLauncher( + const Tensor xyz1, const Tensor xyz2, const Tensor dist1, + const Tensor dist2, const Tensor idx1, const Tensor idx2) { + int batch_size = xyz1.size(0); + int n = xyz1.size(1); + int m = xyz2.size(1); + + c10::musa::MUSAGuard device_guard(xyz1.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + xyz1.scalar_type(), "chamfer_distance_forward_musa_kernel", [&] { + chamfer_distance_forward_musa_kernel + <<>>( + batch_size, n, xyz1.data_ptr(), m, + xyz2.data_ptr(), dist1.data_ptr(), + idx1.data_ptr()); + }); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + xyz1.scalar_type(), "chamfer_distance_forward_musa_kernel", [&] { + chamfer_distance_forward_musa_kernel + <<>>( + batch_size, m, xyz2.data_ptr(), n, + xyz1.data_ptr(), dist2.data_ptr(), + idx2.data_ptr()); + }); + AT_MUSA_CHECK(musaGetLastError()); +} + +void ChamferDistanceBackwardMUSAKernelLauncher( + const Tensor xyz1, const Tensor xyz2, Tensor idx1, Tensor idx2, + Tensor grad_dist1, Tensor grad_dist2, Tensor grad_xyz1, Tensor grad_xyz2) { + int batch_size = xyz1.size(0); + int n = xyz1.size(1); + int m = xyz2.size(1); + + c10::musa::MUSAGuard device_guard(xyz1.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES( + xyz1.scalar_type(), "chamfer_distance_backward_musa_kernel", [&] { + chamfer_distance_backward_musa_kernel + <<>>( + batch_size, m, xyz1.data_ptr(), n, + xyz2.data_ptr(), grad_dist1.data_ptr(), + idx1.data_ptr(), grad_xyz1.data_ptr(), + grad_xyz2.data_ptr()); + }); + AT_DISPATCH_FLOATING_TYPES( + xyz1.scalar_type(), "chamfer_distance_backward_musa_kernel", [&] { + chamfer_distance_backward_musa_kernel + <<>>( + batch_size, n, xyz2.data_ptr(), m, + xyz1.data_ptr(), grad_dist2.data_ptr(), + idx2.data_ptr(), grad_xyz2.data_ptr(), + grad_xyz1.data_ptr()); + }); + AT_MUSA_CHECK(musaGetLastError()); +} +#else +#warning "chamfer_distance is supported when MUSA_ARCH > 21" +#endif //MUSA_ARCH diff --git a/mmcv/ops/csrc/pytorch/musa/convex_iou.mu b/mmcv/ops/csrc/pytorch/musa/convex_iou.mu new file mode 100644 index 000000000..605857358 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/convex_iou.mu @@ -0,0 +1,41 @@ +// Copyright (c) OpenMMLab. All rights reserved +// modified from +// https://github.com/SDL-GuoZonghao/BeyondBoundingBox/blob/main/mmdet/ops/iou/src/convex_iou_kernel.cu +#include "convex_iou_musa_kernel.muh" +#include "pytorch_musa_helper.hpp" + +void ConvexIoUMUSAKernelLauncher(const Tensor pointsets, const Tensor polygons, + Tensor ious) { + int output_size = ious.numel(); + int num_pointsets = pointsets.size(0); + int num_polygons = polygons.size(0); + + c10::musa::MUSAGuard device_guard(pointsets.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES( + pointsets.scalar_type(), "convex_iou_musa_kernel", ([&] { + convex_iou_musa_kernel + <<>>( + num_pointsets, num_polygons, pointsets.data_ptr(), + polygons.data_ptr(), ious.data_ptr()); + })); + AT_MUSA_CHECK(musaGetLastError()); +} + +void ConvexGIoUMUSAKernelLauncher(const Tensor pointsets, const Tensor polygons, + Tensor output) { + int output_size = output.numel(); + int num_pointsets = pointsets.size(0); + int num_polygons = polygons.size(0); + + c10::musa::MUSAGuard device_guard(pointsets.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES( + pointsets.scalar_type(), "convex_giou_musa_kernel", ([&] { + convex_giou_musa_kernel + <<>>( + num_pointsets, num_polygons, pointsets.data_ptr(), + polygons.data_ptr(), output.data_ptr()); + })); + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/correlation_musa.mu b/mmcv/ops/csrc/pytorch/musa/correlation_musa.mu new file mode 100644 index 000000000..f24845057 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/correlation_musa.mu @@ -0,0 +1,94 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/ClementPinard/Pytorch-Correlation-extension/blob/master/Correlation_Module/correlation_musa_kernel.cu +// Original licence: Under MIT License + +#include "correlation_musa.muh" +#include "pytorch_musa_helper.hpp" + +void CorrelationForwardMUSAKernelLauncher(Tensor input1, Tensor input2, + Tensor output, int kH, int kW, + int patchH, int patchW, int padH, + int padW, int dilationH, + int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW) { + const int batch_size = input1.size(0); + const int iH = input1.size(2); + const int iW = input1.size(3); + const int dilatedKH = (kH - 1) * dilationH + 1; + const int dilatedKW = (kW - 1) * dilationW + 1; + + const auto oH = (iH + 2 * padH - dilatedKH) / dH + 1; + const auto oW = (iW + 2 * padW - dilatedKW) / dW + 1; + + auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous(); + auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous(); + + const dim3 threads(WARP_SIZE, 4, 4); + const dim3 blocks(batch_size, (oH + 3) >> 2, (oW + 3) >> 2); + + c10::musa::MUSAGuard device_guard(input1.device()); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input1.scalar_type(), "correlation_forward_musa", ([&] { + TensorAcc4R trInput1_acc = + trInput1.packed_accessor32(); + TensorAcc4R trInput2_acc = + trInput2.packed_accessor32(); + TensorAcc5R output_acc = + output.packed_accessor32(); + + correlation_forward_musa_kernel + <<>>( + trInput1_acc, trInput2_acc, output_acc, kH, kW, patchH, patchW, + padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW, oH, oW); + })); +} + +void CorrelationBackwardMUSAKernelLauncher( + Tensor grad_output, Tensor input1, Tensor input2, Tensor grad_input1, + Tensor grad_input2, int kH, int kW, int patchH, int patchW, int padH, + int padW, int dilationH, int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW) { + const int batch_size = input1.size(0); + const int iH = input1.size(2); + const int iW = input1.size(3); + const int C = input1.size(1); + + auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous(); + auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous(); + const dim3 blocks(batch_size, iH, iW); + const dim3 threads(THREADS_PER_BLOCK); + + c10::musa::MUSAGuard device_guard(input1.device()); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input1.scalar_type(), "correlation_backward_musa", ([&] { + const int grad_cache_size = patchH * patchW * sizeof(scalar_t); + TensorAcc4R input1_acc = + trInput1.packed_accessor32(); + TensorAcc4R input2_acc = + trInput2.packed_accessor32(); + TensorAcc4R grad_input1_acc = + grad_input1.packed_accessor32(); + TensorAcc4R grad_input2_acc = + grad_input2.packed_accessor32(); + TensorAcc5R grad_output_acc = + grad_output.packed_accessor32(); + + correlation_backward_musa_kernel_input1 + <<>>( + grad_output_acc, input2_acc, grad_input1_acc, kH, kW, patchH, + patchW, padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW); + + correlation_backward_musa_kernel_input2 + <<>>( + grad_output_acc, input1_acc, grad_input2_acc, kH, kW, patchH, + patchW, padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW); + })); +} diff --git a/mmcv/ops/csrc/pytorch/musa/deform_conv_musa.mu b/mmcv/ops/csrc/pytorch/musa/deform_conv_musa.mu new file mode 100644 index 000000000..544e58cae --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/deform_conv_musa.mu @@ -0,0 +1,105 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "deform_conv_musa_kernel.muh" +#include "pytorch_musa_helper.hpp" + +void deformable_im2col_musa(Tensor data_im, Tensor data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor data_col) { + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = + (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = + (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, height, width, ksize_h, + ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, + deformable_group, height_col, width_col, data_col_); + })); + AT_MUSA_CHECK(musaGetLastError()); +} + +void deformable_col2im_musa(Tensor data_col, Tensor data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor grad_im) { + // todo: make sure parallel_imgs is passed in correctly + int height_col = + (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = + (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = + channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, channels, height, width, + ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, + dilation_w, channel_per_deformable_group, parallel_imgs, + deformable_group, height_col, width_col, grad_im_); + })); + AT_MUSA_CHECK(musaGetLastError()); +} + +void deformable_col2im_coord_musa( + Tensor data_col, Tensor data_im, Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, Tensor grad_offset) { + int height_col = + (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = + (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * + deformable_group * parallel_imgs; + int channel_per_deformable_group = + channels * ksize_h * ksize_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_gpu_kernel<<< + GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, + c10::musa::getCurrentMUSAStream()>>>( + num_kernels, data_col_, data_im_, data_offset_, channels, height, + width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, parallel_imgs, + 2 * ksize_h * ksize_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_); + })); + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/deform_roi_pool_musa.mu b/mmcv/ops/csrc/pytorch/musa/deform_roi_pool_musa.mu new file mode 100644 index 000000000..0312a6bfe --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/deform_roi_pool_musa.mu @@ -0,0 +1,55 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "deform_roi_pool_musa_kernel.muh" +#include "pytorch_musa_helper.hpp" + +void DeformRoIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois, + Tensor offset, Tensor output, + int pooled_height, int pooled_width, + float spatial_scale, + int sampling_ratio, float gamma) { + int output_size = output.numel(); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + c10::musa::MUSAGuard device_guard(input.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "deform_roi_pool_forward_musa_kernel", [&] { + deform_roi_pool_forward_musa_kernel + <<>>( + output_size, input.data_ptr(), + rois.data_ptr(), offset.data_ptr(), + output.data_ptr(), pooled_height, pooled_width, + static_cast(spatial_scale), sampling_ratio, + static_cast(gamma), channels, height, width); + }); + + AT_MUSA_CHECK(musaGetLastError()); +} + +void DeformRoIPoolBackwardMUSAKernelLauncher( + Tensor grad_output, Tensor input, Tensor rois, Tensor offset, + Tensor grad_input, Tensor grad_offset, int pooled_height, int pooled_width, + float spatial_scale, int sampling_ratio, float gamma) { + int output_size = grad_output.numel(); + int channels = grad_input.size(1); + int height = grad_input.size(2); + int width = grad_input.size(3); + + c10::musa::MUSAGuard device_guard(grad_output.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES( + grad_output.scalar_type(), "deform_roi_pool_backward_musa_kernel", [&] { + deform_roi_pool_backward_musa_kernel + <<>>( + output_size, grad_output.data_ptr(), + input.data_ptr(), rois.data_ptr(), + offset.data_ptr(), grad_input.data_ptr(), + grad_offset.data_ptr(), pooled_height, pooled_width, + static_cast(spatial_scale), sampling_ratio, + static_cast(gamma), channels, height, width); + }); + + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/diff_iou_rotated_musa.mu b/mmcv/ops/csrc/pytorch/musa/diff_iou_rotated_musa.mu new file mode 100644 index 000000000..228813e23 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/diff_iou_rotated_musa.mu @@ -0,0 +1,35 @@ +// Copyright (c) OpenMMLab. All rights reserved +// Adapted from +// https://github.com/lilanxiao/Rotated_IoU/musa_op/sort_vert_kernel.cu # noqa +#include "diff_iou_rotated_musa_kernel.muh" +#include "pytorch_cpp_helper.hpp" +#include "pytorch_musa_helper.hpp" + +at::Tensor DiffIoURotatedSortVerticesMUSAKernelLauncher(at::Tensor vertices, + at::Tensor mask, + at::Tensor num_valid) { + c10::musa::MUSAGuard device_guard(vertices.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + + CHECK_CONTIGUOUS(vertices); + CHECK_CONTIGUOUS(mask); + CHECK_CONTIGUOUS(num_valid); + CHECK_MUSA(vertices); + CHECK_MUSA(mask); + CHECK_MUSA(num_valid); + + int b = vertices.size(0); + int n = vertices.size(1); + int m = vertices.size(2); + at::Tensor idx = + torch::zeros({b, n, MAX_NUM_VERT_IDX}, + at::device(vertices.device()).dtype(at::ScalarType::Int)); + + diff_iou_rotated_sort_vertices_forward_musa_kernel<<>>( + b, n, m, vertices.data_ptr(), mask.data_ptr(), + num_valid.data_ptr(), idx.data_ptr()); + AT_MUSA_CHECK(musaGetLastError()); + + return idx; +} diff --git a/mmcv/ops/csrc/pytorch/musa/musabind.cpp b/mmcv/ops/csrc/pytorch/musa/musabind.cpp index 5aff12ccf..4b193e372 100644 --- a/mmcv/ops/csrc/pytorch/musa/musabind.cpp +++ b/mmcv/ops/csrc/pytorch/musa/musabind.cpp @@ -103,6 +103,303 @@ void bbox_overlaps_impl(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, const int mode, const bool aligned, const int offset); REGISTER_DEVICE_IMPL(bbox_overlaps_impl, MUSA, bbox_overlaps_musa); +void BorderAlignForwardMUSAKernelLauncher(const Tensor &input, + const Tensor &boxes, Tensor output, + Tensor argmax_idx, + const int pool_size); + +void BorderAlignBackwardMUSAKernelLauncher(const Tensor &grad_output, + const Tensor &boxes, + const Tensor &argmax_idx, + Tensor grad_input, + const int pool_size); + +void border_align_forward_musa(const Tensor &input, const Tensor &boxes, + Tensor output, Tensor argmax_idx, + const int pool_size) { + BorderAlignForwardMUSAKernelLauncher(input, boxes, output, argmax_idx, + pool_size); +} + +void border_align_backward_musa(const Tensor &grad_output, const Tensor &boxes, + const Tensor &argmax_idx, Tensor grad_input, + const int pool_size) { + BorderAlignBackwardMUSAKernelLauncher(grad_output, boxes, argmax_idx, + grad_input, pool_size); +} + +void border_align_forward_impl(const Tensor &input, const Tensor &boxes, + Tensor output, Tensor argmax_idx, + const int pool_size); + +void border_align_backward_impl(const Tensor &grad_output, const Tensor &boxes, + const Tensor &argmax_idx, Tensor grad_input, + const int pool_size); + +REGISTER_DEVICE_IMPL(border_align_forward_impl, MUSA, + border_align_forward_musa); +REGISTER_DEVICE_IMPL(border_align_backward_impl, MUSA, + border_align_backward_musa); + +void box_iou_rotated_musa(const Tensor boxes1, const Tensor boxes2, Tensor ious, + const int mode_flag, const bool aligned); + +void box_iou_rotated_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious, + const int mode_flag, const bool aligned); +REGISTER_DEVICE_IMPL(box_iou_rotated_impl, MUSA, box_iou_rotated_musa); + +void box_iou_quadri_musa(const Tensor boxes1, const Tensor boxes2, Tensor ious, + const int mode_flag, const bool aligned); + +void box_iou_quadri_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious, + const int mode_flag, const bool aligned); +REGISTER_DEVICE_IMPL(box_iou_quadri_impl, MUSA, box_iou_quadri_musa); + +#if ((!defined(MUSA_ARCH)) || (defined(MUSA_ARCH)) && (MUSA_ARCH > 21)) + +void CARAFEForwardMUSAKernelLauncher(const Tensor features, const Tensor masks, + Tensor rfeatures, Tensor routput, + Tensor rmasks, Tensor output, + const int kernel_size, + const int group_size, + const int scale_factor); + +void CARAFEBackwardMUSAKernelLauncher( + const Tensor top_grad, const Tensor rfeatures, const Tensor masks, + Tensor rtop_grad, Tensor rbottom_grad_hs, Tensor rbottom_grad, + Tensor rmask_grad, Tensor bottom_grad, Tensor mask_grad, + const int kernel_size, const int group_size, const int scale_factor); + +void carafe_forward_musa(Tensor features, Tensor masks, Tensor rfeatures, + Tensor routput, Tensor rmasks, Tensor output, + int kernel_size, int group_size, int scale_factor) { + CARAFEForwardMUSAKernelLauncher(features, masks, rfeatures, routput, rmasks, + output, kernel_size, group_size, + scale_factor); +} + +void carafe_backward_musa(Tensor top_grad, Tensor rfeatures, Tensor masks, + Tensor rtop_grad, Tensor rbottom_grad_hs, + Tensor rbottom_grad, Tensor rmask_grad, + Tensor bottom_grad, Tensor mask_grad, int kernel_size, + int group_size, int scale_factor) { + CARAFEBackwardMUSAKernelLauncher(top_grad, rfeatures, masks, rtop_grad, + rbottom_grad_hs, rbottom_grad, rmask_grad, + bottom_grad, mask_grad, kernel_size, + group_size, scale_factor); +} + +void carafe_forward_impl(Tensor features, Tensor masks, Tensor rfeatures, + Tensor routput, Tensor rmasks, Tensor output, + int kernel_size, int group_size, int scale_factor); + +void carafe_backward_impl(Tensor top_grad, Tensor rfeatures, Tensor masks, + Tensor rtop_grad, Tensor rbottom_grad_hs, + Tensor rbottom_grad, Tensor rmask_grad, + Tensor bottom_grad, Tensor mask_grad, int kernel_size, + int group_size, int scale_factor); + +REGISTER_DEVICE_IMPL(carafe_forward_impl, MUSA, carafe_forward_musa); +REGISTER_DEVICE_IMPL(carafe_backward_impl, MUSA, carafe_backward_musa); +#endif + +void CARAFENAIVEForwardMUSAKernelLauncher(const Tensor features, + const Tensor masks, Tensor output, + const int kernel_size, + const int group_size, + const int scale_factor); + +void CARAFENAIVEBackwardMUSAKernelLauncher( + const Tensor top_grad, const Tensor features, const Tensor masks, + Tensor bottom_grad, Tensor mask_grad, const int kernel_size, + const int group_size, const int scale_factor); + +void carafe_naive_forward_musa(Tensor features, Tensor masks, Tensor output, + int kernel_size, int group_size, + int scale_factor) { + CARAFENAIVEForwardMUSAKernelLauncher(features, masks, output, kernel_size, + group_size, scale_factor); +} + +void carafe_naive_backward_musa(Tensor top_grad, Tensor features, Tensor masks, + Tensor bottom_grad, Tensor mask_grad, + int kernel_size, int group_size, + int scale_factor) { + CARAFENAIVEBackwardMUSAKernelLauncher(top_grad, features, masks, bottom_grad, + mask_grad, kernel_size, group_size, + scale_factor); +} +void carafe_naive_forward_impl(Tensor features, Tensor masks, Tensor output, + int kernel_size, int group_size, + int scale_factor); + +void carafe_naive_backward_impl(Tensor top_grad, Tensor features, Tensor masks, + Tensor bottom_grad, Tensor mask_grad, + int kernel_size, int group_size, + int scale_factor); + +REGISTER_DEVICE_IMPL(carafe_naive_forward_impl, MUSA, + carafe_naive_forward_musa); +REGISTER_DEVICE_IMPL(carafe_naive_backward_impl, MUSA, + carafe_naive_backward_musa); + +void CorrelationForwardMUSAKernelLauncher(Tensor input1, Tensor input2, + Tensor output, int kH, int kW, + int patchH, int patchW, int padH, + int padW, int dilationH, + int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW); + +void CorrelationBackwardMUSAKernelLauncher(Tensor grad_output, Tensor input1, + Tensor input2, Tensor grad_input1, + Tensor grad_input2, int kH, int kW, + int patchH, int patchW, int padH, + int padW, int dilationH, + int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW); + +void correlation_forward_musa(Tensor input1, Tensor input2, Tensor output, + int kH, int kW, int patchH, int patchW, int padH, + int padW, int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, int dH, + int dW) { + CorrelationForwardMUSAKernelLauncher( + input1, input2, output, kH, kW, patchH, patchW, padH, padW, dilationH, + dilationW, dilation_patchH, dilation_patchW, dH, dW); +} + +void correlation_backward_musa(Tensor grad_output, Tensor input1, Tensor input2, + Tensor grad_input1, Tensor grad_input2, int kH, + int kW, int patchH, int patchW, int padH, + int padW, int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, int dH, + int dW) { + CorrelationBackwardMUSAKernelLauncher( + grad_output, input1, input2, grad_input1, grad_input2, kH, kW, patchH, + patchW, padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW); +} + +void correlation_forward_impl(Tensor input1, Tensor input2, Tensor output, + int kH, int kW, int patchH, int patchW, int padH, + int padW, int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, int dH, + int dW); + +void correlation_backward_impl(Tensor grad_output, Tensor input1, Tensor input2, + Tensor grad_input1, Tensor grad_input2, int kH, + int kW, int patchH, int patchW, int padH, + int padW, int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, int dH, + int dW); + +REGISTER_DEVICE_IMPL(correlation_forward_impl, MUSA, correlation_forward_musa); +REGISTER_DEVICE_IMPL(correlation_backward_impl, MUSA, + correlation_backward_musa); + +void deformable_im2col_musa(Tensor data_im, Tensor data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor data_col); + +void deformable_col2im_musa(Tensor data_col, Tensor data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor grad_im); + +void deformable_col2im_coord_musa( + Tensor data_col, Tensor data_im, Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, Tensor grad_offset); + +void deformable_im2col_impl(Tensor data_im, Tensor data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor data_col); + +void deformable_col2im_impl(Tensor data_col, Tensor data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor grad_im); + +void deformable_col2im_coord_impl( + Tensor data_col, Tensor data_im, Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, Tensor grad_offset); + +REGISTER_DEVICE_IMPL(deformable_im2col_impl, MUSA, deformable_im2col_musa); +REGISTER_DEVICE_IMPL(deformable_col2im_impl, MUSA, deformable_col2im_musa); +REGISTER_DEVICE_IMPL(deformable_col2im_coord_impl, MUSA, + deformable_col2im_coord_musa); + +void DeformRoIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois, + Tensor offset, Tensor output, + int pooled_height, int pooled_width, + float spatial_scale, + int sampling_ratio, float gamma); + +void DeformRoIPoolBackwardMUSAKernelLauncher( + Tensor grad_output, Tensor input, Tensor rois, Tensor offset, + Tensor grad_input, Tensor grad_offset, int pooled_height, int pooled_width, + float spatial_scale, int sampling_ratio, float gamma); + +void deform_roi_pool_forward_musa(Tensor input, Tensor rois, Tensor offset, + Tensor output, int pooled_height, + int pooled_width, float spatial_scale, + int sampling_ratio, float gamma) { + DeformRoIPoolForwardMUSAKernelLauncher(input, rois, offset, output, + pooled_height, pooled_width, + spatial_scale, sampling_ratio, gamma); +} + +void deform_roi_pool_backward_musa(Tensor grad_output, Tensor input, + Tensor rois, Tensor offset, + Tensor grad_input, Tensor grad_offset, + int pooled_height, int pooled_width, + float spatial_scale, int sampling_ratio, + float gamma) { + DeformRoIPoolBackwardMUSAKernelLauncher( + grad_output, input, rois, offset, grad_input, grad_offset, pooled_height, + pooled_width, spatial_scale, sampling_ratio, gamma); +} + +void deform_roi_pool_forward_impl(Tensor input, Tensor rois, Tensor offset, + Tensor output, int pooled_height, + int pooled_width, float spatial_scale, + int sampling_ratio, float gamma); + +void deform_roi_pool_backward_impl(Tensor grad_output, Tensor input, + Tensor rois, Tensor offset, + Tensor grad_input, Tensor grad_offset, + int pooled_height, int pooled_width, + float spatial_scale, int sampling_ratio, + float gamma); + +REGISTER_DEVICE_IMPL(deform_roi_pool_forward_impl, MUSA, + deform_roi_pool_forward_musa); +REGISTER_DEVICE_IMPL(deform_roi_pool_backward_impl, MUSA, + deform_roi_pool_backward_musa); + void ActiveRotatedFilterForwardMUSAKernelLauncher(const Tensor input, const Tensor indices, Tensor output); @@ -132,6 +429,88 @@ REGISTER_DEVICE_IMPL(active_rotated_filter_forward_impl, MUSA, REGISTER_DEVICE_IMPL(active_rotated_filter_backward_impl, MUSA, active_rotated_filter_backward_musa); +void ConvexIoUMUSAKernelLauncher(const Tensor pointsets, const Tensor polygons, + Tensor ious); + +void ConvexGIoUMUSAKernelLauncher(const Tensor pointsets, const Tensor polygons, + Tensor output); + +void convex_iou_musa(const Tensor pointsets, const Tensor polygons, + Tensor ious) { + ConvexIoUMUSAKernelLauncher(pointsets, polygons, ious); +} + +void convex_giou_musa(const Tensor pointsets, const Tensor polygons, + Tensor output) { + ConvexGIoUMUSAKernelLauncher(pointsets, polygons, output); +} + +void convex_iou_impl(const Tensor pointsets, const Tensor polygons, + Tensor ious); + +void convex_giou_impl(const Tensor pointsets, const Tensor polygons, + Tensor output); + +REGISTER_DEVICE_IMPL(convex_iou_impl, MUSA, convex_iou_musa); +REGISTER_DEVICE_IMPL(convex_giou_impl, MUSA, convex_giou_musa); + +Tensor DiffIoURotatedSortVerticesMUSAKernelLauncher(Tensor vertices, + Tensor mask, + Tensor num_valid); + +Tensor diff_iou_rotated_sort_vertices_forward_musa(Tensor vertices, Tensor mask, + Tensor num_valid) { + return DiffIoURotatedSortVerticesMUSAKernelLauncher(vertices, mask, + num_valid); +} + +Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask, + Tensor num_valid); + +REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, MUSA, + diff_iou_rotated_sort_vertices_forward_musa); + +#if ((!defined(MUSA_ARCH)) || (defined(MUSA_ARCH)) && (MUSA_ARCH > 21)) +void ChamferDistanceForwardMUSAKernelLauncher( + const Tensor xyz1, const Tensor xyz2, const Tensor dist1, + const Tensor dist2, const Tensor idx1, const Tensor idx2); +#endif + +void ChamferDistanceBackwardMUSAKernelLauncher( + const Tensor xyz1, const Tensor xyz2, Tensor idx1, Tensor idx2, + Tensor grad_dist1, Tensor grad_dist2, Tensor grad_xyz1, Tensor grad_xyz2); + +#if ((!defined(MUSA_ARCH)) || (defined(MUSA_ARCH)) && (MUSA_ARCH > 21)) +void chamfer_distance_forward_musa(const Tensor xyz1, const Tensor xyz2, + const Tensor dist1, const Tensor dist2, + const Tensor idx1, const Tensor idx2) { + ChamferDistanceForwardMUSAKernelLauncher(xyz1, xyz2, dist1, dist2, idx1, + idx2); +}; + +void chamfer_distance_backward_musa(const Tensor xyz1, const Tensor xyz2, + Tensor idx1, Tensor idx2, Tensor graddist1, + Tensor graddist2, Tensor gradxyz1, + Tensor gradxyz2) { + ChamferDistanceBackwardMUSAKernelLauncher(xyz1, xyz2, idx1, idx2, graddist1, + graddist2, gradxyz1, gradxyz2); +}; + +void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2, + const Tensor dist1, const Tensor dist2, + const Tensor idx1, const Tensor idx2); + +void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2, + Tensor idx1, Tensor idx2, Tensor graddist1, + Tensor graddist2, Tensor gradxyz1, + Tensor gradxyz2); + +REGISTER_DEVICE_IMPL(chamfer_distance_forward_impl, MUSA, + chamfer_distance_forward_musa); +REGISTER_DEVICE_IMPL(chamfer_distance_backward_impl, MUSA, + chamfer_distance_backward_musa); +#endif + void BezierAlignForwardMUSAKernelLauncher(Tensor input, Tensor rois, Tensor output, int aligned_height, int aligned_width, diff --git a/tests/test_ops/test_bias_act.py b/tests/test_ops/test_bias_act.py index 01b57c4ae..154534ac2 100644 --- a/tests/test_ops/test_bias_act.py +++ b/tests/test_ops/test_bias_act.py @@ -4,6 +4,7 @@ import torch from mmcv.ops import bias_act from mmcv.ops.bias_act import EasyDict +from mmcv.utils import IS_MUSA_AVAILABLE _USING_PARROTS = True try: @@ -131,6 +132,74 @@ class TestBiasAct: assert out1.max() <= 0.5 assert out2.max() <= 0.5 + @pytest.mark.skipif(not IS_MUSA_AVAILABLE, reason='requires musa') + def test_bias_act_musa(self): + if _USING_PARROTS: + gradcheck( + bias_act, (self.input_tensor.musa(), self.bias.musa()), + delta=1e-4, + pt_atol=1e-3) + else: + gradcheck( + bias_act, (self.input_tensor.musa(), self.bias.musa()), + eps=1e-4, + atol=1e-3) + + gradgradcheck( + bias_act, (self.input_tensor.musa(), self.bias.musa()), + eps=1e-4, + atol=1e-3) + + out = bias_act(self.input_tensor.musa(), self.bias.musa()) + assert out.shape == (1, 3) + + # test with different dim + input_tensor = torch.randn((1, 1, 3), requires_grad=True).musa() + bias = torch.randn(3, requires_grad=True).musa() + out = bias_act(input_tensor, bias, dim=2) + assert out.shape == (1, 1, 3) + + # test with different act + out = bias_act(self.input_tensor.musa(), self.bias.musa(), act='relu') + assert out.shape == (1, 3) + + out = bias_act(self.input_tensor.musa(), self.bias.musa(), act='lrelu') + assert out.shape == (1, 3) + out = bias_act(self.input_tensor.musa(), self.bias.musa(), act='tanh') + assert out.shape == (1, 3) + out = bias_act( + self.input_tensor.musa(), self.bias.musa(), act='sigmoid') + assert out.shape == (1, 3) + out = bias_act(self.input_tensor.musa(), self.bias.musa(), act='elu') + assert out.shape == (1, 3) + out = bias_act(self.input_tensor.musa(), self.bias.musa(), act='selu') + assert out.shape == (1, 3) + out = bias_act( + self.input_tensor.musa(), self.bias.musa(), act='softplus') + assert out.shape == (1, 3) + out = bias_act(self.input_tensor.musa(), self.bias.musa(), act='swish') + assert out.shape == (1, 3) + + # test with different alpha + out = bias_act( + self.input_tensor.musa(), self.bias.musa(), act='lrelu', alpha=0.1) + assert out.shape == (1, 3) + + # test with different gain + out1 = bias_act( + self.input_tensor.musa(), self.bias.musa(), act='lrelu', gain=0.2) + out2 = bias_act( + self.input_tensor.musa(), self.bias.musa(), act='lrelu', gain=0.1) + assert torch.allclose(out1, out2 * 2) + + # test with different clamp + out1 = bias_act( + self.input_tensor.musa(), self.bias.musa(), act='lrelu', clamp=0.5) + out2 = bias_act( + self.input_tensor.musa(), self.bias.musa(), act='lrelu', clamp=0.2) + assert out1.max() <= 0.5 + assert out2.max() <= 0.5 + def test_easy_dict(self): easy_dict = EasyDict( func=lambda x, **_: x, @@ -142,3 +211,15 @@ class TestBiasAct: _ = easy_dict.def_alpha easy_dict.def_alpha = 1 del easy_dict.def_alpha + + def test_easy_dict_musa(self): + easy_dict = EasyDict( + func=lambda x, **_: x, + def_alpha=0, + def_gain=1, + musa_idx=1, + ref='', + has_2nd_grad=False) + _ = easy_dict.def_alpha + easy_dict.def_alpha = 1 + del easy_dict.def_alpha diff --git a/tests/test_ops/test_border_align.py b/tests/test_ops/test_border_align.py index 71518ce96..f3ca7bd97 100644 --- a/tests/test_ops/test_border_align.py +++ b/tests/test_ops/test_border_align.py @@ -5,6 +5,8 @@ import numpy as np import pytest import torch +from mmcv.utils import IS_MUSA_AVAILABLE + # [1,4c,h,w] input_arr = [[[[1., 2., 3., 4.], [5., 6., 7., 8.], [9., 10., 11., 12.]], [[6, 7, 5, 8], [2, 1, 3, 4], [12, 9, 11, 10]], @@ -51,6 +53,8 @@ input_grad_dict = { def _test_border_align_allclose(device, dtype, pool_size): if not torch.cuda.is_available() and device == 'cuda': pytest.skip('test requires GPU') + elif not IS_MUSA_AVAILABLE and device == 'musa': + pytest.skip('test requires GPU') try: from mmcv.ops import BorderAlign, border_align except ModuleNotFoundError: @@ -84,8 +88,16 @@ def _test_border_align_allclose(device, dtype, pool_size): input.grad.data.type(dtype).cpu().numpy(), np_grad, atol=1e-5) -@pytest.mark.parametrize('device', ['cuda']) -@pytest.mark.parametrize('dtype', [torch.float, torch.half, torch.double]) +@pytest.mark.parametrize('device', ['cuda', 'musa']) +@pytest.mark.parametrize('dtype', [ + torch.float, + torch.half, + pytest.param( + torch.double, + marks=pytest.mark.skipif( + IS_MUSA_AVAILABLE, + reason='MUSA does not support for 64-bit floating point')), +]) @pytest.mark.parametrize('pool_size', [1, 2]) def test_border_align(device, dtype, pool_size): _test_border_align_allclose(device, dtype, pool_size) diff --git a/tests/test_ops/test_box_iou_quadri.py b/tests/test_ops/test_box_iou_quadri.py index 006f04c5c..3f3e60b68 100644 --- a/tests/test_ops/test_box_iou_quadri.py +++ b/tests/test_ops/test_box_iou_quadri.py @@ -3,7 +3,7 @@ import numpy as np import pytest import torch -from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE, IS_NPU_AVAILABLE class TestBoxIoUQuadri: @@ -17,7 +17,11 @@ class TestBoxIoUQuadri: pytest.param( 'npu', marks=pytest.mark.skipif( - not IS_NPU_AVAILABLE, reason='requires NPU support')) + not IS_NPU_AVAILABLE, reason='requires NPU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')) ]) def test_box_iou_quadri_cuda(self, device): from mmcv.ops import box_iou_quadri @@ -55,7 +59,11 @@ class TestBoxIoUQuadri: pytest.param( 'npu', marks=pytest.mark.skipif( - not IS_NPU_AVAILABLE, reason='requires NPU support')) + not IS_NPU_AVAILABLE, reason='requires NPU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')) ]) def test_box_iou_quadri_iof_cuda(self, device): from mmcv.ops import box_iou_quadri diff --git a/tests/test_ops/test_box_iou_rotated.py b/tests/test_ops/test_box_iou_rotated.py index 3af811d0f..04210320e 100644 --- a/tests/test_ops/test_box_iou_rotated.py +++ b/tests/test_ops/test_box_iou_rotated.py @@ -4,7 +4,8 @@ import pytest import torch from mmcv.ops import box_iou_rotated -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE +from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE, + IS_NPU_AVAILABLE) class TestBoxIoURotated: @@ -58,7 +59,11 @@ class TestBoxIoURotated: pytest.param( 'npu', marks=pytest.mark.skipif( - not IS_NPU_AVAILABLE, reason='requires NPU support')) + not IS_NPU_AVAILABLE, reason='requires NPU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')) ]) def test_box_iou_rotated(self, device): np_boxes1 = np.asarray( @@ -145,7 +150,11 @@ class TestBoxIoURotated: pytest.param( 'npu', marks=pytest.mark.skipif( - not IS_NPU_AVAILABLE, reason='requires NPU support')) + not IS_NPU_AVAILABLE, reason='requires NPU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')), ]) def test_box_iou_rotated_iof(self, device): np_boxes1 = np.asarray( diff --git a/tests/test_ops/test_carafe.py b/tests/test_ops/test_carafe.py index 02d00f1ff..b5e7e6bdb 100644 --- a/tests/test_ops/test_carafe.py +++ b/tests/test_ops/test_carafe.py @@ -4,32 +4,34 @@ import pytest import torch from torch.autograd import gradcheck -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE class TestCarafe: def test_carafe_naive_gradcheck(self): - if not torch.cuda.is_available(): + if (not torch.cuda.is_available()) and (not IS_MUSA_AVAILABLE): return from mmcv.ops import CARAFENaive - feat = torch.randn( - 2, 64, 3, 3, requires_grad=True, device='cuda').double() - mask = torch.randn( - 2, 100, 6, 6, requires_grad=True, - device='cuda').sigmoid().double() - gradcheck(CARAFENaive(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4) + if IS_CUDA_AVAILABLE: + feat = torch.randn( + 2, 64, 3, 3, requires_grad=True, device='cuda').double() + mask = torch.randn( + 2, 100, 6, 6, requires_grad=True, + device='cuda').sigmoid().double() + gradcheck(CARAFENaive(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4) def test_carafe_gradcheck(self): - if not torch.cuda.is_available(): + if (not torch.cuda.is_available()) and (not IS_MUSA_AVAILABLE): return from mmcv.ops import CARAFE - feat = torch.randn( - 2, 64, 3, 3, requires_grad=True, device='cuda').double() - mask = torch.randn( - 2, 100, 6, 6, requires_grad=True, - device='cuda').sigmoid().double() - gradcheck(CARAFE(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4) + if IS_CUDA_AVAILABLE: + feat = torch.randn( + 2, 64, 3, 3, requires_grad=True, device='cuda').double() + mask = torch.randn( + 2, 100, 6, 6, requires_grad=True, + device='cuda').sigmoid().double() + gradcheck(CARAFE(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4) @pytest.mark.parametrize('device', [ pytest.param( @@ -39,7 +41,11 @@ class TestCarafe: pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')) ]) def test_carafe_allclose(self, device): try: @@ -64,14 +70,43 @@ class TestCarafe: np_feat_grad = np_feat_grad.reshape((2, 64, 3, 3)) np_mask_grad = np_mask_grad.reshape((2, 100, 6, 6)) - feat = torch.tensor( - np_feat, dtype=torch.float, device=device, requires_grad=True) - mask = torch.tensor( - np_mask, dtype=torch.float, device=device, requires_grad=True) + # feat = torch.tensor( + # np_feat, dtype=torch.float, device=device, requires_grad=True) + # mask = torch.tensor( + # np_mask, dtype=torch.float, device=device, requires_grad=True) + + # feat = torch.tensor( + # np_feat, dtype=torch.float, requires_grad=True).to(device) + # mask = torch.tensor( + # np_mask, dtype=torch.float, requires_grad=True).to(device) + # feat = torch.tensor( + # np_feat, dtype=torch.float).to(device) + # mask = torch.tensor( + # np_mask, dtype=torch.float).to(device) + # feat_cpu = torch.from_numpy(np_feat).to(torch.float) + # mask_cpu = torch.from_numpy(np_mask).to(torch.float) + + # if device == 'musa': + # feat =feat_cpu.musa() + # mask =mask_cpu.musa() + # else: + # feat =feat_cpu.to(device) + # mask =mask_cpu.to(device) + # feat.requires_grad = True + # mask.requires_grad = True + feat_cpu = torch.FloatTensor(np_feat) + mask_cpu = torch.FloatTensor(np_mask) + feat = feat_cpu.to(device) + mask = mask_cpu.to(device) + feat.requires_grad = True + mask.requires_grad = True + # pytest.set_trace() carafe = CARAFE(5, 4, 2) - + carafe.to(device) + carafe.train() output = carafe(feat, mask) + output.backward(torch.ones_like(output)) assert np.allclose( output.data.type(torch.float).cpu().numpy(), np_output, atol=1e-3) diff --git a/tests/test_ops/test_cc_attention.py b/tests/test_ops/test_cc_attention.py index b2a8d22a3..c9851cde2 100644 --- a/tests/test_ops/test_cc_attention.py +++ b/tests/test_ops/test_cc_attention.py @@ -3,6 +3,8 @@ import numpy as np import torch import torch.nn as nn +from mmcv.utils import IS_MUSA_AVAILABLE + class Loss(nn.Module): @@ -20,6 +22,9 @@ class TestCrissCrossAttention: def test_cc_attention(self): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + if IS_MUSA_AVAILABLE: + device = torch.device('musa:0') + from mmcv.ops import CrissCrossAttention loss_func = Loss() diff --git a/tests/test_ops/test_chamfer_distance.py b/tests/test_ops/test_chamfer_distance.py index e5c020ddf..1d20e25c4 100644 --- a/tests/test_ops/test_chamfer_distance.py +++ b/tests/test_ops/test_chamfer_distance.py @@ -4,7 +4,7 @@ import pytest import torch from mmcv.ops import chamfer_distance -from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE, IS_NPU_AVAILABLE def chamfer_distance_forward_groundtruth(xyz1, xyz2, dtype): @@ -51,7 +51,11 @@ def torch_to_np_type(dtype): pytest.param( 'npu', marks=pytest.mark.skipif( - not IS_NPU_AVAILABLE, reason='requires NPU support')) + not IS_NPU_AVAILABLE, reason='requires NPU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')) ]) @pytest.mark.parametrize('dtype', [torch.half, torch.float32]) @pytest.mark.parametrize('shape', [(2, 600, 2), (1, 1, 2), (7, 7, 2)]) diff --git a/tests/test_ops/test_conv_gradfix.py b/tests/test_ops/test_conv_gradfix.py index ff2f35c55..aea9f2b80 100644 --- a/tests/test_ops/test_conv_gradfix.py +++ b/tests/test_ops/test_conv_gradfix.py @@ -5,6 +5,7 @@ import torch.nn as nn from torch.autograd import gradcheck, gradgradcheck from mmcv.ops import conv2d, conv_transpose2d +from mmcv.utils import IS_MUSA_AVAILABLE class TestCond2d: @@ -23,6 +24,15 @@ class TestCond2d: gradcheck(conv2d, (x, weight, None, 1, 1), eps=1e-2, atol=0.1) gradgradcheck(conv2d, (x, weight, None, 1, 1), eps=1e-2, atol=0.1) + @pytest.mark.skipif(not IS_MUSA_AVAILABLE, reason='requires musa') + def test_conv2d_musa(self): + x = self.input.musa() + weight = self.weight.musa() + res = conv2d(x, weight, None, 1, 1) + assert res.shape == (1, 1, 32, 32) + gradcheck(conv2d, (x, weight, None, 1, 1), eps=1e-2, atol=0.1) + gradgradcheck(conv2d, (x, weight, None, 1, 1), eps=1e-2, atol=0.1) + class TestCond2dTansposed: @@ -41,3 +51,14 @@ class TestCond2dTansposed: conv_transpose2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2) gradgradcheck( conv_transpose2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2) + + @pytest.mark.skipif(not IS_MUSA_AVAILABLE, reason='requires musa') + def test_conv2d_transposed_musa(self): + x = self.input.musa() + weight = self.weight.musa() + res = conv_transpose2d(x, weight, None, 1, 1) + assert res.shape == (1, 1, 32, 32) + gradcheck( + conv_transpose2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2) + gradgradcheck( + conv_transpose2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2) diff --git a/tests/test_ops/test_convex_iou.py b/tests/test_ops/test_convex_iou.py index 95dc48243..99ac60ec3 100644 --- a/tests/test_ops/test_convex_iou.py +++ b/tests/test_ops/test_convex_iou.py @@ -4,6 +4,7 @@ import pytest import torch from mmcv.ops import convex_giou, convex_iou +from mmcv.utils import IS_MUSA_AVAILABLE np_pointsets = np.asarray([[ 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0, 1.0, 1.0, 3.0, 3.0, 1.0, 2.0, 3.0, 3.0, @@ -54,3 +55,12 @@ def test_convex_giou(): giou, grad = convex_giou(pointsets, polygons) assert torch.allclose(giou, expected_giou, atol=1e-3) assert torch.allclose(grad, expected_grad, atol=1e-3) + + +@pytest.mark.skipif(not IS_MUSA_AVAILABLE, reason='requires musa') +def test_convex_miou(): + pointsets = torch.from_numpy(np_pointsets).musa().float() + polygons = torch.from_numpy(np_polygons).musa().float() + expected_iou = torch.from_numpy(np_expected_iou).musa().float() + assert torch.allclose( + convex_iou(pointsets, polygons), expected_iou, atol=1e-3) diff --git a/tests/test_ops/test_correlation.py b/tests/test_ops/test_correlation.py index 6cf5f9f72..5b011abc4 100644 --- a/tests/test_ops/test_correlation.py +++ b/tests/test_ops/test_correlation.py @@ -3,6 +3,7 @@ import pytest import torch from mmcv.ops import Correlation +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE _input1 = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]] _input2 = [[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]]]] @@ -23,24 +24,33 @@ class TestCorrelation: layer = Correlation(max_displacement=0) - input1 = torch.tensor(_input1, dtype=dtype).cuda() - input2 = torch.tensor(_input2, dtype=dtype).cuda() + if IS_CUDA_AVAILABLE: + input1 = torch.tensor(_input1, dtype=dtype).cuda() + input2 = torch.tensor(_input2, dtype=dtype).cuda() + elif IS_MUSA_AVAILABLE: + input1 = torch.tensor(_input1, dtype=dtype).musa() + input2 = torch.tensor(_input2, dtype=dtype).musa() input1.requires_grad = True input2.requires_grad = True out = layer(input1, input2) out.backward(torch.ones_like(out)) # `eq_cpu` is not implemented for 'Half' in torch1.5.0, - # so we need to make a comparison for cuda tensor + # so we need to make a comparison for cuda/musa tensor # rather than cpu tensor - gt_out = torch.tensor(_gt_out, dtype=dtype).cuda() + if IS_CUDA_AVAILABLE: + gt_out = torch.tensor(_gt_out, dtype=dtype).cuda() + elif IS_MUSA_AVAILABLE: + gt_out = torch.tensor(_gt_out, dtype=dtype).musa() assert_equal_tensor(out, gt_out) assert_equal_tensor(input1.grad.detach(), input2) assert_equal_tensor(input2.grad.detach(), input1) @pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') + (not torch.cuda.is_available()) and (not IS_MUSA_AVAILABLE), + reason='requires CUDA/MUSA support') def test_correlation(self): self._test_correlation(torch.float) - self._test_correlation(torch.double) + if IS_CUDA_AVAILABLE: + self._test_correlation(torch.double) self._test_correlation(torch.half) diff --git a/tests/test_ops/test_deform_conv.py b/tests/test_ops/test_deform_conv.py index 7f2801fdc..db62dd441 100644 --- a/tests/test_ops/test_deform_conv.py +++ b/tests/test_ops/test_deform_conv.py @@ -5,17 +5,24 @@ import torch from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE if IS_MLU_AVAILABLE: torch.backends.cnnl.allow_tf32 = False -try: - # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast - # would be imported and used; we should test if our modules support it. - from torch.cuda.amp import autocast -except ImportError: - pass +if IS_MUSA_AVAILABLE: + try: + from torch_musa.core.amp import autocast + except ImportError: + pass +else: + try: + # If PyTorch version >= 1.6.0 and fp16 is enabled + # torch.cuda.amp.autocast would be imported and used + # we should test if our modules support it. + from torch.cuda.amp import autocast + except ImportError: + pass input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]] offset_weight = [[[0.1, 0.4, 0.6, 0.1]], [[0.3, 0.2, 0.1, 0.3]], @@ -79,6 +86,8 @@ class TestDeformconv: model.cuda() elif device == 'mlu': model.mlu() + elif device == 'musa': + model.musa() model.type(dtype) out = model(x) @@ -161,6 +170,8 @@ class TestDeformconv: model.cuda() elif device == 'mlu': model.mlu() + elif device == 'musa': + model.musa() out = model(x) out.backward(torch.ones_like(out)) @@ -194,19 +205,24 @@ class TestDeformconv: with pytest.raises(AssertionError): model = DeformConv2d(3, 4, 3, groups=3) - @pytest.mark.parametrize('device, threshold', [ - ('cpu', 1e-1), - pytest.param( - 'cuda', - 1e-3, - marks=pytest.mark.skipif( - not IS_CUDA_AVAILABLE, reason='requires CUDA support')), - pytest.param( - 'mlu', - 1e-3, - marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')), - ]) + @pytest.mark.parametrize( + 'device, threshold', + [('cpu', 1e-1), + pytest.param( + 'cuda', + 1e-3, + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + 1e-3, + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'musa', + 1e-3, + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support'))]) def test_deformconv_float(self, device, threshold): self._test_deformconv(torch.float, device=device, threshold=threshold) # test batch_size < im2col_step @@ -244,6 +260,11 @@ class TestDeformconv: 1e-1, marks=pytest.mark.skipif( not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'musa', + 1e-1, + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')) ]) def test_deformconv_half(self, device, threshold): self._test_deformconv(torch.half, device=device, threshold=threshold) diff --git a/tests/test_ops/test_deform_roi_pool.py b/tests/test_ops/test_deform_roi_pool.py index 346301fe4..55df5f967 100644 --- a/tests/test_ops/test_deform_roi_pool.py +++ b/tests/test_ops/test_deform_roi_pool.py @@ -5,7 +5,8 @@ import numpy as np import pytest import torch -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE +from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE, + IS_NPU_AVAILABLE) _USING_PARROTS = True try: @@ -137,16 +138,20 @@ class TestDeformRoIPool: pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')), ]) @pytest.mark.parametrize('dtype', [ torch.float, pytest.param( torch.double, marks=pytest.mark.skipif( - IS_MLU_AVAILABLE, - reason='MLU does not support for 64-bit floating point')), - torch.half + IS_MLU_AVAILABLE or IS_MUSA_AVAILABLE, + reason='MLU, MUSA does not support for 64-bit floating point'), + ), torch.half ]) def test_deform_roi_pool_allclose(self, device, dtype): self._test_deform_roi_pool_allclose(device, dtype) diff --git a/tests/test_ops/test_diff_iou_rotated.py b/tests/test_ops/test_diff_iou_rotated.py index 5688bb876..653236fba 100644 --- a/tests/test_ops/test_diff_iou_rotated.py +++ b/tests/test_ops/test_diff_iou_rotated.py @@ -4,12 +4,13 @@ import pytest import torch from mmcv.ops import diff_iou_rotated_2d, diff_iou_rotated_3d -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE if IS_MLU_AVAILABLE: torch.backends.mlu.matmul.allow_tf32 = False +# TODO @MTAI there are some bugs for musa! @pytest.mark.parametrize('device', [ pytest.param( 'cuda', @@ -18,7 +19,11 @@ if IS_MLU_AVAILABLE: pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')) ]) def test_diff_iou_rotated_2d(device): np_boxes1 = np.asarray([[[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0], @@ -47,7 +52,11 @@ def test_diff_iou_rotated_2d(device): pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')) ]) def test_diff_iou_rotated_3d(device): np_boxes1 = np.asarray(