[Fix] Use zero as default value of thrs in metrics. (#341)

* Use zero as default value of `thrs` in metrics. And it accepcts a number
instead of float now.

* Fix unit test comment

* Don't pass thrs if no thrs.
This commit is contained in:
Ma Zerun 2021-07-18 16:57:21 +08:00 committed by GitHub
parent 24e58ba26d
commit 15cd34bbef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 34 deletions

View File

@ -1,3 +1,5 @@
from numbers import Number
import numpy as np import numpy as np
import torch import torch
@ -36,7 +38,7 @@ def calculate_confusion_matrix(pred, target):
return confusion_matrix return confusion_matrix
def precision_recall_f1(pred, target, average_mode='macro', thrs=None): def precision_recall_f1(pred, target, average_mode='macro', thrs=0.):
"""Calculate precision, recall and f1 score according to the prediction and """Calculate precision, recall and f1 score according to the prediction and
target. target.
@ -49,8 +51,8 @@ def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
class are returned. If 'macro', calculate metrics for each class, class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean. and find their unweighted mean.
Defaults to 'macro'. Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to None. the thresholds are considered negative. Default to 0.
Returns: Returns:
tuple: tuple containing precision, recall, f1 score. tuple: tuple containing precision, recall, f1 score.
@ -78,16 +80,14 @@ def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
(f'pred and target should be torch.Tensor or np.ndarray, ' (f'pred and target should be torch.Tensor or np.ndarray, '
f'but got {type(pred)} and {type(target)}.') f'but got {type(pred)} and {type(target)}.')
if thrs is None: if isinstance(thrs, Number):
thrs = 0.0
if isinstance(thrs, float):
thrs = (thrs, ) thrs = (thrs, )
return_single = True return_single = True
elif isinstance(thrs, tuple): elif isinstance(thrs, tuple):
return_single = False return_single = False
else: else:
raise TypeError( raise TypeError(
f'thrs should be float or tuple, but got {type(thrs)}.') f'thrs should be a number or tuple, but got {type(thrs)}.')
label = np.indices(pred.shape)[1] label = np.indices(pred.shape)[1]
pred_label = np.argsort(pred, axis=1)[:, -1] pred_label = np.argsort(pred, axis=1)[:, -1]
@ -123,7 +123,7 @@ def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
return precisions, recalls, f1_scores return precisions, recalls, f1_scores
def precision(pred, target, average_mode='macro', thrs=None): def precision(pred, target, average_mode='macro', thrs=0.):
"""Calculate precision according to the prediction and target. """Calculate precision according to the prediction and target.
Args: Args:
@ -135,8 +135,8 @@ def precision(pred, target, average_mode='macro', thrs=None):
class are returned. If 'macro', calculate metrics for each class, class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean. and find their unweighted mean.
Defaults to 'macro'. Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to None. the thresholds are considered negative. Default to 0.
Returns: Returns:
float | np.array | list[float | np.array]: Precision. float | np.array | list[float | np.array]: Precision.
@ -153,7 +153,7 @@ def precision(pred, target, average_mode='macro', thrs=None):
return precisions return precisions
def recall(pred, target, average_mode='macro', thrs=None): def recall(pred, target, average_mode='macro', thrs=0.):
"""Calculate recall according to the prediction and target. """Calculate recall according to the prediction and target.
Args: Args:
@ -165,8 +165,8 @@ def recall(pred, target, average_mode='macro', thrs=None):
class are returned. If 'macro', calculate metrics for each class, class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean. and find their unweighted mean.
Defaults to 'macro'. Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to None. the thresholds are considered negative. Default to 0.
Returns: Returns:
float | np.array | list[float | np.array]: Recall. float | np.array | list[float | np.array]: Recall.
@ -183,7 +183,7 @@ def recall(pred, target, average_mode='macro', thrs=None):
return recalls return recalls
def f1_score(pred, target, average_mode='macro', thrs=None): def f1_score(pred, target, average_mode='macro', thrs=0.):
"""Calculate F1 score according to the prediction and target. """Calculate F1 score according to the prediction and target.
Args: Args:
@ -195,8 +195,8 @@ def f1_score(pred, target, average_mode='macro', thrs=None):
class are returned. If 'macro', calculate metrics for each class, class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean. and find their unweighted mean.
Defaults to 'macro'. Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to None. the thresholds are considered negative. Default to 0.
Returns: Returns:
float | np.array | list[float | np.array]: F1 score. float | np.array | list[float | np.array]: F1 score.

View File

@ -157,7 +157,10 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
average_mode = metric_options.get('average_mode', 'macro') average_mode = metric_options.get('average_mode', 'macro')
if 'accuracy' in metrics: if 'accuracy' in metrics:
acc = accuracy(results, gt_labels, topk=topk, thrs=thrs) if thrs is not None:
acc = accuracy(results, gt_labels, topk=topk, thrs=thrs)
else:
acc = accuracy(results, gt_labels, topk=topk)
if isinstance(topk, tuple): if isinstance(topk, tuple):
eval_results_ = { eval_results_ = {
f'accuracy_top-{k}': a f'accuracy_top-{k}': a
@ -183,8 +186,12 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
precision_recall_f1_keys = ['precision', 'recall', 'f1_score'] precision_recall_f1_keys = ['precision', 'recall', 'f1_score']
if len(set(metrics) & set(precision_recall_f1_keys)) != 0: if len(set(metrics) & set(precision_recall_f1_keys)) != 0:
precision_recall_f1_values = precision_recall_f1( if thrs is not None:
results, gt_labels, average_mode=average_mode, thrs=thrs) precision_recall_f1_values = precision_recall_f1(
results, gt_labels, average_mode=average_mode, thrs=thrs)
else:
precision_recall_f1_values = precision_recall_f1(
results, gt_labels, average_mode=average_mode)
for key, values in zip(precision_recall_f1_keys, for key, values in zip(precision_recall_f1_keys,
precision_recall_f1_values): precision_recall_f1_values):
if key in metrics: if key in metrics:

View File

@ -1,19 +1,19 @@
from numbers import Number
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
def accuracy_numpy(pred, target, topk=1, thrs=None): def accuracy_numpy(pred, target, topk=1, thrs=0.):
if thrs is None: if isinstance(thrs, Number):
thrs = 0.0
if isinstance(thrs, float):
thrs = (thrs, ) thrs = (thrs, )
res_single = True res_single = True
elif isinstance(thrs, tuple): elif isinstance(thrs, tuple):
res_single = False res_single = False
else: else:
raise TypeError( raise TypeError(
f'thrs should be float or tuple, but got {type(thrs)}.') f'thrs should be a number or tuple, but got {type(thrs)}.')
res = [] res = []
maxk = max(topk) maxk = max(topk)
@ -36,17 +36,15 @@ def accuracy_numpy(pred, target, topk=1, thrs=None):
return res return res
def accuracy_torch(pred, target, topk=1, thrs=None): def accuracy_torch(pred, target, topk=1, thrs=0.):
if thrs is None: if isinstance(thrs, Number):
thrs = 0.0
if isinstance(thrs, float):
thrs = (thrs, ) thrs = (thrs, )
res_single = True res_single = True
elif isinstance(thrs, tuple): elif isinstance(thrs, tuple):
res_single = False res_single = False
else: else:
raise TypeError( raise TypeError(
f'thrs should be float or tuple, but got {type(thrs)}.') f'thrs should be a number or tuple, but got {type(thrs)}.')
res = [] res = []
maxk = max(topk) maxk = max(topk)
@ -68,7 +66,7 @@ def accuracy_torch(pred, target, topk=1, thrs=None):
return res return res
def accuracy(pred, target, topk=1, thrs=None): def accuracy(pred, target, topk=1, thrs=0.):
"""Calculate accuracy according to the prediction and target. """Calculate accuracy according to the prediction and target.
Args: Args:
@ -77,9 +75,8 @@ def accuracy(pred, target, topk=1, thrs=None):
topk (int | tuple[int]): If the predictions in ``topk`` topk (int | tuple[int]): If the predictions in ``topk``
matches the target, the predictions will be regarded as matches the target, the predictions will be regarded as
correct ones. Defaults to 1. correct ones. Defaults to 1.
thrs (float, optional): thrs (float | tuple[float], optional): thrs (Number | tuple[Number], optional): Predictions with scores under
Predictions with scores under the thresholds are considered the thresholds are considered negative. Default to 0.
negative. Default to None.
Returns: Returns:
float | list[float] | list[list[float]]: Accuracy float | list[float] | list[list[float]]: Accuracy

View File

@ -132,7 +132,7 @@ def test_dataset_evaluation():
assert eval_results['f1_score'] == pytest.approx( assert eval_results['f1_score'] == pytest.approx(
(1 / 2 + 0 + 1 / 2) / 3 * 100.0) (1 / 2 + 0 + 1 / 2) / 3 * 100.0)
assert eval_results['accuracy'] == pytest.approx(2 / 6 * 100) assert eval_results['accuracy'] == pytest.approx(2 / 6 * 100)
# thrs must be a float, tuple or None # thrs must be a number or tuple
with pytest.raises(TypeError): with pytest.raises(TypeError):
eval_results = dataset.evaluate( eval_results = dataset.evaluate(
fake_results, fake_results,