diff --git a/tests/tools/test_quantize.py b/tests/tools/test_quantize.py index d7cf38d2..dba9415b 100644 --- a/tests/tools/test_quantize.py +++ b/tests/tools/test_quantize.py @@ -32,6 +32,7 @@ _PRUNE_OPTIONS = { TRAIN_CONFIGS = [ { 'config_file': 'configs/edge_models/yolox_edge.py', + 'model_type': 'YOLOX_EDGE', 'cfg_options': { **_QUANTIZE_OPTIONS, 'data.train.data_source.ann_file': SMALL_IMAGENET_DATA_ROOT + 'annotations/instances_train2017.json', @@ -45,6 +46,7 @@ TRAIN_CONFIGS = [ }, { 'config_file': 'configs/edge_models/yolox_edge.py', + 'model_type': 'YOLOX_EDGE', 'cfg_options': { **_PRUNE_OPTIONS, 'img_scale': (128, 128), 'data.train.data_source.ann_file': @@ -72,6 +74,7 @@ class ModelQuantizeTest(unittest.TestCase): cfg_file = train_cfgs.pop('config_file') cfg_options = train_cfgs.pop('cfg_options', None) work_dir = train_cfgs.pop('work_dir', None) + model_type = train_cfgs.pop('model_type', None) if not work_dir: work_dir = tempfile.TemporaryDirectory().name @@ -88,8 +91,8 @@ class ModelQuantizeTest(unittest.TestCase): args_str = ' '.join( ['='.join((str(k), str(v))) for k, v in train_cfgs.items()]) - cmd = 'python tools/quantize.py %s %s --work_dir=%s %s' % \ - (tmp_cfg_file, ckpt_path, work_dir, args_str) + cmd = 'python tools/quantize.py %s %s --model_type=%s --work_dir=%s %s' % \ + (tmp_cfg_file, ckpt_path, model_type, work_dir, args_str) logging.info('run command: %s' % cmd) run_in_subprocess(cmd) @@ -100,10 +103,8 @@ class ModelQuantizeTest(unittest.TestCase): io.remove(work_dir) io.remove(tmp_cfg_file) - # @unittest.skipIf(torch.__version__ < '1.8.0', - # 'model compression need pytorch version >= 1.8.0') - # TODO: fix this unittest - @unittest.skipIf(True, 'some bugs need to be fixed') + @unittest.skipIf(torch.__version__ < '1.8.0', + 'model compression need pytorch version >= 1.8.0') def test_model_quantize(self): train_cfgs = copy.deepcopy(TRAIN_CONFIGS[0]) diff --git a/tools/prune.py b/tools/prune.py index 55bc2f4f..588aff0d 100644 --- a/tools/prune.py +++ b/tools/prune.py @@ -76,7 +76,7 @@ def parse_args(): def main(): args = parse_args() - if args.model_type is not None: + if args.model_type is not None and args.config is None: assert args.model_type in CONFIG_TEMPLATE_ZOO, 'model_type must be in [%s]' % ( ', '.join(CONFIG_TEMPLATE_ZOO.keys())) print('model_type=%s, config file will be replaced by %s' % diff --git a/tools/quantize.py b/tools/quantize.py index 4aac88be..5b1dd85b 100644 --- a/tools/quantize.py +++ b/tools/quantize.py @@ -109,7 +109,7 @@ def quantize_eval(cfg, model, eval_mode): def main(): args = parse_args() - if args.model_type is not None: + if args.model_type is not None and args.config is None: assert args.model_type in CONFIG_TEMPLATE_ZOO, 'model_type must be in [%s]' % ( ', '.join(CONFIG_TEMPLATE_ZOO.keys())) print('model_type=%s, config file will be replaced by %s' %