diff --git a/train.py b/train.py index 7aab20a7..38fd3c6d 100755 --- a/train.py +++ b/train.py @@ -102,6 +102,8 @@ group.add_argument('--input-key', default=None, type=str, help='Dataset key for input images.') group.add_argument('--target-key', default=None, type=str, help='Dataset key for target labels.') +group.add_argument('--dataset-trust-remote-code', action='store_true', default=False, + help='Allow huggingface dataset import to execute code downloaded from the dataset\'s repo.') # Model parameters group = parser.add_argument_group('Model parameters') @@ -641,6 +643,7 @@ def main(): input_key=args.input_key, target_key=args.target_key, num_samples=args.train_num_samples, + trust_remote_code=args.dataset_trust_remote_code, ) if args.val_split: @@ -656,6 +659,7 @@ def main(): input_key=args.input_key, target_key=args.target_key, num_samples=args.val_num_samples, + trust_remote_code=args.dataset_trust_remote_code, ) # setup mixup / cutmix diff --git a/validate.py b/validate.py index 602111bb..159bd0b1 100755 --- a/validate.py +++ b/validate.py @@ -66,6 +66,8 @@ parser.add_argument('--input-img-mode', default=None, type=str, help='Dataset image conversion mode for input images.') parser.add_argument('--target-key', default=None, type=str, help='Dataset key for target labels.') +parser.add_argument('--dataset-trust-remote-code', action='store_true', default=False, + help='Allow huggingface dataset import to execute code downloaded from the dataset\'s repo.') parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', help='model architecture (default: dpn92)') @@ -268,6 +270,7 @@ def validate(args): input_key=args.input_key, input_img_mode=input_img_mode, target_key=args.target_key, + trust_remote_code=args.dataset_trust_remote_code, ) if args.valid_labels: