mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2361 from huggingface/grodino-dataset_trust_remote
Dataset trust remote tweaks
This commit is contained in:
commit
ea231079f5
@ -103,6 +103,7 @@ class IterableImageDataset(data.IterableDataset):
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
max_steps=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert reader is not None
|
||||
if isinstance(reader, str):
|
||||
@ -121,6 +122,7 @@ class IterableImageDataset(data.IterableDataset):
|
||||
input_key=input_key,
|
||||
target_key=target_key,
|
||||
max_steps=max_steps,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
self.reader = reader
|
||||
|
@ -74,34 +74,37 @@ def create_dataset(
|
||||
seed: int = 42,
|
||||
repeats: int = 0,
|
||||
input_img_mode: str = 'RGB',
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
""" Dataset factory method
|
||||
|
||||
In parentheses after each arg are the type of dataset supported for each arg, one of:
|
||||
* folder - default, timm folder (or tar) based ImageDataset
|
||||
* torch - torchvision based datasets
|
||||
* Folder - default, timm folder (or tar) based ImageDataset
|
||||
* Torch - torchvision based datasets
|
||||
* HFDS - Hugging Face Datasets
|
||||
* HFIDS - Hugging Face Datasets Iterable (streaming mode, with IterableDataset)
|
||||
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
|
||||
* WDS - Webdataset
|
||||
* all - any of the above
|
||||
* All - any of the above
|
||||
|
||||
Args:
|
||||
name: dataset name, empty is okay for folder based datasets
|
||||
root: root folder of dataset (all)
|
||||
split: dataset split (all)
|
||||
search_split: search for split specific child fold from root so one can specify
|
||||
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
|
||||
class_map: specify class -> index mapping via text file or dict (folder)
|
||||
load_bytes: load data, return images as undecoded bytes (folder)
|
||||
download: download dataset if not present and supported (HFDS, TFDS, torch)
|
||||
is_training: create dataset in train mode, this is different from the split.
|
||||
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS)
|
||||
batch_size: batch size hint for (TFDS, WDS)
|
||||
seed: seed for iterable datasets (TFDS, WDS)
|
||||
repeats: dataset repeats per iteration i.e. epoch (TFDS, WDS)
|
||||
input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS)
|
||||
**kwargs: other args to pass to dataset
|
||||
name: Dataset name, empty is okay for folder based datasets
|
||||
root: Root folder of dataset (All)
|
||||
split: Dataset split (All)
|
||||
search_split: Search for split specific child fold from root so one can specify
|
||||
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (Folder, Torch)
|
||||
class_map: Specify class -> index mapping via text file or dict (Folder)
|
||||
load_bytes: Load data, return images as undecoded bytes (Folder)
|
||||
download: Download dataset if not present and supported (HFIDS, TFDS, Torch)
|
||||
is_training: Create dataset in train mode, this is different from the split.
|
||||
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS, HFIDS)
|
||||
batch_size: Batch size hint for iterable datasets (TFDS, WDS, HFIDS)
|
||||
seed: Seed for iterable datasets (TFDS, WDS, HFIDS)
|
||||
repeats: Dataset repeats per iteration i.e. epoch (TFDS, WDS, HFIDS)
|
||||
input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS, HFIDS)
|
||||
trust_remote_code: Trust remote code in Hugging Face Datasets if True (HFDS, HFIDS)
|
||||
**kwargs: Other args to pass through to underlying Dataset and/or Reader classes
|
||||
|
||||
Returns:
|
||||
Dataset object
|
||||
@ -162,6 +165,7 @@ def create_dataset(
|
||||
split=split,
|
||||
class_map=class_map,
|
||||
input_img_mode=input_img_mode,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
elif name.startswith('hfids/'):
|
||||
@ -177,7 +181,8 @@ def create_dataset(
|
||||
repeats=repeats,
|
||||
seed=seed,
|
||||
input_img_mode=input_img_mode,
|
||||
**kwargs
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
elif name.startswith('tfds/'):
|
||||
ds = IterableImageDataset(
|
||||
|
@ -48,7 +48,7 @@ class ReaderHfds(Reader):
|
||||
self.dataset = datasets.load_dataset(
|
||||
name, # 'name' maps to path arg in hf datasets
|
||||
split=split,
|
||||
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
|
||||
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path if root set
|
||||
trust_remote_code=trust_remote_code
|
||||
)
|
||||
# leave decode for caller, plus we want easy access to original path names...
|
||||
|
@ -44,6 +44,7 @@ class ReaderHfids(Reader):
|
||||
target_img_mode: str = '',
|
||||
shuffle_size: Optional[int] = None,
|
||||
num_samples: Optional[int] = None,
|
||||
trust_remote_code: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
self.root = root
|
||||
@ -60,7 +61,11 @@ class ReaderHfids(Reader):
|
||||
self.target_key = target_key
|
||||
self.target_img_mode = target_img_mode
|
||||
|
||||
self.builder = datasets.load_dataset_builder(name, cache_dir=root)
|
||||
self.builder = datasets.load_dataset_builder(
|
||||
name,
|
||||
cache_dir=root,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if download:
|
||||
self.builder.download_and_prepare()
|
||||
|
||||
|
4
train.py
4
train.py
@ -103,6 +103,8 @@ group.add_argument('--input-key', default=None, type=str,
|
||||
help='Dataset key for input images.')
|
||||
group.add_argument('--target-key', default=None, type=str,
|
||||
help='Dataset key for target labels.')
|
||||
group.add_argument('--dataset-trust-remote-code', action='store_true', default=False,
|
||||
help='Allow huggingface dataset import to execute code downloaded from the dataset\'s repo.')
|
||||
|
||||
# Model parameters
|
||||
group = parser.add_argument_group('Model parameters')
|
||||
@ -653,6 +655,7 @@ def main():
|
||||
input_key=args.input_key,
|
||||
target_key=args.target_key,
|
||||
num_samples=args.train_num_samples,
|
||||
trust_remote_code=args.dataset_trust_remote_code,
|
||||
)
|
||||
|
||||
if args.val_split:
|
||||
@ -668,6 +671,7 @@ def main():
|
||||
input_key=args.input_key,
|
||||
target_key=args.target_key,
|
||||
num_samples=args.val_num_samples,
|
||||
trust_remote_code=args.dataset_trust_remote_code,
|
||||
)
|
||||
|
||||
# setup mixup / cutmix
|
||||
|
@ -66,6 +66,8 @@ parser.add_argument('--input-img-mode', default=None, type=str,
|
||||
help='Dataset image conversion mode for input images.')
|
||||
parser.add_argument('--target-key', default=None, type=str,
|
||||
help='Dataset key for target labels.')
|
||||
parser.add_argument('--dataset-trust-remote-code', action='store_true', default=False,
|
||||
help='Allow huggingface dataset import to execute code downloaded from the dataset\'s repo.')
|
||||
|
||||
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
|
||||
help='model architecture (default: dpn92)')
|
||||
@ -268,6 +270,7 @@ def validate(args):
|
||||
input_key=args.input_key,
|
||||
input_img_mode=input_img_mode,
|
||||
target_key=args.target_key,
|
||||
trust_remote_code=args.dataset_trust_remote_code,
|
||||
)
|
||||
|
||||
if args.valid_labels:
|
||||
|
Loading…
x
Reference in New Issue
Block a user