# Copyright (c) Alibaba, Inc. and its affiliates. import copy import logging import os import sys import tempfile import unittest import torch from mmcv import Config from tests.ut_config import (COMPRESSION_TEST_DATA, PRETRAINED_MODEL_YOLOX_COMPRESSION) from easycv.file import io from easycv.utils.test_util import run_in_subprocess sys.path.append(os.path.dirname(os.path.realpath(__file__))) logging.basicConfig(level=logging.INFO) SMALL_IMAGENET_DATA_ROOT = COMPRESSION_TEST_DATA.rstrip('/') + '/' _QUANTIZE_OPTIONS = { 'total_epochs': 1, 'data.imgs_per_gpu': 16, } _PRUNE_OPTIONS = { 'total_epochs': 11, 'data.imgs_per_gpu': 1, } TRAIN_CONFIGS = [ { 'config_file': 'configs/edge_models/yolox_edge.py', 'cfg_options': { **_QUANTIZE_OPTIONS, 'data.train.data_source.ann_file': SMALL_IMAGENET_DATA_ROOT + 'annotations/instances_train2017.json', 'data.train.data_source.img_prefix': SMALL_IMAGENET_DATA_ROOT + 'images', 'data.val.data_source.ann_file': SMALL_IMAGENET_DATA_ROOT + 'annotations/instances_train2017.json', 'data.val.data_source.img_prefix': SMALL_IMAGENET_DATA_ROOT + 'images' } }, { 'config_file': 'configs/edge_models/yolox_edge.py', 'cfg_options': { **_PRUNE_OPTIONS, 'img_scale': (128, 128), 'data.train.data_source.ann_file': SMALL_IMAGENET_DATA_ROOT + 'annotations/instances_train2017.json', 'data.train.data_source.img_prefix': SMALL_IMAGENET_DATA_ROOT + 'images', 'data.val.data_source.ann_file': SMALL_IMAGENET_DATA_ROOT + 'annotations/instances_train2017.json', 'data.val.data_source.img_prefix': SMALL_IMAGENET_DATA_ROOT + 'images' } }, ] class ModelQuantizeTest(unittest.TestCase): def setUp(self): print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) def tearDown(self): super().tearDown() def _base_quantize(self, train_cfgs): cfg_file = train_cfgs.pop('config_file') cfg_options = train_cfgs.pop('cfg_options', None) work_dir = train_cfgs.pop('work_dir', None) if not work_dir: work_dir = tempfile.TemporaryDirectory().name cfg = Config.fromfile(cfg_file) if cfg_options is not None: cfg.merge_from_dict(cfg_options) cfg.eval_pipelines[0].data = dict(**cfg.data.val) tmp_cfg_file = tempfile.NamedTemporaryFile(suffix='.py').name cfg.dump(tmp_cfg_file) ckpt_path = PRETRAINED_MODEL_YOLOX_COMPRESSION args_str = ' '.join( ['='.join((str(k), str(v))) for k, v in train_cfgs.items()]) cmd = 'python tools/quantize.py %s %s --work_dir=%s %s' % \ (tmp_cfg_file, ckpt_path, work_dir, args_str) logging.info('run command: %s' % cmd) run_in_subprocess(cmd) output_files = io.listdir(work_dir) self.assertIn('quantize_model.pt', output_files) io.remove(work_dir) io.remove(tmp_cfg_file) # @unittest.skipIf(torch.__version__ < '1.8.0', # 'model compression need pytorch version >= 1.8.0') # TODO: fix this unittest @unittest.skipIf(True, 'some bugs need to be fixed') def test_model_quantize(self): train_cfgs = copy.deepcopy(TRAIN_CONFIGS[0]) self._base_quantize(train_cfgs) if __name__ == '__main__': unittest.main()