Update main.py

Adding the distributed parameter to the main function for the data loader to work as expected.
pull/209/head
Mehdi Yazdani 2023-01-17 16:30:58 -05:00 committed by GitHub
parent ee8893c806
commit e2c74433f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -178,6 +178,7 @@ def get_args_parser():
parser.set_defaults(pin_mem=True) parser.set_defaults(pin_mem=True)
# distributed training parameters # distributed training parameters
parser.add_argument('--distributed', action='store_true', defaul=False, help='Enabling distributed training')
parser.add_argument('--world_size', default=1, type=int, parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes') help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
@ -205,7 +206,7 @@ def main(args):
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
dataset_val, _ = build_dataset(is_train=False, args=args) dataset_val, _ = build_dataset(is_train=False, args=args)
if True: # args.distributed: if args.distributed:
num_tasks = utils.get_world_size() num_tasks = utils.get_world_size()
global_rank = utils.get_rank() global_rank = utils.get_rank()
if args.repeated_aug: if args.repeated_aug: