mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix pass through of input / target keys so ImageDataset readers so args work with hfds instead of just hfids (iterable)
This commit is contained in:
parent
3196d6b131
commit
34c9fee554
@ -30,13 +30,15 @@ class ImageDataset(data.Dataset):
|
|||||||
input_img_mode='RGB',
|
input_img_mode='RGB',
|
||||||
transform=None,
|
transform=None,
|
||||||
target_transform=None,
|
target_transform=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if reader is None or isinstance(reader, str):
|
if reader is None or isinstance(reader, str):
|
||||||
reader = create_reader(
|
reader = create_reader(
|
||||||
reader or '',
|
reader or '',
|
||||||
root=root,
|
root=root,
|
||||||
split=split,
|
split=split,
|
||||||
class_map=class_map
|
class_map=class_map,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.reader = reader
|
self.reader = reader
|
||||||
self.load_bytes = load_bytes
|
self.load_bytes = load_bytes
|
||||||
|
@ -35,7 +35,7 @@ class ReaderHfds(Reader):
|
|||||||
root: Optional[str] = None,
|
root: Optional[str] = None,
|
||||||
split: str = 'train',
|
split: str = 'train',
|
||||||
class_map: dict = None,
|
class_map: dict = None,
|
||||||
image_key: str = 'image',
|
input_key: str = 'image',
|
||||||
target_key: str = 'label',
|
target_key: str = 'label',
|
||||||
download: bool = False,
|
download: bool = False,
|
||||||
):
|
):
|
||||||
@ -50,9 +50,9 @@ class ReaderHfds(Reader):
|
|||||||
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
|
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
|
||||||
)
|
)
|
||||||
# leave decode for caller, plus we want easy access to original path names...
|
# leave decode for caller, plus we want easy access to original path names...
|
||||||
self.dataset = self.dataset.cast_column(image_key, datasets.Image(decode=False))
|
self.dataset = self.dataset.cast_column(input_key, datasets.Image(decode=False))
|
||||||
|
|
||||||
self.image_key = image_key
|
self.image_key = input_key
|
||||||
self.label_key = target_key
|
self.label_key = target_key
|
||||||
self.remap_class = False
|
self.remap_class = False
|
||||||
if class_map:
|
if class_map:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user