mirror of https://github.com/open-mmlab/mmcv.git
[Enhancement] support dipu_mock_cuda=False in dipu for mmcv ext ops with cpu fallback (#2839)
* Support cuda_mock * Fix * Support roi_align * Fix lint * Support bboxpull/2851/head
parent
fc038a386a
commit
99a8d05766
|
@ -33,7 +33,9 @@ void bbox_overlaps_diopi(const Tensor bboxes1, const Tensor bboxes2,
|
||||||
diopiContextHandle_t ch = &ctx;
|
diopiContextHandle_t ch = &ctx;
|
||||||
auto bboxes2_p = toDiopiTensorHandle(bboxes2);
|
auto bboxes2_p = toDiopiTensorHandle(bboxes2);
|
||||||
auto ious_p = toDiopiTensorHandle(ious);
|
auto ious_p = toDiopiTensorHandle(ious);
|
||||||
if (reinterpret_cast<void *>(diopiBboxOverlapsMmcv) != nullptr) {
|
bool is_mock_cuda = bboxes1.device().type() == c10::DeviceType::PrivateUse1;
|
||||||
|
if (is_mock_cuda &&
|
||||||
|
reinterpret_cast<void *>(diopiBboxOverlapsMmcv) != nullptr) {
|
||||||
auto ret = diopiBboxOverlapsMmcv(ch, ious_p, bboxes1_p, bboxes2_p, mode,
|
auto ret = diopiBboxOverlapsMmcv(ch, ious_p, bboxes1_p, bboxes2_p, mode,
|
||||||
offset, aligned);
|
offset, aligned);
|
||||||
if (ret == diopiSuccess) return;
|
if (ret == diopiSuccess) return;
|
||||||
|
|
|
@ -42,7 +42,8 @@ Tensor nms_diopi(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
|
||||||
auto outp = toDiopiTensorHandle(out);
|
auto outp = toDiopiTensorHandle(out);
|
||||||
diopiTensorHandle_t* outhandle = &outp;
|
diopiTensorHandle_t* outhandle = &outp;
|
||||||
auto scores_p = toDiopiTensorHandle(scores);
|
auto scores_p = toDiopiTensorHandle(scores);
|
||||||
if (reinterpret_cast<void*>(diopiNmsMmcv) != nullptr) {
|
bool is_mock_cuda = boxes.device().type() == c10::DeviceType::PrivateUse1;
|
||||||
|
if (is_mock_cuda && reinterpret_cast<void*>(diopiNmsMmcv) != nullptr) {
|
||||||
auto ret =
|
auto ret =
|
||||||
diopiNmsMmcv(ch, outhandle, boxes_p, scores_p, iou_threshold, offset);
|
diopiNmsMmcv(ch, outhandle, boxes_p, scores_p, iou_threshold, offset);
|
||||||
if (ret == diopiSuccess) {
|
if (ret == diopiSuccess) {
|
||||||
|
|
|
@ -53,7 +53,8 @@ void roi_align_forward_diopi(Tensor input, Tensor rois, Tensor output,
|
||||||
auto out_p = toDiopiTensorHandle(output);
|
auto out_p = toDiopiTensorHandle(output);
|
||||||
auto argmax_y_p = toDiopiTensorHandle(argmax_y);
|
auto argmax_y_p = toDiopiTensorHandle(argmax_y);
|
||||||
auto argmax_x_p = toDiopiTensorHandle(argmax_x);
|
auto argmax_x_p = toDiopiTensorHandle(argmax_x);
|
||||||
if (reinterpret_cast<void*>(diopiRoiAlignMmcv) != nullptr) {
|
bool is_mock_cuda = input.device().type() == c10::DeviceType::PrivateUse1;
|
||||||
|
if (is_mock_cuda && reinterpret_cast<void*>(diopiRoiAlignMmcv) != nullptr) {
|
||||||
auto ret = diopiRoiAlignMmcv(
|
auto ret = diopiRoiAlignMmcv(
|
||||||
ch, out_p, argmax_y_p, argmax_x_p, input_p, rois_p, aligned_height,
|
ch, out_p, argmax_y_p, argmax_x_p, input_p, rois_p, aligned_height,
|
||||||
aligned_width, sampling_ratio, pool_mode, spatial_scale, aligned);
|
aligned_width, sampling_ratio, pool_mode, spatial_scale, aligned);
|
||||||
|
@ -91,7 +92,10 @@ void roi_align_backward_diopi(Tensor grad_output, Tensor rois, Tensor argmax_y,
|
||||||
auto grad_input_ = toDiopiTensorHandle(grad_input);
|
auto grad_input_ = toDiopiTensorHandle(grad_input);
|
||||||
diopiContext ctx(dipu::getCurrentDIPUStream().rawstream());
|
diopiContext ctx(dipu::getCurrentDIPUStream().rawstream());
|
||||||
diopiContextHandle_t ch = &ctx;
|
diopiContextHandle_t ch = &ctx;
|
||||||
if (reinterpret_cast<void*>(diopiRoiAlignBackwardMmcv) != nullptr) {
|
bool is_mock_cuda =
|
||||||
|
grad_output.device().type() == c10::DeviceType::PrivateUse1;
|
||||||
|
if (is_mock_cuda &&
|
||||||
|
reinterpret_cast<void*>(diopiRoiAlignBackwardMmcv) != nullptr) {
|
||||||
auto ret = diopiRoiAlignBackwardMmcv(ch, grad_input_, grad_output_, rois_,
|
auto ret = diopiRoiAlignBackwardMmcv(ch, grad_input_, grad_output_, rois_,
|
||||||
argmax_y_, argmax_x_, aligned_height,
|
argmax_y_, argmax_x_, aligned_height,
|
||||||
aligned_width, sampling_ratio,
|
aligned_width, sampling_ratio,
|
||||||
|
|
Loading…
Reference in New Issue