[Docs] Add custom evaluation docs (#1130)

* [Docs] Add evaluation docs

* minor fix

* Fix docs.

Co-authored-by: mzr1996 <mzr1996@163.com>
This commit is contained in:
Hubert 2022-11-01 18:54:06 +08:00 committed by GitHub
parent 280e916979
commit 50aaa711ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 269 additions and 96 deletions

View File

@ -25,3 +25,9 @@ article.pytorch-article section table code {
table.autosummary td { table.autosummary td {
width: 50% width: 50%
} }
img.align-center {
display: block;
margin-left: auto;
margin-right: auto;
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

View File

@ -1 +1,103 @@
# Custom evaluation metrics (TODO) # Customize Evaluation Metrics
## Use metrics in MMClassification
In MMClassification, we have provided multiple metrics for both single-label classification and multi-label
classification:
**Single-label Classification**:
- [`Accuracy`](mmcls.evaluation.Accuracy)
- [`SingleLabelMetric`](mmcls.evaluation.SingleLabelMetric), including precision, recall, f1-score and
support.
**Multi-label Classification**:
- [`AveragePrecision`](mmcls.evaluation.AveragePrecision), or AP (mAP).
- [`MultiLabelMetric`](mmcls.evaluation.MultiLabelMetric), including precision, recall, f1-score and
support.
To use these metrics during validation and testing, we need to modify the `val_evaluator` and `test_evaluator`
fields in the config file.
Here is several examples:
1. Calculate top-1 and top-5 accuracy during both validation and test.
```python
val_evaluator = dict(type='Accuracy', topk=(1, 5))
test_evaluator = val_evaluator
```
2. Calculate top-1 accuracy, top-5 accuracy, precision and recall during both validation and test.
```python
val_evaluator = [
dict(type='Accuracy', topk=(1, 5)),
dict(type='SingleLabelMetric', items=['precision', 'recall']),
]
test_evaluator = val_evaluator
```
3. Calculate mAP (mean AveragePrecision), CP (Class-wise mean Precision), CR (Class-wise mean Recall), CF
(Class-wise mean F1-score), OP (Overall mean Precision), OR (Overall mean Recall) and OF1 (Overall mean
F1-score).
```python
val_evaluator = [
dict(type='AveragePrecision'),
dict(type='MultiLabelMetric', average='macro'), # class-wise mean
dict(type='MultiLabelMetric', average='micro'), # overall mean
]
test_evaluator = val_evaluator
```
## Add new metrics
MMClassification supports the implementation of customized evaluation metrics for users who pursue higher customization.
You need to create a new file under `mmcls/evaluation/metrics`, and implement the new metric in the file, for example, in `mmcls/evaluation/metrics/my_metric.py`. And create a customized evaluation metric class `MyMetric` which inherits [`BaseMetric in MMEngine`](mmengine.evaluator.metrics.BaseMetric).
The data format processing method `process` and the metric calculation method `compute_metrics` need to be overwritten respectively. Add it to the `METRICS` registry to implement any customized evaluation metric.
```python
from mmengine.evaluator import BaseMetric
from mmcls.registry import METRICS
@METRICS.register_module()
class MyMetric(BaseMetric):
def process(self, data_batch: Sequence[Dict], data_samples: Sequence[Dict]):
""" The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
`data_batch` stores the batch data from dataloader,
and `data_samples` stores the batch outputs from model.
"""
...
def compute_metrics(self, results: List):
""" Compute the metrics from processed results and returns the evaluation results.
"""
...
```
Then, import it in the `mmcls/evaluation/metrics/__init__.py` to add it into the `mmcls.evaluation` package.
```python
# In mmcls/evaluation/metrics/__init__.py
...
from .my_metric import MyMetric
__all__ = [..., 'MyMetric']
```
Finally, use `MyMetric` in the `val_evaluator` and `test_evaluator` field of config files.
```python
val_evaluator = dict(type='MyMetric', ...)
test_evaluator = val_evaluator
```
```{note}
More details can be found in {external+mmengine:doc}`MMEngine Documentation: Evaluation <design/evaluation>`.
```

View File

@ -25,3 +25,9 @@ article.pytorch-article section table code {
table.autosummary td { table.autosummary td {
width: 50% width: 50%
} }
img.align-center {
display: block;
margin-left: auto;
margin-right: auto;
}

View File

@ -0,0 +1 @@
../../../en/_static/image/confusion-matrix.png

View File

@ -1 +1,4 @@
# 自定义评估指标(待更新) # 自定义评估指标(待更新)
请参见[英文文档](https://mmclassification.readthedocs.io/en/dev-1.x/advanced_guides/evaluation.html),如果你有兴
趣参与中文文档的翻译,欢迎在 [讨论区](https://github.com/open-mmlab/mmclassification/discussions/1027)进行报名。

View File

@ -13,38 +13,58 @@ from .single_label import _precision_recall_f1_support, to_tensor
@METRICS.register_module() @METRICS.register_module()
class MultiLabelMetric(BaseMetric): class MultiLabelMetric(BaseMetric):
"""A collection of metrics for multi-label multi-class classification task r"""A collection of precision, recall, f1-score and support for
based on confusion matrix. multi-label tasks.
It includes precision, recall, f1-score and support. The collection of metrics is for single-label multi-class classification.
And all these metrics are based on the confusion matrix of every category:
.. image:: ../../_static/image/confusion-matrix.png
:width: 60%
:align: center
All metrics can be formulated use variables above:
**Precision** is the fraction of correct predictions in all predictions:
.. math::
\text{Precision} = \frac{TP}{TP+FP}
**Recall** is the fraction of correct predictions in all targets:
.. math::
\text{Recall} = \frac{TP}{TP+FN}
**F1-score** is the harmonic mean of the precision and recall:
.. math::
\text{F1-score} = \frac{2\times\text{Recall}\times\text{Precision}}{\text{Recall}+\text{Precision}}
**Support** is the number of samples:
.. math::
\text{Support} = TP + TN + FN + FP
Args: Args:
thr (float, optional): Predictions with scores under the thresholds thr (float, optional): Predictions with scores under the threshold
are considered as negative. Defaults to None. are considered as negative. If None, the ``topk`` predictions will
be considered as positive. If the ``topk`` is also None, use
``thr=0.5`` as default. Defaults to None.
topk (int, optional): Predictions with the k-th highest scores are topk (int, optional): Predictions with the k-th highest scores are
considered as positive. Defaults to None. considered as positive. If None, use ``thr`` to determine positive
items (Sequence[str]): The detailed metric items to evaluate. Here is predictions. If both ``thr`` and ``topk`` are not None, use
the available options: ``thr``. Defaults to None.
items (Sequence[str]): The detailed metric items to evaluate, select
from "precision", "recall", "f1-score" and "support".
Defaults to ``('precision', 'recall', 'f1-score')``.
average (str | None): How to calculate the final metrics from the
confusion matrix of every category. It supports three modes:
- `"precision"`: The ratio tp / (tp + fp) where tp is the - `"macro"`: Calculate metrics for each category, and calculate
number of true positives and fp the number of false the mean value over all categories.
positives. - `"micro"`: Average the confusion matrix over all categories and
- `"recall"`: The ratio tp / (tp + fn) where tp is the number calculate metrics on the mean confusion matrix.
of true positives and fn the number of false negatives. - `None`: Calculate metrics of every category and output directly.
- `"f1-score"`: The f1-score is the harmonic mean of the
precision and recall.
- `"support"`: The total number of positive of each category
in the target.
Defaults to ('precision', 'recall', 'f1-score').
average (str | None): The average method. It supports three average
modes:
- `"macro"`: Calculate metrics for each category, and calculate
the mean value over all categories.
- `"micro"`: Calculate metrics globally by counting the total
true positives, false negatives and false positives.
- `None`: Return scores of all categories.
Defaults to "macro". Defaults to "macro".
collect_device (str): Device name used for collecting results from collect_device (str): Device name used for collecting results from
@ -261,15 +281,16 @@ class MultiLabelMetric(BaseMetric):
target_indices (bool): Whether the ``target`` is a sequence of target_indices (bool): Whether the ``target`` is a sequence of
category index labels. If True, ``num_classes`` must be set. category index labels. If True, ``num_classes`` must be set.
Defaults to False. Defaults to False.
average (str | None): The average method. It supports three average average (str | None): How to calculate the final metrics from
the confusion matrix of every category. It supports three
modes: modes:
- `"macro"`: Calculate metrics for each category, and - `"macro"`: Calculate metrics for each category, and calculate
calculate the mean value over all categories. the mean value over all categories.
- `"micro"`: Calculate metrics globally by counting the - `"micro"`: Average the confusion matrix over all categories
total true positives, false negatives and false and calculate metrics on the mean confusion matrix.
positives. - `None`: Calculate metrics of every category and output
- `None`: Return scores of all categories. directly.
Defaults to "macro". Defaults to "macro".
thr (float, optional): Predictions with scores under the thresholds thr (float, optional): Predictions with scores under the thresholds
@ -402,14 +423,25 @@ def _average_precision(pred: torch.Tensor,
@METRICS.register_module() @METRICS.register_module()
class AveragePrecision(BaseMetric): class AveragePrecision(BaseMetric):
"""Calculate the average precision with respect of classes. r"""Calculate the average precision with respect of classes.
AveragePrecision (AP) summarizes a precision-recall curve as the weighted
mean of maximum precisions obtained for any r'>r, where r is the recall:
.. math::
\text{AP} = \sum_n (R_n - R_{n-1}) P_n
Note that no approximation is involved since the curve is piecewise
constant.
Args: Args:
average (str | None): The average method. It supports two modes: average (str | None): How to calculate the final metrics from
every category. It supports two modes:
- `"macro"`: Calculate metrics for each category, and calculate - `"macro"`: Calculate metrics for each category, and calculate
the mean value over all categories. the mean value over all categories. The result of this mode
- `None`: Return scores of all categories. is also called **mAP**.
- `None`: Calculate metrics of every category and output directly.
Defaults to "macro". Defaults to "macro".
collect_device (str): Device name used for collecting results from collect_device (str): Device name used for collecting results from
@ -529,15 +561,6 @@ class AveragePrecision(BaseMetric):
average: Optional[str] = 'macro') -> torch.Tensor: average: Optional[str] = 'macro') -> torch.Tensor:
r"""Calculate the average precision for a single class. r"""Calculate the average precision for a single class.
AP summarizes a precision-recall curve as the weighted mean of maximum
precisions obtained for any r'>r, where r is the recall:
.. math::
\text{AP} = \sum_n (R_n - R_{n-1}) P_n
Note that no approximation is involved since the curve is piecewise
constant.
Args: Args:
pred (torch.Tensor | np.ndarray): The model predictions with pred (torch.Tensor | np.ndarray): The model predictions with
shape ``(N, num_classes)``. shape ``(N, num_classes)``.
@ -545,9 +568,11 @@ class AveragePrecision(BaseMetric):
with shape ``(N, num_classes)``. with shape ``(N, num_classes)``.
average (str | None): The average method. It supports two modes: average (str | None): The average method. It supports two modes:
- `"macro"`: Calculate metrics for each category, and - `"macro"`: Calculate metrics for each category, and calculate
calculate the mean value over all categories. the mean value over all categories. The result of this mode
- `None`: Return scores of all categories. is also called mAP.
- `None`: Calculate metrics of every category and output
directly.
Defaults to "macro". Defaults to "macro".

View File

@ -54,15 +54,25 @@ def _precision_recall_f1_support(pred_positive, gt_positive, average):
@METRICS.register_module() @METRICS.register_module()
class Accuracy(BaseMetric): class Accuracy(BaseMetric):
"""Top-k accuracy evaluation metric. r"""Accuracy evaluation metric.
For either binary classification or multi-class classification, the
accuracy is the fraction of correct predictions in all predictions:
.. math::
\text{Accuracy} = \frac{N_{\text{correct}}}{N_{\text{all}}}
Args: Args:
topk (int | Sequence[int]): If the predictions in ``topk`` topk (int | Sequence[int]): If the ground truth label matches one of
matches the target, the predictions will be regarded as the best **k** predictions, the sample will be regard as a positive
correct ones. Defaults to 1. prediction. If the parameter is a tuple, all of top-k accuracy will
thrs (Sequence[float | None] | float | None): Predictions with scores be calculated and outputted together. Defaults to 1.
under the thresholds are considered negative. None means no thrs (Sequence[float | None] | float | None): If a float, predictions
thresholds. Defaults to 0. with score lower than the threshold will be regard as the negative
prediction. If None, not apply threshold. If the parameter is a
tuple, accuracy based on all thresholds will be calculated and
outputted together. Defaults to 0.
collect_device (str): Device name used for collecting results from collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'. 'gpu'. Defaults to 'cpu'.
@ -262,41 +272,59 @@ class Accuracy(BaseMetric):
@METRICS.register_module() @METRICS.register_module()
class SingleLabelMetric(BaseMetric): class SingleLabelMetric(BaseMetric):
"""A collection of metrics for single-label multi-class classification task r"""A collection of precision, recall, f1-score and support for
based on confusion matrix. single-label tasks.
It includes precision, recall, f1-score and support. Comparing with The collection of metrics is for single-label multi-class classification.
:class:`Accuracy`, these metrics doesn't support topk, but supports And all these metrics are based on the confusion matrix of every category:
various average mode.
.. image:: ../../_static/image/confusion-matrix.png
:width: 60%
:align: center
All metrics can be formulated use variables above:
**Precision** is the fraction of correct predictions in all predictions:
.. math::
\text{Precision} = \frac{TP}{TP+FP}
**Recall** is the fraction of correct predictions in all targets:
.. math::
\text{Recall} = \frac{TP}{TP+FN}
**F1-score** is the harmonic mean of the precision and recall:
.. math::
\text{F1-score} = \frac{2\times\text{Recall}\times\text{Precision}}{\text{Recall}+\text{Precision}}
**Support** is the number of samples:
.. math::
\text{Support} = TP + TN + FN + FP
Args: Args:
thrs (Sequence[float | None] | float | None): Predictions with scores thrs (Sequence[float | None] | float | None): If a float, predictions
under the thresholds are considered negative. None means no with score lower than the threshold will be regard as the negative
thresholds. Defaults to 0. prediction. If None, only the top-1 prediction will be regard as
items (Sequence[str]): The detailed metric items to evaluate. Here is the positive prediction. If the parameter is a tuple, accuracy
the available options: based on all thresholds will be calculated and outputted together.
Defaults to 0.
items (Sequence[str]): The detailed metric items to evaluate, select
from "precision", "recall", "f1-score" and "support".
Defaults to ``('precision', 'recall', 'f1-score')``.
average (str | None): How to calculate the final metrics from the
confusion matrix of every category. It supports three modes:
- `"precision"`: The ratio tp / (tp + fp) where tp is the - `"macro"`: Calculate metrics for each category, and calculate
number of true positives and fp the number of false the mean value over all categories.
positives. - `"micro"`: Average the confusion matrix over all categories and
- `"recall"`: The ratio tp / (tp + fn) where tp is the number calculate metrics on the mean confusion matrix.
of true positives and fn the number of false negatives. - `None`: Calculate metrics of every category and output directly.
- `"f1-score"`: The f1-score is the harmonic mean of the
precision and recall.
- `"support"`: The total number of occurrences of each category
in the target.
Defaults to ('precision', 'recall', 'f1-score').
average (str, optional): The average method. If None, the scores
for each class are returned. And it supports two average modes:
- `"macro"`: Calculate metrics for each category, and calculate
the mean value over all categories.
- `"micro"`: Calculate metrics globally by counting the total
true positives, false negatives and false positives.
Defaults to "macro". Defaults to "macro".
num_classes (Optional, int): The number of classes. Defaults to None. num_classes (int, optional): The number of classes. Defaults to None.
collect_device (str): Device name used for collecting results from collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'. 'gpu'. Defaults to 'cpu'.
@ -343,7 +371,7 @@ class SingleLabelMetric(BaseMetric):
'single-label/recall_classwise': [18.5, 18.5, 17.0, 20.0, 18.0], 'single-label/recall_classwise': [18.5, 18.5, 17.0, 20.0, 18.0],
'single-label/f1-score_classwise': [19.7, 18.6, 17.1, 19.7, 17.0] 'single-label/f1-score_classwise': [19.7, 18.6, 17.1, 19.7, 17.0]
} }
""" """ # noqa: E501
default_prefix: Optional[str] = 'single-label' default_prefix: Optional[str] = 'single-label'
def __init__(self, def __init__(self,
@ -483,14 +511,16 @@ class SingleLabelMetric(BaseMetric):
the thresholds are considered negative. It's only used the thresholds are considered negative. It's only used
when ``pred`` is scores. None means no thresholds. when ``pred`` is scores. None means no thresholds.
Defaults to (0., ). Defaults to (0., ).
average (str, optional): The average method. If None, the scores average (str | None): How to calculate the final metrics from
for each class are returned. And it supports two average modes: the confusion matrix of every category. It supports three
modes:
- `"macro"`: Calculate metrics for each category, and - `"macro"`: Calculate metrics for each category, and calculate
calculate the mean value over all categories. the mean value over all categories.
- `"micro"`: Calculate metrics globally by counting the - `"micro"`: Average the confusion matrix over all categories
total true positives, false negatives and false and calculate metrics on the mean confusion matrix.
positives. - `None`: Calculate metrics of every category and output
directly.
Defaults to "macro". Defaults to "macro".
num_classes (Optional, int): The number of classes. If the ``pred`` num_classes (Optional, int): The number of classes. If the ``pred``