Fix pass through of input / target keys so ImageDataset readers so args work with hfds instead of just hfids (iterable)

pull/2236/head
Ross Wightman 2024-07-17 10:11:46 -07:00
parent 3196d6b131
commit 34c9fee554
2 changed files with 6 additions and 4 deletions

View File

@ -30,13 +30,15 @@ class ImageDataset(data.Dataset):
input_img_mode='RGB',
transform=None,
target_transform=None,
**kwargs,
):
if reader is None or isinstance(reader, str):
reader = create_reader(
reader or '',
root=root,
split=split,
class_map=class_map
class_map=class_map,
**kwargs,
)
self.reader = reader
self.load_bytes = load_bytes

View File

@ -35,7 +35,7 @@ class ReaderHfds(Reader):
root: Optional[str] = None,
split: str = 'train',
class_map: dict = None,
image_key: str = 'image',
input_key: str = 'image',
target_key: str = 'label',
download: bool = False,
):
@ -50,9 +50,9 @@ class ReaderHfds(Reader):
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
)
# leave decode for caller, plus we want easy access to original path names...
self.dataset = self.dataset.cast_column(image_key, datasets.Image(decode=False))
self.dataset = self.dataset.cast_column(input_key, datasets.Image(decode=False))
self.image_key = image_key
self.image_key = input_key
self.label_key = target_key
self.remap_class = False
if class_map: