111 lines
4.4 KiB
Python
111 lines
4.4 KiB
Python
""" ONNX-runtime validation script
|
|
|
|
This script was created to verify accuracy and performance of exported ONNX
|
|
models running with the onnxruntime. It utilizes the PyTorch dataloader/processing
|
|
pipeline for a fair comparison against the originals.
|
|
|
|
Copyright 2020 Ross Wightman
|
|
"""
|
|
import argparse
|
|
import numpy as np
|
|
import onnxruntime
|
|
from timm.data import create_loader, resolve_data_config, create_dataset
|
|
from timm.utils import AverageMeter
|
|
import time
|
|
|
|
parser = argparse.ArgumentParser(description='ONNX Validation')
|
|
parser.add_argument('data', metavar='DIR',
|
|
help='path to dataset')
|
|
parser.add_argument('--onnx-input', default='', type=str, metavar='PATH',
|
|
help='path to onnx model/weights file')
|
|
parser.add_argument('--onnx-output-opt', default='', type=str, metavar='PATH',
|
|
help='path to output optimized onnx graph')
|
|
parser.add_argument('--profile', action='store_true', default=False,
|
|
help='Enable profiler output.')
|
|
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
|
|
help='number of data loading workers (default: 2)')
|
|
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
|
metavar='N', help='mini-batch size (default: 256)')
|
|
parser.add_argument('--img-size', default=None, type=int,
|
|
metavar='N', help='Input image dimension, uses model default if empty')
|
|
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
|
help='Override mean pixel value of dataset')
|
|
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
|
help='Override std deviation of of dataset')
|
|
parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',
|
|
help='Override default crop pct of 0.875')
|
|
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
|
help='Image resize interpolation type (overrides model)')
|
|
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
|
metavar='N', help='print frequency (default: 10)')
|
|
|
|
|
|
def main():
|
|
args = parser.parse_args()
|
|
args.gpu_id = 0
|
|
|
|
# Set graph optimization level
|
|
sess_options = onnxruntime.SessionOptions()
|
|
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
if args.profile:
|
|
sess_options.enable_profiling = True
|
|
if args.onnx_output_opt:
|
|
sess_options.optimized_model_filepath = args.onnx_output_opt
|
|
|
|
session = onnxruntime.InferenceSession(args.onnx_input, sess_options)
|
|
|
|
data_config = resolve_data_config(vars(args))
|
|
loader = create_loader(
|
|
create_dataset('', args.data),
|
|
input_size=data_config['input_size'],
|
|
batch_size=args.batch_size,
|
|
use_prefetcher=False,
|
|
interpolation=data_config['interpolation'],
|
|
mean=data_config['mean'],
|
|
std=data_config['std'],
|
|
num_workers=args.workers,
|
|
crop_pct=data_config['crop_pct']
|
|
)
|
|
|
|
input_name = session.get_inputs()[0].name
|
|
|
|
batch_time = AverageMeter()
|
|
top1 = AverageMeter()
|
|
top5 = AverageMeter()
|
|
end = time.time()
|
|
for i, (input, target) in enumerate(loader):
|
|
# run the net and return prediction
|
|
output = session.run([], {input_name: input.data.numpy()})
|
|
output = output[0]
|
|
|
|
# measure accuracy and record loss
|
|
prec1, prec5 = accuracy_np(output, target.numpy())
|
|
top1.update(prec1.item(), input.size(0))
|
|
top5.update(prec5.item(), input.size(0))
|
|
|
|
# measure elapsed time
|
|
batch_time.update(time.time() - end)
|
|
end = time.time()
|
|
|
|
if i % args.print_freq == 0:
|
|
print(
|
|
f'Test: [{i}/{len(loader)}]\t'
|
|
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {input.size(0) / batch_time.avg:.3f}/s, '
|
|
f'{100 * batch_time.avg / input.size(0):.3f} ms/sample) \t'
|
|
f'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
|
|
f'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'
|
|
)
|
|
|
|
print(f' * Prec@1 {top1.avg:.3f} ({100-top1.avg:.3f}) Prec@5 {top5.avg:.3f} ({100.-top5.avg:.3f})')
|
|
|
|
|
|
def accuracy_np(output, target):
|
|
max_indices = np.argsort(output, axis=1)[:, ::-1]
|
|
top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean()
|
|
top1 = 100 * np.equal(max_indices[:, 0], target).mean()
|
|
return top1, top5
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|