mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2361 from huggingface/grodino-dataset_trust_remote
Dataset trust remote tweaks
This commit is contained in:
commit
ea231079f5
@ -103,6 +103,7 @@ class IterableImageDataset(data.IterableDataset):
|
|||||||
transform=None,
|
transform=None,
|
||||||
target_transform=None,
|
target_transform=None,
|
||||||
max_steps=None,
|
max_steps=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
assert reader is not None
|
assert reader is not None
|
||||||
if isinstance(reader, str):
|
if isinstance(reader, str):
|
||||||
@ -121,6 +122,7 @@ class IterableImageDataset(data.IterableDataset):
|
|||||||
input_key=input_key,
|
input_key=input_key,
|
||||||
target_key=target_key,
|
target_key=target_key,
|
||||||
max_steps=max_steps,
|
max_steps=max_steps,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.reader = reader
|
self.reader = reader
|
||||||
|
@ -74,34 +74,37 @@ def create_dataset(
|
|||||||
seed: int = 42,
|
seed: int = 42,
|
||||||
repeats: int = 0,
|
repeats: int = 0,
|
||||||
input_img_mode: str = 'RGB',
|
input_img_mode: str = 'RGB',
|
||||||
|
trust_remote_code: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
""" Dataset factory method
|
""" Dataset factory method
|
||||||
|
|
||||||
In parentheses after each arg are the type of dataset supported for each arg, one of:
|
In parentheses after each arg are the type of dataset supported for each arg, one of:
|
||||||
* folder - default, timm folder (or tar) based ImageDataset
|
* Folder - default, timm folder (or tar) based ImageDataset
|
||||||
* torch - torchvision based datasets
|
* Torch - torchvision based datasets
|
||||||
* HFDS - Hugging Face Datasets
|
* HFDS - Hugging Face Datasets
|
||||||
|
* HFIDS - Hugging Face Datasets Iterable (streaming mode, with IterableDataset)
|
||||||
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
|
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
|
||||||
* WDS - Webdataset
|
* WDS - Webdataset
|
||||||
* all - any of the above
|
* All - any of the above
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: dataset name, empty is okay for folder based datasets
|
name: Dataset name, empty is okay for folder based datasets
|
||||||
root: root folder of dataset (all)
|
root: Root folder of dataset (All)
|
||||||
split: dataset split (all)
|
split: Dataset split (All)
|
||||||
search_split: search for split specific child fold from root so one can specify
|
search_split: Search for split specific child fold from root so one can specify
|
||||||
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
|
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (Folder, Torch)
|
||||||
class_map: specify class -> index mapping via text file or dict (folder)
|
class_map: Specify class -> index mapping via text file or dict (Folder)
|
||||||
load_bytes: load data, return images as undecoded bytes (folder)
|
load_bytes: Load data, return images as undecoded bytes (Folder)
|
||||||
download: download dataset if not present and supported (HFDS, TFDS, torch)
|
download: Download dataset if not present and supported (HFIDS, TFDS, Torch)
|
||||||
is_training: create dataset in train mode, this is different from the split.
|
is_training: Create dataset in train mode, this is different from the split.
|
||||||
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS)
|
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS, HFIDS)
|
||||||
batch_size: batch size hint for (TFDS, WDS)
|
batch_size: Batch size hint for iterable datasets (TFDS, WDS, HFIDS)
|
||||||
seed: seed for iterable datasets (TFDS, WDS)
|
seed: Seed for iterable datasets (TFDS, WDS, HFIDS)
|
||||||
repeats: dataset repeats per iteration i.e. epoch (TFDS, WDS)
|
repeats: Dataset repeats per iteration i.e. epoch (TFDS, WDS, HFIDS)
|
||||||
input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS)
|
input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS, HFIDS)
|
||||||
**kwargs: other args to pass to dataset
|
trust_remote_code: Trust remote code in Hugging Face Datasets if True (HFDS, HFIDS)
|
||||||
|
**kwargs: Other args to pass through to underlying Dataset and/or Reader classes
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dataset object
|
Dataset object
|
||||||
@ -162,6 +165,7 @@ def create_dataset(
|
|||||||
split=split,
|
split=split,
|
||||||
class_map=class_map,
|
class_map=class_map,
|
||||||
input_img_mode=input_img_mode,
|
input_img_mode=input_img_mode,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif name.startswith('hfids/'):
|
elif name.startswith('hfids/'):
|
||||||
@ -177,7 +181,8 @@ def create_dataset(
|
|||||||
repeats=repeats,
|
repeats=repeats,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
input_img_mode=input_img_mode,
|
input_img_mode=input_img_mode,
|
||||||
**kwargs
|
trust_remote_code=trust_remote_code,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif name.startswith('tfds/'):
|
elif name.startswith('tfds/'):
|
||||||
ds = IterableImageDataset(
|
ds = IterableImageDataset(
|
||||||
|
@ -48,7 +48,7 @@ class ReaderHfds(Reader):
|
|||||||
self.dataset = datasets.load_dataset(
|
self.dataset = datasets.load_dataset(
|
||||||
name, # 'name' maps to path arg in hf datasets
|
name, # 'name' maps to path arg in hf datasets
|
||||||
split=split,
|
split=split,
|
||||||
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
|
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path if root set
|
||||||
trust_remote_code=trust_remote_code
|
trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
# leave decode for caller, plus we want easy access to original path names...
|
# leave decode for caller, plus we want easy access to original path names...
|
||||||
|
@ -44,6 +44,7 @@ class ReaderHfids(Reader):
|
|||||||
target_img_mode: str = '',
|
target_img_mode: str = '',
|
||||||
shuffle_size: Optional[int] = None,
|
shuffle_size: Optional[int] = None,
|
||||||
num_samples: Optional[int] = None,
|
num_samples: Optional[int] = None,
|
||||||
|
trust_remote_code: bool = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.root = root
|
self.root = root
|
||||||
@ -60,7 +61,11 @@ class ReaderHfids(Reader):
|
|||||||
self.target_key = target_key
|
self.target_key = target_key
|
||||||
self.target_img_mode = target_img_mode
|
self.target_img_mode = target_img_mode
|
||||||
|
|
||||||
self.builder = datasets.load_dataset_builder(name, cache_dir=root)
|
self.builder = datasets.load_dataset_builder(
|
||||||
|
name,
|
||||||
|
cache_dir=root,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
if download:
|
if download:
|
||||||
self.builder.download_and_prepare()
|
self.builder.download_and_prepare()
|
||||||
|
|
||||||
|
4
train.py
4
train.py
@ -103,6 +103,8 @@ group.add_argument('--input-key', default=None, type=str,
|
|||||||
help='Dataset key for input images.')
|
help='Dataset key for input images.')
|
||||||
group.add_argument('--target-key', default=None, type=str,
|
group.add_argument('--target-key', default=None, type=str,
|
||||||
help='Dataset key for target labels.')
|
help='Dataset key for target labels.')
|
||||||
|
group.add_argument('--dataset-trust-remote-code', action='store_true', default=False,
|
||||||
|
help='Allow huggingface dataset import to execute code downloaded from the dataset\'s repo.')
|
||||||
|
|
||||||
# Model parameters
|
# Model parameters
|
||||||
group = parser.add_argument_group('Model parameters')
|
group = parser.add_argument_group('Model parameters')
|
||||||
@ -653,6 +655,7 @@ def main():
|
|||||||
input_key=args.input_key,
|
input_key=args.input_key,
|
||||||
target_key=args.target_key,
|
target_key=args.target_key,
|
||||||
num_samples=args.train_num_samples,
|
num_samples=args.train_num_samples,
|
||||||
|
trust_remote_code=args.dataset_trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.val_split:
|
if args.val_split:
|
||||||
@ -668,6 +671,7 @@ def main():
|
|||||||
input_key=args.input_key,
|
input_key=args.input_key,
|
||||||
target_key=args.target_key,
|
target_key=args.target_key,
|
||||||
num_samples=args.val_num_samples,
|
num_samples=args.val_num_samples,
|
||||||
|
trust_remote_code=args.dataset_trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
# setup mixup / cutmix
|
# setup mixup / cutmix
|
||||||
|
@ -66,6 +66,8 @@ parser.add_argument('--input-img-mode', default=None, type=str,
|
|||||||
help='Dataset image conversion mode for input images.')
|
help='Dataset image conversion mode for input images.')
|
||||||
parser.add_argument('--target-key', default=None, type=str,
|
parser.add_argument('--target-key', default=None, type=str,
|
||||||
help='Dataset key for target labels.')
|
help='Dataset key for target labels.')
|
||||||
|
parser.add_argument('--dataset-trust-remote-code', action='store_true', default=False,
|
||||||
|
help='Allow huggingface dataset import to execute code downloaded from the dataset\'s repo.')
|
||||||
|
|
||||||
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
|
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
|
||||||
help='model architecture (default: dpn92)')
|
help='model architecture (default: dpn92)')
|
||||||
@ -268,6 +270,7 @@ def validate(args):
|
|||||||
input_key=args.input_key,
|
input_key=args.input_key,
|
||||||
input_img_mode=input_img_mode,
|
input_img_mode=input_img_mode,
|
||||||
target_key=args.target_key,
|
target_key=args.target_key,
|
||||||
|
trust_remote_code=args.dataset_trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.valid_labels:
|
if args.valid_labels:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user