[Docs] Add custom evaluation docs (#1130)
* [Docs] Add evaluation docs * minor fix * Fix docs. Co-authored-by: mzr1996 <mzr1996@163.com>pull/1162/head
parent
280e916979
commit
50aaa711ea
|
@ -25,3 +25,9 @@ article.pytorch-article section table code {
|
|||
table.autosummary td {
|
||||
width: 50%
|
||||
}
|
||||
|
||||
img.align-center {
|
||||
display: block;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
}
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 51 KiB |
|
@ -1 +1,103 @@
|
|||
# Custom evaluation metrics (TODO)
|
||||
# Customize Evaluation Metrics
|
||||
|
||||
## Use metrics in MMClassification
|
||||
|
||||
In MMClassification, we have provided multiple metrics for both single-label classification and multi-label
|
||||
classification:
|
||||
|
||||
**Single-label Classification**:
|
||||
|
||||
- [`Accuracy`](mmcls.evaluation.Accuracy)
|
||||
- [`SingleLabelMetric`](mmcls.evaluation.SingleLabelMetric), including precision, recall, f1-score and
|
||||
support.
|
||||
|
||||
**Multi-label Classification**:
|
||||
|
||||
- [`AveragePrecision`](mmcls.evaluation.AveragePrecision), or AP (mAP).
|
||||
- [`MultiLabelMetric`](mmcls.evaluation.MultiLabelMetric), including precision, recall, f1-score and
|
||||
support.
|
||||
|
||||
To use these metrics during validation and testing, we need to modify the `val_evaluator` and `test_evaluator`
|
||||
fields in the config file.
|
||||
|
||||
Here is several examples:
|
||||
|
||||
1. Calculate top-1 and top-5 accuracy during both validation and test.
|
||||
|
||||
```python
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
test_evaluator = val_evaluator
|
||||
```
|
||||
|
||||
2. Calculate top-1 accuracy, top-5 accuracy, precision and recall during both validation and test.
|
||||
|
||||
```python
|
||||
val_evaluator = [
|
||||
dict(type='Accuracy', topk=(1, 5)),
|
||||
dict(type='SingleLabelMetric', items=['precision', 'recall']),
|
||||
]
|
||||
test_evaluator = val_evaluator
|
||||
```
|
||||
|
||||
3. Calculate mAP (mean AveragePrecision), CP (Class-wise mean Precision), CR (Class-wise mean Recall), CF
|
||||
(Class-wise mean F1-score), OP (Overall mean Precision), OR (Overall mean Recall) and OF1 (Overall mean
|
||||
F1-score).
|
||||
|
||||
```python
|
||||
val_evaluator = [
|
||||
dict(type='AveragePrecision'),
|
||||
dict(type='MultiLabelMetric', average='macro'), # class-wise mean
|
||||
dict(type='MultiLabelMetric', average='micro'), # overall mean
|
||||
]
|
||||
test_evaluator = val_evaluator
|
||||
```
|
||||
|
||||
## Add new metrics
|
||||
|
||||
MMClassification supports the implementation of customized evaluation metrics for users who pursue higher customization.
|
||||
|
||||
You need to create a new file under `mmcls/evaluation/metrics`, and implement the new metric in the file, for example, in `mmcls/evaluation/metrics/my_metric.py`. And create a customized evaluation metric class `MyMetric` which inherits [`BaseMetric in MMEngine`](mmengine.evaluator.metrics.BaseMetric).
|
||||
|
||||
The data format processing method `process` and the metric calculation method `compute_metrics` need to be overwritten respectively. Add it to the `METRICS` registry to implement any customized evaluation metric.
|
||||
|
||||
```python
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmcls.registry import METRICS
|
||||
|
||||
@METRICS.register_module()
|
||||
class MyMetric(BaseMetric):
|
||||
|
||||
def process(self, data_batch: Sequence[Dict], data_samples: Sequence[Dict]):
|
||||
""" The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
`data_batch` stores the batch data from dataloader,
|
||||
and `data_samples` stores the batch outputs from model.
|
||||
"""
|
||||
...
|
||||
|
||||
def compute_metrics(self, results: List):
|
||||
""" Compute the metrics from processed results and returns the evaluation results.
|
||||
"""
|
||||
...
|
||||
```
|
||||
|
||||
Then, import it in the `mmcls/evaluation/metrics/__init__.py` to add it into the `mmcls.evaluation` package.
|
||||
|
||||
```python
|
||||
# In mmcls/evaluation/metrics/__init__.py
|
||||
...
|
||||
from .my_metric import MyMetric
|
||||
|
||||
__all__ = [..., 'MyMetric']
|
||||
```
|
||||
|
||||
Finally, use `MyMetric` in the `val_evaluator` and `test_evaluator` field of config files.
|
||||
|
||||
```python
|
||||
val_evaluator = dict(type='MyMetric', ...)
|
||||
test_evaluator = val_evaluator
|
||||
```
|
||||
|
||||
```{note}
|
||||
More details can be found in {external+mmengine:doc}`MMEngine Documentation: Evaluation <design/evaluation>`.
|
||||
```
|
||||
|
|
|
@ -25,3 +25,9 @@ article.pytorch-article section table code {
|
|||
table.autosummary td {
|
||||
width: 50%
|
||||
}
|
||||
|
||||
img.align-center {
|
||||
display: block;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
}
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
../../../en/_static/image/confusion-matrix.png
|
|
@ -1 +1,4 @@
|
|||
# 自定义评估指标(待更新)
|
||||
|
||||
请参见[英文文档](https://mmclassification.readthedocs.io/en/dev-1.x/advanced_guides/evaluation.html),如果你有兴
|
||||
趣参与中文文档的翻译,欢迎在 [讨论区](https://github.com/open-mmlab/mmclassification/discussions/1027)进行报名。
|
||||
|
|
|
@ -13,38 +13,58 @@ from .single_label import _precision_recall_f1_support, to_tensor
|
|||
|
||||
@METRICS.register_module()
|
||||
class MultiLabelMetric(BaseMetric):
|
||||
"""A collection of metrics for multi-label multi-class classification task
|
||||
based on confusion matrix.
|
||||
r"""A collection of precision, recall, f1-score and support for
|
||||
multi-label tasks.
|
||||
|
||||
It includes precision, recall, f1-score and support.
|
||||
The collection of metrics is for single-label multi-class classification.
|
||||
And all these metrics are based on the confusion matrix of every category:
|
||||
|
||||
.. image:: ../../_static/image/confusion-matrix.png
|
||||
:width: 60%
|
||||
:align: center
|
||||
|
||||
All metrics can be formulated use variables above:
|
||||
|
||||
**Precision** is the fraction of correct predictions in all predictions:
|
||||
|
||||
.. math::
|
||||
\text{Precision} = \frac{TP}{TP+FP}
|
||||
|
||||
**Recall** is the fraction of correct predictions in all targets:
|
||||
|
||||
.. math::
|
||||
\text{Recall} = \frac{TP}{TP+FN}
|
||||
|
||||
**F1-score** is the harmonic mean of the precision and recall:
|
||||
|
||||
.. math::
|
||||
\text{F1-score} = \frac{2\times\text{Recall}\times\text{Precision}}{\text{Recall}+\text{Precision}}
|
||||
|
||||
**Support** is the number of samples:
|
||||
|
||||
.. math::
|
||||
\text{Support} = TP + TN + FN + FP
|
||||
|
||||
Args:
|
||||
thr (float, optional): Predictions with scores under the thresholds
|
||||
are considered as negative. Defaults to None.
|
||||
thr (float, optional): Predictions with scores under the threshold
|
||||
are considered as negative. If None, the ``topk`` predictions will
|
||||
be considered as positive. If the ``topk`` is also None, use
|
||||
``thr=0.5`` as default. Defaults to None.
|
||||
topk (int, optional): Predictions with the k-th highest scores are
|
||||
considered as positive. Defaults to None.
|
||||
items (Sequence[str]): The detailed metric items to evaluate. Here is
|
||||
the available options:
|
||||
considered as positive. If None, use ``thr`` to determine positive
|
||||
predictions. If both ``thr`` and ``topk`` are not None, use
|
||||
``thr``. Defaults to None.
|
||||
items (Sequence[str]): The detailed metric items to evaluate, select
|
||||
from "precision", "recall", "f1-score" and "support".
|
||||
Defaults to ``('precision', 'recall', 'f1-score')``.
|
||||
average (str | None): How to calculate the final metrics from the
|
||||
confusion matrix of every category. It supports three modes:
|
||||
|
||||
- `"precision"`: The ratio tp / (tp + fp) where tp is the
|
||||
number of true positives and fp the number of false
|
||||
positives.
|
||||
- `"recall"`: The ratio tp / (tp + fn) where tp is the number
|
||||
of true positives and fn the number of false negatives.
|
||||
- `"f1-score"`: The f1-score is the harmonic mean of the
|
||||
precision and recall.
|
||||
- `"support"`: The total number of positive of each category
|
||||
in the target.
|
||||
|
||||
Defaults to ('precision', 'recall', 'f1-score').
|
||||
average (str | None): The average method. It supports three average
|
||||
modes:
|
||||
|
||||
- `"macro"`: Calculate metrics for each category, and calculate
|
||||
the mean value over all categories.
|
||||
- `"micro"`: Calculate metrics globally by counting the total
|
||||
true positives, false negatives and false positives.
|
||||
- `None`: Return scores of all categories.
|
||||
- `"macro"`: Calculate metrics for each category, and calculate
|
||||
the mean value over all categories.
|
||||
- `"micro"`: Average the confusion matrix over all categories and
|
||||
calculate metrics on the mean confusion matrix.
|
||||
- `None`: Calculate metrics of every category and output directly.
|
||||
|
||||
Defaults to "macro".
|
||||
collect_device (str): Device name used for collecting results from
|
||||
|
@ -261,15 +281,16 @@ class MultiLabelMetric(BaseMetric):
|
|||
target_indices (bool): Whether the ``target`` is a sequence of
|
||||
category index labels. If True, ``num_classes`` must be set.
|
||||
Defaults to False.
|
||||
average (str | None): The average method. It supports three average
|
||||
average (str | None): How to calculate the final metrics from
|
||||
the confusion matrix of every category. It supports three
|
||||
modes:
|
||||
|
||||
- `"macro"`: Calculate metrics for each category, and
|
||||
calculate the mean value over all categories.
|
||||
- `"micro"`: Calculate metrics globally by counting the
|
||||
total true positives, false negatives and false
|
||||
positives.
|
||||
- `None`: Return scores of all categories.
|
||||
- `"macro"`: Calculate metrics for each category, and calculate
|
||||
the mean value over all categories.
|
||||
- `"micro"`: Average the confusion matrix over all categories
|
||||
and calculate metrics on the mean confusion matrix.
|
||||
- `None`: Calculate metrics of every category and output
|
||||
directly.
|
||||
|
||||
Defaults to "macro".
|
||||
thr (float, optional): Predictions with scores under the thresholds
|
||||
|
@ -402,14 +423,25 @@ def _average_precision(pred: torch.Tensor,
|
|||
|
||||
@METRICS.register_module()
|
||||
class AveragePrecision(BaseMetric):
|
||||
"""Calculate the average precision with respect of classes.
|
||||
r"""Calculate the average precision with respect of classes.
|
||||
|
||||
AveragePrecision (AP) summarizes a precision-recall curve as the weighted
|
||||
mean of maximum precisions obtained for any r'>r, where r is the recall:
|
||||
|
||||
.. math::
|
||||
\text{AP} = \sum_n (R_n - R_{n-1}) P_n
|
||||
|
||||
Note that no approximation is involved since the curve is piecewise
|
||||
constant.
|
||||
|
||||
Args:
|
||||
average (str | None): The average method. It supports two modes:
|
||||
average (str | None): How to calculate the final metrics from
|
||||
every category. It supports two modes:
|
||||
|
||||
- `"macro"`: Calculate metrics for each category, and calculate
|
||||
the mean value over all categories.
|
||||
- `None`: Return scores of all categories.
|
||||
- `"macro"`: Calculate metrics for each category, and calculate
|
||||
the mean value over all categories. The result of this mode
|
||||
is also called **mAP**.
|
||||
- `None`: Calculate metrics of every category and output directly.
|
||||
|
||||
Defaults to "macro".
|
||||
collect_device (str): Device name used for collecting results from
|
||||
|
@ -529,15 +561,6 @@ class AveragePrecision(BaseMetric):
|
|||
average: Optional[str] = 'macro') -> torch.Tensor:
|
||||
r"""Calculate the average precision for a single class.
|
||||
|
||||
AP summarizes a precision-recall curve as the weighted mean of maximum
|
||||
precisions obtained for any r'>r, where r is the recall:
|
||||
|
||||
.. math::
|
||||
\text{AP} = \sum_n (R_n - R_{n-1}) P_n
|
||||
|
||||
Note that no approximation is involved since the curve is piecewise
|
||||
constant.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor | np.ndarray): The model predictions with
|
||||
shape ``(N, num_classes)``.
|
||||
|
@ -545,9 +568,11 @@ class AveragePrecision(BaseMetric):
|
|||
with shape ``(N, num_classes)``.
|
||||
average (str | None): The average method. It supports two modes:
|
||||
|
||||
- `"macro"`: Calculate metrics for each category, and
|
||||
calculate the mean value over all categories.
|
||||
- `None`: Return scores of all categories.
|
||||
- `"macro"`: Calculate metrics for each category, and calculate
|
||||
the mean value over all categories. The result of this mode
|
||||
is also called mAP.
|
||||
- `None`: Calculate metrics of every category and output
|
||||
directly.
|
||||
|
||||
Defaults to "macro".
|
||||
|
||||
|
|
|
@ -54,15 +54,25 @@ def _precision_recall_f1_support(pred_positive, gt_positive, average):
|
|||
|
||||
@METRICS.register_module()
|
||||
class Accuracy(BaseMetric):
|
||||
"""Top-k accuracy evaluation metric.
|
||||
r"""Accuracy evaluation metric.
|
||||
|
||||
For either binary classification or multi-class classification, the
|
||||
accuracy is the fraction of correct predictions in all predictions:
|
||||
|
||||
.. math::
|
||||
|
||||
\text{Accuracy} = \frac{N_{\text{correct}}}{N_{\text{all}}}
|
||||
|
||||
Args:
|
||||
topk (int | Sequence[int]): If the predictions in ``topk``
|
||||
matches the target, the predictions will be regarded as
|
||||
correct ones. Defaults to 1.
|
||||
thrs (Sequence[float | None] | float | None): Predictions with scores
|
||||
under the thresholds are considered negative. None means no
|
||||
thresholds. Defaults to 0.
|
||||
topk (int | Sequence[int]): If the ground truth label matches one of
|
||||
the best **k** predictions, the sample will be regard as a positive
|
||||
prediction. If the parameter is a tuple, all of top-k accuracy will
|
||||
be calculated and outputted together. Defaults to 1.
|
||||
thrs (Sequence[float | None] | float | None): If a float, predictions
|
||||
with score lower than the threshold will be regard as the negative
|
||||
prediction. If None, not apply threshold. If the parameter is a
|
||||
tuple, accuracy based on all thresholds will be calculated and
|
||||
outputted together. Defaults to 0.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
|
@ -262,41 +272,59 @@ class Accuracy(BaseMetric):
|
|||
|
||||
@METRICS.register_module()
|
||||
class SingleLabelMetric(BaseMetric):
|
||||
"""A collection of metrics for single-label multi-class classification task
|
||||
based on confusion matrix.
|
||||
r"""A collection of precision, recall, f1-score and support for
|
||||
single-label tasks.
|
||||
|
||||
It includes precision, recall, f1-score and support. Comparing with
|
||||
:class:`Accuracy`, these metrics doesn't support topk, but supports
|
||||
various average mode.
|
||||
The collection of metrics is for single-label multi-class classification.
|
||||
And all these metrics are based on the confusion matrix of every category:
|
||||
|
||||
.. image:: ../../_static/image/confusion-matrix.png
|
||||
:width: 60%
|
||||
:align: center
|
||||
|
||||
All metrics can be formulated use variables above:
|
||||
|
||||
**Precision** is the fraction of correct predictions in all predictions:
|
||||
|
||||
.. math::
|
||||
\text{Precision} = \frac{TP}{TP+FP}
|
||||
|
||||
**Recall** is the fraction of correct predictions in all targets:
|
||||
|
||||
.. math::
|
||||
\text{Recall} = \frac{TP}{TP+FN}
|
||||
|
||||
**F1-score** is the harmonic mean of the precision and recall:
|
||||
|
||||
.. math::
|
||||
\text{F1-score} = \frac{2\times\text{Recall}\times\text{Precision}}{\text{Recall}+\text{Precision}}
|
||||
|
||||
**Support** is the number of samples:
|
||||
|
||||
.. math::
|
||||
\text{Support} = TP + TN + FN + FP
|
||||
|
||||
Args:
|
||||
thrs (Sequence[float | None] | float | None): Predictions with scores
|
||||
under the thresholds are considered negative. None means no
|
||||
thresholds. Defaults to 0.
|
||||
items (Sequence[str]): The detailed metric items to evaluate. Here is
|
||||
the available options:
|
||||
thrs (Sequence[float | None] | float | None): If a float, predictions
|
||||
with score lower than the threshold will be regard as the negative
|
||||
prediction. If None, only the top-1 prediction will be regard as
|
||||
the positive prediction. If the parameter is a tuple, accuracy
|
||||
based on all thresholds will be calculated and outputted together.
|
||||
Defaults to 0.
|
||||
items (Sequence[str]): The detailed metric items to evaluate, select
|
||||
from "precision", "recall", "f1-score" and "support".
|
||||
Defaults to ``('precision', 'recall', 'f1-score')``.
|
||||
average (str | None): How to calculate the final metrics from the
|
||||
confusion matrix of every category. It supports three modes:
|
||||
|
||||
- `"precision"`: The ratio tp / (tp + fp) where tp is the
|
||||
number of true positives and fp the number of false
|
||||
positives.
|
||||
- `"recall"`: The ratio tp / (tp + fn) where tp is the number
|
||||
of true positives and fn the number of false negatives.
|
||||
- `"f1-score"`: The f1-score is the harmonic mean of the
|
||||
precision and recall.
|
||||
- `"support"`: The total number of occurrences of each category
|
||||
in the target.
|
||||
|
||||
Defaults to ('precision', 'recall', 'f1-score').
|
||||
average (str, optional): The average method. If None, the scores
|
||||
for each class are returned. And it supports two average modes:
|
||||
|
||||
- `"macro"`: Calculate metrics for each category, and calculate
|
||||
the mean value over all categories.
|
||||
- `"micro"`: Calculate metrics globally by counting the total
|
||||
true positives, false negatives and false positives.
|
||||
- `"macro"`: Calculate metrics for each category, and calculate
|
||||
the mean value over all categories.
|
||||
- `"micro"`: Average the confusion matrix over all categories and
|
||||
calculate metrics on the mean confusion matrix.
|
||||
- `None`: Calculate metrics of every category and output directly.
|
||||
|
||||
Defaults to "macro".
|
||||
num_classes (Optional, int): The number of classes. Defaults to None.
|
||||
num_classes (int, optional): The number of classes. Defaults to None.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
|
@ -343,7 +371,7 @@ class SingleLabelMetric(BaseMetric):
|
|||
'single-label/recall_classwise': [18.5, 18.5, 17.0, 20.0, 18.0],
|
||||
'single-label/f1-score_classwise': [19.7, 18.6, 17.1, 19.7, 17.0]
|
||||
}
|
||||
"""
|
||||
""" # noqa: E501
|
||||
default_prefix: Optional[str] = 'single-label'
|
||||
|
||||
def __init__(self,
|
||||
|
@ -483,14 +511,16 @@ class SingleLabelMetric(BaseMetric):
|
|||
the thresholds are considered negative. It's only used
|
||||
when ``pred`` is scores. None means no thresholds.
|
||||
Defaults to (0., ).
|
||||
average (str, optional): The average method. If None, the scores
|
||||
for each class are returned. And it supports two average modes:
|
||||
average (str | None): How to calculate the final metrics from
|
||||
the confusion matrix of every category. It supports three
|
||||
modes:
|
||||
|
||||
- `"macro"`: Calculate metrics for each category, and
|
||||
calculate the mean value over all categories.
|
||||
- `"micro"`: Calculate metrics globally by counting the
|
||||
total true positives, false negatives and false
|
||||
positives.
|
||||
- `"macro"`: Calculate metrics for each category, and calculate
|
||||
the mean value over all categories.
|
||||
- `"micro"`: Average the confusion matrix over all categories
|
||||
and calculate metrics on the mean confusion matrix.
|
||||
- `None`: Calculate metrics of every category and output
|
||||
directly.
|
||||
|
||||
Defaults to "macro".
|
||||
num_classes (Optional, int): The number of classes. If the ``pred``
|
||||
|
|
Loading…
Reference in New Issue