Add type hints in mmcv/ops/points_sampler.py (#2015)

* add typehint in mmcv/ops/points_sampler.py

* Update mmcv/ops/points_sampler.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/ops/points_sampler.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
gy77 2022-05-28 15:35:17 +08:00 committed by GitHub
parent 1211b06b9e
commit 19902d897a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
from typing import List from typing import List
import torch import torch
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmcv.runner import force_fp32 from mmcv.runner import force_fp32
@ -8,7 +9,9 @@ from .furthest_point_sample import (furthest_point_sample,
furthest_point_sample_with_dist) furthest_point_sample_with_dist)
def calc_square_dist(point_feat_a, point_feat_b, norm=True): def calc_square_dist(point_feat_a: Tensor,
point_feat_b: Tensor,
norm: bool = True) -> Tensor:
"""Calculating square distance between a and b. """Calculating square distance between a and b.
Args: Args:
@ -34,7 +37,7 @@ def calc_square_dist(point_feat_a, point_feat_b, norm=True):
return dist return dist
def get_sampler_cls(sampler_type): def get_sampler_cls(sampler_type: str) -> nn.Module:
"""Get the type and mode of points sampler. """Get the type and mode of points sampler.
Args: Args:
@ -74,7 +77,7 @@ class PointsSampler(nn.Module):
def __init__(self, def __init__(self,
num_point: List[int], num_point: List[int],
fps_mod_list: List[str] = ['D-FPS'], fps_mod_list: List[str] = ['D-FPS'],
fps_sample_range_list: List[int] = [-1]): fps_sample_range_list: List[int] = [-1]) -> None:
super().__init__() super().__init__()
# FPS would be applied to different fps_mod in the list, # FPS would be applied to different fps_mod in the list,
# so the length of the num_point should be equal to # so the length of the num_point should be equal to
@ -89,7 +92,7 @@ class PointsSampler(nn.Module):
self.fp16_enabled = False self.fp16_enabled = False
@force_fp32() @force_fp32()
def forward(self, points_xyz, features): def forward(self, points_xyz: Tensor, features: Tensor) -> Tensor:
""" """
Args: Args:
points_xyz (torch.Tensor): (B, N, 3) xyz coordinates of points_xyz (torch.Tensor): (B, N, 3) xyz coordinates of
@ -134,10 +137,10 @@ class PointsSampler(nn.Module):
class DFPSSampler(nn.Module): class DFPSSampler(nn.Module):
"""Using Euclidean distances of points for FPS.""" """Using Euclidean distances of points for FPS."""
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
def forward(self, points, features, npoint): def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
"""Sampling points with D-FPS.""" """Sampling points with D-FPS."""
fps_idx = furthest_point_sample(points.contiguous(), npoint) fps_idx = furthest_point_sample(points.contiguous(), npoint)
return fps_idx return fps_idx
@ -146,10 +149,10 @@ class DFPSSampler(nn.Module):
class FFPSSampler(nn.Module): class FFPSSampler(nn.Module):
"""Using feature distances for FPS.""" """Using feature distances for FPS."""
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
def forward(self, points, features, npoint): def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
"""Sampling points with F-FPS.""" """Sampling points with F-FPS."""
assert features is not None, \ assert features is not None, \
'feature input to FFPS_Sampler should not be None' 'feature input to FFPS_Sampler should not be None'
@ -163,10 +166,10 @@ class FFPSSampler(nn.Module):
class FSSampler(nn.Module): class FSSampler(nn.Module):
"""Using F-FPS and D-FPS simultaneously.""" """Using F-FPS and D-FPS simultaneously."""
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
def forward(self, points, features, npoint): def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
"""Sampling points with FS_Sampling.""" """Sampling points with FS_Sampling."""
assert features is not None, \ assert features is not None, \
'feature input to FS_Sampler should not be None' 'feature input to FS_Sampler should not be None'