mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
ev quantize test update (#64)
This commit is contained in:
parent
14af96e21d
commit
b97908e9cb
@ -32,6 +32,7 @@ _PRUNE_OPTIONS = {
|
|||||||
TRAIN_CONFIGS = [
|
TRAIN_CONFIGS = [
|
||||||
{
|
{
|
||||||
'config_file': 'configs/edge_models/yolox_edge.py',
|
'config_file': 'configs/edge_models/yolox_edge.py',
|
||||||
|
'model_type': 'YOLOX_EDGE',
|
||||||
'cfg_options': {
|
'cfg_options': {
|
||||||
**_QUANTIZE_OPTIONS, 'data.train.data_source.ann_file':
|
**_QUANTIZE_OPTIONS, 'data.train.data_source.ann_file':
|
||||||
SMALL_IMAGENET_DATA_ROOT + 'annotations/instances_train2017.json',
|
SMALL_IMAGENET_DATA_ROOT + 'annotations/instances_train2017.json',
|
||||||
@ -45,6 +46,7 @@ TRAIN_CONFIGS = [
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
'config_file': 'configs/edge_models/yolox_edge.py',
|
'config_file': 'configs/edge_models/yolox_edge.py',
|
||||||
|
'model_type': 'YOLOX_EDGE',
|
||||||
'cfg_options': {
|
'cfg_options': {
|
||||||
**_PRUNE_OPTIONS, 'img_scale': (128, 128),
|
**_PRUNE_OPTIONS, 'img_scale': (128, 128),
|
||||||
'data.train.data_source.ann_file':
|
'data.train.data_source.ann_file':
|
||||||
@ -72,6 +74,7 @@ class ModelQuantizeTest(unittest.TestCase):
|
|||||||
cfg_file = train_cfgs.pop('config_file')
|
cfg_file = train_cfgs.pop('config_file')
|
||||||
cfg_options = train_cfgs.pop('cfg_options', None)
|
cfg_options = train_cfgs.pop('cfg_options', None)
|
||||||
work_dir = train_cfgs.pop('work_dir', None)
|
work_dir = train_cfgs.pop('work_dir', None)
|
||||||
|
model_type = train_cfgs.pop('model_type', None)
|
||||||
if not work_dir:
|
if not work_dir:
|
||||||
work_dir = tempfile.TemporaryDirectory().name
|
work_dir = tempfile.TemporaryDirectory().name
|
||||||
|
|
||||||
@ -88,8 +91,8 @@ class ModelQuantizeTest(unittest.TestCase):
|
|||||||
args_str = ' '.join(
|
args_str = ' '.join(
|
||||||
['='.join((str(k), str(v))) for k, v in train_cfgs.items()])
|
['='.join((str(k), str(v))) for k, v in train_cfgs.items()])
|
||||||
|
|
||||||
cmd = 'python tools/quantize.py %s %s --work_dir=%s %s' % \
|
cmd = 'python tools/quantize.py %s %s --model_type=%s --work_dir=%s %s' % \
|
||||||
(tmp_cfg_file, ckpt_path, work_dir, args_str)
|
(tmp_cfg_file, ckpt_path, model_type, work_dir, args_str)
|
||||||
|
|
||||||
logging.info('run command: %s' % cmd)
|
logging.info('run command: %s' % cmd)
|
||||||
run_in_subprocess(cmd)
|
run_in_subprocess(cmd)
|
||||||
@ -100,10 +103,8 @@ class ModelQuantizeTest(unittest.TestCase):
|
|||||||
io.remove(work_dir)
|
io.remove(work_dir)
|
||||||
io.remove(tmp_cfg_file)
|
io.remove(tmp_cfg_file)
|
||||||
|
|
||||||
# @unittest.skipIf(torch.__version__ < '1.8.0',
|
@unittest.skipIf(torch.__version__ < '1.8.0',
|
||||||
# 'model compression need pytorch version >= 1.8.0')
|
'model compression need pytorch version >= 1.8.0')
|
||||||
# TODO: fix this unittest
|
|
||||||
@unittest.skipIf(True, 'some bugs need to be fixed')
|
|
||||||
def test_model_quantize(self):
|
def test_model_quantize(self):
|
||||||
train_cfgs = copy.deepcopy(TRAIN_CONFIGS[0])
|
train_cfgs = copy.deepcopy(TRAIN_CONFIGS[0])
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ def parse_args():
|
|||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
if args.model_type is not None:
|
if args.model_type is not None and args.config is None:
|
||||||
assert args.model_type in CONFIG_TEMPLATE_ZOO, 'model_type must be in [%s]' % (
|
assert args.model_type in CONFIG_TEMPLATE_ZOO, 'model_type must be in [%s]' % (
|
||||||
', '.join(CONFIG_TEMPLATE_ZOO.keys()))
|
', '.join(CONFIG_TEMPLATE_ZOO.keys()))
|
||||||
print('model_type=%s, config file will be replaced by %s' %
|
print('model_type=%s, config file will be replaced by %s' %
|
||||||
|
@ -109,7 +109,7 @@ def quantize_eval(cfg, model, eval_mode):
|
|||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
if args.model_type is not None:
|
if args.model_type is not None and args.config is None:
|
||||||
assert args.model_type in CONFIG_TEMPLATE_ZOO, 'model_type must be in [%s]' % (
|
assert args.model_type in CONFIG_TEMPLATE_ZOO, 'model_type must be in [%s]' % (
|
||||||
', '.join(CONFIG_TEMPLATE_ZOO.keys()))
|
', '.join(CONFIG_TEMPLATE_ZOO.keys()))
|
||||||
print('model_type=%s, config file will be replaced by %s' %
|
print('model_type=%s, config file will be replaced by %s' %
|
||||||
|
Loading…
x
Reference in New Issue
Block a user