mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Improve support for custom dataset label name/description through HF hub export, via pretrained_cfg
This commit is contained in:
parent
1e0b347227
commit
9c14654a0d
@ -4,7 +4,7 @@ from .config import resolve_data_config, resolve_model_data_config
|
||||
from .constants import *
|
||||
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
|
||||
from .dataset_factory import create_dataset
|
||||
from .dataset_info import DatasetInfo
|
||||
from .dataset_info import DatasetInfo, CustomDatasetInfo
|
||||
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
|
||||
from .loader import create_loader
|
||||
from .mixup import Mixup, FastCollateMixup
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
|
||||
class DatasetInfo(ABC):
|
||||
@ -29,4 +29,45 @@ class DatasetInfo(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class CustomDatasetInfo(DatasetInfo):
|
||||
""" DatasetInfo that wraps passed values for custom datasets."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
label_names: Union[List[str], Dict[int, str]],
|
||||
label_descriptions: Optional[Dict[str, str]] = None
|
||||
):
|
||||
super().__init__()
|
||||
assert len(label_names) > 0
|
||||
self._label_names = label_names # label index => label name mapping
|
||||
self._label_descriptions = label_descriptions # label name => label description mapping
|
||||
if self._label_descriptions is not None:
|
||||
# validate descriptions (label names required)
|
||||
assert isinstance(self._label_descriptions, dict)
|
||||
for n in self._label_names:
|
||||
assert n in self._label_descriptions
|
||||
|
||||
def num_classes(self):
|
||||
return len(self._label_names)
|
||||
|
||||
def label_names(self):
|
||||
return self._label_names
|
||||
|
||||
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
|
||||
return self._label_descriptions
|
||||
|
||||
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
|
||||
if self._label_descriptions:
|
||||
return self._label_descriptions[label]
|
||||
return label # return label name itself if a descriptions is not present
|
||||
|
||||
def index_to_label_name(self, index) -> str:
|
||||
assert 0 <= index < len(self._label_names)
|
||||
return self._label_names[index]
|
||||
|
||||
def index_to_description(self, index: int, detailed: bool = False) -> str:
|
||||
label = self.index_to_label_name(index)
|
||||
return self.label_name_to_description(label, detailed=detailed)
|
||||
|
@ -16,6 +16,7 @@ except ImportError:
|
||||
from torch.hub import _get_torch_home as get_dir
|
||||
|
||||
from timm import __version__
|
||||
from timm.layers import ClassifierHead, NormMlpClassifierHead
|
||||
from timm.models._pretrained import filter_pretrained_cfg
|
||||
|
||||
try:
|
||||
@ -96,7 +97,7 @@ def has_hf_hub(necessary=False):
|
||||
return _has_hf_hub
|
||||
|
||||
|
||||
def hf_split(hf_id):
|
||||
def hf_split(hf_id: str):
|
||||
# FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
|
||||
rev_split = hf_id.split('@')
|
||||
assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
|
||||
@ -127,19 +128,26 @@ def load_model_config_from_hf(model_id: str):
|
||||
hf_config = {}
|
||||
hf_config['architecture'] = pretrained_cfg.pop('architecture')
|
||||
hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
|
||||
if 'labels' in pretrained_cfg:
|
||||
hf_config['label_name'] = pretrained_cfg.pop('labels')
|
||||
if 'labels' in pretrained_cfg: # deprecated name for 'label_names'
|
||||
pretrained_cfg['label_names'] = pretrained_cfg.pop('labels')
|
||||
hf_config['pretrained_cfg'] = pretrained_cfg
|
||||
|
||||
# NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
|
||||
pretrained_cfg = hf_config['pretrained_cfg']
|
||||
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
|
||||
pretrained_cfg['source'] = 'hf-hub'
|
||||
if 'num_classes' in hf_config:
|
||||
# model should be created with parent num_classes if they exist
|
||||
pretrained_cfg['num_classes'] = hf_config['num_classes']
|
||||
model_name = hf_config['architecture']
|
||||
|
||||
# model should be created with base config num_classes if its exist
|
||||
if 'num_classes' in hf_config:
|
||||
pretrained_cfg['num_classes'] = hf_config['num_classes']
|
||||
|
||||
# label meta-data in base config overrides saved pretrained_cfg on load
|
||||
if 'label_names' in hf_config:
|
||||
pretrained_cfg['label_names'] = hf_config.pop('label_names')
|
||||
if 'label_descriptions' in hf_config:
|
||||
pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions')
|
||||
|
||||
model_name = hf_config['architecture']
|
||||
return pretrained_cfg, model_name
|
||||
|
||||
|
||||
@ -150,7 +158,7 @@ def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_config_for_hf(model, config_path, model_config=None):
|
||||
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
|
||||
model_config = model_config or {}
|
||||
hf_config = {}
|
||||
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
|
||||
@ -164,22 +172,22 @@ def save_config_for_hf(model, config_path, model_config=None):
|
||||
|
||||
if 'labels' in model_config:
|
||||
_logger.warning(
|
||||
"'labels' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. "
|
||||
"Using provided 'label' field as 'label_name'.")
|
||||
model_config['label_name'] = model_config.pop('labels')
|
||||
"'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
|
||||
" Renaming provided 'labels' field to 'label_names'.")
|
||||
model_config.setdefault('label_names', model_config.pop('labels'))
|
||||
|
||||
label_name = model_config.pop('label_name', None)
|
||||
if label_name:
|
||||
assert isinstance(label_name, (dict, list, tuple))
|
||||
label_names = model_config.pop('label_names', None)
|
||||
if label_names:
|
||||
assert isinstance(label_names, (dict, list, tuple))
|
||||
# map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
|
||||
# can be a dict id: name if there are id gaps, or tuple/list if no gaps.
|
||||
hf_config['label_name'] = model_config['label_name']
|
||||
hf_config['label_names'] = label_names
|
||||
|
||||
display_name = model_config.pop('display_name', None)
|
||||
if display_name:
|
||||
assert isinstance(display_name, dict)
|
||||
# map label_name -> user interface display name
|
||||
hf_config['display_name'] = model_config['display_name']
|
||||
label_descriptions = model_config.pop('label_descriptions', None)
|
||||
if label_descriptions:
|
||||
assert isinstance(label_descriptions, dict)
|
||||
# maps label names -> descriptions
|
||||
hf_config['label_descriptions'] = label_descriptions
|
||||
|
||||
hf_config['pretrained_cfg'] = pretrained_cfg
|
||||
hf_config.update(model_config)
|
||||
@ -188,7 +196,7 @@ def save_config_for_hf(model, config_path, model_config=None):
|
||||
json.dump(hf_config, f, indent=2)
|
||||
|
||||
|
||||
def save_for_hf(model, save_directory, model_config=None):
|
||||
def save_for_hf(model, save_directory: str, model_config: Optional[dict] = None):
|
||||
assert has_hf_hub(True)
|
||||
save_directory = Path(save_directory)
|
||||
save_directory.mkdir(exist_ok=True, parents=True)
|
||||
@ -249,7 +257,7 @@ def push_to_hf_hub(
|
||||
)
|
||||
|
||||
|
||||
def generate_readme(model_card, model_name):
|
||||
def generate_readme(model_card: dict, model_name: str):
|
||||
readme_text = "---\n"
|
||||
readme_text += "tags:\n- image-classification\n- timm\n"
|
||||
readme_text += "library_tag: timm\n"
|
||||
|
@ -34,9 +34,11 @@ class PretrainedCfg:
|
||||
mean: Tuple[float, ...] = (0.485, 0.456, 0.406)
|
||||
std: Tuple[float, ...] = (0.229, 0.224, 0.225)
|
||||
|
||||
# head config
|
||||
# head / classifier config and meta-data
|
||||
num_classes: int = 1000
|
||||
label_offset: Optional[int] = None
|
||||
label_names: Optional[Tuple[str]] = None
|
||||
label_descriptions: Optional[Dict[str, str]] = None
|
||||
|
||||
# model attributes that vary with above or required for pretrained adaptation
|
||||
pool_size: Optional[Tuple[int, ...]] = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user