diff --git a/easycv/predictors/pose_predictor.py b/easycv/predictors/pose_predictor.py index 0cd18bed..295306d6 100644 --- a/easycv/predictors/pose_predictor.py +++ b/easycv/predictors/pose_predictor.py @@ -15,7 +15,8 @@ from easycv.file import io from easycv.framework.errors import ModuleNotFoundError, TypeError, ValueError from easycv.models import build_model from easycv.predictors.builder import PREDICTORS -from easycv.predictors.detector import TorchYoloXPredictor +from easycv.predictors.detector import (DetectionPredictor, + TorchYoloXPredictor, YoloXPredictor) from easycv.utils.checkpoint import load_checkpoint from easycv.utils.config_tools import mmcv_config_fromfile from easycv.utils.registry import build_from_cfg @@ -390,26 +391,54 @@ class TorchPoseTopDownPredictor(PredictorInterface): return all_pose_results + def show_result(self, + image_path, + keypoints, + radius=4, + thickness=1, + kpt_score_thr=0.3, + bbox_color='green', + show=False, + save_path=None): + vis_result = vis_pose_result( + self.model, + image_path, + keypoints['pose_results'], + dataset_info=self.dataset_info, + radius=radius, + thickness=thickness, + kpt_score_thr=kpt_score_thr, + bbox_color=bbox_color, + show=show, + out_file=save_path) + + return vis_result + @PREDICTORS.register_module() class TorchPoseTopDownPredictorWithDetector(PredictorInterface): - SUPPORT_DETECTION_PREDICTORS = {'TorchYoloXPredictor': TorchYoloXPredictor} + SUPPORT_DETECTION_PREDICTORS = { + 'TorchYoloXPredictor': TorchYoloXPredictor, + 'YoloXPredictor': YoloXPredictor, + 'DetectionPredictor': DetectionPredictor + } def __init__( - self, - model_path, - model_config={ - 'pose': { - 'bbox_thr': 0.3, - 'format': 'xywh' + self, + model_path, + model_config={ + 'pose': { + 'bbox_thr': 0.3, + 'format': 'xywh' + }, + 'detection': { + 'model_type': None, + 'reserved_classes': [], + 'score_thresh': 0.0, + } }, - 'detection': { - 'model_type': None, - 'reserved_classes': [], - 'score_thresh': 0.0, - } - }): + return_vis_data=False): """ init model @@ -424,7 +453,7 @@ class TorchPoseTopDownPredictorWithDetector(PredictorInterface): detection_model_type = model_config['detection'].pop('model_type') assert detection_model_type in self.SUPPORT_DETECTION_PREDICTORS - self.reserved_classes = model_config['detection'].get( + self.reserved_classes = model_config['detection'].pop( 'reserved_classes', []) model_list = model_path.split(',') @@ -433,10 +462,15 @@ class TorchPoseTopDownPredictorWithDetector(PredictorInterface): pose_model_path, detection_model_path = model_list detection_obj = self.SUPPORT_DETECTION_PREDICTORS[detection_model_type] - self.detection_predictor = detection_obj( - detection_model_path, model_config=model_config['detection']) + if detection_model_type == 'TorchYoloXPredictor': + self.detection_predictor = detection_obj( + detection_model_path, model_config=model_config['detection']) + else: + self.detection_predictor = detection_obj( + detection_model_path, **model_config['detection']) self.pose_predictor = TorchPoseTopDownPredictor( pose_model_path, model_config=model_config['pose']) + self.return_vis_data = return_vis_data def process_det_results(self, outputs, @@ -454,12 +488,16 @@ class TorchPoseTopDownPredictorWithDetector(PredictorInterface): for i in range(len(outputs)): output = outputs[i] cur_data = {'img': input_data_list[i], 'detection_results': []} - for class_name in output['detection_class_names']: + if output['detection_boxes'] is None or len( + output['detection_boxes']) < 1: + filter_outputs.append(cur_data) + continue + for j, class_name in enumerate(output['detection_class_names']): if class_name in reserved_classes: cur_data['detection_results'].append({ 'bbox': - np.append(output['detection_boxes'][i], - output['detection_scores'][i]) + np.append(output['detection_boxes'][j], + output['detection_scores'][j]) }) filter_outputs.append(cur_data) @@ -481,14 +519,30 @@ class TorchPoseTopDownPredictorWithDetector(PredictorInterface): """ - detection_output = self.detection_predictor.predict(input_data_list) + if hasattr(self.detection_predictor, 'predict') and callable( + self.detection_predictor.predict): + detection_output = self.detection_predictor.predict( + input_data_list) + else: + detection_output = self.detection_predictor(input_data_list) output = self.process_det_results(detection_output, input_data_list, self.reserved_classes) pose_output = self.pose_predictor.predict( output, return_heatmap=return_heatmap) + if self.return_vis_data: + for i, img in enumerate(input_data_list): + if len(pose_output[i]['pose_results']) > 0: + vis_result = self.pose_predictor.show_result( + image_path=img, + keypoints=pose_output[i], + ) + pose_output[i]['vis_result'] = vis_result return pose_output + def __call__(self, *args, **kwargs): + return self.predict(*args, **kwargs) + def vis_pose_result(model, img,