mirror of https://github.com/facebookresearch/deit
Update main.py
Adding the distributed parameter to the main function for the data loader to work as expected.pull/209/head
parent
ee8893c806
commit
e2c74433f9
3
main.py
3
main.py
|
@ -178,6 +178,7 @@ def get_args_parser():
|
||||||
parser.set_defaults(pin_mem=True)
|
parser.set_defaults(pin_mem=True)
|
||||||
|
|
||||||
# distributed training parameters
|
# distributed training parameters
|
||||||
|
parser.add_argument('--distributed', action='store_true', defaul=False, help='Enabling distributed training')
|
||||||
parser.add_argument('--world_size', default=1, type=int,
|
parser.add_argument('--world_size', default=1, type=int,
|
||||||
help='number of distributed processes')
|
help='number of distributed processes')
|
||||||
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
||||||
|
@ -205,7 +206,7 @@ def main(args):
|
||||||
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
|
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
|
||||||
dataset_val, _ = build_dataset(is_train=False, args=args)
|
dataset_val, _ = build_dataset(is_train=False, args=args)
|
||||||
|
|
||||||
if True: # args.distributed:
|
if args.distributed:
|
||||||
num_tasks = utils.get_world_size()
|
num_tasks = utils.get_world_size()
|
||||||
global_rank = utils.get_rank()
|
global_rank = utils.get_rank()
|
||||||
if args.repeated_aug:
|
if args.repeated_aug:
|
||||||
|
|
Loading…
Reference in New Issue