mirror of https://github.com/alibaba/EasyCV.git
Optimize TorchPoseTopDownPredictorWithDetector (#291)
* fix WholeBodyKeyPointEvaluatorpull/296/head
parent
c3def063b4
commit
3726ed9553
|
@ -15,7 +15,8 @@ from easycv.file import io
|
|||
from easycv.framework.errors import ModuleNotFoundError, TypeError, ValueError
|
||||
from easycv.models import build_model
|
||||
from easycv.predictors.builder import PREDICTORS
|
||||
from easycv.predictors.detector import TorchYoloXPredictor
|
||||
from easycv.predictors.detector import (DetectionPredictor,
|
||||
TorchYoloXPredictor, YoloXPredictor)
|
||||
from easycv.utils.checkpoint import load_checkpoint
|
||||
from easycv.utils.config_tools import mmcv_config_fromfile
|
||||
from easycv.utils.registry import build_from_cfg
|
||||
|
@ -390,26 +391,54 @@ class TorchPoseTopDownPredictor(PredictorInterface):
|
|||
|
||||
return all_pose_results
|
||||
|
||||
def show_result(self,
|
||||
image_path,
|
||||
keypoints,
|
||||
radius=4,
|
||||
thickness=1,
|
||||
kpt_score_thr=0.3,
|
||||
bbox_color='green',
|
||||
show=False,
|
||||
save_path=None):
|
||||
vis_result = vis_pose_result(
|
||||
self.model,
|
||||
image_path,
|
||||
keypoints['pose_results'],
|
||||
dataset_info=self.dataset_info,
|
||||
radius=radius,
|
||||
thickness=thickness,
|
||||
kpt_score_thr=kpt_score_thr,
|
||||
bbox_color=bbox_color,
|
||||
show=show,
|
||||
out_file=save_path)
|
||||
|
||||
return vis_result
|
||||
|
||||
|
||||
@PREDICTORS.register_module()
|
||||
class TorchPoseTopDownPredictorWithDetector(PredictorInterface):
|
||||
|
||||
SUPPORT_DETECTION_PREDICTORS = {'TorchYoloXPredictor': TorchYoloXPredictor}
|
||||
SUPPORT_DETECTION_PREDICTORS = {
|
||||
'TorchYoloXPredictor': TorchYoloXPredictor,
|
||||
'YoloXPredictor': YoloXPredictor,
|
||||
'DetectionPredictor': DetectionPredictor
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
model_config={
|
||||
'pose': {
|
||||
'bbox_thr': 0.3,
|
||||
'format': 'xywh'
|
||||
self,
|
||||
model_path,
|
||||
model_config={
|
||||
'pose': {
|
||||
'bbox_thr': 0.3,
|
||||
'format': 'xywh'
|
||||
},
|
||||
'detection': {
|
||||
'model_type': None,
|
||||
'reserved_classes': [],
|
||||
'score_thresh': 0.0,
|
||||
}
|
||||
},
|
||||
'detection': {
|
||||
'model_type': None,
|
||||
'reserved_classes': [],
|
||||
'score_thresh': 0.0,
|
||||
}
|
||||
}):
|
||||
return_vis_data=False):
|
||||
"""
|
||||
init model
|
||||
|
||||
|
@ -424,7 +453,7 @@ class TorchPoseTopDownPredictorWithDetector(PredictorInterface):
|
|||
detection_model_type = model_config['detection'].pop('model_type')
|
||||
assert detection_model_type in self.SUPPORT_DETECTION_PREDICTORS
|
||||
|
||||
self.reserved_classes = model_config['detection'].get(
|
||||
self.reserved_classes = model_config['detection'].pop(
|
||||
'reserved_classes', [])
|
||||
|
||||
model_list = model_path.split(',')
|
||||
|
@ -433,10 +462,15 @@ class TorchPoseTopDownPredictorWithDetector(PredictorInterface):
|
|||
pose_model_path, detection_model_path = model_list
|
||||
|
||||
detection_obj = self.SUPPORT_DETECTION_PREDICTORS[detection_model_type]
|
||||
self.detection_predictor = detection_obj(
|
||||
detection_model_path, model_config=model_config['detection'])
|
||||
if detection_model_type == 'TorchYoloXPredictor':
|
||||
self.detection_predictor = detection_obj(
|
||||
detection_model_path, model_config=model_config['detection'])
|
||||
else:
|
||||
self.detection_predictor = detection_obj(
|
||||
detection_model_path, **model_config['detection'])
|
||||
self.pose_predictor = TorchPoseTopDownPredictor(
|
||||
pose_model_path, model_config=model_config['pose'])
|
||||
self.return_vis_data = return_vis_data
|
||||
|
||||
def process_det_results(self,
|
||||
outputs,
|
||||
|
@ -454,12 +488,16 @@ class TorchPoseTopDownPredictorWithDetector(PredictorInterface):
|
|||
for i in range(len(outputs)):
|
||||
output = outputs[i]
|
||||
cur_data = {'img': input_data_list[i], 'detection_results': []}
|
||||
for class_name in output['detection_class_names']:
|
||||
if output['detection_boxes'] is None or len(
|
||||
output['detection_boxes']) < 1:
|
||||
filter_outputs.append(cur_data)
|
||||
continue
|
||||
for j, class_name in enumerate(output['detection_class_names']):
|
||||
if class_name in reserved_classes:
|
||||
cur_data['detection_results'].append({
|
||||
'bbox':
|
||||
np.append(output['detection_boxes'][i],
|
||||
output['detection_scores'][i])
|
||||
np.append(output['detection_boxes'][j],
|
||||
output['detection_scores'][j])
|
||||
})
|
||||
filter_outputs.append(cur_data)
|
||||
|
||||
|
@ -481,14 +519,30 @@ class TorchPoseTopDownPredictorWithDetector(PredictorInterface):
|
|||
|
||||
|
||||
"""
|
||||
detection_output = self.detection_predictor.predict(input_data_list)
|
||||
if hasattr(self.detection_predictor, 'predict') and callable(
|
||||
self.detection_predictor.predict):
|
||||
detection_output = self.detection_predictor.predict(
|
||||
input_data_list)
|
||||
else:
|
||||
detection_output = self.detection_predictor(input_data_list)
|
||||
output = self.process_det_results(detection_output, input_data_list,
|
||||
self.reserved_classes)
|
||||
pose_output = self.pose_predictor.predict(
|
||||
output, return_heatmap=return_heatmap)
|
||||
if self.return_vis_data:
|
||||
for i, img in enumerate(input_data_list):
|
||||
if len(pose_output[i]['pose_results']) > 0:
|
||||
vis_result = self.pose_predictor.show_result(
|
||||
image_path=img,
|
||||
keypoints=pose_output[i],
|
||||
)
|
||||
pose_output[i]['vis_result'] = vis_result
|
||||
|
||||
return pose_output
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.predict(*args, **kwargs)
|
||||
|
||||
|
||||
def vis_pose_result(model,
|
||||
img,
|
||||
|
|
Loading…
Reference in New Issue