Throw when pretrained weights not available and pretrained=True (principle of least surprise).
parent
8ce9a2c00a
commit
ff2464e2a0
|
@ -152,8 +152,7 @@ def load_pretrained(
|
||||||
"""
|
"""
|
||||||
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
|
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
|
||||||
if not pretrained_cfg:
|
if not pretrained_cfg:
|
||||||
_logger.warning("Invalid pretrained config, cannot load weights.")
|
raise RuntimeError("Invalid pretrained config, cannot load weights. Use `pretrained=False` for random init.")
|
||||||
return
|
|
||||||
|
|
||||||
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
|
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
|
||||||
if load_from == 'state_dict':
|
if load_from == 'state_dict':
|
||||||
|
@ -186,8 +185,8 @@ def load_pretrained(
|
||||||
else:
|
else:
|
||||||
state_dict = load_state_dict_from_hf(pretrained_loc)
|
state_dict = load_state_dict_from_hf(pretrained_loc)
|
||||||
else:
|
else:
|
||||||
_logger.warning("No pretrained weights exist or were found for this model. Using random initialization.")
|
model_name = pretrained_cfg.get('architecture', 'this model')
|
||||||
return
|
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")
|
||||||
|
|
||||||
if filter_fn is not None:
|
if filter_fn is not None:
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Reference in New Issue