mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add support for passing model args via hf hub config
This commit is contained in:
parent
23e7f17724
commit
a604011935
@ -99,7 +99,10 @@ def create_model(
|
||||
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
|
||||
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
|
||||
# load model weights + pretrained_cfg from Hugging Face hub.
|
||||
pretrained_cfg, model_name = load_model_config_from_hf(model_name)
|
||||
pretrained_cfg, model_name, model_args = load_model_config_from_hf(model_name)
|
||||
if model_args:
|
||||
for k, v in model_args.items():
|
||||
kwargs.setdefault(k, v)
|
||||
else:
|
||||
model_name, pretrained_tag = split_model_name_tag(model_name)
|
||||
if pretrained_tag and not pretrained_cfg:
|
||||
|
@ -164,8 +164,9 @@ def load_model_config_from_hf(model_id: str):
|
||||
if 'label_descriptions' in hf_config:
|
||||
pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions')
|
||||
|
||||
model_args = hf_config.get('model_args', {})
|
||||
model_name = hf_config['architecture']
|
||||
return pretrained_cfg, model_name
|
||||
return pretrained_cfg, model_name, model_args
|
||||
|
||||
|
||||
def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
||||
@ -193,19 +194,23 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
||||
def save_config_for_hf(
|
||||
model,
|
||||
config_path: str,
|
||||
model_config: Optional[dict] = None
|
||||
model_config: Optional[dict] = None,
|
||||
model_args: Optional[dict] = None
|
||||
):
|
||||
model_config = model_config or {}
|
||||
hf_config = {}
|
||||
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
|
||||
# set some values at root config level
|
||||
hf_config['architecture'] = pretrained_cfg.pop('architecture')
|
||||
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
|
||||
hf_config['num_features'] = model_config.get('num_features', model.num_features)
|
||||
global_pool_type = model_config.get('global_pool', getattr(model, 'global_pool', None))
|
||||
hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)
|
||||
|
||||
# NOTE these attr saved for informational purposes, do not impact model build
|
||||
hf_config['num_features'] = model_config.pop('num_features', model.num_features)
|
||||
global_pool_type = model_config.pop('global_pool', getattr(model, 'global_pool', None))
|
||||
if isinstance(global_pool_type, str) and global_pool_type:
|
||||
hf_config['global_pool'] = global_pool_type
|
||||
|
||||
# Save class label info
|
||||
if 'labels' in model_config:
|
||||
_logger.warning(
|
||||
"'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
|
||||
@ -225,6 +230,9 @@ def save_config_for_hf(
|
||||
# maps label names -> descriptions
|
||||
hf_config['label_descriptions'] = label_descriptions
|
||||
|
||||
if model_args:
|
||||
hf_config['model_args'] = model_args
|
||||
|
||||
hf_config['pretrained_cfg'] = pretrained_cfg
|
||||
hf_config.update(model_config)
|
||||
|
||||
@ -236,6 +244,7 @@ def save_for_hf(
|
||||
model,
|
||||
save_directory: str,
|
||||
model_config: Optional[dict] = None,
|
||||
model_args: Optional[dict] = None,
|
||||
safe_serialization: Union[bool, Literal["both"]] = False,
|
||||
):
|
||||
assert has_hf_hub(True)
|
||||
@ -251,11 +260,16 @@ def save_for_hf(
|
||||
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
|
||||
|
||||
config_path = save_directory / 'config.json'
|
||||
save_config_for_hf(model, config_path, model_config=model_config)
|
||||
save_config_for_hf(
|
||||
model,
|
||||
config_path,
|
||||
model_config=model_config,
|
||||
model_args=model_args,
|
||||
)
|
||||
|
||||
|
||||
def push_to_hf_hub(
|
||||
model,
|
||||
model: torch.nn.Module,
|
||||
repo_id: str,
|
||||
commit_message: str = 'Add model',
|
||||
token: Optional[str] = None,
|
||||
@ -264,6 +278,7 @@ def push_to_hf_hub(
|
||||
create_pr: bool = False,
|
||||
model_config: Optional[dict] = None,
|
||||
model_card: Optional[dict] = None,
|
||||
model_args: Optional[dict] = None,
|
||||
safe_serialization: Union[bool, Literal["both"]] = False,
|
||||
):
|
||||
"""
|
||||
@ -291,7 +306,13 @@ def push_to_hf_hub(
|
||||
# Dump model and push to Hub
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Save model weights and config.
|
||||
save_for_hf(model, tmpdir, model_config=model_config, safe_serialization=safe_serialization)
|
||||
save_for_hf(
|
||||
model,
|
||||
tmpdir,
|
||||
model_config=model_config,
|
||||
model_args=model_args,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Add readme if it does not exist
|
||||
if not has_readme:
|
||||
|
Loading…
x
Reference in New Issue
Block a user