Add type hints for mmcv/ops (#1995)

* Merge Master

* Add typehint in mmcv/ops/*

* Fix

* Update mmcv/ops/roi_align.py

Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>

* Fix

* Fix

* Fix

* Update mmcv/ops/riroi_align_rotated.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/ops/riroi_align_rotated.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* remove type hints of all symbolic methods

* remove type hints of all symbolic methods

* minor refinement

* minor refinement

* minor fix

Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: zhouzaida <zhouzaida@163.com>
pull/2070/head
tripleMu 2022-06-20 14:50:28 +08:00 committed by GitHub
parent 3dd2a21b45
commit 2d3e42fc41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 239 additions and 145 deletions

View File

@ -1,4 +1,7 @@
# Modified from https://github.com/hszhao/semseg/blob/master/lib/psa
from typing import Optional, Tuple
import torch
from torch import nn
from torch.autograd import Function
from torch.nn.modules.utils import _pair
@ -20,7 +23,8 @@ class PSAMaskFunction(Function):
mask_size_i=mask_size)
@staticmethod
def forward(ctx, input, psa_type, mask_size):
def forward(ctx, input: torch.Tensor, psa_type: str,
mask_size: int) -> torch.Tensor:
ctx.psa_type = psa_type
ctx.mask_size = _pair(mask_size)
ctx.save_for_backward(input)
@ -45,7 +49,9 @@ class PSAMaskFunction(Function):
return output
@staticmethod
def backward(ctx, grad_output):
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[torch.Tensor, None, None, None]:
input = ctx.saved_tensors[0]
psa_type = ctx.psa_type
h_mask, w_mask = ctx.mask_size
@ -71,7 +77,7 @@ psa_mask = PSAMaskFunction.apply
class PSAMask(nn.Module):
def __init__(self, psa_type, mask_size=None):
def __init__(self, psa_type: str, mask_size: Optional[tuple] = None):
super().__init__()
assert psa_type in ['collect', 'distribute']
if psa_type == 'collect':
@ -82,7 +88,7 @@ class PSAMask(nn.Module):
self.mask_size = mask_size
self.psa_type = psa_type
def forward(self, input):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return psa_mask(input, self.psa_type_enum, self.mask_size)
def __repr__(self):

View File

@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.autograd import Function
@ -11,14 +14,14 @@ ext_module = ext_loader.load_ext(
class RiRoIAlignRotatedFunction(Function):
@staticmethod
def forward(ctx,
features,
rois,
out_size,
spatial_scale,
num_samples=0,
num_orientations=8,
clockwise=False):
def forward(ctx: Any,
features: torch.Tensor,
rois: torch.Tensor,
out_size: Union[int, tuple],
spatial_scale: float,
num_samples: int = 0,
num_orientations: int = 8,
clockwise: bool = False) -> torch.Tensor:
if isinstance(out_size, int):
out_h = out_size
out_w = out_size
@ -54,7 +57,9 @@ class RiRoIAlignRotatedFunction(Function):
return output
@staticmethod
def backward(ctx, grad_output):
def backward(
ctx: Any, grad_output: torch.Tensor
) -> Optional[Tuple[torch.Tensor, None, None, None, None, None, None]]:
feature_size = ctx.feature_size
spatial_scale = ctx.spatial_scale
num_orientations = ctx.num_orientations
@ -67,7 +72,7 @@ class RiRoIAlignRotatedFunction(Function):
out_w = grad_output.size(3)
out_h = grad_output.size(2)
grad_input = grad_rois = None
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = rois.new_zeros(batch_size, num_channels, feature_h,
@ -83,7 +88,8 @@ class RiRoIAlignRotatedFunction(Function):
num_orientations=num_orientations,
clockwise=clockwise)
return grad_input, grad_rois, None, None, None, None, None
return grad_input, None, None, None, None, None, None
return None
riroi_align_rotated = RiRoIAlignRotatedFunction.apply
@ -111,11 +117,11 @@ class RiRoIAlignRotated(nn.Module):
"""
def __init__(self,
out_size,
spatial_scale,
num_samples=0,
num_orientations=8,
clockwise=False):
out_size: tuple,
spatial_scale: float,
num_samples: int = 0,
num_orientations: int = 8,
clockwise: bool = False):
super().__init__()
self.out_size = out_size
@ -124,7 +130,8 @@ class RiRoIAlignRotated(nn.Module):
self.num_orientations = int(num_orientations)
self.clockwise = clockwise
def forward(self, features, rois):
def forward(self, features: torch.Tensor,
rois: torch.Tensor) -> torch.Tensor:
return RiRoIAlignRotatedFunction.apply(features, rois, self.out_size,
self.spatial_scale,
self.num_samples,

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any
import torch
import torch.nn as nn
from torch.autograd import Function
@ -62,14 +64,14 @@ class RoIAlignFunction(Function):
mode_s=pool_mode)
@staticmethod
def forward(ctx,
input,
rois,
output_size,
spatial_scale=1.0,
sampling_ratio=0,
pool_mode='avg',
aligned=True):
def forward(ctx: Any,
input: torch.Tensor,
rois: torch.Tensor,
output_size: int,
spatial_scale: float = 1.0,
sampling_ratio: int = 0,
pool_mode: str = 'avg',
aligned: bool = True) -> torch.Tensor:
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
ctx.sampling_ratio = sampling_ratio
@ -108,7 +110,7 @@ class RoIAlignFunction(Function):
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:
rois, argmax_y, argmax_x = ctx.saved_tensors
grad_input = grad_output.new_zeros(ctx.input_shape)
# complex head architecture may cause grad_output uncontiguous.
@ -175,12 +177,12 @@ class RoIAlign(nn.Module):
},
cls_name='RoIAlign')
def __init__(self,
output_size,
spatial_scale=1.0,
sampling_ratio=0,
pool_mode='avg',
aligned=True,
use_torchvision=False):
output_size: tuple,
spatial_scale: float = 1.0,
sampling_ratio: int = 0,
pool_mode: str = 'avg',
aligned: bool = True,
use_torchvision: bool = False):
super().__init__()
self.output_size = _pair(output_size)
@ -190,7 +192,7 @@ class RoIAlign(nn.Module):
self.aligned = aligned
self.use_torchvision = use_torchvision
def forward(self, input, rois):
def forward(self, input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
"""
Args:
input: NCHW images

View File

@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.nn.modules.utils import _pair
@ -37,14 +40,14 @@ class RoIAlignRotatedFunction(Function):
clockwise_i=clockwise)
@staticmethod
def forward(ctx,
input,
rois,
output_size,
spatial_scale,
sampling_ratio=0,
aligned=True,
clockwise=False):
def forward(ctx: Any,
input: torch.Tensor,
rois: torch.Tensor,
output_size: Union[int, tuple],
spatial_scale: float,
sampling_ratio: int = 0,
aligned: bool = True,
clockwise: bool = False) -> torch.Tensor:
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
ctx.sampling_ratio = sampling_ratio
@ -71,7 +74,10 @@ class RoIAlignRotatedFunction(Function):
return output
@staticmethod
def backward(ctx, grad_output):
def backward(
ctx: Any, grad_output: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], None, None,
None, None, None]:
feature_size = ctx.feature_size
rois = ctx.saved_tensors[0]
assert feature_size is not None
@ -151,11 +157,11 @@ class RoIAlignRotated(nn.Module):
},
cls_name='RoIAlignRotated')
def __init__(self,
output_size,
spatial_scale,
sampling_ratio=0,
aligned=True,
clockwise=False):
output_size: Union[int, tuple],
spatial_scale: float,
sampling_ratio: int = 0,
aligned: bool = True,
clockwise: bool = False):
super().__init__()
self.output_size = _pair(output_size)
@ -164,7 +170,7 @@ class RoIAlignRotated(nn.Module):
self.aligned = aligned
self.clockwise = clockwise
def forward(self, input, rois):
def forward(self, input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
return RoIAlignRotatedFunction.apply(input, rois, self.output_size,
self.spatial_scale,
self.sampling_ratio, self.aligned,

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Tuple, Union
import torch
import torch.nn as nn
from torch.autograd import Function
@ -23,7 +25,11 @@ class RoIPoolFunction(Function):
spatial_scale_f=spatial_scale)
@staticmethod
def forward(ctx, input, rois, output_size, spatial_scale=1.0):
def forward(ctx: Any,
input: torch.Tensor,
rois: torch.Tensor,
output_size: Union[int, tuple],
spatial_scale: float = 1.0) -> torch.Tensor:
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
ctx.input_shape = input.size()
@ -49,7 +55,9 @@ class RoIPoolFunction(Function):
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
def backward(
ctx: Any, grad_output: torch.Tensor
) -> Tuple[torch.Tensor, None, None, None]:
rois, argmax = ctx.saved_tensors
grad_input = grad_output.new_zeros(ctx.input_shape)
@ -70,13 +78,15 @@ roi_pool = RoIPoolFunction.apply
class RoIPool(nn.Module):
def __init__(self, output_size, spatial_scale=1.0):
def __init__(self,
output_size: Union[int, tuple],
spatial_scale: float = 1.0):
super().__init__()
self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale)
def forward(self, input, rois):
def forward(self, input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
return roi_pool(input, rois, self.output_size, self.spatial_scale)
def __repr__(self):

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Tuple, Union
import torch
from torch import nn as nn
from torch.autograd import Function
@ -25,7 +27,10 @@ class RoIAwarePool3d(nn.Module):
Default: 'max'.
"""
def __init__(self, out_size, max_pts_per_voxel=128, mode='max'):
def __init__(self,
out_size: Union[int, tuple],
max_pts_per_voxel: int = 128,
mode: str = 'max'):
super().__init__()
self.out_size = out_size
@ -34,7 +39,8 @@ class RoIAwarePool3d(nn.Module):
pool_mapping = {'max': 0, 'avg': 1}
self.mode = pool_mapping[mode]
def forward(self, rois, pts, pts_feature):
def forward(self, rois: torch.Tensor, pts: torch.Tensor,
pts_feature: torch.Tensor) -> torch.Tensor:
"""
Args:
rois (torch.Tensor): [N, 7], in LiDAR coordinate,
@ -55,8 +61,9 @@ class RoIAwarePool3d(nn.Module):
class RoIAwarePool3dFunction(Function):
@staticmethod
def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel,
mode):
def forward(ctx: Any, rois: torch.Tensor, pts: torch.Tensor,
pts_feature: torch.Tensor, out_size: Union[int, tuple],
max_pts_per_voxel: int, mode: int) -> torch.Tensor:
"""
Args:
rois (torch.Tensor): [N, 7], in LiDAR coordinate,
@ -108,7 +115,9 @@ class RoIAwarePool3dFunction(Function):
return pooled_features
@staticmethod
def backward(ctx, grad_out):
def backward(
ctx: Any, grad_out: torch.Tensor
) -> Tuple[None, None, torch.Tensor, None, None, None]:
ret = ctx.roiaware_pool3d_for_backward
pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret

View File

@ -1,3 +1,6 @@
from typing import Any, Tuple
import torch
from torch import nn as nn
from torch.autograd import Function
@ -17,11 +20,12 @@ class RoIPointPool3d(nn.Module):
Default: 512.
"""
def __init__(self, num_sampled_points=512):
def __init__(self, num_sampled_points: int = 512):
super().__init__()
self.num_sampled_points = num_sampled_points
def forward(self, points, point_features, boxes3d):
def forward(self, points: torch.Tensor, point_features: torch.Tensor,
boxes3d: torch.Tensor) -> Tuple[torch.Tensor]:
"""
Args:
points (torch.Tensor): Input points whose shape is (B, N, C).
@ -41,7 +45,13 @@ class RoIPointPool3d(nn.Module):
class RoIPointPool3dFunction(Function):
@staticmethod
def forward(ctx, points, point_features, boxes3d, num_sampled_points=512):
def forward(
ctx: Any,
points: torch.Tensor,
point_features: torch.Tensor,
boxes3d: torch.Tensor,
num_sampled_points: int = 512
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
points (torch.Tensor): Input points whose shape is (B, N, C).
@ -73,5 +83,5 @@ class RoIPointPool3dFunction(Function):
return pooled_features, pooled_empty_flag
@staticmethod
def backward(ctx, grad_out):
def backward(ctx: Any, grad_out: torch.Tensor) -> torch.Tensor:
raise NotImplementedError

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any
import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable
@ -31,7 +33,8 @@ class RotatedFeatureAlignFunction(Function):
points_i=points)
@staticmethod
def forward(ctx, features, best_rbboxes, spatial_scale, points):
def forward(ctx: Any, features: torch.Tensor, best_rbboxes: torch.Tensor,
spatial_scale: float, points: int) -> torch.Tensor:
"""
Args:
features (torch.Tensor): Input features with shape [N,C,H,W].
@ -60,7 +63,7 @@ class RotatedFeatureAlignFunction(Function):
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:
"""
Args:
grad_output (torch.Tensor): The gradiant of output features
@ -84,9 +87,9 @@ class RotatedFeatureAlignFunction(Function):
return grad_input, None, None, None
def rotated_feature_align(features,
best_rbboxes,
spatial_scale=1 / 8,
points=1):
def rotated_feature_align(features: torch.Tensor,
best_rbboxes: torch.Tensor,
spatial_scale: float = 1 / 8,
points: int = 1) -> torch.Tensor:
return RotatedFeatureAlignFunction.apply(features, best_rbboxes,
spatial_scale, points)

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
@ -14,7 +16,10 @@ ext_module = ext_loader.load_ext(
class _DynamicScatter(Function):
@staticmethod
def forward(ctx, feats, coors, reduce_type='max'):
def forward(ctx: Any,
feats: torch.Tensor,
coors: torch.Tensor,
reduce_type: str = 'max') -> Tuple[torch.Tensor, torch.Tensor]:
"""convert kitti points(N, >=3) to voxels.
Args:
@ -42,7 +47,9 @@ class _DynamicScatter(Function):
return voxel_feats, voxel_coors
@staticmethod
def backward(ctx, grad_voxel_feats, grad_voxel_coors=None):
def backward(ctx: Any,
grad_voxel_feats: torch.Tensor,
grad_voxel_coors: Optional[torch.Tensor] = None) -> tuple:
(feats, voxel_feats, point2voxel_map,
voxel_points_count) = ctx.saved_tensors
grad_feats = torch.zeros_like(feats)
@ -73,14 +80,17 @@ class DynamicScatter(nn.Module):
into voxel.
"""
def __init__(self, voxel_size, point_cloud_range, average_points: bool):
def __init__(self, voxel_size: List, point_cloud_range: List,
average_points: bool):
super().__init__()
self.voxel_size = voxel_size
self.point_cloud_range = point_cloud_range
self.average_points = average_points
def forward_single(self, points, coors):
def forward_single(
self, points: torch.Tensor,
coors: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Scatters points into voxels.
Args:
@ -97,7 +107,8 @@ class DynamicScatter(nn.Module):
reduce = 'mean' if self.average_points else 'max'
return dynamic_scatter(points.contiguous(), coors.contiguous(), reduce)
def forward(self, points, coors):
def forward(self, points: torch.Tensor,
coors: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Scatters points/features into voxels.
Args:

View File

@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from torch.autograd import Function
from . import sparse_ops as ops
@ -25,8 +27,9 @@ class SparseConvFunction(Function):
"""
@staticmethod
def forward(ctx, features, filters, indice_pairs, indice_pair_num,
num_activate_out):
def forward(ctx: Any, features: torch.Tensor, filters: torch.nn.Parameter,
indice_pairs: torch.Tensor, indice_pair_num: torch.Tensor,
num_activate_out: torch.Tensor) -> torch.Tensor:
"""
Args:
features (torch.Tensor): Features that needs to convolute.
@ -44,7 +47,7 @@ class SparseConvFunction(Function):
indice_pair_num, num_activate_out, False)
@staticmethod
def backward(ctx, grad_output):
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
input_bp, filters_bp = ops.indice_conv_backward(
features, filters, grad_output, indice_pairs, indice_pair_num,
@ -56,8 +59,9 @@ class SparseConvFunction(Function):
class SparseInverseConvFunction(Function):
@staticmethod
def forward(ctx, features, filters, indice_pairs, indice_pair_num,
num_activate_out):
def forward(ctx: Any, features: torch.Tensor, filters: torch.nn.Parameter,
indice_pairs: torch.Tensor, indice_pair_num: torch.Tensor,
num_activate_out: torch.Tensor) -> torch.Tensor:
"""
Args:
features (torch.Tensor): Features that needs to convolute.
@ -75,7 +79,7 @@ class SparseInverseConvFunction(Function):
indice_pair_num, num_activate_out, True, False)
@staticmethod
def backward(ctx, grad_output):
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
input_bp, filters_bp = ops.indice_conv_backward(
features, filters, grad_output, indice_pairs, indice_pair_num,
@ -87,8 +91,9 @@ class SparseInverseConvFunction(Function):
class SubMConvFunction(Function):
@staticmethod
def forward(ctx, features, filters, indice_pairs, indice_pair_num,
num_activate_out):
def forward(ctx: Any, features: torch.Tensor, filters: torch.nn.Parameter,
indice_pairs: torch.Tensor, indice_pair_num: torch.Tensor,
num_activate_out: torch.Tensor) -> torch.Tensor:
"""
Args:
features (torch.Tensor): Features that needs to convolute.
@ -106,7 +111,7 @@ class SubMConvFunction(Function):
indice_pair_num, num_activate_out, False, True)
@staticmethod
def backward(ctx, grad_output):
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
input_bp, filters_bp = ops.indice_conv_backward(
features, filters, grad_output, indice_pairs, indice_pair_num,
@ -118,8 +123,9 @@ class SubMConvFunction(Function):
class SparseMaxPoolFunction(Function):
@staticmethod
def forward(ctx, features, indice_pairs, indice_pair_num,
num_activate_out):
def forward(ctx, features: torch.Tensor, indice_pairs: torch.Tensor,
indice_pair_num: torch.Tensor,
num_activate_out: torch.Tensor) -> torch.Tensor:
"""
Args:
features (torch.Tensor): Features that needs to convolute.
@ -137,7 +143,7 @@ class SparseMaxPoolFunction(Function):
return out
@staticmethod
def backward(ctx, grad_output):
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:
indice_pairs, indice_pair_num, features, out = ctx.saved_tensors
input_bp = ops.indice_maxpool_backward(features, out, grad_output,
indice_pairs, indice_pair_num)

View File

@ -13,6 +13,7 @@
# limitations under the License.
import sys
from collections import OrderedDict
from typing import Any, List, Optional, Union
import torch
from torch import nn
@ -20,17 +21,18 @@ from torch import nn
from .sparse_structure import SparseConvTensor
def is_spconv_module(module):
def is_spconv_module(module: nn.Module) -> bool:
spconv_modules = (SparseModule, )
return isinstance(module, spconv_modules)
def is_sparse_conv(module):
def is_sparse_conv(module: nn.Module) -> bool:
from .sparse_conv import SparseConvolution
return isinstance(module, SparseConvolution)
def _mean_update(vals, m_vals, t):
def _mean_update(vals: Union[int, List], m_vals: Union[int, List],
t: float) -> List:
outputs = []
if not isinstance(vals, list):
vals = [vals]
@ -101,7 +103,7 @@ class SparseSequential(SparseModule):
self.add_module(name, module)
self._sparity_dict = {}
def __getitem__(self, idx):
def __getitem__(self, idx: int) -> torch.Tensor:
if not (-len(self) <= idx < len(self)):
raise IndexError(f'index {idx} is out of range')
if idx < 0:
@ -118,14 +120,14 @@ class SparseSequential(SparseModule):
def sparity_dict(self):
return self._sparity_dict
def add(self, module, name=None):
def add(self, module: Any, name: Optional[str] = None) -> None:
if name is None:
name = str(len(self._modules))
if name in self._modules:
raise KeyError('name exists')
self.add_module(name, module)
def forward(self, input):
def forward(self, input: torch.Tensor) -> torch.Tensor:
for k, module in self._modules.items():
if is_spconv_module(module):
assert isinstance(input, SparseConvTensor)

View File

@ -1,8 +1,11 @@
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
def scatter_nd(indices, updates, shape):
def scatter_nd(indices: torch.Tensor, updates: torch.Tensor,
shape: torch.Tensor) -> torch.Tensor:
"""pytorch edition of tensorflow scatter_nd.
this function don't contain except handle code. so use this carefully when
@ -21,18 +24,18 @@ def scatter_nd(indices, updates, shape):
class SparseConvTensor:
def __init__(self,
features,
indices,
spatial_shape,
batch_size,
grid=None):
features: torch.Tensor,
indices: torch.Tensor,
spatial_shape: Union[List, Tuple],
batch_size: int,
grid: Optional[torch.Tensor] = None):
self.features = features
self.indices = indices
if self.indices.dtype != torch.int32:
self.indices.int()
self.spatial_shape = spatial_shape
self.batch_size = batch_size
self.indice_dict = {}
self.indice_dict: dict = {}
self.grid = grid
@property
@ -46,7 +49,7 @@ class SparseConvTensor:
return self.indice_dict[key]
return None
def dense(self, channels_first=True):
def dense(self, channels_first: bool = True) -> torch.Tensor:
output_shape = [self.batch_size] + list(
self.spatial_shape) + [self.features.shape[1]]
res = scatter_nd(self.indices.long(), self.features, output_shape)

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn.functional as F
@ -35,8 +37,10 @@ class SyncBatchNormFunction(Function):
stats_mode=stats_mode)
@staticmethod
def forward(self, input, running_mean, running_var, weight, bias, momentum,
eps, group, group_size, stats_mode):
def forward(self, input: torch.Tensor, running_mean: torch.Tensor,
running_var: torch.Tensor, weight: torch.Tensor,
bias: torch.Tensor, momentum: float, eps: float, group: int,
group_size: int, stats_mode: str) -> torch.Tensor:
self.momentum = momentum
self.eps = eps
self.group = group
@ -126,7 +130,7 @@ class SyncBatchNormFunction(Function):
@staticmethod
@once_differentiable
def backward(self, grad_output):
def backward(self, grad_output: torch.Tensor) -> tuple:
norm, std, weight = self.saved_tensors
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(weight)
@ -191,13 +195,13 @@ class SyncBatchNorm(Module):
"""
def __init__(self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
group=None,
stats_mode='default'):
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
group: Optional[int] = None,
stats_mode: str = 'default'):
super().__init__()
self.num_features = num_features
self.eps = eps
@ -239,7 +243,7 @@ class SyncBatchNorm(Module):
self.weight.data.uniform_() # pytorch use ones_()
self.bias.data.zero_()
def forward(self, input):
def forward(self, input: torch.Tensor) -> torch.Tensor:
if input.dim() < 2:
raise ValueError(
f'expected at least 2D input, got {input.dim()}D input')

View File

@ -1,4 +1,4 @@
from typing import Tuple
from typing import Any, Tuple
import torch
from torch.autograd import Function
@ -17,7 +17,7 @@ class ThreeInterpolate(Function):
"""
@staticmethod
def forward(ctx, features: torch.Tensor, indices: torch.Tensor,
def forward(ctx: Any, features: torch.Tensor, indices: torch.Tensor,
weight: torch.Tensor) -> torch.Tensor:
"""
Args:

View File

@ -1,4 +1,4 @@
from typing import Tuple
from typing import Any, Tuple
import torch
from torch.autograd import Function
@ -16,7 +16,7 @@ class ThreeNN(Function):
"""
@staticmethod
def forward(ctx, target: torch.Tensor,
def forward(ctx: Any, target: torch.Tensor,
source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:

View File

@ -95,6 +95,8 @@
# =======================================================================
from typing import Any, List, Tuple, Union
import torch
from torch.autograd import Function
from torch.nn import functional as F
@ -108,8 +110,10 @@ upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d'])
class UpFirDn2dBackward(Function):
@staticmethod
def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad,
in_size, out_size):
def forward(ctx: Any, grad_output: torch.Tensor, kernel: torch.Tensor,
grad_kernel: torch.Tensor, up: tuple, down: tuple, pad: tuple,
g_pad: tuple, in_size: Union[List, Tuple],
out_size: Union[List, Tuple]) -> torch.Tensor:
up_x, up_y = up
down_x, down_y = down
@ -149,7 +153,7 @@ class UpFirDn2dBackward(Function):
return grad_input
@staticmethod
def backward(ctx, gradgrad_input):
def backward(ctx: Any, gradgrad_input: torch.Tensor) -> tuple:
kernel, = ctx.saved_tensors
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2],
@ -177,7 +181,8 @@ class UpFirDn2dBackward(Function):
class UpFirDn2d(Function):
@staticmethod
def forward(ctx, input, kernel, up, down, pad):
def forward(ctx: Any, input: torch.Tensor, kernel: torch.Tensor, up: tuple,
down: tuple, pad: tuple) -> torch.Tensor:
up_x, up_y = up
down_x, down_y = down
pad_x0, pad_x1, pad_y0, pad_y1 = pad
@ -222,7 +227,7 @@ class UpFirDn2d(Function):
return out
@staticmethod
def backward(ctx, grad_output):
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:
kernel, grad_kernel = ctx.saved_tensors
grad_input = UpFirDn2dBackward.apply(
@ -240,7 +245,12 @@ class UpFirDn2d(Function):
return grad_input, None, None, None, None
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
def upfirdn2d(
input: torch.Tensor,
kernel: torch.Tensor,
up: Union[int, tuple] = 1,
down: Union[int, tuple] = 1,
pad: tuple = (0, 0)) -> torch.Tensor: # noqa E125
"""UpFRIDn for 2d features.
UpFIRDn is short for upsample, apply FIR filter and downsample. More
@ -264,14 +274,14 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
"""
if input.device.type == 'cpu':
if len(pad) == 2:
pad = (pad[0], pad[1], pad[0], pad[1])
pad = (pad[0], pad[1], pad[0], pad[1]) # type: ignore
up = to_2tuple(up)
_up = to_2tuple(up)
down = to_2tuple(down)
_down = to_2tuple(down)
out = upfirdn2d_native(input, kernel, up[0], up[1], down[0], down[1],
pad[0], pad[1], pad[2], pad[3])
out = upfirdn2d_native(input, kernel, _up[0], _up[1], _down[0],
_down[1], pad[0], pad[1], pad[2], pad[3])
else:
_up = to_2tuple(up)
@ -287,8 +297,9 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return out
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
pad_y0, pad_y1):
def upfirdn2d_native(input: torch.Tensor, kernel: torch.Tensor, up_x: int,
up_y: int, down_x: int, down_y: int, pad_x0: int,
pad_x1: int, pad_y0: int, pad_y1: int) -> torch.Tensor:
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, List, Tuple, Union
import torch
from torch import nn
from torch.autograd import Function
@ -13,13 +15,14 @@ ext_module = ext_loader.load_ext(
class _Voxelization(Function):
@staticmethod
def forward(ctx,
points,
voxel_size,
coors_range,
max_points=35,
max_voxels=20000,
deterministic=True):
def forward(
ctx: Any,
points: torch.Tensor,
voxel_size: Union[tuple, float],
coors_range: Union[tuple, float],
max_points: int = 35,
max_voxels: int = 20000,
deterministic: bool = True) -> Union[Tuple[torch.Tensor], Tuple]:
"""Convert kitti points(N, >=3) to voxels.
Args:
@ -111,11 +114,11 @@ class Voxelization(nn.Module):
"""
def __init__(self,
voxel_size,
point_cloud_range,
max_num_points,
max_voxels=20000,
deterministic=True):
voxel_size: List,
point_cloud_range: List,
max_num_points: int,
max_voxels: Union[tuple, int] = 20000,
deterministic: bool = True):
"""
Args:
voxel_size (list): list [x, y, z] size of three dimension
@ -149,8 +152,9 @@ class Voxelization(nn.Module):
point_cloud_range = torch.tensor(
point_cloud_range, dtype=torch.float32)
voxel_size = torch.tensor(voxel_size, dtype=torch.float32)
grid_size = (point_cloud_range[3:] -
point_cloud_range[:3]) / voxel_size
grid_size = (
point_cloud_range[3:] - # type: ignore
point_cloud_range[:3]) / voxel_size # type: ignore
grid_size = torch.round(grid_size).long()
input_feat_shape = grid_size[:2]
self.grid_size = grid_size
@ -158,7 +162,7 @@ class Voxelization(nn.Module):
# [w, h, d] -> [d, h, w]
self.pcd_shape = [*input_feat_shape, 1][::-1]
def forward(self, input):
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.training:
max_voxels = self.max_voxels[0]
else: