improvement: add typehints and docs to timm/models/resnet.py

This commit is contained in:
a-r-r-o-w 2023-10-19 19:32:26 +05:30 committed by Ross Wightman
parent 564db019f6
commit 05b0aaca51

View File

@ -9,10 +9,12 @@ Copyright 2019, Ross Wightman
"""
import math
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \
@ -20,11 +22,12 @@ from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupN
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
from ._typing import LayerType
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
def get_padding(kernel_size, stride, dilation=1):
def get_padding(kernel_size: int, stride: int, dilation: int = 1):
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
@ -43,22 +46,40 @@ class BasicBlock(nn.Module):
def __init__(
self,
inplanes,
planes,
stride=1,
downsample=None,
cardinality=1,
base_width=64,
reduce_first=1,
dilation=1,
first_dilation=None,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
attn_layer=None,
aa_layer=None,
drop_block=None,
drop_path=None,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
cardinality: int = 1,
base_width: int = 64,
reduce_first: int = 1,
dilation: int = 1,
first_dilation: Optional[int] = None,
act_layer: nn.Module = nn.ReLU,
norm_layer: nn.Module = nn.BatchNorm2d,
attn_layer: Optional[nn.Module] = None,
aa_layer: Optional[nn.Module] = None,
drop_block: Type[nn.Module] = None,
drop_path: Optional[nn.Module] = None,
):
"""
Args:
inplanes: Input channel dimensionality.
planes: Used to determine output channel dimensionalities.
stride: Stride used in convolution layers.
downsample: Optional downsample layer for residual path.
cardinality: Number of convolution groups.
base_width: Base width used to determine output channel dimensionality.
reduce_first: Reduction factor for first convolution output width of residual blocks.
dilation: Dilation rate for convolution layers.
first_dilation: Dilation rate for first convolution layer.
act_layer: Activation layer.
norm_layer: Normalization layer.
attn_layer: Attention layer.
aa_layer: Anti-aliasing layer.
drop_block: Class for DropBlock layer.
drop_path: Optional DropPath layer.
"""
super(BasicBlock, self).__init__()
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
@ -92,7 +113,7 @@ class BasicBlock(nn.Module):
if getattr(self.bn2, 'weight', None) is not None:
nn.init.zeros_(self.bn2.weight)
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
shortcut = x
x = self.conv1(x)
@ -123,22 +144,40 @@ class Bottleneck(nn.Module):
def __init__(
self,
inplanes,
planes,
stride=1,
downsample=None,
cardinality=1,
base_width=64,
reduce_first=1,
dilation=1,
first_dilation=None,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
attn_layer=None,
aa_layer=None,
drop_block=None,
drop_path=None,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
cardinality: int = 1,
base_width: int = 64,
reduce_first: int = 1,
dilation: int = 1,
first_dilation: Optional[int] = None,
act_layer: nn.Module = nn.ReLU,
norm_layer: nn.Module = nn.BatchNorm2d,
attn_layer: Optional[nn.Module] = None,
aa_layer: Optional[nn.Module] = None,
drop_block: Type[nn.Module] = None,
drop_path: Optional[nn.Module] = None,
):
"""
Args:
inplanes: Input channel dimensionality.
planes: Used to determine output channel dimensionalities.
stride: Stride used in convolution layers.
downsample: Optional downsample layer for residual path.
cardinality: Number of convolution groups.
base_width: Base width used to determine output channel dimensionality.
reduce_first: Reduction factor for first convolution output width of residual blocks.
dilation: Dilation rate for convolution layers.
first_dilation: Dilation rate for first convolution layer.
act_layer: Activation layer.
norm_layer: Normalization layer.
attn_layer: Attention layer.
aa_layer: Anti-aliasing layer.
drop_block: Class for DropBlock layer.
drop_path: Optional DropPath layer.
"""
super(Bottleneck, self).__init__()
width = int(math.floor(planes * (base_width / 64)) * cardinality)
@ -174,7 +213,7 @@ class Bottleneck(nn.Module):
if getattr(self.bn3, 'weight', None) is not None:
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
shortcut = x
x = self.conv1(x)
@ -205,14 +244,14 @@ class Bottleneck(nn.Module):
def downsample_conv(
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
first_dilation=None,
norm_layer=None,
):
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
first_dilation: Optional[int] = None,
norm_layer: Optional[Type[nn.Module]] = None,
) -> nn.Module:
norm_layer = norm_layer or nn.BatchNorm2d
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
@ -226,14 +265,14 @@ def downsample_conv(
def downsample_avg(
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
first_dilation=None,
norm_layer=None,
):
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
first_dilation: Optional[int] = None,
norm_layer: Optional[Type[nn.Module]] = None,
) -> nn.Module:
norm_layer = norm_layer or nn.BatchNorm2d
avg_stride = stride if dilation == 1 else 1
if stride == 1 and dilation == 1:
@ -249,7 +288,7 @@ def downsample_avg(
])
def drop_blocks(drop_prob=0.):
def drop_blocks(drop_prob: float = 0.):
return [
None, None,
partial(DropBlock2d, drop_prob=drop_prob, block_size=5, gamma_scale=0.25) if drop_prob else None,
@ -257,18 +296,18 @@ def drop_blocks(drop_prob=0.):
def make_blocks(
block_fn,
channels,
block_repeats,
inplanes,
reduce_first=1,
output_stride=32,
down_kernel_size=1,
avg_down=False,
drop_block_rate=0.,
drop_path_rate=0.,
block_fn: nn.Module,
channels: List[int],
block_repeats: List[int],
inplanes: int,
reduce_first: int = 1,
output_stride: int = 32,
down_kernel_size: int = 1,
avg_down: bool = False,
drop_block_rate: float = 0.,
drop_path_rate: float = 0.,
**kwargs,
):
) -> Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Any]]]:
stages = []
feature_info = []
net_num_blocks = sum(block_repeats)
@ -356,28 +395,28 @@ class ResNet(nn.Module):
def __init__(
self,
block,
layers,
num_classes=1000,
in_chans=3,
output_stride=32,
global_pool='avg',
cardinality=1,
base_width=64,
stem_width=64,
stem_type='',
replace_stem_pool=False,
block_reduce_first=1,
down_kernel_size=1,
avg_down=False,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
aa_layer=None,
drop_rate=0.0,
drop_path_rate=0.,
drop_block_rate=0.,
zero_init_last=True,
block_args=None,
block: nn.Module,
layers: List[int],
num_classes: int = 1000,
in_chans: int = 3,
output_stride: int = 32,
global_pool: str = 'avg',
cardinality: int = 1,
base_width: int = 64,
stem_width: int = 64,
stem_type: str = '',
replace_stem_pool: bool = False,
block_reduce_first: int = 1,
down_kernel_size: int = 1,
avg_down: bool = False,
act_layer: LayerType = nn.ReLU,
norm_layer: LayerType = nn.BatchNorm2d,
aa_layer: Optional[nn.Module] = None,
drop_rate: float = 0.0,
drop_path_rate: float = 0.,
drop_block_rate: float = 0.,
zero_init_last: bool = True,
block_args: Optional[Dict[str, Any]] = None,
):
"""
Args:
@ -490,7 +529,7 @@ class ResNet(nn.Module):
self.init_weights(zero_init_last=zero_init_last)
@torch.jit.ignore
def init_weights(self, zero_init_last=True):
def init_weights(self, zero_init_last: bool = True):
for n, m in self.named_modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
@ -500,23 +539,23 @@ class ResNet(nn.Module):
m.zero_init_last()
@torch.jit.ignore
def group_matcher(self, coarse=False):
def group_matcher(self, coarse: bool = False):
matcher = dict(stem=r'^conv1|bn1|maxpool', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)')
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
def set_grad_checkpointing(self, enable: bool = True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self, name_only=False):
def get_classifier(self, name_only: bool = False):
return 'fc' if name_only else self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
def forward_features(self, x):
def forward_features(self, x: Tensor) -> Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
@ -531,19 +570,19 @@ class ResNet(nn.Module):
x = self.layer4(x)
return x
def forward_head(self, x, pre_logits: bool = False):
def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor:
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
return x if pre_logits else self.fc(x)
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _create_resnet(variant, pretrained=False, **kwargs):
def _create_resnet(variant, pretrained: bool = False, **kwargs) -> ResNet:
return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
@ -1204,7 +1243,7 @@ default_cfgs = generate_default_cfgs({
@register_model
def resnet10t(pretrained=False, **kwargs) -> ResNet:
def resnet10t(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-10-T model.
"""
model_args = dict(block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True)
@ -1212,7 +1251,7 @@ def resnet10t(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet14t(pretrained=False, **kwargs) -> ResNet:
def resnet14t(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-14-T model.
"""
model_args = dict(block=Bottleneck, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True)
@ -1220,7 +1259,7 @@ def resnet14t(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet18(pretrained=False, **kwargs) -> ResNet:
def resnet18(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-18 model.
"""
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2])
@ -1228,7 +1267,7 @@ def resnet18(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet18d(pretrained=False, **kwargs) -> ResNet:
def resnet18d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-18-D model.
"""
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True)
@ -1236,7 +1275,7 @@ def resnet18d(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet34(pretrained=False, **kwargs) -> ResNet:
def resnet34(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-34 model.
"""
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3])
@ -1244,7 +1283,7 @@ def resnet34(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet34d(pretrained=False, **kwargs) -> ResNet:
def resnet34d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-34-D model.
"""
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True)
@ -1252,7 +1291,7 @@ def resnet34d(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet26(pretrained=False, **kwargs) -> ResNet:
def resnet26(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-26 model.
"""
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2])
@ -1260,7 +1299,7 @@ def resnet26(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet26t(pretrained=False, **kwargs) -> ResNet:
def resnet26t(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-26-T model.
"""
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep_tiered', avg_down=True)
@ -1268,7 +1307,7 @@ def resnet26t(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet26d(pretrained=False, **kwargs) -> ResNet:
def resnet26d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-26-D model.
"""
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True)
@ -1276,7 +1315,7 @@ def resnet26d(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet50(pretrained=False, **kwargs) -> ResNet:
def resnet50(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50 model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
@ -1284,7 +1323,7 @@ def resnet50(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet50c(pretrained=False, **kwargs) -> ResNet:
def resnet50c(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-C model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep')
@ -1292,7 +1331,7 @@ def resnet50c(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet50d(pretrained=False, **kwargs) -> ResNet:
def resnet50d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-D model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True)
@ -1300,7 +1339,7 @@ def resnet50d(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet50s(pretrained=False, **kwargs) -> ResNet:
def resnet50s(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-S model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=64, stem_type='deep')
@ -1308,7 +1347,7 @@ def resnet50s(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet50t(pretrained=False, **kwargs) -> ResNet:
def resnet50t(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-T model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True)
@ -1316,7 +1355,7 @@ def resnet50t(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet101(pretrained=False, **kwargs) -> ResNet:
def resnet101(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-101 model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3])
@ -1324,7 +1363,7 @@ def resnet101(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet101c(pretrained=False, **kwargs) -> ResNet:
def resnet101c(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-101-C model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep')
@ -1332,7 +1371,7 @@ def resnet101c(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet101d(pretrained=False, **kwargs) -> ResNet:
def resnet101d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-101-D model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True)
@ -1340,7 +1379,7 @@ def resnet101d(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet101s(pretrained=False, **kwargs) -> ResNet:
def resnet101s(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-101-S model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=64, stem_type='deep')
@ -1348,7 +1387,7 @@ def resnet101s(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet152(pretrained=False, **kwargs) -> ResNet:
def resnet152(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-152 model.
"""
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3])
@ -1356,7 +1395,7 @@ def resnet152(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet152c(pretrained=False, **kwargs) -> ResNet:
def resnet152c(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-152-C model.
"""
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep')
@ -1364,7 +1403,7 @@ def resnet152c(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet152d(pretrained=False, **kwargs) -> ResNet:
def resnet152d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-152-D model.
"""
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True)
@ -1372,7 +1411,7 @@ def resnet152d(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet152s(pretrained=False, **kwargs) -> ResNet:
def resnet152s(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-152-S model.
"""
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=64, stem_type='deep')
@ -1380,7 +1419,7 @@ def resnet152s(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet200(pretrained=False, **kwargs) -> ResNet:
def resnet200(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-200 model.
"""
model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3])
@ -1388,7 +1427,7 @@ def resnet200(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet200d(pretrained=False, **kwargs) -> ResNet:
def resnet200d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-200-D model.
"""
model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True)
@ -1396,7 +1435,7 @@ def resnet200d(pretrained=False, **kwargs) -> ResNet:
@register_model
def wide_resnet50_2(pretrained=False, **kwargs) -> ResNet:
def wide_resnet50_2(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a Wide ResNet-50-2 model.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
@ -1408,7 +1447,7 @@ def wide_resnet50_2(pretrained=False, **kwargs) -> ResNet:
@register_model
def wide_resnet101_2(pretrained=False, **kwargs) -> ResNet:
def wide_resnet101_2(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a Wide ResNet-101-2 model.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
@ -1419,7 +1458,7 @@ def wide_resnet101_2(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnet50_gn(pretrained=False, **kwargs) -> ResNet:
def resnet50_gn(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50 model w/ GroupNorm
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
@ -1427,7 +1466,7 @@ def resnet50_gn(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnext50_32x4d(pretrained=False, **kwargs) -> ResNet:
def resnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNeXt50-32x4d model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4)
@ -1435,7 +1474,7 @@ def resnext50_32x4d(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnext50d_32x4d(pretrained=False, **kwargs) -> ResNet:
def resnext50d_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample
"""
model_args = dict(
@ -1445,7 +1484,7 @@ def resnext50d_32x4d(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnext101_32x4d(pretrained=False, **kwargs) -> ResNet:
def resnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNeXt-101 32x4d model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4)
@ -1453,7 +1492,7 @@ def resnext101_32x4d(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnext101_32x8d(pretrained=False, **kwargs) -> ResNet:
def resnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNeXt-101 32x8d model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8)
@ -1461,7 +1500,7 @@ def resnext101_32x8d(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnext101_32x16d(pretrained=False, **kwargs) -> ResNet:
def resnext101_32x16d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNeXt-101 32x16d model
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16)
@ -1469,7 +1508,7 @@ def resnext101_32x16d(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnext101_32x32d(pretrained=False, **kwargs) -> ResNet:
def resnext101_32x32d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNeXt-101 32x32d model
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32)
@ -1477,7 +1516,7 @@ def resnext101_32x32d(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnext101_64x4d(pretrained=False, **kwargs) -> ResNet:
def resnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNeXt101-64x4d model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4)
@ -1485,7 +1524,7 @@ def resnext101_64x4d(pretrained=False, **kwargs) -> ResNet:
@register_model
def ecaresnet26t(pretrained=False, **kwargs) -> ResNet:
def ecaresnet26t(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs an ECA-ResNeXt-26-T model.
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
in the deep stem and ECA attn.
@ -1497,7 +1536,7 @@ def ecaresnet26t(pretrained=False, **kwargs) -> ResNet:
@register_model
def ecaresnet50d(pretrained=False, **kwargs) -> ResNet:
def ecaresnet50d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-D model with eca.
"""
model_args = dict(
@ -1507,7 +1546,7 @@ def ecaresnet50d(pretrained=False, **kwargs) -> ResNet:
@register_model
def ecaresnet50d_pruned(pretrained=False, **kwargs) -> ResNet:
def ecaresnet50d_pruned(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-D model pruned with eca.
The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
"""
@ -1518,7 +1557,7 @@ def ecaresnet50d_pruned(pretrained=False, **kwargs) -> ResNet:
@register_model
def ecaresnet50t(pretrained=False, **kwargs) -> ResNet:
def ecaresnet50t(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs an ECA-ResNet-50-T model.
Like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem and ECA attn.
"""
@ -1529,7 +1568,7 @@ def ecaresnet50t(pretrained=False, **kwargs) -> ResNet:
@register_model
def ecaresnetlight(pretrained=False, **kwargs) -> ResNet:
def ecaresnetlight(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-D light model with eca.
"""
model_args = dict(
@ -1539,7 +1578,7 @@ def ecaresnetlight(pretrained=False, **kwargs) -> ResNet:
@register_model
def ecaresnet101d(pretrained=False, **kwargs) -> ResNet:
def ecaresnet101d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-101-D model with eca.
"""
model_args = dict(
@ -1549,7 +1588,7 @@ def ecaresnet101d(pretrained=False, **kwargs) -> ResNet:
@register_model
def ecaresnet101d_pruned(pretrained=False, **kwargs) -> ResNet:
def ecaresnet101d_pruned(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-101-D model pruned with eca.
The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
"""
@ -1560,7 +1599,7 @@ def ecaresnet101d_pruned(pretrained=False, **kwargs) -> ResNet:
@register_model
def ecaresnet200d(pretrained=False, **kwargs) -> ResNet:
def ecaresnet200d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-200-D model with ECA.
"""
model_args = dict(
@ -1570,7 +1609,7 @@ def ecaresnet200d(pretrained=False, **kwargs) -> ResNet:
@register_model
def ecaresnet269d(pretrained=False, **kwargs) -> ResNet:
def ecaresnet269d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-269-D model with ECA.
"""
model_args = dict(
@ -1580,7 +1619,7 @@ def ecaresnet269d(pretrained=False, **kwargs) -> ResNet:
@register_model
def ecaresnext26t_32x4d(pretrained=False, **kwargs) -> ResNet:
def ecaresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs an ECA-ResNeXt-26-T model.
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
in the deep stem. This model replaces SE module with the ECA module
@ -1592,7 +1631,7 @@ def ecaresnext26t_32x4d(pretrained=False, **kwargs) -> ResNet:
@register_model
def ecaresnext50t_32x4d(pretrained=False, **kwargs) -> ResNet:
def ecaresnext50t_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs an ECA-ResNeXt-50-T model.
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
in the deep stem. This model replaces SE module with the ECA module
@ -1604,25 +1643,25 @@ def ecaresnext50t_32x4d(pretrained=False, **kwargs) -> ResNet:
@register_model
def seresnet18(pretrained=False, **kwargs) -> ResNet:
def seresnet18(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'))
return _create_resnet('seresnet18', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet34(pretrained=False, **kwargs) -> ResNet:
def seresnet34(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'))
return _create_resnet('seresnet34', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet50(pretrained=False, **kwargs) -> ResNet:
def seresnet50(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'))
return _create_resnet('seresnet50', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet50t(pretrained=False, **kwargs) -> ResNet:
def seresnet50t(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered',
avg_down=True, block_args=dict(attn_layer='se'))
@ -1630,19 +1669,19 @@ def seresnet50t(pretrained=False, **kwargs) -> ResNet:
@register_model
def seresnet101(pretrained=False, **kwargs) -> ResNet:
def seresnet101(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='se'))
return _create_resnet('seresnet101', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet152(pretrained=False, **kwargs) -> ResNet:
def seresnet152(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='se'))
return _create_resnet('seresnet152', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet152d(pretrained=False, **kwargs) -> ResNet:
def seresnet152d(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep',
avg_down=True, block_args=dict(attn_layer='se'))
@ -1650,7 +1689,7 @@ def seresnet152d(pretrained=False, **kwargs) -> ResNet:
@register_model
def seresnet200d(pretrained=False, **kwargs) -> ResNet:
def seresnet200d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-200-D model with SE attn.
"""
model_args = dict(
@ -1660,7 +1699,7 @@ def seresnet200d(pretrained=False, **kwargs) -> ResNet:
@register_model
def seresnet269d(pretrained=False, **kwargs) -> ResNet:
def seresnet269d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-269-D model with SE attn.
"""
model_args = dict(
@ -1670,7 +1709,7 @@ def seresnet269d(pretrained=False, **kwargs) -> ResNet:
@register_model
def seresnext26d_32x4d(pretrained=False, **kwargs) -> ResNet:
def seresnext26d_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a SE-ResNeXt-26-D model.`
This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for
combination of deep stem and avg_pool in downsample.
@ -1682,7 +1721,7 @@ def seresnext26d_32x4d(pretrained=False, **kwargs) -> ResNet:
@register_model
def seresnext26t_32x4d(pretrained=False, **kwargs) -> ResNet:
def seresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a SE-ResNet-26-T model.
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
in the deep stem.
@ -1694,7 +1733,7 @@ def seresnext26t_32x4d(pretrained=False, **kwargs) -> ResNet:
@register_model
def seresnext26tn_32x4d(pretrained=False, **kwargs) -> ResNet:
def seresnext26tn_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a SE-ResNeXt-26-T model.
NOTE I deprecated previous 't' model defs and replaced 't' with 'tn', this was the only tn model of note
so keeping this def for backwards compat with any uses out there. Old 't' model is lost.
@ -1703,7 +1742,7 @@ def seresnext26tn_32x4d(pretrained=False, **kwargs) -> ResNet:
@register_model
def seresnext50_32x4d(pretrained=False, **kwargs) -> ResNet:
def seresnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
block_args=dict(attn_layer='se'))
@ -1711,7 +1750,7 @@ def seresnext50_32x4d(pretrained=False, **kwargs) -> ResNet:
@register_model
def seresnext101_32x4d(pretrained=False, **kwargs) -> ResNet:
def seresnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4,
block_args=dict(attn_layer='se'))
@ -1719,7 +1758,7 @@ def seresnext101_32x4d(pretrained=False, **kwargs) -> ResNet:
@register_model
def seresnext101_32x8d(pretrained=False, **kwargs) -> ResNet:
def seresnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
block_args=dict(attn_layer='se'))
@ -1727,7 +1766,7 @@ def seresnext101_32x8d(pretrained=False, **kwargs) -> ResNet:
@register_model
def seresnext101d_32x8d(pretrained=False, **kwargs) -> ResNet:
def seresnext101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
stem_width=32, stem_type='deep', avg_down=True,
@ -1736,7 +1775,7 @@ def seresnext101d_32x8d(pretrained=False, **kwargs) -> ResNet:
@register_model
def seresnext101_64x4d(pretrained=False, **kwargs) -> ResNet:
def seresnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4,
block_args=dict(attn_layer='se'))
@ -1744,7 +1783,7 @@ def seresnext101_64x4d(pretrained=False, **kwargs) -> ResNet:
@register_model
def senet154(pretrained=False, **kwargs) -> ResNet:
def senet154(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep',
down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'))
@ -1752,7 +1791,7 @@ def senet154(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnetblur18(pretrained=False, **kwargs) -> ResNet:
def resnetblur18(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-18 model with blur anti-aliasing
"""
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d)
@ -1760,7 +1799,7 @@ def resnetblur18(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnetblur50(pretrained=False, **kwargs) -> ResNet:
def resnetblur50(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50 model with blur anti-aliasing
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d)
@ -1768,7 +1807,7 @@ def resnetblur50(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnetblur50d(pretrained=False, **kwargs) -> ResNet:
def resnetblur50d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-D model with blur anti-aliasing
"""
model_args = dict(
@ -1778,7 +1817,7 @@ def resnetblur50d(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnetblur101d(pretrained=False, **kwargs) -> ResNet:
def resnetblur101d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-101-D model with blur anti-aliasing
"""
model_args = dict(
@ -1788,7 +1827,7 @@ def resnetblur101d(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnetaa34d(pretrained=False, **kwargs) -> ResNet:
def resnetaa34d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-34-D model w/ avgpool anti-aliasing
"""
model_args = dict(
@ -1797,7 +1836,7 @@ def resnetaa34d(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnetaa50(pretrained=False, **kwargs) -> ResNet:
def resnetaa50(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50 model with avgpool anti-aliasing
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d)
@ -1805,7 +1844,7 @@ def resnetaa50(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnetaa50d(pretrained=False, **kwargs) -> ResNet:
def resnetaa50d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-D model with avgpool anti-aliasing
"""
model_args = dict(
@ -1815,7 +1854,7 @@ def resnetaa50d(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnetaa101d(pretrained=False, **kwargs) -> ResNet:
def resnetaa101d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-101-D model with avgpool anti-aliasing
"""
model_args = dict(
@ -1825,7 +1864,7 @@ def resnetaa101d(pretrained=False, **kwargs) -> ResNet:
@register_model
def seresnetaa50d(pretrained=False, **kwargs) -> ResNet:
def seresnetaa50d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a SE=ResNet-50-D model with avgpool anti-aliasing
"""
model_args = dict(
@ -1835,7 +1874,7 @@ def seresnetaa50d(pretrained=False, **kwargs) -> ResNet:
@register_model
def seresnextaa101d_32x8d(pretrained=False, **kwargs) -> ResNet:
def seresnextaa101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing
"""
model_args = dict(
@ -1846,7 +1885,7 @@ def seresnextaa101d_32x8d(pretrained=False, **kwargs) -> ResNet:
@register_model
def seresnextaa201d_32x8d(pretrained=False, **kwargs):
def seresnextaa201d_32x8d(pretrained: bool = False, **kwargs):
"""Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing
"""
model_args = dict(
@ -1857,7 +1896,7 @@ def seresnextaa201d_32x8d(pretrained=False, **kwargs):
@register_model
def resnetrs50(pretrained=False, **kwargs) -> ResNet:
def resnetrs50(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-RS-50 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
@ -1870,7 +1909,7 @@ def resnetrs50(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnetrs101(pretrained=False, **kwargs) -> ResNet:
def resnetrs101(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-RS-101 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
@ -1883,7 +1922,7 @@ def resnetrs101(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnetrs152(pretrained=False, **kwargs) -> ResNet:
def resnetrs152(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-RS-152 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
@ -1896,7 +1935,7 @@ def resnetrs152(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnetrs200(pretrained=False, **kwargs) -> ResNet:
def resnetrs200(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-RS-200 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
@ -1909,7 +1948,7 @@ def resnetrs200(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnetrs270(pretrained=False, **kwargs) -> ResNet:
def resnetrs270(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-RS-270 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
@ -1923,7 +1962,7 @@ def resnetrs270(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnetrs350(pretrained=False, **kwargs) -> ResNet:
def resnetrs350(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-RS-350 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
@ -1936,7 +1975,7 @@ def resnetrs350(pretrained=False, **kwargs) -> ResNet:
@register_model
def resnetrs420(pretrained=False, **kwargs) -> ResNet:
def resnetrs420(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-RS-420 model
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs