mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix dtype log when default (None) is used w/o AMP
This commit is contained in:
parent
92f610c982
commit
1969528296
2
train.py
2
train.py
@ -597,7 +597,7 @@ def main():
|
|||||||
_logger.info('Using native Torch AMP. Training in mixed precision.')
|
_logger.info('Using native Torch AMP. Training in mixed precision.')
|
||||||
else:
|
else:
|
||||||
if utils.is_primary(args):
|
if utils.is_primary(args):
|
||||||
_logger.info(f'AMP not enabled. Training in {model_dtype}.')
|
_logger.info(f'AMP not enabled. Training in {model_dtype or torch.float32}.')
|
||||||
|
|
||||||
# optionally resume from a checkpoint
|
# optionally resume from a checkpoint
|
||||||
resume_epoch = None
|
resume_epoch = None
|
||||||
|
@ -192,7 +192,7 @@ def validate(args):
|
|||||||
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
||||||
_logger.info('Validating in mixed precision with native PyTorch AMP.')
|
_logger.info('Validating in mixed precision with native PyTorch AMP.')
|
||||||
else:
|
else:
|
||||||
_logger.info(f'Validating in {model_dtype}. AMP not enabled.')
|
_logger.info(f'Validating in {model_dtype or torch.float32}. AMP not enabled.')
|
||||||
|
|
||||||
if args.fuser:
|
if args.fuser:
|
||||||
set_jit_fuser(args.fuser)
|
set_jit_fuser(args.fuser)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user