Fix #2097 a small typo in train.py

This commit is contained in:
Ross Wightman 2024-04-10 09:40:14 -07:00
parent 195d12fdff
commit e25bbfceec

View File

@ -684,7 +684,7 @@ def main():
num_classes=args.num_classes
)
if args.prefetcher:
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
assert not num_aug_splits # collate conflict (need to support de-interleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
else:
mixup_fn = Mixup(**mixup_args)
@ -693,7 +693,7 @@ def main():
if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
# create data loaders w/ augmentation pipeiine
# create data loaders w/ augmentation pipeline
train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation']