Improve resolve_pretrained_cfg behaviour when no cfg exists, warn instead of crash. Improve usability ex #1311
parent
879df47c0a
commit
7d657d2ef4
|
@ -455,18 +455,27 @@ def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter):
|
|||
filter_kwargs(kwargs, names=kwargs_filter)
|
||||
|
||||
|
||||
def resolve_pretrained_cfg(variant: str, pretrained_cfg=None, kwargs=None):
|
||||
def resolve_pretrained_cfg(variant: str, **kwargs):
|
||||
pretrained_cfg = kwargs.pop('pretrained_cfg', None)
|
||||
if pretrained_cfg and isinstance(pretrained_cfg, dict):
|
||||
# highest priority, pretrained_cfg available and passed explicitly
|
||||
# highest priority, pretrained_cfg available and passed in args
|
||||
return deepcopy(pretrained_cfg)
|
||||
if kwargs and 'pretrained_cfg' in kwargs:
|
||||
# next highest, pretrained_cfg in a kwargs dict, pop and return
|
||||
pretrained_cfg = kwargs.pop('pretrained_cfg', {})
|
||||
if pretrained_cfg:
|
||||
return deepcopy(pretrained_cfg)
|
||||
# lookup pretrained cfg in model registry by variant
|
||||
# fallback to looking up pretrained cfg in model registry by variant identifier
|
||||
pretrained_cfg = get_pretrained_cfg(variant)
|
||||
assert pretrained_cfg
|
||||
if not pretrained_cfg:
|
||||
_logger.warning(
|
||||
f"No pretrained configuration specified for {variant} model. Using a default."
|
||||
f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
|
||||
pretrained_cfg = dict(
|
||||
url='',
|
||||
num_classes=1000,
|
||||
input_size=(3, 224, 224),
|
||||
pool_size=None,
|
||||
crop_pct=.9,
|
||||
interpolation='bicubic',
|
||||
first_conv='',
|
||||
classifier='',
|
||||
)
|
||||
return pretrained_cfg
|
||||
|
||||
|
||||
|
|
|
@ -428,7 +428,7 @@ class InceptionV3Aux(InceptionV3):
|
|||
|
||||
|
||||
def _create_inception_v3(variant, pretrained=False, **kwargs):
|
||||
pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
|
||||
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
|
||||
aux_logits = kwargs.pop('aux_logits', False)
|
||||
if aux_logits:
|
||||
assert not kwargs.pop('features_only', False)
|
||||
|
|
|
@ -633,7 +633,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
|||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
|
||||
pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
|
||||
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
|
||||
model = build_model_with_cfg(
|
||||
VisionTransformer, variant, pretrained,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
|
|
|
@ -16,7 +16,7 @@ import torch.nn.functional as F
|
|||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply
|
||||
from .helpers import build_model_with_cfg, named_apply
|
||||
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple
|
||||
from .registry import register_model
|
||||
|
||||
|
|
Loading…
Reference in New Issue