mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
Add type hints for mmcv/cnn/bricks (#1993)
* Add type hint * Add typehint in mmcv/cnn/bricks* * Deal conflict0 * Fix * fix * minor fix * minor fix Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
parent
2d3e42fc41
commit
305c2a3025
@ -1,4 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -28,12 +30,12 @@ class Clamp(nn.Module):
|
||||
Default to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, min=-1., max=1.):
|
||||
def __init__(self, min: float = -1., max: float = 1.):
|
||||
super().__init__()
|
||||
self.min = min
|
||||
self.max = max
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
@ -67,7 +69,7 @@ class GELU(nn.Module):
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def forward(self, input):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.gelu(input)
|
||||
|
||||
|
||||
@ -78,7 +80,7 @@ else:
|
||||
ACTIVATION_LAYERS.register_module(module=nn.GELU)
|
||||
|
||||
|
||||
def build_activation_layer(cfg):
|
||||
def build_activation_layer(cfg: Dict) -> nn.Module:
|
||||
"""Build activation layer.
|
||||
|
||||
Args:
|
||||
|
@ -1,4 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@ -6,7 +8,7 @@ from ..utils import constant_init, kaiming_init
|
||||
from .registry import PLUGIN_LAYERS
|
||||
|
||||
|
||||
def last_zero_init(m):
|
||||
def last_zero_init(m: Union[nn.Module, nn.Sequential]) -> None:
|
||||
if isinstance(m, nn.Sequential):
|
||||
constant_init(m[-1], val=0)
|
||||
else:
|
||||
@ -34,10 +36,10 @@ class ContextBlock(nn.Module):
|
||||
_abbr_ = 'context_block'
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
ratio,
|
||||
pooling_type='att',
|
||||
fusion_types=('channel_add', )):
|
||||
in_channels: int,
|
||||
ratio: float,
|
||||
pooling_type: str = 'att',
|
||||
fusion_types: tuple = ('channel_add', )):
|
||||
super().__init__()
|
||||
assert pooling_type in ['avg', 'att']
|
||||
assert isinstance(fusion_types, (list, tuple))
|
||||
@ -82,7 +84,7 @@ class ContextBlock(nn.Module):
|
||||
if self.channel_mul_conv is not None:
|
||||
last_zero_init(self.channel_mul_conv)
|
||||
|
||||
def spatial_pool(self, x):
|
||||
def spatial_pool(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch, channel, height, width = x.size()
|
||||
if self.pooling_type == 'att':
|
||||
input_x = x
|
||||
@ -108,7 +110,7 @@ class ContextBlock(nn.Module):
|
||||
|
||||
return context
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# [N, C, 1, 1]
|
||||
context = self.spatial_pool(x)
|
||||
|
||||
|
@ -1,4 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .registry import CONV_LAYERS
|
||||
@ -9,7 +11,7 @@ CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
|
||||
CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
|
||||
|
||||
|
||||
def build_conv_layer(cfg, *args, **kwargs):
|
||||
def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module:
|
||||
"""Build convolution layer.
|
||||
|
||||
Args:
|
||||
|
@ -1,6 +1,8 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
@ -31,18 +33,18 @@ class Conv2dAdaptivePadding(nn.Conv2d):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True):
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Union[int, Tuple[int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int]] = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True):
|
||||
super().__init__(in_channels, out_channels, kernel_size, stride, 0,
|
||||
dilation, groups, bias)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
img_h, img_w = x.size()[-2:]
|
||||
kernel_h, kernel_w = self.weight.size()[-2:]
|
||||
stride_h, stride_w = self.stride
|
||||
|
@ -1,6 +1,8 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcv.utils import _BatchNorm, _InstanceNorm
|
||||
@ -68,21 +70,21 @@ class ConvModule(nn.Module):
|
||||
_abbr_ = 'conv_block'
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias='auto',
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
inplace=True,
|
||||
with_spectral_norm=False,
|
||||
padding_mode='zeros',
|
||||
order=('conv', 'norm', 'act')):
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Union[int, Tuple[int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int]] = 1,
|
||||
groups: int = 1,
|
||||
bias: Union[bool, str] = 'auto',
|
||||
conv_cfg: Optional[Dict] = None,
|
||||
norm_cfg: Optional[Dict] = None,
|
||||
act_cfg: Optional[Dict] = dict(type='ReLU'),
|
||||
inplace: bool = True,
|
||||
with_spectral_norm: bool = False,
|
||||
padding_mode: str = 'zeros',
|
||||
order: tuple = ('conv', 'norm', 'act')):
|
||||
super().__init__()
|
||||
assert conv_cfg is None or isinstance(conv_cfg, dict)
|
||||
assert norm_cfg is None or isinstance(norm_cfg, dict)
|
||||
@ -143,18 +145,19 @@ class ConvModule(nn.Module):
|
||||
norm_channels = out_channels
|
||||
else:
|
||||
norm_channels = in_channels
|
||||
self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
|
||||
self.norm_name, norm = build_norm_layer(
|
||||
norm_cfg, norm_channels) # type: ignore
|
||||
self.add_module(self.norm_name, norm)
|
||||
if self.with_bias:
|
||||
if isinstance(norm, (_BatchNorm, _InstanceNorm)):
|
||||
warnings.warn(
|
||||
'Unnecessary conv bias before batch/instance norm')
|
||||
else:
|
||||
self.norm_name = None
|
||||
self.norm_name = None # type: ignore
|
||||
|
||||
# build activation layer
|
||||
if self.with_activation:
|
||||
act_cfg_ = act_cfg.copy()
|
||||
act_cfg_ = act_cfg.copy() # type: ignore
|
||||
# nn.Tanh has no 'inplace' argument
|
||||
if act_cfg_['type'] not in [
|
||||
'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish', 'GELU'
|
||||
@ -193,7 +196,10 @@ class ConvModule(nn.Module):
|
||||
if self.with_norm:
|
||||
constant_init(self.norm, 1, bias=0)
|
||||
|
||||
def forward(self, x, activate=True, norm=True):
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
activate: bool = True,
|
||||
norm: bool = True) -> torch.Tensor:
|
||||
for layer in self.order:
|
||||
if layer == 'conv':
|
||||
if self.with_explicit_padding:
|
||||
|
@ -1,4 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -6,14 +9,14 @@ import torch.nn.functional as F
|
||||
from .registry import CONV_LAYERS
|
||||
|
||||
|
||||
def conv_ws_2d(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
eps=1e-5):
|
||||
def conv_ws_2d(input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
stride: Union[int, Tuple[int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int]] = 1,
|
||||
groups: int = 1,
|
||||
eps: float = 1e-5) -> torch.Tensor:
|
||||
c_in = weight.size(0)
|
||||
weight_flat = weight.view(c_in, -1)
|
||||
mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
|
||||
@ -26,15 +29,15 @@ def conv_ws_2d(input,
|
||||
class ConvWS2d(nn.Conv2d):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
eps=1e-5):
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Union[int, Tuple[int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int]] = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
eps: float = 1e-5):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
@ -46,7 +49,7 @@ class ConvWS2d(nn.Conv2d):
|
||||
bias=bias)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
|
||||
self.dilation, self.groups, self.eps)
|
||||
|
||||
@ -76,14 +79,14 @@ class ConvAWS2d(nn.Conv2d):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True):
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Union[int, Tuple[int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int]] = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
@ -98,7 +101,7 @@ class ConvAWS2d(nn.Conv2d):
|
||||
self.register_buffer('weight_beta',
|
||||
torch.zeros(self.out_channels, 1, 1, 1))
|
||||
|
||||
def _get_weight(self, weight):
|
||||
def _get_weight(self, weight: torch.Tensor) -> torch.Tensor:
|
||||
weight_flat = weight.view(weight.size(0), -1)
|
||||
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
|
||||
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
|
||||
@ -106,13 +109,16 @@ class ConvAWS2d(nn.Conv2d):
|
||||
weight = self.weight_gamma * weight + self.weight_beta
|
||||
return weight
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
weight = self._get_weight(self.weight)
|
||||
return F.conv2d(x, weight, self.bias, self.stride, self.padding,
|
||||
self.dilation, self.groups)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
def _load_from_state_dict(self, state_dict: OrderedDict, prefix: str,
|
||||
local_metadata: Dict, strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str]) -> None:
|
||||
"""Override default load function.
|
||||
|
||||
AWS overrides the function _load_from_state_dict to recover
|
||||
@ -124,7 +130,7 @@ class ConvAWS2d(nn.Conv2d):
|
||||
"""
|
||||
|
||||
self.weight_gamma.data.fill_(-1)
|
||||
local_missing_keys = []
|
||||
local_missing_keys: List = []
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata,
|
||||
strict, local_missing_keys,
|
||||
unexpected_keys, error_msgs)
|
||||
|
@ -1,4 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .conv_module import ConvModule
|
||||
@ -46,27 +49,27 @@ class DepthwiseSeparableConvModule(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
dw_norm_cfg='default',
|
||||
dw_act_cfg='default',
|
||||
pw_norm_cfg='default',
|
||||
pw_act_cfg='default',
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Union[int, Tuple[int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int]] = 1,
|
||||
norm_cfg: Optional[Dict] = None,
|
||||
act_cfg: Dict = dict(type='ReLU'),
|
||||
dw_norm_cfg: Union[Dict, str] = 'default',
|
||||
dw_act_cfg: Union[Dict, str] = 'default',
|
||||
pw_norm_cfg: Union[Dict, str] = 'default',
|
||||
pw_act_cfg: Union[Dict, str] = 'default',
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
assert 'groups' not in kwargs, 'groups should not be specified'
|
||||
|
||||
# if norm/activation config of depthwise/pointwise ConvModule is not
|
||||
# specified, use default config.
|
||||
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg
|
||||
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
|
||||
dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
|
||||
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg
|
||||
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
|
||||
pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
|
||||
|
||||
# depthwise convolution
|
||||
@ -78,19 +81,19 @@ class DepthwiseSeparableConvModule(nn.Module):
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=in_channels,
|
||||
norm_cfg=dw_norm_cfg,
|
||||
act_cfg=dw_act_cfg,
|
||||
norm_cfg=dw_norm_cfg, # type: ignore
|
||||
act_cfg=dw_act_cfg, # type: ignore
|
||||
**kwargs)
|
||||
|
||||
self.pointwise_conv = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
norm_cfg=pw_norm_cfg,
|
||||
act_cfg=pw_act_cfg,
|
||||
norm_cfg=pw_norm_cfg, # type: ignore
|
||||
act_cfg=pw_act_cfg, # type: ignore
|
||||
**kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.pointwise_conv(x)
|
||||
return x
|
||||
|
@ -1,4 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -6,7 +8,9 @@ from mmcv import build_from_cfg
|
||||
from .registry import DROPOUT_LAYERS
|
||||
|
||||
|
||||
def drop_path(x, drop_prob=0., training=False):
|
||||
def drop_path(x: torch.Tensor,
|
||||
drop_prob: float = 0.,
|
||||
training: bool = False) -> torch.Tensor:
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
|
||||
residual blocks).
|
||||
|
||||
@ -36,11 +40,11 @@ class DropPath(nn.Module):
|
||||
drop_prob (float): Probability of the path to be zeroed. Default: 0.1
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob=0.1):
|
||||
def __init__(self, drop_prob: float = 0.1):
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
|
||||
@ -56,10 +60,10 @@ class Dropout(nn.Dropout):
|
||||
inplace (bool): Do the operation inplace or not. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob=0.5, inplace=False):
|
||||
def __init__(self, drop_prob: float = 0.5, inplace: bool = False):
|
||||
super().__init__(p=drop_prob, inplace=inplace)
|
||||
|
||||
|
||||
def build_dropout(cfg, default_args=None):
|
||||
def build_dropout(cfg: Dict, default_args: Optional[Dict] = None) -> Any:
|
||||
"""Builder for drop out layers."""
|
||||
return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
|
||||
|
@ -45,14 +45,14 @@ class GeneralizedAttention(nn.Module):
|
||||
_abbr_ = 'gen_attention_block'
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
spatial_range=-1,
|
||||
num_heads=9,
|
||||
position_embedding_dim=-1,
|
||||
position_magnitude=1,
|
||||
kv_stride=2,
|
||||
q_stride=1,
|
||||
attention_type='1111'):
|
||||
in_channels: int,
|
||||
spatial_range: int = -1,
|
||||
num_heads: int = 9,
|
||||
position_embedding_dim: int = -1,
|
||||
position_magnitude: int = 1,
|
||||
kv_stride: int = 2,
|
||||
q_stride: int = 1,
|
||||
attention_type: str = '1111'):
|
||||
|
||||
super().__init__()
|
||||
|
||||
@ -213,7 +213,7 @@ class GeneralizedAttention(nn.Module):
|
||||
|
||||
return embedding_x, embedding_y
|
||||
|
||||
def forward(self, x_input):
|
||||
def forward(self, x_input: torch.Tensor) -> torch.Tensor:
|
||||
num_heads = self.num_heads
|
||||
|
||||
# use empirical_attention
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .registry import ACTIVATION_LAYERS
|
||||
@ -26,7 +27,11 @@ class HSigmoid(nn.Module):
|
||||
Tensor: The output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, bias=3.0, divisor=6.0, min_value=0.0, max_value=1.0):
|
||||
def __init__(self,
|
||||
bias: float = 3.0,
|
||||
divisor: float = 6.0,
|
||||
min_value: float = 0.0,
|
||||
max_value: float = 1.0):
|
||||
super().__init__()
|
||||
warnings.warn(
|
||||
'In MMCV v1.4.4, we modified the default value of args to align '
|
||||
@ -40,7 +45,7 @@ class HSigmoid(nn.Module):
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = (x + self.bias) / self.divisor
|
||||
|
||||
return x.clamp_(self.min_value, self.max_value)
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcv.utils import TORCH_VERSION, digit_version
|
||||
@ -21,11 +22,11 @@ class HSwish(nn.Module):
|
||||
Tensor: The output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, inplace=False):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super().__init__()
|
||||
self.act = nn.ReLU6(inplace)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x * self.act(x + 3) / 6
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -33,12 +34,12 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
reduction=2,
|
||||
use_scale=True,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
mode='embedded_gaussian',
|
||||
in_channels: int,
|
||||
reduction: int = 2,
|
||||
use_scale: bool = True,
|
||||
conv_cfg: Optional[Dict] = None,
|
||||
norm_cfg: Optional[Dict] = None,
|
||||
mode: str = 'embedded_gaussian',
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
@ -61,7 +62,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
||||
self.inter_channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=conv_cfg,
|
||||
act_cfg=None)
|
||||
act_cfg=None) # type: ignore
|
||||
self.conv_out = ConvModule(
|
||||
self.inter_channels,
|
||||
self.in_channels,
|
||||
@ -96,7 +97,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
||||
|
||||
self.init_weights(**kwargs)
|
||||
|
||||
def init_weights(self, std=0.01, zeros_init=True):
|
||||
def init_weights(self, std: float = 0.01, zeros_init: bool = True) -> None:
|
||||
if self.mode != 'gaussian':
|
||||
for m in [self.g, self.theta, self.phi]:
|
||||
normal_init(m.conv, std=std)
|
||||
@ -113,7 +114,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
||||
else:
|
||||
normal_init(self.conv_out.norm, std=std)
|
||||
|
||||
def gaussian(self, theta_x, phi_x):
|
||||
def gaussian(self, theta_x: torch.Tensor,
|
||||
phi_x: torch.Tensor) -> torch.Tensor:
|
||||
# NonLocal1d pairwise_weight: [N, H, H]
|
||||
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
||||
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
||||
@ -121,7 +123,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
||||
pairwise_weight = pairwise_weight.softmax(dim=-1)
|
||||
return pairwise_weight
|
||||
|
||||
def embedded_gaussian(self, theta_x, phi_x):
|
||||
def embedded_gaussian(self, theta_x: torch.Tensor,
|
||||
phi_x: torch.Tensor) -> torch.Tensor:
|
||||
# NonLocal1d pairwise_weight: [N, H, H]
|
||||
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
||||
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
||||
@ -132,7 +135,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
||||
pairwise_weight = pairwise_weight.softmax(dim=-1)
|
||||
return pairwise_weight
|
||||
|
||||
def dot_product(self, theta_x, phi_x):
|
||||
def dot_product(self, theta_x: torch.Tensor,
|
||||
phi_x: torch.Tensor) -> torch.Tensor:
|
||||
# NonLocal1d pairwise_weight: [N, H, H]
|
||||
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
||||
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
||||
@ -140,7 +144,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
||||
pairwise_weight /= pairwise_weight.shape[-1]
|
||||
return pairwise_weight
|
||||
|
||||
def concatenation(self, theta_x, phi_x):
|
||||
def concatenation(self, theta_x: torch.Tensor,
|
||||
phi_x: torch.Tensor) -> torch.Tensor:
|
||||
# NonLocal1d pairwise_weight: [N, H, H]
|
||||
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
||||
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
||||
@ -157,7 +162,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
||||
|
||||
return pairwise_weight
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Assume `reduction = 1`, then `inter_channels = C`
|
||||
# or `inter_channels = C` when `mode="gaussian"`
|
||||
|
||||
@ -224,9 +229,9 @@ class NonLocal1d(_NonLocalNd):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
sub_sample=False,
|
||||
conv_cfg=dict(type='Conv1d'),
|
||||
in_channels: int,
|
||||
sub_sample: bool = False,
|
||||
conv_cfg: Dict = dict(type='Conv1d'),
|
||||
**kwargs):
|
||||
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
|
||||
|
||||
@ -257,9 +262,9 @@ class NonLocal2d(_NonLocalNd):
|
||||
_abbr_ = 'nonlocal_block'
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
sub_sample=False,
|
||||
conv_cfg=dict(type='Conv2d'),
|
||||
in_channels: int,
|
||||
sub_sample: bool = False,
|
||||
conv_cfg: Dict = dict(type='Conv2d'),
|
||||
**kwargs):
|
||||
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
|
||||
|
||||
@ -287,9 +292,9 @@ class NonLocal3d(_NonLocalNd):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
sub_sample=False,
|
||||
conv_cfg=dict(type='Conv3d'),
|
||||
in_channels: int,
|
||||
sub_sample: bool = False,
|
||||
conv_cfg: Dict = dict(type='Conv3d'),
|
||||
**kwargs):
|
||||
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
|
||||
self.sub_sample = sub_sample
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import inspect
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
@ -69,7 +70,9 @@ def infer_abbr(class_type):
|
||||
return 'norm_layer'
|
||||
|
||||
|
||||
def build_norm_layer(cfg, num_features, postfix=''):
|
||||
def build_norm_layer(cfg: Dict,
|
||||
num_features: int,
|
||||
postfix: Union[int, str] = '') -> Tuple[str, nn.Module]:
|
||||
"""Build normalization layer.
|
||||
|
||||
Args:
|
||||
@ -119,7 +122,8 @@ def build_norm_layer(cfg, num_features, postfix=''):
|
||||
return name, layer
|
||||
|
||||
|
||||
def is_norm(layer, exclude=None):
|
||||
def is_norm(layer: nn.Module,
|
||||
exclude: Union[type, tuple, None] = None) -> bool:
|
||||
"""Check if a layer is a normalization layer.
|
||||
|
||||
Args:
|
||||
|
@ -1,4 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .registry import PADDING_LAYERS
|
||||
@ -8,11 +10,11 @@ PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
|
||||
PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
|
||||
|
||||
|
||||
def build_padding_layer(cfg, *args, **kwargs):
|
||||
def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
|
||||
"""Build padding layer.
|
||||
|
||||
Args:
|
||||
cfg (None or dict): The padding layer config, which should contain:
|
||||
cfg (dict): The padding layer config, which should contain:
|
||||
- type (str): Layer type.
|
||||
- layer args: Args needed to instantiate a padding layer.
|
||||
|
||||
|
@ -1,6 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import inspect
|
||||
import platform
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .registry import PLUGIN_LAYERS
|
||||
|
||||
@ -10,7 +13,7 @@ else:
|
||||
import re # type: ignore
|
||||
|
||||
|
||||
def infer_abbr(class_type):
|
||||
def infer_abbr(class_type: type) -> str:
|
||||
"""Infer abbreviation from the class name.
|
||||
|
||||
This method will infer the abbreviation to map class types to
|
||||
@ -48,16 +51,18 @@ def infer_abbr(class_type):
|
||||
raise TypeError(
|
||||
f'class_type must be a type, but got {type(class_type)}')
|
||||
if hasattr(class_type, '_abbr_'):
|
||||
return class_type._abbr_
|
||||
return class_type._abbr_ # type: ignore
|
||||
else:
|
||||
return camel2snack(class_type.__name__)
|
||||
|
||||
|
||||
def build_plugin_layer(cfg, postfix='', **kwargs):
|
||||
def build_plugin_layer(cfg: Dict,
|
||||
postfix: Union[int, str] = '',
|
||||
**kwargs) -> Tuple[str, nn.Module]:
|
||||
"""Build plugin layer.
|
||||
|
||||
Args:
|
||||
cfg (None or dict): cfg should contain:
|
||||
cfg (dict): cfg should contain:
|
||||
|
||||
- type (str): identify plugin layer type.
|
||||
- layer args: args needed to instantiate a plugin layer.
|
||||
|
@ -13,9 +13,9 @@ class Scale(nn.Module):
|
||||
scale (float): Initial value of scale factor. Default: 1.0
|
||||
"""
|
||||
|
||||
def __init__(self, scale=1.0):
|
||||
def __init__(self, scale: float = 1.0):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x * self.scale
|
||||
|
@ -21,5 +21,5 @@ class Swish(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
@ -1,4 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
@ -24,8 +27,8 @@ class PixelShufflePack(nn.Module):
|
||||
channels.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, scale_factor,
|
||||
upsample_kernel):
|
||||
def __init__(self, in_channels: int, out_channels: int, scale_factor: int,
|
||||
upsample_kernel: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
@ -41,13 +44,13 @@ class PixelShufflePack(nn.Module):
|
||||
def init_weights(self):
|
||||
xavier_init(self.upsample_conv, distribution='uniform')
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.upsample_conv(x)
|
||||
x = F.pixel_shuffle(x, self.scale_factor)
|
||||
return x
|
||||
|
||||
|
||||
def build_upsample_layer(cfg, *args, **kwargs):
|
||||
def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
|
||||
"""Build upsample layer.
|
||||
|
||||
Args:
|
||||
|
@ -21,19 +21,19 @@ else:
|
||||
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
|
||||
|
||||
|
||||
def obsolete_torch_version(torch_version, version_threshold):
|
||||
def obsolete_torch_version(torch_version, version_threshold) -> bool:
|
||||
return torch_version == 'parrots' or torch_version <= version_threshold
|
||||
|
||||
|
||||
class NewEmptyTensorOp(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, new_shape):
|
||||
def forward(ctx, x: torch.Tensor, new_shape: tuple) -> torch.Tensor:
|
||||
ctx.shape = x.shape
|
||||
return x.new_empty(new_shape)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
def backward(ctx, grad: torch.Tensor) -> tuple:
|
||||
shape = ctx.shape
|
||||
return NewEmptyTensorOp.apply(grad, shape), None
|
||||
|
||||
@ -41,7 +41,7 @@ class NewEmptyTensorOp(torch.autograd.Function):
|
||||
@CONV_LAYERS.register_module('Conv', force=True)
|
||||
class Conv2d(nn.Conv2d):
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
|
||||
out_shape = [x.shape[0], self.out_channels]
|
||||
for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
|
||||
@ -62,7 +62,7 @@ class Conv2d(nn.Conv2d):
|
||||
@CONV_LAYERS.register_module('Conv3d', force=True)
|
||||
class Conv3d(nn.Conv3d):
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
|
||||
out_shape = [x.shape[0], self.out_channels]
|
||||
for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
|
||||
@ -85,7 +85,7 @@ class Conv3d(nn.Conv3d):
|
||||
@UPSAMPLE_LAYERS.register_module('deconv', force=True)
|
||||
class ConvTranspose2d(nn.ConvTranspose2d):
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
|
||||
out_shape = [x.shape[0], self.out_channels]
|
||||
for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
|
||||
@ -108,7 +108,7 @@ class ConvTranspose2d(nn.ConvTranspose2d):
|
||||
@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
|
||||
class ConvTranspose3d(nn.ConvTranspose3d):
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
|
||||
out_shape = [x.shape[0], self.out_channels]
|
||||
for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
|
||||
@ -128,7 +128,7 @@ class ConvTranspose3d(nn.ConvTranspose3d):
|
||||
|
||||
class MaxPool2d(nn.MaxPool2d):
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# PyTorch 1.9 does not support empty tensor inference yet
|
||||
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
|
||||
out_shape = list(x.shape[:2])
|
||||
@ -146,7 +146,7 @@ class MaxPool2d(nn.MaxPool2d):
|
||||
|
||||
class MaxPool3d(nn.MaxPool3d):
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# PyTorch 1.9 does not support empty tensor inference yet
|
||||
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
|
||||
out_shape = list(x.shape[:2])
|
||||
@ -165,7 +165,7 @@ class MaxPool3d(nn.MaxPool3d):
|
||||
|
||||
class Linear(torch.nn.Linear):
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# empty tensor forward of Linear layer is supported in Pytorch 1.6
|
||||
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
|
||||
out_shape = [x.shape[0], self.out_features]
|
||||
|
Loading…
x
Reference in New Issue
Block a user