EasyCV/tests/utils/test_ms_utils.py

48 lines
1.4 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
import easycv
from easycv.utils.config_tools import Config
from easycv.utils.ms_utils import to_ms_config
class MsConfigTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
def tearDown(self):
super().tearDown()
shutil.rmtree(self.tmp_dir)
def test_to_ms_config(self):
config_path = os.path.join(
os.path.dirname(os.path.dirname(easycv.__file__)),
'configs/detection/yolox/yolox_s_8xb16_300e_coco.py')
ms_cfg_file = os.path.join(self.tmp_dir,
'ms_yolox_s_8xb16_300e_coco.json')
to_ms_config(
config_path,
task='image-object-detection',
ms_model_name='yolox',
save_path=ms_cfg_file)
cfg = Config.fromfile(ms_cfg_file)
self.assertIn('task', cfg)
self.assertIn('framework', cfg)
self.assertEqual(cfg.model.type, 'yolox')
self.assertIn('dataset', cfg)
self.assertIn('batch_size_per_gpu', cfg.train.dataloader)
self.assertIn('batch_size_per_gpu', cfg.evaluation.dataloader)
if __name__ == '__main__':
unittest.main()