mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
commit
3f9959cdd2
4
train.py
4
train.py
@ -397,7 +397,7 @@ def main():
|
||||
# setup synchronized BatchNorm for distributed training
|
||||
if args.distributed and args.sync_bn:
|
||||
assert not args.split_bn
|
||||
if has_apex and use_amp != 'native':
|
||||
if has_apex and use_amp == 'apex':
|
||||
# Apex SyncBN preferred unless native amp is activated
|
||||
model = convert_syncbn_model(model)
|
||||
else:
|
||||
@ -451,7 +451,7 @@ def main():
|
||||
|
||||
# setup distributed training
|
||||
if args.distributed:
|
||||
if has_apex and use_amp != 'native':
|
||||
if has_apex and use_amp == 'apex':
|
||||
# Apex DDP preferred unless native amp is activated
|
||||
if args.local_rank == 0:
|
||||
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
|
||||
|
Loading…
x
Reference in New Issue
Block a user