diff --git a/mmcv/ops/points_in_boxes.py b/mmcv/ops/points_in_boxes.py index 737179a13..4915e6b57 100644 --- a/mmcv/ops/points_in_boxes.py +++ b/mmcv/ops/points_in_boxes.py @@ -1,4 +1,5 @@ import torch +from torch import Tensor from ..utils import ext_loader @@ -8,7 +9,7 @@ ext_module = ext_loader.load_ext('_ext', [ ]) -def points_in_boxes_part(points, boxes): +def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor: """Find the box in which each point is (CUDA). Args: @@ -56,7 +57,7 @@ def points_in_boxes_part(points, boxes): return box_idxs_of_pts -def points_in_boxes_cpu(points, boxes): +def points_in_boxes_cpu(points: Tensor, boxes: Tensor) -> Tensor: """Find all boxes in which each point is (CPU). The CPU version of :meth:`points_in_boxes_all`. @@ -94,7 +95,7 @@ def points_in_boxes_cpu(points, boxes): return point_indices -def points_in_boxes_all(points, boxes): +def points_in_boxes_all(points: Tensor, boxes: Tensor) -> Tensor: """Find all boxes in which each point is (CUDA). Args: