mirror of https://github.com/alibaba/EasyCV.git
Optimize TorchPoseTopDownPredictorWithDetector (#291)
* fix WholeBodyKeyPointEvaluatorpull/296/head
parent
c3def063b4
commit
3726ed9553
|
@ -15,7 +15,8 @@ from easycv.file import io
|
||||||
from easycv.framework.errors import ModuleNotFoundError, TypeError, ValueError
|
from easycv.framework.errors import ModuleNotFoundError, TypeError, ValueError
|
||||||
from easycv.models import build_model
|
from easycv.models import build_model
|
||||||
from easycv.predictors.builder import PREDICTORS
|
from easycv.predictors.builder import PREDICTORS
|
||||||
from easycv.predictors.detector import TorchYoloXPredictor
|
from easycv.predictors.detector import (DetectionPredictor,
|
||||||
|
TorchYoloXPredictor, YoloXPredictor)
|
||||||
from easycv.utils.checkpoint import load_checkpoint
|
from easycv.utils.checkpoint import load_checkpoint
|
||||||
from easycv.utils.config_tools import mmcv_config_fromfile
|
from easycv.utils.config_tools import mmcv_config_fromfile
|
||||||
from easycv.utils.registry import build_from_cfg
|
from easycv.utils.registry import build_from_cfg
|
||||||
|
@ -390,11 +391,38 @@ class TorchPoseTopDownPredictor(PredictorInterface):
|
||||||
|
|
||||||
return all_pose_results
|
return all_pose_results
|
||||||
|
|
||||||
|
def show_result(self,
|
||||||
|
image_path,
|
||||||
|
keypoints,
|
||||||
|
radius=4,
|
||||||
|
thickness=1,
|
||||||
|
kpt_score_thr=0.3,
|
||||||
|
bbox_color='green',
|
||||||
|
show=False,
|
||||||
|
save_path=None):
|
||||||
|
vis_result = vis_pose_result(
|
||||||
|
self.model,
|
||||||
|
image_path,
|
||||||
|
keypoints['pose_results'],
|
||||||
|
dataset_info=self.dataset_info,
|
||||||
|
radius=radius,
|
||||||
|
thickness=thickness,
|
||||||
|
kpt_score_thr=kpt_score_thr,
|
||||||
|
bbox_color=bbox_color,
|
||||||
|
show=show,
|
||||||
|
out_file=save_path)
|
||||||
|
|
||||||
|
return vis_result
|
||||||
|
|
||||||
|
|
||||||
@PREDICTORS.register_module()
|
@PREDICTORS.register_module()
|
||||||
class TorchPoseTopDownPredictorWithDetector(PredictorInterface):
|
class TorchPoseTopDownPredictorWithDetector(PredictorInterface):
|
||||||
|
|
||||||
SUPPORT_DETECTION_PREDICTORS = {'TorchYoloXPredictor': TorchYoloXPredictor}
|
SUPPORT_DETECTION_PREDICTORS = {
|
||||||
|
'TorchYoloXPredictor': TorchYoloXPredictor,
|
||||||
|
'YoloXPredictor': YoloXPredictor,
|
||||||
|
'DetectionPredictor': DetectionPredictor
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -409,7 +437,8 @@ class TorchPoseTopDownPredictorWithDetector(PredictorInterface):
|
||||||
'reserved_classes': [],
|
'reserved_classes': [],
|
||||||
'score_thresh': 0.0,
|
'score_thresh': 0.0,
|
||||||
}
|
}
|
||||||
}):
|
},
|
||||||
|
return_vis_data=False):
|
||||||
"""
|
"""
|
||||||
init model
|
init model
|
||||||
|
|
||||||
|
@ -424,7 +453,7 @@ class TorchPoseTopDownPredictorWithDetector(PredictorInterface):
|
||||||
detection_model_type = model_config['detection'].pop('model_type')
|
detection_model_type = model_config['detection'].pop('model_type')
|
||||||
assert detection_model_type in self.SUPPORT_DETECTION_PREDICTORS
|
assert detection_model_type in self.SUPPORT_DETECTION_PREDICTORS
|
||||||
|
|
||||||
self.reserved_classes = model_config['detection'].get(
|
self.reserved_classes = model_config['detection'].pop(
|
||||||
'reserved_classes', [])
|
'reserved_classes', [])
|
||||||
|
|
||||||
model_list = model_path.split(',')
|
model_list = model_path.split(',')
|
||||||
|
@ -433,10 +462,15 @@ class TorchPoseTopDownPredictorWithDetector(PredictorInterface):
|
||||||
pose_model_path, detection_model_path = model_list
|
pose_model_path, detection_model_path = model_list
|
||||||
|
|
||||||
detection_obj = self.SUPPORT_DETECTION_PREDICTORS[detection_model_type]
|
detection_obj = self.SUPPORT_DETECTION_PREDICTORS[detection_model_type]
|
||||||
|
if detection_model_type == 'TorchYoloXPredictor':
|
||||||
self.detection_predictor = detection_obj(
|
self.detection_predictor = detection_obj(
|
||||||
detection_model_path, model_config=model_config['detection'])
|
detection_model_path, model_config=model_config['detection'])
|
||||||
|
else:
|
||||||
|
self.detection_predictor = detection_obj(
|
||||||
|
detection_model_path, **model_config['detection'])
|
||||||
self.pose_predictor = TorchPoseTopDownPredictor(
|
self.pose_predictor = TorchPoseTopDownPredictor(
|
||||||
pose_model_path, model_config=model_config['pose'])
|
pose_model_path, model_config=model_config['pose'])
|
||||||
|
self.return_vis_data = return_vis_data
|
||||||
|
|
||||||
def process_det_results(self,
|
def process_det_results(self,
|
||||||
outputs,
|
outputs,
|
||||||
|
@ -454,12 +488,16 @@ class TorchPoseTopDownPredictorWithDetector(PredictorInterface):
|
||||||
for i in range(len(outputs)):
|
for i in range(len(outputs)):
|
||||||
output = outputs[i]
|
output = outputs[i]
|
||||||
cur_data = {'img': input_data_list[i], 'detection_results': []}
|
cur_data = {'img': input_data_list[i], 'detection_results': []}
|
||||||
for class_name in output['detection_class_names']:
|
if output['detection_boxes'] is None or len(
|
||||||
|
output['detection_boxes']) < 1:
|
||||||
|
filter_outputs.append(cur_data)
|
||||||
|
continue
|
||||||
|
for j, class_name in enumerate(output['detection_class_names']):
|
||||||
if class_name in reserved_classes:
|
if class_name in reserved_classes:
|
||||||
cur_data['detection_results'].append({
|
cur_data['detection_results'].append({
|
||||||
'bbox':
|
'bbox':
|
||||||
np.append(output['detection_boxes'][i],
|
np.append(output['detection_boxes'][j],
|
||||||
output['detection_scores'][i])
|
output['detection_scores'][j])
|
||||||
})
|
})
|
||||||
filter_outputs.append(cur_data)
|
filter_outputs.append(cur_data)
|
||||||
|
|
||||||
|
@ -481,14 +519,30 @@ class TorchPoseTopDownPredictorWithDetector(PredictorInterface):
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
detection_output = self.detection_predictor.predict(input_data_list)
|
if hasattr(self.detection_predictor, 'predict') and callable(
|
||||||
|
self.detection_predictor.predict):
|
||||||
|
detection_output = self.detection_predictor.predict(
|
||||||
|
input_data_list)
|
||||||
|
else:
|
||||||
|
detection_output = self.detection_predictor(input_data_list)
|
||||||
output = self.process_det_results(detection_output, input_data_list,
|
output = self.process_det_results(detection_output, input_data_list,
|
||||||
self.reserved_classes)
|
self.reserved_classes)
|
||||||
pose_output = self.pose_predictor.predict(
|
pose_output = self.pose_predictor.predict(
|
||||||
output, return_heatmap=return_heatmap)
|
output, return_heatmap=return_heatmap)
|
||||||
|
if self.return_vis_data:
|
||||||
|
for i, img in enumerate(input_data_list):
|
||||||
|
if len(pose_output[i]['pose_results']) > 0:
|
||||||
|
vis_result = self.pose_predictor.show_result(
|
||||||
|
image_path=img,
|
||||||
|
keypoints=pose_output[i],
|
||||||
|
)
|
||||||
|
pose_output[i]['vis_result'] = vis_result
|
||||||
|
|
||||||
return pose_output
|
return pose_output
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.predict(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def vis_pose_result(model,
|
def vis_pose_result(model,
|
||||||
img,
|
img,
|
||||||
|
|
Loading…
Reference in New Issue