EasyCV/tests/tools/test_quantize.py

96 lines
2.9 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import logging
import os
import sys
import tempfile
import unittest
from distutils.version import LooseVersion
import torch
from mmcv import Config
from tests.ut_config import (COMPRESSION_TEST_DATA,
PRETRAINED_MODEL_YOLOX_COMPRESSION)
from easycv.file import io
from easycv.utils.test_util import run_in_subprocess
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
logging.basicConfig(level=logging.INFO)
SMALL_IMAGENET_DATA_ROOT = COMPRESSION_TEST_DATA.rstrip('/') + '/'
_QUANTIZE_OPTIONS = {
'total_epochs': 1,
'data.imgs_per_gpu': 16,
}
TRAIN_CONFIGS = [{
'config_file': 'configs/edge_models/yolox_edge.py',
'model_type': 'YOLOX_EDGE',
'cfg_options': {
**_QUANTIZE_OPTIONS, 'data.train.data_source.ann_file':
SMALL_IMAGENET_DATA_ROOT + 'annotations/instances_train2017.json',
'data.train.data_source.img_prefix':
SMALL_IMAGENET_DATA_ROOT + 'images',
'data.val.data_source.ann_file':
SMALL_IMAGENET_DATA_ROOT + 'annotations/instances_train2017.json',
'data.val.data_source.img_prefix':
SMALL_IMAGENET_DATA_ROOT + 'images'
}
}]
class ModelQuantizeTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def tearDown(self):
super().tearDown()
def _base_quantize(self, train_cfgs):
cfg_file = train_cfgs.pop('config_file')
cfg_options = train_cfgs.pop('cfg_options', None)
work_dir = train_cfgs.pop('work_dir', None)
model_type = train_cfgs.pop('model_type', None)
if not work_dir:
work_dir = tempfile.TemporaryDirectory().name
cfg = Config.fromfile(cfg_file)
if cfg_options is not None:
cfg.merge_from_dict(cfg_options)
cfg.eval_pipelines[0].data = dict(**cfg.data.val)
tmp_cfg_file = tempfile.NamedTemporaryFile(suffix='.py').name
cfg.dump(tmp_cfg_file)
ckpt_path = PRETRAINED_MODEL_YOLOX_COMPRESSION
args_str = ' '.join(
['='.join((str(k), str(v))) for k, v in train_cfgs.items()])
cmd = 'python tools/quantize.py %s %s --model_type=%s --work_dir=%s %s' % \
(tmp_cfg_file, ckpt_path, model_type, work_dir, args_str)
logging.info('run command: %s' % cmd)
run_in_subprocess(cmd)
output_files = io.listdir(work_dir)
self.assertIn('quantize_model.pt', output_files)
io.remove(work_dir)
io.remove(tmp_cfg_file)
@unittest.skipIf(
LooseVersion(torch.__version__) < LooseVersion('1.10.0'),
'model compression need pytorch version >= 1.10.0')
def test_model_quantize(self):
train_cfgs = copy.deepcopy(TRAIN_CONFIGS[0])
self._base_quantize(train_cfgs)
if __name__ == '__main__':
unittest.main()