EasyCV/tests/test_tools/test_export.py
lostkevin 31897984d8
implement onnx export for inception3/4, resnext, mobilenetv2 (#346)
* add inceptionv4 backbone/training settings
* add converted backbone, top-1 acc 80.08
2024-07-18 16:52:56 +08:00

139 lines
4.8 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import sys
import unittest
import numpy as np
import onnxruntime
import torch
from easycv.models import build_model
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.config_tools import mmcv_config_fromfile, rebuild_config
from easycv.utils.test_util import run_in_subprocess
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
logging.basicConfig(level=logging.INFO)
WORK_DIRECTORY = 'work_dir3'
BASIC_EXPORT_CONFIGS = {
'config_file': None,
'checkpoint': 'dummy',
'output_filename': f'{WORK_DIRECTORY}/test_out.pth',
'user_config_params': ['--export.export_type', 'onnx']
}
def build_cmd(export_configs, MODEL_TYPE) -> str:
base_cmd = 'python tools/export.py'
base_cmd += f" {export_configs['config_file']}"
base_cmd += f" {export_configs['checkpoint']}"
base_cmd += f" {export_configs['output_filename']}"
base_cmd += f' --model_type {MODEL_TYPE}'
user_params = ' '.join(export_configs['user_config_params'])
base_cmd += f' --user_config_params {user_params}'
return base_cmd
class ExportTest(unittest.TestCase):
"""In this unittest, we test the onnx export functionality of
some classification/detection models.
"""
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
os.makedirs(WORK_DIRECTORY, exist_ok=True)
def tearDown(self):
super().tearDown()
def run_test(self,
CONFIG_FILE,
MODEL_TYPE,
img_size: int = 224,
**override_configs):
configs = BASIC_EXPORT_CONFIGS.copy()
configs['config_file'] = CONFIG_FILE
configs.update(override_configs)
cmd = build_cmd(configs, MODEL_TYPE)
logging.info(f'Export with commands: {cmd}')
run_in_subprocess(cmd)
cfg = mmcv_config_fromfile(configs['config_file'])
cfg = rebuild_config(cfg, configs['user_config_params'])
if hasattr(cfg.model, 'pretrained'):
cfg.model.pretrained = False
torch_model = build_model(cfg.model).eval()
if 'checkpoint' in override_configs:
load_checkpoint(
torch_model,
override_configs['checkpoint'],
strict=False,
logger=logging.getLogger())
session = onnxruntime.InferenceSession(configs['output_filename'] +
'.onnx')
input_tensor = torch.randn((1, 3, img_size, img_size))
torch_output = torch_model(input_tensor, mode='test')['prob']
onnx_output = session.run(
[session.get_outputs()[0].name],
{session.get_inputs()[0].name: np.array(input_tensor)})
if isinstance(onnx_output, list):
onnx_output = onnx_output[0]
onnx_output = torch.tensor(onnx_output)
is_same_shape = torch_output.shape == onnx_output.shape
self.assertTrue(
is_same_shape,
f'The shapes of the two outputs are mismatch, got {torch_output.shape} and {onnx_output.shape}'
)
is_allclose = torch.allclose(torch_output, onnx_output)
torch_out_minmax = f'{float(torch_output.min())}~{float(torch_output.max())}'
onnx_out_minmax = f'{float(onnx_output.min())}~{float(onnx_output.max())}'
info_msg = f'got avg: {float(torch_output.mean())} and {float(onnx_output.mean())},'
info_msg += f' and range: {torch_out_minmax} and {onnx_out_minmax}'
self.assertTrue(
is_allclose,
f'The values between the two outputs are mismatch, {info_msg}')
def test_inceptionv3(self):
CONFIG_FILE = 'configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py'
self.run_test(CONFIG_FILE, 'CLASSIFICATION_INCEPTIONV3', 299)
def test_inceptionv4(self):
CONFIG_FILE = 'configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py'
self.run_test(CONFIG_FILE, 'CLASSIFICATION_INCEPTIONV4', 299)
def test_resnext50(self):
CONFIG_FILE = 'configs/classification/imagenet/resnext/imagenet_resnext50-32x4d_jpg.py'
self.run_test(
CONFIG_FILE,
'CLASSIFICATION_RESNEXT',
checkpoint=
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext50-32x4d/epoch_100.pth'
)
def test_mobilenetv2(self):
CONFIG_FILE = 'configs/classification/imagenet/mobilenet/mobilenetv2.py'
self.run_test(
CONFIG_FILE,
'CLASSIFICATION_M0BILENET',
checkpoint=
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/mobilenetv2/mobilenet_v2.pth'
)
if __name__ == '__main__':
unittest.main()