mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
Add type hints in mmcv/ops/points_sampler.py (#2015)
* add typehint in mmcv/ops/points_sampler.py * Update mmcv/ops/points_sampler.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmcv/ops/points_sampler.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
parent
1211b06b9e
commit
19902d897a
@ -1,6 +1,7 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
from mmcv.runner import force_fp32
|
from mmcv.runner import force_fp32
|
||||||
@ -8,7 +9,9 @@ from .furthest_point_sample import (furthest_point_sample,
|
|||||||
furthest_point_sample_with_dist)
|
furthest_point_sample_with_dist)
|
||||||
|
|
||||||
|
|
||||||
def calc_square_dist(point_feat_a, point_feat_b, norm=True):
|
def calc_square_dist(point_feat_a: Tensor,
|
||||||
|
point_feat_b: Tensor,
|
||||||
|
norm: bool = True) -> Tensor:
|
||||||
"""Calculating square distance between a and b.
|
"""Calculating square distance between a and b.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -34,7 +37,7 @@ def calc_square_dist(point_feat_a, point_feat_b, norm=True):
|
|||||||
return dist
|
return dist
|
||||||
|
|
||||||
|
|
||||||
def get_sampler_cls(sampler_type):
|
def get_sampler_cls(sampler_type: str) -> nn.Module:
|
||||||
"""Get the type and mode of points sampler.
|
"""Get the type and mode of points sampler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -74,7 +77,7 @@ class PointsSampler(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_point: List[int],
|
num_point: List[int],
|
||||||
fps_mod_list: List[str] = ['D-FPS'],
|
fps_mod_list: List[str] = ['D-FPS'],
|
||||||
fps_sample_range_list: List[int] = [-1]):
|
fps_sample_range_list: List[int] = [-1]) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# FPS would be applied to different fps_mod in the list,
|
# FPS would be applied to different fps_mod in the list,
|
||||||
# so the length of the num_point should be equal to
|
# so the length of the num_point should be equal to
|
||||||
@ -89,7 +92,7 @@ class PointsSampler(nn.Module):
|
|||||||
self.fp16_enabled = False
|
self.fp16_enabled = False
|
||||||
|
|
||||||
@force_fp32()
|
@force_fp32()
|
||||||
def forward(self, points_xyz, features):
|
def forward(self, points_xyz: Tensor, features: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
points_xyz (torch.Tensor): (B, N, 3) xyz coordinates of
|
points_xyz (torch.Tensor): (B, N, 3) xyz coordinates of
|
||||||
@ -134,10 +137,10 @@ class PointsSampler(nn.Module):
|
|||||||
class DFPSSampler(nn.Module):
|
class DFPSSampler(nn.Module):
|
||||||
"""Using Euclidean distances of points for FPS."""
|
"""Using Euclidean distances of points for FPS."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, points, features, npoint):
|
def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
|
||||||
"""Sampling points with D-FPS."""
|
"""Sampling points with D-FPS."""
|
||||||
fps_idx = furthest_point_sample(points.contiguous(), npoint)
|
fps_idx = furthest_point_sample(points.contiguous(), npoint)
|
||||||
return fps_idx
|
return fps_idx
|
||||||
@ -146,10 +149,10 @@ class DFPSSampler(nn.Module):
|
|||||||
class FFPSSampler(nn.Module):
|
class FFPSSampler(nn.Module):
|
||||||
"""Using feature distances for FPS."""
|
"""Using feature distances for FPS."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, points, features, npoint):
|
def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
|
||||||
"""Sampling points with F-FPS."""
|
"""Sampling points with F-FPS."""
|
||||||
assert features is not None, \
|
assert features is not None, \
|
||||||
'feature input to FFPS_Sampler should not be None'
|
'feature input to FFPS_Sampler should not be None'
|
||||||
@ -163,10 +166,10 @@ class FFPSSampler(nn.Module):
|
|||||||
class FSSampler(nn.Module):
|
class FSSampler(nn.Module):
|
||||||
"""Using F-FPS and D-FPS simultaneously."""
|
"""Using F-FPS and D-FPS simultaneously."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, points, features, npoint):
|
def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
|
||||||
"""Sampling points with FS_Sampling."""
|
"""Sampling points with FS_Sampling."""
|
||||||
assert features is not None, \
|
assert features is not None, \
|
||||||
'feature input to FS_Sampler should not be None'
|
'feature input to FS_Sampler should not be None'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user