Fix pass through of input / target keys so ImageDataset readers so args work with hfds instead of just hfids (iterable)

This commit is contained in:
Ross Wightman 2024-07-17 10:11:46 -07:00
parent 3196d6b131
commit 34c9fee554
2 changed files with 6 additions and 4 deletions

View File

@ -30,13 +30,15 @@ class ImageDataset(data.Dataset):
input_img_mode='RGB', input_img_mode='RGB',
transform=None, transform=None,
target_transform=None, target_transform=None,
**kwargs,
): ):
if reader is None or isinstance(reader, str): if reader is None or isinstance(reader, str):
reader = create_reader( reader = create_reader(
reader or '', reader or '',
root=root, root=root,
split=split, split=split,
class_map=class_map class_map=class_map,
**kwargs,
) )
self.reader = reader self.reader = reader
self.load_bytes = load_bytes self.load_bytes = load_bytes

View File

@ -35,7 +35,7 @@ class ReaderHfds(Reader):
root: Optional[str] = None, root: Optional[str] = None,
split: str = 'train', split: str = 'train',
class_map: dict = None, class_map: dict = None,
image_key: str = 'image', input_key: str = 'image',
target_key: str = 'label', target_key: str = 'label',
download: bool = False, download: bool = False,
): ):
@ -50,9 +50,9 @@ class ReaderHfds(Reader):
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
) )
# leave decode for caller, plus we want easy access to original path names... # leave decode for caller, plus we want easy access to original path names...
self.dataset = self.dataset.cast_column(image_key, datasets.Image(decode=False)) self.dataset = self.dataset.cast_column(input_key, datasets.Image(decode=False))
self.image_key = image_key self.image_key = input_key
self.label_key = target_key self.label_key = target_key
self.remap_class = False self.remap_class = False
if class_map: if class_map: