diff --git a/train.py b/train.py index d95611ad..785b99e2 100755 --- a/train.py +++ b/train.py @@ -270,6 +270,8 @@ parser.add_argument('--apex-amp', action='store_true', default=False, help='Use NVIDIA Apex AMP mixed precision') parser.add_argument('--native-amp', action='store_true', default=False, help='Use Native Torch AMP mixed precision') +parser.add_argument('--no-ddp-bb', action='store_true', default=False, + help='Force broadcast buffers for native DDP to off.') parser.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') parser.add_argument('--pin-mem', action='store_true', default=False, @@ -463,7 +465,7 @@ def main(): else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") - model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.dist_bn) + model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch