Fix dtype log when default (None) is used w/o AMP

half_prec_trainval
Ross Wightman 2025-01-07 11:47:22 -08:00
parent 92f610c982
commit 1969528296
2 changed files with 2 additions and 2 deletions

View File

@ -597,7 +597,7 @@ def main():
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if utils.is_primary(args):
_logger.info(f'AMP not enabled. Training in {model_dtype}.')
_logger.info(f'AMP not enabled. Training in {model_dtype or torch.float32}.')
# optionally resume from a checkpoint
resume_epoch = None

View File

@ -192,7 +192,7 @@ def validate(args):
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
_logger.info('Validating in mixed precision with native PyTorch AMP.')
else:
_logger.info(f'Validating in {model_dtype}. AMP not enabled.')
_logger.info(f'Validating in {model_dtype or torch.float32}. AMP not enabled.')
if args.fuser:
set_jit_fuser(args.fuser)