mirror of https://github.com/open-mmlab/mmcv.git
Add type hints for mmcv/ops (#1987)
* add type hints for mmcv/ops/... * add type hints for mmcv/ops/... * add type hints for mmcv/ops/...pull/2005/head^2
parent
84a544fb3e
commit
699398ad86
|
@ -6,7 +6,7 @@ import torch.nn.functional as F
|
|||
from mmcv.cnn import PLUGIN_LAYERS, Scale
|
||||
|
||||
|
||||
def NEG_INF_DIAG(n, device):
|
||||
def NEG_INF_DIAG(n: int, device: torch.device) -> torch.Tensor:
|
||||
"""Returns a diagonal matrix of size [n, n].
|
||||
|
||||
The diagonal are all "-inf". This is for avoiding calculating the
|
||||
|
@ -41,7 +41,7 @@ class CrissCrossAttention(nn.Module):
|
|||
in_channels (int): Channels of the input feature map.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels):
|
||||
def __init__(self, in_channels: int) -> None:
|
||||
super().__init__()
|
||||
self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
|
||||
self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
|
||||
|
@ -49,7 +49,7 @@ class CrissCrossAttention(nn.Module):
|
|||
self.gamma = Scale(0.)
|
||||
self.in_channels = in_channels
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""forward function of Criss-Cross Attention.
|
||||
|
||||
Args:
|
||||
|
@ -78,7 +78,7 @@ class CrissCrossAttention(nn.Module):
|
|||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
s = self.__class__.__name__
|
||||
s += f'(in_channels={self.in_channels})'
|
||||
return s
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
@ -7,8 +9,9 @@ from ..utils import ext_loader
|
|||
ext_module = ext_loader.load_ext('_ext', ['contour_expand'])
|
||||
|
||||
|
||||
def contour_expand(kernel_mask, internal_kernel_label, min_kernel_area,
|
||||
kernel_num):
|
||||
def contour_expand(kernel_mask: Union[np.array, torch.Tensor],
|
||||
internal_kernel_label: Union[np.array, torch.Tensor],
|
||||
min_kernel_area: int, kernel_num: int) -> list:
|
||||
"""Expand kernel contours so that foreground pixels are assigned into
|
||||
instances.
|
||||
|
||||
|
@ -42,7 +45,7 @@ def contour_expand(kernel_mask, internal_kernel_label, min_kernel_area,
|
|||
internal_kernel_label,
|
||||
min_kernel_area=min_kernel_area,
|
||||
kernel_num=kernel_num)
|
||||
label = label.tolist()
|
||||
label = label.tolist() # type: ignore
|
||||
else:
|
||||
label = ext_module.contour_expand(kernel_mask, internal_kernel_label,
|
||||
min_kernel_area, kernel_num)
|
||||
|
|
|
@ -1,10 +1,15 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import ext_loader
|
||||
|
||||
ext_module = ext_loader.load_ext('_ext', ['convex_iou', 'convex_giou'])
|
||||
|
||||
|
||||
def convex_giou(pointsets, polygons):
|
||||
def convex_giou(pointsets: torch.Tensor,
|
||||
polygons: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Return generalized intersection-over-union (Jaccard index) between point
|
||||
sets and polygons.
|
||||
|
||||
|
@ -26,7 +31,8 @@ def convex_giou(pointsets, polygons):
|
|||
return convex_giou, points_grad
|
||||
|
||||
|
||||
def convex_iou(pointsets, polygons):
|
||||
def convex_iou(pointsets: torch.Tensor,
|
||||
polygons: torch.Tensor) -> torch.Tensor:
|
||||
"""Return intersection-over-union (Jaccard index) between point sets and
|
||||
polygons.
|
||||
|
||||
|
|
Loading…
Reference in New Issue