ev quantize test update (#64)

This commit is contained in:
Xiaohe You 2022-05-16 18:40:36 +08:00 committed by GitHub
parent 14af96e21d
commit b97908e9cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 8 deletions

View File

@ -32,6 +32,7 @@ _PRUNE_OPTIONS = {
TRAIN_CONFIGS = [
{
'config_file': 'configs/edge_models/yolox_edge.py',
'model_type': 'YOLOX_EDGE',
'cfg_options': {
**_QUANTIZE_OPTIONS, 'data.train.data_source.ann_file':
SMALL_IMAGENET_DATA_ROOT + 'annotations/instances_train2017.json',
@ -45,6 +46,7 @@ TRAIN_CONFIGS = [
},
{
'config_file': 'configs/edge_models/yolox_edge.py',
'model_type': 'YOLOX_EDGE',
'cfg_options': {
**_PRUNE_OPTIONS, 'img_scale': (128, 128),
'data.train.data_source.ann_file':
@ -72,6 +74,7 @@ class ModelQuantizeTest(unittest.TestCase):
cfg_file = train_cfgs.pop('config_file')
cfg_options = train_cfgs.pop('cfg_options', None)
work_dir = train_cfgs.pop('work_dir', None)
model_type = train_cfgs.pop('model_type', None)
if not work_dir:
work_dir = tempfile.TemporaryDirectory().name
@ -88,8 +91,8 @@ class ModelQuantizeTest(unittest.TestCase):
args_str = ' '.join(
['='.join((str(k), str(v))) for k, v in train_cfgs.items()])
cmd = 'python tools/quantize.py %s %s --work_dir=%s %s' % \
(tmp_cfg_file, ckpt_path, work_dir, args_str)
cmd = 'python tools/quantize.py %s %s --model_type=%s --work_dir=%s %s' % \
(tmp_cfg_file, ckpt_path, model_type, work_dir, args_str)
logging.info('run command: %s' % cmd)
run_in_subprocess(cmd)
@ -100,10 +103,8 @@ class ModelQuantizeTest(unittest.TestCase):
io.remove(work_dir)
io.remove(tmp_cfg_file)
# @unittest.skipIf(torch.__version__ < '1.8.0',
# 'model compression need pytorch version >= 1.8.0')
# TODO: fix this unittest
@unittest.skipIf(True, 'some bugs need to be fixed')
@unittest.skipIf(torch.__version__ < '1.8.0',
'model compression need pytorch version >= 1.8.0')
def test_model_quantize(self):
train_cfgs = copy.deepcopy(TRAIN_CONFIGS[0])

View File

@ -76,7 +76,7 @@ def parse_args():
def main():
args = parse_args()
if args.model_type is not None:
if args.model_type is not None and args.config is None:
assert args.model_type in CONFIG_TEMPLATE_ZOO, 'model_type must be in [%s]' % (
', '.join(CONFIG_TEMPLATE_ZOO.keys()))
print('model_type=%s, config file will be replaced by %s' %

View File

@ -109,7 +109,7 @@ def quantize_eval(cfg, model, eval_mode):
def main():
args = parse_args()
if args.model_type is not None:
if args.model_type is not None and args.config is None:
assert args.model_type in CONFIG_TEMPLATE_ZOO, 'model_type must be in [%s]' % (
', '.join(CONFIG_TEMPLATE_ZOO.keys()))
print('model_type=%s, config file will be replaced by %s' %