mirror of https://github.com/open-mmlab/mmcv.git
parent
4e8972fbf9
commit
1de3aeffd7
|
@ -116,6 +116,10 @@ def bbox_overlaps(bboxes1: torch.Tensor,
|
|||
if rows * cols == 0:
|
||||
return ious
|
||||
|
||||
if bboxes1.device.type == 'cpu' and torch.__version__ == 'parrots':
|
||||
return _bbox_overlaps_cpu(
|
||||
bboxes1, bboxes2, mode=mode, aligned=aligned, offset=offset)
|
||||
|
||||
ext_module.bbox_overlaps(
|
||||
bboxes1, bboxes2, ious, mode=mode_flag, aligned=aligned, offset=offset)
|
||||
|
||||
|
|
1
setup.py
1
setup.py
|
@ -212,6 +212,7 @@ def get_extensions():
|
|||
glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') +\
|
||||
glob.glob('./mmcv/ops/csrc/parrots/*.cpp')
|
||||
op_files.remove('./mmcv/ops/csrc/pytorch/cuda/iou3d_cuda.cu')
|
||||
op_files.remove('./mmcv/ops/csrc/pytorch/cpu/bbox_overlaps_cpu.cpp')
|
||||
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
|
||||
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/cuda'))
|
||||
cuda_args = os.getenv('MMCV_CUDA_ARGS')
|
||||
|
|
Loading…
Reference in New Issue