mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Docs] Add custom evaluation docs (#1130)
* [Docs] Add evaluation docs * minor fix * Fix docs. Co-authored-by: mzr1996 <mzr1996@163.com>
This commit is contained in:
parent
280e916979
commit
50aaa711ea
@ -25,3 +25,9 @@ article.pytorch-article section table code {
|
|||||||
table.autosummary td {
|
table.autosummary td {
|
||||||
width: 50%
|
width: 50%
|
||||||
}
|
}
|
||||||
|
|
||||||
|
img.align-center {
|
||||||
|
display: block;
|
||||||
|
margin-left: auto;
|
||||||
|
margin-right: auto;
|
||||||
|
}
|
||||||
|
BIN
docs/en/_static/image/confusion-matrix.png
Executable file
BIN
docs/en/_static/image/confusion-matrix.png
Executable file
Binary file not shown.
After Width: | Height: | Size: 51 KiB |
@ -1 +1,103 @@
|
|||||||
# Custom evaluation metrics (TODO)
|
# Customize Evaluation Metrics
|
||||||
|
|
||||||
|
## Use metrics in MMClassification
|
||||||
|
|
||||||
|
In MMClassification, we have provided multiple metrics for both single-label classification and multi-label
|
||||||
|
classification:
|
||||||
|
|
||||||
|
**Single-label Classification**:
|
||||||
|
|
||||||
|
- [`Accuracy`](mmcls.evaluation.Accuracy)
|
||||||
|
- [`SingleLabelMetric`](mmcls.evaluation.SingleLabelMetric), including precision, recall, f1-score and
|
||||||
|
support.
|
||||||
|
|
||||||
|
**Multi-label Classification**:
|
||||||
|
|
||||||
|
- [`AveragePrecision`](mmcls.evaluation.AveragePrecision), or AP (mAP).
|
||||||
|
- [`MultiLabelMetric`](mmcls.evaluation.MultiLabelMetric), including precision, recall, f1-score and
|
||||||
|
support.
|
||||||
|
|
||||||
|
To use these metrics during validation and testing, we need to modify the `val_evaluator` and `test_evaluator`
|
||||||
|
fields in the config file.
|
||||||
|
|
||||||
|
Here is several examples:
|
||||||
|
|
||||||
|
1. Calculate top-1 and top-5 accuracy during both validation and test.
|
||||||
|
|
||||||
|
```python
|
||||||
|
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||||
|
test_evaluator = val_evaluator
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Calculate top-1 accuracy, top-5 accuracy, precision and recall during both validation and test.
|
||||||
|
|
||||||
|
```python
|
||||||
|
val_evaluator = [
|
||||||
|
dict(type='Accuracy', topk=(1, 5)),
|
||||||
|
dict(type='SingleLabelMetric', items=['precision', 'recall']),
|
||||||
|
]
|
||||||
|
test_evaluator = val_evaluator
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Calculate mAP (mean AveragePrecision), CP (Class-wise mean Precision), CR (Class-wise mean Recall), CF
|
||||||
|
(Class-wise mean F1-score), OP (Overall mean Precision), OR (Overall mean Recall) and OF1 (Overall mean
|
||||||
|
F1-score).
|
||||||
|
|
||||||
|
```python
|
||||||
|
val_evaluator = [
|
||||||
|
dict(type='AveragePrecision'),
|
||||||
|
dict(type='MultiLabelMetric', average='macro'), # class-wise mean
|
||||||
|
dict(type='MultiLabelMetric', average='micro'), # overall mean
|
||||||
|
]
|
||||||
|
test_evaluator = val_evaluator
|
||||||
|
```
|
||||||
|
|
||||||
|
## Add new metrics
|
||||||
|
|
||||||
|
MMClassification supports the implementation of customized evaluation metrics for users who pursue higher customization.
|
||||||
|
|
||||||
|
You need to create a new file under `mmcls/evaluation/metrics`, and implement the new metric in the file, for example, in `mmcls/evaluation/metrics/my_metric.py`. And create a customized evaluation metric class `MyMetric` which inherits [`BaseMetric in MMEngine`](mmengine.evaluator.metrics.BaseMetric).
|
||||||
|
|
||||||
|
The data format processing method `process` and the metric calculation method `compute_metrics` need to be overwritten respectively. Add it to the `METRICS` registry to implement any customized evaluation metric.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mmengine.evaluator import BaseMetric
|
||||||
|
from mmcls.registry import METRICS
|
||||||
|
|
||||||
|
@METRICS.register_module()
|
||||||
|
class MyMetric(BaseMetric):
|
||||||
|
|
||||||
|
def process(self, data_batch: Sequence[Dict], data_samples: Sequence[Dict]):
|
||||||
|
""" The processed results should be stored in ``self.results``, which will
|
||||||
|
be used to computed the metrics when all batches have been processed.
|
||||||
|
`data_batch` stores the batch data from dataloader,
|
||||||
|
and `data_samples` stores the batch outputs from model.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def compute_metrics(self, results: List):
|
||||||
|
""" Compute the metrics from processed results and returns the evaluation results.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, import it in the `mmcls/evaluation/metrics/__init__.py` to add it into the `mmcls.evaluation` package.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# In mmcls/evaluation/metrics/__init__.py
|
||||||
|
...
|
||||||
|
from .my_metric import MyMetric
|
||||||
|
|
||||||
|
__all__ = [..., 'MyMetric']
|
||||||
|
```
|
||||||
|
|
||||||
|
Finally, use `MyMetric` in the `val_evaluator` and `test_evaluator` field of config files.
|
||||||
|
|
||||||
|
```python
|
||||||
|
val_evaluator = dict(type='MyMetric', ...)
|
||||||
|
test_evaluator = val_evaluator
|
||||||
|
```
|
||||||
|
|
||||||
|
```{note}
|
||||||
|
More details can be found in {external+mmengine:doc}`MMEngine Documentation: Evaluation <design/evaluation>`.
|
||||||
|
```
|
||||||
|
@ -25,3 +25,9 @@ article.pytorch-article section table code {
|
|||||||
table.autosummary td {
|
table.autosummary td {
|
||||||
width: 50%
|
width: 50%
|
||||||
}
|
}
|
||||||
|
|
||||||
|
img.align-center {
|
||||||
|
display: block;
|
||||||
|
margin-left: auto;
|
||||||
|
margin-right: auto;
|
||||||
|
}
|
||||||
|
1
docs/zh_CN/_static/image/confusion-matrix.png
Symbolic link
1
docs/zh_CN/_static/image/confusion-matrix.png
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../en/_static/image/confusion-matrix.png
|
@ -1 +1,4 @@
|
|||||||
# 自定义评估指标(待更新)
|
# 自定义评估指标(待更新)
|
||||||
|
|
||||||
|
请参见[英文文档](https://mmclassification.readthedocs.io/en/dev-1.x/advanced_guides/evaluation.html),如果你有兴
|
||||||
|
趣参与中文文档的翻译,欢迎在 [讨论区](https://github.com/open-mmlab/mmclassification/discussions/1027)进行报名。
|
||||||
|
@ -13,38 +13,58 @@ from .single_label import _precision_recall_f1_support, to_tensor
|
|||||||
|
|
||||||
@METRICS.register_module()
|
@METRICS.register_module()
|
||||||
class MultiLabelMetric(BaseMetric):
|
class MultiLabelMetric(BaseMetric):
|
||||||
"""A collection of metrics for multi-label multi-class classification task
|
r"""A collection of precision, recall, f1-score and support for
|
||||||
based on confusion matrix.
|
multi-label tasks.
|
||||||
|
|
||||||
It includes precision, recall, f1-score and support.
|
The collection of metrics is for single-label multi-class classification.
|
||||||
|
And all these metrics are based on the confusion matrix of every category:
|
||||||
|
|
||||||
|
.. image:: ../../_static/image/confusion-matrix.png
|
||||||
|
:width: 60%
|
||||||
|
:align: center
|
||||||
|
|
||||||
|
All metrics can be formulated use variables above:
|
||||||
|
|
||||||
|
**Precision** is the fraction of correct predictions in all predictions:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{Precision} = \frac{TP}{TP+FP}
|
||||||
|
|
||||||
|
**Recall** is the fraction of correct predictions in all targets:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{Recall} = \frac{TP}{TP+FN}
|
||||||
|
|
||||||
|
**F1-score** is the harmonic mean of the precision and recall:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{F1-score} = \frac{2\times\text{Recall}\times\text{Precision}}{\text{Recall}+\text{Precision}}
|
||||||
|
|
||||||
|
**Support** is the number of samples:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{Support} = TP + TN + FN + FP
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
thr (float, optional): Predictions with scores under the thresholds
|
thr (float, optional): Predictions with scores under the threshold
|
||||||
are considered as negative. Defaults to None.
|
are considered as negative. If None, the ``topk`` predictions will
|
||||||
|
be considered as positive. If the ``topk`` is also None, use
|
||||||
|
``thr=0.5`` as default. Defaults to None.
|
||||||
topk (int, optional): Predictions with the k-th highest scores are
|
topk (int, optional): Predictions with the k-th highest scores are
|
||||||
considered as positive. Defaults to None.
|
considered as positive. If None, use ``thr`` to determine positive
|
||||||
items (Sequence[str]): The detailed metric items to evaluate. Here is
|
predictions. If both ``thr`` and ``topk`` are not None, use
|
||||||
the available options:
|
``thr``. Defaults to None.
|
||||||
|
items (Sequence[str]): The detailed metric items to evaluate, select
|
||||||
|
from "precision", "recall", "f1-score" and "support".
|
||||||
|
Defaults to ``('precision', 'recall', 'f1-score')``.
|
||||||
|
average (str | None): How to calculate the final metrics from the
|
||||||
|
confusion matrix of every category. It supports three modes:
|
||||||
|
|
||||||
- `"precision"`: The ratio tp / (tp + fp) where tp is the
|
- `"macro"`: Calculate metrics for each category, and calculate
|
||||||
number of true positives and fp the number of false
|
the mean value over all categories.
|
||||||
positives.
|
- `"micro"`: Average the confusion matrix over all categories and
|
||||||
- `"recall"`: The ratio tp / (tp + fn) where tp is the number
|
calculate metrics on the mean confusion matrix.
|
||||||
of true positives and fn the number of false negatives.
|
- `None`: Calculate metrics of every category and output directly.
|
||||||
- `"f1-score"`: The f1-score is the harmonic mean of the
|
|
||||||
precision and recall.
|
|
||||||
- `"support"`: The total number of positive of each category
|
|
||||||
in the target.
|
|
||||||
|
|
||||||
Defaults to ('precision', 'recall', 'f1-score').
|
|
||||||
average (str | None): The average method. It supports three average
|
|
||||||
modes:
|
|
||||||
|
|
||||||
- `"macro"`: Calculate metrics for each category, and calculate
|
|
||||||
the mean value over all categories.
|
|
||||||
- `"micro"`: Calculate metrics globally by counting the total
|
|
||||||
true positives, false negatives and false positives.
|
|
||||||
- `None`: Return scores of all categories.
|
|
||||||
|
|
||||||
Defaults to "macro".
|
Defaults to "macro".
|
||||||
collect_device (str): Device name used for collecting results from
|
collect_device (str): Device name used for collecting results from
|
||||||
@ -261,15 +281,16 @@ class MultiLabelMetric(BaseMetric):
|
|||||||
target_indices (bool): Whether the ``target`` is a sequence of
|
target_indices (bool): Whether the ``target`` is a sequence of
|
||||||
category index labels. If True, ``num_classes`` must be set.
|
category index labels. If True, ``num_classes`` must be set.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
average (str | None): The average method. It supports three average
|
average (str | None): How to calculate the final metrics from
|
||||||
|
the confusion matrix of every category. It supports three
|
||||||
modes:
|
modes:
|
||||||
|
|
||||||
- `"macro"`: Calculate metrics for each category, and
|
- `"macro"`: Calculate metrics for each category, and calculate
|
||||||
calculate the mean value over all categories.
|
the mean value over all categories.
|
||||||
- `"micro"`: Calculate metrics globally by counting the
|
- `"micro"`: Average the confusion matrix over all categories
|
||||||
total true positives, false negatives and false
|
and calculate metrics on the mean confusion matrix.
|
||||||
positives.
|
- `None`: Calculate metrics of every category and output
|
||||||
- `None`: Return scores of all categories.
|
directly.
|
||||||
|
|
||||||
Defaults to "macro".
|
Defaults to "macro".
|
||||||
thr (float, optional): Predictions with scores under the thresholds
|
thr (float, optional): Predictions with scores under the thresholds
|
||||||
@ -402,14 +423,25 @@ def _average_precision(pred: torch.Tensor,
|
|||||||
|
|
||||||
@METRICS.register_module()
|
@METRICS.register_module()
|
||||||
class AveragePrecision(BaseMetric):
|
class AveragePrecision(BaseMetric):
|
||||||
"""Calculate the average precision with respect of classes.
|
r"""Calculate the average precision with respect of classes.
|
||||||
|
|
||||||
|
AveragePrecision (AP) summarizes a precision-recall curve as the weighted
|
||||||
|
mean of maximum precisions obtained for any r'>r, where r is the recall:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{AP} = \sum_n (R_n - R_{n-1}) P_n
|
||||||
|
|
||||||
|
Note that no approximation is involved since the curve is piecewise
|
||||||
|
constant.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
average (str | None): The average method. It supports two modes:
|
average (str | None): How to calculate the final metrics from
|
||||||
|
every category. It supports two modes:
|
||||||
|
|
||||||
- `"macro"`: Calculate metrics for each category, and calculate
|
- `"macro"`: Calculate metrics for each category, and calculate
|
||||||
the mean value over all categories.
|
the mean value over all categories. The result of this mode
|
||||||
- `None`: Return scores of all categories.
|
is also called **mAP**.
|
||||||
|
- `None`: Calculate metrics of every category and output directly.
|
||||||
|
|
||||||
Defaults to "macro".
|
Defaults to "macro".
|
||||||
collect_device (str): Device name used for collecting results from
|
collect_device (str): Device name used for collecting results from
|
||||||
@ -529,15 +561,6 @@ class AveragePrecision(BaseMetric):
|
|||||||
average: Optional[str] = 'macro') -> torch.Tensor:
|
average: Optional[str] = 'macro') -> torch.Tensor:
|
||||||
r"""Calculate the average precision for a single class.
|
r"""Calculate the average precision for a single class.
|
||||||
|
|
||||||
AP summarizes a precision-recall curve as the weighted mean of maximum
|
|
||||||
precisions obtained for any r'>r, where r is the recall:
|
|
||||||
|
|
||||||
.. math::
|
|
||||||
\text{AP} = \sum_n (R_n - R_{n-1}) P_n
|
|
||||||
|
|
||||||
Note that no approximation is involved since the curve is piecewise
|
|
||||||
constant.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pred (torch.Tensor | np.ndarray): The model predictions with
|
pred (torch.Tensor | np.ndarray): The model predictions with
|
||||||
shape ``(N, num_classes)``.
|
shape ``(N, num_classes)``.
|
||||||
@ -545,9 +568,11 @@ class AveragePrecision(BaseMetric):
|
|||||||
with shape ``(N, num_classes)``.
|
with shape ``(N, num_classes)``.
|
||||||
average (str | None): The average method. It supports two modes:
|
average (str | None): The average method. It supports two modes:
|
||||||
|
|
||||||
- `"macro"`: Calculate metrics for each category, and
|
- `"macro"`: Calculate metrics for each category, and calculate
|
||||||
calculate the mean value over all categories.
|
the mean value over all categories. The result of this mode
|
||||||
- `None`: Return scores of all categories.
|
is also called mAP.
|
||||||
|
- `None`: Calculate metrics of every category and output
|
||||||
|
directly.
|
||||||
|
|
||||||
Defaults to "macro".
|
Defaults to "macro".
|
||||||
|
|
||||||
|
@ -54,15 +54,25 @@ def _precision_recall_f1_support(pred_positive, gt_positive, average):
|
|||||||
|
|
||||||
@METRICS.register_module()
|
@METRICS.register_module()
|
||||||
class Accuracy(BaseMetric):
|
class Accuracy(BaseMetric):
|
||||||
"""Top-k accuracy evaluation metric.
|
r"""Accuracy evaluation metric.
|
||||||
|
|
||||||
|
For either binary classification or multi-class classification, the
|
||||||
|
accuracy is the fraction of correct predictions in all predictions:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\text{Accuracy} = \frac{N_{\text{correct}}}{N_{\text{all}}}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
topk (int | Sequence[int]): If the predictions in ``topk``
|
topk (int | Sequence[int]): If the ground truth label matches one of
|
||||||
matches the target, the predictions will be regarded as
|
the best **k** predictions, the sample will be regard as a positive
|
||||||
correct ones. Defaults to 1.
|
prediction. If the parameter is a tuple, all of top-k accuracy will
|
||||||
thrs (Sequence[float | None] | float | None): Predictions with scores
|
be calculated and outputted together. Defaults to 1.
|
||||||
under the thresholds are considered negative. None means no
|
thrs (Sequence[float | None] | float | None): If a float, predictions
|
||||||
thresholds. Defaults to 0.
|
with score lower than the threshold will be regard as the negative
|
||||||
|
prediction. If None, not apply threshold. If the parameter is a
|
||||||
|
tuple, accuracy based on all thresholds will be calculated and
|
||||||
|
outputted together. Defaults to 0.
|
||||||
collect_device (str): Device name used for collecting results from
|
collect_device (str): Device name used for collecting results from
|
||||||
different ranks during distributed training. Must be 'cpu' or
|
different ranks during distributed training. Must be 'cpu' or
|
||||||
'gpu'. Defaults to 'cpu'.
|
'gpu'. Defaults to 'cpu'.
|
||||||
@ -262,41 +272,59 @@ class Accuracy(BaseMetric):
|
|||||||
|
|
||||||
@METRICS.register_module()
|
@METRICS.register_module()
|
||||||
class SingleLabelMetric(BaseMetric):
|
class SingleLabelMetric(BaseMetric):
|
||||||
"""A collection of metrics for single-label multi-class classification task
|
r"""A collection of precision, recall, f1-score and support for
|
||||||
based on confusion matrix.
|
single-label tasks.
|
||||||
|
|
||||||
It includes precision, recall, f1-score and support. Comparing with
|
The collection of metrics is for single-label multi-class classification.
|
||||||
:class:`Accuracy`, these metrics doesn't support topk, but supports
|
And all these metrics are based on the confusion matrix of every category:
|
||||||
various average mode.
|
|
||||||
|
.. image:: ../../_static/image/confusion-matrix.png
|
||||||
|
:width: 60%
|
||||||
|
:align: center
|
||||||
|
|
||||||
|
All metrics can be formulated use variables above:
|
||||||
|
|
||||||
|
**Precision** is the fraction of correct predictions in all predictions:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{Precision} = \frac{TP}{TP+FP}
|
||||||
|
|
||||||
|
**Recall** is the fraction of correct predictions in all targets:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{Recall} = \frac{TP}{TP+FN}
|
||||||
|
|
||||||
|
**F1-score** is the harmonic mean of the precision and recall:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{F1-score} = \frac{2\times\text{Recall}\times\text{Precision}}{\text{Recall}+\text{Precision}}
|
||||||
|
|
||||||
|
**Support** is the number of samples:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{Support} = TP + TN + FN + FP
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
thrs (Sequence[float | None] | float | None): Predictions with scores
|
thrs (Sequence[float | None] | float | None): If a float, predictions
|
||||||
under the thresholds are considered negative. None means no
|
with score lower than the threshold will be regard as the negative
|
||||||
thresholds. Defaults to 0.
|
prediction. If None, only the top-1 prediction will be regard as
|
||||||
items (Sequence[str]): The detailed metric items to evaluate. Here is
|
the positive prediction. If the parameter is a tuple, accuracy
|
||||||
the available options:
|
based on all thresholds will be calculated and outputted together.
|
||||||
|
Defaults to 0.
|
||||||
|
items (Sequence[str]): The detailed metric items to evaluate, select
|
||||||
|
from "precision", "recall", "f1-score" and "support".
|
||||||
|
Defaults to ``('precision', 'recall', 'f1-score')``.
|
||||||
|
average (str | None): How to calculate the final metrics from the
|
||||||
|
confusion matrix of every category. It supports three modes:
|
||||||
|
|
||||||
- `"precision"`: The ratio tp / (tp + fp) where tp is the
|
- `"macro"`: Calculate metrics for each category, and calculate
|
||||||
number of true positives and fp the number of false
|
the mean value over all categories.
|
||||||
positives.
|
- `"micro"`: Average the confusion matrix over all categories and
|
||||||
- `"recall"`: The ratio tp / (tp + fn) where tp is the number
|
calculate metrics on the mean confusion matrix.
|
||||||
of true positives and fn the number of false negatives.
|
- `None`: Calculate metrics of every category and output directly.
|
||||||
- `"f1-score"`: The f1-score is the harmonic mean of the
|
|
||||||
precision and recall.
|
|
||||||
- `"support"`: The total number of occurrences of each category
|
|
||||||
in the target.
|
|
||||||
|
|
||||||
Defaults to ('precision', 'recall', 'f1-score').
|
|
||||||
average (str, optional): The average method. If None, the scores
|
|
||||||
for each class are returned. And it supports two average modes:
|
|
||||||
|
|
||||||
- `"macro"`: Calculate metrics for each category, and calculate
|
|
||||||
the mean value over all categories.
|
|
||||||
- `"micro"`: Calculate metrics globally by counting the total
|
|
||||||
true positives, false negatives and false positives.
|
|
||||||
|
|
||||||
Defaults to "macro".
|
Defaults to "macro".
|
||||||
num_classes (Optional, int): The number of classes. Defaults to None.
|
num_classes (int, optional): The number of classes. Defaults to None.
|
||||||
collect_device (str): Device name used for collecting results from
|
collect_device (str): Device name used for collecting results from
|
||||||
different ranks during distributed training. Must be 'cpu' or
|
different ranks during distributed training. Must be 'cpu' or
|
||||||
'gpu'. Defaults to 'cpu'.
|
'gpu'. Defaults to 'cpu'.
|
||||||
@ -343,7 +371,7 @@ class SingleLabelMetric(BaseMetric):
|
|||||||
'single-label/recall_classwise': [18.5, 18.5, 17.0, 20.0, 18.0],
|
'single-label/recall_classwise': [18.5, 18.5, 17.0, 20.0, 18.0],
|
||||||
'single-label/f1-score_classwise': [19.7, 18.6, 17.1, 19.7, 17.0]
|
'single-label/f1-score_classwise': [19.7, 18.6, 17.1, 19.7, 17.0]
|
||||||
}
|
}
|
||||||
"""
|
""" # noqa: E501
|
||||||
default_prefix: Optional[str] = 'single-label'
|
default_prefix: Optional[str] = 'single-label'
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -483,14 +511,16 @@ class SingleLabelMetric(BaseMetric):
|
|||||||
the thresholds are considered negative. It's only used
|
the thresholds are considered negative. It's only used
|
||||||
when ``pred`` is scores. None means no thresholds.
|
when ``pred`` is scores. None means no thresholds.
|
||||||
Defaults to (0., ).
|
Defaults to (0., ).
|
||||||
average (str, optional): The average method. If None, the scores
|
average (str | None): How to calculate the final metrics from
|
||||||
for each class are returned. And it supports two average modes:
|
the confusion matrix of every category. It supports three
|
||||||
|
modes:
|
||||||
|
|
||||||
- `"macro"`: Calculate metrics for each category, and
|
- `"macro"`: Calculate metrics for each category, and calculate
|
||||||
calculate the mean value over all categories.
|
the mean value over all categories.
|
||||||
- `"micro"`: Calculate metrics globally by counting the
|
- `"micro"`: Average the confusion matrix over all categories
|
||||||
total true positives, false negatives and false
|
and calculate metrics on the mean confusion matrix.
|
||||||
positives.
|
- `None`: Calculate metrics of every category and output
|
||||||
|
directly.
|
||||||
|
|
||||||
Defaults to "macro".
|
Defaults to "macro".
|
||||||
num_classes (Optional, int): The number of classes. If the ``pred``
|
num_classes (Optional, int): The number of classes. If the ``pred``
|
||||||
|
Loading…
x
Reference in New Issue
Block a user