mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
ev quantize test update (#64)
This commit is contained in:
parent
14af96e21d
commit
b97908e9cb
@ -32,6 +32,7 @@ _PRUNE_OPTIONS = {
|
||||
TRAIN_CONFIGS = [
|
||||
{
|
||||
'config_file': 'configs/edge_models/yolox_edge.py',
|
||||
'model_type': 'YOLOX_EDGE',
|
||||
'cfg_options': {
|
||||
**_QUANTIZE_OPTIONS, 'data.train.data_source.ann_file':
|
||||
SMALL_IMAGENET_DATA_ROOT + 'annotations/instances_train2017.json',
|
||||
@ -45,6 +46,7 @@ TRAIN_CONFIGS = [
|
||||
},
|
||||
{
|
||||
'config_file': 'configs/edge_models/yolox_edge.py',
|
||||
'model_type': 'YOLOX_EDGE',
|
||||
'cfg_options': {
|
||||
**_PRUNE_OPTIONS, 'img_scale': (128, 128),
|
||||
'data.train.data_source.ann_file':
|
||||
@ -72,6 +74,7 @@ class ModelQuantizeTest(unittest.TestCase):
|
||||
cfg_file = train_cfgs.pop('config_file')
|
||||
cfg_options = train_cfgs.pop('cfg_options', None)
|
||||
work_dir = train_cfgs.pop('work_dir', None)
|
||||
model_type = train_cfgs.pop('model_type', None)
|
||||
if not work_dir:
|
||||
work_dir = tempfile.TemporaryDirectory().name
|
||||
|
||||
@ -88,8 +91,8 @@ class ModelQuantizeTest(unittest.TestCase):
|
||||
args_str = ' '.join(
|
||||
['='.join((str(k), str(v))) for k, v in train_cfgs.items()])
|
||||
|
||||
cmd = 'python tools/quantize.py %s %s --work_dir=%s %s' % \
|
||||
(tmp_cfg_file, ckpt_path, work_dir, args_str)
|
||||
cmd = 'python tools/quantize.py %s %s --model_type=%s --work_dir=%s %s' % \
|
||||
(tmp_cfg_file, ckpt_path, model_type, work_dir, args_str)
|
||||
|
||||
logging.info('run command: %s' % cmd)
|
||||
run_in_subprocess(cmd)
|
||||
@ -100,10 +103,8 @@ class ModelQuantizeTest(unittest.TestCase):
|
||||
io.remove(work_dir)
|
||||
io.remove(tmp_cfg_file)
|
||||
|
||||
# @unittest.skipIf(torch.__version__ < '1.8.0',
|
||||
# 'model compression need pytorch version >= 1.8.0')
|
||||
# TODO: fix this unittest
|
||||
@unittest.skipIf(True, 'some bugs need to be fixed')
|
||||
@unittest.skipIf(torch.__version__ < '1.8.0',
|
||||
'model compression need pytorch version >= 1.8.0')
|
||||
def test_model_quantize(self):
|
||||
train_cfgs = copy.deepcopy(TRAIN_CONFIGS[0])
|
||||
|
||||
|
@ -76,7 +76,7 @@ def parse_args():
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.model_type is not None:
|
||||
if args.model_type is not None and args.config is None:
|
||||
assert args.model_type in CONFIG_TEMPLATE_ZOO, 'model_type must be in [%s]' % (
|
||||
', '.join(CONFIG_TEMPLATE_ZOO.keys()))
|
||||
print('model_type=%s, config file will be replaced by %s' %
|
||||
|
@ -109,7 +109,7 @@ def quantize_eval(cfg, model, eval_mode):
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.model_type is not None:
|
||||
if args.model_type is not None and args.config is None:
|
||||
assert args.model_type in CONFIG_TEMPLATE_ZOO, 'model_type must be in [%s]' % (
|
||||
', '.join(CONFIG_TEMPLATE_ZOO.keys()))
|
||||
print('model_type=%s, config file will be replaced by %s' %
|
||||
|
Loading…
x
Reference in New Issue
Block a user