|
|
|
@ -7,7 +7,6 @@ ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered ste
|
|
|
|
|
Copyright 2020 Ross Wightman
|
|
|
|
|
"""
|
|
|
|
|
import math
|
|
|
|
|
import copy
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
@ -58,24 +57,18 @@ default_cfgs = {
|
|
|
|
|
'resnet101': _cfg(url='', interpolation='bicubic'),
|
|
|
|
|
'resnet101d': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet101d_ra2-2803ffab.pth',
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
|
|
|
|
|
'resnet101d_320': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet101d_ra2-2803ffab.pth',
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), crop_pct=1.0, pool_size=(10, 10)),
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
|
|
|
|
|
crop_pct=1.0, test_input_size=(3, 320, 320)),
|
|
|
|
|
'resnet152': _cfg(url='', interpolation='bicubic'),
|
|
|
|
|
'resnet152d': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet152d_ra2-5cac0439.pth',
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
|
|
|
|
|
'resnet152d_320': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet152d_ra2-5cac0439.pth',
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), crop_pct=1.0, pool_size=(10, 10)),
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
|
|
|
|
|
crop_pct=1.0, test_input_size=(3, 320, 320)),
|
|
|
|
|
'resnet200': _cfg(url='', interpolation='bicubic'),
|
|
|
|
|
'resnet200d': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet200d_ra2-bdba9bf9.pth',
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
|
|
|
|
|
'resnet200d_320': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet200d_ra2-bdba9bf9.pth',
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), crop_pct=1.0, pool_size=(10, 10)),
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
|
|
|
|
|
crop_pct=1.0, test_input_size=(3, 320, 320)),
|
|
|
|
|
'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'),
|
|
|
|
|
'tv_resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'),
|
|
|
|
|
'tv_resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'),
|
|
|
|
@ -146,7 +139,7 @@ default_cfgs = {
|
|
|
|
|
'seresnet50': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet50_ra_224-8efdb4bb.pth',
|
|
|
|
|
interpolation='bicubic'),
|
|
|
|
|
'seresnet50tn': _cfg(
|
|
|
|
|
'seresnet50t': _cfg(
|
|
|
|
|
url='',
|
|
|
|
|
interpolation='bicubic',
|
|
|
|
|
first_conv='conv1.0'),
|
|
|
|
@ -158,10 +151,9 @@ default_cfgs = {
|
|
|
|
|
interpolation='bicubic'),
|
|
|
|
|
'seresnet152d': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet152d_ra2-04464dd2.pth',
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
|
|
|
|
|
'seresnet152d_320': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet152d_ra2-04464dd2.pth',
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), crop_pct=1.0, pool_size=(10, 10)),
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
|
|
|
|
|
crop_pct=1.0, test_input_size=(3, 320, 320)
|
|
|
|
|
),
|
|
|
|
|
'seresnet200d': _cfg(
|
|
|
|
|
url='',
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
|
|
|
|
@ -171,18 +163,11 @@ default_cfgs = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py
|
|
|
|
|
'seresnext26_32x4d': _cfg(
|
|
|
|
|
url='',
|
|
|
|
|
interpolation='bicubic'),
|
|
|
|
|
'seresnext26d_32x4d': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth',
|
|
|
|
|
interpolation='bicubic',
|
|
|
|
|
first_conv='conv1.0'),
|
|
|
|
|
'seresnext26t_32x4d': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26t_32x4d-361bc1c4.pth',
|
|
|
|
|
interpolation='bicubic',
|
|
|
|
|
first_conv='conv1.0'),
|
|
|
|
|
'seresnext26tn_32x4d': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth',
|
|
|
|
|
interpolation='bicubic',
|
|
|
|
|
first_conv='conv1.0'),
|
|
|
|
@ -201,8 +186,10 @@ default_cfgs = {
|
|
|
|
|
first_conv='conv1.0'),
|
|
|
|
|
|
|
|
|
|
# Efficient Channel Attention ResNets
|
|
|
|
|
'ecaresnet18': _cfg(),
|
|
|
|
|
'ecaresnet50': _cfg(),
|
|
|
|
|
'ecaresnet26t': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet26t_ra2-46609757.pth',
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
|
|
|
|
|
crop_pct=0.95, test_input_size=(3, 320, 320)),
|
|
|
|
|
'ecaresnetlight': _cfg(
|
|
|
|
|
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNetLight_4f34b35b.pth',
|
|
|
|
|
interpolation='bicubic'),
|
|
|
|
@ -214,10 +201,13 @@ default_cfgs = {
|
|
|
|
|
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45899/outputs/ECAResNet50D_P_9c67f710.pth',
|
|
|
|
|
interpolation='bicubic',
|
|
|
|
|
first_conv='conv1.0'),
|
|
|
|
|
'ecaresnet50t': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet50t_ra2-f7ac63c4.pth',
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
|
|
|
|
|
crop_pct=0.95, test_input_size=(3, 320, 320)),
|
|
|
|
|
'ecaresnet101d': _cfg(
|
|
|
|
|
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet101D_281c5844.pth',
|
|
|
|
|
interpolation='bicubic',
|
|
|
|
|
first_conv='conv1.0'),
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0'),
|
|
|
|
|
'ecaresnet101d_pruned': _cfg(
|
|
|
|
|
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
|
|
|
|
|
interpolation='bicubic',
|
|
|
|
@ -226,17 +216,17 @@ default_cfgs = {
|
|
|
|
|
url='',
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
|
|
|
|
|
'ecaresnet269d': _cfg(
|
|
|
|
|
url='',
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet269d_320_ra2-7baa55cb.pth',
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), pool_size=(10, 10),
|
|
|
|
|
crop_pct=1.0, test_input_size=(3, 352, 352)),
|
|
|
|
|
|
|
|
|
|
# Efficient Channel Attention ResNeXts
|
|
|
|
|
'ecaresnext26tn_32x4d': _cfg(
|
|
|
|
|
'ecaresnext26t_32x4d': _cfg(
|
|
|
|
|
url='',
|
|
|
|
|
interpolation='bicubic',
|
|
|
|
|
first_conv='conv1.0'),
|
|
|
|
|
'ecaresnext50_32x4d': _cfg(
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0'),
|
|
|
|
|
'ecaresnext50t_32x4d': _cfg(
|
|
|
|
|
url='',
|
|
|
|
|
interpolation='bicubic'),
|
|
|
|
|
interpolation='bicubic', first_conv='conv1.0'),
|
|
|
|
|
|
|
|
|
|
# ResNets with anti-aliasing blur pool
|
|
|
|
|
'resnetblur18': _cfg(
|
|
|
|
@ -529,8 +519,7 @@ class ResNet(nn.Module):
|
|
|
|
|
The type of stem:
|
|
|
|
|
* '', default - a single 7x7 conv with a width of stem_width
|
|
|
|
|
* 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
|
|
|
|
|
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width//4 * 6, stem_width * 2
|
|
|
|
|
* 'deep_tiered_narrow' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
|
|
|
|
|
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
|
|
|
|
|
block_reduce_first: int, default 1
|
|
|
|
|
Reduction factor for first convolution output width of residual blocks,
|
|
|
|
|
1 for all archs except senets, where 2
|
|
|
|
@ -564,18 +553,17 @@ class ResNet(nn.Module):
|
|
|
|
|
deep_stem = 'deep' in stem_type
|
|
|
|
|
inplanes = stem_width * 2 if deep_stem else 64
|
|
|
|
|
if deep_stem:
|
|
|
|
|
stem_chs_1 = stem_chs_2 = stem_width
|
|
|
|
|
stem_chs = (stem_width, stem_width)
|
|
|
|
|
if 'tiered' in stem_type:
|
|
|
|
|
stem_chs_1 = 3 * (stem_width // 4)
|
|
|
|
|
stem_chs_2 = stem_width if 'narrow' in stem_type else 6 * (stem_width // 4)
|
|
|
|
|
stem_chs = (3 * (stem_width // 4), stem_width)
|
|
|
|
|
self.conv1 = nn.Sequential(*[
|
|
|
|
|
nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False),
|
|
|
|
|
norm_layer(stem_chs_1),
|
|
|
|
|
nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False),
|
|
|
|
|
norm_layer(stem_chs[0]),
|
|
|
|
|
act_layer(inplace=True),
|
|
|
|
|
nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False),
|
|
|
|
|
norm_layer(stem_chs_2),
|
|
|
|
|
nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False),
|
|
|
|
|
norm_layer(stem_chs[1]),
|
|
|
|
|
act_layer(inplace=True),
|
|
|
|
|
nn.Conv2d(stem_chs_2, inplanes, 3, stride=1, padding=1, bias=False)])
|
|
|
|
|
nn.Conv2d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)])
|
|
|
|
|
else:
|
|
|
|
|
self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
|
|
|
|
self.bn1 = norm_layer(inplanes)
|
|
|
|
@ -732,14 +720,6 @@ def resnet101d(pretrained=False, **kwargs):
|
|
|
|
|
return _create_resnet('resnet101d', pretrained, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnet101d_320(pretrained=False, **kwargs):
|
|
|
|
|
"""Constructs a ResNet-101-D model.
|
|
|
|
|
"""
|
|
|
|
|
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
|
|
|
|
|
return _create_resnet('resnet101d_320', pretrained, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnet152(pretrained=False, **kwargs):
|
|
|
|
|
"""Constructs a ResNet-152 model.
|
|
|
|
@ -757,15 +737,6 @@ def resnet152d(pretrained=False, **kwargs):
|
|
|
|
|
return _create_resnet('resnet152d', pretrained, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnet152d_320(pretrained=False, **kwargs):
|
|
|
|
|
"""Constructs a ResNet-152-D model.
|
|
|
|
|
"""
|
|
|
|
|
model_args = dict(
|
|
|
|
|
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
|
|
|
|
|
return _create_resnet('resnet152d_320', pretrained, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnet200(pretrained=False, **kwargs):
|
|
|
|
|
"""Constructs a ResNet-200 model.
|
|
|
|
@ -783,15 +754,6 @@ def resnet200d(pretrained=False, **kwargs):
|
|
|
|
|
return _create_resnet('resnet200d', pretrained, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnet200d_320(pretrained=False, **kwargs):
|
|
|
|
|
"""Constructs a ResNet-200-D model. NOTE: Duplicate of 200D above w/ diff default cfg for 320x320.
|
|
|
|
|
"""
|
|
|
|
|
model_args = dict(
|
|
|
|
|
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
|
|
|
|
|
return _create_resnet('resnet200d_320', pretrained, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def tv_resnet34(pretrained=False, **kwargs):
|
|
|
|
|
"""Constructs a ResNet-34 model with original Torchvision weights.
|
|
|
|
@ -1068,19 +1030,15 @@ def swsl_resnext101_32x16d(pretrained=True, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def ecaresnet18(pretrained=False, **kwargs):
|
|
|
|
|
""" Constructs an ECA-ResNet-18 model.
|
|
|
|
|
def ecaresnet26t(pretrained=False, **kwargs):
|
|
|
|
|
"""Constructs an ECA-ResNeXt-26-T model.
|
|
|
|
|
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
|
|
|
|
|
in the deep stem and ECA attn.
|
|
|
|
|
"""
|
|
|
|
|
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='eca'), **kwargs)
|
|
|
|
|
return _create_resnet('ecaresnet18', pretrained, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def ecaresnet50(pretrained=False, **kwargs):
|
|
|
|
|
"""Constructs an ECA-ResNet-50 model.
|
|
|
|
|
"""
|
|
|
|
|
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='eca'), **kwargs)
|
|
|
|
|
return _create_resnet('ecaresnet50', pretrained, **model_args)
|
|
|
|
|
model_args = dict(
|
|
|
|
|
block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32,
|
|
|
|
|
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
|
|
|
|
|
return _create_resnet('ecaresnet26t', pretrained, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@ -1104,6 +1062,17 @@ def ecaresnet50d_pruned(pretrained=False, **kwargs):
|
|
|
|
|
return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def ecaresnet50t(pretrained=False, **kwargs):
|
|
|
|
|
"""Constructs an ECA-ResNet-50-T model.
|
|
|
|
|
Like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem and ECA attn.
|
|
|
|
|
"""
|
|
|
|
|
model_args = dict(
|
|
|
|
|
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32,
|
|
|
|
|
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
|
|
|
|
|
return _create_resnet('ecaresnet50t', pretrained, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def ecaresnetlight(pretrained=False, **kwargs):
|
|
|
|
|
"""Constructs a ResNet-50-D light model with eca.
|
|
|
|
@ -1156,16 +1125,27 @@ def ecaresnet269d(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def ecaresnext26tn_32x4d(pretrained=False, **kwargs):
|
|
|
|
|
"""Constructs an ECA-ResNeXt-26-TN model.
|
|
|
|
|
def ecaresnext26t_32x4d(pretrained=False, **kwargs):
|
|
|
|
|
"""Constructs an ECA-ResNeXt-26-T model.
|
|
|
|
|
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
|
|
|
|
|
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
|
|
|
|
|
this model replaces SE module with the ECA module
|
|
|
|
|
in the deep stem. This model replaces SE module with the ECA module
|
|
|
|
|
"""
|
|
|
|
|
model_args = dict(
|
|
|
|
|
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
|
|
|
|
stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
|
|
|
|
|
return _create_resnet('ecaresnext26tn_32x4d', pretrained, **model_args)
|
|
|
|
|
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
|
|
|
|
|
return _create_resnet('ecaresnext26t_32x4d', pretrained, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def ecaresnext50t_32x4d(pretrained=False, **kwargs):
|
|
|
|
|
"""Constructs an ECA-ResNeXt-50-T model.
|
|
|
|
|
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
|
|
|
|
|
in the deep stem. This model replaces SE module with the ECA module
|
|
|
|
|
"""
|
|
|
|
|
model_args = dict(
|
|
|
|
|
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
|
|
|
|
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
|
|
|
|
|
return _create_resnet('ecaresnext50t_32x4d', pretrained, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@ -1203,11 +1183,11 @@ def seresnet50(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def seresnet50tn(pretrained=False, **kwargs):
|
|
|
|
|
def seresnet50t(pretrained=False, **kwargs):
|
|
|
|
|
model_args = dict(
|
|
|
|
|
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered_narrow', avg_down=True,
|
|
|
|
|
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True,
|
|
|
|
|
block_args=dict(attn_layer='se'), **kwargs)
|
|
|
|
|
return _create_resnet('seresnet50tn', pretrained, **model_args)
|
|
|
|
|
return _create_resnet('seresnet50t', pretrained, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@ -1250,22 +1230,6 @@ def seresnet269d(pretrained=False, **kwargs):
|
|
|
|
|
return _create_resnet('seresnet269d', pretrained, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def seresnet152d_320(pretrained=False, **kwargs):
|
|
|
|
|
model_args = dict(
|
|
|
|
|
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
|
|
|
|
|
block_args=dict(attn_layer='se'), **kwargs)
|
|
|
|
|
return _create_resnet('seresnet152d_320', pretrained, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def seresnext26_32x4d(pretrained=False, **kwargs):
|
|
|
|
|
model_args = dict(
|
|
|
|
|
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4,
|
|
|
|
|
block_args=dict(attn_layer='se'), **kwargs)
|
|
|
|
|
return _create_resnet('seresnext26_32x4d', pretrained, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def seresnext26d_32x4d(pretrained=False, **kwargs):
|
|
|
|
|
"""Constructs a SE-ResNeXt-26-D model.`
|
|
|
|
@ -1281,7 +1245,7 @@ def seresnext26d_32x4d(pretrained=False, **kwargs):
|
|
|
|
|
@register_model
|
|
|
|
|
def seresnext26t_32x4d(pretrained=False, **kwargs):
|
|
|
|
|
"""Constructs a SE-ResNet-26-T model.
|
|
|
|
|
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 48, 64 channels
|
|
|
|
|
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
|
|
|
|
|
in the deep stem.
|
|
|
|
|
"""
|
|
|
|
|
model_args = dict(
|
|
|
|
@ -1292,14 +1256,11 @@ def seresnext26t_32x4d(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def seresnext26tn_32x4d(pretrained=False, **kwargs):
|
|
|
|
|
"""Constructs a SE-ResNeXt-26-TN model.
|
|
|
|
|
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
|
|
|
|
|
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
|
|
|
|
|
"""Constructs a SE-ResNeXt-26-T model.
|
|
|
|
|
NOTE I deprecated previous 't' model defs and replaced 't' with 'tn', this was the only tn model of note
|
|
|
|
|
so keeping this def for backwards compat with any uses out there. Old 't' model is lost.
|
|
|
|
|
"""
|
|
|
|
|
model_args = dict(
|
|
|
|
|
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
|
|
|
|
stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
|
|
|
|
|
return _create_resnet('seresnext26tn_32x4d', pretrained, **model_args)
|
|
|
|
|
return seresnext26t_32x4d(pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|