Add type hints for mmcv/cnn/bricks (#1993)

* Add type hint

* Add typehint in mmcv/cnn/bricks*

* Deal conflict0

* Fix

* fix

* minor fix

* minor fix

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
tripleMu 2022-06-20 14:52:45 +08:00 committed by GitHub
parent 2d3e42fc41
commit 305c2a3025
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 207 additions and 155 deletions

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -28,12 +30,12 @@ class Clamp(nn.Module):
Default to 1.
"""
def __init__(self, min=-1., max=1.):
def __init__(self, min: float = -1., max: float = 1.):
super().__init__()
self.min = min
self.max = max
def forward(self, x):
def forward(self, x) -> torch.Tensor:
"""Forward function.
Args:
@ -67,7 +69,7 @@ class GELU(nn.Module):
>>> output = m(input)
"""
def forward(self, input):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.gelu(input)
@ -78,7 +80,7 @@ else:
ACTIVATION_LAYERS.register_module(module=nn.GELU)
def build_activation_layer(cfg):
def build_activation_layer(cfg: Dict) -> nn.Module:
"""Build activation layer.
Args:

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import torch
from torch import nn
@ -6,7 +8,7 @@ from ..utils import constant_init, kaiming_init
from .registry import PLUGIN_LAYERS
def last_zero_init(m):
def last_zero_init(m: Union[nn.Module, nn.Sequential]) -> None:
if isinstance(m, nn.Sequential):
constant_init(m[-1], val=0)
else:
@ -34,10 +36,10 @@ class ContextBlock(nn.Module):
_abbr_ = 'context_block'
def __init__(self,
in_channels,
ratio,
pooling_type='att',
fusion_types=('channel_add', )):
in_channels: int,
ratio: float,
pooling_type: str = 'att',
fusion_types: tuple = ('channel_add', )):
super().__init__()
assert pooling_type in ['avg', 'att']
assert isinstance(fusion_types, (list, tuple))
@ -82,7 +84,7 @@ class ContextBlock(nn.Module):
if self.channel_mul_conv is not None:
last_zero_init(self.channel_mul_conv)
def spatial_pool(self, x):
def spatial_pool(self, x: torch.Tensor) -> torch.Tensor:
batch, channel, height, width = x.size()
if self.pooling_type == 'att':
input_x = x
@ -108,7 +110,7 @@ class ContextBlock(nn.Module):
return context
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# [N, C, 1, 1]
context = self.spatial_pool(x)

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional
from torch import nn
from .registry import CONV_LAYERS
@ -9,7 +11,7 @@ CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
def build_conv_layer(cfg, *args, **kwargs):
def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module:
"""Build convolution layer.
Args:

View File

@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Tuple, Union
import torch
from torch import nn
from torch.nn import functional as F
@ -31,18 +33,18 @@ class Conv2dAdaptivePadding(nn.Conv2d):
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True):
super().__init__(in_channels, out_channels, kernel_size, stride, 0,
dilation, groups, bias)
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
img_h, img_w = x.size()[-2:]
kernel_h, kernel_w = self.weight.size()[-2:]
stride_h, stride_w = self.stride

View File

@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmcv.utils import _BatchNorm, _InstanceNorm
@ -68,21 +70,21 @@ class ConvModule(nn.Module):
_abbr_ = 'conv_block'
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias='auto',
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
inplace=True,
with_spectral_norm=False,
padding_mode='zeros',
order=('conv', 'norm', 'act')):
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: Union[bool, str] = 'auto',
conv_cfg: Optional[Dict] = None,
norm_cfg: Optional[Dict] = None,
act_cfg: Optional[Dict] = dict(type='ReLU'),
inplace: bool = True,
with_spectral_norm: bool = False,
padding_mode: str = 'zeros',
order: tuple = ('conv', 'norm', 'act')):
super().__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict)
@ -143,18 +145,19 @@ class ConvModule(nn.Module):
norm_channels = out_channels
else:
norm_channels = in_channels
self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
self.norm_name, norm = build_norm_layer(
norm_cfg, norm_channels) # type: ignore
self.add_module(self.norm_name, norm)
if self.with_bias:
if isinstance(norm, (_BatchNorm, _InstanceNorm)):
warnings.warn(
'Unnecessary conv bias before batch/instance norm')
else:
self.norm_name = None
self.norm_name = None # type: ignore
# build activation layer
if self.with_activation:
act_cfg_ = act_cfg.copy()
act_cfg_ = act_cfg.copy() # type: ignore
# nn.Tanh has no 'inplace' argument
if act_cfg_['type'] not in [
'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish', 'GELU'
@ -193,7 +196,10 @@ class ConvModule(nn.Module):
if self.with_norm:
constant_init(self.norm, 1, bias=0)
def forward(self, x, activate=True, norm=True):
def forward(self,
x: torch.Tensor,
activate: bool = True,
norm: bool = True) -> torch.Tensor:
for layer in self.order:
if layer == 'conv':
if self.with_explicit_padding:

View File

@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -6,14 +9,14 @@ import torch.nn.functional as F
from .registry import CONV_LAYERS
def conv_ws_2d(input,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
eps=1e-5):
def conv_ws_2d(input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
eps: float = 1e-5) -> torch.Tensor:
c_in = weight.size(0)
weight_flat = weight.view(c_in, -1)
mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
@ -26,15 +29,15 @@ def conv_ws_2d(input,
class ConvWS2d(nn.Conv2d):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
eps=1e-5):
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
eps: float = 1e-5):
super().__init__(
in_channels,
out_channels,
@ -46,7 +49,7 @@ class ConvWS2d(nn.Conv2d):
bias=bias)
self.eps = eps
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups, self.eps)
@ -76,14 +79,14 @@ class ConvAWS2d(nn.Conv2d):
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True):
super().__init__(
in_channels,
out_channels,
@ -98,7 +101,7 @@ class ConvAWS2d(nn.Conv2d):
self.register_buffer('weight_beta',
torch.zeros(self.out_channels, 1, 1, 1))
def _get_weight(self, weight):
def _get_weight(self, weight: torch.Tensor) -> torch.Tensor:
weight_flat = weight.view(weight.size(0), -1)
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
@ -106,13 +109,16 @@ class ConvAWS2d(nn.Conv2d):
weight = self.weight_gamma * weight + self.weight_beta
return weight
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
weight = self._get_weight(self.weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding,
self.dilation, self.groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
def _load_from_state_dict(self, state_dict: OrderedDict, prefix: str,
local_metadata: Dict, strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str]) -> None:
"""Override default load function.
AWS overrides the function _load_from_state_dict to recover
@ -124,7 +130,7 @@ class ConvAWS2d(nn.Conv2d):
"""
self.weight_gamma.data.fill_(-1)
local_missing_keys = []
local_missing_keys: List = []
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, local_missing_keys,
unexpected_keys, error_msgs)

View File

@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from .conv_module import ConvModule
@ -46,27 +49,27 @@ class DepthwiseSeparableConvModule(nn.Module):
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
dw_norm_cfg='default',
dw_act_cfg='default',
pw_norm_cfg='default',
pw_act_cfg='default',
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
norm_cfg: Optional[Dict] = None,
act_cfg: Dict = dict(type='ReLU'),
dw_norm_cfg: Union[Dict, str] = 'default',
dw_act_cfg: Union[Dict, str] = 'default',
pw_norm_cfg: Union[Dict, str] = 'default',
pw_act_cfg: Union[Dict, str] = 'default',
**kwargs):
super().__init__()
assert 'groups' not in kwargs, 'groups should not be specified'
# if norm/activation config of depthwise/pointwise ConvModule is not
# specified, use default config.
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
# depthwise convolution
@ -78,19 +81,19 @@ class DepthwiseSeparableConvModule(nn.Module):
padding=padding,
dilation=dilation,
groups=in_channels,
norm_cfg=dw_norm_cfg,
act_cfg=dw_act_cfg,
norm_cfg=dw_norm_cfg, # type: ignore
act_cfg=dw_act_cfg, # type: ignore
**kwargs)
self.pointwise_conv = ConvModule(
in_channels,
out_channels,
1,
norm_cfg=pw_norm_cfg,
act_cfg=pw_act_cfg,
norm_cfg=pw_norm_cfg, # type: ignore
act_cfg=pw_act_cfg, # type: ignore
**kwargs)
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
return x

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
@ -6,7 +8,9 @@ from mmcv import build_from_cfg
from .registry import DROPOUT_LAYERS
def drop_path(x, drop_prob=0., training=False):
def drop_path(x: torch.Tensor,
drop_prob: float = 0.,
training: bool = False) -> torch.Tensor:
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
residual blocks).
@ -36,11 +40,11 @@ class DropPath(nn.Module):
drop_prob (float): Probability of the path to be zeroed. Default: 0.1
"""
def __init__(self, drop_prob=0.1):
def __init__(self, drop_prob: float = 0.1):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)
@ -56,10 +60,10 @@ class Dropout(nn.Dropout):
inplace (bool): Do the operation inplace or not. Default: False.
"""
def __init__(self, drop_prob=0.5, inplace=False):
def __init__(self, drop_prob: float = 0.5, inplace: bool = False):
super().__init__(p=drop_prob, inplace=inplace)
def build_dropout(cfg, default_args=None):
def build_dropout(cfg: Dict, default_args: Optional[Dict] = None) -> Any:
"""Builder for drop out layers."""
return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)

View File

@ -45,14 +45,14 @@ class GeneralizedAttention(nn.Module):
_abbr_ = 'gen_attention_block'
def __init__(self,
in_channels,
spatial_range=-1,
num_heads=9,
position_embedding_dim=-1,
position_magnitude=1,
kv_stride=2,
q_stride=1,
attention_type='1111'):
in_channels: int,
spatial_range: int = -1,
num_heads: int = 9,
position_embedding_dim: int = -1,
position_magnitude: int = 1,
kv_stride: int = 2,
q_stride: int = 1,
attention_type: str = '1111'):
super().__init__()
@ -213,7 +213,7 @@ class GeneralizedAttention(nn.Module):
return embedding_x, embedding_y
def forward(self, x_input):
def forward(self, x_input: torch.Tensor) -> torch.Tensor:
num_heads = self.num_heads
# use empirical_attention

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn
from .registry import ACTIVATION_LAYERS
@ -26,7 +27,11 @@ class HSigmoid(nn.Module):
Tensor: The output tensor.
"""
def __init__(self, bias=3.0, divisor=6.0, min_value=0.0, max_value=1.0):
def __init__(self,
bias: float = 3.0,
divisor: float = 6.0,
min_value: float = 0.0,
max_value: float = 1.0):
super().__init__()
warnings.warn(
'In MMCV v1.4.4, we modified the default value of args to align '
@ -40,7 +45,7 @@ class HSigmoid(nn.Module):
self.min_value = min_value
self.max_value = max_value
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = (x + self.bias) / self.divisor
return x.clamp_(self.min_value, self.max_value)

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.utils import TORCH_VERSION, digit_version
@ -21,11 +22,11 @@ class HSwish(nn.Module):
Tensor: The output tensor.
"""
def __init__(self, inplace=False):
def __init__(self, inplace: bool = False):
super().__init__()
self.act = nn.ReLU6(inplace)
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.act(x + 3) / 6

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta
from typing import Dict, Optional
import torch
import torch.nn as nn
@ -33,12 +34,12 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
"""
def __init__(self,
in_channels,
reduction=2,
use_scale=True,
conv_cfg=None,
norm_cfg=None,
mode='embedded_gaussian',
in_channels: int,
reduction: int = 2,
use_scale: bool = True,
conv_cfg: Optional[Dict] = None,
norm_cfg: Optional[Dict] = None,
mode: str = 'embedded_gaussian',
**kwargs):
super().__init__()
self.in_channels = in_channels
@ -61,7 +62,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
self.inter_channels,
kernel_size=1,
conv_cfg=conv_cfg,
act_cfg=None)
act_cfg=None) # type: ignore
self.conv_out = ConvModule(
self.inter_channels,
self.in_channels,
@ -96,7 +97,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
self.init_weights(**kwargs)
def init_weights(self, std=0.01, zeros_init=True):
def init_weights(self, std: float = 0.01, zeros_init: bool = True) -> None:
if self.mode != 'gaussian':
for m in [self.g, self.theta, self.phi]:
normal_init(m.conv, std=std)
@ -113,7 +114,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
else:
normal_init(self.conv_out.norm, std=std)
def gaussian(self, theta_x, phi_x):
def gaussian(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
@ -121,7 +123,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight
def embedded_gaussian(self, theta_x, phi_x):
def embedded_gaussian(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
@ -132,7 +135,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight
def dot_product(self, theta_x, phi_x):
def dot_product(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
@ -140,7 +144,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
pairwise_weight /= pairwise_weight.shape[-1]
return pairwise_weight
def concatenation(self, theta_x, phi_x):
def concatenation(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
@ -157,7 +162,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
return pairwise_weight
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Assume `reduction = 1`, then `inter_channels = C`
# or `inter_channels = C` when `mode="gaussian"`
@ -224,9 +229,9 @@ class NonLocal1d(_NonLocalNd):
"""
def __init__(self,
in_channels,
sub_sample=False,
conv_cfg=dict(type='Conv1d'),
in_channels: int,
sub_sample: bool = False,
conv_cfg: Dict = dict(type='Conv1d'),
**kwargs):
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
@ -257,9 +262,9 @@ class NonLocal2d(_NonLocalNd):
_abbr_ = 'nonlocal_block'
def __init__(self,
in_channels,
sub_sample=False,
conv_cfg=dict(type='Conv2d'),
in_channels: int,
sub_sample: bool = False,
conv_cfg: Dict = dict(type='Conv2d'),
**kwargs):
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
@ -287,9 +292,9 @@ class NonLocal3d(_NonLocalNd):
"""
def __init__(self,
in_channels,
sub_sample=False,
conv_cfg=dict(type='Conv3d'),
in_channels: int,
sub_sample: bool = False,
conv_cfg: Dict = dict(type='Conv3d'),
**kwargs):
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict, Tuple, Union
import torch.nn as nn
@ -69,7 +70,9 @@ def infer_abbr(class_type):
return 'norm_layer'
def build_norm_layer(cfg, num_features, postfix=''):
def build_norm_layer(cfg: Dict,
num_features: int,
postfix: Union[int, str] = '') -> Tuple[str, nn.Module]:
"""Build normalization layer.
Args:
@ -119,7 +122,8 @@ def build_norm_layer(cfg, num_features, postfix=''):
return name, layer
def is_norm(layer, exclude=None):
def is_norm(layer: nn.Module,
exclude: Union[type, tuple, None] = None) -> bool:
"""Check if a layer is a normalization layer.
Args:

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
import torch.nn as nn
from .registry import PADDING_LAYERS
@ -8,11 +10,11 @@ PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
def build_padding_layer(cfg, *args, **kwargs):
def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
"""Build padding layer.
Args:
cfg (None or dict): The padding layer config, which should contain:
cfg (dict): The padding layer config, which should contain:
- type (str): Layer type.
- layer args: Args needed to instantiate a padding layer.

View File

@ -1,6 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import platform
from typing import Dict, Tuple, Union
import torch.nn as nn
from .registry import PLUGIN_LAYERS
@ -10,7 +13,7 @@ else:
import re # type: ignore
def infer_abbr(class_type):
def infer_abbr(class_type: type) -> str:
"""Infer abbreviation from the class name.
This method will infer the abbreviation to map class types to
@ -48,16 +51,18 @@ def infer_abbr(class_type):
raise TypeError(
f'class_type must be a type, but got {type(class_type)}')
if hasattr(class_type, '_abbr_'):
return class_type._abbr_
return class_type._abbr_ # type: ignore
else:
return camel2snack(class_type.__name__)
def build_plugin_layer(cfg, postfix='', **kwargs):
def build_plugin_layer(cfg: Dict,
postfix: Union[int, str] = '',
**kwargs) -> Tuple[str, nn.Module]:
"""Build plugin layer.
Args:
cfg (None or dict): cfg should contain:
cfg (dict): cfg should contain:
- type (str): identify plugin layer type.
- layer args: args needed to instantiate a plugin layer.

View File

@ -13,9 +13,9 @@ class Scale(nn.Module):
scale (float): Initial value of scale factor. Default: 1.0
"""
def __init__(self, scale=1.0):
def __init__(self, scale: float = 1.0):
super().__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.scale

View File

@ -21,5 +21,5 @@ class Swish(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)

View File

@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -24,8 +27,8 @@ class PixelShufflePack(nn.Module):
channels.
"""
def __init__(self, in_channels, out_channels, scale_factor,
upsample_kernel):
def __init__(self, in_channels: int, out_channels: int, scale_factor: int,
upsample_kernel: int):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@ -41,13 +44,13 @@ class PixelShufflePack(nn.Module):
def init_weights(self):
xavier_init(self.upsample_conv, distribution='uniform')
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.upsample_conv(x)
x = F.pixel_shuffle(x, self.scale_factor)
return x
def build_upsample_layer(cfg, *args, **kwargs):
def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
"""Build upsample layer.
Args:

View File

@ -21,19 +21,19 @@ else:
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
def obsolete_torch_version(torch_version, version_threshold):
def obsolete_torch_version(torch_version, version_threshold) -> bool:
return torch_version == 'parrots' or torch_version <= version_threshold
class NewEmptyTensorOp(torch.autograd.Function):
@staticmethod
def forward(ctx, x, new_shape):
def forward(ctx, x: torch.Tensor, new_shape: tuple) -> torch.Tensor:
ctx.shape = x.shape
return x.new_empty(new_shape)
@staticmethod
def backward(ctx, grad):
def backward(ctx, grad: torch.Tensor) -> tuple:
shape = ctx.shape
return NewEmptyTensorOp.apply(grad, shape), None
@ -41,7 +41,7 @@ class NewEmptyTensorOp(torch.autograd.Function):
@CONV_LAYERS.register_module('Conv', force=True)
class Conv2d(nn.Conv2d):
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
@ -62,7 +62,7 @@ class Conv2d(nn.Conv2d):
@CONV_LAYERS.register_module('Conv3d', force=True)
class Conv3d(nn.Conv3d):
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
@ -85,7 +85,7 @@ class Conv3d(nn.Conv3d):
@UPSAMPLE_LAYERS.register_module('deconv', force=True)
class ConvTranspose2d(nn.ConvTranspose2d):
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
@ -108,7 +108,7 @@ class ConvTranspose2d(nn.ConvTranspose2d):
@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
class ConvTranspose3d(nn.ConvTranspose3d):
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
@ -128,7 +128,7 @@ class ConvTranspose3d(nn.ConvTranspose3d):
class MaxPool2d(nn.MaxPool2d):
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
out_shape = list(x.shape[:2])
@ -146,7 +146,7 @@ class MaxPool2d(nn.MaxPool2d):
class MaxPool3d(nn.MaxPool3d):
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
out_shape = list(x.shape[:2])
@ -165,7 +165,7 @@ class MaxPool3d(nn.MaxPool3d):
class Linear(torch.nn.Linear):
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# empty tensor forward of Linear layer is supported in Pytorch 1.6
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
out_shape = [x.shape[0], self.out_features]