mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Support HF datasets and TFSD w/ a sub-path by fixing split, fix #1598 ... add class mapping support to HF datasets in case class label isn't in info.
This commit is contained in:
parent
35fb00c779
commit
d1bfa9a000
@ -151,7 +151,7 @@ def create_dataset(
|
|||||||
elif name.startswith('hfds/'):
|
elif name.startswith('hfds/'):
|
||||||
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
|
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
|
||||||
# There will be a IterableDataset variant too, TBD
|
# There will be a IterableDataset variant too, TBD
|
||||||
ds = ImageDataset(root, reader=name, split=split, **kwargs)
|
ds = ImageDataset(root, reader=name, split=split, class_map=class_map, **kwargs)
|
||||||
elif name.startswith('tfds/'):
|
elif name.startswith('tfds/'):
|
||||||
ds = IterableImageDataset(
|
ds = IterableImageDataset(
|
||||||
root,
|
root,
|
||||||
|
@ -6,7 +6,7 @@ from .reader_image_in_tar import ReaderImageInTar
|
|||||||
|
|
||||||
def create_reader(name, root, split='train', **kwargs):
|
def create_reader(name, root, split='train', **kwargs):
|
||||||
name = name.lower()
|
name = name.lower()
|
||||||
name = name.split('/', 2)
|
name = name.split('/', 1)
|
||||||
prefix = ''
|
prefix = ''
|
||||||
if len(name) > 1:
|
if len(name) > 1:
|
||||||
prefix = name[0]
|
prefix = name[0]
|
||||||
|
@ -13,13 +13,14 @@ try:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print("Please install Hugging Face datasets package `pip install datasets`.")
|
print("Please install Hugging Face datasets package `pip install datasets`.")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
from .class_map import load_class_map
|
||||||
from .reader import Reader
|
from .reader import Reader
|
||||||
|
|
||||||
|
|
||||||
def get_class_labels(info):
|
def get_class_labels(info, label_key='label'):
|
||||||
if 'label' not in info.features:
|
if 'label' not in info.features:
|
||||||
return {}
|
return {}
|
||||||
class_label = info.features['label']
|
class_label = info.features[label_key]
|
||||||
class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
|
class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
|
||||||
return class_to_idx
|
return class_to_idx
|
||||||
|
|
||||||
@ -32,6 +33,7 @@ class ReaderHfds(Reader):
|
|||||||
name,
|
name,
|
||||||
split='train',
|
split='train',
|
||||||
class_map=None,
|
class_map=None,
|
||||||
|
label_key='label',
|
||||||
download=False,
|
download=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -43,12 +45,17 @@ class ReaderHfds(Reader):
|
|||||||
name, # 'name' maps to path arg in hf datasets
|
name, # 'name' maps to path arg in hf datasets
|
||||||
split=split,
|
split=split,
|
||||||
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
|
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
|
||||||
#use_auth_token=True,
|
|
||||||
)
|
)
|
||||||
# leave decode for caller, plus we want easy access to original path names...
|
# leave decode for caller, plus we want easy access to original path names...
|
||||||
self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))
|
self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))
|
||||||
|
|
||||||
self.class_to_idx = get_class_labels(self.dataset.info)
|
self.label_key = label_key
|
||||||
|
self.remap_class = False
|
||||||
|
if class_map:
|
||||||
|
self.class_to_idx = load_class_map(class_map)
|
||||||
|
self.remap_class = True
|
||||||
|
else:
|
||||||
|
self.class_to_idx = get_class_labels(self.dataset.info, self.label_key)
|
||||||
self.split_info = self.dataset.info.splits[split]
|
self.split_info = self.dataset.info.splits[split]
|
||||||
self.num_samples = self.split_info.num_examples
|
self.num_samples = self.split_info.num_examples
|
||||||
|
|
||||||
@ -60,7 +67,10 @@ class ReaderHfds(Reader):
|
|||||||
else:
|
else:
|
||||||
assert 'path' in image and image['path']
|
assert 'path' in image and image['path']
|
||||||
image = open(image['path'], 'rb')
|
image = open(image['path'], 'rb')
|
||||||
return image, item['label']
|
label = item[self.label_key]
|
||||||
|
if self.remap_class:
|
||||||
|
label = self.class_to_idx[label]
|
||||||
|
return image, label
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.dataset)
|
return len(self.dataset)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user