mirror of https://github.com/alibaba/EasyCV.git
add predict pipeline
Link: https://code.alibaba-inc.com/pai-vision/EasyCV/codereview/9828601 * add predict pipelinepull/191/head
parent
b3abdf507f
commit
0f74adb848
|
@ -233,6 +233,8 @@ eval_pipelines = [
|
|||
)
|
||||
]
|
||||
|
||||
predict = dict(type='SegmentationPredictor')
|
||||
|
||||
log_config = dict(
|
||||
interval=50,
|
||||
hooks=[
|
||||
|
|
|
@ -233,6 +233,8 @@ eval_pipelines = [
|
|||
)
|
||||
]
|
||||
|
||||
predict = dict(type='SegmentationPredictor')
|
||||
|
||||
log_config = dict(
|
||||
interval=50,
|
||||
hooks=[
|
||||
|
|
|
@ -4,4 +4,4 @@ from .dali_transforms import (DaliColorTwist, DaliCropMirrorNormalize,
|
|||
DaliImageDecoder, DaliRandomGrayscale,
|
||||
DaliRandomResizedCrop, DaliResize)
|
||||
from .format import Collect, DefaultFormatBundle, ImageToTensor
|
||||
from .transforms import Compose
|
||||
from .transforms import Compose, LoadImage
|
||||
|
|
|
@ -2,7 +2,10 @@
|
|||
import time
|
||||
from collections.abc import Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from easycv.datasets.registry import PIPELINES
|
||||
from easycv.file.image import load_image
|
||||
from easycv.utils.registry import build_from_cfg
|
||||
|
||||
|
||||
|
@ -48,3 +51,49 @@ class Compose(object):
|
|||
format_string += f'\n {t}'
|
||||
format_string += '\n)'
|
||||
return format_string
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class LoadImage:
|
||||
"""Load an image from file or numpy or PIL object.
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an uint8 array.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, to_float32=False, mode='bgr'):
|
||||
self.to_float32 = to_float32
|
||||
self.mode = mode
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call functions to load image and get image meta information.
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
filename = results.get('filename', None)
|
||||
img = results.get('img', None)
|
||||
|
||||
if img is not None:
|
||||
if not isinstance(img, np.ndarray):
|
||||
img = np.asarray(img, dtype=np.uint8)
|
||||
elif filename is None:
|
||||
raise ValueError('Please provide "filename" or "img"!')
|
||||
|
||||
img = load_image(filename, mode=self.mode)
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['filename'] = filename
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
results['img_fields'] = ['img']
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32}, '
|
||||
f"mode='{self.mode}'")
|
||||
|
||||
return repr_str
|
||||
|
|
|
@ -7,4 +7,5 @@ from .feature_extractor import (TorchFaceAttrExtractor,
|
|||
TorchFeatureExtractor)
|
||||
from .pose_predictor import (TorchPoseTopDownPredictor,
|
||||
TorchPoseTopDownPredictorWithDetector)
|
||||
from .segmentation import Mask2formerPredictor, SegFormerPredictor
|
||||
from .segmentation import (Mask2formerPredictor, SegFormerPredictor,
|
||||
SegmentationPredictor)
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.parallel import collate, scatter_kwargs
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from easycv.datasets.registry import PIPELINES
|
||||
from easycv.file import io
|
||||
from easycv.models import build_model
|
||||
from easycv.models.builder import build_model
|
||||
from easycv.utils.checkpoint import load_checkpoint
|
||||
from easycv.utils.config_tools import mmcv_config_fromfile
|
||||
from easycv.utils.constant import CACHE_DIR
|
||||
|
@ -91,3 +93,175 @@ class Predictor(object):
|
|||
output = self.model.forward(
|
||||
image_batch.to(self.device), **forward_kwargs)
|
||||
return output
|
||||
|
||||
|
||||
class PredictorV2(object):
|
||||
"""Base predict pipeline.
|
||||
Args:
|
||||
model_path (str): Path of model path.
|
||||
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
|
||||
batch_size (int): batch size for forward.
|
||||
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
|
||||
save_results (bool): Whether to save predict results.
|
||||
save_path (str): File path for saving results, only valid when `save_results` is True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_path,
|
||||
config_file=None,
|
||||
batch_size=1,
|
||||
device=None,
|
||||
save_results=False,
|
||||
save_path=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
self.model_path = model_path
|
||||
self.batch_size = batch_size
|
||||
self.save_results = save_results
|
||||
self.save_path = save_path
|
||||
if self.save_results:
|
||||
assert self.save_path is not None
|
||||
self.device = device
|
||||
if self.device is None:
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
self.cfg = None
|
||||
if config_file is not None:
|
||||
if isinstance(config_file, str):
|
||||
self.cfg = mmcv_config_fromfile(config_file)
|
||||
else:
|
||||
self.cfg = config_file
|
||||
|
||||
self.model = self.prepare_model()
|
||||
self.processor = self.build_processor()
|
||||
self._load_op = None
|
||||
|
||||
def prepare_model(self):
|
||||
"""Build model from config file by default.
|
||||
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
|
||||
"""
|
||||
model = self._build_model()
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
load_checkpoint(model, self.model_path, map_location='cpu')
|
||||
return model
|
||||
|
||||
def _build_model(self):
|
||||
if self.cfg is None:
|
||||
raise ValueError('Please provide "config_file"!')
|
||||
|
||||
model = build_model(self.cfg.model)
|
||||
return model
|
||||
|
||||
def build_processor(self):
|
||||
"""Build processor to process loaded input.
|
||||
If you need custom preprocessing ops, you need to reimplement it.
|
||||
"""
|
||||
if self.cfg is None:
|
||||
pipeline = []
|
||||
else:
|
||||
pipeline = [
|
||||
build_from_cfg(p, PIPELINES)
|
||||
for p in self.cfg.get('test_pipeline', [])
|
||||
]
|
||||
|
||||
from easycv.datasets.shared.pipelines.transforms import Compose
|
||||
processor = Compose(pipeline)
|
||||
return processor
|
||||
|
||||
def _load_input(self, input):
|
||||
"""Load image from file or numpy or PIL object.
|
||||
Args:
|
||||
input: File path or numpy or PIL object.
|
||||
Returns:
|
||||
{
|
||||
'filename': filename,
|
||||
'img': img,
|
||||
'img_shape': img_shape,
|
||||
'img_fields': ['img']
|
||||
}
|
||||
"""
|
||||
if self._load_op is None:
|
||||
load_cfg = dict(type='LoadImage', mode='rgb')
|
||||
self._load_op = build_from_cfg(load_cfg, PIPELINES)
|
||||
|
||||
if not isinstance(input, str):
|
||||
sample = self._load_op({'img': input})
|
||||
else:
|
||||
sample = self._load_op({'filename': input})
|
||||
|
||||
return sample
|
||||
|
||||
def preprocess_single(self, input):
|
||||
"""Preprocess single input sample.
|
||||
If you need custom ops to load or process a single input sample, you need to reimplement it.
|
||||
"""
|
||||
input = self._load_input(input)
|
||||
return self.processor(input)
|
||||
|
||||
def preprocess(self, inputs, *args, **kwargs):
|
||||
"""Process all inputs list. And collate to batch and put to target device.
|
||||
If you need custom ops to load or process a batch samples, you need to reimplement it.
|
||||
"""
|
||||
batch_outputs = []
|
||||
for i in inputs:
|
||||
batch_outputs.append(self.preprocess_single(i, *args, **kwargs))
|
||||
|
||||
batch_outputs = self._collate_fn(batch_outputs)
|
||||
batch_outputs = self._to_device(batch_outputs)
|
||||
|
||||
return batch_outputs
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Model forward.
|
||||
If you need refactor model forward, you need to reimplement it.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs, mode='test')
|
||||
return outputs
|
||||
|
||||
def postprocess(self, inputs, *args, **kwargs):
|
||||
"""Process model outputs.
|
||||
If you need add some processing ops to process model outputs, you need to reimplement it.
|
||||
"""
|
||||
return inputs
|
||||
|
||||
def _collate_fn(self, inputs):
|
||||
"""Prepare the input just before the forward function.
|
||||
Puts each data field into a tensor with outer dimension batch size
|
||||
"""
|
||||
return collate(inputs, samples_per_gpu=self.batch_size)
|
||||
|
||||
def _to_device(self, inputs):
|
||||
target_gpus = [-1] if self.device == 'cpu' else [
|
||||
torch.cuda.current_device()
|
||||
]
|
||||
_, kwargs = scatter_kwargs(None, inputs, target_gpus=target_gpus)
|
||||
return kwargs[0]
|
||||
|
||||
@staticmethod
|
||||
def dump(obj, save_path, mode='wb'):
|
||||
with open(save_path, mode) as f:
|
||||
f.write(pickle.dumps(obj))
|
||||
|
||||
def __call__(self, inputs, keep_inputs=False):
|
||||
# TODO: fault tolerance
|
||||
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
|
||||
results_list = []
|
||||
for i in range(0, len(inputs), self.batch_size):
|
||||
batch = inputs[i:max(len(inputs) - 1, i + self.batch_size)]
|
||||
batch_outputs = self.preprocess(batch)
|
||||
batch_outputs = self.forward(batch_outputs)
|
||||
results = self.postprocess(batch_outputs)
|
||||
if keep_inputs:
|
||||
results = {'inputs': batch, 'results': results}
|
||||
# if dump, the outputs will not added to the return value to prevent taking up too much memory
|
||||
if self.save_results:
|
||||
self.dump([results], self.save_path, mode='ab+')
|
||||
else:
|
||||
results_list.append(results)
|
||||
|
||||
return results_list
|
||||
|
|
|
@ -4,5 +4,5 @@ from easycv.utils.registry import Registry, build_from_cfg
|
|||
PREDICTORS = Registry('predictor')
|
||||
|
||||
|
||||
def build_predictor(cfg):
|
||||
return build_from_cfg(cfg, PREDICTORS, default_args=None)
|
||||
def build_predictor(cfg, default_args=None):
|
||||
return build_from_cfg(cfg, PREDICTORS, default_args=default_args)
|
||||
|
|
|
@ -16,6 +16,110 @@ from easycv.predictors.interface import PredictorInterface
|
|||
from easycv.utils.checkpoint import load_checkpoint
|
||||
from easycv.utils.config_tools import mmcv_config_fromfile
|
||||
from easycv.utils.registry import build_from_cfg
|
||||
from .base import PredictorV2
|
||||
|
||||
|
||||
@PREDICTORS.register_module()
|
||||
class SegmentationPredictor(PredictorV2):
|
||||
|
||||
def __init__(self,
|
||||
model_path,
|
||||
config_file,
|
||||
batch_size=1,
|
||||
device=None,
|
||||
save_results=False,
|
||||
save_path=None):
|
||||
"""Predict pipeline for Segmentation
|
||||
|
||||
Args:
|
||||
model_path (str): Path of model path
|
||||
config_file (str): config file path for model and processor to init. Defaults to None.
|
||||
"""
|
||||
super(SegmentationPredictor, self).__init__(
|
||||
model_path,
|
||||
config_file,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
save_results=save_results,
|
||||
save_path=save_path)
|
||||
|
||||
self.CLASSES = self.cfg.CLASSES
|
||||
self.PALETTE = self.cfg.PALETTE
|
||||
|
||||
def show_result(self,
|
||||
img,
|
||||
result,
|
||||
palette=None,
|
||||
win_name='',
|
||||
show=False,
|
||||
wait_time=0,
|
||||
out_file=None,
|
||||
opacity=0.5):
|
||||
"""Draw `result` over `img`.
|
||||
|
||||
Args:
|
||||
img (str or Tensor): The image to be displayed.
|
||||
result (Tensor): The semantic segmentation results to draw over
|
||||
`img`.
|
||||
palette (list[list[int]]] | np.ndarray | None): The palette of
|
||||
segmentation map. If None is given, random palette will be
|
||||
generated. Default: None
|
||||
win_name (str): The window name.
|
||||
wait_time (int): Value of waitKey param.
|
||||
Default: 0.
|
||||
show (bool): Whether to show the image.
|
||||
Default: False.
|
||||
out_file (str or None): The filename to write the image.
|
||||
Default: None.
|
||||
opacity(float): Opacity of painted segmentation map.
|
||||
Default 0.5.
|
||||
Must be in (0, 1] range.
|
||||
Returns:
|
||||
img (Tensor): Only if not `show` or `out_file`
|
||||
"""
|
||||
|
||||
img = mmcv.imread(img)
|
||||
img = img.copy()
|
||||
seg = result[0]
|
||||
if palette is None:
|
||||
if self.PALETTE is None:
|
||||
# Get random state before set seed,
|
||||
# and restore random state later.
|
||||
# It will prevent loss of randomness, as the palette
|
||||
# may be different in each iteration if not specified.
|
||||
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
||||
state = np.random.get_state()
|
||||
np.random.seed(42)
|
||||
# random palette
|
||||
palette = np.random.randint(
|
||||
0, 255, size=(len(self.CLASSES), 3))
|
||||
np.random.set_state(state)
|
||||
else:
|
||||
palette = self.PALETTE
|
||||
palette = np.array(palette)
|
||||
assert palette.shape[0] == len(self.CLASSES)
|
||||
assert palette.shape[1] == 3
|
||||
assert len(palette.shape) == 2
|
||||
assert 0 < opacity <= 1.0
|
||||
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
||||
for label, color in enumerate(palette):
|
||||
color_seg[seg == label, :] = color
|
||||
# convert to BGR
|
||||
color_seg = color_seg[..., ::-1]
|
||||
|
||||
img = img * (1 - opacity) + color_seg * opacity
|
||||
img = img.astype(np.uint8)
|
||||
# if out_file specified, do not show image in window
|
||||
if out_file is not None:
|
||||
show = False
|
||||
|
||||
if show:
|
||||
mmcv.imshow(img, win_name, wait_time)
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(img, out_file)
|
||||
|
||||
if not (show or out_file):
|
||||
return img
|
||||
|
||||
|
||||
@PREDICTORS.register_module()
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
|
||||
import jsonplus
|
||||
|
||||
from easycv.file import io
|
||||
from easycv.utils.config_tools import Config
|
||||
|
||||
MODELSCOPE_PREFIX = 'modelscope'
|
||||
EASYCV_ARCH = '__easycv_arch__'
|
||||
|
||||
|
||||
def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True):
|
||||
"""Convert EasyCV config to ModelScope style.
|
||||
|
||||
Args:
|
||||
cfg (str | Config): Easycv config file or Config object.
|
||||
task (str): Task name in modelscope, refer to: modelscope.utils.constant.Tasks.
|
||||
ms_model_name (str): Model name registered in modelscope, model type will be replaced with `ms_model_name`, used in modelscope.
|
||||
save_path (str): Save path for saving the generated modelscope configuration file. Only valid when dump is True.
|
||||
dump (bool): Whether dump the converted config to `save_path`.
|
||||
"""
|
||||
# TODO: support multi eval_pipelines
|
||||
# TODO: support for adding customized required keys to the configuration file
|
||||
|
||||
if isinstance(cfg, str):
|
||||
easycv_cfg = Config.fromfile(cfg)
|
||||
if dump and save_path is None:
|
||||
save_dir = os.path.dirname(cfg)
|
||||
save_name = MODELSCOPE_PREFIX + '_' + os.path.splitext(
|
||||
os.path.basename(cfg))[0] + '.json'
|
||||
save_path = os.path.join(save_dir, save_name)
|
||||
else:
|
||||
easycv_cfg = cfg
|
||||
if dump and save_path is None:
|
||||
raise ValueError('Please provide `save_path`!')
|
||||
|
||||
assert save_path.endswith('json'), 'Only support json file!'
|
||||
optimizer_options = easycv_cfg.optimizer_config
|
||||
optimizer_options.update({'loss_keys': 'total_loss'})
|
||||
|
||||
val_dataset_cfg = easycv_cfg.data.val
|
||||
val_imgs_per_gpu = val_dataset_cfg.pop('imgs_per_gpu',
|
||||
easycv_cfg.data.imgs_per_gpu)
|
||||
val_workers_per_gpu = val_dataset_cfg.pop('workers_per_gpu',
|
||||
easycv_cfg.data.workers_per_gpu)
|
||||
|
||||
log_config = easycv_cfg.log_config
|
||||
predict_config = easycv_cfg.get('predict', None)
|
||||
|
||||
hooks = [{
|
||||
'type': 'CheckpointHook',
|
||||
'interval': easycv_cfg.checkpoint_config.interval
|
||||
}, {
|
||||
'type': 'EvaluationHook',
|
||||
'interval': easycv_cfg.eval_config.interval
|
||||
}, {
|
||||
'type': 'AddLrLogHook'
|
||||
}, {
|
||||
'type': 'IterTimerHook'
|
||||
}]
|
||||
|
||||
custom_hooks = easycv_cfg.get('custom_hooks', [])
|
||||
hooks.extend(custom_hooks)
|
||||
|
||||
for log_hook_i in log_config.hooks:
|
||||
if log_hook_i['type'] == 'TensorboardLoggerHook':
|
||||
# replace with modelscope api
|
||||
hooks.append({
|
||||
'type': 'TensorboardHook',
|
||||
'interval': log_config.interval
|
||||
})
|
||||
elif log_hook_i['type'] == 'TextLoggerHook':
|
||||
# use modelscope api
|
||||
hooks.append({
|
||||
'type': 'TextLoggerHook',
|
||||
'interval': log_config.interval
|
||||
})
|
||||
else:
|
||||
log_hook_i.update({'interval': log_config.interval})
|
||||
hooks.append(log_hook_i)
|
||||
|
||||
ori_model_type = easycv_cfg.model.pop('type')
|
||||
|
||||
ms_cfg = Config(
|
||||
dict(
|
||||
task=task,
|
||||
framework='pytorch',
|
||||
model={
|
||||
'type': ms_model_name,
|
||||
**easycv_cfg.model, EASYCV_ARCH: {
|
||||
'type': ori_model_type
|
||||
}
|
||||
},
|
||||
dataset=dict(train=easycv_cfg.data.train, val=val_dataset_cfg),
|
||||
train=dict(
|
||||
work_dir=easycv_cfg.get('work_dir', None),
|
||||
max_epochs=easycv_cfg.total_epochs,
|
||||
dataloader=dict(
|
||||
batch_size_per_gpu=easycv_cfg.data.imgs_per_gpu,
|
||||
workers_per_gpu=easycv_cfg.data.workers_per_gpu,
|
||||
),
|
||||
optimizer=dict(
|
||||
**easycv_cfg.optimizer, options=optimizer_options),
|
||||
lr_scheduler=easycv_cfg.lr_config,
|
||||
hooks=hooks),
|
||||
evaluation=dict(
|
||||
dataloader=dict(
|
||||
batch_size_per_gpu=val_imgs_per_gpu,
|
||||
workers_per_gpu=val_workers_per_gpu,
|
||||
),
|
||||
metrics={
|
||||
'type': 'EasyCVMetric',
|
||||
'evaluators': easycv_cfg.eval_pipelines[0].evaluators
|
||||
}),
|
||||
pipeline=dict(predictor_config=predict_config),
|
||||
))
|
||||
|
||||
if dump:
|
||||
with io.open(save_path, 'w') as f:
|
||||
res = jsonplus.dumps(
|
||||
ms_cfg._cfg_dict.to_dict(), indent=4, sort_keys=False)
|
||||
f.write(res)
|
||||
|
||||
return ms_cfg
|
|
@ -1,2 +1,4 @@
|
|||
http://pai-nni.oss-cn-zhangjiakou.aliyuncs.com/release/2.6.1/pai_nni-2.6.1-py3-none-manylinux1_x86_64.whl
|
||||
http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/pkgs/whl/panopticapi/panopticapi-0.1-py3-none-any.whl
|
||||
http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/blade_compression-0.0.2-py3-none-any.whl
|
||||
https://developer.download.nvidia.com/compute/redist/nvidia-dali-cuda100/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl
|
||||
|
|
|
@ -3,21 +3,19 @@ dataclasses
|
|||
einops
|
||||
future
|
||||
h5py
|
||||
http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/pkgs/whl/panopticapi/panopticapi-0.1-py3-none-any.whl
|
||||
https://developer.download.nvidia.com/compute/redist/nvidia-dali-cuda100/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl
|
||||
json_tricks
|
||||
numpy
|
||||
opencv-python-headless
|
||||
opencv-python
|
||||
oss2
|
||||
packaging
|
||||
Pillow
|
||||
prettytable
|
||||
pycocotools
|
||||
pytorch_metric_learning==0.9.89
|
||||
pytorch_metric_learning>=0.9.89
|
||||
scikit-image
|
||||
sklearn
|
||||
tensorboard
|
||||
thop
|
||||
timm==0.4.9
|
||||
timm>=0.4.9
|
||||
xtcocotools
|
||||
yacs
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tests.ut_config import (MODEL_CONFIG_SEGFORMER,
|
||||
PRETRAINED_MODEL_SEGFORMER, TEST_IMAGES_DIR)
|
||||
|
||||
from easycv.predictors.segmentation import SegmentationPredictor
|
||||
|
||||
|
||||
class SegmentationPredictorTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def test_single(self):
|
||||
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
|
||||
segmentation_model_config = MODEL_CONFIG_SEGFORMER
|
||||
|
||||
img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
|
||||
img = np.asarray(Image.open(img_path))
|
||||
|
||||
predict_pipeline = SegmentationPredictor(
|
||||
model_path=segmentation_model_path,
|
||||
config_file=segmentation_model_config)
|
||||
|
||||
outputs = predict_pipeline(img_path, keep_inputs=True)
|
||||
self.assertEqual(len(outputs), 1)
|
||||
self.assertEqual(outputs[0]['inputs'], [img_path])
|
||||
|
||||
results = outputs[0]['results']
|
||||
self.assertListEqual(
|
||||
list(img.shape)[:2], list(results['seg_pred'][0].shape))
|
||||
self.assertListEqual(results['seg_pred'][0][1, :10].tolist(),
|
||||
[161 for i in range(10)])
|
||||
self.assertListEqual(results['seg_pred'][0][-1, -10:].tolist(),
|
||||
[133 for i in range(10)])
|
||||
|
||||
def test_batch(self):
|
||||
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
|
||||
segmentation_model_config = MODEL_CONFIG_SEGFORMER
|
||||
|
||||
img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
|
||||
img = np.asarray(Image.open(img_path))
|
||||
|
||||
predict_pipeline = SegmentationPredictor(
|
||||
model_path=segmentation_model_path,
|
||||
config_file=segmentation_model_config,
|
||||
batch_size=2)
|
||||
|
||||
total_samples = 3
|
||||
outputs = predict_pipeline(
|
||||
[img_path] * total_samples, keep_inputs=True)
|
||||
self.assertEqual(len(outputs), 2)
|
||||
|
||||
self.assertEqual(outputs[0]['inputs'], [img_path] * 2)
|
||||
self.assertEqual(outputs[1]['inputs'], [img_path] * 1)
|
||||
self.assertEqual(len(outputs[0]['results']['seg_pred']), 2)
|
||||
self.assertEqual(len(outputs[1]['results']['seg_pred']), 1)
|
||||
|
||||
for result in [outputs[0]['results'], outputs[1]['results']]:
|
||||
self.assertListEqual(
|
||||
list(img.shape)[:2], list(result['seg_pred'][0].shape))
|
||||
self.assertListEqual(result['seg_pred'][0][1, :10].tolist(),
|
||||
[161 for i in range(10)])
|
||||
self.assertListEqual(result['seg_pred'][0][-1, -10:].tolist(),
|
||||
[133 for i in range(10)])
|
||||
|
||||
def test_dump(self):
|
||||
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
|
||||
segmentation_model_config = MODEL_CONFIG_SEGFORMER
|
||||
|
||||
img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
|
||||
|
||||
temp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(temp_dir):
|
||||
os.makedirs(temp_dir)
|
||||
tmp_path = os.path.join(temp_dir, 'results.pkl')
|
||||
|
||||
predict_pipeline = SegmentationPredictor(
|
||||
model_path=segmentation_model_path,
|
||||
config_file=segmentation_model_config,
|
||||
batch_size=2,
|
||||
save_results=True,
|
||||
save_path=tmp_path)
|
||||
|
||||
total_samples = 3
|
||||
outputs = predict_pipeline(
|
||||
[img_path] * total_samples, keep_inputs=True)
|
||||
self.assertEqual(outputs, [])
|
||||
|
||||
with open(tmp_path, 'rb') as f:
|
||||
results = pickle.loads(f.read())
|
||||
|
||||
self.assertIn('inputs', results[0])
|
||||
self.assertIn('results', results[0])
|
||||
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import easycv
|
||||
from easycv.utils.config_tools import Config
|
||||
from easycv.utils.ms_utils import to_ms_config
|
||||
|
||||
|
||||
class MsConfigTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
|
||||
def test_to_ms_config(self):
|
||||
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(easycv.__file__)),
|
||||
'configs/detection/yolox/yolox_s_8xb16_300e_coco.py')
|
||||
|
||||
ms_cfg_file = os.path.join(self.tmp_dir,
|
||||
'ms_yolox_s_8xb16_300e_coco.json')
|
||||
to_ms_config(
|
||||
config_path,
|
||||
task='image-object-detection',
|
||||
ms_model_name='yolox',
|
||||
save_path=ms_cfg_file)
|
||||
cfg = Config.fromfile(ms_cfg_file)
|
||||
self.assertIn('task', cfg)
|
||||
self.assertIn('framework', cfg)
|
||||
self.assertEqual(cfg.model.type, 'yolox')
|
||||
self.assertIn('dataset', cfg)
|
||||
self.assertIn('batch_size_per_gpu', cfg.train.dataloader)
|
||||
self.assertIn('batch_size_per_gpu', cfg.evaluation.dataloader)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue