mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add model registry and model listing fns, refactor model_factory/create_model fn
This commit is contained in:
parent
8512436436
commit
171c0b88b6
@ -1,2 +1,2 @@
|
||||
from .version import __version__
|
||||
from .models import create_model
|
||||
from .models import create_model, list_models, is_model, list_modules, model_entrypoint
|
||||
|
@ -1,4 +1,16 @@
|
||||
from .model_factory import create_model
|
||||
from .inception_v4 import *
|
||||
from .inception_resnet_v2 import *
|
||||
from .densenet import *
|
||||
from .resnet import *
|
||||
from .dpn import *
|
||||
from .senet import *
|
||||
from .xception import *
|
||||
from .pnasnet import *
|
||||
from .gen_efficientnet import *
|
||||
from .inception_v3 import *
|
||||
from .gluon_resnet import *
|
||||
|
||||
from .registry import *
|
||||
from .factory import create_model
|
||||
from .helpers import load_checkpoint, resume_checkpoint
|
||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
|
||||
|
@ -4,13 +4,17 @@ fixed kwargs passthrough and addition of dynamic global avg/max pool.
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .adaptive_avgmax_pool import *
|
||||
from .adaptive_avgmax_pool import select_adaptive_pool2d
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
import re
|
||||
|
||||
_models = ['densenet121', 'densenet169', 'densenet201', 'densenet161']
|
||||
__all__ = ['DenseNet'] + _models
|
||||
__all__ = ['DenseNet']
|
||||
|
||||
|
||||
def _cfg(url=''):
|
||||
@ -30,71 +34,6 @@ default_cfgs = {
|
||||
}
|
||||
|
||||
|
||||
def _filter_pretrained(state_dict):
|
||||
pattern = re.compile(
|
||||
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
res = pattern.match(key)
|
||||
if res:
|
||||
new_key = res.group(1) + res.group(2)
|
||||
state_dict[new_key] = state_dict[key]
|
||||
del state_dict[key]
|
||||
return state_dict
|
||||
|
||||
|
||||
def densenet121(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
r"""Densenet-121 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
"""
|
||||
default_cfg = default_cfgs['densenet121']
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
|
||||
return model
|
||||
|
||||
|
||||
def densenet169(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
r"""Densenet-169 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
"""
|
||||
default_cfg = default_cfgs['densenet169']
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
|
||||
return model
|
||||
|
||||
|
||||
def densenet201(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
r"""Densenet-201 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
"""
|
||||
default_cfg = default_cfgs['densenet201']
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
|
||||
return model
|
||||
|
||||
|
||||
def densenet161(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
r"""Densenet-201 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
"""
|
||||
default_cfg = default_cfgs['densenet161']
|
||||
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
|
||||
return model
|
||||
|
||||
|
||||
class _DenseLayer(nn.Sequential):
|
||||
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
|
||||
super(_DenseLayer, self).__init__()
|
||||
@ -205,3 +144,72 @@ class DenseNet(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.classifier(self.forward_features(x, pool=True))
|
||||
|
||||
|
||||
def _filter_pretrained(state_dict):
|
||||
pattern = re.compile(
|
||||
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
res = pattern.match(key)
|
||||
if res:
|
||||
new_key = res.group(1) + res.group(2)
|
||||
state_dict[new_key] = state_dict[key]
|
||||
del state_dict[key]
|
||||
return state_dict
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def densenet121(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
r"""Densenet-121 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
"""
|
||||
default_cfg = default_cfgs['densenet121']
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def densenet169(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
r"""Densenet-169 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
"""
|
||||
default_cfg = default_cfgs['densenet169']
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def densenet201(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
r"""Densenet-201 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
"""
|
||||
default_cfg = default_cfgs['densenet201']
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def densenet161(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
r"""Densenet-201 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
"""
|
||||
default_cfg = default_cfgs['densenet161']
|
||||
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
|
||||
return model
|
||||
|
@ -14,12 +14,13 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from collections import OrderedDict
|
||||
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .adaptive_avgmax_pool import select_adaptive_pool2d
|
||||
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
|
||||
|
||||
_models = ['dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn131', 'dpn107']
|
||||
__all__ = ['DPN'] + _models
|
||||
|
||||
__all__ = ['DPN']
|
||||
|
||||
|
||||
def _cfg(url=''):
|
||||
@ -47,78 +48,6 @@ default_cfgs = {
|
||||
}
|
||||
|
||||
|
||||
def dpn68(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['dpn68']
|
||||
model = DPN(
|
||||
small=True, num_init_features=10, k_r=128, groups=32,
|
||||
k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def dpn68b(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['dpn68b_extra']
|
||||
model = DPN(
|
||||
small=True, num_init_features=10, k_r=128, groups=32,
|
||||
b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def dpn92(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['dpn92_extra']
|
||||
model = DPN(
|
||||
num_init_features=64, k_r=96, groups=32,
|
||||
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def dpn98(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['dpn98']
|
||||
model = DPN(
|
||||
num_init_features=96, k_r=160, groups=40,
|
||||
k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def dpn131(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['dpn131']
|
||||
model = DPN(
|
||||
num_init_features=128, k_r=160, groups=40,
|
||||
k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def dpn107(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['dpn107_extra']
|
||||
model = DPN(
|
||||
num_init_features=128, k_r=200, groups=50,
|
||||
k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
class CatBnAct(nn.Module):
|
||||
def __init__(self, in_chs, activation_fn=nn.ReLU(inplace=True)):
|
||||
super(CatBnAct, self).__init__()
|
||||
@ -317,3 +246,78 @@ class DPN(nn.Module):
|
||||
return out.view(out.size(0), -1)
|
||||
|
||||
|
||||
@register_model
|
||||
def dpn68(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['dpn68']
|
||||
model = DPN(
|
||||
small=True, num_init_features=10, k_r=128, groups=32,
|
||||
k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def dpn68b(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['dpn68b_extra']
|
||||
model = DPN(
|
||||
small=True, num_init_features=10, k_r=128, groups=32,
|
||||
b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def dpn92(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['dpn92_extra']
|
||||
model = DPN(
|
||||
num_init_features=64, k_r=96, groups=32,
|
||||
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def dpn98(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['dpn98']
|
||||
model = DPN(
|
||||
num_init_features=96, k_r=160, groups=40,
|
||||
k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def dpn131(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['dpn131']
|
||||
model = DPN(
|
||||
num_init_features=128, k_r=160, groups=40,
|
||||
k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def dpn107(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['dpn107_extra']
|
||||
model = DPN(
|
||||
num_init_features=128, k_r=200, groups=50,
|
||||
k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
44
timm/models/factory.py
Normal file
44
timm/models/factory.py
Normal file
@ -0,0 +1,44 @@
|
||||
from .registry import is_model, is_model_in_modules, model_entrypoint
|
||||
from .helpers import load_checkpoint
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name,
|
||||
pretrained=False,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
checkpoint_path='',
|
||||
**kwargs):
|
||||
"""Create a model
|
||||
|
||||
Args:
|
||||
model_name (str): name of model to instantiate
|
||||
pretrained (bool): load pretrained ImageNet-1k weights if true
|
||||
num_classes (int): number of classes for final fully connected layer (default: 1000)
|
||||
in_chans (int): number of input channels / colors (default: 3)
|
||||
checkpoint_path (str): path of checkpoint to load after model is initialized
|
||||
|
||||
Keyword Args:
|
||||
drop_rate (float): dropout rate for training (default: 0.0)
|
||||
global_pool (str): global pool type (default: 'avg')
|
||||
**: other kwargs are model specific
|
||||
"""
|
||||
margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
|
||||
|
||||
# Not all models have support for batchnorm params passed as args, only gen_efficientnet variants
|
||||
supports_bn_params = is_model_in_modules(model_name, ['gen_efficientnet'])
|
||||
if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]):
|
||||
kwargs.pop('bn_tf', None)
|
||||
kwargs.pop('bn_momentum', None)
|
||||
kwargs.pop('bn_eps', None)
|
||||
|
||||
if is_model(model_name):
|
||||
create_fn = model_entrypoint(model_name)
|
||||
model = create_fn(**margs, **kwargs)
|
||||
else:
|
||||
raise RuntimeError('Unknown model (%s)' % model_name)
|
||||
|
||||
if checkpoint_path:
|
||||
load_checkpoint(model, checkpoint_path)
|
||||
|
||||
return model
|
@ -23,19 +23,15 @@ from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from .conv2d_same import sconv2d
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
_models = [
|
||||
'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140', 'semnasnet_050', 'semnasnet_075',
|
||||
'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small', 'mobilenetv1_100', 'mobilenetv2_100',
|
||||
'mobilenetv3_050', 'mobilenetv3_075', 'mobilenetv3_100', 'chamnetv1_100', 'chamnetv2_100',
|
||||
'fbnetc_100', 'spnasnet_100', 'tflite_mnasnet_100', 'tflite_semnasnet_100', 'efficientnet_b0', 'efficientnet_b1',
|
||||
'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'tf_efficientnet_b0',
|
||||
'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3', 'tf_efficientnet_b4', 'tf_efficientnet_b5']
|
||||
__all__ = ['GenEfficientNet', 'gen_efficientnet_model_names'] + _models
|
||||
|
||||
__all__ = ['GenEfficientNet']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
@ -1157,6 +1153,7 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mnasnet_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MNASNet B1, depth multiplier of 0.5. """
|
||||
default_cfg = default_cfgs['mnasnet_050']
|
||||
@ -1167,6 +1164,7 @@ def mnasnet_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mnasnet_075(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MNASNet B1, depth multiplier of 0.75. """
|
||||
default_cfg = default_cfgs['mnasnet_075']
|
||||
@ -1177,6 +1175,7 @@ def mnasnet_075(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MNASNet B1, depth multiplier of 1.0. """
|
||||
default_cfg = default_cfgs['mnasnet_100']
|
||||
@ -1187,11 +1186,13 @@ def mnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mnasnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MNASNet B1, depth multiplier of 1.0. """
|
||||
return mnasnet_100(num_classes, in_chans, pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def tflite_mnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MNASNet B1, depth multiplier of 1.0. """
|
||||
default_cfg = default_cfgs['tflite_mnasnet_100']
|
||||
@ -1205,6 +1206,7 @@ def tflite_mnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mnasnet_140(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MNASNet B1, depth multiplier of 1.4 """
|
||||
default_cfg = default_cfgs['mnasnet_140']
|
||||
@ -1215,6 +1217,7 @@ def mnasnet_140(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def semnasnet_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MNASNet A1 (w/ SE), depth multiplier of 0.5 """
|
||||
default_cfg = default_cfgs['semnasnet_050']
|
||||
@ -1225,6 +1228,7 @@ def semnasnet_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def semnasnet_075(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MNASNet A1 (w/ SE), depth multiplier of 0.75. """
|
||||
default_cfg = default_cfgs['semnasnet_075']
|
||||
@ -1235,6 +1239,7 @@ def semnasnet_075(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def semnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MNASNet A1 (w/ SE), depth multiplier of 1.0. """
|
||||
default_cfg = default_cfgs['semnasnet_100']
|
||||
@ -1245,11 +1250,13 @@ def semnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mnasnet_a1(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MNASNet A1 (w/ SE), depth multiplier of 1.0. """
|
||||
return semnasnet_100(num_classes, in_chans, pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def tflite_semnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MNASNet A1, depth multiplier of 1.0. """
|
||||
default_cfg = default_cfgs['tflite_semnasnet_100']
|
||||
@ -1263,6 +1270,7 @@ def tflite_semnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwarg
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def semnasnet_140(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MNASNet A1 (w/ SE), depth multiplier of 1.4. """
|
||||
default_cfg = default_cfgs['semnasnet_140']
|
||||
@ -1273,6 +1281,7 @@ def semnasnet_140(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mnasnet_small(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MNASNet Small, depth multiplier of 1.0. """
|
||||
default_cfg = default_cfgs['mnasnet_small']
|
||||
@ -1283,6 +1292,7 @@ def mnasnet_small(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv1_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MobileNet V1 """
|
||||
default_cfg = default_cfgs['mobilenetv1_100']
|
||||
@ -1293,6 +1303,7 @@ def mobilenetv1_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv2_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MobileNet V2 """
|
||||
default_cfg = default_cfgs['mobilenetv2_100']
|
||||
@ -1303,6 +1314,7 @@ def mobilenetv2_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MobileNet V3 """
|
||||
default_cfg = default_cfgs['mobilenetv3_050']
|
||||
@ -1313,6 +1325,7 @@ def mobilenetv3_050(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_075(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MobileNet V3 """
|
||||
default_cfg = default_cfgs['mobilenetv3_075']
|
||||
@ -1323,6 +1336,7 @@ def mobilenetv3_075(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MobileNet V3 """
|
||||
default_cfg = default_cfgs['mobilenetv3_100']
|
||||
@ -1336,6 +1350,7 @@ def mobilenetv3_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def fbnetc_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" FBNet-C """
|
||||
default_cfg = default_cfgs['fbnetc_100']
|
||||
@ -1349,6 +1364,7 @@ def fbnetc_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def chamnetv1_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" ChamNet """
|
||||
default_cfg = default_cfgs['chamnetv1_100']
|
||||
@ -1359,6 +1375,7 @@ def chamnetv1_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def chamnetv2_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" ChamNet """
|
||||
default_cfg = default_cfgs['chamnetv2_100']
|
||||
@ -1369,6 +1386,7 @@ def chamnetv2_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def spnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" Single-Path NAS Pixel1"""
|
||||
default_cfg = default_cfgs['spnasnet_100']
|
||||
@ -1379,6 +1397,7 @@ def spnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" EfficientNet-B0 """
|
||||
default_cfg = default_cfgs['efficientnet_b0']
|
||||
@ -1392,6 +1411,7 @@ def efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" EfficientNet-B1 """
|
||||
default_cfg = default_cfgs['efficientnet_b1']
|
||||
@ -1405,6 +1425,7 @@ def efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" EfficientNet-B2 """
|
||||
default_cfg = default_cfgs['efficientnet_b2']
|
||||
@ -1418,6 +1439,7 @@ def efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" EfficientNet-B3 """
|
||||
default_cfg = default_cfgs['efficientnet_b3']
|
||||
@ -1431,6 +1453,7 @@ def efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" EfficientNet-B4 """
|
||||
default_cfg = default_cfgs['efficientnet_b4']
|
||||
@ -1444,6 +1467,7 @@ def efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" EfficientNet-B5 """
|
||||
# NOTE for train, drop_rate should be 0.4
|
||||
@ -1457,6 +1481,7 @@ def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" EfficientNet-B0. Tensorflow compatible variant """
|
||||
default_cfg = default_cfgs['tf_efficientnet_b0']
|
||||
@ -1471,6 +1496,7 @@ def tf_efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" EfficientNet-B1. Tensorflow compatible variant """
|
||||
default_cfg = default_cfgs['tf_efficientnet_b1']
|
||||
@ -1485,6 +1511,7 @@ def tf_efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" EfficientNet-B2. Tensorflow compatible variant """
|
||||
default_cfg = default_cfgs['tf_efficientnet_b2']
|
||||
@ -1499,6 +1526,7 @@ def tf_efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" EfficientNet-B3. Tensorflow compatible variant """
|
||||
default_cfg = default_cfgs['tf_efficientnet_b3']
|
||||
@ -1513,6 +1541,7 @@ def tf_efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" EfficientNet-B4. Tensorflow compatible variant """
|
||||
default_cfg = default_cfgs['tf_efficientnet_b4']
|
||||
@ -1527,6 +1556,7 @@ def tf_efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" EfficientNet-B5. Tensorflow compatible variant """
|
||||
default_cfg = default_cfgs['tf_efficientnet_b5']
|
||||
|
@ -3,21 +3,19 @@ This file evolved from https://github.com/pytorch/vision 'resnet.py' with (SE)-R
|
||||
and ports of Gluon variations (https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/resnet.py)
|
||||
by Ross Wightman
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
_models = [
|
||||
'gluon_resnet18_v1b', 'gluon_resnet34_v1b', 'gluon_resnet50_v1b', 'gluon_resnet101_v1b', 'gluon_resnet152_v1b',
|
||||
'gluon_resnet50_v1c', 'gluon_resnet101_v1c', 'gluon_resnet152_v1c', 'gluon_resnet50_v1d', 'gluon_resnet101_v1d',
|
||||
'gluon_resnet152_v1d', 'gluon_resnet50_v1e', 'gluon_resnet101_v1e', 'gluon_resnet152_v1e', 'gluon_resnet50_v1s',
|
||||
'gluon_resnet101_v1s', 'gluon_resnet152_v1s', 'gluon_resnext50_32x4d', 'gluon_resnext101_32x4d',
|
||||
'gluon_resnext101_64x4d', 'gluon_resnext152_32x4d', 'gluon_seresnext50_32x4d', 'gluon_seresnext101_32x4d',
|
||||
'gluon_seresnext101_64x4d', 'gluon_seresnext152_32x4d', 'gluon_senet154']
|
||||
__all__ = ['GluonResNet'] + _models
|
||||
|
||||
__all__ = ['GluonResNet']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
@ -361,6 +359,7 @@ class GluonResNet(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnet18_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-18 model.
|
||||
"""
|
||||
@ -372,6 +371,7 @@ def gluon_resnet18_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnet34_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-34 model.
|
||||
"""
|
||||
@ -383,6 +383,7 @@ def gluon_resnet34_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnet50_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-50 model.
|
||||
"""
|
||||
@ -394,6 +395,7 @@ def gluon_resnet50_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnet101_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-101 model.
|
||||
"""
|
||||
@ -405,6 +407,7 @@ def gluon_resnet101_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnet152_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-152 model.
|
||||
"""
|
||||
@ -416,6 +419,7 @@ def gluon_resnet152_v1b(pretrained=False, num_classes=1000, in_chans=3, **kwargs
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnet50_v1c(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-50 model.
|
||||
"""
|
||||
@ -428,6 +432,7 @@ def gluon_resnet50_v1c(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnet101_v1c(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-101 model.
|
||||
"""
|
||||
@ -440,6 +445,7 @@ def gluon_resnet101_v1c(pretrained=False, num_classes=1000, in_chans=3, **kwargs
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnet152_v1c(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-152 model.
|
||||
"""
|
||||
@ -452,6 +458,7 @@ def gluon_resnet152_v1c(pretrained=False, num_classes=1000, in_chans=3, **kwargs
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnet50_v1d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-50 model.
|
||||
"""
|
||||
@ -464,6 +471,7 @@ def gluon_resnet50_v1d(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnet101_v1d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-101 model.
|
||||
"""
|
||||
@ -476,6 +484,7 @@ def gluon_resnet101_v1d(pretrained=False, num_classes=1000, in_chans=3, **kwargs
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnet152_v1d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-152 model.
|
||||
"""
|
||||
@ -488,6 +497,7 @@ def gluon_resnet152_v1d(pretrained=False, num_classes=1000, in_chans=3, **kwargs
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnet50_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-50-V1e model. No pretrained weights for any 'e' variants
|
||||
"""
|
||||
@ -500,6 +510,7 @@ def gluon_resnet50_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnet101_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-101 model.
|
||||
"""
|
||||
@ -512,6 +523,7 @@ def gluon_resnet101_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnet152_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-152 model.
|
||||
"""
|
||||
@ -524,6 +536,7 @@ def gluon_resnet152_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnet50_v1s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-50 model.
|
||||
"""
|
||||
@ -536,6 +549,7 @@ def gluon_resnet50_v1s(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnet101_v1s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-101 model.
|
||||
"""
|
||||
@ -548,6 +562,7 @@ def gluon_resnet101_v1s(pretrained=False, num_classes=1000, in_chans=3, **kwargs
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnet152_v1s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-152 model.
|
||||
"""
|
||||
@ -560,6 +575,7 @@ def gluon_resnet152_v1s(pretrained=False, num_classes=1000, in_chans=3, **kwargs
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt50-32x4d model.
|
||||
"""
|
||||
@ -573,6 +589,7 @@ def gluon_resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwar
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt-101 model.
|
||||
"""
|
||||
@ -586,6 +603,7 @@ def gluon_resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwa
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt-101 model.
|
||||
"""
|
||||
@ -599,6 +617,7 @@ def gluon_resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwa
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt152-32x4d model.
|
||||
"""
|
||||
@ -612,6 +631,7 @@ def gluon_resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwa
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a SEResNeXt50-32x4d model.
|
||||
"""
|
||||
@ -625,6 +645,7 @@ def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kw
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a SEResNeXt-101-32x4d model.
|
||||
"""
|
||||
@ -638,6 +659,7 @@ def gluon_seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **k
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a SEResNeXt-101-64x4d model.
|
||||
"""
|
||||
@ -651,6 +673,7 @@ def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **k
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_seresnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a SEResNeXt152-32x4d model.
|
||||
"""
|
||||
@ -664,6 +687,7 @@ def gluon_seresnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **k
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs an SENet-154 model.
|
||||
"""
|
||||
|
@ -2,12 +2,16 @@
|
||||
Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
|
||||
based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .adaptive_avgmax_pool import *
|
||||
from .adaptive_avgmax_pool import select_adaptive_pool2d
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
|
||||
_models = ['inception_resnet_v2', 'ens_adv_inception_resnet_v2']
|
||||
__all__ = ['InceptionResnetV2'] + _models
|
||||
__all__ = ['InceptionResnetV2']
|
||||
|
||||
default_cfgs = {
|
||||
# ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
|
||||
@ -328,6 +332,7 @@ class InceptionResnetV2(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@register_model
|
||||
def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
r"""InceptionResnetV2 model architecture from the
|
||||
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>` paper.
|
||||
@ -341,6 +346,7 @@ def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ens_adv_inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
r""" Ensemble Adversarially trained InceptionResnetV2 model architecture
|
||||
As per https://arxiv.org/abs/1705.07204 and
|
||||
|
@ -1,9 +1,9 @@
|
||||
from torchvision.models import Inception3
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
|
||||
_models = ['inception_v3', 'tf_inception_v3', 'adv_inception_v3', 'gluon_inception_v3']
|
||||
__all__ = _models
|
||||
__all__ = []
|
||||
|
||||
default_cfgs = {
|
||||
# original PyTorch weights, ported from Tensorflow but modified
|
||||
@ -66,6 +66,7 @@ def _assert_default_kwargs(kwargs):
|
||||
assert kwargs.pop('drop_rate', 0.) == 0.
|
||||
|
||||
|
||||
@register_model
|
||||
def inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
# original PyTorch weights, ported from Tensorflow but modified
|
||||
default_cfg = default_cfgs['inception_v3']
|
||||
@ -78,6 +79,7 @@ def inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
|
||||
default_cfg = default_cfgs['tf_inception_v3']
|
||||
@ -90,6 +92,7 @@ def tf_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def adv_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
# my port of Tensorflow adversarially trained Inception V3 from
|
||||
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
|
||||
@ -103,6 +106,7 @@ def adv_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
# from gluon pretrained models, best performing in terms of accuracy/loss metrics
|
||||
# https://gluon-cv.mxnet.io/model_zoo/classification.html
|
||||
|
@ -2,12 +2,16 @@
|
||||
Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
|
||||
based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .adaptive_avgmax_pool import *
|
||||
from .adaptive_avgmax_pool import select_adaptive_pool2d
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
|
||||
_models = ['inception_v4']
|
||||
__all__ = ['InceptionV4'] + _models
|
||||
__all__ = ['InceptionV4']
|
||||
|
||||
default_cfgs = {
|
||||
'inception_v4': {
|
||||
@ -293,6 +297,7 @@ class InceptionV4(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@register_model
|
||||
def inception_v4(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['inception_v4']
|
||||
model = InceptionV4(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
|
@ -1,42 +0,0 @@
|
||||
from .inception_v4 import *
|
||||
from .inception_resnet_v2 import *
|
||||
from .densenet import *
|
||||
from .resnet import *
|
||||
from .dpn import *
|
||||
from .senet import *
|
||||
from .xception import *
|
||||
from .pnasnet import *
|
||||
from .gen_efficientnet import *
|
||||
from .inception_v3 import *
|
||||
from .gluon_resnet import *
|
||||
|
||||
from .helpers import load_checkpoint
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name,
|
||||
pretrained=False,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
checkpoint_path='',
|
||||
**kwargs):
|
||||
|
||||
margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
|
||||
|
||||
# Not all models have support for batchnorm params passed as args, only gen_efficientnet variants
|
||||
supports_bn_params = model_name in gen_efficientnet_model_names()
|
||||
if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]):
|
||||
kwargs.pop('bn_tf', None)
|
||||
kwargs.pop('bn_momentum', None)
|
||||
kwargs.pop('bn_eps', None)
|
||||
|
||||
if model_name in globals():
|
||||
create_fn = globals()[model_name]
|
||||
model = create_fn(**margs, **kwargs)
|
||||
else:
|
||||
raise RuntimeError('Unknown model (%s)' % model_name)
|
||||
|
||||
if checkpoint_path:
|
||||
load_checkpoint(model, checkpoint_path)
|
||||
|
||||
return model
|
@ -12,11 +12,11 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
|
||||
_models = ['pnasnet5large']
|
||||
__all__ = ['PNASNet5Large'] + _models
|
||||
__all__ = ['PNASNet5Large']
|
||||
|
||||
default_cfgs = {
|
||||
'pnasnet5large': {
|
||||
@ -385,6 +385,7 @@ class PNASNet5Large(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@register_model
|
||||
def pnasnet5large(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
r"""PNASNet-5 model architecture from the
|
||||
`"Progressive Neural Architecture Search"
|
||||
|
78
timm/models/registry.py
Normal file
78
timm/models/registry.py
Normal file
@ -0,0 +1,78 @@
|
||||
import sys
|
||||
import re
|
||||
import fnmatch
|
||||
from collections import defaultdict
|
||||
|
||||
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules']
|
||||
|
||||
_module_to_models = defaultdict(set)
|
||||
_model_to_module = {}
|
||||
_model_entrypoints = {}
|
||||
|
||||
|
||||
def register_model(fn):
|
||||
mod = sys.modules[fn.__module__]
|
||||
module_name_split = fn.__module__.split('.')
|
||||
module_name = module_name_split[-1] if len(module_name_split) else ''
|
||||
if hasattr(mod, '__all__'):
|
||||
mod.__all__.append(fn.__name__)
|
||||
else:
|
||||
mod.__all__ = [fn.__name__]
|
||||
_model_entrypoints[fn.__name__] = fn
|
||||
_model_to_module[fn.__name__] = module_name
|
||||
_module_to_models[module_name].add(fn.__name__)
|
||||
return fn
|
||||
|
||||
|
||||
def _natural_key(string_):
|
||||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||
|
||||
|
||||
def list_models(filter='', module=''):
|
||||
""" Return list of available model names, sorted alphabetically
|
||||
|
||||
Args:
|
||||
filter (str) - Wildcard filter string that works with fnmatch
|
||||
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
|
||||
|
||||
Example:
|
||||
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
|
||||
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
|
||||
"""
|
||||
if module:
|
||||
models = list(_module_to_models[module])
|
||||
else:
|
||||
models = _model_entrypoints.keys()
|
||||
if filter:
|
||||
models = fnmatch.filter(models, filter)
|
||||
return list(sorted(models, key=_natural_key))
|
||||
|
||||
|
||||
def is_model(model_name):
|
||||
""" Check if a model name exists
|
||||
"""
|
||||
return model_name in _model_entrypoints
|
||||
|
||||
|
||||
def model_entrypoint(model_name):
|
||||
"""Fetch a model entrypoint for specified model name
|
||||
"""
|
||||
return _model_entrypoints[model_name]
|
||||
|
||||
|
||||
def list_modules():
|
||||
""" Return list of module names that contain models / model entrypoints
|
||||
"""
|
||||
modules = _module_to_models.keys()
|
||||
return list(sorted(modules))
|
||||
|
||||
|
||||
def is_model_in_modules(model_name, module_names):
|
||||
"""Check if a model exists within a subset of modules
|
||||
Args:
|
||||
model_name (str) - name of model to check
|
||||
module_names (tuple, list, set) - names of modules to search in
|
||||
"""
|
||||
assert isinstance(module_names, (tuple, list, set))
|
||||
return any(model_name in _module_to_models[n] for n in module_names)
|
||||
|
@ -4,17 +4,18 @@ additional dropout and dynamic global avg/max pool.
|
||||
|
||||
ResNext additions added by Ross Wightman
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
_models = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
|
||||
'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d',
|
||||
'ig_resnext101_32x8d', 'ig_resnext101_32x16d', 'ig_resnext101_32x32d', 'ig_resnext101_32x48d']
|
||||
__all__ = ['ResNet'] + _models
|
||||
|
||||
__all__ = ['ResNet'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
@ -224,6 +225,7 @@ class ResNet(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-18 model.
|
||||
"""
|
||||
@ -235,6 +237,7 @@ def resnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-34 model.
|
||||
"""
|
||||
@ -246,6 +249,7 @@ def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-50 model.
|
||||
"""
|
||||
@ -257,6 +261,7 @@ def resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-101 model.
|
||||
"""
|
||||
@ -268,6 +273,7 @@ def resnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-152 model.
|
||||
"""
|
||||
@ -279,6 +285,7 @@ def resnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt50-32x4d model.
|
||||
"""
|
||||
@ -292,6 +299,7 @@ def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt-101 model.
|
||||
"""
|
||||
@ -305,6 +313,7 @@ def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt101-64x4d model.
|
||||
"""
|
||||
@ -318,6 +327,7 @@ def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt152-32x4d model.
|
||||
"""
|
||||
@ -331,6 +341,7 @@ def resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ig_resnext101_32x8d(pretrained=True, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data
|
||||
and finetuned on ImageNet from Figure 5 in
|
||||
@ -349,6 +360,7 @@ def ig_resnext101_32x8d(pretrained=True, num_classes=1000, in_chans=3, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ig_resnext101_32x16d(pretrained=True, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt-101 32x16 model pre-trained on weakly-supervised data
|
||||
and finetuned on ImageNet from Figure 5 in
|
||||
@ -367,6 +379,7 @@ def ig_resnext101_32x16d(pretrained=True, num_classes=1000, in_chans=3, **kwargs
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ig_resnext101_32x32d(pretrained=True, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt-101 32x32 model pre-trained on weakly-supervised data
|
||||
and finetuned on ImageNet from Figure 5 in
|
||||
@ -385,6 +398,7 @@ def ig_resnext101_32x32d(pretrained=True, num_classes=1000, in_chans=3, **kwargs
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ig_resnext101_32x48d(pretrained=True, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt-101 32x48 model pre-trained on weakly-supervised data
|
||||
and finetuned on ImageNet from Figure 5 in
|
||||
|
@ -8,20 +8,18 @@ Original model: https://github.com/hujie-frank/SENet
|
||||
ResNet code gently borrowed from
|
||||
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
|
||||
"""
|
||||
from __future__ import print_function, division, absolute_import
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
_models = ['seresnet18', 'seresnet34', 'seresnet50', 'seresnet101', 'seresnet152', 'senet154',
|
||||
'seresnext26_32x4d', 'seresnext50_32x4d', 'seresnext101_32x4d']
|
||||
__all__ = ['SENet'] + _models
|
||||
__all__ = ['SENet']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
@ -400,6 +398,7 @@ class SENet(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['seresnet18']
|
||||
model = SENet(SEResNetBlock, [2, 2, 2, 2], groups=1, reduction=16,
|
||||
@ -412,6 +411,7 @@ def seresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['seresnet34']
|
||||
model = SENet(SEResNetBlock, [3, 4, 6, 3], groups=1, reduction=16,
|
||||
@ -424,6 +424,7 @@ def seresnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['seresnet50']
|
||||
model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16,
|
||||
@ -436,6 +437,7 @@ def seresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['seresnet101']
|
||||
model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16,
|
||||
@ -448,6 +450,7 @@ def seresnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['seresnet152']
|
||||
model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16,
|
||||
@ -460,6 +463,7 @@ def seresnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['senet154']
|
||||
model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
|
||||
@ -470,6 +474,7 @@ def senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext26_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['seresnext26_32x4d']
|
||||
model = SENet(SEResNeXtBottleneck, [2, 2, 2, 2], groups=32, reduction=16,
|
||||
@ -482,6 +487,7 @@ def seresnext26_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['seresnext50_32x4d']
|
||||
model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
|
||||
@ -494,6 +500,7 @@ def seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['seresnext101_32x4d']
|
||||
model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
|
||||
|
@ -21,17 +21,17 @@ normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
|
||||
|
||||
The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
|
||||
"""
|
||||
from __future__ import print_function, division, absolute_import
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .adaptive_avgmax_pool import select_adaptive_pool2d
|
||||
|
||||
_models = ['xception']
|
||||
__all__ = ['Xception'] + _models
|
||||
__all__ = ['Xception']
|
||||
|
||||
default_cfgs = {
|
||||
'xception': {
|
||||
@ -228,6 +228,7 @@ class Xception(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@register_model
|
||||
def xception(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['xception']
|
||||
model = Xception(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
|
24
validate.py
24
validate.py
@ -13,7 +13,7 @@ import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
from collections import OrderedDict
|
||||
|
||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint
|
||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
||||
from timm.data import Dataset, create_loader, resolve_data_config
|
||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
|
||||
|
||||
@ -144,22 +144,26 @@ def validate(args):
|
||||
def main():
|
||||
setup_default_logging()
|
||||
args = parser.parse_args()
|
||||
if args.model == 'all':
|
||||
# validate all models in a list of names with pretrained checkpoints
|
||||
args.pretrained = True
|
||||
# FIXME just an example list, need to add model name collections for
|
||||
# batch testing of various pretrained combinations by arg string
|
||||
models = ['tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3']
|
||||
model_cfgs = [(n, '') for n in models]
|
||||
elif os.path.isdir(args.checkpoint):
|
||||
model_cfgs = []
|
||||
model_names = []
|
||||
if os.path.isdir(args.checkpoint):
|
||||
# validate all checkpoints in a path with same model
|
||||
checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
|
||||
checkpoints += glob.glob(args.checkpoint + '/*.pth')
|
||||
model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
|
||||
else:
|
||||
model_cfgs = []
|
||||
if args.model == 'all':
|
||||
# validate all models in a list of names with pretrained checkpoints
|
||||
args.pretrained = True
|
||||
model_names = list_models()
|
||||
model_cfgs = [(n, '') for n in model_names]
|
||||
elif not is_model(args.model):
|
||||
# model name doesn't exist, try as wildcard filter
|
||||
model_names = list_models(args.model)
|
||||
model_cfgs = [(n, '') for n in model_names]
|
||||
|
||||
if len(model_cfgs):
|
||||
print('Running bulk validation on these pretrained models:', ', '.join(model_names))
|
||||
header_written = False
|
||||
with open('./results-all.csv', mode='w') as cf:
|
||||
for m, c in model_cfgs:
|
||||
|
Loading…
x
Reference in New Issue
Block a user