mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Fix] Use zero as default value of thrs
in metrics. (#341)
* Use zero as default value of `thrs` in metrics. And it accepcts a number instead of float now. * Fix unit test comment * Don't pass thrs if no thrs.
This commit is contained in:
parent
24e58ba26d
commit
15cd34bbef
@ -1,3 +1,5 @@
|
||||
from numbers import Number
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -36,7 +38,7 @@ def calculate_confusion_matrix(pred, target):
|
||||
return confusion_matrix
|
||||
|
||||
|
||||
def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
|
||||
def precision_recall_f1(pred, target, average_mode='macro', thrs=0.):
|
||||
"""Calculate precision, recall and f1 score according to the prediction and
|
||||
target.
|
||||
|
||||
@ -49,8 +51,8 @@ def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
|
||||
class are returned. If 'macro', calculate metrics for each class,
|
||||
and find their unweighted mean.
|
||||
Defaults to 'macro'.
|
||||
thrs (float | tuple[float], optional): Predictions with scores under
|
||||
the thresholds are considered negative. Default to None.
|
||||
thrs (Number | tuple[Number], optional): Predictions with scores under
|
||||
the thresholds are considered negative. Default to 0.
|
||||
|
||||
Returns:
|
||||
tuple: tuple containing precision, recall, f1 score.
|
||||
@ -78,16 +80,14 @@ def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
|
||||
(f'pred and target should be torch.Tensor or np.ndarray, '
|
||||
f'but got {type(pred)} and {type(target)}.')
|
||||
|
||||
if thrs is None:
|
||||
thrs = 0.0
|
||||
if isinstance(thrs, float):
|
||||
if isinstance(thrs, Number):
|
||||
thrs = (thrs, )
|
||||
return_single = True
|
||||
elif isinstance(thrs, tuple):
|
||||
return_single = False
|
||||
else:
|
||||
raise TypeError(
|
||||
f'thrs should be float or tuple, but got {type(thrs)}.')
|
||||
f'thrs should be a number or tuple, but got {type(thrs)}.')
|
||||
|
||||
label = np.indices(pred.shape)[1]
|
||||
pred_label = np.argsort(pred, axis=1)[:, -1]
|
||||
@ -123,7 +123,7 @@ def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
|
||||
return precisions, recalls, f1_scores
|
||||
|
||||
|
||||
def precision(pred, target, average_mode='macro', thrs=None):
|
||||
def precision(pred, target, average_mode='macro', thrs=0.):
|
||||
"""Calculate precision according to the prediction and target.
|
||||
|
||||
Args:
|
||||
@ -135,8 +135,8 @@ def precision(pred, target, average_mode='macro', thrs=None):
|
||||
class are returned. If 'macro', calculate metrics for each class,
|
||||
and find their unweighted mean.
|
||||
Defaults to 'macro'.
|
||||
thrs (float | tuple[float], optional): Predictions with scores under
|
||||
the thresholds are considered negative. Default to None.
|
||||
thrs (Number | tuple[Number], optional): Predictions with scores under
|
||||
the thresholds are considered negative. Default to 0.
|
||||
|
||||
Returns:
|
||||
float | np.array | list[float | np.array]: Precision.
|
||||
@ -153,7 +153,7 @@ def precision(pred, target, average_mode='macro', thrs=None):
|
||||
return precisions
|
||||
|
||||
|
||||
def recall(pred, target, average_mode='macro', thrs=None):
|
||||
def recall(pred, target, average_mode='macro', thrs=0.):
|
||||
"""Calculate recall according to the prediction and target.
|
||||
|
||||
Args:
|
||||
@ -165,8 +165,8 @@ def recall(pred, target, average_mode='macro', thrs=None):
|
||||
class are returned. If 'macro', calculate metrics for each class,
|
||||
and find their unweighted mean.
|
||||
Defaults to 'macro'.
|
||||
thrs (float | tuple[float], optional): Predictions with scores under
|
||||
the thresholds are considered negative. Default to None.
|
||||
thrs (Number | tuple[Number], optional): Predictions with scores under
|
||||
the thresholds are considered negative. Default to 0.
|
||||
|
||||
Returns:
|
||||
float | np.array | list[float | np.array]: Recall.
|
||||
@ -183,7 +183,7 @@ def recall(pred, target, average_mode='macro', thrs=None):
|
||||
return recalls
|
||||
|
||||
|
||||
def f1_score(pred, target, average_mode='macro', thrs=None):
|
||||
def f1_score(pred, target, average_mode='macro', thrs=0.):
|
||||
"""Calculate F1 score according to the prediction and target.
|
||||
|
||||
Args:
|
||||
@ -195,8 +195,8 @@ def f1_score(pred, target, average_mode='macro', thrs=None):
|
||||
class are returned. If 'macro', calculate metrics for each class,
|
||||
and find their unweighted mean.
|
||||
Defaults to 'macro'.
|
||||
thrs (float | tuple[float], optional): Predictions with scores under
|
||||
the thresholds are considered negative. Default to None.
|
||||
thrs (Number | tuple[Number], optional): Predictions with scores under
|
||||
the thresholds are considered negative. Default to 0.
|
||||
|
||||
Returns:
|
||||
float | np.array | list[float | np.array]: F1 score.
|
||||
|
@ -157,7 +157,10 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
||||
average_mode = metric_options.get('average_mode', 'macro')
|
||||
|
||||
if 'accuracy' in metrics:
|
||||
acc = accuracy(results, gt_labels, topk=topk, thrs=thrs)
|
||||
if thrs is not None:
|
||||
acc = accuracy(results, gt_labels, topk=topk, thrs=thrs)
|
||||
else:
|
||||
acc = accuracy(results, gt_labels, topk=topk)
|
||||
if isinstance(topk, tuple):
|
||||
eval_results_ = {
|
||||
f'accuracy_top-{k}': a
|
||||
@ -183,8 +186,12 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
||||
|
||||
precision_recall_f1_keys = ['precision', 'recall', 'f1_score']
|
||||
if len(set(metrics) & set(precision_recall_f1_keys)) != 0:
|
||||
precision_recall_f1_values = precision_recall_f1(
|
||||
results, gt_labels, average_mode=average_mode, thrs=thrs)
|
||||
if thrs is not None:
|
||||
precision_recall_f1_values = precision_recall_f1(
|
||||
results, gt_labels, average_mode=average_mode, thrs=thrs)
|
||||
else:
|
||||
precision_recall_f1_values = precision_recall_f1(
|
||||
results, gt_labels, average_mode=average_mode)
|
||||
for key, values in zip(precision_recall_f1_keys,
|
||||
precision_recall_f1_values):
|
||||
if key in metrics:
|
||||
|
@ -1,19 +1,19 @@
|
||||
from numbers import Number
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def accuracy_numpy(pred, target, topk=1, thrs=None):
|
||||
if thrs is None:
|
||||
thrs = 0.0
|
||||
if isinstance(thrs, float):
|
||||
def accuracy_numpy(pred, target, topk=1, thrs=0.):
|
||||
if isinstance(thrs, Number):
|
||||
thrs = (thrs, )
|
||||
res_single = True
|
||||
elif isinstance(thrs, tuple):
|
||||
res_single = False
|
||||
else:
|
||||
raise TypeError(
|
||||
f'thrs should be float or tuple, but got {type(thrs)}.')
|
||||
f'thrs should be a number or tuple, but got {type(thrs)}.')
|
||||
|
||||
res = []
|
||||
maxk = max(topk)
|
||||
@ -36,17 +36,15 @@ def accuracy_numpy(pred, target, topk=1, thrs=None):
|
||||
return res
|
||||
|
||||
|
||||
def accuracy_torch(pred, target, topk=1, thrs=None):
|
||||
if thrs is None:
|
||||
thrs = 0.0
|
||||
if isinstance(thrs, float):
|
||||
def accuracy_torch(pred, target, topk=1, thrs=0.):
|
||||
if isinstance(thrs, Number):
|
||||
thrs = (thrs, )
|
||||
res_single = True
|
||||
elif isinstance(thrs, tuple):
|
||||
res_single = False
|
||||
else:
|
||||
raise TypeError(
|
||||
f'thrs should be float or tuple, but got {type(thrs)}.')
|
||||
f'thrs should be a number or tuple, but got {type(thrs)}.')
|
||||
|
||||
res = []
|
||||
maxk = max(topk)
|
||||
@ -68,7 +66,7 @@ def accuracy_torch(pred, target, topk=1, thrs=None):
|
||||
return res
|
||||
|
||||
|
||||
def accuracy(pred, target, topk=1, thrs=None):
|
||||
def accuracy(pred, target, topk=1, thrs=0.):
|
||||
"""Calculate accuracy according to the prediction and target.
|
||||
|
||||
Args:
|
||||
@ -77,9 +75,8 @@ def accuracy(pred, target, topk=1, thrs=None):
|
||||
topk (int | tuple[int]): If the predictions in ``topk``
|
||||
matches the target, the predictions will be regarded as
|
||||
correct ones. Defaults to 1.
|
||||
thrs (float, optional): thrs (float | tuple[float], optional):
|
||||
Predictions with scores under the thresholds are considered
|
||||
negative. Default to None.
|
||||
thrs (Number | tuple[Number], optional): Predictions with scores under
|
||||
the thresholds are considered negative. Default to 0.
|
||||
|
||||
Returns:
|
||||
float | list[float] | list[list[float]]: Accuracy
|
||||
|
@ -132,7 +132,7 @@ def test_dataset_evaluation():
|
||||
assert eval_results['f1_score'] == pytest.approx(
|
||||
(1 / 2 + 0 + 1 / 2) / 3 * 100.0)
|
||||
assert eval_results['accuracy'] == pytest.approx(2 / 6 * 100)
|
||||
# thrs must be a float, tuple or None
|
||||
# thrs must be a number or tuple
|
||||
with pytest.raises(TypeError):
|
||||
eval_results = dataset.evaluate(
|
||||
fake_results,
|
||||
|
Loading…
x
Reference in New Issue
Block a user