From fc0609bcb63dfc95c3613f285665a6a07fd6fb51 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 14 Jan 2025 11:00:16 -0800 Subject: [PATCH] Add --model-dtype (pure bfloat16/float16) support to inference.py --- inference.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/inference.py b/inference.py index e18d6fc2..05f2d89a 100755 --- a/inference.py +++ b/inference.py @@ -105,6 +105,8 @@ parser.add_argument('--amp', action='store_true', default=False, help='use Native AMP for mixed precision training') parser.add_argument('--amp-dtype', default='float16', type=str, help='lower precision AMP dtype (default: float16)') +parser.add_argument('--model-dtype', default=None, type=str, + help='Model dtype override (non-AMP) (default: float32)') parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) @@ -161,9 +163,15 @@ def main(): device = torch.device(args.device) + model_dtype = None + if args.model_dtype: + assert args.model_dtype in ('float32', 'float16', 'bfloat16') + model_dtype = getattr(torch, args.model_dtype) + # resolve AMP arguments based on PyTorch / Apex availability amp_autocast = suppress if args.amp: + assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP' assert args.amp_dtype in ('float16', 'bfloat16') amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16 amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) @@ -201,7 +209,7 @@ def main(): if args.test_pool: model, test_time_pool = apply_test_time_pool(model, data_config) - model = model.to(device) + model = model.to(device=device, dtype=model_dtype) model.eval() if args.channels_last: model = model.to(memory_format=torch.channels_last) @@ -237,6 +245,7 @@ def main(): use_prefetcher=True, num_workers=workers, device=device, + img_dtype=model_dtype or torch.float32, **data_config, ) @@ -280,7 +289,7 @@ def main(): np_labels = to_label(np_indices) all_labels.append(np_labels) - all_outputs.append(output.cpu().numpy()) + all_outputs.append(output.float().cpu().numpy()) # measure elapsed time batch_time.update(time.time() - end)