EasyCV/tests/apis/test_export_blade.py

75 lines
2.8 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import os
import subprocess
import tempfile
import unittest
import numpy as np
import torch
from tests.ut_config import (PRETRAINED_MODEL_RESNET50,
PRETRAINED_MODEL_YOLOXS_EXPORT)
from easycv.apis.export import export
from easycv.utils.config_tools import mmcv_config_fromfile
from easycv.utils.test_util import clean_up, get_tmp_dir
@unittest.skipIf(torch.__version__ != '1.8.1+cu102',
'Blade need another environment')
class ModelExportTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = get_tmp_dir()
print('tmp dir %s' % self.tmp_dir)
def tearDown(self):
clean_up(self.tmp_dir)
def test_export_yolox_blade(self):
config_file = 'configs/detection/yolox/yolox_s_8xb16_300e_coco.py'
cfg = mmcv_config_fromfile(config_file)
cfg.export = dict(use_jit=True, export_blade=True, end2end=False)
ori_ckpt = PRETRAINED_MODEL_YOLOXS_EXPORT
target_path = f'{self.tmp_dir}/export_yolox_s_epoch300_export'
export(cfg, ori_ckpt, target_path)
self.assertTrue(os.path.exists(target_path + '.jit'))
self.assertTrue(os.path.exists(target_path + '.jit.config.json'))
self.assertTrue(os.path.exists(target_path + '.blade'))
self.assertTrue(os.path.exists(target_path + '.blade.config.json'))
def test_export_yolox_blade_nojit(self):
config_file = 'configs/detection/yolox/yolox_s_8xb16_300e_coco.py'
cfg = mmcv_config_fromfile(config_file)
cfg.export = dict(use_jit=False, export_blade=True, end2end=False)
ori_ckpt = PRETRAINED_MODEL_YOLOXS_EXPORT
target_path = f'{self.tmp_dir}/export_yolox_s_epoch300_export'
export(cfg, ori_ckpt, target_path)
self.assertFalse(os.path.exists(target_path + '.jit'))
self.assertFalse(os.path.exists(target_path + '.jit.config.json'))
self.assertTrue(os.path.exists(target_path + '.blade'))
self.assertTrue(os.path.exists(target_path + '.blade.config.json'))
def test_export_yolox_blade_end2end(self):
config_file = 'configs/detection/yolox/yolox_s_8xb16_300e_coco.py'
cfg = mmcv_config_fromfile(config_file)
cfg.export = dict(use_jit=True, export_blade=True, end2end=True)
ori_ckpt = PRETRAINED_MODEL_YOLOXS_EXPORT
target_path = f'{self.tmp_dir}/export_yolox_s_epoch300_end2end'
export(cfg, ori_ckpt, target_path)
self.assertTrue(os.path.exists(target_path + '.jit'))
self.assertTrue(os.path.exists(target_path + '.jit.config.json'))
self.assertTrue(os.path.exists(target_path + '.blade'))
self.assertTrue(os.path.exists(target_path + '.blade.config.json'))
if __name__ == '__main__':
unittest.main()