Add type hints in mmcv/ops/points_sampler.py (#2015)

* add typehint in mmcv/ops/points_sampler.py

* Update mmcv/ops/points_sampler.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/ops/points_sampler.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
gy77 2022-05-28 15:35:17 +08:00 committed by GitHub
parent 1211b06b9e
commit 19902d897a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
from typing import List
import torch
from torch import Tensor
from torch import nn as nn
from mmcv.runner import force_fp32
@ -8,7 +9,9 @@ from .furthest_point_sample import (furthest_point_sample,
furthest_point_sample_with_dist)
def calc_square_dist(point_feat_a, point_feat_b, norm=True):
def calc_square_dist(point_feat_a: Tensor,
point_feat_b: Tensor,
norm: bool = True) -> Tensor:
"""Calculating square distance between a and b.
Args:
@ -34,7 +37,7 @@ def calc_square_dist(point_feat_a, point_feat_b, norm=True):
return dist
def get_sampler_cls(sampler_type):
def get_sampler_cls(sampler_type: str) -> nn.Module:
"""Get the type and mode of points sampler.
Args:
@ -74,7 +77,7 @@ class PointsSampler(nn.Module):
def __init__(self,
num_point: List[int],
fps_mod_list: List[str] = ['D-FPS'],
fps_sample_range_list: List[int] = [-1]):
fps_sample_range_list: List[int] = [-1]) -> None:
super().__init__()
# FPS would be applied to different fps_mod in the list,
# so the length of the num_point should be equal to
@ -89,7 +92,7 @@ class PointsSampler(nn.Module):
self.fp16_enabled = False
@force_fp32()
def forward(self, points_xyz, features):
def forward(self, points_xyz: Tensor, features: Tensor) -> Tensor:
"""
Args:
points_xyz (torch.Tensor): (B, N, 3) xyz coordinates of
@ -134,10 +137,10 @@ class PointsSampler(nn.Module):
class DFPSSampler(nn.Module):
"""Using Euclidean distances of points for FPS."""
def __init__(self):
def __init__(self) -> None:
super().__init__()
def forward(self, points, features, npoint):
def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
"""Sampling points with D-FPS."""
fps_idx = furthest_point_sample(points.contiguous(), npoint)
return fps_idx
@ -146,10 +149,10 @@ class DFPSSampler(nn.Module):
class FFPSSampler(nn.Module):
"""Using feature distances for FPS."""
def __init__(self):
def __init__(self) -> None:
super().__init__()
def forward(self, points, features, npoint):
def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
"""Sampling points with F-FPS."""
assert features is not None, \
'feature input to FFPS_Sampler should not be None'
@ -163,10 +166,10 @@ class FFPSSampler(nn.Module):
class FSSampler(nn.Module):
"""Using F-FPS and D-FPS simultaneously."""
def __init__(self):
def __init__(self) -> None:
super().__init__()
def forward(self, points, features, npoint):
def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
"""Sampling points with FS_Sampling."""
assert features is not None, \
'feature input to FS_Sampler should not be None'