mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Allow using class_map functionality w/ IterableDataset (TFDS/WDS) to remap class labels
This commit is contained in:
parent
01fdf44438
commit
c061d5e401
@ -88,6 +88,7 @@ class IterableImageDataset(data.IterableDataset):
|
||||
root,
|
||||
reader=None,
|
||||
split='train',
|
||||
class_map=None,
|
||||
is_training=False,
|
||||
batch_size=None,
|
||||
seed=42,
|
||||
@ -102,6 +103,7 @@ class IterableImageDataset(data.IterableDataset):
|
||||
reader,
|
||||
root=root,
|
||||
split=split,
|
||||
class_map=class_map,
|
||||
is_training=is_training,
|
||||
batch_size=batch_size,
|
||||
seed=seed,
|
||||
|
@ -157,6 +157,7 @@ def create_dataset(
|
||||
root,
|
||||
reader=name,
|
||||
split=split,
|
||||
class_map=class_map,
|
||||
is_training=is_training,
|
||||
download=download,
|
||||
batch_size=batch_size,
|
||||
@ -169,6 +170,7 @@ def create_dataset(
|
||||
root,
|
||||
reader=name,
|
||||
split=split,
|
||||
class_map=class_map,
|
||||
is_training=is_training,
|
||||
batch_size=batch_size,
|
||||
repeats=repeats,
|
||||
|
@ -34,6 +34,7 @@ except ImportError as e:
|
||||
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
||||
exit(1)
|
||||
|
||||
from .class_map import load_class_map
|
||||
from .reader import Reader
|
||||
from .shared_count import SharedCount
|
||||
|
||||
@ -94,6 +95,7 @@ class ReaderTfds(Reader):
|
||||
root,
|
||||
name,
|
||||
split='train',
|
||||
class_map=None,
|
||||
is_training=False,
|
||||
batch_size=None,
|
||||
download=False,
|
||||
@ -151,7 +153,12 @@ class ReaderTfds(Reader):
|
||||
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
|
||||
if download:
|
||||
self.builder.download_and_prepare()
|
||||
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
|
||||
self.remap_class = False
|
||||
if class_map:
|
||||
self.class_to_idx = load_class_map(class_map)
|
||||
self.remap_class = True
|
||||
else:
|
||||
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
|
||||
self.split_info = self.builder.info.splits[split]
|
||||
self.num_samples = self.split_info.num_examples
|
||||
|
||||
@ -299,6 +306,8 @@ class ReaderTfds(Reader):
|
||||
target_data = sample[self.target_name]
|
||||
if self.target_img_mode:
|
||||
target_data = Image.fromarray(target_data, mode=self.target_img_mode)
|
||||
elif self.remap_class:
|
||||
target_data = self.class_to_idx[target_data]
|
||||
yield input_data, target_data
|
||||
sample_count += 1
|
||||
if self.is_training and sample_count >= target_sample_count:
|
||||
|
@ -29,6 +29,7 @@ except ImportError:
|
||||
wds = None
|
||||
expand_urls = None
|
||||
|
||||
from .class_map import load_class_map
|
||||
from .reader import Reader
|
||||
from .shared_count import SharedCount
|
||||
|
||||
@ -42,13 +43,13 @@ def _load_info(root, basename='info'):
|
||||
info_yaml = os.path.join(root, basename + '.yaml')
|
||||
err_str = ''
|
||||
try:
|
||||
with wds.gopen.gopen(info_json) as f:
|
||||
with wds.gopen(info_json) as f:
|
||||
info_dict = json.load(f)
|
||||
return info_dict
|
||||
except Exception as e:
|
||||
err_str = str(e)
|
||||
try:
|
||||
with wds.gopen.gopen(info_yaml) as f:
|
||||
with wds.gopen(info_yaml) as f:
|
||||
info_dict = yaml.safe_load(f)
|
||||
return info_dict
|
||||
except Exception:
|
||||
@ -110,8 +111,8 @@ def _parse_split_info(split: str, info: Dict):
|
||||
filenames=split_filenames,
|
||||
)
|
||||
else:
|
||||
if split not in info['splits']:
|
||||
raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})")
|
||||
if 'splits' not in info or split not in info['splits']:
|
||||
raise RuntimeError(f"split {split} not found in info ({info.get('splits', {}).keys()})")
|
||||
split = split
|
||||
split_info = info['splits'][split]
|
||||
split_info = _info_convert(split_info)
|
||||
@ -290,6 +291,7 @@ class ReaderWds(Reader):
|
||||
batch_size=None,
|
||||
repeats=0,
|
||||
seed=42,
|
||||
class_map=None,
|
||||
input_name='jpg',
|
||||
input_image='RGB',
|
||||
target_name='cls',
|
||||
@ -320,6 +322,12 @@ class ReaderWds(Reader):
|
||||
self.num_samples = self.split_info.num_samples
|
||||
if not self.num_samples:
|
||||
raise RuntimeError(f'Invalid split definition, no samples found.')
|
||||
self.remap_class = False
|
||||
if class_map:
|
||||
self.class_to_idx = load_class_map(class_map)
|
||||
self.remap_class = True
|
||||
else:
|
||||
self.class_to_idx = {}
|
||||
|
||||
# Distributed world state
|
||||
self.dist_rank = 0
|
||||
@ -431,7 +439,10 @@ class ReaderWds(Reader):
|
||||
i = 0
|
||||
# _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug
|
||||
for sample in ds:
|
||||
yield sample[self.image_key], sample[self.target_key]
|
||||
target = sample[self.target_key]
|
||||
if self.remap_class:
|
||||
target = self.class_to_idx[target]
|
||||
yield sample[self.image_key], target
|
||||
i += 1
|
||||
# _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user