Fix pass through of input / target keys so ImageDataset readers so args work with hfds instead of just hfids (iterable)
parent
3196d6b131
commit
34c9fee554
|
@ -30,13 +30,15 @@ class ImageDataset(data.Dataset):
|
|||
input_img_mode='RGB',
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
**kwargs,
|
||||
):
|
||||
if reader is None or isinstance(reader, str):
|
||||
reader = create_reader(
|
||||
reader or '',
|
||||
root=root,
|
||||
split=split,
|
||||
class_map=class_map
|
||||
class_map=class_map,
|
||||
**kwargs,
|
||||
)
|
||||
self.reader = reader
|
||||
self.load_bytes = load_bytes
|
||||
|
|
|
@ -35,7 +35,7 @@ class ReaderHfds(Reader):
|
|||
root: Optional[str] = None,
|
||||
split: str = 'train',
|
||||
class_map: dict = None,
|
||||
image_key: str = 'image',
|
||||
input_key: str = 'image',
|
||||
target_key: str = 'label',
|
||||
download: bool = False,
|
||||
):
|
||||
|
@ -50,9 +50,9 @@ class ReaderHfds(Reader):
|
|||
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
|
||||
)
|
||||
# leave decode for caller, plus we want easy access to original path names...
|
||||
self.dataset = self.dataset.cast_column(image_key, datasets.Image(decode=False))
|
||||
self.dataset = self.dataset.cast_column(input_key, datasets.Image(decode=False))
|
||||
|
||||
self.image_key = image_key
|
||||
self.image_key = input_key
|
||||
self.label_key = target_key
|
||||
self.remap_class = False
|
||||
if class_map:
|
||||
|
|
Loading…
Reference in New Issue