mirror of https://github.com/alibaba/EasyCV.git
60 lines
2.1 KiB
Python
60 lines
2.1 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import torch
|
|
|
|
from .base_evaluator import Evaluator
|
|
from .builder import EVALUATORS
|
|
from .metric_registry import METRICS
|
|
|
|
|
|
@EVALUATORS.register_module
|
|
class FaceKeypointEvaluator(Evaluator):
|
|
|
|
def __init__(self, dataset_name=None, metric_names=['ave_nme']):
|
|
super(FaceKeypointEvaluator, self).__init__(dataset_name, metric_names)
|
|
self.metric = metric_names
|
|
self.dataset_name = dataset_name
|
|
|
|
def _evaluate_impl(self, prediction_dict, groundtruth_dict, **kwargs):
|
|
"""
|
|
Args:
|
|
prediction_dict: model forward output dict, ['point', 'pose']
|
|
groundtruth_dict: groundtruth dict, ['target_point', 'target_point_mask', 'target_pose', 'target_pose_mask'] used for compute accuracy
|
|
kwargs: other parameters
|
|
"""
|
|
|
|
def evaluate(predicts, gts, **kwargs):
|
|
from easycv.models.utils.face_keypoint_utils import get_keypoint_accuracy, get_pose_accuracy
|
|
ave_pose_acc = 0
|
|
ave_nme = 0
|
|
idx = 0
|
|
|
|
for (predict_point, predict_pose,
|
|
gt) in zip(predicts['point'], predicts['pose'], gts):
|
|
target_point = gt['target_point']
|
|
target_point_mask = gt['target_point_mask']
|
|
target_pose = gt['target_pose']
|
|
target_pose_mask = gt['target_pose_mask']
|
|
|
|
target_point = target_point * target_point_mask
|
|
target_pose = target_pose * target_pose_mask
|
|
|
|
keypoint_accuracy = get_keypoint_accuracy(
|
|
predict_point, target_point)
|
|
pose_accuracy = get_pose_accuracy(predict_pose, target_pose)
|
|
|
|
ave_pose_acc += pose_accuracy['pose_acc']
|
|
ave_nme += keypoint_accuracy['nme']
|
|
idx += 1
|
|
|
|
eval_result = {}
|
|
idx += 0.000001
|
|
eval_result['ave_pose_acc'] = ave_pose_acc / idx
|
|
eval_result['ave_nme'] = ave_nme / idx
|
|
|
|
return eval_result
|
|
|
|
return evaluate(prediction_dict, groundtruth_dict)
|
|
|
|
|
|
METRICS.register_default_best_metric(FaceKeypointEvaluator, 'ave_nme', 'min')
|