Merge pull request #501 from rwightman/hf_hub_revisit
Support for huggingface hub via create_model and default_cfgs.pull/510/head
commit
3eac7dc5a3
|
@ -509,7 +509,7 @@ for m in model_list:
|
|||
model.eval()
|
||||
with torch.no_grad():
|
||||
# warmup
|
||||
input = torch.randn((batch_size,) + data_config['input_size']).cuda()
|
||||
input = torch.randn((batch_size,) + tuple(data_config['input_size'])).cuda()
|
||||
model(input)
|
||||
|
||||
bar = tqdm(desc="Evaluation", mininterval=5, total=50000)
|
||||
|
|
|
@ -72,8 +72,8 @@ class RandomResizedCropAndInterpolation:
|
|||
|
||||
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
|
||||
interpolation='bilinear'):
|
||||
if isinstance(size, tuple):
|
||||
self.size = size
|
||||
if isinstance(size, (list, tuple)):
|
||||
self.size = tuple(size)
|
||||
else:
|
||||
self.size = (size, size)
|
||||
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
||||
|
|
|
@ -78,7 +78,7 @@ def transforms_imagenet_train(
|
|||
secondary_tfl = []
|
||||
if auto_augment:
|
||||
assert isinstance(auto_augment, str)
|
||||
if isinstance(img_size, tuple):
|
||||
if isinstance(img_size, (tuple, list)):
|
||||
img_size_min = min(img_size)
|
||||
else:
|
||||
img_size_min = img_size
|
||||
|
@ -136,7 +136,7 @@ def transforms_imagenet_eval(
|
|||
std=IMAGENET_DEFAULT_STD):
|
||||
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
||||
|
||||
if isinstance(img_size, tuple):
|
||||
if isinstance(img_size, (tuple, list)):
|
||||
assert len(img_size) == 2
|
||||
if img_size[-1] == img_size[-2]:
|
||||
# fall-back to older behaviour so Resize scales to shortest edge if target is square
|
||||
|
@ -186,7 +186,7 @@ def create_transform(
|
|||
tf_preprocessing=False,
|
||||
separate=False):
|
||||
|
||||
if isinstance(input_size, tuple):
|
||||
if isinstance(input_size, (tuple, list)):
|
||||
img_size = input_size[-2:]
|
||||
else:
|
||||
img_size = input_size
|
||||
|
|
|
@ -31,7 +31,7 @@ from .xception import *
|
|||
from .xception_aligned import *
|
||||
from .hardcorenas import *
|
||||
|
||||
from .factory import create_model
|
||||
from .factory import create_model, split_model_name, safe_model_name
|
||||
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
|
||||
from .layers import TestTimePoolHead, apply_test_time_pool
|
||||
from .layers import convert_splitbn_model
|
||||
|
|
|
@ -409,8 +409,10 @@ class CspNet(nn.Module):
|
|||
def _create_cspnet(variant, pretrained=False, **kwargs):
|
||||
cfg_variant = variant.split('_')[0]
|
||||
return build_model_with_cfg(
|
||||
CspNet, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(flatten_sequential=True), model_cfg=model_cfgs[cfg_variant], **kwargs)
|
||||
CspNet, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(flatten_sequential=True), model_cfg=model_cfgs[cfg_variant],
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -287,8 +287,10 @@ def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs):
|
|||
kwargs['growth_rate'] = growth_rate
|
||||
kwargs['block_config'] = block_config
|
||||
return build_model_with_cfg(
|
||||
DenseNet, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained, **kwargs)
|
||||
DenseNet, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained,
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -338,8 +338,11 @@ class DLA(nn.Module):
|
|||
|
||||
def _create_dla(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
DLA, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
pretrained_strict=False, feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)), **kwargs)
|
||||
DLA, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
pretrained_strict=False,
|
||||
feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)),
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -262,8 +262,10 @@ class DPN(nn.Module):
|
|||
|
||||
def _create_dpn(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
DPN, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(feature_concat=True, flatten_sequential=True), **kwargs)
|
||||
DPN, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(feature_concat=True, flatten_sequential=True),
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -452,18 +452,20 @@ class EfficientNetFeatures(nn.Module):
|
|||
return list(out.values())
|
||||
|
||||
|
||||
def _create_effnet(model_kwargs, variant, pretrained=False):
|
||||
def _create_effnet(variant, pretrained=False, **kwargs):
|
||||
features_only = False
|
||||
model_cls = EfficientNet
|
||||
if model_kwargs.pop('features_only', False):
|
||||
kwargs_filter = None
|
||||
if kwargs.pop('features_only', False):
|
||||
features_only = True
|
||||
model_kwargs.pop('num_classes', 0)
|
||||
model_kwargs.pop('num_features', 0)
|
||||
model_kwargs.pop('head_conv', None)
|
||||
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool')
|
||||
model_cls = EfficientNetFeatures
|
||||
model = build_model_with_cfg(
|
||||
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
pretrained_strict=not features_only, **model_kwargs)
|
||||
model_cls, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
pretrained_strict=not features_only,
|
||||
kwargs_filter=kwargs_filter,
|
||||
**kwargs)
|
||||
if features_only:
|
||||
model.default_cfg = default_cfg_for_features(model.default_cfg)
|
||||
return model
|
||||
|
@ -501,7 +503,7 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
|
|||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
model = _create_effnet(model_kwargs, variant, pretrained)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -537,7 +539,7 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
|
|||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
model = _create_effnet(model_kwargs, variant, pretrained)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -566,7 +568,7 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
|
|||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
model = _create_effnet(model_kwargs,variant, pretrained)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -595,7 +597,7 @@ def _gen_mobilenet_v2(
|
|||
act_layer=resolve_act_layer(kwargs, 'relu6'),
|
||||
**kwargs
|
||||
)
|
||||
model = _create_effnet(model_kwargs, variant, pretrained)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -625,7 +627,7 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
|||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
model = _create_effnet(model_kwargs, variant, pretrained)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -660,7 +662,7 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
|||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
model = _create_effnet(model_kwargs, variant, pretrained)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -706,7 +708,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
|
|||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_effnet(model_kwargs, variant, pretrained)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -735,7 +737,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
|
|||
act_layer=resolve_act_layer(kwargs, 'relu'),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_effnet(model_kwargs, variant, pretrained)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -765,7 +767,7 @@ def _gen_efficientnet_condconv(
|
|||
act_layer=resolve_act_layer(kwargs, 'swish'),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_effnet(model_kwargs, variant, pretrained)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -806,7 +808,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
|
|||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_effnet(model_kwargs, variant, pretrained)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -839,7 +841,7 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
|||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
model = _create_effnet(model_kwargs, variant, pretrained)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -872,7 +874,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai
|
|||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
model = _create_effnet(model_kwargs, variant, pretrained)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,25 @@
|
|||
from .registry import is_model, is_model_in_modules, model_entrypoint
|
||||
from .helpers import load_checkpoint
|
||||
from .layers import set_layer_config
|
||||
from .hub import load_model_config_from_hf
|
||||
|
||||
|
||||
def split_model_name(model_name):
|
||||
model_split = model_name.split(':', 1)
|
||||
if len(model_split) == 1:
|
||||
return '', model_split[0]
|
||||
else:
|
||||
source_name, model_name = model_split
|
||||
assert source_name in ('timm', 'hf_hub')
|
||||
return source_name, model_name
|
||||
|
||||
|
||||
def safe_model_name(model_name, remove_source=True):
|
||||
def make_safe(name):
|
||||
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
|
||||
if remove_source:
|
||||
model_name = split_model_name(model_name)[-1]
|
||||
return make_safe(model_name)
|
||||
|
||||
|
||||
def create_model(
|
||||
|
@ -26,7 +45,7 @@ def create_model(
|
|||
global_pool (str): global pool type (default: 'avg')
|
||||
**: other kwargs are model specific
|
||||
"""
|
||||
model_args = dict(pretrained=pretrained)
|
||||
source_name, model_name = split_model_name(model_name)
|
||||
|
||||
# Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
|
||||
is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3'])
|
||||
|
@ -47,12 +66,19 @@ def create_model(
|
|||
# non-supporting models don't break and default args remain in effect.
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
if source_name == 'hf_hub':
|
||||
# For model names specified in the form `hf_hub:path/architecture_name#revision`,
|
||||
# load model weights + default_cfg from Hugging Face hub.
|
||||
hf_default_cfg, model_name = load_model_config_from_hf(model_name)
|
||||
kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday
|
||||
|
||||
if is_model(model_name):
|
||||
create_fn = model_entrypoint(model_name)
|
||||
else:
|
||||
raise RuntimeError('Unknown model (%s)' % model_name)
|
||||
|
||||
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
|
||||
if is_model(model_name):
|
||||
create_fn = model_entrypoint(model_name)
|
||||
model = create_fn(**model_args, **kwargs)
|
||||
else:
|
||||
raise RuntimeError('Unknown model (%s)' % model_name)
|
||||
model = create_fn(pretrained=pretrained, **kwargs)
|
||||
|
||||
if checkpoint_path:
|
||||
load_checkpoint(model, checkpoint_path)
|
||||
|
|
|
@ -58,7 +58,10 @@ default_cfgs = {
|
|||
|
||||
|
||||
def _create_resnet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
|
||||
return build_model_with_cfg(
|
||||
ResNet, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -233,8 +233,10 @@ class Xception65(nn.Module):
|
|||
|
||||
def _create_gluon_xception(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
Xception65, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(feature_cls='hook'), **kwargs)
|
||||
Xception65, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(feature_cls='hook'),
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -35,7 +35,6 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
|
|||
|
||||
"""
|
||||
num_features = 1280
|
||||
act_layer = resolve_act_layer(kwargs, 'hard_swish')
|
||||
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
|
@ -43,23 +42,24 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
|
|||
stem_size=32,
|
||||
channel_multiplier=1,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
act_layer=act_layer,
|
||||
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
|
||||
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
features_only = False
|
||||
model_cls = MobileNetV3
|
||||
kwargs_filter = None
|
||||
if model_kwargs.pop('features_only', False):
|
||||
features_only = True
|
||||
model_kwargs.pop('num_classes', 0)
|
||||
model_kwargs.pop('num_features', 0)
|
||||
model_kwargs.pop('head_conv', None)
|
||||
model_kwargs.pop('head_bias', None)
|
||||
kwargs_filter = ('num_classes', 'num_features', 'global_pool', 'head_conv', 'head_bias', 'global_pool')
|
||||
model_cls = MobileNetV3Features
|
||||
model = build_model_with_cfg(
|
||||
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
pretrained_strict=not features_only, **model_kwargs)
|
||||
model_cls, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
pretrained_strict=not features_only,
|
||||
kwargs_filter=kwargs_filter,
|
||||
**model_kwargs)
|
||||
if features_only:
|
||||
model.default_cfg = default_cfg_for_features(model.default_cfg)
|
||||
return model
|
||||
|
|
|
@ -7,17 +7,14 @@ import os
|
|||
import math
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Callable
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX
|
||||
try:
|
||||
from torch.hub import get_dir
|
||||
except ImportError:
|
||||
from torch.hub import _get_torch_home as get_dir
|
||||
|
||||
|
||||
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
|
||||
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url
|
||||
from .layers import Conv2dSame, Linear
|
||||
|
||||
|
||||
|
@ -92,7 +89,7 @@ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None,
|
|||
raise FileNotFoundError()
|
||||
|
||||
|
||||
def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_hash=False):
|
||||
def load_custom_pretrained(model, default_cfg=None, load_fn=None, progress=False, check_hash=False):
|
||||
r"""Loads a custom (read non .pth) weight file
|
||||
|
||||
Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
|
||||
|
@ -104,7 +101,7 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
|
|||
|
||||
Args:
|
||||
model: The instantiated model to load weights into
|
||||
cfg (dict): Default pretrained model cfg
|
||||
default_cfg (dict): Default pretrained model cfg
|
||||
load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named
|
||||
'laod_pretrained' on the model will be called if it exists
|
||||
progress (bool, optional): whether or not to display a progress bar to stderr. Default: False
|
||||
|
@ -113,31 +110,12 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
|
|||
digits of the SHA256 hash of the contents of the file. The hash is used to
|
||||
ensure unique names and to verify the contents of the file. Default: False
|
||||
"""
|
||||
cfg = cfg or getattr(model, 'default_cfg')
|
||||
if cfg is None or not cfg.get('url', None):
|
||||
default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {}
|
||||
pretrained_url = default_cfg.get('url', None)
|
||||
if not pretrained_url:
|
||||
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
||||
return
|
||||
url = cfg['url']
|
||||
|
||||
# Issue warning to move data if old env is set
|
||||
if os.getenv('TORCH_MODEL_ZOO'):
|
||||
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
||||
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(model_dir, filename)
|
||||
if not os.path.exists(cached_file):
|
||||
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
hash_prefix = None
|
||||
if check_hash:
|
||||
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
||||
hash_prefix = r.group(1) if r else None
|
||||
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
|
||||
cached_file = download_cached_file(default_cfg['url'], check_hash=check_hash, progress=progress)
|
||||
|
||||
if load_fn is not None:
|
||||
load_fn(model, cached_file)
|
||||
|
@ -172,17 +150,39 @@ def adapt_input_conv(in_chans, conv_weight):
|
|||
return conv_weight
|
||||
|
||||
|
||||
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
|
||||
cfg = cfg or getattr(model, 'default_cfg')
|
||||
if cfg is None or not cfg.get('url', None):
|
||||
def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
|
||||
""" Load pretrained checkpoint
|
||||
|
||||
Args:
|
||||
model (nn.Module) : PyTorch model module
|
||||
default_cfg (Optional[Dict]): default configuration for pretrained weights / target dataset
|
||||
num_classes (int): num_classes for model
|
||||
in_chans (int): in_chans for model
|
||||
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
|
||||
strict (bool): strict load of checkpoint
|
||||
progress (bool): enable progress bar for weight download
|
||||
|
||||
"""
|
||||
default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {}
|
||||
pretrained_url = default_cfg.get('url', None)
|
||||
hf_hub_id = default_cfg.get('hf_hub', None)
|
||||
if not pretrained_url and not hf_hub_id:
|
||||
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
||||
return
|
||||
|
||||
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu')
|
||||
if hf_hub_id and has_hf_hub(necessary=not pretrained_url):
|
||||
_logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})')
|
||||
state_dict = load_state_dict_from_hf(hf_hub_id)
|
||||
else:
|
||||
_logger.info(f'Loading pretrained weights from url ({pretrained_url})')
|
||||
state_dict = load_state_dict_from_url(pretrained_url, progress=progress, map_location='cpu')
|
||||
if filter_fn is not None:
|
||||
state_dict = filter_fn(state_dict)
|
||||
# for backwards compat with filter fn that take one arg, try one first, the two
|
||||
try:
|
||||
state_dict = filter_fn(state_dict)
|
||||
except TypeError:
|
||||
state_dict = filter_fn(state_dict, model)
|
||||
|
||||
input_convs = cfg.get('first_conv', None)
|
||||
input_convs = default_cfg.get('first_conv', None)
|
||||
if input_convs is not None and in_chans != 3:
|
||||
if isinstance(input_convs, str):
|
||||
input_convs = (input_convs,)
|
||||
|
@ -198,19 +198,20 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
|
|||
_logger.warning(
|
||||
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
|
||||
|
||||
classifier_name = cfg['classifier']
|
||||
label_offset = cfg.get('label_offset', 0)
|
||||
if num_classes != cfg['num_classes']:
|
||||
# completely discard fully connected if model num_classes doesn't match pretrained weights
|
||||
del state_dict[classifier_name + '.weight']
|
||||
del state_dict[classifier_name + '.bias']
|
||||
strict = False
|
||||
elif label_offset > 0:
|
||||
# special case for pretrained weights with an extra background class in pretrained weights
|
||||
classifier_weight = state_dict[classifier_name + '.weight']
|
||||
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
|
||||
classifier_bias = state_dict[classifier_name + '.bias']
|
||||
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
|
||||
classifier_name = default_cfg.get('classifier', None)
|
||||
label_offset = default_cfg.get('label_offset', 0)
|
||||
if classifier_name is not None:
|
||||
if num_classes != default_cfg['num_classes']:
|
||||
# completely discard fully connected if model num_classes doesn't match pretrained weights
|
||||
del state_dict[classifier_name + '.weight']
|
||||
del state_dict[classifier_name + '.bias']
|
||||
strict = False
|
||||
elif label_offset > 0:
|
||||
# special case for pretrained weights with an extra background class in pretrained weights
|
||||
classifier_weight = state_dict[classifier_name + '.weight']
|
||||
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
|
||||
classifier_bias = state_dict[classifier_name + '.bias']
|
||||
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
|
||||
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
@ -316,40 +317,123 @@ def adapt_model_from_file(parent_module, model_variant):
|
|||
def default_cfg_for_features(default_cfg):
|
||||
default_cfg = deepcopy(default_cfg)
|
||||
# remove default pretrained cfg fields that don't have much relevance for feature backbone
|
||||
to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size?
|
||||
to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size?
|
||||
for tr in to_remove:
|
||||
default_cfg.pop(tr, None)
|
||||
return default_cfg
|
||||
|
||||
|
||||
def overlay_external_default_cfg(default_cfg, kwargs):
|
||||
""" Overlay 'external_default_cfg' in kwargs on top of default_cfg arg.
|
||||
"""
|
||||
external_default_cfg = kwargs.pop('external_default_cfg', None)
|
||||
if external_default_cfg:
|
||||
default_cfg.pop('url', None) # url should come from external cfg
|
||||
default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg
|
||||
default_cfg.update(external_default_cfg)
|
||||
|
||||
|
||||
def set_default_kwargs(kwargs, names, default_cfg):
|
||||
for n in names:
|
||||
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while
|
||||
# default_cfg has one input_size=(C, H ,W) entry
|
||||
if n == 'img_size':
|
||||
input_size = default_cfg.get('input_size', None)
|
||||
if input_size is not None:
|
||||
assert len(input_size) == 3
|
||||
kwargs.setdefault(n, input_size[-2:])
|
||||
elif n == 'in_chans':
|
||||
input_size = default_cfg.get('input_size', None)
|
||||
if input_size is not None:
|
||||
assert len(input_size) == 3
|
||||
kwargs.setdefault(n, input_size[0])
|
||||
else:
|
||||
default_val = default_cfg.get(n, None)
|
||||
if default_val is not None:
|
||||
kwargs.setdefault(n, default_cfg[n])
|
||||
|
||||
|
||||
def filter_kwargs(kwargs, names):
|
||||
if not kwargs or not names:
|
||||
return
|
||||
for n in names:
|
||||
kwargs.pop(n, None)
|
||||
|
||||
|
||||
def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
|
||||
""" Update the default_cfg and kwargs before passing to model
|
||||
|
||||
FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs
|
||||
could/should be replaced by an improved configuration mechanism
|
||||
|
||||
Args:
|
||||
default_cfg: input default_cfg (updated in-place)
|
||||
kwargs: keyword args passed to model build fn (updated in-place)
|
||||
kwargs_filter: keyword arg keys that must be removed before model __init__
|
||||
"""
|
||||
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
|
||||
overlay_external_default_cfg(default_cfg, kwargs)
|
||||
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
|
||||
set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg)
|
||||
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
|
||||
filter_kwargs(kwargs, names=kwargs_filter)
|
||||
|
||||
|
||||
def build_model_with_cfg(
|
||||
model_cls: Callable,
|
||||
variant: str,
|
||||
pretrained: bool,
|
||||
default_cfg: dict,
|
||||
model_cfg: dict = None,
|
||||
feature_cfg: dict = None,
|
||||
model_cfg: Optional[Any] = None,
|
||||
feature_cfg: Optional[dict] = None,
|
||||
pretrained_strict: bool = True,
|
||||
pretrained_filter_fn: Callable = None,
|
||||
pretrained_filter_fn: Optional[Callable] = None,
|
||||
pretrained_custom_load: bool = False,
|
||||
kwargs_filter: Optional[Tuple[str]] = None,
|
||||
**kwargs):
|
||||
""" Build model with specified default_cfg and optional model_cfg
|
||||
|
||||
This helper fn aids in the construction of a model including:
|
||||
* handling default_cfg and associated pretained weight loading
|
||||
* passing through optional model_cfg for models with config based arch spec
|
||||
* features_only model adaptation
|
||||
* pruning config / model adaptation
|
||||
|
||||
Args:
|
||||
model_cls (nn.Module): model class
|
||||
variant (str): model variant name
|
||||
pretrained (bool): load pretrained weights
|
||||
default_cfg (dict): model's default pretrained/task config
|
||||
model_cfg (Optional[Dict]): model's architecture config
|
||||
feature_cfg (Optional[Dict]: feature extraction adapter config
|
||||
pretrained_strict (bool): load pretrained weights strictly
|
||||
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
|
||||
pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights
|
||||
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
|
||||
**kwargs: model args passed through to model __init__
|
||||
"""
|
||||
pruned = kwargs.pop('pruned', False)
|
||||
features = False
|
||||
feature_cfg = feature_cfg or {}
|
||||
default_cfg = deepcopy(default_cfg) if default_cfg else {}
|
||||
update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter)
|
||||
default_cfg.setdefault('architecture', variant)
|
||||
|
||||
# Setup for feature extraction wrapper done at end of this fn
|
||||
if kwargs.pop('features_only', False):
|
||||
features = True
|
||||
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
|
||||
if 'out_indices' in kwargs:
|
||||
feature_cfg['out_indices'] = kwargs.pop('out_indices')
|
||||
|
||||
# Build the model
|
||||
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
|
||||
model.default_cfg = deepcopy(default_cfg)
|
||||
model.default_cfg = default_cfg
|
||||
|
||||
if pruned:
|
||||
model = adapt_model_from_file(model, variant)
|
||||
|
||||
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
|
||||
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
|
||||
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
|
||||
if pretrained:
|
||||
if pretrained_custom_load:
|
||||
|
@ -357,9 +441,12 @@ def build_model_with_cfg(
|
|||
else:
|
||||
load_pretrained(
|
||||
model,
|
||||
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
|
||||
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
|
||||
|
||||
num_classes=num_classes_pretrained,
|
||||
in_chans=kwargs.get('in_chans', 3),
|
||||
filter_fn=pretrained_filter_fn,
|
||||
strict=pretrained_strict)
|
||||
|
||||
# Wrap the model in a feature extraction module if enabled
|
||||
if features:
|
||||
feature_cls = FeatureListNet
|
||||
if 'feature_cls' in feature_cfg:
|
||||
|
|
|
@ -774,13 +774,18 @@ class HighResolutionNetFeatures(HighResolutionNet):
|
|||
def _create_hrnet(variant, pretrained, **model_kwargs):
|
||||
model_cls = HighResolutionNet
|
||||
features_only = False
|
||||
kwargs_filter = None
|
||||
if model_kwargs.pop('features_only', False):
|
||||
model_cls = HighResolutionNetFeatures
|
||||
model_kwargs['num_classes'] = 0
|
||||
kwargs_filter = ('num_classes', 'global_pool')
|
||||
features_only = True
|
||||
model = build_model_with_cfg(
|
||||
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
model_cfg=cfg_cls[variant], pretrained_strict=not features_only, **model_kwargs)
|
||||
model_cls, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
model_cfg=cfg_cls[variant],
|
||||
pretrained_strict=not features_only,
|
||||
kwargs_filter=kwargs_filter,
|
||||
**model_kwargs)
|
||||
if features_only:
|
||||
model.default_cfg = default_cfg_for_features(model.default_cfg)
|
||||
return model
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Union, Optional
|
||||
|
||||
import torch
|
||||
from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX
|
||||
try:
|
||||
from torch.hub import get_dir
|
||||
except ImportError:
|
||||
from torch.hub import _get_torch_home as get_dir
|
||||
|
||||
from timm import __version__
|
||||
try:
|
||||
from huggingface_hub import hf_hub_url
|
||||
from huggingface_hub import cached_download
|
||||
cached_download = partial(cached_download, library_name="timm", library_version=__version__)
|
||||
except ImportError:
|
||||
hf_hub_url = None
|
||||
cached_download = None
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_cache_dir(child_dir=''):
|
||||
"""
|
||||
Returns the location of the directory where models are cached (and creates it if necessary).
|
||||
"""
|
||||
# Issue warning to move data if old env is set
|
||||
if os.getenv('TORCH_MODEL_ZOO'):
|
||||
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
||||
|
||||
hub_dir = get_dir()
|
||||
child_dir = () if not child_dir else (child_dir,)
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
return model_dir
|
||||
|
||||
|
||||
def download_cached_file(url, check_hash=True, progress=False):
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(get_cache_dir(), filename)
|
||||
if not os.path.exists(cached_file):
|
||||
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
hash_prefix = None
|
||||
if check_hash:
|
||||
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
||||
hash_prefix = r.group(1) if r else None
|
||||
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
|
||||
return cached_file
|
||||
|
||||
|
||||
def has_hf_hub(necessary=False):
|
||||
if hf_hub_url is None and necessary:
|
||||
# if no HF Hub module installed and it is necessary to continue, raise error
|
||||
raise RuntimeError(
|
||||
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
|
||||
return hf_hub_url is not None
|
||||
|
||||
|
||||
def hf_split(hf_id):
|
||||
rev_split = hf_id.split('#')
|
||||
assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one # character to identify revision.'
|
||||
hf_model_id = rev_split[0]
|
||||
hf_revision = rev_split[-1] if len(rev_split) > 1 else None
|
||||
return hf_model_id, hf_revision
|
||||
|
||||
|
||||
def load_cfg_from_json(json_file: Union[str, os.PathLike]):
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
return json.loads(text)
|
||||
|
||||
|
||||
def _download_from_hf(model_id: str, filename: str):
|
||||
hf_model_id, hf_revision = hf_split(model_id)
|
||||
url = hf_hub_url(hf_model_id, filename, revision=hf_revision)
|
||||
return cached_download(url, cache_dir=get_cache_dir('hf'))
|
||||
|
||||
|
||||
def load_model_config_from_hf(model_id: str):
|
||||
assert has_hf_hub(True)
|
||||
cached_file = _download_from_hf(model_id, 'config.json')
|
||||
default_cfg = load_cfg_from_json(cached_file)
|
||||
default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation
|
||||
model_name = default_cfg.get('architecture')
|
||||
return default_cfg, model_name
|
||||
|
||||
|
||||
def load_state_dict_from_hf(model_id: str):
|
||||
assert has_hf_hub(True)
|
||||
cached_file = _download_from_hf(model_id, 'pytorch_model.bin')
|
||||
state_dict = torch.load(cached_file, map_location='cpu')
|
||||
return state_dict
|
|
@ -336,7 +336,9 @@ class InceptionResnetV2(nn.Module):
|
|||
|
||||
def _create_inception_resnet_v2(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
InceptionResnetV2, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs)
|
||||
InceptionResnetV2, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -434,8 +434,10 @@ def _create_inception_v3(variant, pretrained=False, **kwargs):
|
|||
model_cls = InceptionV3
|
||||
load_strict = not default_cfg['has_aux']
|
||||
return build_model_with_cfg(
|
||||
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
pretrained_strict=load_strict, **kwargs)
|
||||
model_cls, variant, pretrained,
|
||||
default_cfg=default_cfg,
|
||||
pretrained_strict=load_strict,
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -305,8 +305,10 @@ class InceptionV4(nn.Module):
|
|||
|
||||
def _create_inception_v4(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
InceptionV4, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(flatten_sequential=True), **kwargs)
|
||||
InceptionV4, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(flatten_sequential=True),
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -200,19 +200,20 @@ class MobileNetV3Features(nn.Module):
|
|||
return list(out.values())
|
||||
|
||||
|
||||
def _create_mnv3(model_kwargs, variant, pretrained=False):
|
||||
def _create_mnv3(variant, pretrained=False, **kwargs):
|
||||
features_only = False
|
||||
model_cls = MobileNetV3
|
||||
if model_kwargs.pop('features_only', False):
|
||||
kwargs_filter = None
|
||||
if kwargs.pop('features_only', False):
|
||||
features_only = True
|
||||
model_kwargs.pop('num_classes', 0)
|
||||
model_kwargs.pop('num_features', 0)
|
||||
model_kwargs.pop('head_conv', None)
|
||||
model_kwargs.pop('head_bias', None)
|
||||
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')
|
||||
model_cls = MobileNetV3Features
|
||||
model = build_model_with_cfg(
|
||||
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
pretrained_strict=not features_only, **model_kwargs)
|
||||
model_cls, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
pretrained_strict=not features_only,
|
||||
kwargs_filter=kwargs_filter,
|
||||
**kwargs)
|
||||
if features_only:
|
||||
model.default_cfg = default_cfg_for_features(model.default_cfg)
|
||||
return model
|
||||
|
@ -252,7 +253,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
|
|||
se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=1),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_mnv3(model_kwargs, variant, pretrained)
|
||||
model = _create_mnv3(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -348,7 +349,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
|
|||
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_mnv3(model_kwargs, variant, pretrained)
|
||||
model = _create_mnv3(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -553,7 +553,8 @@ class NASNetALarge(nn.Module):
|
|||
|
||||
def _create_nasnet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
NASNetALarge, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
NASNetALarge, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
|
||||
**kwargs)
|
||||
|
||||
|
|
|
@ -334,7 +334,8 @@ class PNASNet5Large(nn.Module):
|
|||
|
||||
def _create_pnasnet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
PNASNet5Large, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
PNASNet5Large, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
|
||||
**kwargs)
|
||||
|
||||
|
|
|
@ -330,7 +330,10 @@ class RegNet(nn.Module):
|
|||
|
||||
def _create_regnet(variant, pretrained, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
RegNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant], **kwargs)
|
||||
RegNet, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
model_cfg=model_cfgs[variant],
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -134,7 +134,9 @@ class Bottle2neck(nn.Module):
|
|||
|
||||
def _create_res2net(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
ResNet, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs)
|
||||
ResNet, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -141,7 +141,9 @@ class ResNestBottleneck(nn.Module):
|
|||
|
||||
def _create_resnest(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
|
||||
ResNet, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -634,7 +634,9 @@ class ResNet(nn.Module):
|
|||
|
||||
def _create_resnet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
|
||||
ResNet, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -413,8 +413,11 @@ class ResNetV2(nn.Module):
|
|||
def _create_resnetv2(variant, pretrained=False, **kwargs):
|
||||
feature_cfg = dict(flatten_sequential=True)
|
||||
return build_model_with_cfg(
|
||||
ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_custom_load=True,
|
||||
feature_cfg=feature_cfg, **kwargs)
|
||||
ResNetV2, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
feature_cfg=feature_cfg,
|
||||
pretrained_custom_load=True,
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -199,7 +199,10 @@ class ReXNetV1(nn.Module):
|
|||
def _create_rexnet(variant, pretrained, **kwargs):
|
||||
feature_cfg = dict(flatten_sequential=True)
|
||||
return build_model_with_cfg(
|
||||
ReXNetV1, variant, pretrained, default_cfg=default_cfgs[variant], feature_cfg=feature_cfg, **kwargs)
|
||||
ReXNetV1, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
feature_cfg=feature_cfg,
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -196,7 +196,7 @@ class SelecSLS(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def _create_selecsls(variant, pretrained, model_kwargs):
|
||||
def _create_selecsls(variant, pretrained, **kwargs):
|
||||
cfg = {}
|
||||
feature_info = [dict(num_chs=32, reduction=2, module='stem.2')]
|
||||
if variant.startswith('selecsls42'):
|
||||
|
@ -320,40 +320,43 @@ def _create_selecsls(variant, pretrained, model_kwargs):
|
|||
|
||||
# this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises?
|
||||
return build_model_with_cfg(
|
||||
SelecSLS, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=cfg,
|
||||
feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True), **model_kwargs)
|
||||
SelecSLS, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
model_cfg=cfg,
|
||||
feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True),
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def selecsls42(pretrained=False, **kwargs):
|
||||
"""Constructs a SelecSLS42 model.
|
||||
"""
|
||||
return _create_selecsls('selecsls42', pretrained, kwargs)
|
||||
return _create_selecsls('selecsls42', pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def selecsls42b(pretrained=False, **kwargs):
|
||||
"""Constructs a SelecSLS42_B model.
|
||||
"""
|
||||
return _create_selecsls('selecsls42b', pretrained, kwargs)
|
||||
return _create_selecsls('selecsls42b', pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def selecsls60(pretrained=False, **kwargs):
|
||||
"""Constructs a SelecSLS60 model.
|
||||
"""
|
||||
return _create_selecsls('selecsls60', pretrained, kwargs)
|
||||
return _create_selecsls('selecsls60', pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def selecsls60b(pretrained=False, **kwargs):
|
||||
"""Constructs a SelecSLS60_B model.
|
||||
"""
|
||||
return _create_selecsls('selecsls60b', pretrained, kwargs)
|
||||
return _create_selecsls('selecsls60b', pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def selecsls84(pretrained=False, **kwargs):
|
||||
"""Constructs a SelecSLS84 model.
|
||||
"""
|
||||
return _create_selecsls('selecsls84', pretrained, kwargs)
|
||||
return _create_selecsls('selecsls84', pretrained, **kwargs)
|
||||
|
|
|
@ -398,7 +398,9 @@ class SENet(nn.Module):
|
|||
|
||||
def _create_senet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
SENet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
|
||||
SENet, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -141,7 +141,9 @@ class SelectiveKernelBottleneck(nn.Module):
|
|||
|
||||
def _create_skresnet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
|
||||
ResNet, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -253,8 +253,10 @@ class TResNet(nn.Module):
|
|||
|
||||
def _create_tresnet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
TResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained,
|
||||
feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True), **kwargs)
|
||||
TResNet, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True),
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -180,9 +180,9 @@ def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG:
|
|||
# NOTE: VGG is one of the only models with stride==1 features, so indices are offset from other models
|
||||
out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5))
|
||||
model = build_model_with_cfg(
|
||||
VGG, variant, pretrained=pretrained,
|
||||
model_cfg=cfgs[cfg],
|
||||
VGG, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
model_cfg=cfgs[cfg],
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
pretrained_filter_fn=_filter_fn,
|
||||
**kwargs)
|
||||
|
|
|
@ -21,13 +21,14 @@ import math
|
|||
import logging
|
||||
from functools import partial
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import load_pretrained
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from .layers import StdConv2dSame, DropPath, to_2tuple, trunc_normal_
|
||||
from .resnet import resnet26d, resnet50d
|
||||
from .resnetv2 import ResNetV2
|
||||
|
@ -94,7 +95,7 @@ default_cfgs = {
|
|||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
'vit_huge_patch14_224_in21k': _cfg(
|
||||
url='', # FIXME I have weights for this but > 2GB limit for github release binaries
|
||||
hf_hub='timm/vit_huge_patch14_224_in21k',
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
|
||||
# hybrid models (weights ported from official Google JAX impl)
|
||||
|
@ -462,9 +463,10 @@ def checkpoint_filter_fn(state_dict, model):
|
|||
|
||||
|
||||
def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs):
|
||||
default_cfg = default_cfgs[variant]
|
||||
default_cfg = deepcopy(default_cfgs[variant])
|
||||
overlay_external_default_cfg(default_cfg, kwargs)
|
||||
default_num_classes = default_cfg['num_classes']
|
||||
default_img_size = default_cfg['input_size'][-1]
|
||||
default_img_size = default_cfg['input_size'][-2:]
|
||||
|
||||
num_classes = kwargs.pop('num_classes', default_num_classes)
|
||||
img_size = kwargs.pop('img_size', default_img_size)
|
||||
|
@ -475,14 +477,19 @@ def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwa
|
|||
_logger.warning("Removing representation layer for fine-tuning.")
|
||||
repr_size = None
|
||||
|
||||
model_cls = DistilledVisionTransformer if distilled else VisionTransformer
|
||||
model = model_cls(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
|
||||
model_cls = DistilledVisionTransformer if distilled else VisionTransformer
|
||||
model = build_model_with_cfg(
|
||||
model_cls, variant, pretrained,
|
||||
default_cfg=default_cfg,
|
||||
img_size=img_size,
|
||||
num_classes=num_classes,
|
||||
representation_size=repr_size,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
**kwargs)
|
||||
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3),
|
||||
filter_fn=partial(checkpoint_filter_fn, model=model))
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -338,8 +338,11 @@ class VovNet(nn.Module):
|
|||
|
||||
def _create_vovnet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
VovNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant],
|
||||
feature_cfg=dict(flatten_sequential=True), **kwargs)
|
||||
VovNet, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
model_cfg=model_cfgs[variant],
|
||||
feature_cfg=dict(flatten_sequential=True),
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -221,8 +221,10 @@ class Xception(nn.Module):
|
|||
|
||||
def _xception(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
Xception, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(feature_cls='hook'), **kwargs)
|
||||
Xception, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(feature_cls='hook'),
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -173,8 +173,10 @@ class XceptionAligned(nn.Module):
|
|||
|
||||
def _xception(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
XceptionAligned, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(flatten_sequential=True, feature_cls='hook'), **kwargs)
|
||||
XceptionAligned, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(flatten_sequential=True, feature_cls='hook'),
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
9
train.py
9
train.py
|
@ -29,7 +29,8 @@ import torchvision.utils
|
|||
from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
||||
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
||||
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters
|
||||
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
|
||||
convert_splitbn_model, model_parameters
|
||||
from timm.utils import *
|
||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
||||
from timm.optim import create_optimizer
|
||||
|
@ -345,8 +346,8 @@ def main():
|
|||
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
|
||||
|
||||
if args.local_rank == 0:
|
||||
_logger.info('Model %s created, param count: %d' %
|
||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||
_logger.info(
|
||||
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
|
||||
|
||||
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
||||
|
||||
|
@ -543,7 +544,7 @@ def main():
|
|||
output_base = args.output if args.output else './output'
|
||||
exp_name = '-'.join([
|
||||
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
||||
args.model,
|
||||
safe_model_name(args.model),
|
||||
str(data_config['input_size'][-1])
|
||||
])
|
||||
output_dir = get_outdir(output_base, 'train', exp_name)
|
||||
|
|
|
@ -211,7 +211,7 @@ def validate(args):
|
|||
model.eval()
|
||||
with torch.no_grad():
|
||||
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
|
||||
input = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
|
||||
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
|
||||
if args.channels_last:
|
||||
input = input.contiguous(memory_format=torch.channels_last)
|
||||
model(input)
|
||||
|
|
Loading…
Reference in New Issue