mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #1662 from rwightman/dataset_info
ImageNet metadata (info) and labelling update
This commit is contained in:
commit
88a5b8491d
@ -1,2 +1,3 @@
|
||||
include timm/models/pruned/*.txt
|
||||
|
||||
include timm/models/_pruned/*.txt
|
||||
include timm/data/_info/*.txt
|
||||
include timm/data/_info/*.json
|
||||
|
56
inference.py
56
inference.py
@ -17,7 +17,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config, ImageNetInfo, infer_imagenet_subset
|
||||
from timm.layers import apply_test_time_pool
|
||||
from timm.models import create_model
|
||||
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
|
||||
@ -46,6 +46,7 @@ has_compile = hasattr(torch, 'compile')
|
||||
|
||||
_FMT_EXT = {
|
||||
'json': '.json',
|
||||
'json-record': '.json',
|
||||
'json-split': '.json',
|
||||
'parquet': '.parquet',
|
||||
'csv': '.csv',
|
||||
@ -122,7 +123,7 @@ scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None
|
||||
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
|
||||
help="Enable AOT Autograd support.")
|
||||
|
||||
parser.add_argument('--results-dir',type=str, default=None,
|
||||
parser.add_argument('--results-dir', type=str, default=None,
|
||||
help='folder for output results')
|
||||
parser.add_argument('--results-file', type=str, default=None,
|
||||
help='results filename (relative to results-dir)')
|
||||
@ -134,14 +135,20 @@ parser.add_argument('--topk', default=1, type=int,
|
||||
metavar='N', help='Top-k to output to CSV')
|
||||
parser.add_argument('--fullname', action='store_true', default=False,
|
||||
help='use full sample name in output (not just basename).')
|
||||
parser.add_argument('--filename-col', default='filename',
|
||||
parser.add_argument('--filename-col', type=str, default='filename',
|
||||
help='name for filename / sample name column')
|
||||
parser.add_argument('--index-col', default='index',
|
||||
parser.add_argument('--index-col', type=str, default='index',
|
||||
help='name for output indices column(s)')
|
||||
parser.add_argument('--output-col', default=None,
|
||||
parser.add_argument('--label-col', type=str, default='label',
|
||||
help='name for output indices column(s)')
|
||||
parser.add_argument('--output-col', type=str, default=None,
|
||||
help='name for logit/probs output column(s)')
|
||||
parser.add_argument('--output-type', default='prob',
|
||||
parser.add_argument('--output-type', type=str, default='prob',
|
||||
help='output type colum ("prob" for probabilities, "logit" for raw logits)')
|
||||
parser.add_argument('--label-type', type=str, default='description',
|
||||
help='type of label to output, one of "none", "name", "description", "detailed"')
|
||||
parser.add_argument('--include-index', action='store_true', default=False,
|
||||
help='include the class index in results')
|
||||
parser.add_argument('--exclude-output', action='store_true', default=False,
|
||||
help='exclude logits/probs from results, just indices. topk must be set !=0.')
|
||||
|
||||
@ -237,10 +244,26 @@ def main():
|
||||
**data_config,
|
||||
)
|
||||
|
||||
to_label = None
|
||||
if args.label_type in ('name', 'description', 'detail'):
|
||||
imagenet_subset = infer_imagenet_subset(model)
|
||||
if imagenet_subset is not None:
|
||||
dataset_info = ImageNetInfo(imagenet_subset)
|
||||
if args.label_type == 'name':
|
||||
to_label = lambda x: dataset_info.index_to_label_name(x)
|
||||
elif args.label_type == 'detail':
|
||||
to_label = lambda x: dataset_info.index_to_description(x, detailed=True)
|
||||
else:
|
||||
to_label = lambda x: dataset_info.index_to_description(x)
|
||||
to_label = np.vectorize(to_label)
|
||||
else:
|
||||
_logger.error("Cannot deduce ImageNet subset from model, no labelling will be performed.")
|
||||
|
||||
top_k = min(args.topk, args.num_classes)
|
||||
batch_time = AverageMeter()
|
||||
end = time.time()
|
||||
all_indices = []
|
||||
all_labels = []
|
||||
all_outputs = []
|
||||
use_probs = args.output_type == 'prob'
|
||||
with torch.no_grad():
|
||||
@ -254,7 +277,12 @@ def main():
|
||||
|
||||
if top_k:
|
||||
output, indices = output.topk(top_k)
|
||||
all_indices.append(indices.cpu().numpy())
|
||||
np_indices = indices.cpu().numpy()
|
||||
if args.include_index:
|
||||
all_indices.append(np_indices)
|
||||
if to_label is not None:
|
||||
np_labels = to_label(np_indices)
|
||||
all_labels.append(np_labels)
|
||||
|
||||
all_outputs.append(output.cpu().numpy())
|
||||
|
||||
@ -267,6 +295,7 @@ def main():
|
||||
batch_idx, len(loader), batch_time=batch_time))
|
||||
|
||||
all_indices = np.concatenate(all_indices, axis=0) if all_indices else None
|
||||
all_labels = np.concatenate(all_labels, axis=0) if all_labels else None
|
||||
all_outputs = np.concatenate(all_outputs, axis=0).astype(np.float32)
|
||||
filenames = loader.dataset.filenames(basename=not args.fullname)
|
||||
|
||||
@ -276,6 +305,9 @@ def main():
|
||||
if all_indices is not None:
|
||||
for i in range(all_indices.shape[-1]):
|
||||
data_dict[f'{args.index_col}_{i}'] = all_indices[:, i]
|
||||
if all_labels is not None:
|
||||
for i in range(all_labels.shape[-1]):
|
||||
data_dict[f'{args.label_col}_{i}'] = all_labels[:, i]
|
||||
for i in range(all_outputs.shape[-1]):
|
||||
data_dict[f'{output_col}_{i}'] = all_outputs[:, i]
|
||||
else:
|
||||
@ -283,6 +315,10 @@ def main():
|
||||
if all_indices.shape[-1] == 1:
|
||||
all_indices = all_indices.squeeze(-1)
|
||||
data_dict[args.index_col] = list(all_indices)
|
||||
if all_labels is not None:
|
||||
if all_labels.shape[-1] == 1:
|
||||
all_labels = all_labels.squeeze(-1)
|
||||
data_dict[args.label_col] = list(all_labels)
|
||||
if all_outputs.shape[-1] == 1:
|
||||
all_outputs = all_outputs.squeeze(-1)
|
||||
data_dict[output_col] = list(all_outputs)
|
||||
@ -291,7 +327,7 @@ def main():
|
||||
|
||||
results_filename = args.results_file
|
||||
if results_filename:
|
||||
filename_no_ext, ext = os.path.splitext(results_filename)[-1]
|
||||
filename_no_ext, ext = os.path.splitext(results_filename)
|
||||
if ext and ext in _FMT_EXT.values():
|
||||
# if filename provided with one of expected ext,
|
||||
# remove it as it will be added back
|
||||
@ -308,7 +344,7 @@ def main():
|
||||
save_results(df, results_filename, fmt)
|
||||
|
||||
print(f'--result')
|
||||
print(json.dumps(dict(filename=results_filename)))
|
||||
print(df.set_index(args.filename_col).to_json(orient='index', indent=4))
|
||||
|
||||
|
||||
def save_results(df, results_filename, results_format='csv', filename_col='filename'):
|
||||
@ -316,6 +352,8 @@ def save_results(df, results_filename, results_format='csv', filename_col='filen
|
||||
if results_format == 'parquet':
|
||||
df.set_index(filename_col).to_parquet(results_filename)
|
||||
elif results_format == 'json':
|
||||
df.set_index(filename_col).to_json(results_filename, indent=4, orient='index')
|
||||
elif results_format == 'json-records':
|
||||
df.to_json(results_filename, lines=True, orient='records')
|
||||
elif results_format == 'json-split':
|
||||
df.to_json(results_filename, indent=4, orient='split', index=False)
|
||||
|
@ -4,6 +4,8 @@ from .config import resolve_data_config, resolve_model_data_config
|
||||
from .constants import *
|
||||
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
|
||||
from .dataset_factory import create_dataset
|
||||
from .dataset_info import DatasetInfo
|
||||
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
|
||||
from .loader import create_loader
|
||||
from .mixup import Mixup, FastCollateMixup
|
||||
from .readers import create_reader
|
||||
|
11221
timm/data/_info/imagenet21k_miil_synsets.txt
Normal file
11221
timm/data/_info/imagenet21k_miil_synsets.txt
Normal file
File diff suppressed because it is too large
Load Diff
10450
timm/data/_info/imagenet21k_miil_w21_synsets.txt
Normal file
10450
timm/data/_info/imagenet21k_miil_w21_synsets.txt
Normal file
File diff suppressed because it is too large
Load Diff
21844
timm/data/_info/imagenet_synset_to_definition.txt
Normal file
21844
timm/data/_info/imagenet_synset_to_definition.txt
Normal file
File diff suppressed because it is too large
Load Diff
21844
timm/data/_info/imagenet_synset_to_lemma.txt
Normal file
21844
timm/data/_info/imagenet_synset_to_lemma.txt
Normal file
File diff suppressed because it is too large
Load Diff
32
timm/data/dataset_info.py
Normal file
32
timm/data/dataset_info.py
Normal file
@ -0,0 +1,32 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Union
|
||||
|
||||
|
||||
class DatasetInfo(ABC):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def num_classes(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def label_names(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def index_to_label_name(self, index) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def index_to_description(self, index: int, detailed: bool = False) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
|
||||
pass
|
92
timm/data/imagenet_info.py
Normal file
92
timm/data/imagenet_info.py
Normal file
@ -0,0 +1,92 @@
|
||||
import csv
|
||||
import os
|
||||
import pkgutil
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from .dataset_info import DatasetInfo
|
||||
|
||||
|
||||
_NUM_CLASSES_TO_SUBSET = {
|
||||
1000: 'imagenet-1k',
|
||||
11821: 'imagenet-12k',
|
||||
21841: 'imagenet-22k',
|
||||
21843: 'imagenet-21k-goog',
|
||||
11221: 'imagenet-21k-miil',
|
||||
}
|
||||
|
||||
_SUBSETS = {
|
||||
'imagenet1k': 'imagenet_synsets.txt',
|
||||
'imagenet12k': 'imagenet12k_synsets.txt',
|
||||
'imagenet22k': 'imagenet22k_synsets.txt',
|
||||
'imagenet21k': 'imagenet21k_goog_synsets.txt',
|
||||
'imagenet21kgoog': 'imagenet21k_goog_synsets.txt',
|
||||
'imagenet21kmiil': 'imagenet21k_miil_synsets.txt',
|
||||
}
|
||||
_LEMMA_FILE = 'imagenet_synset_to_lemma.txt'
|
||||
_DEFINITION_FILE = 'imagenet_synset_to_definition.txt'
|
||||
|
||||
|
||||
def infer_imagenet_subset(model_or_cfg) -> Optional[str]:
|
||||
if isinstance(model_or_cfg, dict):
|
||||
num_classes = model_or_cfg.get('num_classes', None)
|
||||
else:
|
||||
num_classes = getattr(model_or_cfg, 'num_classes', None)
|
||||
if not num_classes:
|
||||
pretrained_cfg = getattr(model_or_cfg, 'pretrained_cfg', {})
|
||||
# FIXME at some point pretrained_cfg should include dataset-tag,
|
||||
# which will be more robust than a guess based on num_classes
|
||||
num_classes = pretrained_cfg.get('num_classes', None)
|
||||
if not num_classes or num_classes not in _NUM_CLASSES_TO_SUBSET:
|
||||
return None
|
||||
return _NUM_CLASSES_TO_SUBSET[num_classes]
|
||||
|
||||
|
||||
class ImageNetInfo(DatasetInfo):
|
||||
|
||||
def __init__(self, subset: str = 'imagenet-1k'):
|
||||
super().__init__()
|
||||
subset = re.sub(r'[-_\s]', '', subset.lower())
|
||||
assert subset in _SUBSETS, f'Unknown imagenet subset {subset}.'
|
||||
|
||||
# WordNet synsets (part-of-speach + offset) are the unique class label names for ImageNet classifiers
|
||||
synset_file = _SUBSETS[subset]
|
||||
synset_data = pkgutil.get_data(__name__, os.path.join('_info', synset_file))
|
||||
self._synsets = synset_data.decode('utf-8').splitlines()
|
||||
|
||||
# WordNet lemmas (canonical dictionary form of word) and definitions are used to build
|
||||
# the class descriptions. If detailed=True both are used, otherwise just the lemmas.
|
||||
lemma_data = pkgutil.get_data(__name__, os.path.join('_info', _LEMMA_FILE))
|
||||
reader = csv.reader(lemma_data.decode('utf-8').splitlines(), delimiter='\t')
|
||||
self._lemmas = dict(reader)
|
||||
definition_data = pkgutil.get_data(__name__, os.path.join('_info', _DEFINITION_FILE))
|
||||
reader = csv.reader(definition_data.decode('utf-8').splitlines(), delimiter='\t')
|
||||
self._definitions = dict(reader)
|
||||
|
||||
def num_classes(self):
|
||||
return len(self._synsets)
|
||||
|
||||
def label_names(self):
|
||||
return self._synsets
|
||||
|
||||
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
|
||||
if as_dict:
|
||||
return {label: self.label_name_to_description(label, detailed=detailed) for label in self._synsets}
|
||||
else:
|
||||
return [self.label_name_to_description(label, detailed=detailed) for label in self._synsets]
|
||||
|
||||
def index_to_label_name(self, index) -> str:
|
||||
assert 0 <= index < len(self._synsets), \
|
||||
f'Index ({index}) out of range for dataset with {len(self._synsets)} classes.'
|
||||
return self._synsets[index]
|
||||
|
||||
def index_to_description(self, index: int, detailed: bool = False) -> str:
|
||||
label = self.index_to_label_name(index)
|
||||
return self.label_name_to_description(label, detailed=detailed)
|
||||
|
||||
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
|
||||
if detailed:
|
||||
description = f'{self._lemmas[label]}: {self._definitions[label]}'
|
||||
else:
|
||||
description = f'{self._lemmas[label]}'
|
||||
return description
|
@ -7,14 +7,19 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
import pkgutil
|
||||
|
||||
|
||||
class RealLabelsImagenet:
|
||||
|
||||
def __init__(self, filenames, real_json='real.json', topk=(1, 5)):
|
||||
with open(real_json) as real_labels:
|
||||
real_labels = json.load(real_labels)
|
||||
real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)}
|
||||
def __init__(self, filenames, real_json=None, topk=(1, 5)):
|
||||
if real_json is not None:
|
||||
with open(real_json) as real_labels:
|
||||
real_labels = json.load(real_labels)
|
||||
else:
|
||||
real_labels = json.loads(
|
||||
pkgutil.get_data(__name__, os.path.join('_info', 'imagenet_real_labels.json')).decode('utf-8'))
|
||||
real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)}
|
||||
self.real_labels = real_labels
|
||||
self.filenames = filenames
|
||||
assert len(self.filenames) == len(self.real_labels)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import pkgutil
|
||||
from copy import deepcopy
|
||||
|
||||
from torch import nn as nn
|
||||
@ -108,6 +109,5 @@ def adapt_model_from_string(parent_module, model_string):
|
||||
|
||||
|
||||
def adapt_model_from_file(parent_module, model_variant):
|
||||
adapt_file = os.path.join(os.path.dirname(__file__), '_pruned', model_variant + '.txt')
|
||||
with open(adapt_file, 'r') as f:
|
||||
return adapt_model_from_string(parent_module, f.read().strip())
|
||||
adapt_data = pkgutil.get_data(__name__, os.path.join('_pruned', model_variant + '.txt'))
|
||||
return adapt_model_from_string(parent_module, adapt_data.decode('utf-8').strip())
|
||||
|
Loading…
x
Reference in New Issue
Block a user