mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Allow using class_map functionality w/ IterableDataset (TFDS/WDS) to remap class labels
This commit is contained in:
parent
01fdf44438
commit
c061d5e401
@ -88,6 +88,7 @@ class IterableImageDataset(data.IterableDataset):
|
|||||||
root,
|
root,
|
||||||
reader=None,
|
reader=None,
|
||||||
split='train',
|
split='train',
|
||||||
|
class_map=None,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
seed=42,
|
seed=42,
|
||||||
@ -102,6 +103,7 @@ class IterableImageDataset(data.IterableDataset):
|
|||||||
reader,
|
reader,
|
||||||
root=root,
|
root=root,
|
||||||
split=split,
|
split=split,
|
||||||
|
class_map=class_map,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
@ -157,6 +157,7 @@ def create_dataset(
|
|||||||
root,
|
root,
|
||||||
reader=name,
|
reader=name,
|
||||||
split=split,
|
split=split,
|
||||||
|
class_map=class_map,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
download=download,
|
download=download,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@ -169,6 +170,7 @@ def create_dataset(
|
|||||||
root,
|
root,
|
||||||
reader=name,
|
reader=name,
|
||||||
split=split,
|
split=split,
|
||||||
|
class_map=class_map,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
repeats=repeats,
|
repeats=repeats,
|
||||||
|
@ -34,6 +34,7 @@ except ImportError as e:
|
|||||||
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
from .class_map import load_class_map
|
||||||
from .reader import Reader
|
from .reader import Reader
|
||||||
from .shared_count import SharedCount
|
from .shared_count import SharedCount
|
||||||
|
|
||||||
@ -94,6 +95,7 @@ class ReaderTfds(Reader):
|
|||||||
root,
|
root,
|
||||||
name,
|
name,
|
||||||
split='train',
|
split='train',
|
||||||
|
class_map=None,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
download=False,
|
download=False,
|
||||||
@ -151,6 +153,11 @@ class ReaderTfds(Reader):
|
|||||||
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
|
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
|
||||||
if download:
|
if download:
|
||||||
self.builder.download_and_prepare()
|
self.builder.download_and_prepare()
|
||||||
|
self.remap_class = False
|
||||||
|
if class_map:
|
||||||
|
self.class_to_idx = load_class_map(class_map)
|
||||||
|
self.remap_class = True
|
||||||
|
else:
|
||||||
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
|
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
|
||||||
self.split_info = self.builder.info.splits[split]
|
self.split_info = self.builder.info.splits[split]
|
||||||
self.num_samples = self.split_info.num_examples
|
self.num_samples = self.split_info.num_examples
|
||||||
@ -299,6 +306,8 @@ class ReaderTfds(Reader):
|
|||||||
target_data = sample[self.target_name]
|
target_data = sample[self.target_name]
|
||||||
if self.target_img_mode:
|
if self.target_img_mode:
|
||||||
target_data = Image.fromarray(target_data, mode=self.target_img_mode)
|
target_data = Image.fromarray(target_data, mode=self.target_img_mode)
|
||||||
|
elif self.remap_class:
|
||||||
|
target_data = self.class_to_idx[target_data]
|
||||||
yield input_data, target_data
|
yield input_data, target_data
|
||||||
sample_count += 1
|
sample_count += 1
|
||||||
if self.is_training and sample_count >= target_sample_count:
|
if self.is_training and sample_count >= target_sample_count:
|
||||||
|
@ -29,6 +29,7 @@ except ImportError:
|
|||||||
wds = None
|
wds = None
|
||||||
expand_urls = None
|
expand_urls = None
|
||||||
|
|
||||||
|
from .class_map import load_class_map
|
||||||
from .reader import Reader
|
from .reader import Reader
|
||||||
from .shared_count import SharedCount
|
from .shared_count import SharedCount
|
||||||
|
|
||||||
@ -42,13 +43,13 @@ def _load_info(root, basename='info'):
|
|||||||
info_yaml = os.path.join(root, basename + '.yaml')
|
info_yaml = os.path.join(root, basename + '.yaml')
|
||||||
err_str = ''
|
err_str = ''
|
||||||
try:
|
try:
|
||||||
with wds.gopen.gopen(info_json) as f:
|
with wds.gopen(info_json) as f:
|
||||||
info_dict = json.load(f)
|
info_dict = json.load(f)
|
||||||
return info_dict
|
return info_dict
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err_str = str(e)
|
err_str = str(e)
|
||||||
try:
|
try:
|
||||||
with wds.gopen.gopen(info_yaml) as f:
|
with wds.gopen(info_yaml) as f:
|
||||||
info_dict = yaml.safe_load(f)
|
info_dict = yaml.safe_load(f)
|
||||||
return info_dict
|
return info_dict
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -110,8 +111,8 @@ def _parse_split_info(split: str, info: Dict):
|
|||||||
filenames=split_filenames,
|
filenames=split_filenames,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if split not in info['splits']:
|
if 'splits' not in info or split not in info['splits']:
|
||||||
raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})")
|
raise RuntimeError(f"split {split} not found in info ({info.get('splits', {}).keys()})")
|
||||||
split = split
|
split = split
|
||||||
split_info = info['splits'][split]
|
split_info = info['splits'][split]
|
||||||
split_info = _info_convert(split_info)
|
split_info = _info_convert(split_info)
|
||||||
@ -290,6 +291,7 @@ class ReaderWds(Reader):
|
|||||||
batch_size=None,
|
batch_size=None,
|
||||||
repeats=0,
|
repeats=0,
|
||||||
seed=42,
|
seed=42,
|
||||||
|
class_map=None,
|
||||||
input_name='jpg',
|
input_name='jpg',
|
||||||
input_image='RGB',
|
input_image='RGB',
|
||||||
target_name='cls',
|
target_name='cls',
|
||||||
@ -320,6 +322,12 @@ class ReaderWds(Reader):
|
|||||||
self.num_samples = self.split_info.num_samples
|
self.num_samples = self.split_info.num_samples
|
||||||
if not self.num_samples:
|
if not self.num_samples:
|
||||||
raise RuntimeError(f'Invalid split definition, no samples found.')
|
raise RuntimeError(f'Invalid split definition, no samples found.')
|
||||||
|
self.remap_class = False
|
||||||
|
if class_map:
|
||||||
|
self.class_to_idx = load_class_map(class_map)
|
||||||
|
self.remap_class = True
|
||||||
|
else:
|
||||||
|
self.class_to_idx = {}
|
||||||
|
|
||||||
# Distributed world state
|
# Distributed world state
|
||||||
self.dist_rank = 0
|
self.dist_rank = 0
|
||||||
@ -431,7 +439,10 @@ class ReaderWds(Reader):
|
|||||||
i = 0
|
i = 0
|
||||||
# _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug
|
# _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug
|
||||||
for sample in ds:
|
for sample in ds:
|
||||||
yield sample[self.image_key], sample[self.target_key]
|
target = sample[self.target_key]
|
||||||
|
if self.remap_class:
|
||||||
|
target = self.class_to_idx[target]
|
||||||
|
yield sample[self.image_key], target
|
||||||
i += 1
|
i += 1
|
||||||
# _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug
|
# _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user