mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Support HF datasets and TFSD w/ a sub-path by fixing split, fix #1598 ... add class mapping support to HF datasets in case class label isn't in info.
This commit is contained in:
parent
35fb00c779
commit
d1bfa9a000
@ -151,7 +151,7 @@ def create_dataset(
|
||||
elif name.startswith('hfds/'):
|
||||
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
|
||||
# There will be a IterableDataset variant too, TBD
|
||||
ds = ImageDataset(root, reader=name, split=split, **kwargs)
|
||||
ds = ImageDataset(root, reader=name, split=split, class_map=class_map, **kwargs)
|
||||
elif name.startswith('tfds/'):
|
||||
ds = IterableImageDataset(
|
||||
root,
|
||||
|
@ -6,7 +6,7 @@ from .reader_image_in_tar import ReaderImageInTar
|
||||
|
||||
def create_reader(name, root, split='train', **kwargs):
|
||||
name = name.lower()
|
||||
name = name.split('/', 2)
|
||||
name = name.split('/', 1)
|
||||
prefix = ''
|
||||
if len(name) > 1:
|
||||
prefix = name[0]
|
||||
|
@ -13,13 +13,14 @@ try:
|
||||
except ImportError as e:
|
||||
print("Please install Hugging Face datasets package `pip install datasets`.")
|
||||
exit(1)
|
||||
from .class_map import load_class_map
|
||||
from .reader import Reader
|
||||
|
||||
|
||||
def get_class_labels(info):
|
||||
def get_class_labels(info, label_key='label'):
|
||||
if 'label' not in info.features:
|
||||
return {}
|
||||
class_label = info.features['label']
|
||||
class_label = info.features[label_key]
|
||||
class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
|
||||
return class_to_idx
|
||||
|
||||
@ -32,6 +33,7 @@ class ReaderHfds(Reader):
|
||||
name,
|
||||
split='train',
|
||||
class_map=None,
|
||||
label_key='label',
|
||||
download=False,
|
||||
):
|
||||
"""
|
||||
@ -43,12 +45,17 @@ class ReaderHfds(Reader):
|
||||
name, # 'name' maps to path arg in hf datasets
|
||||
split=split,
|
||||
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
|
||||
#use_auth_token=True,
|
||||
)
|
||||
# leave decode for caller, plus we want easy access to original path names...
|
||||
self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))
|
||||
|
||||
self.class_to_idx = get_class_labels(self.dataset.info)
|
||||
self.label_key = label_key
|
||||
self.remap_class = False
|
||||
if class_map:
|
||||
self.class_to_idx = load_class_map(class_map)
|
||||
self.remap_class = True
|
||||
else:
|
||||
self.class_to_idx = get_class_labels(self.dataset.info, self.label_key)
|
||||
self.split_info = self.dataset.info.splits[split]
|
||||
self.num_samples = self.split_info.num_examples
|
||||
|
||||
@ -60,7 +67,10 @@ class ReaderHfds(Reader):
|
||||
else:
|
||||
assert 'path' in image and image['path']
|
||||
image = open(image['path'], 'rb')
|
||||
return image, item['label']
|
||||
label = item[self.label_key]
|
||||
if self.remap_class:
|
||||
label = self.class_to_idx[label]
|
||||
return image, label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
Loading…
x
Reference in New Issue
Block a user