mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
commit
3f9959cdd2
4
train.py
4
train.py
@ -397,7 +397,7 @@ def main():
|
|||||||
# setup synchronized BatchNorm for distributed training
|
# setup synchronized BatchNorm for distributed training
|
||||||
if args.distributed and args.sync_bn:
|
if args.distributed and args.sync_bn:
|
||||||
assert not args.split_bn
|
assert not args.split_bn
|
||||||
if has_apex and use_amp != 'native':
|
if has_apex and use_amp == 'apex':
|
||||||
# Apex SyncBN preferred unless native amp is activated
|
# Apex SyncBN preferred unless native amp is activated
|
||||||
model = convert_syncbn_model(model)
|
model = convert_syncbn_model(model)
|
||||||
else:
|
else:
|
||||||
@ -451,7 +451,7 @@ def main():
|
|||||||
|
|
||||||
# setup distributed training
|
# setup distributed training
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
if has_apex and use_amp != 'native':
|
if has_apex and use_amp == 'apex':
|
||||||
# Apex DDP preferred unless native amp is activated
|
# Apex DDP preferred unless native amp is activated
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
|
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user