mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Fix] fix hmean iou
This commit is contained in:
parent
71d1a445c9
commit
2a852f23b5
@ -171,15 +171,15 @@ class HmeanIOUMetric(BaseMetric):
|
|||||||
dataset_pred_num[i] += np.sum(~pred_ignore_flags)
|
dataset_pred_num[i] += np.sum(~pred_ignore_flags)
|
||||||
|
|
||||||
for i, pred_score_thr in enumerate(self.pred_score_thrs):
|
for i, pred_score_thr in enumerate(self.pred_score_thrs):
|
||||||
precision, recall, hmean = compute_hmean(
|
recall, precision, hmean = compute_hmean(
|
||||||
int(dataset_hit_num[i]), int(dataset_hit_num[i]),
|
int(dataset_hit_num[i]), int(dataset_hit_num[i]),
|
||||||
int(dataset_gt_num), int(dataset_pred_num[i]))
|
int(dataset_gt_num), int(dataset_pred_num[i]))
|
||||||
eval_results = dict(
|
eval_results = dict(
|
||||||
precision=precision, recall=recall, hmean=hmean)
|
precision=precision, recall=recall, hmean=hmean)
|
||||||
logger.info(f'prediction score threshold: {pred_score_thr}, '
|
logger.info(f'prediction score threshold: {pred_score_thr:.2f}, '
|
||||||
f'recall: {eval_results["recall"]:.3f}, '
|
f'recall: {eval_results["recall"]:.4f}, '
|
||||||
f'precision: {eval_results["precision"]:.3f}, '
|
f'precision: {eval_results["precision"]:.4f}, '
|
||||||
f'hmean: {eval_results["hmean"]:.3f}\n')
|
f'hmean: {eval_results["hmean"]:.4f}\n')
|
||||||
if eval_results['hmean'] > best_eval_results['hmean']:
|
if eval_results['hmean'] > best_eval_results['hmean']:
|
||||||
best_eval_results = eval_results
|
best_eval_results = eval_results
|
||||||
return best_eval_results
|
return best_eval_results
|
||||||
|
@ -68,9 +68,10 @@ class TestHmeanIOU(unittest.TestCase):
|
|||||||
pred_data_sample = TextDetDataSample()
|
pred_data_sample = TextDetDataSample()
|
||||||
pred_data_sample.pred_instances = InstanceData()
|
pred_data_sample.pred_instances = InstanceData()
|
||||||
pred_data_sample.pred_instances.polygons = [
|
pred_data_sample.pred_instances.polygons = [
|
||||||
|
torch.FloatTensor([0, 0, 1, 0, 1, 1, 0, 1]),
|
||||||
torch.FloatTensor([0, 0, 1, 0, 1, 1, 0, 1])
|
torch.FloatTensor([0, 0, 1, 0, 1, 1, 0, 1])
|
||||||
]
|
]
|
||||||
pred_data_sample.pred_instances.scores = torch.FloatTensor([0.8])
|
pred_data_sample.pred_instances.scores = torch.FloatTensor([1, 0.95])
|
||||||
predictions.append(pred_data_sample.to_dict())
|
predictions.append(pred_data_sample.to_dict())
|
||||||
|
|
||||||
self.predictions = predictions
|
self.predictions = predictions
|
||||||
@ -81,7 +82,7 @@ class TestHmeanIOU(unittest.TestCase):
|
|||||||
metric.process(self.gt, self.predictions)
|
metric.process(self.gt, self.predictions)
|
||||||
eval_results = metric.evaluate(size=2)
|
eval_results = metric.evaluate(size=2)
|
||||||
|
|
||||||
precision = 3 / 4
|
precision = 3 / 5
|
||||||
recall = 3 / 4
|
recall = 3 / 4
|
||||||
hmean = 2 * precision * recall / (precision + recall)
|
hmean = 2 * precision * recall / (precision + recall)
|
||||||
target_result = {
|
target_result = {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user