mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Fix] Use zero as default value of thrs
in metrics. (#341)
* Use zero as default value of `thrs` in metrics. And it accepcts a number instead of float now. * Fix unit test comment * Don't pass thrs if no thrs.
This commit is contained in:
parent
24e58ba26d
commit
15cd34bbef
@ -1,3 +1,5 @@
|
|||||||
|
from numbers import Number
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -36,7 +38,7 @@ def calculate_confusion_matrix(pred, target):
|
|||||||
return confusion_matrix
|
return confusion_matrix
|
||||||
|
|
||||||
|
|
||||||
def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
|
def precision_recall_f1(pred, target, average_mode='macro', thrs=0.):
|
||||||
"""Calculate precision, recall and f1 score according to the prediction and
|
"""Calculate precision, recall and f1 score according to the prediction and
|
||||||
target.
|
target.
|
||||||
|
|
||||||
@ -49,8 +51,8 @@ def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
|
|||||||
class are returned. If 'macro', calculate metrics for each class,
|
class are returned. If 'macro', calculate metrics for each class,
|
||||||
and find their unweighted mean.
|
and find their unweighted mean.
|
||||||
Defaults to 'macro'.
|
Defaults to 'macro'.
|
||||||
thrs (float | tuple[float], optional): Predictions with scores under
|
thrs (Number | tuple[Number], optional): Predictions with scores under
|
||||||
the thresholds are considered negative. Default to None.
|
the thresholds are considered negative. Default to 0.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: tuple containing precision, recall, f1 score.
|
tuple: tuple containing precision, recall, f1 score.
|
||||||
@ -78,16 +80,14 @@ def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
|
|||||||
(f'pred and target should be torch.Tensor or np.ndarray, '
|
(f'pred and target should be torch.Tensor or np.ndarray, '
|
||||||
f'but got {type(pred)} and {type(target)}.')
|
f'but got {type(pred)} and {type(target)}.')
|
||||||
|
|
||||||
if thrs is None:
|
if isinstance(thrs, Number):
|
||||||
thrs = 0.0
|
|
||||||
if isinstance(thrs, float):
|
|
||||||
thrs = (thrs, )
|
thrs = (thrs, )
|
||||||
return_single = True
|
return_single = True
|
||||||
elif isinstance(thrs, tuple):
|
elif isinstance(thrs, tuple):
|
||||||
return_single = False
|
return_single = False
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f'thrs should be float or tuple, but got {type(thrs)}.')
|
f'thrs should be a number or tuple, but got {type(thrs)}.')
|
||||||
|
|
||||||
label = np.indices(pred.shape)[1]
|
label = np.indices(pred.shape)[1]
|
||||||
pred_label = np.argsort(pred, axis=1)[:, -1]
|
pred_label = np.argsort(pred, axis=1)[:, -1]
|
||||||
@ -123,7 +123,7 @@ def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
|
|||||||
return precisions, recalls, f1_scores
|
return precisions, recalls, f1_scores
|
||||||
|
|
||||||
|
|
||||||
def precision(pred, target, average_mode='macro', thrs=None):
|
def precision(pred, target, average_mode='macro', thrs=0.):
|
||||||
"""Calculate precision according to the prediction and target.
|
"""Calculate precision according to the prediction and target.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -135,8 +135,8 @@ def precision(pred, target, average_mode='macro', thrs=None):
|
|||||||
class are returned. If 'macro', calculate metrics for each class,
|
class are returned. If 'macro', calculate metrics for each class,
|
||||||
and find their unweighted mean.
|
and find their unweighted mean.
|
||||||
Defaults to 'macro'.
|
Defaults to 'macro'.
|
||||||
thrs (float | tuple[float], optional): Predictions with scores under
|
thrs (Number | tuple[Number], optional): Predictions with scores under
|
||||||
the thresholds are considered negative. Default to None.
|
the thresholds are considered negative. Default to 0.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
float | np.array | list[float | np.array]: Precision.
|
float | np.array | list[float | np.array]: Precision.
|
||||||
@ -153,7 +153,7 @@ def precision(pred, target, average_mode='macro', thrs=None):
|
|||||||
return precisions
|
return precisions
|
||||||
|
|
||||||
|
|
||||||
def recall(pred, target, average_mode='macro', thrs=None):
|
def recall(pred, target, average_mode='macro', thrs=0.):
|
||||||
"""Calculate recall according to the prediction and target.
|
"""Calculate recall according to the prediction and target.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -165,8 +165,8 @@ def recall(pred, target, average_mode='macro', thrs=None):
|
|||||||
class are returned. If 'macro', calculate metrics for each class,
|
class are returned. If 'macro', calculate metrics for each class,
|
||||||
and find their unweighted mean.
|
and find their unweighted mean.
|
||||||
Defaults to 'macro'.
|
Defaults to 'macro'.
|
||||||
thrs (float | tuple[float], optional): Predictions with scores under
|
thrs (Number | tuple[Number], optional): Predictions with scores under
|
||||||
the thresholds are considered negative. Default to None.
|
the thresholds are considered negative. Default to 0.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
float | np.array | list[float | np.array]: Recall.
|
float | np.array | list[float | np.array]: Recall.
|
||||||
@ -183,7 +183,7 @@ def recall(pred, target, average_mode='macro', thrs=None):
|
|||||||
return recalls
|
return recalls
|
||||||
|
|
||||||
|
|
||||||
def f1_score(pred, target, average_mode='macro', thrs=None):
|
def f1_score(pred, target, average_mode='macro', thrs=0.):
|
||||||
"""Calculate F1 score according to the prediction and target.
|
"""Calculate F1 score according to the prediction and target.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -195,8 +195,8 @@ def f1_score(pred, target, average_mode='macro', thrs=None):
|
|||||||
class are returned. If 'macro', calculate metrics for each class,
|
class are returned. If 'macro', calculate metrics for each class,
|
||||||
and find their unweighted mean.
|
and find their unweighted mean.
|
||||||
Defaults to 'macro'.
|
Defaults to 'macro'.
|
||||||
thrs (float | tuple[float], optional): Predictions with scores under
|
thrs (Number | tuple[Number], optional): Predictions with scores under
|
||||||
the thresholds are considered negative. Default to None.
|
the thresholds are considered negative. Default to 0.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
float | np.array | list[float | np.array]: F1 score.
|
float | np.array | list[float | np.array]: F1 score.
|
||||||
|
@ -157,7 +157,10 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||||||
average_mode = metric_options.get('average_mode', 'macro')
|
average_mode = metric_options.get('average_mode', 'macro')
|
||||||
|
|
||||||
if 'accuracy' in metrics:
|
if 'accuracy' in metrics:
|
||||||
acc = accuracy(results, gt_labels, topk=topk, thrs=thrs)
|
if thrs is not None:
|
||||||
|
acc = accuracy(results, gt_labels, topk=topk, thrs=thrs)
|
||||||
|
else:
|
||||||
|
acc = accuracy(results, gt_labels, topk=topk)
|
||||||
if isinstance(topk, tuple):
|
if isinstance(topk, tuple):
|
||||||
eval_results_ = {
|
eval_results_ = {
|
||||||
f'accuracy_top-{k}': a
|
f'accuracy_top-{k}': a
|
||||||
@ -183,8 +186,12 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||||||
|
|
||||||
precision_recall_f1_keys = ['precision', 'recall', 'f1_score']
|
precision_recall_f1_keys = ['precision', 'recall', 'f1_score']
|
||||||
if len(set(metrics) & set(precision_recall_f1_keys)) != 0:
|
if len(set(metrics) & set(precision_recall_f1_keys)) != 0:
|
||||||
precision_recall_f1_values = precision_recall_f1(
|
if thrs is not None:
|
||||||
results, gt_labels, average_mode=average_mode, thrs=thrs)
|
precision_recall_f1_values = precision_recall_f1(
|
||||||
|
results, gt_labels, average_mode=average_mode, thrs=thrs)
|
||||||
|
else:
|
||||||
|
precision_recall_f1_values = precision_recall_f1(
|
||||||
|
results, gt_labels, average_mode=average_mode)
|
||||||
for key, values in zip(precision_recall_f1_keys,
|
for key, values in zip(precision_recall_f1_keys,
|
||||||
precision_recall_f1_values):
|
precision_recall_f1_values):
|
||||||
if key in metrics:
|
if key in metrics:
|
||||||
|
@ -1,19 +1,19 @@
|
|||||||
|
from numbers import Number
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def accuracy_numpy(pred, target, topk=1, thrs=None):
|
def accuracy_numpy(pred, target, topk=1, thrs=0.):
|
||||||
if thrs is None:
|
if isinstance(thrs, Number):
|
||||||
thrs = 0.0
|
|
||||||
if isinstance(thrs, float):
|
|
||||||
thrs = (thrs, )
|
thrs = (thrs, )
|
||||||
res_single = True
|
res_single = True
|
||||||
elif isinstance(thrs, tuple):
|
elif isinstance(thrs, tuple):
|
||||||
res_single = False
|
res_single = False
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f'thrs should be float or tuple, but got {type(thrs)}.')
|
f'thrs should be a number or tuple, but got {type(thrs)}.')
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
maxk = max(topk)
|
maxk = max(topk)
|
||||||
@ -36,17 +36,15 @@ def accuracy_numpy(pred, target, topk=1, thrs=None):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def accuracy_torch(pred, target, topk=1, thrs=None):
|
def accuracy_torch(pred, target, topk=1, thrs=0.):
|
||||||
if thrs is None:
|
if isinstance(thrs, Number):
|
||||||
thrs = 0.0
|
|
||||||
if isinstance(thrs, float):
|
|
||||||
thrs = (thrs, )
|
thrs = (thrs, )
|
||||||
res_single = True
|
res_single = True
|
||||||
elif isinstance(thrs, tuple):
|
elif isinstance(thrs, tuple):
|
||||||
res_single = False
|
res_single = False
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f'thrs should be float or tuple, but got {type(thrs)}.')
|
f'thrs should be a number or tuple, but got {type(thrs)}.')
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
maxk = max(topk)
|
maxk = max(topk)
|
||||||
@ -68,7 +66,7 @@ def accuracy_torch(pred, target, topk=1, thrs=None):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def accuracy(pred, target, topk=1, thrs=None):
|
def accuracy(pred, target, topk=1, thrs=0.):
|
||||||
"""Calculate accuracy according to the prediction and target.
|
"""Calculate accuracy according to the prediction and target.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -77,9 +75,8 @@ def accuracy(pred, target, topk=1, thrs=None):
|
|||||||
topk (int | tuple[int]): If the predictions in ``topk``
|
topk (int | tuple[int]): If the predictions in ``topk``
|
||||||
matches the target, the predictions will be regarded as
|
matches the target, the predictions will be regarded as
|
||||||
correct ones. Defaults to 1.
|
correct ones. Defaults to 1.
|
||||||
thrs (float, optional): thrs (float | tuple[float], optional):
|
thrs (Number | tuple[Number], optional): Predictions with scores under
|
||||||
Predictions with scores under the thresholds are considered
|
the thresholds are considered negative. Default to 0.
|
||||||
negative. Default to None.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
float | list[float] | list[list[float]]: Accuracy
|
float | list[float] | list[list[float]]: Accuracy
|
||||||
|
@ -132,7 +132,7 @@ def test_dataset_evaluation():
|
|||||||
assert eval_results['f1_score'] == pytest.approx(
|
assert eval_results['f1_score'] == pytest.approx(
|
||||||
(1 / 2 + 0 + 1 / 2) / 3 * 100.0)
|
(1 / 2 + 0 + 1 / 2) / 3 * 100.0)
|
||||||
assert eval_results['accuracy'] == pytest.approx(2 / 6 * 100)
|
assert eval_results['accuracy'] == pytest.approx(2 / 6 * 100)
|
||||||
# thrs must be a float, tuple or None
|
# thrs must be a number or tuple
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
eval_results = dataset.evaluate(
|
eval_results = dataset.evaluate(
|
||||||
fake_results,
|
fake_results,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user