EasyCV/tests/apis/test_export_blade.py
zouxinyi0625 9ad1a8f448
Fix YOLOX-Edge problem and update yolox-pai turtorial (#182)
* complete yolox introduction

* fix env problem with nni2.6.1

* fix yolox-edge problem

*  add docker turtorial

* add blade test

* support old export yolox model

Co-authored-by: wuziheng <zhoulou.wzh@alibaba-inc.com>
2022-09-06 17:11:12 +08:00

67 lines
2.1 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import os
import subprocess
import tempfile
import unittest
import numpy as np
import torch
from tests.ut_config import (PRETRAINED_MODEL_RESNET50,
PRETRAINED_MODEL_YOLOXS_EXPORT)
from easycv.apis.export import export
from easycv.utils.config_tools import mmcv_config_fromfile
from easycv.utils.test_util import clean_up, get_tmp_dir
@unittest.skipIf(torch.__version__ != '1.8.1+cu102',
'Blade need another environment')
class ModelExportTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = get_tmp_dir()
print('tmp dir %s' % self.tmp_dir)
def tearDown(self):
clean_up(self.tmp_dir)
def test_export_yolox_blade_nopre_notrt(self):
config_file = 'configs/detection/yolox/yolox_s_8xb16_300e_coco.py'
cfg = mmcv_config_fromfile(config_file)
cfg.export = dict(
export_type='blade',
preprocess_jit=False,
use_trt_efficientnms=False)
ori_ckpt = PRETRAINED_MODEL_YOLOXS_EXPORT
target_path = f'{self.tmp_dir}/export_yolox_s_epoch300_export'
export(cfg, ori_ckpt, target_path)
self.assertTrue(os.path.exists(target_path + '.blade'))
self.assertTrue(os.path.exists(target_path + '.blade.config.json'))
def test_export_yolox_blade_pre_notrt(self):
config_file = 'configs/detection/yolox/yolox_s_8xb16_300e_coco.py'
cfg = mmcv_config_fromfile(config_file)
cfg.export = dict(
export_type='blade',
preprocess_jit=True,
use_trt_efficientnms=False)
ori_ckpt = PRETRAINED_MODEL_YOLOXS_EXPORT
target_path = f'{self.tmp_dir}/export_yolox_s_epoch300_export'
export(cfg, ori_ckpt, target_path)
self.assertTrue(os.path.exists(target_path + '.blade'))
self.assertTrue(os.path.exists(target_path + '.blade.config.json'))
self.assertTrue(os.path.exists(target_path + '.preprocess'))
if __name__ == '__main__':
unittest.main()