Support multi processes for predictor (#272)

* support multi processes for predictor
pull/278/head
Cathy0908 2023-02-01 12:14:44 +08:00 committed by GitHub
parent c94f647ec6
commit 74ecd3d037
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1133 additions and 643 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import logging
import os
import pickle
@ -103,6 +104,178 @@ class Predictor(object):
return output
class InputProcessor(object):
"""Base input processor for processing input samples.
Args:
cfg (Config): Config instance.
pipelines (list[dict]): Data pipeline configs.
batch_size (int): batch size for forward.
threads (int): Number of processes to process inputs.
mode (str): The image mode into the model.
"""
def __init__(self,
cfg,
pipelines=None,
batch_size=1,
threads=8,
mode='BGR'):
self.cfg = cfg
self.pipelines = pipelines
self.batch_size = batch_size
if self.batch_size < threads:
logging.warning(
f'``batch_size`` is less than ``threads``, set ``threads`` to {self.batch_size }'
)
self.threads = min(self.batch_size, threads)
self.mode = mode
self.processor = self.build_processor()
self._load_op = None
def build_processor(self):
"""Build processor to process loaded input.
If you need custom preprocessing ops, you need to reimplement it.
"""
if self.pipelines is not None:
pipelines = self.pipelines
else:
pipelines = self.cfg.get('test_pipeline', [])
pipelines = [build_from_cfg(p, PIPELINES) for p in pipelines]
from easycv.datasets.shared.pipelines.transforms import Compose
processor = Compose(pipelines)
return processor
def _load_input(self, input):
"""Load image from file or numpy or PIL object.
Args:
input: File path or numpy or PIL object.
Returns:
{
'filename': filename,
'img': img,
'img_shape': img_shape,
'img_fields': ['img']
}
"""
if self._load_op is None:
load_cfg = dict(type='LoadImage', mode=self.mode)
self._load_op = build_from_cfg(load_cfg, PIPELINES)
if not isinstance(input, str):
if isinstance(input, np.ndarray):
# Only support RGB mode if input is np.ndarray.
input = cv2.cvtColor(input, cv2.COLOR_RGB2BGR)
sample = self._load_op({'img': input})
else:
sample = self._load_op({'filename': input})
return sample
def _collate_fn(self, inputs):
"""Prepare the input just before the forward function.
Puts each data field into a tensor with outer dimension batch size
"""
return collate(inputs, samples_per_gpu=self.batch_size)
def process_single(self, input):
"""Process single input sample.
If you need custom ops to load or process a single input sample, you need to reimplement it.
"""
input = self._load_input(input)
return self.processor(input)
def _process_single_for_parallel(self, i, *args, **kwargs):
# Fix hang issue with multi processes, refer to: https://github.com/pytorch/vision/issues/7068.
# Torch dataloder also set num_threads to 1 when num_workers>0, refer to: torch.utilss.data._utils.worker._worker_loop
# set_num_threads only valid in subprocesses, no need to reset for the main process
torch.set_num_threads(1)
return i, self.process_single(*args, **kwargs)
def __call__(self, inputs):
"""Process all inputs list. And collate to batch and put to target device.
If you need custom ops to load or process a batch samples, you need to reimplement it.
"""
batch_outputs = []
threads = min(self.threads, len(inputs))
if threads <= 1:
for inp in inputs:
batch_outputs.append(self.process_single(inp))
else:
import concurrent.futures
batch_outputs_with_idx = []
futures = []
with concurrent.futures.ProcessPoolExecutor(threads) as executor:
for i, inp in enumerate(inputs):
future = executor.submit(self._process_single_for_parallel,
i, inp)
futures.append(future)
for future in concurrent.futures.as_completed(futures):
batch_outputs_with_idx.append(future.result())
batch_outputs_with_idx = sorted(
batch_outputs_with_idx, key=lambda item: item[0])
batch_outputs = [out[1] for out in batch_outputs_with_idx]
return self._collate_fn(batch_outputs)
class OutputProcessor(object):
"""Base output processor for processing model outputs.
"""
def __init__(self):
pass
def _get_batch_size(self, inputs):
for k, batch_v in inputs.items():
if isinstance(batch_v, dict):
batch_size = self._get_batch_size(batch_v)
elif batch_v is not None:
batch_size = len(batch_v)
break
else:
batch_size = 1
return batch_size
def _extract_ith_result(self, inputs, i, out_i):
for k, batch_v in inputs.items():
if isinstance(batch_v, dict):
out_i[k] = {}
self._extract_ith_result(batch_v, i, out_i[k])
elif batch_v is not None:
out_i[k] = batch_v[i]
else:
out_i[k] = None
return out_i
def process_single(self, inputs):
"""Process outputs of single sample.
If you need add some processing ops, you need to reimplement it.
"""
return inputs
def __call__(self, inputs):
"""Process model batch outputs.
The "inputs" should be dict format as follows:
{
"key1": torch.Tensor or list, the first dimension should be batch_size,
"key2": torch.Tensor or list, the first dimension should be batch_size,
...
}
"""
outputs = []
batch_size = self._get_batch_size(inputs)
for i in range(batch_size):
out_i = self._extract_ith_result(inputs, i, {})
out_i = self.process_single(out_i)
outputs.append(out_i)
return outputs
class PredictorV2(object):
"""Base predict pipeline.
Args:
@ -113,8 +286,9 @@ class PredictorV2(object):
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
input_processor_threads (int): Number of processes to process inputs.
mode (str): The image mode into the model.
"""
INPUT_IMAGE_MODE = 'BGR' # the image mode into the model
def __init__(self,
model_path,
@ -124,14 +298,17 @@ class PredictorV2(object):
save_results=False,
save_path=None,
pipelines=None,
*args,
**kwargs):
input_processor_threads=8,
mode='BGR'):
self.logger = get_root_logger()
self.model_path = model_path
self.batch_size = batch_size
self.save_results = save_results
self.save_path = save_path
self.config_file = config_file
self.pipelines = pipelines
self.input_processor_threads = input_processor_threads
self.mode = mode
if self.save_results:
assert self.save_path is not None
self.device = device
@ -149,8 +326,6 @@ class PredictorV2(object):
if self.cfg is None:
raise ValueError('Please provide "config_file"!')
self.pipelines = pipelines
if self.cfg.get('predict', None) is not None:
self._sync_cfg_predict(self.cfg.predict)
@ -159,8 +334,19 @@ class PredictorV2(object):
self.cfg.model.pretrained = None
self.model = self.prepare_model()
self.processor = self.build_processor()
self._load_op = None
self.input_processor = None
self.output_processor = None
def get_input_processor(self):
return InputProcessor(
self.cfg,
pipelines=self.pipelines,
batch_size=self.batch_size,
threads=self.input_processor_threads,
mode=self.mode)
def get_output_processor(self):
return OutputProcessor()
def _sync_cfg_predict(self, predict_cfg):
if predict_cfg.get('type', None) is not None:
@ -207,68 +393,7 @@ class PredictorV2(object):
remove_adapt_for_mmlab(self.cfg)
return model
def build_processor(self):
"""Build processor to process loaded input.
If you need custom preprocessing ops, you need to reimplement it.
"""
if self.pipelines is not None:
pipelines = self.pipelines
else:
pipelines = self.cfg.get('test_pipeline', [])
pipelines = [build_from_cfg(p, PIPELINES) for p in pipelines]
from easycv.datasets.shared.pipelines.transforms import Compose
processor = Compose(pipelines)
return processor
def _load_input(self, input):
"""Load image from file or numpy or PIL object.
Args:
input: File path or numpy or PIL object.
Returns:
{
'filename': filename,
'img': img,
'img_shape': img_shape,
'img_fields': ['img']
}
"""
if self._load_op is None:
load_cfg = dict(type='LoadImage', mode=self.INPUT_IMAGE_MODE)
self._load_op = build_from_cfg(load_cfg, PIPELINES)
if not isinstance(input, str):
if isinstance(input, np.ndarray):
# Only support RGB mode if input is np.ndarray.
input = cv2.cvtColor(input, cv2.COLOR_RGB2BGR)
sample = self._load_op({'img': input})
else:
sample = self._load_op({'filename': input})
return sample
def preprocess_single(self, input):
"""Preprocess single input sample.
If you need custom ops to load or process a single input sample, you need to reimplement it.
"""
input = self._load_input(input)
return self.processor(input)
def preprocess(self, inputs, *args, **kwargs):
"""Process all inputs list. And collate to batch and put to target device.
If you need custom ops to load or process a batch samples, you need to reimplement it.
"""
batch_outputs = []
for i in inputs:
batch_outputs.append(self.preprocess_single(i, *args, **kwargs))
batch_outputs = self._collate_fn(batch_outputs)
batch_outputs = self._to_device(batch_outputs)
return batch_outputs
def forward(self, inputs):
def model_forward(self, inputs):
"""Model forward.
If you need refactor model forward, you need to reimplement it.
"""
@ -276,92 +401,43 @@ class PredictorV2(object):
outputs = self.model(**inputs, mode='test')
return outputs
def _get_batch_size(self, inputs):
for k, batch_v in inputs.items():
if isinstance(batch_v, dict):
batch_size = self._get_batch_size(batch_v)
elif batch_v is not None:
batch_size = len(batch_v)
break
else:
batch_size = 1
return batch_size
def _extract_ith_result(self, inputs, i, out_i):
for k, batch_v in inputs.items():
if isinstance(batch_v, dict):
out_i[k] = {}
self._extract_ith_result(batch_v, i, out_i[k])
elif batch_v is not None:
out_i[k] = batch_v[i]
else:
out_i[k] = None
return out_i
def postprocess(self, inputs, *args, **kwargs):
"""Process model batch outputs.
The "inputs" should be dict format as follows:
{
"key1": torch.Tensor or list, the first dimension should be batch_size,
"key2": torch.Tensor or list, the first dimension should be batch_size,
...
}
"""
outputs = []
batch_size = self._get_batch_size(inputs)
for i in range(batch_size):
out_i = self._extract_ith_result(inputs, i, {})
out_i = self.postprocess_single(out_i, *args, **kwargs)
outputs.append(out_i)
return outputs
def postprocess_single(self, inputs, *args, **kwargs):
"""Process outputs of single sample.
If you need add some processing ops, you need to reimplement it.
"""
return inputs
def _collate_fn(self, inputs):
"""Prepare the input just before the forward function.
Puts each data field into a tensor with outer dimension batch size
"""
return collate(inputs, samples_per_gpu=self.batch_size)
def _to_device(self, inputs):
target_gpus = [-1] if str(
self.device) == 'cpu' else [torch.cuda.current_device()]
_, kwargs = scatter_kwargs(None, inputs, target_gpus=target_gpus)
return kwargs[0]
@staticmethod
def dump(obj, save_path, mode='wb'):
def dump(self, obj, save_path, mode='wb'):
with open(save_path, mode) as f:
f.write(pickle.dumps(obj))
def __call__(self, inputs, keep_inputs=False):
# TODO: fault tolerance
if self.input_processor is None:
self.input_processor = self.get_input_processor()
if self.output_processor is None:
self.output_processor = self.get_output_processor()
# TODO: fault tolerance
if isinstance(inputs, (str, np.ndarray, ImageFile.ImageFile)):
inputs = [inputs]
results_list = []
for i in range(0, len(inputs), self.batch_size):
batch = inputs[i:min(len(inputs), i + self.batch_size)]
batch_outputs = self.preprocess(batch)
batch_outputs = self.forward(batch_outputs)
results = self.postprocess(batch_outputs)
# assert len(results) == len(
# batch), f'Mismatch size {len(results)} != {len(batch)}'
batch_inputs = inputs[i:min(len(inputs), i + self.batch_size)]
batch_outputs = self.input_processor(batch_inputs)
batch_outputs = self._to_device(batch_outputs)
batch_outputs = self.model_forward(batch_outputs)
results = self.output_processor(batch_outputs)
if keep_inputs:
for i in range(len(batch)):
results[i].update({'inputs': batch[i]})
# if dump, the outputs will not added to the return value to prevent taking up too much memory
if self.save_results:
self.dump(results, self.save_path, mode='ab+')
for i in range(len(batch_inputs)):
results[i].update({'inputs': batch_inputs[i]})
if isinstance(results, list):
results_list.extend(results)
else:
if isinstance(results, list):
results_list.extend(results)
else:
results_list.append(results)
results_list.append(results)
# TODO: support append to file
if self.save_results:
self.dump(results_list, self.save_path)
return results_list

View File

@ -12,84 +12,38 @@ from easycv.datasets.registry import PIPELINES
from easycv.datasets.shared.pipelines.format import to_tensor
from easycv.datasets.shared.pipelines.transforms import Compose
from easycv.framework.errors import ValueError
from easycv.predictors.base import PredictorV2
from easycv.predictors.base import InputProcessor, PredictorV2
from easycv.predictors.builder import PREDICTORS
from easycv.utils.misc import encode_str_to_tensor
from easycv.utils.registry import build_from_cfg
@PREDICTORS.register_module()
class BEVFormerPredictor(PredictorV2):
"""Predictor for BEVFormer.
class BEVFormerInputProcessor(InputProcessor):
"""Process inputs for BEVFormer model.
Args:
model_path (str): Path of model path.
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward.
device (str | torch.device): Support str('cuda' or 'cpu') or torch.device, if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
cfg (Config): Config instance.
pipelines (list[dict]): Data pipeline configs.
box_type_3d (str): Box type.
batch_size (int): batch size for forward.
use_camera (bool): Whether use camera data.
score_threshold (float): Score threshold to filter inference results.
box_type_3d (str): Box type.
threads (int): Number of processes to process inputs.
"""
def __init__(self,
model_path,
config_file=None,
batch_size=1,
device=None,
save_results=False,
save_path=None,
cfg,
pipelines=None,
box_type_3d='LiDAR',
batch_size=1,
use_camera=True,
score_threshold=0.1,
model_type=None,
*arg,
**kwargs):
if batch_size > 1:
raise ValueError(
f'Only support batch_size=1 now, but get batch_size={batch_size}'
)
self.model_type = model_type
if self.model_type is None:
if model_path.endswith('jit'):
self.model_type = 'jit'
elif model_path.endswith('blade'):
self.model_type = 'blade'
self.is_jit_model = self.model_type in ['jit', 'blade']
super(BEVFormerPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
*arg,
**kwargs)
self.box_type_3d, self.box_mode_3d = get_box_type(box_type_3d)
self.CLASSES = self.cfg.get('CLASSES', None)
box_type_3d='LiDAR',
adapt_jit=False,
threads=8):
self.use_camera = use_camera
self.score_threshold = score_threshold
self.result_key = 'pts_bbox'
self.box_type_3d, self.box_mode_3d = get_box_type(box_type_3d)
self.adapt_jit = adapt_jit
# The initial prev_bev should be the weight of self.model.pts_bbox_head.bev_embedding, but the weight cannot be taken out from the blade model.
# So we using the dummy data as the the initial value, and it will not be used, just to adapt to jit and blade models.
# init_prev_bev = self.model.pts_bbox_head.bev_embedding.weight.clone().detach()
# init_prev_bev = init_prev_bev[:, None, :], # [40000, 256] -> [40000, 1, 256]
dummy_prev_bev = torch.rand(
[self.cfg.bev_h * self.cfg.bev_w, 1,
self.cfg.embed_dim]).to(self.device)
self.prev_frame_info = {
'prev_bev': dummy_prev_bev.to(self.device),
'prev_scene_token': encode_str_to_tensor('dummy_prev_scene_token'),
'prev_pos': torch.tensor(0),
'prev_angle': torch.tensor(0),
}
super(BEVFormerInputProcessor, self).__init__(
cfg, pipelines=pipelines, batch_size=batch_size, threads=threads)
def _prepare_input_dict(self, data_info):
from nuscenes.eval.common.utils import Quaternion, quaternion_yaw
@ -161,8 +115,8 @@ class BEVFormerPredictor(PredictorV2):
result = load_pipelines(input_dict)
return result
def preprocess_single(self, input):
"""Preprocess single input sample.
def process_single(self, input):
"""Process single input sample.
Args:
input (str): Pickle file path, the content format is the same with the infos file of nusences.
"""
@ -170,7 +124,7 @@ class BEVFormerPredictor(PredictorV2):
result = self._prepare_input_dict(data_info)
result = self.processor(result)
if self.is_jit_model:
if self.adapt_jit:
result['can_bus'] = DC(
to_tensor(result['img_metas'][0]._data['can_bus']),
cpu_only=False)
@ -199,9 +153,96 @@ class BEVFormerPredictor(PredictorV2):
return result
def postprocess_single(self, inputs, *args, **kwargs):
# TODO: filter results by score_threshold
return super().postprocess_single(inputs, *args, **kwargs)
@PREDICTORS.register_module()
class BEVFormerPredictor(PredictorV2):
"""Predictor for BEVFormer.
Args:
model_path (str): Path of model path.
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward.
device (str | torch.device): Support str('cuda' or 'cpu') or torch.device, if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
box_type_3d (str): Box type.
use_camera (bool): Whether use camera data.
score_threshold (float): Score threshold to filter inference results.
input_processor_threads (int): Number of processes to process inputs.
mode (str): The image mode into the model.
"""
def __init__(self,
model_path,
config_file=None,
batch_size=1,
device=None,
save_results=False,
save_path=None,
pipelines=None,
box_type_3d='LiDAR',
use_camera=True,
score_threshold=0.1,
model_type=None,
input_processor_threads=8,
mode='BGR',
*arg,
**kwargs):
if batch_size > 1:
raise ValueError(
f'Only support batch_size=1 now, but get batch_size={batch_size}'
)
self.model_type = model_type
if self.model_type is None:
if model_path.endswith('jit'):
self.model_type = 'jit'
elif model_path.endswith('blade'):
self.model_type = 'blade'
self.is_jit_model = self.model_type in ['jit', 'blade']
self.use_camera = use_camera
self.score_threshold = score_threshold
self.result_key = 'pts_bbox'
self.box_type_3d_str = box_type_3d
self.box_type_3d, self.box_mode_3d = get_box_type(box_type_3d)
super(BEVFormerPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
input_processor_threads=input_processor_threads,
mode=mode,
*arg,
**kwargs)
self.CLASSES = self.cfg.get('CLASSES', None)
# The initial prev_bev should be the weight of self.model.pts_bbox_head.bev_embedding, but the weight cannot be taken out from the blade model.
# So we using the dummy data as the the initial value, and it will not be used, just to adapt to jit and blade models.
# init_prev_bev = self.model.pts_bbox_head.bev_embedding.weight.clone().detach()
# init_prev_bev = init_prev_bev[:, None, :], # [40000, 256] -> [40000, 1, 256]
dummy_prev_bev = torch.rand(
[self.cfg.bev_h * self.cfg.bev_w, 1,
self.cfg.embed_dim]).to(self.device)
self.prev_frame_info = {
'prev_bev': dummy_prev_bev.to(self.device),
'prev_scene_token': encode_str_to_tensor('dummy_prev_scene_token'),
'prev_pos': torch.tensor(0),
'prev_angle': torch.tensor(0),
}
def get_input_processor(self):
return BEVFormerInputProcessor(
self.cfg,
pipelines=self.pipelines,
batch_size=self.batch_size,
use_camera=self.use_camera,
box_type_3d=self.box_type_3d_str,
adapt_jit=self.is_jit_model,
threads=self.input_processor_threads)
def prepare_model(self):
if self.is_jit_model:
@ -209,7 +250,7 @@ class BEVFormerPredictor(PredictorV2):
return model
return super().prepare_model()
def forward(self, inputs):
def model_forward(self, inputs):
if self.is_jit_model:
with torch.no_grad():
img = inputs['img'][0][0]
@ -244,7 +285,7 @@ class BEVFormerPredictor(PredictorV2):
}],
}
return outputs
return super().forward(inputs)
return super().model_forward(inputs)
def visualize(self, inputs, results, out_dir, show=False, pipeline=None):
raise NotImplementedError

View File

@ -3,73 +3,40 @@ import math
import numpy as np
import torch
from PIL import Image, ImageFile
from PIL import Image
from easycv.file import io
from easycv.framework.errors import ValueError
from easycv.utils.misc import deprecated
from .base import Predictor, PredictorV2
from .base import InputProcessor, OutputProcessor, Predictor, PredictorV2
from .builder import PREDICTORS
@PREDICTORS.register_module()
class ClassificationPredictor(PredictorV2):
"""Predictor for classification.
class ClsInputProcessor(InputProcessor):
"""Process inputs for classification models.
Args:
model_path (str): Path of model path.
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward.
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
cfg (Config): Config instance.
pipelines (list[dict]): Data pipeline configs.
topk (int): Return top-k results. Default: 1.
batch_size (int): batch size for forward.
pil_input (bool): Whether use PIL image. If processor need PIL input, set true, default false.
label_map_path (str): File path of saving labels list.
threads (int): Number of processes to process inputs.
mode (str): The image mode into the model.
"""
def __init__(self,
model_path,
config_file=None,
batch_size=1,
device=None,
save_results=False,
save_path=None,
cfg,
pipelines=None,
topk=1,
batch_size=1,
pil_input=True,
label_map_path=None,
*args,
**kwargs):
super(ClassificationPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
*args,
**kwargs)
self.topk = topk
threads=8,
mode='BGR'):
super(ClsInputProcessor, self).__init__(
cfg, pipelines=pipelines, batch_size=batch_size, threads=threads)
self.mode = mode
self.pil_input = pil_input
# Adapt to torchvision transforms which process PIL inputs.
if self.pil_input:
self.INPUT_IMAGE_MODE = 'RGB'
if label_map_path is None:
if 'CLASSES' in self.cfg:
class_list = self.cfg.get('CLASSES', [])
elif 'class_list' in self.cfg:
class_list = self.cfg.get('class_list', [])
else:
class_list = []
else:
with io.open(label_map_path, 'r') as f:
class_list = f.readlines()
self.label_map = [i.strip() for i in class_list]
def _load_input(self, input):
"""Load image from file or numpy or PIL object.
Args:
@ -86,8 +53,8 @@ class ClassificationPredictor(PredictorV2):
results = {}
if isinstance(input, str):
img = Image.open(input)
if img.mode.upper() != self.INPUT_IMAGE_MODE.upper():
img = img.convert(self.INPUT_IMAGE_MODE.upper())
if img.mode.upper() != self.mode.upper():
img = img.convert(self.mode.upper())
results['filename'] = input
else:
if isinstance(input, np.ndarray):
@ -103,7 +70,22 @@ class ClassificationPredictor(PredictorV2):
return super()._load_input(input)
def postprocess(self, inputs, *args, **kwargs):
class ClsOutputProcessor(OutputProcessor):
"""Output processor for processing classification model outputs.
Args:
topk (int): Return top-k results. Default: 1.
label_map (dict): Dict of class id to class name.
"""
def __init__(self, topk=1, label_map={}):
self.topk = topk
self.label_map = label_map
super(ClsOutputProcessor, self).__init__()
def __call__(self, inputs):
"""Return top-k results."""
output_prob = inputs['prob'].data.cpu()
topk_class = torch.topk(output_prob, self.topk).indices.numpy()
@ -127,6 +109,84 @@ class ClassificationPredictor(PredictorV2):
return batch_results
@PREDICTORS.register_module()
class ClassificationPredictor(PredictorV2):
"""Predictor for classification.
Args:
model_path (str): Path of model path.
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward.
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
topk (int): Return top-k results. Default: 1.
pil_input (bool): Whether use PIL image. If processor need PIL input, set true, default false.
label_map_path (str): File path of saving labels list.
input_processor_threads (int): Number of processes to process inputs.
mode (str): The image mode into the model.
"""
def __init__(self,
model_path,
config_file=None,
batch_size=1,
device=None,
save_results=False,
save_path=None,
pipelines=None,
topk=1,
pil_input=True,
label_map_path=None,
input_processor_threads=8,
mode='BGR',
*args,
**kwargs):
self.topk = topk
self.pil_input = pil_input
self.label_map_path = label_map_path
if self.pil_input:
mode = 'RGB'
super(ClassificationPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
input_processor_threads=input_processor_threads,
mode=mode,
*args,
**kwargs)
def get_input_processor(self):
return ClsInputProcessor(
self.cfg,
pipelines=self.pipelines,
batch_size=self.batch_size,
threads=self.input_processor_threads,
pil_input=self.pil_input,
mode=self.mode)
def get_output_processor(self):
# Adapt to torchvision transforms which process PIL inputs.
if self.label_map_path is None:
if 'CLASSES' in self.cfg:
class_list = self.cfg.get('CLASSES', [])
elif 'class_list' in self.cfg:
class_list = self.cfg.get('class_list', [])
else:
class_list = []
else:
with io.open(self.label_map_path, 'r') as f:
class_list = f.readlines()
self.label_map = [i.strip() for i in class_list]
return ClsOutputProcessor(topk=self.topk, label_map=self.label_map)
try:
from easy_vision.python.inference.predictor import PredictorInterface
except:

View File

@ -5,21 +5,15 @@ from glob import glob
import numpy as np
import torch
from torchvision.transforms import Compose
from easycv.core.visualization import imshow_bboxes
from easycv.datasets.registry import PIPELINES
from easycv.datasets.utils import replace_ImageToTensor
from easycv.file import io
from easycv.models import build_model
from easycv.models.detection.utils import postprocess
from easycv.thirdparty.mtcnn import FaceDetector
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.config_tools import mmcv_config_fromfile
from easycv.utils.constant import CACHE_DIR
from easycv.utils.misc import deprecated
from easycv.utils.registry import build_from_cfg
from .base import PredictorV2
from .base import InputProcessor, OutputProcessor, PredictorV2
from .builder import PREDICTORS
from .classifier import TorchClassifier
@ -29,33 +23,7 @@ except Exception:
from .interface import PredictorInterface
@PREDICTORS.register_module()
class DetectionPredictor(PredictorV2):
"""Generic Detection Predictor, it will filter bbox results by ``score_threshold`` .
"""
def __init__(self,
model_path,
config_file=None,
batch_size=1,
device=None,
save_results=False,
save_path=None,
pipelines=None,
score_threshold=0.5,
*arg,
**kwargs):
super(DetectionPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
)
self.score_thresh = score_threshold
self.CLASSES = self.cfg.get('CLASSES', None)
class DetInputProcessor(InputProcessor):
def build_processor(self):
if self.pipelines is not None:
@ -70,7 +38,15 @@ class DetectionPredictor(PredictorV2):
return super().build_processor()
def postprocess_single(self, inputs, *args, **kwargs):
class DetOutputProcessor(OutputProcessor):
def __init__(self, score_thresh, classes=None):
super(DetOutputProcessor, self).__init__()
self.score_thresh = score_thresh
self.classes = classes
def process_single(self, inputs):
if inputs['detection_scores'] is None or len(
inputs['detection_scores']) < 1:
return inputs
@ -87,8 +63,8 @@ class DetectionPredictor(PredictorV2):
for _, classes_id in enumerate(inputs['detection_classes']):
if classes_id is None:
class_names.append(None)
elif self.CLASSES is not None and len(self.CLASSES) > 0:
class_names.append(self.CLASSES[int(classes_id)])
elif self.classes is not None and len(self.classes) > 0:
class_names.append(self.classes[int(classes_id)])
else:
class_names.append(classes_id)
@ -96,11 +72,65 @@ class DetectionPredictor(PredictorV2):
return inputs
@PREDICTORS.register_module()
class DetectionPredictor(PredictorV2):
"""Generic Detection Predictor, it will filter bbox results by ``score_threshold`` .
Args:
model_path (str): Path of model path.
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward.
device (str | torch.device): Support str('cuda' or 'cpu') or torch.device, if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
input_processor_threads (int): Number of processes to process inputs.
mode (str): The image mode into the model.
"""
def __init__(self,
model_path,
config_file=None,
batch_size=1,
device=None,
save_results=False,
save_path=None,
pipelines=None,
score_threshold=0.5,
input_processor_threads=8,
mode='BGR',
*arg,
**kwargs):
super(DetectionPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
input_processor_threads=input_processor_threads,
mode=mode)
self.score_thresh = score_threshold
self.CLASSES = self.cfg.get('CLASSES', None)
def get_input_processor(self):
return DetInputProcessor(
self.cfg,
pipelines=self.pipelines,
batch_size=self.batch_size,
threads=self.input_processor_threads,
mode=self.mode)
def get_output_processor(self):
return DetOutputProcessor(self.score_thresh, self.CLASSES)
def visualize(self, img, results, show=False, out_file=None):
"""Only support show one sample now."""
bboxes = results['detection_boxes']
labels = results['detection_class_names']
img = self._load_input(img)['img']
img = self.input_processor._load_input(img)['img']
imshow_bboxes(
img,
bboxes,
@ -135,77 +165,43 @@ class _JitProcessorWrapper:
return results
@PREDICTORS.register_module()
class YoloXPredictor(DetectionPredictor):
"""Detection predictor for Yolox."""
class YoloXInputProcessor(DetInputProcessor):
"""Input processor for yolox.
def __init__(self,
model_path,
config_file=None,
batch_size=1,
use_trt_efficientnms=False,
device=None,
save_results=False,
save_path=None,
pipelines=None,
max_det=100,
score_thresh=0.5,
nms_thresh=None,
test_conf=None,
*arg,
**kwargs):
self.max_det = max_det
self.use_trt_efficientnms = use_trt_efficientnms
Args:
cfg (Config): Config instance.
pipelines (list[dict]): Data pipeline configs.
batch_size (int): batch size for forward.
model_type (str): "raw" or "jit" or "blade"
jit_processor_path (str): File of the saved processing operator of torch jit type.
device (str | torch.device): Support str('cuda' or 'cpu') or torch.device, if is None, detect device automatically.
threads (int): Number of processes to process inputs.
mode (str): The image mode into the model.
"""
if model_path.endswith('jit'):
self.model_type = 'jit'
elif model_path.endswith('blade'):
self.model_type = 'blade'
else:
self.model_type = 'raw'
def __init__(
self,
cfg,
pipelines=None,
batch_size=1,
model_type='raw',
jit_processor_path=None,
device=None,
threads=8,
mode='BGR',
):
self.model_type = model_type
self.jit_processor_path = jit_processor_path
self.device = device
if self.device is None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if self.model_type == 'blade' or self.use_trt_efficientnms:
import torch_blade
if self.model_type != 'raw' and config_file is None:
config_file = model_path + '.config.json'
super(YoloXPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
super().__init__(
cfg,
pipelines=pipelines,
score_threshold=score_thresh)
self.test_conf = test_conf or self.cfg['model'].get('test_conf', 0.01)
self.nms_thre = nms_thresh or self.cfg['model'].get('nms_thre', 0.65)
self.CLASSES = self.cfg.get('CLASSES', None) or self.cfg.get(
'classes', None)
assert self.CLASSES is not None
def _build_model(self):
if self.model_type != 'raw':
with io.open(self.model_path, 'rb') as infile:
model = torch.jit.load(infile, self.device)
else:
from easycv.utils.misc import reparameterize_models
model = super()._build_model()
model = reparameterize_models(model)
return model
def prepare_model(self):
"""Build model from config file by default.
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
"""
model = self._build_model()
model.to(self.device)
model.eval()
if self.model_type == 'raw':
load_checkpoint(model, self.model_path, map_location='cpu')
return model
batch_size=batch_size,
threads=threads,
mode=mode)
def build_processor(self):
self.jit_preprocess = False
@ -217,31 +213,32 @@ class YoloXPredictor(DetectionPredictor):
if self.model_type != 'raw' and self.jit_preprocess:
# jit or blade model
processor = None
preprocess_path = '.'.join(
self.model_path.split('.')[:-1] + ['preprocess'])
if os.path.exists(preprocess_path):
if os.path.exists(self.jit_processor_path):
if self.threads > 1:
raise ValueError(
'Not support threads>1 for jit processor !')
# use a preprocess jit model to speed up
with io.open(preprocess_path, 'rb') as infile:
with io.open(self.jit_processor_path, 'rb') as infile:
processor = torch.jit.load(infile, self.device)
return _JitProcessorWrapper(processor, self.device)
else:
return super().build_processor()
def forward(self, inputs):
"""Model forward.
If you need refactor model forward, you need to reimplement it.
"""
if self.model_type != 'raw':
with torch.no_grad():
outputs = self.model(inputs['img'])
outputs = {'results': outputs} # convert to dict format
else:
outputs = super().forward(inputs)
if 'img_metas' not in outputs:
outputs['img_metas'] = inputs['img_metas']
class YoloXOutputProcessor(DetOutputProcessor):
return outputs
def __init__(self,
score_thresh=0.5,
model_type='raw',
test_conf=0.01,
nms_thre=0.65,
use_trt_efficientnms=False,
classes=None):
super().__init__(score_thresh, classes)
self.model_type = model_type
self.test_conf = test_conf
self.nms_thre = nms_thre
self.use_trt_efficientnms = use_trt_efficientnms
def post_assign(self, outputs, img_metas):
detection_boxes = []
@ -277,7 +274,7 @@ class YoloXPredictor(DetectionPredictor):
}
return test_outputs
def postprocess_single(self, inputs):
def process_single(self, inputs):
det_out = inputs
img_meta = det_out['img_metas']
@ -296,18 +293,148 @@ class YoloXPredictor(DetectionPredictor):
else:
det_out = self.post_assign(
postprocess(
results.unsqueeze(0), len(self.CLASSES),
results.unsqueeze(0), len(self.classes),
self.test_conf, self.nms_thre),
img_metas=[img_meta])
det_out['detection_scores'] = det_out['detection_scores'][0]
det_out['detection_boxes'] = det_out['detection_boxes'][0]
det_out['detection_classes'] = det_out['detection_classes'][0]
resuts = super().postprocess_single(det_out)
resuts = super().process_single(det_out)
resuts['ori_img_shape'] = list(img_meta['ori_img_shape'][:2])
return resuts
@PREDICTORS.register_module()
class YoloXPredictor(DetectionPredictor):
"""Detection predictor for Yolox.
Args:
model_path (str): Path of model path.
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward.
use_trt_efficientnms (bool): Whether used tensorrt efficient nms operation in the saved model.
device (str | torch.device): Support str('cuda' or 'cpu') or torch.device, if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
max_det (int): Maximum number of detection output boxes.
score_thresh (float): Score threshold to filter box.
nms_thresh (float): Nms threshold to filter box.
input_processor_threads (int): Number of processes to process inputs.
mode (str): The image mode into the model.
"""
def __init__(self,
model_path,
config_file=None,
batch_size=1,
use_trt_efficientnms=False,
device=None,
save_results=False,
save_path=None,
pipelines=None,
max_det=100,
score_thresh=0.5,
nms_thresh=None,
test_conf=None,
input_processor_threads=8,
mode='BGR'):
self.max_det = max_det
self.use_trt_efficientnms = use_trt_efficientnms
if model_path.endswith('jit'):
self.model_type = 'jit'
elif model_path.endswith('blade'):
self.model_type = 'blade'
else:
self.model_type = 'raw'
if self.model_type == 'blade' or self.use_trt_efficientnms:
import torch_blade
if self.model_type != 'raw' and config_file is None:
config_file = model_path + '.config.json'
super(YoloXPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
score_threshold=score_thresh,
input_processor_threads=input_processor_threads,
mode=mode)
self.test_conf = test_conf or self.cfg['model'].get('test_conf', 0.01)
self.nms_thre = nms_thresh or self.cfg['model'].get('nms_thre', 0.65)
self.CLASSES = self.cfg.get('CLASSES', None) or self.cfg.get(
'classes', None)
assert self.CLASSES is not None
self.jit_processor_path = '.'.join(
self.model_path.split('.')[:-1] + ['preprocess'])
def _build_model(self):
if self.model_type != 'raw':
with io.open(self.model_path, 'rb') as infile:
model = torch.jit.load(infile, self.device)
else:
from easycv.utils.misc import reparameterize_models
model = super()._build_model()
model = reparameterize_models(model)
return model
def prepare_model(self):
"""Build model from config file by default.
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
"""
model = self._build_model()
model.to(self.device)
model.eval()
if self.model_type == 'raw':
load_checkpoint(model, self.model_path, map_location='cpu')
return model
def model_forward(self, inputs):
"""Model forward.
If you need refactor model forward, you need to reimplement it.
"""
if self.model_type != 'raw':
with torch.no_grad():
outputs = self.model(inputs['img'])
outputs = {'results': outputs} # convert to dict format
else:
outputs = super().model_forward(inputs)
if 'img_metas' not in outputs:
outputs['img_metas'] = inputs['img_metas']
return outputs
def get_input_processor(self):
return YoloXInputProcessor(
self.cfg,
pipelines=self.pipelines,
batch_size=self.batch_size,
model_type=self.model_type,
jit_processor_path=self.jit_processor_path,
device=self.device,
threads=self.input_processor_threads,
mode=self.mode,
)
def get_output_processor(self):
return YoloXOutputProcessor(
score_thresh=self.score_thresh,
model_type=self.model_type,
test_conf=self.test_conf,
nms_thre=self.nms_thre,
use_trt_efficientnms=self.use_trt_efficientnms,
classes=self.CLASSES)
@deprecated(reason='Please use YoloXPredictor.')
@PREDICTORS.register_module()
class TorchYoloXPredictor(YoloXPredictor):
@ -317,7 +444,9 @@ class TorchYoloXPredictor(YoloXPredictor):
max_det=100,
score_thresh=0.5,
use_trt_efficientnms=False,
model_config=None):
model_config=None,
input_processor_threads=8,
mode='BGR'):
"""
Args:
model_path: model file path
@ -345,7 +474,9 @@ class TorchYoloXPredictor(YoloXPredictor):
max_det=max_det,
score_thresh=score_thresh,
nms_thresh=None,
test_conf=None)
test_conf=None,
input_processor_threads=input_processor_threads,
mode=mode)
def predict(self, input_data_list, batch_size=-1, to_numpy=True):
return super().__call__(input_data_list)

View File

@ -3,7 +3,7 @@
import cv2
from easycv.predictors.builder import PREDICTORS
from .base import PredictorV2
from .base import OutputProcessor, PredictorV2
face_contour_point_index = [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
@ -19,52 +19,26 @@ mouth_outer_point_index = [84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 84]
mouth_inter_point_index = [96, 97, 98, 99, 100, 101, 102, 103, 96]
@PREDICTORS.register_module()
class FaceKeypointsPredictor(PredictorV2):
"""Predict pipeline for face keypoint
class FaceKptsOutputProcessor(OutputProcessor):
"""Process the output of face keypoints models.
Args:
model_path (str): Path of model path
config_file (str): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward.
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
input_size (int): Target image size.
"""
def __init__(self,
model_path,
config_file,
batch_size=1,
device=None,
save_results=False,
save_path=None,
pipelines=None):
super(FaceKeypointsPredictor, self).__init__(
model_path,
config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines)
def __init__(self, input_size, point_number):
self.input_size = input_size
self.point_number = point_number
self.input_size = self.cfg.IMAGE_SIZE
self.point_number = self.cfg.POINT_NUMBER
def preprocess(self, inputs, *args, **kwargs):
batch_outputs = super().preprocess(inputs, *args, **kwargs)
self.img_metas = batch_outputs['img_metas']
return batch_outputs
def postprocess(self, inputs, *args, **kwargs):
def __call__(self, inputs):
results = []
img_metas = inputs['img_metas']
points = inputs['point'].cpu().numpy()
poses = inputs['pose'].cpu().numpy()
for idx, point in enumerate(points):
h, w, c = self.img_metas[idx]['img_shape']
h, w, c = img_metas[idx]['img_shape']
scale_h = h / self.input_size
scale_w = w / self.input_size
@ -77,6 +51,57 @@ class FaceKeypointsPredictor(PredictorV2):
return results
@PREDICTORS.register_module()
class FaceKeypointsPredictor(PredictorV2):
"""Predict pipeline for face keypoint
Args:
model_path (str): Path of model path
config_file (str): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward.
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
input_processor_threads (int): Number of processes to process inputs.
mode (str): The image mode into the model.
"""
def __init__(
self,
model_path,
config_file,
batch_size=1,
device=None,
save_results=False,
save_path=None,
pipelines=None,
input_processor_threads=8,
mode='BGR',
):
super(FaceKeypointsPredictor, self).__init__(
model_path,
config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
input_processor_threads=input_processor_threads,
mode=mode)
self.input_size = self.cfg.IMAGE_SIZE
self.point_number = self.cfg.POINT_NUMBER
def model_forward(self, inputs):
outputs = super().model_forward(inputs)
outputs['img_metas'] = inputs['img_metas']
return outputs
def get_output_processor(self):
return FaceKptsOutputProcessor(
input_size=self.input_size, point_number=self.point_number)
def show_result(self, img, points, scale=4.0, save_path=None):
"""Draw `result` over `img`.

View File

@ -7,7 +7,7 @@ from easycv.predictors.builder import PREDICTORS, build_predictor
from ..datasets.pose.data_sources.hand.coco_hand import \
COCO_WHOLEBODY_HAND_DATASET_INFO
from ..datasets.pose.data_sources.top_down import DatasetInfo
from .base import PredictorV2
from .base import InputProcessor, OutputProcessor, PredictorV2
from .pose_predictor import _box2cs
HAND_SKELETON = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7],
@ -16,47 +16,24 @@ HAND_SKELETON = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7],
[9, 13], [13, 17]]
@PREDICTORS.register_module()
class HandKeypointsPredictor(PredictorV2):
"""HandKeypointsPredictor
Attributes:
model_path: path of keypoint model
config_file: path or ``Config`` of config file
detection_model_config: dict of hand detection model predictor config,
example like ``dict(type="", model_path="", config_file="", ......)``
batch_size (int): batch size for forward.
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
"""
class HandkptsInputProcessor(InputProcessor):
def __init__(self,
model_path,
config_file=None,
detection_predictor_config=None,
batch_size=1,
device=None,
save_results=False,
save_path=None,
cfg,
detection_predictor_config,
pipelines=None,
*args,
**kwargs):
super(HandKeypointsPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
*args,
**kwargs)
self.dataset_info = DatasetInfo(COCO_WHOLEBODY_HAND_DATASET_INFO)
batch_size=1,
mode='BGR'):
assert detection_predictor_config is not None, f"{self.__class__.__name__} need 'detection_predictor_config' " \
f'property to build hand detection model'
self.detection_predictor = build_predictor(detection_predictor_config)
self.dataset_info = DatasetInfo(COCO_WHOLEBODY_HAND_DATASET_INFO)
super().__init__(
cfg,
pipelines=pipelines,
batch_size=batch_size,
threads=1,
mode=mode)
def _load_input(self, input):
""" load img and convert detection result to topdown style
@ -77,7 +54,7 @@ class HandKeypointsPredictor(PredictorV2):
box_id = 0
det_bbox_result = input['detection_boxes']
det_bbox_scores = input['detection_scores']
img = mmcv.imread(image_path, 'color', self.INPUT_IMAGE_MODE)
img = mmcv.imread(image_path, 'color', self.mode)
for bbox, score in zip(det_bbox_result, det_bbox_scores):
center, scale = _box2cs(self.cfg.data_cfg['image_size'], bbox)
# prepare data
@ -115,14 +92,14 @@ class HandKeypointsPredictor(PredictorV2):
box_id += 1
return data_list
def preprocess_single(self, input):
def process_single(self, input):
results = []
outputs = self._load_input(input)
for output in outputs:
results.append(self.processor(output))
return results
def preprocess(self, inputs, *args, **kwargs):
def __call__(self, inputs):
"""Process all inputs list. And collate to batch and put to target device.
If you need custom ops to load or process a batch samples, you need to reimplement it.
"""
@ -130,14 +107,16 @@ class HandKeypointsPredictor(PredictorV2):
det_results = self.detection_predictor(inputs, keep_inputs=True)
batch_outputs = []
for i in det_results:
for res in self.preprocess_single(i, *args, **kwargs):
for inp in det_results:
for res in self.process_single(inp):
batch_outputs.append(res)
batch_outputs = self._collate_fn(batch_outputs)
batch_outputs = self._to_device(batch_outputs)
return batch_outputs
def postprocess(self, inputs, *args, **kwargs):
class HandkptsOutputProcessor(OutputProcessor):
def __call__(self, inputs):
keypoints = inputs['preds']
boxes = inputs['boxes']
for i, bbox in enumerate(boxes):
@ -158,6 +137,64 @@ class HandKeypointsPredictor(PredictorV2):
})
return batch_outputs
@PREDICTORS.register_module()
class HandKeypointsPredictor(PredictorV2):
"""HandKeypointsPredictor
Attributes:
model_path: path of keypoint model
config_file: path or ``Config`` of config file
detection_model_config: dict of hand detection model predictor config,
example like ``dict(type="", model_path="", config_file="", ......)``
batch_size (int): batch size for forward.
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
input_processor_threads (int): Number of processes to process inputs.
mode (str): The image mode into the model.
"""
def __init__(self,
model_path,
config_file=None,
detection_predictor_config=None,
batch_size=1,
device=None,
save_results=False,
save_path=None,
pipelines=None,
mode='BGR',
*args,
**kwargs):
self.detection_predictor_config = detection_predictor_config
super(HandKeypointsPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
input_processor_threads=1,
mode=mode,
*args,
**kwargs)
def get_input_processor(self):
return HandkptsInputProcessor(
self.cfg,
self.detection_predictor_config,
pipelines=self.pipelines,
batch_size=self.batch_size,
mode=self.mode)
def get_output_processor(self):
return HandkptsOutputProcessor()
def show_result(self,
image_path,
keypoints,

View File

@ -5,17 +5,9 @@ import os
import cv2
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms import Compose
from easycv.datasets.registry import PIPELINES
from easycv.file import io
from easycv.models import build_model
from easycv.predictors.builder import PREDICTORS
from easycv.predictors.interface import PredictorInterface
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.registry import build_from_cfg
from .base import PredictorV2
@ -30,6 +22,8 @@ class OCRDetPredictor(PredictorV2):
save_results=False,
save_path=None,
pipelines=None,
input_processor_threads=8,
mode='BGR',
*args,
**kwargs):
@ -41,6 +35,8 @@ class OCRDetPredictor(PredictorV2):
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
input_processor_threads=input_processor_threads,
mode=mode,
*args,
**kwargs)
@ -63,6 +59,8 @@ class OCRRecPredictor(PredictorV2):
save_results=False,
save_path=None,
pipelines=None,
input_processor_threads=8,
mode='BGR',
*args,
**kwargs):
@ -74,6 +72,8 @@ class OCRRecPredictor(PredictorV2):
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
input_processor_threads=input_processor_threads,
mode=mode,
*args,
**kwargs)
@ -89,6 +89,8 @@ class OCRClsPredictor(PredictorV2):
save_results=False,
save_path=None,
pipelines=None,
input_processor_threads=8,
mode='BGR',
*args,
**kwargs):
@ -100,6 +102,8 @@ class OCRClsPredictor(PredictorV2):
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
input_processor_threads=input_processor_threads,
mode=mode,
*args,
**kwargs)

View File

@ -8,7 +8,7 @@ from matplotlib.patches import Polygon
from easycv.core.visualization.image import imshow_bboxes
from easycv.predictors.builder import PREDICTORS
from .base import PredictorV2
from .base import OutputProcessor, PredictorV2
@PREDICTORS.register_module()
@ -23,6 +23,8 @@ class SegmentationPredictor(PredictorV2):
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
input_processor_threads (int): Number of processes to process inputs.
mode (str): The image mode into the model.
"""
def __init__(self,
@ -33,6 +35,8 @@ class SegmentationPredictor(PredictorV2):
save_results=False,
save_path=None,
pipelines=None,
input_processor_threads=8,
mode='BGR',
*args,
**kwargs):
@ -44,6 +48,8 @@ class SegmentationPredictor(PredictorV2):
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
input_processor_threads=input_processor_threads,
mode=mode,
*args,
**kwargs)
@ -126,64 +132,31 @@ class SegmentationPredictor(PredictorV2):
return img
@PREDICTORS.register_module()
class Mask2formerPredictor(SegmentationPredictor):
"""Predictor for Mask2former.
class Mask2formerOutputProcessor(OutputProcessor):
"""Process the output of Mask2former.
Args:
model_path (str): Path of model path.
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward.
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
task_mode (str): Support task in ['panoptic', 'instance', 'semantic'].
classes (list): Classes name list.
"""
def __init__(self,
model_path,
config_file=None,
batch_size=1,
device=None,
save_results=False,
save_path=None,
pipelines=None,
task_mode='panoptic',
*args,
**kwargs):
super(Mask2formerPredictor, self).__init__(
model_path,
config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
*args,
**kwargs)
def __init__(self, task_mode, classes):
super(Mask2formerOutputProcessor, self).__init__()
self.task_mode = task_mode
self.class_name = self.cfg.CLASSES
self.PALETTE = self.cfg.PALETTE
self.classes = classes
def forward(self, inputs):
"""Model forward.
"""
with torch.no_grad():
outputs = self.model.forward(**inputs, mode='test', encode=False)
return outputs
def postprocess_single(self, inputs, *args, **kwargs):
def process_single(self, inputs):
output = {}
if self.task_mode == 'panoptic':
pan_results = inputs['pan_results']
# keep objects ahead
ids = np.unique(pan_results)[::-1]
legal_indices = ids != len(self.CLASSES) # for VOID label
legal_indices = ids != len(self.classes) # for VOID label
ids = ids[legal_indices]
labels = np.array([id % 1000 for id in ids], dtype=np.int64)
segms = (pan_results[None] == ids[:, None, None])
masks = [it.astype(np.int) for it in segms]
labels_txt = np.array(self.CLASSES)[labels].tolist()
labels_txt = np.array(self.classes)[labels].tolist()
output['masks'] = masks
output['labels'] = labels_txt
@ -199,6 +172,62 @@ class Mask2formerPredictor(SegmentationPredictor):
raise ValueError(f'Not support model {self.task_mode}')
return output
@PREDICTORS.register_module()
class Mask2formerPredictor(SegmentationPredictor):
"""Predictor for Mask2former.
Args:
model_path (str): Path of model path.
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward.
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
input_processor_threads (int): Number of processes to process inputs.
mode (str): The image mode into the model.
"""
def __init__(self,
model_path,
config_file=None,
batch_size=1,
device=None,
save_results=False,
save_path=None,
pipelines=None,
task_mode='panoptic',
input_processor_threads=8,
mode='BGR',
*args,
**kwargs):
super(Mask2formerPredictor, self).__init__(
model_path,
config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
input_processor_threads=input_processor_threads,
mode=mode,
*args,
**kwargs)
self.task_mode = task_mode
self.class_name = self.cfg.CLASSES
self.PALETTE = self.cfg.PALETTE
def get_output_processor(self):
return Mask2formerOutputProcessor(self.task_mode, self.CLASSES)
def model_forward(self, inputs):
"""Model forward.
"""
with torch.no_grad():
outputs = self.model.forward(**inputs, mode='test', encode=False)
return outputs
def show_panoptic(self, img, masks, labels):
palette = np.asarray(self.cfg.PALETTE)
palette = palette[labels % 1000]

View File

@ -1,80 +1,34 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import numpy as np
import torch
from PIL import Image, ImageFile
from easycv.datasets.registry import PIPELINES
from easycv.file import io
from easycv.framework.errors import ValueError
from easycv.models.builder import build_model
from easycv.utils.misc import deprecated
from easycv.utils.mmlab_utils import (dynamic_adapt_for_mmlab,
remove_adapt_for_mmlab)
from easycv.utils.registry import build_from_cfg
from .base import Predictor, PredictorV2
from .base import InputProcessor, OutputProcessor, PredictorV2
from .builder import PREDICTORS
@PREDICTORS.register_module()
class VideoClassificationPredictor(PredictorV2):
"""Predictor for classification.
Args:
model_path (str): Path of model path.
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward.
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
topk (int): Return top-k results. Default: 1.
pil_input (bool): Whether use PIL image. If processor need PIL input, set true, default false.
label_map_path (str): File path of saving labels list.
"""
class VideoClsInputProcessor(InputProcessor):
def __init__(self,
model_path,
config_file=None,
batch_size=1,
device=None,
save_results=False,
save_path=None,
cfg,
pipelines=None,
multi_class=False,
with_text=False,
label_map_path=None,
topk=1,
*args,
**kwargs):
super(VideoClassificationPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
*args,
**kwargs)
self.topk = topk
self.multi_class = multi_class
batch_size=1,
threads=8,
mode='RGB'):
self.with_text = with_text
if label_map_path is None:
if 'CLASSES' in self.cfg:
class_list = self.cfg.get('CLASSES', [])
elif 'class_list' in self.cfg:
class_list = self.cfg.get('class_list', [])
elif 'num_classes' in self.cfg:
class_list = list(range(self.cfg.get('num_classes', 0)))
class_list = [str(i) for i in class_list]
else:
class_list = []
else:
with io.open(label_map_path, 'r') as f:
class_list = f.readlines()
self.label_map = [i.strip() for i in class_list]
super().__init__(
cfg,
pipelines=pipelines,
batch_size=batch_size,
threads=threads,
mode=mode)
def _load_input(self, input):
"""Load image from file or numpy or PIL object.
@ -93,7 +47,7 @@ class VideoClassificationPredictor(PredictorV2):
if self.with_text and 'text' not in result:
result['text'] = ''
result['start_index'] = 0
result['modality'] = 'RGB'
result['modality'] = self.mode
return result
@ -121,20 +75,15 @@ class VideoClassificationPredictor(PredictorV2):
processor = Compose(pipelines)
return processor
def _build_model(self):
# Use mmdet model
dynamic_adapt_for_mmlab(self.cfg)
if 'vison_pretrained' in self.cfg.model:
self.cfg.model.vison_pretrained = None
if 'text_pretrained' in self.cfg.model:
self.cfg.model.text_pretrained = None
model = build_model(self.cfg.model)
# remove adapt for mmdet to avoid conflict using mmdet models
remove_adapt_for_mmlab(self.cfg)
return model
class VideoClsOutputProcessor(OutputProcessor):
def postprocess(self, inputs, *args, **kwargs):
def __init__(self, label_map, topk=1):
super().__init__()
self.label_map = label_map
self.topk = topk
def __call__(self, inputs):
"""Return top-k results."""
output_prob = inputs['prob'].data.cpu()
topk_class = torch.topk(output_prob, self.topk).indices.numpy()
@ -156,3 +105,94 @@ class VideoClassificationPredictor(PredictorV2):
batch_results.append(result)
return batch_results
@PREDICTORS.register_module()
class VideoClassificationPredictor(PredictorV2):
"""Predictor for classification.
Args:
model_path (str): Path of model path.
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward.
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
topk (int): Return top-k results. Default: 1.
pil_input (bool): Whether use PIL image. If processor need PIL input, set true, default false.
label_map_path (str): File path of saving labels list.
input_processor_threads (int): Number of processes to process inputs.
mode (str): The image mode into the model.
"""
def __init__(self,
model_path,
config_file=None,
batch_size=1,
device=None,
save_results=False,
save_path=None,
pipelines=None,
multi_class=False,
with_text=False,
label_map_path=None,
topk=1,
input_processor_threads=8,
mode='RGB',
*args,
**kwargs):
super(VideoClassificationPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
input_processor_threads=input_processor_threads,
mode=mode,
*args,
**kwargs)
self.topk = topk
self.multi_class = multi_class
self.with_text = with_text
if label_map_path is None:
if 'CLASSES' in self.cfg:
class_list = self.cfg.get('CLASSES', [])
elif 'class_list' in self.cfg:
class_list = self.cfg.get('class_list', [])
elif 'num_classes' in self.cfg:
class_list = list(range(self.cfg.get('num_classes', 0)))
class_list = [str(i) for i in class_list]
else:
class_list = []
else:
with io.open(label_map_path, 'r') as f:
class_list = f.readlines()
self.label_map = [i.strip() for i in class_list]
def _build_model(self):
# Use mmdet model
dynamic_adapt_for_mmlab(self.cfg)
if 'vison_pretrained' in self.cfg.model:
self.cfg.model.vison_pretrained = None
if 'text_pretrained' in self.cfg.model:
self.cfg.model.text_pretrained = None
model = build_model(self.cfg.model)
# remove adapt for mmdet to avoid conflict using mmdet models
remove_adapt_for_mmlab(self.cfg)
return model
def get_input_processor(self):
return VideoClsInputProcessor(
self.cfg,
pipelines=self.pipelines,
with_text=self.with_text,
batch_size=self.batch_size,
threads=self.input_processor_threads,
mode=self.mode)
def get_output_processor(self):
return VideoClsOutputProcessor(self.label_map, self.topk)

View File

@ -7,9 +7,7 @@ from easycv.datasets.pose.data_sources.wholebody.wholebody_coco_source import \
WHOLEBODY_COCO_DATASET_INFO
from easycv.datasets.pose.pipelines.transforms import bbox_cs2xyxy
from easycv.predictors.builder import PREDICTORS, build_predictor
from easycv.predictors.detector import TorchYoloXPredictor
from .base import PredictorV2
from .pose_predictor import _box2cs
from .base import InputProcessor, OutputProcessor, PredictorV2
SKELETON = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12],
[5, 6], [5, 7], [6, 8], [7, 9], [8, 10], [1, 2], [0, 1], [0, 2],
@ -25,44 +23,24 @@ SKELETON = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12],
[131, 132]]
@PREDICTORS.register_module()
class WholeBodyKeypointsPredictor(PredictorV2):
"""WholeBodyKeypointsPredictor
Attributes:
model_path: path of keypoint model
config_file: path or ``Config`` of config file
detection_model_config: dict of hand detection model predictor config,
example like ``dict(type="", model_path="", config_file="", ......)``
batch_size: batch_size to infer
save_results: bool
save_path: path of result image
bbox_thr: bounding box threshold
"""
class WholeBodyKptsInputProcessor(InputProcessor):
def __init__(self,
model_path,
config_file=None,
detection_predictor_config=None,
batch_size=1,
device=None,
save_results=False,
save_path=None,
cfg,
detection_predictor_config,
bbox_thr=None,
*args,
**kwargs):
super(WholeBodyKeypointsPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
*args,
**kwargs)
self.bbox_thr = bbox_thr
self.dataset_info = DatasetInfo(WHOLEBODY_COCO_DATASET_INFO)
pipelines=None,
batch_size=1,
mode='BGR'):
self.detection_predictor = build_predictor(detection_predictor_config)
self.dataset_info = DatasetInfo(WHOLEBODY_COCO_DATASET_INFO)
self.bbox_thr = bbox_thr
super().__init__(
cfg,
pipelines=pipelines,
batch_size=batch_size,
threads=1,
mode=mode)
def process_detection_results(self, det_results, cat_id=0):
"""Process det results, and return a list of bboxes.
@ -165,7 +143,7 @@ class WholeBodyKeypointsPredictor(PredictorV2):
return output_person_info
def preprocess_single(self, input):
def process_single(self, input):
results = []
outputs = self._load_input(input)
@ -173,13 +151,13 @@ class WholeBodyKeypointsPredictor(PredictorV2):
results.append(self.processor(output))
return results
def preprocess(self, inputs, *args, **kwargs):
def __call__(self, inputs):
"""Process all inputs list. And collate to batch and put to target device.
If you need custom ops to load or process a batch samples, you need to reimplement it.
"""
batch_outputs = []
for i in inputs:
for res in self.preprocess_single(i, *args, **kwargs):
for inp in inputs:
for res in self.process_single(inp):
batch_outputs.append(res)
batch_outputs = self._collate_fn(batch_outputs)
@ -187,10 +165,12 @@ class WholeBodyKeypointsPredictor(PredictorV2):
i[j] for i in batch_outputs['img_metas']._data
for j in range(len(i))
]]
batch_outputs = self._to_device(batch_outputs)
return batch_outputs
def postprocess(self, inputs, *args, **kwargs):
class WholeBodyKptsOutputProcessor(OutputProcessor):
def __call__(self, inputs):
output = {}
output['keypoints'] = inputs['preds'][:, :, :2]
output['boxes'] = inputs['boxes']
@ -201,6 +181,61 @@ class WholeBodyKeypointsPredictor(PredictorV2):
output['boxes'] = output['boxes'][:, :4]
return output
@PREDICTORS.register_module()
class WholeBodyKeypointsPredictor(PredictorV2):
"""WholeBodyKeypointsPredictor
Attributes:
model_path: path of keypoint model
config_file: path or ``Config`` of config file
detection_model_config: dict of hand detection model predictor config,
example like ``dict(type="", model_path="", config_file="", ......)``
batch_size: batch_size to infer
save_results: bool
save_path: path of result image
bbox_thr: bounding box threshold
mode (str): the image mode into the model
"""
def __init__(self,
model_path,
config_file=None,
detection_predictor_config=None,
batch_size=1,
device=None,
save_results=False,
save_path=None,
bbox_thr=None,
mode='BGR',
*args,
**kwargs):
super(WholeBodyKeypointsPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
input_processor_threads=1,
mode=mode,
*args,
**kwargs)
self.bbox_thr = bbox_thr
self.detection_predictor_config = detection_predictor_config
def get_input_processor(self):
return WholeBodyKptsInputProcessor(
cfg=self.cfg,
detection_predictor_config=self.detection_predictor_config,
bbox_thr=self.bbox_thr,
pipelines=self.pipelines,
batch_size=self.batch_size,
mode=self.mode)
def get_output_processor(self):
return WholeBodyKptsOutputProcessor()
def show_result(self,
image_path,
keypoints,

View File

@ -48,11 +48,19 @@ class YoloXPredictorTest(unittest.TestCase):
output = outputs[0]
self._assert_results(output)
def _base_test_batch(self, model_path, inputs, num_samples, batch_size):
def _base_test_batch(self,
model_path,
inputs,
num_samples,
batch_size,
input_processor_threads=8):
assert isinstance(inputs, list) and len(inputs) == 1
predictor = YoloXPredictor(
model_path=model_path, score_thresh=0.5, batch_size=batch_size)
model_path=model_path,
score_thresh=0.5,
batch_size=batch_size,
input_processor_threads=input_processor_threads)
outputs = predictor(inputs * num_samples)
self.assertEqual(len(outputs), num_samples)
@ -85,7 +93,10 @@ class YoloXPredictorTest(unittest.TestCase):
def test_batch_jit_pre_trt(self):
jit_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT_B2
self._base_test_batch(
jit_path, [self.img], num_samples=4, batch_size=2)
jit_path, [self.img],
num_samples=4,
batch_size=2,
input_processor_threads=1)
def test_single_raw_TorchYoloXPredictor(self):
detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT

View File

@ -3,7 +3,7 @@
import unittest
import cv2
from ut_config import PRETRAINED_MODEL_FACE_2D_KEYPOINTS
from tests.ut_config import PRETRAINED_MODEL_FACE_2D_KEYPOINTS
from easycv.predictors.face_keypoints_predictor import FaceKeypointsPredictor

View File

@ -2,7 +2,7 @@
import unittest
from ut_config import PRETRAINED_MODEL_HAND_KEYPOINTS
from tests.ut_config import PRETRAINED_MODEL_HAND_KEYPOINTS
from easycv.predictors.hand_keypoints_predictor import HandKeypointsPredictor
from easycv.utils.config_tools import mmcv_config_fromfile

View File

@ -105,11 +105,12 @@ class SegmentationPredictorTest(unittest.TestCase):
total_samples = 3
outputs = predict_pipeline(
[self.img_path] * total_samples, keep_inputs=False)
self.assertEqual(outputs, [])
with open(tmp_path, 'rb') as f:
results = pickle.loads(f.read())
self.assertEqual(len(results), total_samples)
for res in results:
self.assertNotIn('inputs', res)
self.assertIn('seg_pred', res)

View File

@ -2,8 +2,8 @@
import unittest
from ut_config import (PRETRAINED_MODEL_WHOLEBODY,
PRETRAINED_MODEL_WHOLEBODY_DETECTION)
from tests.ut_config import (PRETRAINED_MODEL_WHOLEBODY,
PRETRAINED_MODEL_WHOLEBODY_DETECTION)
from easycv.predictors.wholebody_keypoints_predictor import \
WholeBodyKeypointsPredictor