Add --model-dtype (pure bfloat16/float16) support to inference.py

This commit is contained in:
Ross Wightman 2025-01-14 11:00:16 -08:00
parent 8ce197e33a
commit fc0609bcb6

View File

@ -105,6 +105,8 @@ parser.add_argument('--amp', action='store_true', default=False,
help='use Native AMP for mixed precision training')
parser.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
parser.add_argument('--model-dtype', default=None, type=str,
help='Model dtype override (non-AMP) (default: float32)')
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
@ -161,9 +163,15 @@ def main():
device = torch.device(args.device)
model_dtype = None
if args.model_dtype:
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
model_dtype = getattr(torch, args.model_dtype)
# resolve AMP arguments based on PyTorch / Apex availability
amp_autocast = suppress
if args.amp:
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
assert args.amp_dtype in ('float16', 'bfloat16')
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
@ -201,7 +209,7 @@ def main():
if args.test_pool:
model, test_time_pool = apply_test_time_pool(model, data_config)
model = model.to(device)
model = model.to(device=device, dtype=model_dtype)
model.eval()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
@ -237,6 +245,7 @@ def main():
use_prefetcher=True,
num_workers=workers,
device=device,
img_dtype=model_dtype or torch.float32,
**data_config,
)
@ -280,7 +289,7 @@ def main():
np_labels = to_label(np_indices)
all_labels.append(np_labels)
all_outputs.append(output.cpu().numpy())
all_outputs.append(output.float().cpu().numpy())
# measure elapsed time
batch_time.update(time.time() - end)