mirror of https://github.com/alibaba/EasyCV.git
68 lines
2.2 KiB
Python
68 lines
2.2 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
"""
|
|
isort:skip_file
|
|
"""
|
|
import os
|
|
import tempfile
|
|
import unittest
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from easycv.predictors.pose_predictor import TorchPoseTopDownPredictorWithDetector, vis_pose_result
|
|
from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_EXPORT,
|
|
PRETRAINED_MODEL_POSE_HRNET_EXPORT,
|
|
POSE_DATA_SMALL_COCO_LOCAL)
|
|
|
|
|
|
class TorchPoseTopDownPredictorWithDetectorTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
|
|
def test_pose_topdown_with_detector(self):
|
|
detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
|
|
pose_model_path = PRETRAINED_MODEL_POSE_HRNET_EXPORT
|
|
img = os.path.join(POSE_DATA_SMALL_COCO_LOCAL,
|
|
'images/000000067078.jpg')
|
|
|
|
input_data_list = [np.asarray(Image.open(img))]
|
|
model_path = ','.join((pose_model_path, detection_model_path))
|
|
predictor = TorchPoseTopDownPredictorWithDetector(
|
|
model_path=model_path,
|
|
model_config={
|
|
'pose': {
|
|
'bbox_thr': 0.3,
|
|
'format': 'xywh'
|
|
},
|
|
'detection': {
|
|
'model_type': 'TorchYoloXPredictor'
|
|
}
|
|
})
|
|
|
|
all_pose_results = predictor.predict(input_data_list)
|
|
one_result = all_pose_results[0]['pose_results']
|
|
self.assertIn('bbox', one_result[1])
|
|
self.assertIn('keypoints', one_result[1])
|
|
self.assertEqual(len(one_result[1]['bbox']), 5)
|
|
self.assertEqual(one_result[1]['keypoints'].shape, (17, 3))
|
|
|
|
vis_result = vis_pose_result(
|
|
predictor.pose_predictor.model,
|
|
img,
|
|
all_pose_results[0]['pose_results'],
|
|
dataset_info=predictor.pose_predictor.dataset_info,
|
|
show=False)
|
|
|
|
vis_result = cv2.resize(vis_result, dsize=None, fx=0.5, fy=0.5)
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.jpg') as tmp_file:
|
|
tmp_save_path = tmp_file.name
|
|
cv2.imwrite(tmp_save_path, vis_result)
|
|
assert os.path.exists(tmp_save_path)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|