fix pose prediction of result error (#337)

update pose predictor
pull/343/head
gulou 2024-04-09 16:27:23 +08:00 committed by GitHub
parent 5ba3057ff8
commit 1d1ac8aa5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 13 additions and 14 deletions

View File

@ -19,6 +19,8 @@ from easycv.utils.config_tools import mmcv_config_fromfile
from easycv.utils.misc import deprecated
from .base import InputProcessor, OutputProcessor, PredictorV2
np.set_printoptions(suppress=True)
def _box2cs(image_size, box):
"""This encodes bbox(x,y,w,h) into (center, scale)
@ -222,11 +224,12 @@ class PoseTopDownInputProcessor(InputProcessor):
bboxes = bboxes[valid_idx]
person_results = [person_results[i] for i in valid_idx]
output_person_info = []
results = []
for person_result in person_results:
box = person_result['bbox'] # x,y,x,y
box = [box[0], box[1], box[2] - box[0], box[3] - box[1]] # x,y,w,h
center, scale = _box2cs(self.cfg.data_cfg['image_size'], box)
box = person_result['bbox'] # x,y,x,y,s
boxc = [box[0], box[1], box[2] - box[0],
box[3] - box[1]] # x,y,w,h
center, scale = _box2cs(self.cfg.data_cfg['image_size'], boxc)
data = {
'image_id':
0,
@ -264,11 +267,10 @@ class PoseTopDownInputProcessor(InputProcessor):
output['img_fields'],
}
box_id += 1
output_person_info.append(data)
data_processor = self.processor(data)
data_processor['bbox'] = box
results.append(data_processor)
results = []
for output in output_person_info:
results.append(self.processor(output))
return results
def __call__(self, inputs):
@ -296,12 +298,7 @@ class PoseTopDownOutputProcessor(OutputProcessor):
def __call__(self, inputs):
output = {}
output['keypoints'] = inputs['preds']
output['bbox'] = inputs['boxes'] # c1, c2, s1, s2, area, core
for i, bbox in enumerate(output['bbox']):
center, scale = bbox[:2], bbox[2:4]
output['bbox'][i][:4] = bbox_cs2xyxy(center, scale)
output['bbox'] = output['bbox'][:, [0, 1, 2, 3, 5]]
output['bbox'] = np.array(inputs['boxes']) # x1, y1, x2, y2 score
return output
@ -403,6 +400,7 @@ class PoseTopDownPredictor(PredictorV2):
return model
def model_forward(self, inputs, return_heatmap=False):
boxes = inputs['bbox'].cpu().numpy()
if self.model_type == 'raw':
with torch.no_grad():
result = self.model(
@ -423,6 +421,7 @@ class PoseTopDownPredictor(PredictorV2):
result = decode_heatmap(output_heatmap, img_metas,
self.cfg.model.test_cfg)
result['boxes'] = np.array(boxes)
return result
def get_input_processor(self):