Optimize TorchPoseTopDownPredictorWithDetector (#291)

* fix WholeBodyKeyPointEvaluator
pull/296/head
Cathy0908 2023-02-24 13:55:58 +08:00 committed by GitHub
parent c3def063b4
commit 3726ed9553
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 75 additions and 21 deletions

View File

@ -15,7 +15,8 @@ from easycv.file import io
from easycv.framework.errors import ModuleNotFoundError, TypeError, ValueError
from easycv.models import build_model
from easycv.predictors.builder import PREDICTORS
from easycv.predictors.detector import TorchYoloXPredictor
from easycv.predictors.detector import (DetectionPredictor,
TorchYoloXPredictor, YoloXPredictor)
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.config_tools import mmcv_config_fromfile
from easycv.utils.registry import build_from_cfg
@ -390,26 +391,54 @@ class TorchPoseTopDownPredictor(PredictorInterface):
return all_pose_results
def show_result(self,
image_path,
keypoints,
radius=4,
thickness=1,
kpt_score_thr=0.3,
bbox_color='green',
show=False,
save_path=None):
vis_result = vis_pose_result(
self.model,
image_path,
keypoints['pose_results'],
dataset_info=self.dataset_info,
radius=radius,
thickness=thickness,
kpt_score_thr=kpt_score_thr,
bbox_color=bbox_color,
show=show,
out_file=save_path)
return vis_result
@PREDICTORS.register_module()
class TorchPoseTopDownPredictorWithDetector(PredictorInterface):
SUPPORT_DETECTION_PREDICTORS = {'TorchYoloXPredictor': TorchYoloXPredictor}
SUPPORT_DETECTION_PREDICTORS = {
'TorchYoloXPredictor': TorchYoloXPredictor,
'YoloXPredictor': YoloXPredictor,
'DetectionPredictor': DetectionPredictor
}
def __init__(
self,
model_path,
model_config={
'pose': {
'bbox_thr': 0.3,
'format': 'xywh'
self,
model_path,
model_config={
'pose': {
'bbox_thr': 0.3,
'format': 'xywh'
},
'detection': {
'model_type': None,
'reserved_classes': [],
'score_thresh': 0.0,
}
},
'detection': {
'model_type': None,
'reserved_classes': [],
'score_thresh': 0.0,
}
}):
return_vis_data=False):
"""
init model
@ -424,7 +453,7 @@ class TorchPoseTopDownPredictorWithDetector(PredictorInterface):
detection_model_type = model_config['detection'].pop('model_type')
assert detection_model_type in self.SUPPORT_DETECTION_PREDICTORS
self.reserved_classes = model_config['detection'].get(
self.reserved_classes = model_config['detection'].pop(
'reserved_classes', [])
model_list = model_path.split(',')
@ -433,10 +462,15 @@ class TorchPoseTopDownPredictorWithDetector(PredictorInterface):
pose_model_path, detection_model_path = model_list
detection_obj = self.SUPPORT_DETECTION_PREDICTORS[detection_model_type]
self.detection_predictor = detection_obj(
detection_model_path, model_config=model_config['detection'])
if detection_model_type == 'TorchYoloXPredictor':
self.detection_predictor = detection_obj(
detection_model_path, model_config=model_config['detection'])
else:
self.detection_predictor = detection_obj(
detection_model_path, **model_config['detection'])
self.pose_predictor = TorchPoseTopDownPredictor(
pose_model_path, model_config=model_config['pose'])
self.return_vis_data = return_vis_data
def process_det_results(self,
outputs,
@ -454,12 +488,16 @@ class TorchPoseTopDownPredictorWithDetector(PredictorInterface):
for i in range(len(outputs)):
output = outputs[i]
cur_data = {'img': input_data_list[i], 'detection_results': []}
for class_name in output['detection_class_names']:
if output['detection_boxes'] is None or len(
output['detection_boxes']) < 1:
filter_outputs.append(cur_data)
continue
for j, class_name in enumerate(output['detection_class_names']):
if class_name in reserved_classes:
cur_data['detection_results'].append({
'bbox':
np.append(output['detection_boxes'][i],
output['detection_scores'][i])
np.append(output['detection_boxes'][j],
output['detection_scores'][j])
})
filter_outputs.append(cur_data)
@ -481,14 +519,30 @@ class TorchPoseTopDownPredictorWithDetector(PredictorInterface):
"""
detection_output = self.detection_predictor.predict(input_data_list)
if hasattr(self.detection_predictor, 'predict') and callable(
self.detection_predictor.predict):
detection_output = self.detection_predictor.predict(
input_data_list)
else:
detection_output = self.detection_predictor(input_data_list)
output = self.process_det_results(detection_output, input_data_list,
self.reserved_classes)
pose_output = self.pose_predictor.predict(
output, return_heatmap=return_heatmap)
if self.return_vis_data:
for i, img in enumerate(input_data_list):
if len(pose_output[i]['pose_results']) > 0:
vis_result = self.pose_predictor.show_result(
image_path=img,
keypoints=pose_output[i],
)
pose_output[i]['vis_result'] = vis_result
return pose_output
def __call__(self, *args, **kwargs):
return self.predict(*args, **kwargs)
def vis_pose_result(model,
img,