mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
139 lines
4.8 KiB
Python
139 lines
4.8 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import logging
|
|
import os
|
|
import sys
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import onnxruntime
|
|
import torch
|
|
|
|
from easycv.models import build_model
|
|
from easycv.utils.checkpoint import load_checkpoint
|
|
from easycv.utils.config_tools import mmcv_config_fromfile, rebuild_config
|
|
from easycv.utils.test_util import run_in_subprocess
|
|
|
|
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
WORK_DIRECTORY = 'work_dir3'
|
|
|
|
BASIC_EXPORT_CONFIGS = {
|
|
'config_file': None,
|
|
'checkpoint': 'dummy',
|
|
'output_filename': f'{WORK_DIRECTORY}/test_out.pth',
|
|
'user_config_params': ['--export.export_type', 'onnx']
|
|
}
|
|
|
|
|
|
def build_cmd(export_configs, MODEL_TYPE) -> str:
|
|
base_cmd = 'python tools/export.py'
|
|
base_cmd += f" {export_configs['config_file']}"
|
|
base_cmd += f" {export_configs['checkpoint']}"
|
|
base_cmd += f" {export_configs['output_filename']}"
|
|
base_cmd += f' --model_type {MODEL_TYPE}'
|
|
user_params = ' '.join(export_configs['user_config_params'])
|
|
base_cmd += f' --user_config_params {user_params}'
|
|
return base_cmd
|
|
|
|
|
|
class ExportTest(unittest.TestCase):
|
|
"""In this unittest, we test the onnx export functionality of
|
|
some classification/detection models.
|
|
"""
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
os.makedirs(WORK_DIRECTORY, exist_ok=True)
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
|
|
def run_test(self,
|
|
CONFIG_FILE,
|
|
MODEL_TYPE,
|
|
img_size: int = 224,
|
|
**override_configs):
|
|
configs = BASIC_EXPORT_CONFIGS.copy()
|
|
configs['config_file'] = CONFIG_FILE
|
|
|
|
configs.update(override_configs)
|
|
|
|
cmd = build_cmd(configs, MODEL_TYPE)
|
|
logging.info(f'Export with commands: {cmd}')
|
|
run_in_subprocess(cmd)
|
|
|
|
cfg = mmcv_config_fromfile(configs['config_file'])
|
|
cfg = rebuild_config(cfg, configs['user_config_params'])
|
|
|
|
if hasattr(cfg.model, 'pretrained'):
|
|
cfg.model.pretrained = False
|
|
|
|
torch_model = build_model(cfg.model).eval()
|
|
if 'checkpoint' in override_configs:
|
|
load_checkpoint(
|
|
torch_model,
|
|
override_configs['checkpoint'],
|
|
strict=False,
|
|
logger=logging.getLogger())
|
|
session = onnxruntime.InferenceSession(configs['output_filename'] +
|
|
'.onnx')
|
|
input_tensor = torch.randn((1, 3, img_size, img_size))
|
|
|
|
torch_output = torch_model(input_tensor, mode='test')['prob']
|
|
|
|
onnx_output = session.run(
|
|
[session.get_outputs()[0].name],
|
|
{session.get_inputs()[0].name: np.array(input_tensor)})
|
|
if isinstance(onnx_output, list):
|
|
onnx_output = onnx_output[0]
|
|
|
|
onnx_output = torch.tensor(onnx_output)
|
|
|
|
is_same_shape = torch_output.shape == onnx_output.shape
|
|
|
|
self.assertTrue(
|
|
is_same_shape,
|
|
f'The shapes of the two outputs are mismatch, got {torch_output.shape} and {onnx_output.shape}'
|
|
)
|
|
is_allclose = torch.allclose(torch_output, onnx_output)
|
|
|
|
torch_out_minmax = f'{float(torch_output.min())}~{float(torch_output.max())}'
|
|
onnx_out_minmax = f'{float(onnx_output.min())}~{float(onnx_output.max())}'
|
|
|
|
info_msg = f'got avg: {float(torch_output.mean())} and {float(onnx_output.mean())},'
|
|
info_msg += f' and range: {torch_out_minmax} and {onnx_out_minmax}'
|
|
self.assertTrue(
|
|
is_allclose,
|
|
f'The values between the two outputs are mismatch, {info_msg}')
|
|
|
|
def test_inceptionv3(self):
|
|
CONFIG_FILE = 'configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py'
|
|
self.run_test(CONFIG_FILE, 'CLASSIFICATION_INCEPTIONV3', 299)
|
|
|
|
def test_inceptionv4(self):
|
|
CONFIG_FILE = 'configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py'
|
|
self.run_test(CONFIG_FILE, 'CLASSIFICATION_INCEPTIONV4', 299)
|
|
|
|
def test_resnext50(self):
|
|
CONFIG_FILE = 'configs/classification/imagenet/resnext/imagenet_resnext50-32x4d_jpg.py'
|
|
self.run_test(
|
|
CONFIG_FILE,
|
|
'CLASSIFICATION_RESNEXT',
|
|
checkpoint=
|
|
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext50-32x4d/epoch_100.pth'
|
|
)
|
|
|
|
def test_mobilenetv2(self):
|
|
CONFIG_FILE = 'configs/classification/imagenet/mobilenet/mobilenetv2.py'
|
|
self.run_test(
|
|
CONFIG_FILE,
|
|
'CLASSIFICATION_M0BILENET',
|
|
checkpoint=
|
|
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/mobilenetv2/mobilenet_v2.pth'
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|