mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
Add type hints in mmcv/ops/points_sampler.py (#2015)
* add typehint in mmcv/ops/points_sampler.py * Update mmcv/ops/points_sampler.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmcv/ops/points_sampler.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
parent
1211b06b9e
commit
19902d897a
@ -1,6 +1,7 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn as nn
|
||||
|
||||
from mmcv.runner import force_fp32
|
||||
@ -8,7 +9,9 @@ from .furthest_point_sample import (furthest_point_sample,
|
||||
furthest_point_sample_with_dist)
|
||||
|
||||
|
||||
def calc_square_dist(point_feat_a, point_feat_b, norm=True):
|
||||
def calc_square_dist(point_feat_a: Tensor,
|
||||
point_feat_b: Tensor,
|
||||
norm: bool = True) -> Tensor:
|
||||
"""Calculating square distance between a and b.
|
||||
|
||||
Args:
|
||||
@ -34,7 +37,7 @@ def calc_square_dist(point_feat_a, point_feat_b, norm=True):
|
||||
return dist
|
||||
|
||||
|
||||
def get_sampler_cls(sampler_type):
|
||||
def get_sampler_cls(sampler_type: str) -> nn.Module:
|
||||
"""Get the type and mode of points sampler.
|
||||
|
||||
Args:
|
||||
@ -74,7 +77,7 @@ class PointsSampler(nn.Module):
|
||||
def __init__(self,
|
||||
num_point: List[int],
|
||||
fps_mod_list: List[str] = ['D-FPS'],
|
||||
fps_sample_range_list: List[int] = [-1]):
|
||||
fps_sample_range_list: List[int] = [-1]) -> None:
|
||||
super().__init__()
|
||||
# FPS would be applied to different fps_mod in the list,
|
||||
# so the length of the num_point should be equal to
|
||||
@ -89,7 +92,7 @@ class PointsSampler(nn.Module):
|
||||
self.fp16_enabled = False
|
||||
|
||||
@force_fp32()
|
||||
def forward(self, points_xyz, features):
|
||||
def forward(self, points_xyz: Tensor, features: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
points_xyz (torch.Tensor): (B, N, 3) xyz coordinates of
|
||||
@ -134,10 +137,10 @@ class PointsSampler(nn.Module):
|
||||
class DFPSSampler(nn.Module):
|
||||
"""Using Euclidean distances of points for FPS."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, points, features, npoint):
|
||||
def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
|
||||
"""Sampling points with D-FPS."""
|
||||
fps_idx = furthest_point_sample(points.contiguous(), npoint)
|
||||
return fps_idx
|
||||
@ -146,10 +149,10 @@ class DFPSSampler(nn.Module):
|
||||
class FFPSSampler(nn.Module):
|
||||
"""Using feature distances for FPS."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, points, features, npoint):
|
||||
def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
|
||||
"""Sampling points with F-FPS."""
|
||||
assert features is not None, \
|
||||
'feature input to FFPS_Sampler should not be None'
|
||||
@ -163,10 +166,10 @@ class FFPSSampler(nn.Module):
|
||||
class FSSampler(nn.Module):
|
||||
"""Using F-FPS and D-FPS simultaneously."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, points, features, npoint):
|
||||
def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
|
||||
"""Sampling points with FS_Sampling."""
|
||||
assert features is not None, \
|
||||
'feature input to FS_Sampler should not be None'
|
||||
|
Loading…
x
Reference in New Issue
Block a user