[Fix] Use zero as default value of thrs in metrics. (#341)

* Use zero as default value of `thrs` in metrics. And it accepcts a number
instead of float now.

* Fix unit test comment

* Don't pass thrs if no thrs.
This commit is contained in:
Ma Zerun 2021-07-18 16:57:21 +08:00 committed by GitHub
parent 24e58ba26d
commit 15cd34bbef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 34 deletions

View File

@ -1,3 +1,5 @@
from numbers import Number
import numpy as np
import torch
@ -36,7 +38,7 @@ def calculate_confusion_matrix(pred, target):
return confusion_matrix
def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
def precision_recall_f1(pred, target, average_mode='macro', thrs=0.):
"""Calculate precision, recall and f1 score according to the prediction and
target.
@ -49,8 +51,8 @@ def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under
the thresholds are considered negative. Default to None.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to 0.
Returns:
tuple: tuple containing precision, recall, f1 score.
@ -78,16 +80,14 @@ def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
(f'pred and target should be torch.Tensor or np.ndarray, '
f'but got {type(pred)} and {type(target)}.')
if thrs is None:
thrs = 0.0
if isinstance(thrs, float):
if isinstance(thrs, Number):
thrs = (thrs, )
return_single = True
elif isinstance(thrs, tuple):
return_single = False
else:
raise TypeError(
f'thrs should be float or tuple, but got {type(thrs)}.')
f'thrs should be a number or tuple, but got {type(thrs)}.')
label = np.indices(pred.shape)[1]
pred_label = np.argsort(pred, axis=1)[:, -1]
@ -123,7 +123,7 @@ def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
return precisions, recalls, f1_scores
def precision(pred, target, average_mode='macro', thrs=None):
def precision(pred, target, average_mode='macro', thrs=0.):
"""Calculate precision according to the prediction and target.
Args:
@ -135,8 +135,8 @@ def precision(pred, target, average_mode='macro', thrs=None):
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under
the thresholds are considered negative. Default to None.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to 0.
Returns:
float | np.array | list[float | np.array]: Precision.
@ -153,7 +153,7 @@ def precision(pred, target, average_mode='macro', thrs=None):
return precisions
def recall(pred, target, average_mode='macro', thrs=None):
def recall(pred, target, average_mode='macro', thrs=0.):
"""Calculate recall according to the prediction and target.
Args:
@ -165,8 +165,8 @@ def recall(pred, target, average_mode='macro', thrs=None):
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under
the thresholds are considered negative. Default to None.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to 0.
Returns:
float | np.array | list[float | np.array]: Recall.
@ -183,7 +183,7 @@ def recall(pred, target, average_mode='macro', thrs=None):
return recalls
def f1_score(pred, target, average_mode='macro', thrs=None):
def f1_score(pred, target, average_mode='macro', thrs=0.):
"""Calculate F1 score according to the prediction and target.
Args:
@ -195,8 +195,8 @@ def f1_score(pred, target, average_mode='macro', thrs=None):
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under
the thresholds are considered negative. Default to None.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to 0.
Returns:
float | np.array | list[float | np.array]: F1 score.

View File

@ -157,7 +157,10 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
average_mode = metric_options.get('average_mode', 'macro')
if 'accuracy' in metrics:
acc = accuracy(results, gt_labels, topk=topk, thrs=thrs)
if thrs is not None:
acc = accuracy(results, gt_labels, topk=topk, thrs=thrs)
else:
acc = accuracy(results, gt_labels, topk=topk)
if isinstance(topk, tuple):
eval_results_ = {
f'accuracy_top-{k}': a
@ -183,8 +186,12 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
precision_recall_f1_keys = ['precision', 'recall', 'f1_score']
if len(set(metrics) & set(precision_recall_f1_keys)) != 0:
precision_recall_f1_values = precision_recall_f1(
results, gt_labels, average_mode=average_mode, thrs=thrs)
if thrs is not None:
precision_recall_f1_values = precision_recall_f1(
results, gt_labels, average_mode=average_mode, thrs=thrs)
else:
precision_recall_f1_values = precision_recall_f1(
results, gt_labels, average_mode=average_mode)
for key, values in zip(precision_recall_f1_keys,
precision_recall_f1_values):
if key in metrics:

View File

@ -1,19 +1,19 @@
from numbers import Number
import numpy as np
import torch
import torch.nn as nn
def accuracy_numpy(pred, target, topk=1, thrs=None):
if thrs is None:
thrs = 0.0
if isinstance(thrs, float):
def accuracy_numpy(pred, target, topk=1, thrs=0.):
if isinstance(thrs, Number):
thrs = (thrs, )
res_single = True
elif isinstance(thrs, tuple):
res_single = False
else:
raise TypeError(
f'thrs should be float or tuple, but got {type(thrs)}.')
f'thrs should be a number or tuple, but got {type(thrs)}.')
res = []
maxk = max(topk)
@ -36,17 +36,15 @@ def accuracy_numpy(pred, target, topk=1, thrs=None):
return res
def accuracy_torch(pred, target, topk=1, thrs=None):
if thrs is None:
thrs = 0.0
if isinstance(thrs, float):
def accuracy_torch(pred, target, topk=1, thrs=0.):
if isinstance(thrs, Number):
thrs = (thrs, )
res_single = True
elif isinstance(thrs, tuple):
res_single = False
else:
raise TypeError(
f'thrs should be float or tuple, but got {type(thrs)}.')
f'thrs should be a number or tuple, but got {type(thrs)}.')
res = []
maxk = max(topk)
@ -68,7 +66,7 @@ def accuracy_torch(pred, target, topk=1, thrs=None):
return res
def accuracy(pred, target, topk=1, thrs=None):
def accuracy(pred, target, topk=1, thrs=0.):
"""Calculate accuracy according to the prediction and target.
Args:
@ -77,9 +75,8 @@ def accuracy(pred, target, topk=1, thrs=None):
topk (int | tuple[int]): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
thrs (float, optional): thrs (float | tuple[float], optional):
Predictions with scores under the thresholds are considered
negative. Default to None.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to 0.
Returns:
float | list[float] | list[list[float]]: Accuracy

View File

@ -132,7 +132,7 @@ def test_dataset_evaluation():
assert eval_results['f1_score'] == pytest.approx(
(1 / 2 + 0 + 1 / 2) / 3 * 100.0)
assert eval_results['accuracy'] == pytest.approx(2 / 6 * 100)
# thrs must be a float, tuple or None
# thrs must be a number or tuple
with pytest.raises(TypeError):
eval_results = dataset.evaluate(
fake_results,