mirror of https://github.com/alibaba/EasyCV.git
75 lines
2.8 KiB
Python
75 lines
2.8 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import json
|
|
import os
|
|
import subprocess
|
|
import tempfile
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import torch
|
|
from tests.ut_config import (PRETRAINED_MODEL_RESNET50,
|
|
PRETRAINED_MODEL_YOLOXS_EXPORT)
|
|
|
|
from easycv.apis.export import export
|
|
from easycv.utils.config_tools import mmcv_config_fromfile
|
|
from easycv.utils.test_util import clean_up, get_tmp_dir
|
|
|
|
|
|
@unittest.skipIf(torch.__version__ != '1.8.1+cu102',
|
|
'Blade need another environment')
|
|
class ModelExportTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
self.tmp_dir = get_tmp_dir()
|
|
print('tmp dir %s' % self.tmp_dir)
|
|
|
|
def tearDown(self):
|
|
clean_up(self.tmp_dir)
|
|
|
|
def test_export_yolox_blade(self):
|
|
config_file = 'configs/detection/yolox/yolox_s_8xb16_300e_coco.py'
|
|
cfg = mmcv_config_fromfile(config_file)
|
|
cfg.export = dict(use_jit=True, export_blade=True, end2end=False)
|
|
ori_ckpt = PRETRAINED_MODEL_YOLOXS_EXPORT
|
|
|
|
target_path = f'{self.tmp_dir}/export_yolox_s_epoch300_export'
|
|
|
|
export(cfg, ori_ckpt, target_path)
|
|
self.assertTrue(os.path.exists(target_path + '.jit'))
|
|
self.assertTrue(os.path.exists(target_path + '.jit.config.json'))
|
|
self.assertTrue(os.path.exists(target_path + '.blade'))
|
|
self.assertTrue(os.path.exists(target_path + '.blade.config.json'))
|
|
|
|
def test_export_yolox_blade_nojit(self):
|
|
config_file = 'configs/detection/yolox/yolox_s_8xb16_300e_coco.py'
|
|
cfg = mmcv_config_fromfile(config_file)
|
|
cfg.export = dict(use_jit=False, export_blade=True, end2end=False)
|
|
ori_ckpt = PRETRAINED_MODEL_YOLOXS_EXPORT
|
|
|
|
target_path = f'{self.tmp_dir}/export_yolox_s_epoch300_export'
|
|
|
|
export(cfg, ori_ckpt, target_path)
|
|
self.assertFalse(os.path.exists(target_path + '.jit'))
|
|
self.assertFalse(os.path.exists(target_path + '.jit.config.json'))
|
|
self.assertTrue(os.path.exists(target_path + '.blade'))
|
|
self.assertTrue(os.path.exists(target_path + '.blade.config.json'))
|
|
|
|
def test_export_yolox_blade_end2end(self):
|
|
config_file = 'configs/detection/yolox/yolox_s_8xb16_300e_coco.py'
|
|
cfg = mmcv_config_fromfile(config_file)
|
|
cfg.export = dict(use_jit=True, export_blade=True, end2end=True)
|
|
ori_ckpt = PRETRAINED_MODEL_YOLOXS_EXPORT
|
|
|
|
target_path = f'{self.tmp_dir}/export_yolox_s_epoch300_end2end'
|
|
|
|
export(cfg, ori_ckpt, target_path)
|
|
self.assertTrue(os.path.exists(target_path + '.jit'))
|
|
self.assertTrue(os.path.exists(target_path + '.jit.config.json'))
|
|
self.assertTrue(os.path.exists(target_path + '.blade'))
|
|
self.assertTrue(os.path.exists(target_path + '.blade.config.json'))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|