mmpretrain/mmcls/core/export/test.py

97 lines
3.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import numpy as np
import onnxruntime as ort
import torch
from mmcls.models.classifiers import BaseClassifier
class ONNXRuntimeClassifier(BaseClassifier):
"""Wrapper for classifier's inference with ONNXRuntime."""
def __init__(self, onnx_file, class_names, device_id):
super(ONNXRuntimeClassifier, self).__init__()
sess = ort.InferenceSession(onnx_file)
providers = ['CPUExecutionProvider']
options = [{}]
is_cuda_available = ort.get_device() == 'GPU'
if is_cuda_available:
providers.insert(0, 'CUDAExecutionProvider')
options.insert(0, {'device_id': device_id})
sess.set_providers(providers, options)
self.sess = sess
self.CLASSES = class_names
self.device_id = device_id
self.io_binding = sess.io_binding()
self.output_names = [_.name for _ in sess.get_outputs()]
self.is_cuda_available = is_cuda_available
def simple_test(self, img, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')
def forward_train(self, imgs, **kwargs):
raise NotImplementedError('This method is not implemented.')
def forward_test(self, imgs, img_metas, **kwargs):
input_data = imgs
# set io binding for inputs/outputs
device_type = 'cuda' if self.is_cuda_available else 'cpu'
if not self.is_cuda_available:
input_data = input_data.cpu()
self.io_binding.bind_input(
name='input',
device_type=device_type,
device_id=self.device_id,
element_type=np.float32,
shape=input_data.shape,
buffer_ptr=input_data.data_ptr())
for name in self.output_names:
self.io_binding.bind_output(name)
# run session to get outputs
self.sess.run_with_iobinding(self.io_binding)
results = self.io_binding.copy_outputs_to_cpu()[0]
return list(results)
class TensorRTClassifier(BaseClassifier):
def __init__(self, trt_file, class_names, device_id):
super(TensorRTClassifier, self).__init__()
from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
try:
load_tensorrt_plugin()
except (ImportError, ModuleNotFoundError):
warnings.warn('If input model has custom op from mmcv, \
you may have to build mmcv with TensorRT from source.')
model = TRTWraper(
trt_file, input_names=['input'], output_names=['probs'])
self.model = model
self.device_id = device_id
self.CLASSES = class_names
def simple_test(self, img, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')
def forward_train(self, imgs, **kwargs):
raise NotImplementedError('This method is not implemented.')
def forward_test(self, imgs, img_metas, **kwargs):
input_data = imgs
with torch.cuda.device(self.device_id), torch.no_grad():
results = self.model({'input': input_data})['probs']
results = results.detach().cpu().numpy()
return list(results)