mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add backward and default_cfg tests and fix a few issues found. Fix #153
This commit is contained in:
parent
ea2e59ca36
commit
afb6bd0669
@ -1,19 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from timm import list_models, create_model
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(300)
|
|
||||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters='*efficientnet_l2*'))
|
|
||||||
@pytest.mark.parametrize('batch_size', [1])
|
|
||||||
def test_model_forward(model_name, batch_size):
|
|
||||||
"""Run a single forward pass with each model"""
|
|
||||||
model = create_model(model_name, pretrained=False)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
inputs = torch.randn((batch_size, *model.default_cfg['input_size']))
|
|
||||||
outputs = model(inputs)
|
|
||||||
|
|
||||||
assert outputs.shape[0] == batch_size
|
|
||||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
|
70
tests/test_models.py
Normal file
70
tests/test_models.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from timm import list_models, create_model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(120)
|
||||||
|
@pytest.mark.parametrize('model_name', list_models())
|
||||||
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
|
def test_model_forward(model_name, batch_size):
|
||||||
|
"""Run a single forward pass with each model"""
|
||||||
|
model = create_model(model_name, pretrained=False)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
input_size = model.default_cfg['input_size']
|
||||||
|
if any([x > 448 for x in input_size]):
|
||||||
|
# cap forward test at max res 448 * 448 to keep resource down
|
||||||
|
input_size = tuple([min(x, 448) for x in input_size])
|
||||||
|
inputs = torch.randn((batch_size, *input_size))
|
||||||
|
outputs = model(inputs)
|
||||||
|
|
||||||
|
assert outputs.shape[0] == batch_size
|
||||||
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(120)
|
||||||
|
@pytest.mark.parametrize('model_name', list_models(exclude_filters='dla*')) # DLA models have an issue TBD
|
||||||
|
@pytest.mark.parametrize('batch_size', [2])
|
||||||
|
def test_model_backward(model_name, batch_size):
|
||||||
|
"""Run a single forward pass with each model"""
|
||||||
|
model = create_model(model_name, pretrained=False, num_classes=42)
|
||||||
|
num_params = sum([x.numel() for x in model.parameters()])
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
input_size = model.default_cfg['input_size']
|
||||||
|
if any([x > 128 for x in input_size]):
|
||||||
|
# cap backward test at 128 * 128 to keep resource usage down
|
||||||
|
input_size = tuple([min(x, 128) for x in input_size])
|
||||||
|
inputs = torch.randn((batch_size, *input_size))
|
||||||
|
outputs = model(inputs)
|
||||||
|
outputs.mean().backward()
|
||||||
|
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
|
||||||
|
|
||||||
|
assert outputs.shape[-1] == 42
|
||||||
|
assert num_params == num_grad, 'Some parameters are missing gradients'
|
||||||
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(120)
|
||||||
|
@pytest.mark.parametrize('model_name', list_models())
|
||||||
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
|
def test_model_default_cfgs(model_name, batch_size):
|
||||||
|
"""Run a single forward pass with each model"""
|
||||||
|
model = create_model(model_name, pretrained=False)
|
||||||
|
model.eval()
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
cfg = model.default_cfg
|
||||||
|
|
||||||
|
classifier = cfg['classifier']
|
||||||
|
first_conv = cfg['first_conv']
|
||||||
|
pool_size = cfg['pool_size']
|
||||||
|
input_size = model.default_cfg['input_size']
|
||||||
|
|
||||||
|
if all([x <= 448 for x in input_size]):
|
||||||
|
# pool size only checked if default res <= 448 * 448 to keep resource down
|
||||||
|
input_size = tuple([min(x, 448) for x in input_size])
|
||||||
|
outputs = model.forward_features(torch.randn((batch_size, *input_size)))
|
||||||
|
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
|
||||||
|
assert any([k.startswith(cfg['classifier']) for k in state_dict.keys()]), f'{classifier} not in model params'
|
||||||
|
assert any([k.startswith(cfg['first_conv']) for k in state_dict.keys()]), f'{first_conv} not in model params'
|
@ -237,8 +237,11 @@ class DlaTree(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, residual=None, children=None):
|
def forward(self, x, residual=None, children=None):
|
||||||
children = [] if children is None else children
|
children = [] if children is None else children
|
||||||
bottom = self.downsample(x) if self.downsample else x
|
# FIXME the way downsample / project are used here and residual is passed to next level up
|
||||||
residual = self.project(bottom) if self.project else bottom
|
# the tree, the residual is overridden and some project weights are thus never used and
|
||||||
|
# have no gradients. This appears to be an issue with the original model / weights.
|
||||||
|
bottom = self.downsample(x) if self.downsample is not None else x
|
||||||
|
residual = self.project(bottom) if self.project is not None else bottom
|
||||||
if self.level_root:
|
if self.level_root:
|
||||||
children.append(bottom)
|
children.append(bottom)
|
||||||
x1 = self.tree1(x, residual)
|
x1 = self.tree1(x, residual)
|
||||||
@ -354,7 +357,8 @@ def dla60_res2next(pretrained=None, num_classes=1000, in_chans=3, **kwargs):
|
|||||||
@register_model
|
@register_model
|
||||||
def dla34(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-34
|
def dla34(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-34
|
||||||
default_cfg = default_cfgs['dla34']
|
default_cfg = default_cfgs['dla34']
|
||||||
model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=DlaBasic, **kwargs)
|
model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=DlaBasic,
|
||||||
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
@ -36,7 +36,7 @@ default_cfgs = {
|
|||||||
'url': '',
|
'url': '',
|
||||||
'input_size': (3, 299, 299),
|
'input_size': (3, 299, 299),
|
||||||
'crop_pct': 0.875,
|
'crop_pct': 0.875,
|
||||||
'pool_size': (10, 10),
|
'pool_size': (5, 5),
|
||||||
'interpolation': 'bicubic',
|
'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_DEFAULT_MEAN,
|
'mean': IMAGENET_DEFAULT_MEAN,
|
||||||
'std': IMAGENET_DEFAULT_STD,
|
'std': IMAGENET_DEFAULT_STD,
|
||||||
|
@ -34,7 +34,7 @@ def _cfg(url='', **kwargs):
|
|||||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
'first_conv': 'conv1', 'classifier': 'fc',
|
'first_conv': 'conv1', 'classifier': 'classifier',
|
||||||
**kwargs
|
**kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ def _cfg(url='', **kwargs):
|
|||||||
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||||
'first_conv': 'conv1', 'classifier': 'fc',
|
'first_conv': 'Conv2d_1a_3x3', 'classifier': 'fc',
|
||||||
**kwargs
|
**kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ __all__ = ['MobileNetV3']
|
|||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
return {
|
return {
|
||||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1),
|
||||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
'first_conv': 'conv_stem', 'classifier': 'classifier',
|
'first_conv': 'conv_stem', 'classifier': 'classifier',
|
||||||
|
@ -19,7 +19,7 @@ default_cfgs = {
|
|||||||
'mean': (0.5, 0.5, 0.5),
|
'mean': (0.5, 0.5, 0.5),
|
||||||
'std': (0.5, 0.5, 0.5),
|
'std': (0.5, 0.5, 0.5),
|
||||||
'num_classes': 1001,
|
'num_classes': 1001,
|
||||||
'first_conv': 'conv_0.conv',
|
'first_conv': 'conv0.conv',
|
||||||
'classifier': 'last_linear',
|
'classifier': 'last_linear',
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -612,7 +612,7 @@ def nasnetalarge(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|||||||
"""NASNet-A large model architecture.
|
"""NASNet-A large model architecture.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['nasnetalarge']
|
default_cfg = default_cfgs['nasnetalarge']
|
||||||
model = NASNetALarge(num_classes=1000, in_chans=in_chans, **kwargs)
|
model = NASNetALarge(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
@ -38,11 +38,14 @@ default_cfgs = {
|
|||||||
'resnest50d': _cfg(
|
'resnest50d': _cfg(
|
||||||
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50-528c19ca.pth'),
|
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50-528c19ca.pth'),
|
||||||
'resnest101e': _cfg(
|
'resnest101e': _cfg(
|
||||||
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest101-22405ba7.pth', input_size=(3, 256, 256)),
|
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest101-22405ba7.pth',
|
||||||
|
input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||||
'resnest200e': _cfg(
|
'resnest200e': _cfg(
|
||||||
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest200-75117900.pth', input_size=(3, 320, 320)),
|
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest200-75117900.pth',
|
||||||
|
input_size=(3, 320, 320), pool_size=(10, 10)),
|
||||||
'resnest269e': _cfg(
|
'resnest269e': _cfg(
|
||||||
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest269-0cc87c48.pth', input_size=(3, 416, 416)),
|
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest269-0cc87c48.pth',
|
||||||
|
input_size=(3, 416, 416), pool_size=(13, 13)),
|
||||||
'resnest50d_4s2x40d': _cfg(
|
'resnest50d_4s2x40d': _cfg(
|
||||||
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50_fast_4s2x40d-41d14ed0.pth',
|
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50_fast_4s2x40d-41d14ed0.pth',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
|
@ -26,7 +26,7 @@ __all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this
|
|||||||
def _cfg(url='', **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
return {
|
return {
|
||||||
'url': url,
|
'url': url,
|
||||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (3, 3),
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (4, 4),
|
||||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
'first_conv': 'stem', 'classifier': 'fc',
|
'first_conv': 'stem', 'classifier': 'fc',
|
||||||
|
@ -28,7 +28,7 @@ def _cfg(url='', **kwargs):
|
|||||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||||
'mean': (0, 0, 0), 'std': (1, 1, 1),
|
'mean': (0, 0, 0), 'std': (1, 1, 1),
|
||||||
'first_conv': 'layer0.conv1', 'classifier': 'head.fc',
|
'first_conv': 'body.conv1', 'classifier': 'head.fc',
|
||||||
**kwargs
|
**kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,13 +41,13 @@ default_cfgs = {
|
|||||||
'tresnet_xl': _cfg(
|
'tresnet_xl': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_82_0-a2d51b00.pth'),
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_82_0-a2d51b00.pth'),
|
||||||
'tresnet_m_448': _cfg(
|
'tresnet_m_448': _cfg(
|
||||||
input_size=(3, 448, 448),
|
input_size=(3, 448, 448), pool_size=(14, 14),
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_448-bc359d10.pth'),
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_448-bc359d10.pth'),
|
||||||
'tresnet_l_448': _cfg(
|
'tresnet_l_448': _cfg(
|
||||||
input_size=(3, 448, 448),
|
input_size=(3, 448, 448), pool_size=(14, 14),
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_448-940d0cd1.pth'),
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_448-940d0cd1.pth'),
|
||||||
'tresnet_xl_448': _cfg(
|
'tresnet_xl_448': _cfg(
|
||||||
input_size=(3, 448, 448),
|
input_size=(3, 448, 448), pool_size=(14, 14),
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth')
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,6 +37,7 @@ default_cfgs = {
|
|||||||
'xception': {
|
'xception': {
|
||||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth',
|
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth',
|
||||||
'input_size': (3, 299, 299),
|
'input_size': (3, 299, 299),
|
||||||
|
'pool_size': (10, 10),
|
||||||
'crop_pct': 0.8975,
|
'crop_pct': 0.8975,
|
||||||
'interpolation': 'bicubic',
|
'interpolation': 'bicubic',
|
||||||
'mean': (0.5, 0.5, 0.5),
|
'mean': (0.5, 0.5, 0.5),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user