mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2402 from JosuaRieder/fix_inference_csv_export
disable abbreviating csv inference output with ellipses
This commit is contained in:
commit
1572769059
15
inference.py
15
inference.py
@ -12,6 +12,7 @@ import os
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
from sys import maxsize
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -104,6 +105,8 @@ parser.add_argument('--amp', action='store_true', default=False,
|
||||
help='use Native AMP for mixed precision training')
|
||||
parser.add_argument('--amp-dtype', default='float16', type=str,
|
||||
help='lower precision AMP dtype (default: float16)')
|
||||
parser.add_argument('--model-dtype', default=None, type=str,
|
||||
help='Model dtype override (non-AMP) (default: float32)')
|
||||
parser.add_argument('--fuser', default='', type=str,
|
||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
|
||||
@ -160,9 +163,15 @@ def main():
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
model_dtype = None
|
||||
if args.model_dtype:
|
||||
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
|
||||
model_dtype = getattr(torch, args.model_dtype)
|
||||
|
||||
# resolve AMP arguments based on PyTorch / Apex availability
|
||||
amp_autocast = suppress
|
||||
if args.amp:
|
||||
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
|
||||
assert args.amp_dtype in ('float16', 'bfloat16')
|
||||
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
|
||||
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
||||
@ -200,7 +209,7 @@ def main():
|
||||
if args.test_pool:
|
||||
model, test_time_pool = apply_test_time_pool(model, data_config)
|
||||
|
||||
model = model.to(device)
|
||||
model = model.to(device=device, dtype=model_dtype)
|
||||
model.eval()
|
||||
if args.channels_last:
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
@ -236,6 +245,7 @@ def main():
|
||||
use_prefetcher=True,
|
||||
num_workers=workers,
|
||||
device=device,
|
||||
img_dtype=model_dtype or torch.float32,
|
||||
**data_config,
|
||||
)
|
||||
|
||||
@ -279,7 +289,7 @@ def main():
|
||||
np_labels = to_label(np_indices)
|
||||
all_labels.append(np_labels)
|
||||
|
||||
all_outputs.append(output.cpu().numpy())
|
||||
all_outputs.append(output.float().cpu().numpy())
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
@ -343,6 +353,7 @@ def main():
|
||||
|
||||
|
||||
def save_results(df, results_filename, results_format='csv', filename_col='filename'):
|
||||
np.set_printoptions(threshold=maxsize)
|
||||
results_filename += _FMT_EXT[results_format]
|
||||
if results_format == 'parquet':
|
||||
df.set_index(filename_col).to_parquet(results_filename)
|
||||
|
Loading…
x
Reference in New Issue
Block a user