mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Change default_cfg names for senet to include the legacy and match model names
This commit is contained in:
parent
6e9d6172c8
commit
d5145fa4d5
@ -112,7 +112,7 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
|
||||
if 'GITHUB_ACTIONS' not in os.environ:
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models())
|
||||
@pytest.mark.parametrize('model_name', list_models(pretrained=True))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_load_pretrained(model_name, batch_size):
|
||||
"""Run a single forward pass with each model"""
|
||||
|
@ -36,25 +36,25 @@ def _cfg(url='', **kwargs):
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'senet154':
|
||||
'legacy_senet154':
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth'),
|
||||
'seresnet18': _cfg(
|
||||
'legacy_seresnet18': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet18-4bb0ce65.pth',
|
||||
interpolation='bicubic'),
|
||||
'seresnet34': _cfg(
|
||||
'legacy_seresnet34': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet34-a4004e63.pth'),
|
||||
'seresnet50': _cfg(
|
||||
'legacy_seresnet50': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'),
|
||||
'seresnet101': _cfg(
|
||||
'legacy_seresnet101': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'),
|
||||
'seresnet152': _cfg(
|
||||
'legacy_seresnet152': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'),
|
||||
'seresnext26_32x4d': _cfg(
|
||||
'legacy_seresnext26_32x4d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth',
|
||||
interpolation='bicubic'),
|
||||
'seresnext50_32x4d':
|
||||
'legacy_seresnext50_32x4d':
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'),
|
||||
'seresnext101_32x4d':
|
||||
'legacy_seresnext101_32x4d':
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth'),
|
||||
}
|
||||
|
||||
@ -408,35 +408,35 @@ def _create_senet(variant, pretrained=False, **kwargs):
|
||||
def legacy_seresnet18(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
block=SEResNetBlock, layers=[2, 2, 2, 2], groups=1, reduction=16, **kwargs)
|
||||
return _create_senet('seresnet18', pretrained, **model_args)
|
||||
return _create_senet('legacy_seresnet18', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def legacy_seresnet34(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
block=SEResNetBlock, layers=[3, 4, 6, 3], groups=1, reduction=16, **kwargs)
|
||||
return _create_senet('seresnet34', pretrained, **model_args)
|
||||
return _create_senet('legacy_seresnet34', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def legacy_seresnet50(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
block=SEResNetBottleneck, layers=[3, 4, 6, 3], groups=1, reduction=16, **kwargs)
|
||||
return _create_senet('seresnet50', pretrained, **model_args)
|
||||
return _create_senet('legacy_seresnet50', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def legacy_seresnet101(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
block=SEResNetBottleneck, layers=[3, 4, 23, 3], groups=1, reduction=16, **kwargs)
|
||||
return _create_senet('seresnet101', pretrained, **model_args)
|
||||
return _create_senet('legacy_seresnet101', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def legacy_seresnet152(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
block=SEResNetBottleneck, layers=[3, 8, 36, 3], groups=1, reduction=16, **kwargs)
|
||||
return _create_senet('seresnet152', pretrained, **model_args)
|
||||
return _create_senet('legacy_seresnet152', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -444,25 +444,25 @@ def legacy_senet154(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
block=SEBottleneck, layers=[3, 8, 36, 3], groups=64, reduction=16,
|
||||
downsample_kernel_size=3, downsample_padding=1, inplanes=128, input_3x3=True, **kwargs)
|
||||
return _create_senet('senet154', pretrained, **model_args)
|
||||
return _create_senet('legacy_senet154', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def legacy_seresnext26_32x4d(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
block=SEResNeXtBottleneck, layers=[2, 2, 2, 2], groups=32, reduction=16, **kwargs)
|
||||
return _create_senet('seresnext26_32x4d', pretrained, **model_args)
|
||||
return _create_senet('legacy_seresnext26_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def legacy_seresnext50_32x4d(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
block=SEResNeXtBottleneck, layers=[3, 4, 6, 3], groups=32, reduction=16, **kwargs)
|
||||
return _create_senet('seresnext50_32x4d', pretrained, **model_args)
|
||||
return _create_senet('legacy_seresnext50_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def legacy_seresnext101_32x4d(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
block=SEResNeXtBottleneck, layers=[3, 4, 23, 3], groups=32, reduction=16, **kwargs)
|
||||
return _create_senet('seresnext101_32x4d', pretrained, **model_args)
|
||||
return _create_senet('legacy_seresnext101_32x4d', pretrained, **model_args)
|
||||
|
Loading…
x
Reference in New Issue
Block a user