mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Make broadcast_buffers disable its own flag for now (needs more testing on interaction with dist_bn)
This commit is contained in:
parent
b1c2e3eb92
commit
d9abfa48df
4
train.py
4
train.py
@ -270,6 +270,8 @@ parser.add_argument('--apex-amp', action='store_true', default=False,
|
||||
help='Use NVIDIA Apex AMP mixed precision')
|
||||
parser.add_argument('--native-amp', action='store_true', default=False,
|
||||
help='Use Native Torch AMP mixed precision')
|
||||
parser.add_argument('--no-ddp-bb', action='store_true', default=False,
|
||||
help='Force broadcast buffers for native DDP to off.')
|
||||
parser.add_argument('--channels-last', action='store_true', default=False,
|
||||
help='Use channels_last memory layout')
|
||||
parser.add_argument('--pin-mem', action='store_true', default=False,
|
||||
@ -463,7 +465,7 @@ def main():
|
||||
else:
|
||||
if args.local_rank == 0:
|
||||
_logger.info("Using native Torch DistributedDataParallel.")
|
||||
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.dist_bn)
|
||||
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
|
||||
# NOTE: EMA model does not need to be wrapped by DDP
|
||||
|
||||
# setup learning rate schedule and starting epoch
|
||||
|
Loading…
x
Reference in New Issue
Block a user