add predict pipeline

Link: https://code.alibaba-inc.com/pai-vision/EasyCV/codereview/9828601

    * add predict pipeline
pull/191/head
jiangnana.jnn 2022-08-23 19:52:52 +08:00
parent b3abdf507f
commit 0f74adb848
13 changed files with 621 additions and 10 deletions

View File

@ -233,6 +233,8 @@ eval_pipelines = [
) )
] ]
predict = dict(type='SegmentationPredictor')
log_config = dict( log_config = dict(
interval=50, interval=50,
hooks=[ hooks=[

View File

@ -233,6 +233,8 @@ eval_pipelines = [
) )
] ]
predict = dict(type='SegmentationPredictor')
log_config = dict( log_config = dict(
interval=50, interval=50,
hooks=[ hooks=[

View File

@ -4,4 +4,4 @@ from .dali_transforms import (DaliColorTwist, DaliCropMirrorNormalize,
DaliImageDecoder, DaliRandomGrayscale, DaliImageDecoder, DaliRandomGrayscale,
DaliRandomResizedCrop, DaliResize) DaliRandomResizedCrop, DaliResize)
from .format import Collect, DefaultFormatBundle, ImageToTensor from .format import Collect, DefaultFormatBundle, ImageToTensor
from .transforms import Compose from .transforms import Compose, LoadImage

View File

@ -2,7 +2,10 @@
import time import time
from collections.abc import Sequence from collections.abc import Sequence
import numpy as np
from easycv.datasets.registry import PIPELINES from easycv.datasets.registry import PIPELINES
from easycv.file.image import load_image
from easycv.utils.registry import build_from_cfg from easycv.utils.registry import build_from_cfg
@ -48,3 +51,49 @@ class Compose(object):
format_string += f'\n {t}' format_string += f'\n {t}'
format_string += '\n)' format_string += '\n)'
return format_string return format_string
@PIPELINES.register_module()
class LoadImage:
"""Load an image from file or numpy or PIL object.
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
"""
def __init__(self, to_float32=False, mode='bgr'):
self.to_float32 = to_float32
self.mode = mode
def __call__(self, results):
"""Call functions to load image and get image meta information.
Returns:
dict: The dict contains loaded image and meta information.
"""
filename = results.get('filename', None)
img = results.get('img', None)
if img is not None:
if not isinstance(img, np.ndarray):
img = np.asarray(img, dtype=np.uint8)
elif filename is None:
raise ValueError('Please provide "filename" or "img"!')
img = load_image(filename, mode=self.mode)
if self.to_float32:
img = img.astype(np.float32)
results['filename'] = filename
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
results['img_fields'] = ['img']
return results
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'to_float32={self.to_float32}, '
f"mode='{self.mode}'")
return repr_str

View File

@ -7,4 +7,5 @@ from .feature_extractor import (TorchFaceAttrExtractor,
TorchFeatureExtractor) TorchFeatureExtractor)
from .pose_predictor import (TorchPoseTopDownPredictor, from .pose_predictor import (TorchPoseTopDownPredictor,
TorchPoseTopDownPredictorWithDetector) TorchPoseTopDownPredictorWithDetector)
from .segmentation import Mask2formerPredictor, SegFormerPredictor from .segmentation import (Mask2formerPredictor, SegFormerPredictor,
SegmentationPredictor)

View File

@ -1,14 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import os import os
import pickle
import numpy as np import numpy as np
import torch import torch
from mmcv.parallel import collate, scatter_kwargs
from PIL import Image from PIL import Image
from torchvision.transforms import Compose from torchvision.transforms import Compose
from easycv.datasets.registry import PIPELINES from easycv.datasets.registry import PIPELINES
from easycv.file import io from easycv.file import io
from easycv.models import build_model from easycv.models.builder import build_model
from easycv.utils.checkpoint import load_checkpoint from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.config_tools import mmcv_config_fromfile from easycv.utils.config_tools import mmcv_config_fromfile
from easycv.utils.constant import CACHE_DIR from easycv.utils.constant import CACHE_DIR
@ -91,3 +93,175 @@ class Predictor(object):
output = self.model.forward( output = self.model.forward(
image_batch.to(self.device), **forward_kwargs) image_batch.to(self.device), **forward_kwargs)
return output return output
class PredictorV2(object):
"""Base predict pipeline.
Args:
model_path (str): Path of model path.
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward.
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
"""
def __init__(self,
model_path,
config_file=None,
batch_size=1,
device=None,
save_results=False,
save_path=None,
*args,
**kwargs):
self.model_path = model_path
self.batch_size = batch_size
self.save_results = save_results
self.save_path = save_path
if self.save_results:
assert self.save_path is not None
self.device = device
if self.device is None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.cfg = None
if config_file is not None:
if isinstance(config_file, str):
self.cfg = mmcv_config_fromfile(config_file)
else:
self.cfg = config_file
self.model = self.prepare_model()
self.processor = self.build_processor()
self._load_op = None
def prepare_model(self):
"""Build model from config file by default.
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
"""
model = self._build_model()
model.to(self.device)
model.eval()
load_checkpoint(model, self.model_path, map_location='cpu')
return model
def _build_model(self):
if self.cfg is None:
raise ValueError('Please provide "config_file"!')
model = build_model(self.cfg.model)
return model
def build_processor(self):
"""Build processor to process loaded input.
If you need custom preprocessing ops, you need to reimplement it.
"""
if self.cfg is None:
pipeline = []
else:
pipeline = [
build_from_cfg(p, PIPELINES)
for p in self.cfg.get('test_pipeline', [])
]
from easycv.datasets.shared.pipelines.transforms import Compose
processor = Compose(pipeline)
return processor
def _load_input(self, input):
"""Load image from file or numpy or PIL object.
Args:
input: File path or numpy or PIL object.
Returns:
{
'filename': filename,
'img': img,
'img_shape': img_shape,
'img_fields': ['img']
}
"""
if self._load_op is None:
load_cfg = dict(type='LoadImage', mode='rgb')
self._load_op = build_from_cfg(load_cfg, PIPELINES)
if not isinstance(input, str):
sample = self._load_op({'img': input})
else:
sample = self._load_op({'filename': input})
return sample
def preprocess_single(self, input):
"""Preprocess single input sample.
If you need custom ops to load or process a single input sample, you need to reimplement it.
"""
input = self._load_input(input)
return self.processor(input)
def preprocess(self, inputs, *args, **kwargs):
"""Process all inputs list. And collate to batch and put to target device.
If you need custom ops to load or process a batch samples, you need to reimplement it.
"""
batch_outputs = []
for i in inputs:
batch_outputs.append(self.preprocess_single(i, *args, **kwargs))
batch_outputs = self._collate_fn(batch_outputs)
batch_outputs = self._to_device(batch_outputs)
return batch_outputs
def forward(self, inputs):
"""Model forward.
If you need refactor model forward, you need to reimplement it.
"""
with torch.no_grad():
outputs = self.model(**inputs, mode='test')
return outputs
def postprocess(self, inputs, *args, **kwargs):
"""Process model outputs.
If you need add some processing ops to process model outputs, you need to reimplement it.
"""
return inputs
def _collate_fn(self, inputs):
"""Prepare the input just before the forward function.
Puts each data field into a tensor with outer dimension batch size
"""
return collate(inputs, samples_per_gpu=self.batch_size)
def _to_device(self, inputs):
target_gpus = [-1] if self.device == 'cpu' else [
torch.cuda.current_device()
]
_, kwargs = scatter_kwargs(None, inputs, target_gpus=target_gpus)
return kwargs[0]
@staticmethod
def dump(obj, save_path, mode='wb'):
with open(save_path, mode) as f:
f.write(pickle.dumps(obj))
def __call__(self, inputs, keep_inputs=False):
# TODO: fault tolerance
if isinstance(inputs, str):
inputs = [inputs]
results_list = []
for i in range(0, len(inputs), self.batch_size):
batch = inputs[i:max(len(inputs) - 1, i + self.batch_size)]
batch_outputs = self.preprocess(batch)
batch_outputs = self.forward(batch_outputs)
results = self.postprocess(batch_outputs)
if keep_inputs:
results = {'inputs': batch, 'results': results}
# if dump, the outputs will not added to the return value to prevent taking up too much memory
if self.save_results:
self.dump([results], self.save_path, mode='ab+')
else:
results_list.append(results)
return results_list

View File

@ -4,5 +4,5 @@ from easycv.utils.registry import Registry, build_from_cfg
PREDICTORS = Registry('predictor') PREDICTORS = Registry('predictor')
def build_predictor(cfg): def build_predictor(cfg, default_args=None):
return build_from_cfg(cfg, PREDICTORS, default_args=None) return build_from_cfg(cfg, PREDICTORS, default_args=default_args)

View File

@ -16,6 +16,110 @@ from easycv.predictors.interface import PredictorInterface
from easycv.utils.checkpoint import load_checkpoint from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.config_tools import mmcv_config_fromfile from easycv.utils.config_tools import mmcv_config_fromfile
from easycv.utils.registry import build_from_cfg from easycv.utils.registry import build_from_cfg
from .base import PredictorV2
@PREDICTORS.register_module()
class SegmentationPredictor(PredictorV2):
def __init__(self,
model_path,
config_file,
batch_size=1,
device=None,
save_results=False,
save_path=None):
"""Predict pipeline for Segmentation
Args:
model_path (str): Path of model path
config_file (str): config file path for model and processor to init. Defaults to None.
"""
super(SegmentationPredictor, self).__init__(
model_path,
config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path)
self.CLASSES = self.cfg.CLASSES
self.PALETTE = self.cfg.PALETTE
def show_result(self,
img,
result,
palette=None,
win_name='',
show=False,
wait_time=0,
out_file=None,
opacity=0.5):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
result (Tensor): The semantic segmentation results to draw over
`img`.
palette (list[list[int]]] | np.ndarray | None): The palette of
segmentation map. If None is given, random palette will be
generated. Default: None
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The filename to write the image.
Default: None.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
Returns:
img (Tensor): Only if not `show` or `out_file`
"""
img = mmcv.imread(img)
img = img.copy()
seg = result[0]
if palette is None:
if self.PALETTE is None:
# Get random state before set seed,
# and restore random state later.
# It will prevent loss of randomness, as the palette
# may be different in each iteration if not specified.
# See: https://github.com/open-mmlab/mmdetection/issues/5844
state = np.random.get_state()
np.random.seed(42)
# random palette
palette = np.random.randint(
0, 255, size=(len(self.CLASSES), 3))
np.random.set_state(state)
else:
palette = self.PALETTE
palette = np.array(palette)
assert palette.shape[0] == len(self.CLASSES)
assert palette.shape[1] == 3
assert len(palette.shape) == 2
assert 0 < opacity <= 1.0
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# convert to BGR
color_seg = color_seg[..., ::-1]
img = img * (1 - opacity) + color_seg * opacity
img = img.astype(np.uint8)
# if out_file specified, do not show image in window
if out_file is not None:
show = False
if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)
if not (show or out_file):
return img
@PREDICTORS.register_module() @PREDICTORS.register_module()

View File

@ -0,0 +1,125 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import jsonplus
from easycv.file import io
from easycv.utils.config_tools import Config
MODELSCOPE_PREFIX = 'modelscope'
EASYCV_ARCH = '__easycv_arch__'
def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True):
"""Convert EasyCV config to ModelScope style.
Args:
cfg (str | Config): Easycv config file or Config object.
task (str): Task name in modelscope, refer to: modelscope.utils.constant.Tasks.
ms_model_name (str): Model name registered in modelscope, model type will be replaced with `ms_model_name`, used in modelscope.
save_path (str): Save path for saving the generated modelscope configuration file. Only valid when dump is True.
dump (bool): Whether dump the converted config to `save_path`.
"""
# TODO: support multi eval_pipelines
# TODO: support for adding customized required keys to the configuration file
if isinstance(cfg, str):
easycv_cfg = Config.fromfile(cfg)
if dump and save_path is None:
save_dir = os.path.dirname(cfg)
save_name = MODELSCOPE_PREFIX + '_' + os.path.splitext(
os.path.basename(cfg))[0] + '.json'
save_path = os.path.join(save_dir, save_name)
else:
easycv_cfg = cfg
if dump and save_path is None:
raise ValueError('Please provide `save_path`!')
assert save_path.endswith('json'), 'Only support json file!'
optimizer_options = easycv_cfg.optimizer_config
optimizer_options.update({'loss_keys': 'total_loss'})
val_dataset_cfg = easycv_cfg.data.val
val_imgs_per_gpu = val_dataset_cfg.pop('imgs_per_gpu',
easycv_cfg.data.imgs_per_gpu)
val_workers_per_gpu = val_dataset_cfg.pop('workers_per_gpu',
easycv_cfg.data.workers_per_gpu)
log_config = easycv_cfg.log_config
predict_config = easycv_cfg.get('predict', None)
hooks = [{
'type': 'CheckpointHook',
'interval': easycv_cfg.checkpoint_config.interval
}, {
'type': 'EvaluationHook',
'interval': easycv_cfg.eval_config.interval
}, {
'type': 'AddLrLogHook'
}, {
'type': 'IterTimerHook'
}]
custom_hooks = easycv_cfg.get('custom_hooks', [])
hooks.extend(custom_hooks)
for log_hook_i in log_config.hooks:
if log_hook_i['type'] == 'TensorboardLoggerHook':
# replace with modelscope api
hooks.append({
'type': 'TensorboardHook',
'interval': log_config.interval
})
elif log_hook_i['type'] == 'TextLoggerHook':
# use modelscope api
hooks.append({
'type': 'TextLoggerHook',
'interval': log_config.interval
})
else:
log_hook_i.update({'interval': log_config.interval})
hooks.append(log_hook_i)
ori_model_type = easycv_cfg.model.pop('type')
ms_cfg = Config(
dict(
task=task,
framework='pytorch',
model={
'type': ms_model_name,
**easycv_cfg.model, EASYCV_ARCH: {
'type': ori_model_type
}
},
dataset=dict(train=easycv_cfg.data.train, val=val_dataset_cfg),
train=dict(
work_dir=easycv_cfg.get('work_dir', None),
max_epochs=easycv_cfg.total_epochs,
dataloader=dict(
batch_size_per_gpu=easycv_cfg.data.imgs_per_gpu,
workers_per_gpu=easycv_cfg.data.workers_per_gpu,
),
optimizer=dict(
**easycv_cfg.optimizer, options=optimizer_options),
lr_scheduler=easycv_cfg.lr_config,
hooks=hooks),
evaluation=dict(
dataloader=dict(
batch_size_per_gpu=val_imgs_per_gpu,
workers_per_gpu=val_workers_per_gpu,
),
metrics={
'type': 'EasyCVMetric',
'evaluators': easycv_cfg.eval_pipelines[0].evaluators
}),
pipeline=dict(predictor_config=predict_config),
))
if dump:
with io.open(save_path, 'w') as f:
res = jsonplus.dumps(
ms_cfg._cfg_dict.to_dict(), indent=4, sort_keys=False)
f.write(res)
return ms_cfg

View File

@ -1,2 +1,4 @@
http://pai-nni.oss-cn-zhangjiakou.aliyuncs.com/release/2.6.1/pai_nni-2.6.1-py3-none-manylinux1_x86_64.whl http://pai-nni.oss-cn-zhangjiakou.aliyuncs.com/release/2.6.1/pai_nni-2.6.1-py3-none-manylinux1_x86_64.whl
http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/pkgs/whl/panopticapi/panopticapi-0.1-py3-none-any.whl
http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/blade_compression-0.0.2-py3-none-any.whl http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/blade_compression-0.0.2-py3-none-any.whl
https://developer.download.nvidia.com/compute/redist/nvidia-dali-cuda100/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl

View File

@ -3,21 +3,19 @@ dataclasses
einops einops
future future
h5py h5py
http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/pkgs/whl/panopticapi/panopticapi-0.1-py3-none-any.whl
https://developer.download.nvidia.com/compute/redist/nvidia-dali-cuda100/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl
json_tricks json_tricks
numpy numpy
opencv-python-headless opencv-python
oss2 oss2
packaging packaging
Pillow Pillow
prettytable prettytable
pycocotools pycocotools
pytorch_metric_learning==0.9.89 pytorch_metric_learning>=0.9.89
scikit-image scikit-image
sklearn sklearn
tensorboard tensorboard
thop thop
timm==0.4.9 timm>=0.4.9
xtcocotools xtcocotools
yacs yacs

View File

@ -0,0 +1,107 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import pickle
import shutil
import tempfile
import unittest
import numpy as np
from PIL import Image
from tests.ut_config import (MODEL_CONFIG_SEGFORMER,
PRETRAINED_MODEL_SEGFORMER, TEST_IMAGES_DIR)
from easycv.predictors.segmentation import SegmentationPredictor
class SegmentationPredictorTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_single(self):
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
segmentation_model_config = MODEL_CONFIG_SEGFORMER
img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
img = np.asarray(Image.open(img_path))
predict_pipeline = SegmentationPredictor(
model_path=segmentation_model_path,
config_file=segmentation_model_config)
outputs = predict_pipeline(img_path, keep_inputs=True)
self.assertEqual(len(outputs), 1)
self.assertEqual(outputs[0]['inputs'], [img_path])
results = outputs[0]['results']
self.assertListEqual(
list(img.shape)[:2], list(results['seg_pred'][0].shape))
self.assertListEqual(results['seg_pred'][0][1, :10].tolist(),
[161 for i in range(10)])
self.assertListEqual(results['seg_pred'][0][-1, -10:].tolist(),
[133 for i in range(10)])
def test_batch(self):
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
segmentation_model_config = MODEL_CONFIG_SEGFORMER
img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
img = np.asarray(Image.open(img_path))
predict_pipeline = SegmentationPredictor(
model_path=segmentation_model_path,
config_file=segmentation_model_config,
batch_size=2)
total_samples = 3
outputs = predict_pipeline(
[img_path] * total_samples, keep_inputs=True)
self.assertEqual(len(outputs), 2)
self.assertEqual(outputs[0]['inputs'], [img_path] * 2)
self.assertEqual(outputs[1]['inputs'], [img_path] * 1)
self.assertEqual(len(outputs[0]['results']['seg_pred']), 2)
self.assertEqual(len(outputs[1]['results']['seg_pred']), 1)
for result in [outputs[0]['results'], outputs[1]['results']]:
self.assertListEqual(
list(img.shape)[:2], list(result['seg_pred'][0].shape))
self.assertListEqual(result['seg_pred'][0][1, :10].tolist(),
[161 for i in range(10)])
self.assertListEqual(result['seg_pred'][0][-1, -10:].tolist(),
[133 for i in range(10)])
def test_dump(self):
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
segmentation_model_config = MODEL_CONFIG_SEGFORMER
img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
temp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(temp_dir):
os.makedirs(temp_dir)
tmp_path = os.path.join(temp_dir, 'results.pkl')
predict_pipeline = SegmentationPredictor(
model_path=segmentation_model_path,
config_file=segmentation_model_config,
batch_size=2,
save_results=True,
save_path=tmp_path)
total_samples = 3
outputs = predict_pipeline(
[img_path] * total_samples, keep_inputs=True)
self.assertEqual(outputs, [])
with open(tmp_path, 'rb') as f:
results = pickle.loads(f.read())
self.assertIn('inputs', results[0])
self.assertIn('results', results[0])
shutil.rmtree(temp_dir, ignore_errors=True)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,47 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
import easycv
from easycv.utils.config_tools import Config
from easycv.utils.ms_utils import to_ms_config
class MsConfigTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
def tearDown(self):
super().tearDown()
shutil.rmtree(self.tmp_dir)
def test_to_ms_config(self):
config_path = os.path.join(
os.path.dirname(os.path.dirname(easycv.__file__)),
'configs/detection/yolox/yolox_s_8xb16_300e_coco.py')
ms_cfg_file = os.path.join(self.tmp_dir,
'ms_yolox_s_8xb16_300e_coco.json')
to_ms_config(
config_path,
task='image-object-detection',
ms_model_name='yolox',
save_path=ms_cfg_file)
cfg = Config.fromfile(ms_cfg_file)
self.assertIn('task', cfg)
self.assertIn('framework', cfg)
self.assertEqual(cfg.model.type, 'yolox')
self.assertIn('dataset', cfg)
self.assertIn('batch_size_per_gpu', cfg.train.dataloader)
self.assertIn('batch_size_per_gpu', cfg.evaluation.dataloader)
if __name__ == '__main__':
unittest.main()