Fix #2097 a small typo in train.py

This commit is contained in:
Ross Wightman 2024-04-10 09:40:14 -07:00
parent 195d12fdff
commit e25bbfceec

View File

@ -684,7 +684,7 @@ def main():
num_classes=args.num_classes num_classes=args.num_classes
) )
if args.prefetcher: if args.prefetcher:
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) assert not num_aug_splits # collate conflict (need to support de-interleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args) collate_fn = FastCollateMixup(**mixup_args)
else: else:
mixup_fn = Mixup(**mixup_args) mixup_fn = Mixup(**mixup_args)
@ -693,7 +693,7 @@ def main():
if num_aug_splits > 1: if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
# create data loaders w/ augmentation pipeiine # create data loaders w/ augmentation pipeline
train_interpolation = args.train_interpolation train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation: if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation'] train_interpolation = data_config['interpolation']