EasyCV/tests/apis/test_export.py

125 lines
4.8 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import os
import subprocess
import tempfile
import unittest
import numpy as np
import torch
from tests.ut_config import (IMAGENET_LABEL_TXT, PRETRAINED_MODEL_MOCO,
PRETRAINED_MODEL_RESNET50,
PRETRAINED_MODEL_YOLOXS_EXPORT)
from easycv.apis.export import export
from easycv.utils.config_tools import mmcv_config_fromfile
from easycv.utils.test_util import clean_up, get_tmp_dir
class ModelExportTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = get_tmp_dir()
print('tmp dir %s' % self.tmp_dir)
def tearDown(self):
clean_up(self.tmp_dir)
def test_export_moco(self):
config_file = 'configs/selfsup/mocov2/mocov2_rn50_8xb32_200e_tfrecord.py'
ori_ckpt = PRETRAINED_MODEL_MOCO
ckpt_path = f'{self.tmp_dir}/moco_export.pth'
stat, output = subprocess.getstatusoutput(
f'python tools/export.py {config_file} {ori_ckpt} {ckpt_path}')
if stat != 0:
print(output)
self.assertTrue(stat == 0, 'export model failed')
def test_export_yolox(self):
config_file = 'configs/detection/yolox/yolox_s_8xb16_300e_coco.py'
ori_ckpt = PRETRAINED_MODEL_YOLOXS_EXPORT
ckpt_path = f'{self.tmp_dir}/export_yolox_s_epoch300.pt'
stat, output = subprocess.getstatusoutput(
f'python tools/export.py {config_file} {ori_ckpt} {ckpt_path}')
if stat != 0:
print(output)
self.assertTrue(stat == 0, 'export model failed')
def test_export_yolox_jit(self):
config_file = 'configs/detection/yolox/yolox_s_8xb16_300e_coco.py'
cfg = mmcv_config_fromfile(config_file)
cfg.export = dict(use_jit=True, export_blade=False, end2end=False)
ori_ckpt = PRETRAINED_MODEL_YOLOXS_EXPORT
target_path = f'{self.tmp_dir}/export_yolox_s_epoch300_export'
export(cfg, ori_ckpt, target_path)
self.assertTrue(os.path.exists(target_path + '.jit'))
self.assertTrue(os.path.exists(target_path + '.jit.config.json'))
def test_export_yolox_jit_end2end(self):
config_file = 'configs/detection/yolox/yolox_s_8xb16_300e_coco.py'
cfg = mmcv_config_fromfile(config_file)
cfg.export = dict(use_jit=True, export_blade=False, end2end=True)
ori_ckpt = PRETRAINED_MODEL_YOLOXS_EXPORT
target_path = f'{self.tmp_dir}/export_yolox_s_epoch300_end2end'
export(cfg, ori_ckpt, target_path)
self.assertTrue(os.path.exists(target_path + '.jit'))
self.assertTrue(os.path.exists(target_path + '.jit.config.json'))
def test_export_classification_jit(self):
config_file = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py'
cfg = mmcv_config_fromfile(config_file)
cfg.model.pretrained = False
cfg.model.backbone = dict(
type='ResNetJIT',
depth=50,
out_indices=[4],
norm_cfg=dict(type='BN'))
cfg.export = dict(use_jit=True)
ori_ckpt = PRETRAINED_MODEL_RESNET50
target_ckpt = f'{self.tmp_dir}/classification.pth.jit'
export(cfg, ori_ckpt, target_ckpt)
self.assertTrue(os.path.exists(target_ckpt))
def test_export_classification_and_inference(self):
config_file = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py'
cfg = mmcv_config_fromfile(config_file)
cfg.export = dict(use_jit=False)
ori_ckpt = PRETRAINED_MODEL_RESNET50
target_ckpt = f'{self.tmp_dir}/classification_export.pth'
export(cfg, ori_ckpt, target_ckpt)
self.assertTrue(os.path.exists(target_ckpt))
from easycv.predictors.classifier import TorchClassifier
classifier = TorchClassifier(
target_ckpt, label_map_path=IMAGENET_LABEL_TXT)
img = np.random.randint(0, 255, (256, 256, 3))
r = classifier.predict([img])
def test_export_cls_syncbn(self):
config_file = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py'
cfg = mmcv_config_fromfile(config_file)
cfg_options = {
'model.backbone.norm_cfg.type': 'SyncBN',
}
if cfg_options is not None:
cfg.merge_from_dict(cfg_options)
tmp_cfg_file = tempfile.NamedTemporaryFile(suffix='.py').name
cfg.dump(tmp_cfg_file)
ori_ckpt = PRETRAINED_MODEL_RESNET50
target_ckpt = f'{self.tmp_dir}/classification.pth.jit'
export(cfg, ori_ckpt, target_ckpt)
export_config_str = torch.load(target_ckpt)['meta']['config']
export_config = json.loads(export_config_str)
self.assertTrue(
export_config['model']['backbone']['norm_cfg']['type'] == 'BN')
if __name__ == '__main__':
unittest.main()