Merge pull request #2361 from huggingface/grodino-dataset_trust_remote

Dataset trust remote tweaks
This commit is contained in:
Ross Wightman 2024-12-06 12:06:56 -08:00 committed by GitHub
commit ea231079f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 40 additions and 21 deletions

View File

@ -103,6 +103,7 @@ class IterableImageDataset(data.IterableDataset):
transform=None, transform=None,
target_transform=None, target_transform=None,
max_steps=None, max_steps=None,
**kwargs,
): ):
assert reader is not None assert reader is not None
if isinstance(reader, str): if isinstance(reader, str):
@ -121,6 +122,7 @@ class IterableImageDataset(data.IterableDataset):
input_key=input_key, input_key=input_key,
target_key=target_key, target_key=target_key,
max_steps=max_steps, max_steps=max_steps,
**kwargs,
) )
else: else:
self.reader = reader self.reader = reader

View File

@ -74,34 +74,37 @@ def create_dataset(
seed: int = 42, seed: int = 42,
repeats: int = 0, repeats: int = 0,
input_img_mode: str = 'RGB', input_img_mode: str = 'RGB',
trust_remote_code: bool = False,
**kwargs, **kwargs,
): ):
""" Dataset factory method """ Dataset factory method
In parentheses after each arg are the type of dataset supported for each arg, one of: In parentheses after each arg are the type of dataset supported for each arg, one of:
* folder - default, timm folder (or tar) based ImageDataset * Folder - default, timm folder (or tar) based ImageDataset
* torch - torchvision based datasets * Torch - torchvision based datasets
* HFDS - Hugging Face Datasets * HFDS - Hugging Face Datasets
* HFIDS - Hugging Face Datasets Iterable (streaming mode, with IterableDataset)
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
* WDS - Webdataset * WDS - Webdataset
* all - any of the above * All - any of the above
Args: Args:
name: dataset name, empty is okay for folder based datasets name: Dataset name, empty is okay for folder based datasets
root: root folder of dataset (all) root: Root folder of dataset (All)
split: dataset split (all) split: Dataset split (All)
search_split: search for split specific child fold from root so one can specify search_split: Search for split specific child fold from root so one can specify
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder) `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (Folder, Torch)
class_map: specify class -> index mapping via text file or dict (folder) class_map: Specify class -> index mapping via text file or dict (Folder)
load_bytes: load data, return images as undecoded bytes (folder) load_bytes: Load data, return images as undecoded bytes (Folder)
download: download dataset if not present and supported (HFDS, TFDS, torch) download: Download dataset if not present and supported (HFIDS, TFDS, Torch)
is_training: create dataset in train mode, this is different from the split. is_training: Create dataset in train mode, this is different from the split.
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS) For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS, HFIDS)
batch_size: batch size hint for (TFDS, WDS) batch_size: Batch size hint for iterable datasets (TFDS, WDS, HFIDS)
seed: seed for iterable datasets (TFDS, WDS) seed: Seed for iterable datasets (TFDS, WDS, HFIDS)
repeats: dataset repeats per iteration i.e. epoch (TFDS, WDS) repeats: Dataset repeats per iteration i.e. epoch (TFDS, WDS, HFIDS)
input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS) input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS, HFIDS)
**kwargs: other args to pass to dataset trust_remote_code: Trust remote code in Hugging Face Datasets if True (HFDS, HFIDS)
**kwargs: Other args to pass through to underlying Dataset and/or Reader classes
Returns: Returns:
Dataset object Dataset object
@ -162,6 +165,7 @@ def create_dataset(
split=split, split=split,
class_map=class_map, class_map=class_map,
input_img_mode=input_img_mode, input_img_mode=input_img_mode,
trust_remote_code=trust_remote_code,
**kwargs, **kwargs,
) )
elif name.startswith('hfids/'): elif name.startswith('hfids/'):
@ -177,7 +181,8 @@ def create_dataset(
repeats=repeats, repeats=repeats,
seed=seed, seed=seed,
input_img_mode=input_img_mode, input_img_mode=input_img_mode,
**kwargs trust_remote_code=trust_remote_code,
**kwargs,
) )
elif name.startswith('tfds/'): elif name.startswith('tfds/'):
ds = IterableImageDataset( ds = IterableImageDataset(

View File

@ -48,7 +48,7 @@ class ReaderHfds(Reader):
self.dataset = datasets.load_dataset( self.dataset = datasets.load_dataset(
name, # 'name' maps to path arg in hf datasets name, # 'name' maps to path arg in hf datasets
split=split, split=split,
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path if root set
trust_remote_code=trust_remote_code trust_remote_code=trust_remote_code
) )
# leave decode for caller, plus we want easy access to original path names... # leave decode for caller, plus we want easy access to original path names...

View File

@ -44,6 +44,7 @@ class ReaderHfids(Reader):
target_img_mode: str = '', target_img_mode: str = '',
shuffle_size: Optional[int] = None, shuffle_size: Optional[int] = None,
num_samples: Optional[int] = None, num_samples: Optional[int] = None,
trust_remote_code: bool = False
): ):
super().__init__() super().__init__()
self.root = root self.root = root
@ -60,7 +61,11 @@ class ReaderHfids(Reader):
self.target_key = target_key self.target_key = target_key
self.target_img_mode = target_img_mode self.target_img_mode = target_img_mode
self.builder = datasets.load_dataset_builder(name, cache_dir=root) self.builder = datasets.load_dataset_builder(
name,
cache_dir=root,
trust_remote_code=trust_remote_code,
)
if download: if download:
self.builder.download_and_prepare() self.builder.download_and_prepare()

View File

@ -103,6 +103,8 @@ group.add_argument('--input-key', default=None, type=str,
help='Dataset key for input images.') help='Dataset key for input images.')
group.add_argument('--target-key', default=None, type=str, group.add_argument('--target-key', default=None, type=str,
help='Dataset key for target labels.') help='Dataset key for target labels.')
group.add_argument('--dataset-trust-remote-code', action='store_true', default=False,
help='Allow huggingface dataset import to execute code downloaded from the dataset\'s repo.')
# Model parameters # Model parameters
group = parser.add_argument_group('Model parameters') group = parser.add_argument_group('Model parameters')
@ -653,6 +655,7 @@ def main():
input_key=args.input_key, input_key=args.input_key,
target_key=args.target_key, target_key=args.target_key,
num_samples=args.train_num_samples, num_samples=args.train_num_samples,
trust_remote_code=args.dataset_trust_remote_code,
) )
if args.val_split: if args.val_split:
@ -668,6 +671,7 @@ def main():
input_key=args.input_key, input_key=args.input_key,
target_key=args.target_key, target_key=args.target_key,
num_samples=args.val_num_samples, num_samples=args.val_num_samples,
trust_remote_code=args.dataset_trust_remote_code,
) )
# setup mixup / cutmix # setup mixup / cutmix

View File

@ -66,6 +66,8 @@ parser.add_argument('--input-img-mode', default=None, type=str,
help='Dataset image conversion mode for input images.') help='Dataset image conversion mode for input images.')
parser.add_argument('--target-key', default=None, type=str, parser.add_argument('--target-key', default=None, type=str,
help='Dataset key for target labels.') help='Dataset key for target labels.')
parser.add_argument('--dataset-trust-remote-code', action='store_true', default=False,
help='Allow huggingface dataset import to execute code downloaded from the dataset\'s repo.')
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
help='model architecture (default: dpn92)') help='model architecture (default: dpn92)')
@ -268,6 +270,7 @@ def validate(args):
input_key=args.input_key, input_key=args.input_key,
input_img_mode=input_img_mode, input_img_mode=input_img_mode,
target_key=args.target_key, target_key=args.target_key,
trust_remote_code=args.dataset_trust_remote_code,
) )
if args.valid_labels: if args.valid_labels: