mirror of https://github.com/alibaba/EasyCV.git
add predict pipeline
Link: https://code.alibaba-inc.com/pai-vision/EasyCV/codereview/9828601 * add predict pipelinepull/191/head
parent
b3abdf507f
commit
0f74adb848
|
@ -233,6 +233,8 @@ eval_pipelines = [
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
predict = dict(type='SegmentationPredictor')
|
||||||
|
|
||||||
log_config = dict(
|
log_config = dict(
|
||||||
interval=50,
|
interval=50,
|
||||||
hooks=[
|
hooks=[
|
||||||
|
|
|
@ -233,6 +233,8 @@ eval_pipelines = [
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
predict = dict(type='SegmentationPredictor')
|
||||||
|
|
||||||
log_config = dict(
|
log_config = dict(
|
||||||
interval=50,
|
interval=50,
|
||||||
hooks=[
|
hooks=[
|
||||||
|
|
|
@ -4,4 +4,4 @@ from .dali_transforms import (DaliColorTwist, DaliCropMirrorNormalize,
|
||||||
DaliImageDecoder, DaliRandomGrayscale,
|
DaliImageDecoder, DaliRandomGrayscale,
|
||||||
DaliRandomResizedCrop, DaliResize)
|
DaliRandomResizedCrop, DaliResize)
|
||||||
from .format import Collect, DefaultFormatBundle, ImageToTensor
|
from .format import Collect, DefaultFormatBundle, ImageToTensor
|
||||||
from .transforms import Compose
|
from .transforms import Compose, LoadImage
|
||||||
|
|
|
@ -2,7 +2,10 @@
|
||||||
import time
|
import time
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from easycv.datasets.registry import PIPELINES
|
from easycv.datasets.registry import PIPELINES
|
||||||
|
from easycv.file.image import load_image
|
||||||
from easycv.utils.registry import build_from_cfg
|
from easycv.utils.registry import build_from_cfg
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,3 +51,49 @@ class Compose(object):
|
||||||
format_string += f'\n {t}'
|
format_string += f'\n {t}'
|
||||||
format_string += '\n)'
|
format_string += '\n)'
|
||||||
return format_string
|
return format_string
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module()
|
||||||
|
class LoadImage:
|
||||||
|
"""Load an image from file or numpy or PIL object.
|
||||||
|
Args:
|
||||||
|
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||||
|
numpy array. If set to False, the loaded image is an uint8 array.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, to_float32=False, mode='bgr'):
|
||||||
|
self.to_float32 = to_float32
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def __call__(self, results):
|
||||||
|
"""Call functions to load image and get image meta information.
|
||||||
|
Returns:
|
||||||
|
dict: The dict contains loaded image and meta information.
|
||||||
|
"""
|
||||||
|
filename = results.get('filename', None)
|
||||||
|
img = results.get('img', None)
|
||||||
|
|
||||||
|
if img is not None:
|
||||||
|
if not isinstance(img, np.ndarray):
|
||||||
|
img = np.asarray(img, dtype=np.uint8)
|
||||||
|
elif filename is None:
|
||||||
|
raise ValueError('Please provide "filename" or "img"!')
|
||||||
|
|
||||||
|
img = load_image(filename, mode=self.mode)
|
||||||
|
if self.to_float32:
|
||||||
|
img = img.astype(np.float32)
|
||||||
|
|
||||||
|
results['filename'] = filename
|
||||||
|
results['img'] = img
|
||||||
|
results['img_shape'] = img.shape
|
||||||
|
results['ori_shape'] = img.shape
|
||||||
|
results['img_fields'] = ['img']
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr_str = (f'{self.__class__.__name__}('
|
||||||
|
f'to_float32={self.to_float32}, '
|
||||||
|
f"mode='{self.mode}'")
|
||||||
|
|
||||||
|
return repr_str
|
||||||
|
|
|
@ -7,4 +7,5 @@ from .feature_extractor import (TorchFaceAttrExtractor,
|
||||||
TorchFeatureExtractor)
|
TorchFeatureExtractor)
|
||||||
from .pose_predictor import (TorchPoseTopDownPredictor,
|
from .pose_predictor import (TorchPoseTopDownPredictor,
|
||||||
TorchPoseTopDownPredictorWithDetector)
|
TorchPoseTopDownPredictorWithDetector)
|
||||||
from .segmentation import Mask2formerPredictor, SegFormerPredictor
|
from .segmentation import (Mask2formerPredictor, SegFormerPredictor,
|
||||||
|
SegmentationPredictor)
|
||||||
|
|
|
@ -1,14 +1,16 @@
|
||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from mmcv.parallel import collate, scatter_kwargs
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.transforms import Compose
|
from torchvision.transforms import Compose
|
||||||
|
|
||||||
from easycv.datasets.registry import PIPELINES
|
from easycv.datasets.registry import PIPELINES
|
||||||
from easycv.file import io
|
from easycv.file import io
|
||||||
from easycv.models import build_model
|
from easycv.models.builder import build_model
|
||||||
from easycv.utils.checkpoint import load_checkpoint
|
from easycv.utils.checkpoint import load_checkpoint
|
||||||
from easycv.utils.config_tools import mmcv_config_fromfile
|
from easycv.utils.config_tools import mmcv_config_fromfile
|
||||||
from easycv.utils.constant import CACHE_DIR
|
from easycv.utils.constant import CACHE_DIR
|
||||||
|
@ -91,3 +93,175 @@ class Predictor(object):
|
||||||
output = self.model.forward(
|
output = self.model.forward(
|
||||||
image_batch.to(self.device), **forward_kwargs)
|
image_batch.to(self.device), **forward_kwargs)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class PredictorV2(object):
|
||||||
|
"""Base predict pipeline.
|
||||||
|
Args:
|
||||||
|
model_path (str): Path of model path.
|
||||||
|
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
|
||||||
|
batch_size (int): batch size for forward.
|
||||||
|
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
|
||||||
|
save_results (bool): Whether to save predict results.
|
||||||
|
save_path (str): File path for saving results, only valid when `save_results` is True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_path,
|
||||||
|
config_file=None,
|
||||||
|
batch_size=1,
|
||||||
|
device=None,
|
||||||
|
save_results=False,
|
||||||
|
save_path=None,
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
|
self.model_path = model_path
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.save_results = save_results
|
||||||
|
self.save_path = save_path
|
||||||
|
if self.save_results:
|
||||||
|
assert self.save_path is not None
|
||||||
|
self.device = device
|
||||||
|
if self.device is None:
|
||||||
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
self.cfg = None
|
||||||
|
if config_file is not None:
|
||||||
|
if isinstance(config_file, str):
|
||||||
|
self.cfg = mmcv_config_fromfile(config_file)
|
||||||
|
else:
|
||||||
|
self.cfg = config_file
|
||||||
|
|
||||||
|
self.model = self.prepare_model()
|
||||||
|
self.processor = self.build_processor()
|
||||||
|
self._load_op = None
|
||||||
|
|
||||||
|
def prepare_model(self):
|
||||||
|
"""Build model from config file by default.
|
||||||
|
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
|
||||||
|
"""
|
||||||
|
model = self._build_model()
|
||||||
|
model.to(self.device)
|
||||||
|
model.eval()
|
||||||
|
load_checkpoint(model, self.model_path, map_location='cpu')
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _build_model(self):
|
||||||
|
if self.cfg is None:
|
||||||
|
raise ValueError('Please provide "config_file"!')
|
||||||
|
|
||||||
|
model = build_model(self.cfg.model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def build_processor(self):
|
||||||
|
"""Build processor to process loaded input.
|
||||||
|
If you need custom preprocessing ops, you need to reimplement it.
|
||||||
|
"""
|
||||||
|
if self.cfg is None:
|
||||||
|
pipeline = []
|
||||||
|
else:
|
||||||
|
pipeline = [
|
||||||
|
build_from_cfg(p, PIPELINES)
|
||||||
|
for p in self.cfg.get('test_pipeline', [])
|
||||||
|
]
|
||||||
|
|
||||||
|
from easycv.datasets.shared.pipelines.transforms import Compose
|
||||||
|
processor = Compose(pipeline)
|
||||||
|
return processor
|
||||||
|
|
||||||
|
def _load_input(self, input):
|
||||||
|
"""Load image from file or numpy or PIL object.
|
||||||
|
Args:
|
||||||
|
input: File path or numpy or PIL object.
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
'filename': filename,
|
||||||
|
'img': img,
|
||||||
|
'img_shape': img_shape,
|
||||||
|
'img_fields': ['img']
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if self._load_op is None:
|
||||||
|
load_cfg = dict(type='LoadImage', mode='rgb')
|
||||||
|
self._load_op = build_from_cfg(load_cfg, PIPELINES)
|
||||||
|
|
||||||
|
if not isinstance(input, str):
|
||||||
|
sample = self._load_op({'img': input})
|
||||||
|
else:
|
||||||
|
sample = self._load_op({'filename': input})
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def preprocess_single(self, input):
|
||||||
|
"""Preprocess single input sample.
|
||||||
|
If you need custom ops to load or process a single input sample, you need to reimplement it.
|
||||||
|
"""
|
||||||
|
input = self._load_input(input)
|
||||||
|
return self.processor(input)
|
||||||
|
|
||||||
|
def preprocess(self, inputs, *args, **kwargs):
|
||||||
|
"""Process all inputs list. And collate to batch and put to target device.
|
||||||
|
If you need custom ops to load or process a batch samples, you need to reimplement it.
|
||||||
|
"""
|
||||||
|
batch_outputs = []
|
||||||
|
for i in inputs:
|
||||||
|
batch_outputs.append(self.preprocess_single(i, *args, **kwargs))
|
||||||
|
|
||||||
|
batch_outputs = self._collate_fn(batch_outputs)
|
||||||
|
batch_outputs = self._to_device(batch_outputs)
|
||||||
|
|
||||||
|
return batch_outputs
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
"""Model forward.
|
||||||
|
If you need refactor model forward, you need to reimplement it.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs, mode='test')
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def postprocess(self, inputs, *args, **kwargs):
|
||||||
|
"""Process model outputs.
|
||||||
|
If you need add some processing ops to process model outputs, you need to reimplement it.
|
||||||
|
"""
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def _collate_fn(self, inputs):
|
||||||
|
"""Prepare the input just before the forward function.
|
||||||
|
Puts each data field into a tensor with outer dimension batch size
|
||||||
|
"""
|
||||||
|
return collate(inputs, samples_per_gpu=self.batch_size)
|
||||||
|
|
||||||
|
def _to_device(self, inputs):
|
||||||
|
target_gpus = [-1] if self.device == 'cpu' else [
|
||||||
|
torch.cuda.current_device()
|
||||||
|
]
|
||||||
|
_, kwargs = scatter_kwargs(None, inputs, target_gpus=target_gpus)
|
||||||
|
return kwargs[0]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dump(obj, save_path, mode='wb'):
|
||||||
|
with open(save_path, mode) as f:
|
||||||
|
f.write(pickle.dumps(obj))
|
||||||
|
|
||||||
|
def __call__(self, inputs, keep_inputs=False):
|
||||||
|
# TODO: fault tolerance
|
||||||
|
|
||||||
|
if isinstance(inputs, str):
|
||||||
|
inputs = [inputs]
|
||||||
|
|
||||||
|
results_list = []
|
||||||
|
for i in range(0, len(inputs), self.batch_size):
|
||||||
|
batch = inputs[i:max(len(inputs) - 1, i + self.batch_size)]
|
||||||
|
batch_outputs = self.preprocess(batch)
|
||||||
|
batch_outputs = self.forward(batch_outputs)
|
||||||
|
results = self.postprocess(batch_outputs)
|
||||||
|
if keep_inputs:
|
||||||
|
results = {'inputs': batch, 'results': results}
|
||||||
|
# if dump, the outputs will not added to the return value to prevent taking up too much memory
|
||||||
|
if self.save_results:
|
||||||
|
self.dump([results], self.save_path, mode='ab+')
|
||||||
|
else:
|
||||||
|
results_list.append(results)
|
||||||
|
|
||||||
|
return results_list
|
||||||
|
|
|
@ -4,5 +4,5 @@ from easycv.utils.registry import Registry, build_from_cfg
|
||||||
PREDICTORS = Registry('predictor')
|
PREDICTORS = Registry('predictor')
|
||||||
|
|
||||||
|
|
||||||
def build_predictor(cfg):
|
def build_predictor(cfg, default_args=None):
|
||||||
return build_from_cfg(cfg, PREDICTORS, default_args=None)
|
return build_from_cfg(cfg, PREDICTORS, default_args=default_args)
|
||||||
|
|
|
@ -16,6 +16,110 @@ from easycv.predictors.interface import PredictorInterface
|
||||||
from easycv.utils.checkpoint import load_checkpoint
|
from easycv.utils.checkpoint import load_checkpoint
|
||||||
from easycv.utils.config_tools import mmcv_config_fromfile
|
from easycv.utils.config_tools import mmcv_config_fromfile
|
||||||
from easycv.utils.registry import build_from_cfg
|
from easycv.utils.registry import build_from_cfg
|
||||||
|
from .base import PredictorV2
|
||||||
|
|
||||||
|
|
||||||
|
@PREDICTORS.register_module()
|
||||||
|
class SegmentationPredictor(PredictorV2):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_path,
|
||||||
|
config_file,
|
||||||
|
batch_size=1,
|
||||||
|
device=None,
|
||||||
|
save_results=False,
|
||||||
|
save_path=None):
|
||||||
|
"""Predict pipeline for Segmentation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str): Path of model path
|
||||||
|
config_file (str): config file path for model and processor to init. Defaults to None.
|
||||||
|
"""
|
||||||
|
super(SegmentationPredictor, self).__init__(
|
||||||
|
model_path,
|
||||||
|
config_file,
|
||||||
|
batch_size=batch_size,
|
||||||
|
device=device,
|
||||||
|
save_results=save_results,
|
||||||
|
save_path=save_path)
|
||||||
|
|
||||||
|
self.CLASSES = self.cfg.CLASSES
|
||||||
|
self.PALETTE = self.cfg.PALETTE
|
||||||
|
|
||||||
|
def show_result(self,
|
||||||
|
img,
|
||||||
|
result,
|
||||||
|
palette=None,
|
||||||
|
win_name='',
|
||||||
|
show=False,
|
||||||
|
wait_time=0,
|
||||||
|
out_file=None,
|
||||||
|
opacity=0.5):
|
||||||
|
"""Draw `result` over `img`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (str or Tensor): The image to be displayed.
|
||||||
|
result (Tensor): The semantic segmentation results to draw over
|
||||||
|
`img`.
|
||||||
|
palette (list[list[int]]] | np.ndarray | None): The palette of
|
||||||
|
segmentation map. If None is given, random palette will be
|
||||||
|
generated. Default: None
|
||||||
|
win_name (str): The window name.
|
||||||
|
wait_time (int): Value of waitKey param.
|
||||||
|
Default: 0.
|
||||||
|
show (bool): Whether to show the image.
|
||||||
|
Default: False.
|
||||||
|
out_file (str or None): The filename to write the image.
|
||||||
|
Default: None.
|
||||||
|
opacity(float): Opacity of painted segmentation map.
|
||||||
|
Default 0.5.
|
||||||
|
Must be in (0, 1] range.
|
||||||
|
Returns:
|
||||||
|
img (Tensor): Only if not `show` or `out_file`
|
||||||
|
"""
|
||||||
|
|
||||||
|
img = mmcv.imread(img)
|
||||||
|
img = img.copy()
|
||||||
|
seg = result[0]
|
||||||
|
if palette is None:
|
||||||
|
if self.PALETTE is None:
|
||||||
|
# Get random state before set seed,
|
||||||
|
# and restore random state later.
|
||||||
|
# It will prevent loss of randomness, as the palette
|
||||||
|
# may be different in each iteration if not specified.
|
||||||
|
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
||||||
|
state = np.random.get_state()
|
||||||
|
np.random.seed(42)
|
||||||
|
# random palette
|
||||||
|
palette = np.random.randint(
|
||||||
|
0, 255, size=(len(self.CLASSES), 3))
|
||||||
|
np.random.set_state(state)
|
||||||
|
else:
|
||||||
|
palette = self.PALETTE
|
||||||
|
palette = np.array(palette)
|
||||||
|
assert palette.shape[0] == len(self.CLASSES)
|
||||||
|
assert palette.shape[1] == 3
|
||||||
|
assert len(palette.shape) == 2
|
||||||
|
assert 0 < opacity <= 1.0
|
||||||
|
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
||||||
|
for label, color in enumerate(palette):
|
||||||
|
color_seg[seg == label, :] = color
|
||||||
|
# convert to BGR
|
||||||
|
color_seg = color_seg[..., ::-1]
|
||||||
|
|
||||||
|
img = img * (1 - opacity) + color_seg * opacity
|
||||||
|
img = img.astype(np.uint8)
|
||||||
|
# if out_file specified, do not show image in window
|
||||||
|
if out_file is not None:
|
||||||
|
show = False
|
||||||
|
|
||||||
|
if show:
|
||||||
|
mmcv.imshow(img, win_name, wait_time)
|
||||||
|
if out_file is not None:
|
||||||
|
mmcv.imwrite(img, out_file)
|
||||||
|
|
||||||
|
if not (show or out_file):
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
@PREDICTORS.register_module()
|
@PREDICTORS.register_module()
|
||||||
|
|
|
@ -0,0 +1,125 @@
|
||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import os
|
||||||
|
|
||||||
|
import jsonplus
|
||||||
|
|
||||||
|
from easycv.file import io
|
||||||
|
from easycv.utils.config_tools import Config
|
||||||
|
|
||||||
|
MODELSCOPE_PREFIX = 'modelscope'
|
||||||
|
EASYCV_ARCH = '__easycv_arch__'
|
||||||
|
|
||||||
|
|
||||||
|
def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True):
|
||||||
|
"""Convert EasyCV config to ModelScope style.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (str | Config): Easycv config file or Config object.
|
||||||
|
task (str): Task name in modelscope, refer to: modelscope.utils.constant.Tasks.
|
||||||
|
ms_model_name (str): Model name registered in modelscope, model type will be replaced with `ms_model_name`, used in modelscope.
|
||||||
|
save_path (str): Save path for saving the generated modelscope configuration file. Only valid when dump is True.
|
||||||
|
dump (bool): Whether dump the converted config to `save_path`.
|
||||||
|
"""
|
||||||
|
# TODO: support multi eval_pipelines
|
||||||
|
# TODO: support for adding customized required keys to the configuration file
|
||||||
|
|
||||||
|
if isinstance(cfg, str):
|
||||||
|
easycv_cfg = Config.fromfile(cfg)
|
||||||
|
if dump and save_path is None:
|
||||||
|
save_dir = os.path.dirname(cfg)
|
||||||
|
save_name = MODELSCOPE_PREFIX + '_' + os.path.splitext(
|
||||||
|
os.path.basename(cfg))[0] + '.json'
|
||||||
|
save_path = os.path.join(save_dir, save_name)
|
||||||
|
else:
|
||||||
|
easycv_cfg = cfg
|
||||||
|
if dump and save_path is None:
|
||||||
|
raise ValueError('Please provide `save_path`!')
|
||||||
|
|
||||||
|
assert save_path.endswith('json'), 'Only support json file!'
|
||||||
|
optimizer_options = easycv_cfg.optimizer_config
|
||||||
|
optimizer_options.update({'loss_keys': 'total_loss'})
|
||||||
|
|
||||||
|
val_dataset_cfg = easycv_cfg.data.val
|
||||||
|
val_imgs_per_gpu = val_dataset_cfg.pop('imgs_per_gpu',
|
||||||
|
easycv_cfg.data.imgs_per_gpu)
|
||||||
|
val_workers_per_gpu = val_dataset_cfg.pop('workers_per_gpu',
|
||||||
|
easycv_cfg.data.workers_per_gpu)
|
||||||
|
|
||||||
|
log_config = easycv_cfg.log_config
|
||||||
|
predict_config = easycv_cfg.get('predict', None)
|
||||||
|
|
||||||
|
hooks = [{
|
||||||
|
'type': 'CheckpointHook',
|
||||||
|
'interval': easycv_cfg.checkpoint_config.interval
|
||||||
|
}, {
|
||||||
|
'type': 'EvaluationHook',
|
||||||
|
'interval': easycv_cfg.eval_config.interval
|
||||||
|
}, {
|
||||||
|
'type': 'AddLrLogHook'
|
||||||
|
}, {
|
||||||
|
'type': 'IterTimerHook'
|
||||||
|
}]
|
||||||
|
|
||||||
|
custom_hooks = easycv_cfg.get('custom_hooks', [])
|
||||||
|
hooks.extend(custom_hooks)
|
||||||
|
|
||||||
|
for log_hook_i in log_config.hooks:
|
||||||
|
if log_hook_i['type'] == 'TensorboardLoggerHook':
|
||||||
|
# replace with modelscope api
|
||||||
|
hooks.append({
|
||||||
|
'type': 'TensorboardHook',
|
||||||
|
'interval': log_config.interval
|
||||||
|
})
|
||||||
|
elif log_hook_i['type'] == 'TextLoggerHook':
|
||||||
|
# use modelscope api
|
||||||
|
hooks.append({
|
||||||
|
'type': 'TextLoggerHook',
|
||||||
|
'interval': log_config.interval
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
log_hook_i.update({'interval': log_config.interval})
|
||||||
|
hooks.append(log_hook_i)
|
||||||
|
|
||||||
|
ori_model_type = easycv_cfg.model.pop('type')
|
||||||
|
|
||||||
|
ms_cfg = Config(
|
||||||
|
dict(
|
||||||
|
task=task,
|
||||||
|
framework='pytorch',
|
||||||
|
model={
|
||||||
|
'type': ms_model_name,
|
||||||
|
**easycv_cfg.model, EASYCV_ARCH: {
|
||||||
|
'type': ori_model_type
|
||||||
|
}
|
||||||
|
},
|
||||||
|
dataset=dict(train=easycv_cfg.data.train, val=val_dataset_cfg),
|
||||||
|
train=dict(
|
||||||
|
work_dir=easycv_cfg.get('work_dir', None),
|
||||||
|
max_epochs=easycv_cfg.total_epochs,
|
||||||
|
dataloader=dict(
|
||||||
|
batch_size_per_gpu=easycv_cfg.data.imgs_per_gpu,
|
||||||
|
workers_per_gpu=easycv_cfg.data.workers_per_gpu,
|
||||||
|
),
|
||||||
|
optimizer=dict(
|
||||||
|
**easycv_cfg.optimizer, options=optimizer_options),
|
||||||
|
lr_scheduler=easycv_cfg.lr_config,
|
||||||
|
hooks=hooks),
|
||||||
|
evaluation=dict(
|
||||||
|
dataloader=dict(
|
||||||
|
batch_size_per_gpu=val_imgs_per_gpu,
|
||||||
|
workers_per_gpu=val_workers_per_gpu,
|
||||||
|
),
|
||||||
|
metrics={
|
||||||
|
'type': 'EasyCVMetric',
|
||||||
|
'evaluators': easycv_cfg.eval_pipelines[0].evaluators
|
||||||
|
}),
|
||||||
|
pipeline=dict(predictor_config=predict_config),
|
||||||
|
))
|
||||||
|
|
||||||
|
if dump:
|
||||||
|
with io.open(save_path, 'w') as f:
|
||||||
|
res = jsonplus.dumps(
|
||||||
|
ms_cfg._cfg_dict.to_dict(), indent=4, sort_keys=False)
|
||||||
|
f.write(res)
|
||||||
|
|
||||||
|
return ms_cfg
|
|
@ -1,2 +1,4 @@
|
||||||
http://pai-nni.oss-cn-zhangjiakou.aliyuncs.com/release/2.6.1/pai_nni-2.6.1-py3-none-manylinux1_x86_64.whl
|
http://pai-nni.oss-cn-zhangjiakou.aliyuncs.com/release/2.6.1/pai_nni-2.6.1-py3-none-manylinux1_x86_64.whl
|
||||||
|
http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/pkgs/whl/panopticapi/panopticapi-0.1-py3-none-any.whl
|
||||||
http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/blade_compression-0.0.2-py3-none-any.whl
|
http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/blade_compression-0.0.2-py3-none-any.whl
|
||||||
|
https://developer.download.nvidia.com/compute/redist/nvidia-dali-cuda100/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl
|
||||||
|
|
|
@ -3,21 +3,19 @@ dataclasses
|
||||||
einops
|
einops
|
||||||
future
|
future
|
||||||
h5py
|
h5py
|
||||||
http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/pkgs/whl/panopticapi/panopticapi-0.1-py3-none-any.whl
|
|
||||||
https://developer.download.nvidia.com/compute/redist/nvidia-dali-cuda100/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl
|
|
||||||
json_tricks
|
json_tricks
|
||||||
numpy
|
numpy
|
||||||
opencv-python-headless
|
opencv-python
|
||||||
oss2
|
oss2
|
||||||
packaging
|
packaging
|
||||||
Pillow
|
Pillow
|
||||||
prettytable
|
prettytable
|
||||||
pycocotools
|
pycocotools
|
||||||
pytorch_metric_learning==0.9.89
|
pytorch_metric_learning>=0.9.89
|
||||||
scikit-image
|
scikit-image
|
||||||
sklearn
|
sklearn
|
||||||
tensorboard
|
tensorboard
|
||||||
thop
|
thop
|
||||||
timm==0.4.9
|
timm>=0.4.9
|
||||||
xtcocotools
|
xtcocotools
|
||||||
yacs
|
yacs
|
||||||
|
|
|
@ -0,0 +1,107 @@
|
||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from tests.ut_config import (MODEL_CONFIG_SEGFORMER,
|
||||||
|
PRETRAINED_MODEL_SEGFORMER, TEST_IMAGES_DIR)
|
||||||
|
|
||||||
|
from easycv.predictors.segmentation import SegmentationPredictor
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentationPredictorTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||||
|
|
||||||
|
def test_single(self):
|
||||||
|
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
|
||||||
|
segmentation_model_config = MODEL_CONFIG_SEGFORMER
|
||||||
|
|
||||||
|
img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
|
||||||
|
img = np.asarray(Image.open(img_path))
|
||||||
|
|
||||||
|
predict_pipeline = SegmentationPredictor(
|
||||||
|
model_path=segmentation_model_path,
|
||||||
|
config_file=segmentation_model_config)
|
||||||
|
|
||||||
|
outputs = predict_pipeline(img_path, keep_inputs=True)
|
||||||
|
self.assertEqual(len(outputs), 1)
|
||||||
|
self.assertEqual(outputs[0]['inputs'], [img_path])
|
||||||
|
|
||||||
|
results = outputs[0]['results']
|
||||||
|
self.assertListEqual(
|
||||||
|
list(img.shape)[:2], list(results['seg_pred'][0].shape))
|
||||||
|
self.assertListEqual(results['seg_pred'][0][1, :10].tolist(),
|
||||||
|
[161 for i in range(10)])
|
||||||
|
self.assertListEqual(results['seg_pred'][0][-1, -10:].tolist(),
|
||||||
|
[133 for i in range(10)])
|
||||||
|
|
||||||
|
def test_batch(self):
|
||||||
|
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
|
||||||
|
segmentation_model_config = MODEL_CONFIG_SEGFORMER
|
||||||
|
|
||||||
|
img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
|
||||||
|
img = np.asarray(Image.open(img_path))
|
||||||
|
|
||||||
|
predict_pipeline = SegmentationPredictor(
|
||||||
|
model_path=segmentation_model_path,
|
||||||
|
config_file=segmentation_model_config,
|
||||||
|
batch_size=2)
|
||||||
|
|
||||||
|
total_samples = 3
|
||||||
|
outputs = predict_pipeline(
|
||||||
|
[img_path] * total_samples, keep_inputs=True)
|
||||||
|
self.assertEqual(len(outputs), 2)
|
||||||
|
|
||||||
|
self.assertEqual(outputs[0]['inputs'], [img_path] * 2)
|
||||||
|
self.assertEqual(outputs[1]['inputs'], [img_path] * 1)
|
||||||
|
self.assertEqual(len(outputs[0]['results']['seg_pred']), 2)
|
||||||
|
self.assertEqual(len(outputs[1]['results']['seg_pred']), 1)
|
||||||
|
|
||||||
|
for result in [outputs[0]['results'], outputs[1]['results']]:
|
||||||
|
self.assertListEqual(
|
||||||
|
list(img.shape)[:2], list(result['seg_pred'][0].shape))
|
||||||
|
self.assertListEqual(result['seg_pred'][0][1, :10].tolist(),
|
||||||
|
[161 for i in range(10)])
|
||||||
|
self.assertListEqual(result['seg_pred'][0][-1, -10:].tolist(),
|
||||||
|
[133 for i in range(10)])
|
||||||
|
|
||||||
|
def test_dump(self):
|
||||||
|
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
|
||||||
|
segmentation_model_config = MODEL_CONFIG_SEGFORMER
|
||||||
|
|
||||||
|
img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
|
||||||
|
|
||||||
|
temp_dir = tempfile.TemporaryDirectory().name
|
||||||
|
if not os.path.exists(temp_dir):
|
||||||
|
os.makedirs(temp_dir)
|
||||||
|
tmp_path = os.path.join(temp_dir, 'results.pkl')
|
||||||
|
|
||||||
|
predict_pipeline = SegmentationPredictor(
|
||||||
|
model_path=segmentation_model_path,
|
||||||
|
config_file=segmentation_model_config,
|
||||||
|
batch_size=2,
|
||||||
|
save_results=True,
|
||||||
|
save_path=tmp_path)
|
||||||
|
|
||||||
|
total_samples = 3
|
||||||
|
outputs = predict_pipeline(
|
||||||
|
[img_path] * total_samples, keep_inputs=True)
|
||||||
|
self.assertEqual(outputs, [])
|
||||||
|
|
||||||
|
with open(tmp_path, 'rb') as f:
|
||||||
|
results = pickle.loads(f.read())
|
||||||
|
|
||||||
|
self.assertIn('inputs', results[0])
|
||||||
|
self.assertIn('results', results[0])
|
||||||
|
|
||||||
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
|
@ -0,0 +1,47 @@
|
||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import easycv
|
||||||
|
from easycv.utils.config_tools import Config
|
||||||
|
from easycv.utils.ms_utils import to_ms_config
|
||||||
|
|
||||||
|
|
||||||
|
class MsConfigTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||||
|
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||||
|
if not os.path.exists(self.tmp_dir):
|
||||||
|
os.makedirs(self.tmp_dir)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
shutil.rmtree(self.tmp_dir)
|
||||||
|
|
||||||
|
def test_to_ms_config(self):
|
||||||
|
|
||||||
|
config_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(easycv.__file__)),
|
||||||
|
'configs/detection/yolox/yolox_s_8xb16_300e_coco.py')
|
||||||
|
|
||||||
|
ms_cfg_file = os.path.join(self.tmp_dir,
|
||||||
|
'ms_yolox_s_8xb16_300e_coco.json')
|
||||||
|
to_ms_config(
|
||||||
|
config_path,
|
||||||
|
task='image-object-detection',
|
||||||
|
ms_model_name='yolox',
|
||||||
|
save_path=ms_cfg_file)
|
||||||
|
cfg = Config.fromfile(ms_cfg_file)
|
||||||
|
self.assertIn('task', cfg)
|
||||||
|
self.assertIn('framework', cfg)
|
||||||
|
self.assertEqual(cfg.model.type, 'yolox')
|
||||||
|
self.assertIn('dataset', cfg)
|
||||||
|
self.assertIn('batch_size_per_gpu', cfg.train.dataloader)
|
||||||
|
self.assertIn('batch_size_per_gpu', cfg.evaluation.dataloader)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue