[Feature] Add diff_iiou_rotated op in parrots (#1911)

pull/1935/head
pc 2022-04-30 09:43:33 +08:00 committed by GitHub
parent 057c032347
commit 9f5a03dc2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 85 additions and 15 deletions

View File

@ -924,20 +924,20 @@ REGISTER_DEVICE_IMPL(roi_align_forward_impl, CUDA, roi_align_forward_cuda);
REGISTER_DEVICE_IMPL(roi_align_backward_impl, CUDA, roi_align_backward_cuda);
void ROIAlignRotatedForwardCUDAKernelLauncher(
const at::Tensor features, const at::Tensor rois, const float spatial_scale,
const int sample_num, const bool aligned, const bool clockwise,
const at::Tensor input, const at::Tensor rois, const float spatial_scale,
const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor output);
void ROIAlignRotatedBackwardCUDAKernelLauncher(
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
const int sample_num, const bool aligned, const bool clockwise,
const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor bottom_grad);
void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output,
void roi_align_rotated_forward_cuda(Tensor input, Tensor rois, Tensor output,
int aligned_height, int aligned_width,
float spatial_scale, int sample_ratio,
float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise) {
// Number of ROIs
int num_rois = rois.size(0);
@ -947,11 +947,11 @@ void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output,
AT_ERROR("wrong roi size");
}
int num_channels = features.size(1);
int data_height = features.size(2);
int data_width = features.size(3);
int num_channels = input.size(1);
int data_height = input.size(2);
int data_width = input.size(3);
ROIAlignRotatedForwardCUDAKernelLauncher(
features, rois, spatial_scale, sample_ratio, aligned, clockwise,
input, rois, spatial_scale, sampling_ratio, aligned, clockwise,
num_channels, data_height, data_width, num_rois, aligned_height,
aligned_width, output);
}
@ -959,7 +959,7 @@ void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output,
void roi_align_rotated_backward_cuda(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int aligned_height,
int aligned_width, float spatial_scale,
int sample_ratio, bool aligned,
int sampling_ratio, bool aligned,
bool clockwise) {
// Number of ROIs
int num_rois = rois.size(0);
@ -972,20 +972,20 @@ void roi_align_rotated_backward_cuda(Tensor top_grad, Tensor rois,
int data_height = bottom_grad.size(2);
int data_width = bottom_grad.size(3);
ROIAlignRotatedBackwardCUDAKernelLauncher(
top_grad, rois, spatial_scale, sample_ratio, aligned, clockwise,
top_grad, rois, spatial_scale, sampling_ratio, aligned, clockwise,
num_channels, data_height, data_width, num_rois, aligned_height,
aligned_width, bottom_grad);
}
void roi_align_rotated_forward_impl(Tensor features, Tensor rois, Tensor output,
void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output,
int aligned_height, int aligned_width,
float spatial_scale, int sample_ratio,
float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise);
void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int aligned_height,
int aligned_width, float spatial_scale,
int sample_ratio, bool aligned,
int sampling_ratio, bool aligned,
bool clockwise);
REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, CUDA,
roi_align_rotated_forward_cuda);
@ -1564,3 +1564,19 @@ void convex_giou_impl(const Tensor pointsets, const Tensor polygons,
REGISTER_DEVICE_IMPL(convex_iou_impl, CUDA, convex_iou_cuda);
REGISTER_DEVICE_IMPL(convex_giou_impl, CUDA, convex_giou_cuda);
Tensor DiffIoURotatedSortVerticesCUDAKernelLauncher(Tensor vertices,
Tensor mask,
Tensor num_valid);
Tensor diff_iou_rotated_sort_vertices_forward_cuda(Tensor vertices, Tensor mask,
Tensor num_valid) {
return DiffIoURotatedSortVerticesCUDAKernelLauncher(vertices, mask,
num_valid);
}
Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,
Tensor num_valid);
REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, CUDA,
diff_iou_rotated_sort_vertices_forward_cuda);

View File

@ -0,0 +1,14 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,
Tensor num_valid) {
return DISPATCH_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl,
vertices, mask, num_valid);
}
Tensor diff_iou_rotated_sort_vertices_forward(Tensor vertices, Tensor mask,
Tensor num_valid) {
return diff_iou_rotated_sort_vertices_forward_impl(vertices, mask, num_valid);
}

View File

@ -0,0 +1,28 @@
// Copyright (c) OpenMMLab. All rights reserved
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>
#include "diff_iou_rotated_pytorch.h"
using namespace parrots;
#ifdef MMCV_WITH_CUDA
void diff_iou_rotated_sort_vertices_forward_cuda_parrots(
CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
at::Tensor boxes, scores, dets;
auto vertices = buildATensor(ctx, ins[0]);
auto mask = buildATensor(ctx, ins[1]);
auto num_valid = buildATensor(ctx, ins[2]);
auto out =
diff_iou_rotated_sort_vertices_forward_cuda(vertices, mask, num_valid);
updateDArray(ctx, out, outs[0]);
}
PARROTS_EXTENSION_REGISTER(diff_iou_rotated_sort_vertices_forward)
.input(3)
.output(1)
.apply(diff_iou_rotated_sort_vertices_forward_cuda_parrots)
.done();
#endif

View File

@ -0,0 +1,10 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef DIFF_IOU_ROTATED_PYTORCH_H
#define DIFF_IOU_ROTATED_PYTORCH_H
#include <torch/extension.h>
using namespace at;
Tensor diff_iou_rotated_sort_vertices_forward_cuda(Tensor vertices, Tensor mask,
Tensor num_valid);
#endif // DIFF_IOU_ROTATED_PYTORCH_H

View File

@ -17,7 +17,8 @@ class SortVertices(Function):
def forward(ctx, vertices, mask, num_valid):
idx = ext_module.diff_iou_rotated_sort_vertices_forward(
vertices, mask, num_valid)
ctx.mark_non_differentiable(idx)
if torch.__version__ != 'parrots':
ctx.mark_non_differentiable(idx)
return idx
@staticmethod

View File

@ -36,6 +36,7 @@ else:
'ms_deform_attn_forward',
'pixel_group',
'contour_expand',
'diff_iou_rotated_sort_vertices_forward',
]
def get_fake_func(name, e):