From 0f74adb8482d8527d991fe1821696b5ca4dace83 Mon Sep 17 00:00:00 2001 From: "jiangnana.jnn" Date: Tue, 23 Aug 2022 19:52:52 +0800 Subject: [PATCH] add predict pipeline Link: https://code.alibaba-inc.com/pai-vision/EasyCV/codereview/9828601 * add predict pipeline --- .../segformer/segformer_b0_coco.py | 2 + .../segformer/segformer_b5_coco.py | 2 + easycv/datasets/shared/pipelines/__init__.py | 2 +- .../datasets/shared/pipelines/transforms.py | 49 +++++ easycv/predictors/__init__.py | 3 +- easycv/predictors/base.py | 176 +++++++++++++++++- easycv/predictors/builder.py | 4 +- easycv/predictors/segmentation.py | 104 +++++++++++ easycv/utils/ms_utils.py | 125 +++++++++++++ requirements/optional.txt | 2 + requirements/runtime.txt | 8 +- tests/predictors/test_segmentation.py | 107 +++++++++++ tests/utils/test_ms_utils.py | 47 +++++ 13 files changed, 621 insertions(+), 10 deletions(-) create mode 100644 easycv/utils/ms_utils.py create mode 100644 tests/predictors/test_segmentation.py create mode 100644 tests/utils/test_ms_utils.py diff --git a/configs/segmentation/segformer/segformer_b0_coco.py b/configs/segmentation/segformer/segformer_b0_coco.py index 9c7b5580..75dbefe8 100644 --- a/configs/segmentation/segformer/segformer_b0_coco.py +++ b/configs/segmentation/segformer/segformer_b0_coco.py @@ -233,6 +233,8 @@ eval_pipelines = [ ) ] +predict = dict(type='SegmentationPredictor') + log_config = dict( interval=50, hooks=[ diff --git a/configs/segmentation/segformer/segformer_b5_coco.py b/configs/segmentation/segformer/segformer_b5_coco.py index 4f0d4d23..9863cc14 100644 --- a/configs/segmentation/segformer/segformer_b5_coco.py +++ b/configs/segmentation/segformer/segformer_b5_coco.py @@ -233,6 +233,8 @@ eval_pipelines = [ ) ] +predict = dict(type='SegmentationPredictor') + log_config = dict( interval=50, hooks=[ diff --git a/easycv/datasets/shared/pipelines/__init__.py b/easycv/datasets/shared/pipelines/__init__.py index 2d1f6332..7a285840 100644 --- a/easycv/datasets/shared/pipelines/__init__.py +++ b/easycv/datasets/shared/pipelines/__init__.py @@ -4,4 +4,4 @@ from .dali_transforms import (DaliColorTwist, DaliCropMirrorNormalize, DaliImageDecoder, DaliRandomGrayscale, DaliRandomResizedCrop, DaliResize) from .format import Collect, DefaultFormatBundle, ImageToTensor -from .transforms import Compose +from .transforms import Compose, LoadImage diff --git a/easycv/datasets/shared/pipelines/transforms.py b/easycv/datasets/shared/pipelines/transforms.py index d10a1330..59a4b99a 100644 --- a/easycv/datasets/shared/pipelines/transforms.py +++ b/easycv/datasets/shared/pipelines/transforms.py @@ -2,7 +2,10 @@ import time from collections.abc import Sequence +import numpy as np + from easycv.datasets.registry import PIPELINES +from easycv.file.image import load_image from easycv.utils.registry import build_from_cfg @@ -48,3 +51,49 @@ class Compose(object): format_string += f'\n {t}' format_string += '\n)' return format_string + + +@PIPELINES.register_module() +class LoadImage: + """Load an image from file or numpy or PIL object. + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def __init__(self, to_float32=False, mode='bgr'): + self.to_float32 = to_float32 + self.mode = mode + + def __call__(self, results): + """Call functions to load image and get image meta information. + Returns: + dict: The dict contains loaded image and meta information. + """ + filename = results.get('filename', None) + img = results.get('img', None) + + if img is not None: + if not isinstance(img, np.ndarray): + img = np.asarray(img, dtype=np.uint8) + elif filename is None: + raise ValueError('Please provide "filename" or "img"!') + + img = load_image(filename, mode=self.mode) + if self.to_float32: + img = img.astype(np.float32) + + results['filename'] = filename + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + results['img_fields'] = ['img'] + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'to_float32={self.to_float32}, ' + f"mode='{self.mode}'") + + return repr_str diff --git a/easycv/predictors/__init__.py b/easycv/predictors/__init__.py index 7f1b2742..9577e75d 100644 --- a/easycv/predictors/__init__.py +++ b/easycv/predictors/__init__.py @@ -7,4 +7,5 @@ from .feature_extractor import (TorchFaceAttrExtractor, TorchFeatureExtractor) from .pose_predictor import (TorchPoseTopDownPredictor, TorchPoseTopDownPredictorWithDetector) -from .segmentation import Mask2formerPredictor, SegFormerPredictor +from .segmentation import (Mask2formerPredictor, SegFormerPredictor, + SegmentationPredictor) diff --git a/easycv/predictors/base.py b/easycv/predictors/base.py index b7143c94..9bc64bad 100644 --- a/easycv/predictors/base.py +++ b/easycv/predictors/base.py @@ -1,14 +1,16 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import pickle import numpy as np import torch +from mmcv.parallel import collate, scatter_kwargs from PIL import Image from torchvision.transforms import Compose from easycv.datasets.registry import PIPELINES from easycv.file import io -from easycv.models import build_model +from easycv.models.builder import build_model from easycv.utils.checkpoint import load_checkpoint from easycv.utils.config_tools import mmcv_config_fromfile from easycv.utils.constant import CACHE_DIR @@ -91,3 +93,175 @@ class Predictor(object): output = self.model.forward( image_batch.to(self.device), **forward_kwargs) return output + + +class PredictorV2(object): + """Base predict pipeline. + Args: + model_path (str): Path of model path. + config_file (Optinal[str]): config file path for model and processor to init. Defaults to None. + batch_size (int): batch size for forward. + device (str): Support 'cuda' or 'cpu', if is None, detect device automatically. + save_results (bool): Whether to save predict results. + save_path (str): File path for saving results, only valid when `save_results` is True. + """ + + def __init__(self, + model_path, + config_file=None, + batch_size=1, + device=None, + save_results=False, + save_path=None, + *args, + **kwargs): + self.model_path = model_path + self.batch_size = batch_size + self.save_results = save_results + self.save_path = save_path + if self.save_results: + assert self.save_path is not None + self.device = device + if self.device is None: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + + self.cfg = None + if config_file is not None: + if isinstance(config_file, str): + self.cfg = mmcv_config_fromfile(config_file) + else: + self.cfg = config_file + + self.model = self.prepare_model() + self.processor = self.build_processor() + self._load_op = None + + def prepare_model(self): + """Build model from config file by default. + If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it. + """ + model = self._build_model() + model.to(self.device) + model.eval() + load_checkpoint(model, self.model_path, map_location='cpu') + return model + + def _build_model(self): + if self.cfg is None: + raise ValueError('Please provide "config_file"!') + + model = build_model(self.cfg.model) + return model + + def build_processor(self): + """Build processor to process loaded input. + If you need custom preprocessing ops, you need to reimplement it. + """ + if self.cfg is None: + pipeline = [] + else: + pipeline = [ + build_from_cfg(p, PIPELINES) + for p in self.cfg.get('test_pipeline', []) + ] + + from easycv.datasets.shared.pipelines.transforms import Compose + processor = Compose(pipeline) + return processor + + def _load_input(self, input): + """Load image from file or numpy or PIL object. + Args: + input: File path or numpy or PIL object. + Returns: + { + 'filename': filename, + 'img': img, + 'img_shape': img_shape, + 'img_fields': ['img'] + } + """ + if self._load_op is None: + load_cfg = dict(type='LoadImage', mode='rgb') + self._load_op = build_from_cfg(load_cfg, PIPELINES) + + if not isinstance(input, str): + sample = self._load_op({'img': input}) + else: + sample = self._load_op({'filename': input}) + + return sample + + def preprocess_single(self, input): + """Preprocess single input sample. + If you need custom ops to load or process a single input sample, you need to reimplement it. + """ + input = self._load_input(input) + return self.processor(input) + + def preprocess(self, inputs, *args, **kwargs): + """Process all inputs list. And collate to batch and put to target device. + If you need custom ops to load or process a batch samples, you need to reimplement it. + """ + batch_outputs = [] + for i in inputs: + batch_outputs.append(self.preprocess_single(i, *args, **kwargs)) + + batch_outputs = self._collate_fn(batch_outputs) + batch_outputs = self._to_device(batch_outputs) + + return batch_outputs + + def forward(self, inputs): + """Model forward. + If you need refactor model forward, you need to reimplement it. + """ + with torch.no_grad(): + outputs = self.model(**inputs, mode='test') + return outputs + + def postprocess(self, inputs, *args, **kwargs): + """Process model outputs. + If you need add some processing ops to process model outputs, you need to reimplement it. + """ + return inputs + + def _collate_fn(self, inputs): + """Prepare the input just before the forward function. + Puts each data field into a tensor with outer dimension batch size + """ + return collate(inputs, samples_per_gpu=self.batch_size) + + def _to_device(self, inputs): + target_gpus = [-1] if self.device == 'cpu' else [ + torch.cuda.current_device() + ] + _, kwargs = scatter_kwargs(None, inputs, target_gpus=target_gpus) + return kwargs[0] + + @staticmethod + def dump(obj, save_path, mode='wb'): + with open(save_path, mode) as f: + f.write(pickle.dumps(obj)) + + def __call__(self, inputs, keep_inputs=False): + # TODO: fault tolerance + + if isinstance(inputs, str): + inputs = [inputs] + + results_list = [] + for i in range(0, len(inputs), self.batch_size): + batch = inputs[i:max(len(inputs) - 1, i + self.batch_size)] + batch_outputs = self.preprocess(batch) + batch_outputs = self.forward(batch_outputs) + results = self.postprocess(batch_outputs) + if keep_inputs: + results = {'inputs': batch, 'results': results} + # if dump, the outputs will not added to the return value to prevent taking up too much memory + if self.save_results: + self.dump([results], self.save_path, mode='ab+') + else: + results_list.append(results) + + return results_list diff --git a/easycv/predictors/builder.py b/easycv/predictors/builder.py index f387471c..ac73bc52 100644 --- a/easycv/predictors/builder.py +++ b/easycv/predictors/builder.py @@ -4,5 +4,5 @@ from easycv.utils.registry import Registry, build_from_cfg PREDICTORS = Registry('predictor') -def build_predictor(cfg): - return build_from_cfg(cfg, PREDICTORS, default_args=None) +def build_predictor(cfg, default_args=None): + return build_from_cfg(cfg, PREDICTORS, default_args=default_args) diff --git a/easycv/predictors/segmentation.py b/easycv/predictors/segmentation.py index 8f780fed..6916817b 100644 --- a/easycv/predictors/segmentation.py +++ b/easycv/predictors/segmentation.py @@ -16,6 +16,110 @@ from easycv.predictors.interface import PredictorInterface from easycv.utils.checkpoint import load_checkpoint from easycv.utils.config_tools import mmcv_config_fromfile from easycv.utils.registry import build_from_cfg +from .base import PredictorV2 + + +@PREDICTORS.register_module() +class SegmentationPredictor(PredictorV2): + + def __init__(self, + model_path, + config_file, + batch_size=1, + device=None, + save_results=False, + save_path=None): + """Predict pipeline for Segmentation + + Args: + model_path (str): Path of model path + config_file (str): config file path for model and processor to init. Defaults to None. + """ + super(SegmentationPredictor, self).__init__( + model_path, + config_file, + batch_size=batch_size, + device=device, + save_results=save_results, + save_path=save_path) + + self.CLASSES = self.cfg.CLASSES + self.PALETTE = self.cfg.PALETTE + + def show_result(self, + img, + result, + palette=None, + win_name='', + show=False, + wait_time=0, + out_file=None, + opacity=0.5): + """Draw `result` over `img`. + + Args: + img (str or Tensor): The image to be displayed. + result (Tensor): The semantic segmentation results to draw over + `img`. + palette (list[list[int]]] | np.ndarray | None): The palette of + segmentation map. If None is given, random palette will be + generated. Default: None + win_name (str): The window name. + wait_time (int): Value of waitKey param. + Default: 0. + show (bool): Whether to show the image. + Default: False. + out_file (str or None): The filename to write the image. + Default: None. + opacity(float): Opacity of painted segmentation map. + Default 0.5. + Must be in (0, 1] range. + Returns: + img (Tensor): Only if not `show` or `out_file` + """ + + img = mmcv.imread(img) + img = img.copy() + seg = result[0] + if palette is None: + if self.PALETTE is None: + # Get random state before set seed, + # and restore random state later. + # It will prevent loss of randomness, as the palette + # may be different in each iteration if not specified. + # See: https://github.com/open-mmlab/mmdetection/issues/5844 + state = np.random.get_state() + np.random.seed(42) + # random palette + palette = np.random.randint( + 0, 255, size=(len(self.CLASSES), 3)) + np.random.set_state(state) + else: + palette = self.PALETTE + palette = np.array(palette) + assert palette.shape[0] == len(self.CLASSES) + assert palette.shape[1] == 3 + assert len(palette.shape) == 2 + assert 0 < opacity <= 1.0 + color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) + for label, color in enumerate(palette): + color_seg[seg == label, :] = color + # convert to BGR + color_seg = color_seg[..., ::-1] + + img = img * (1 - opacity) + color_seg * opacity + img = img.astype(np.uint8) + # if out_file specified, do not show image in window + if out_file is not None: + show = False + + if show: + mmcv.imshow(img, win_name, wait_time) + if out_file is not None: + mmcv.imwrite(img, out_file) + + if not (show or out_file): + return img @PREDICTORS.register_module() diff --git a/easycv/utils/ms_utils.py b/easycv/utils/ms_utils.py new file mode 100644 index 00000000..351f495b --- /dev/null +++ b/easycv/utils/ms_utils.py @@ -0,0 +1,125 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os + +import jsonplus + +from easycv.file import io +from easycv.utils.config_tools import Config + +MODELSCOPE_PREFIX = 'modelscope' +EASYCV_ARCH = '__easycv_arch__' + + +def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True): + """Convert EasyCV config to ModelScope style. + + Args: + cfg (str | Config): Easycv config file or Config object. + task (str): Task name in modelscope, refer to: modelscope.utils.constant.Tasks. + ms_model_name (str): Model name registered in modelscope, model type will be replaced with `ms_model_name`, used in modelscope. + save_path (str): Save path for saving the generated modelscope configuration file. Only valid when dump is True. + dump (bool): Whether dump the converted config to `save_path`. + """ + # TODO: support multi eval_pipelines + # TODO: support for adding customized required keys to the configuration file + + if isinstance(cfg, str): + easycv_cfg = Config.fromfile(cfg) + if dump and save_path is None: + save_dir = os.path.dirname(cfg) + save_name = MODELSCOPE_PREFIX + '_' + os.path.splitext( + os.path.basename(cfg))[0] + '.json' + save_path = os.path.join(save_dir, save_name) + else: + easycv_cfg = cfg + if dump and save_path is None: + raise ValueError('Please provide `save_path`!') + + assert save_path.endswith('json'), 'Only support json file!' + optimizer_options = easycv_cfg.optimizer_config + optimizer_options.update({'loss_keys': 'total_loss'}) + + val_dataset_cfg = easycv_cfg.data.val + val_imgs_per_gpu = val_dataset_cfg.pop('imgs_per_gpu', + easycv_cfg.data.imgs_per_gpu) + val_workers_per_gpu = val_dataset_cfg.pop('workers_per_gpu', + easycv_cfg.data.workers_per_gpu) + + log_config = easycv_cfg.log_config + predict_config = easycv_cfg.get('predict', None) + + hooks = [{ + 'type': 'CheckpointHook', + 'interval': easycv_cfg.checkpoint_config.interval + }, { + 'type': 'EvaluationHook', + 'interval': easycv_cfg.eval_config.interval + }, { + 'type': 'AddLrLogHook' + }, { + 'type': 'IterTimerHook' + }] + + custom_hooks = easycv_cfg.get('custom_hooks', []) + hooks.extend(custom_hooks) + + for log_hook_i in log_config.hooks: + if log_hook_i['type'] == 'TensorboardLoggerHook': + # replace with modelscope api + hooks.append({ + 'type': 'TensorboardHook', + 'interval': log_config.interval + }) + elif log_hook_i['type'] == 'TextLoggerHook': + # use modelscope api + hooks.append({ + 'type': 'TextLoggerHook', + 'interval': log_config.interval + }) + else: + log_hook_i.update({'interval': log_config.interval}) + hooks.append(log_hook_i) + + ori_model_type = easycv_cfg.model.pop('type') + + ms_cfg = Config( + dict( + task=task, + framework='pytorch', + model={ + 'type': ms_model_name, + **easycv_cfg.model, EASYCV_ARCH: { + 'type': ori_model_type + } + }, + dataset=dict(train=easycv_cfg.data.train, val=val_dataset_cfg), + train=dict( + work_dir=easycv_cfg.get('work_dir', None), + max_epochs=easycv_cfg.total_epochs, + dataloader=dict( + batch_size_per_gpu=easycv_cfg.data.imgs_per_gpu, + workers_per_gpu=easycv_cfg.data.workers_per_gpu, + ), + optimizer=dict( + **easycv_cfg.optimizer, options=optimizer_options), + lr_scheduler=easycv_cfg.lr_config, + hooks=hooks), + evaluation=dict( + dataloader=dict( + batch_size_per_gpu=val_imgs_per_gpu, + workers_per_gpu=val_workers_per_gpu, + ), + metrics={ + 'type': 'EasyCVMetric', + 'evaluators': easycv_cfg.eval_pipelines[0].evaluators + }), + pipeline=dict(predictor_config=predict_config), + )) + + if dump: + with io.open(save_path, 'w') as f: + res = jsonplus.dumps( + ms_cfg._cfg_dict.to_dict(), indent=4, sort_keys=False) + f.write(res) + + return ms_cfg diff --git a/requirements/optional.txt b/requirements/optional.txt index 4b97d767..0fd17691 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,2 +1,4 @@ http://pai-nni.oss-cn-zhangjiakou.aliyuncs.com/release/2.6.1/pai_nni-2.6.1-py3-none-manylinux1_x86_64.whl +http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/pkgs/whl/panopticapi/panopticapi-0.1-py3-none-any.whl http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/blade_compression-0.0.2-py3-none-any.whl +https://developer.download.nvidia.com/compute/redist/nvidia-dali-cuda100/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 9deeae54..9c7fd6d7 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -3,21 +3,19 @@ dataclasses einops future h5py -http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/pkgs/whl/panopticapi/panopticapi-0.1-py3-none-any.whl -https://developer.download.nvidia.com/compute/redist/nvidia-dali-cuda100/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl json_tricks numpy -opencv-python-headless +opencv-python oss2 packaging Pillow prettytable pycocotools -pytorch_metric_learning==0.9.89 +pytorch_metric_learning>=0.9.89 scikit-image sklearn tensorboard thop -timm==0.4.9 +timm>=0.4.9 xtcocotools yacs diff --git a/tests/predictors/test_segmentation.py b/tests/predictors/test_segmentation.py new file mode 100644 index 00000000..e84a3e1a --- /dev/null +++ b/tests/predictors/test_segmentation.py @@ -0,0 +1,107 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import pickle +import shutil +import tempfile +import unittest + +import numpy as np +from PIL import Image +from tests.ut_config import (MODEL_CONFIG_SEGFORMER, + PRETRAINED_MODEL_SEGFORMER, TEST_IMAGES_DIR) + +from easycv.predictors.segmentation import SegmentationPredictor + + +class SegmentationPredictorTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def test_single(self): + segmentation_model_path = PRETRAINED_MODEL_SEGFORMER + segmentation_model_config = MODEL_CONFIG_SEGFORMER + + img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg') + img = np.asarray(Image.open(img_path)) + + predict_pipeline = SegmentationPredictor( + model_path=segmentation_model_path, + config_file=segmentation_model_config) + + outputs = predict_pipeline(img_path, keep_inputs=True) + self.assertEqual(len(outputs), 1) + self.assertEqual(outputs[0]['inputs'], [img_path]) + + results = outputs[0]['results'] + self.assertListEqual( + list(img.shape)[:2], list(results['seg_pred'][0].shape)) + self.assertListEqual(results['seg_pred'][0][1, :10].tolist(), + [161 for i in range(10)]) + self.assertListEqual(results['seg_pred'][0][-1, -10:].tolist(), + [133 for i in range(10)]) + + def test_batch(self): + segmentation_model_path = PRETRAINED_MODEL_SEGFORMER + segmentation_model_config = MODEL_CONFIG_SEGFORMER + + img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg') + img = np.asarray(Image.open(img_path)) + + predict_pipeline = SegmentationPredictor( + model_path=segmentation_model_path, + config_file=segmentation_model_config, + batch_size=2) + + total_samples = 3 + outputs = predict_pipeline( + [img_path] * total_samples, keep_inputs=True) + self.assertEqual(len(outputs), 2) + + self.assertEqual(outputs[0]['inputs'], [img_path] * 2) + self.assertEqual(outputs[1]['inputs'], [img_path] * 1) + self.assertEqual(len(outputs[0]['results']['seg_pred']), 2) + self.assertEqual(len(outputs[1]['results']['seg_pred']), 1) + + for result in [outputs[0]['results'], outputs[1]['results']]: + self.assertListEqual( + list(img.shape)[:2], list(result['seg_pred'][0].shape)) + self.assertListEqual(result['seg_pred'][0][1, :10].tolist(), + [161 for i in range(10)]) + self.assertListEqual(result['seg_pred'][0][-1, -10:].tolist(), + [133 for i in range(10)]) + + def test_dump(self): + segmentation_model_path = PRETRAINED_MODEL_SEGFORMER + segmentation_model_config = MODEL_CONFIG_SEGFORMER + + img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg') + + temp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(temp_dir): + os.makedirs(temp_dir) + tmp_path = os.path.join(temp_dir, 'results.pkl') + + predict_pipeline = SegmentationPredictor( + model_path=segmentation_model_path, + config_file=segmentation_model_config, + batch_size=2, + save_results=True, + save_path=tmp_path) + + total_samples = 3 + outputs = predict_pipeline( + [img_path] * total_samples, keep_inputs=True) + self.assertEqual(outputs, []) + + with open(tmp_path, 'rb') as f: + results = pickle.loads(f.read()) + + self.assertIn('inputs', results[0]) + self.assertIn('results', results[0]) + + shutil.rmtree(temp_dir, ignore_errors=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils/test_ms_utils.py b/tests/utils/test_ms_utils.py new file mode 100644 index 00000000..88da7d27 --- /dev/null +++ b/tests/utils/test_ms_utils.py @@ -0,0 +1,47 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +import easycv +from easycv.utils.config_tools import Config +from easycv.utils.ms_utils import to_ms_config + + +class MsConfigTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + def test_to_ms_config(self): + + config_path = os.path.join( + os.path.dirname(os.path.dirname(easycv.__file__)), + 'configs/detection/yolox/yolox_s_8xb16_300e_coco.py') + + ms_cfg_file = os.path.join(self.tmp_dir, + 'ms_yolox_s_8xb16_300e_coco.json') + to_ms_config( + config_path, + task='image-object-detection', + ms_model_name='yolox', + save_path=ms_cfg_file) + cfg = Config.fromfile(ms_cfg_file) + self.assertIn('task', cfg) + self.assertIn('framework', cfg) + self.assertEqual(cfg.model.type, 'yolox') + self.assertIn('dataset', cfg) + self.assertIn('batch_size_per_gpu', cfg.train.dataloader) + self.assertIn('batch_size_per_gpu', cfg.evaluation.dataloader) + + +if __name__ == '__main__': + unittest.main()