diff --git a/docs/en/advanced_guides/evaluation.md b/docs/en/advanced_guides/evaluation.md
index b394c7690..55728281a 100644
--- a/docs/en/advanced_guides/evaluation.md
+++ b/docs/en/advanced_guides/evaluation.md
@@ -1 +1,158 @@
# Evaluation
+
+The evaluation procedure would be executed at [ValLoop](https://github.com/open-mmlab/mmengine/blob/main/mmengine/runner/loops.py#L300) and [TestLoop](https://github.com/open-mmlab/mmengine/blob/main/mmengine/runner/loops.py#L373), users can evaluate model performance during training or using the test script with simple settings in the configuration file. The `ValLoop` and `TestLoop` are properties of [Runner](https://github.com/open-mmlab/mmengine/blob/main/mmengine/runner/runner.py#L59), they will be built the first time they are called. To build the `ValLoop` successfully, the `val_dataloader` and `val_evaluator` must be set when building `Runner` since `dataloder` and `evaluator` are required parameters, and the same goes for `TestLoop`. For more information about the Runner's design, please refer to the [documentoation](https://github.com/open-mmlab/mmengine/blob/main/docs/en/design/runner.md) of [MMEngine](https://github.com/open-mmlab/mmengine).
+
+
+
+ test_step/val_step dataflow
+
+
+In MMSegmentation, we write the settings of dataloader and metrics in the config files of datasets and the configuration of the evaluation loop in the `schedule_x` config files by default.
+
+For example, in the ADE20K config file `configs/_base_/dataset/ade20k.py`, on lines 37 to 48, we configured the `val_dataloader`, on line 51, we select `IoUMetric` as the evaluator and set `mIoU` as the metric:
+
+```python
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='images/validation',
+ seg_map_path='annotations/validation'),
+ pipeline=test_pipeline))
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+```
+
+To be able to evaluate the model during training, for example, we add the evaluation configuration to the file `configs/schedules/schedule_40k.py` on lines 15 to 16:
+
+```python
+train_cfg = dict(type='IterBasedTrainLoop', max_iters=40000, val_interval=4000)
+val_cfg = dict(type='ValLoop')
+```
+
+With the above two settings, MMSegmentation evaluates the **mIoU** metric of the model once every 4000 iterations during the training of 40K iterations.
+
+If we would like to test the model after training, we need to add the `test_dataloader`, `test_evaluator` and `test_cfg` configs to the config file.
+
+```python
+test_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='images/validation',
+ seg_map_path='annotations/validation'),
+ pipeline=test_pipeline))
+
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_cfg = dict(type='TestLoop')
+```
+
+In MMSegmentation, the settings of `test_dataloader` and `test_evaluator` are the same as the `ValLoop`'s dataloader and evaluator by default, we can modify these settings to meet our needs.
+
+## IoUMetric
+
+MMSegmentation implements [IoUMetric](https://github.com/open-mmlab/mmsegmentation/blob/1.x/mmseg/evaluation/metrics/iou_metric.py) and [CitysMetric](https://github.com/open-mmlab/mmsegmentation/blob/1.x/mmseg/evaluation/metrics/citys_metric.py) for evaluating the performance of models, based on the [BaseMetric](https://github.com/open-mmlab/mmengine/blob/main/mmengine/evaluator/metric.py) provided by [MMEngine](https://github.com/open-mmlab/mmengine). Please refer to [the documentation](https://mmengine.readthedocs.io/en/latest/tutorials/evaluation.html) for more details about the unified evaluation interface.
+
+Here we briefly describe the arguments and the two main methods of `IoUMetric`.
+
+The constructor of `IoUMetric` has some additional parameters besides the base `collect_device` and `prefix`.
+
+The arguments of the constructor:
+
+- ignore_index (int) - Index that will be ignored in evaluation. Default: 255.
+- iou_metrics (list\[str\] | str) - Metrics to be calculated, the options includes 'mIoU', 'mDice' and 'mFscore'.
+- nan_to_num (int, optional) - If specified, NaN values will be replaced by the numbers defined by the user. Default: None.
+- beta (int) - Determines the weight of recall in the combined score. Default: 1.
+- collect_device (str) - Device name used for collecting results from different ranks during distributed training. Must be 'cpu' or 'gpu'. Defaults to 'cpu'.
+- prefix (str, optional) - The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If the prefix is not provided in the argument, self.default_prefix will be used instead. Defaults to None.
+
+`IoUMetric` implements the IoU metric calculation, the core two methods of `IoUMetric` are `process` and `compute_metrics`.
+
+- `process` method processes one batch of data and data_samples.
+- `compute_metrics` method computes the metrics from processed results.
+
+#### IoUMetric.process
+
+Parameters:
+
+- data_batch (Any) - A batch of data from the dataloader.
+- data_samples (Sequence\[dict\]) - A batch of outputs from the model.
+
+Returns:
+
+This method doesn't have returns since the processed results would be stored in `self.results`, which will be used to compute the metrics when all batches have been processed.
+
+#### IoUMetric.compute_metrics
+
+Parameters:
+
+- results (list) - The processed results of each batch.
+
+Returns:
+
+- Dict\[str, float\] - The computed metrics. The keys are the names of the metrics, and the values are corresponding results. The key mainly includes **aAcc**, **mIoU**, **mAcc**, **mDice**, **mFscore**, **mPrecision**, **mRecall**.
+
+## CitysMetric
+
+`CitysMetric` uses the official [CityscapesScripts](https://github.com/mcordts/cityscapesScripts) provided by Cityscapes to evaluate model performance.
+
+### Usage
+
+Before using it, please install the `cityscapesscripts` package first:
+
+```shell
+pip install cityscapesscripts
+```
+
+Since the `IoUMetric` is used as the default evaluator in MMSegmentation, if you would like to use `CitysMetric`, customizing the config file is required. In your customized config file, you should overwrite the default evaluator as follows.
+
+```python
+val_evaluator = dict(type='CitysMetric', citys_metrics=['cityscapes'])
+test_evaluator = val_evaluator
+```
+
+### Interface
+
+The arguments of the constructor:
+
+- ignore_index (int) - Index that will be ignored in evaluation. Default: 255.
+- citys_metrics (list\[str\] | str) - Metrics to be evaluated, Default: \['cityscapes'\].
+- to_label_id (bool) - whether convert output to label_id for submission. Default: True.
+- suffix (str): The filename prefix of the png files. If the prefix is "somepath/xxx", the png files will be named "somepath/xxx.png". Default: '.format_cityscapes'.
+- collect_device (str): Device name used for collecting results from different ranks during distributed training. Must be 'cpu' or 'gpu'. Defaults to 'cpu'.
+- prefix (str, optional): The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If the prefix is not provided in the argument, self.default_prefix will be used instead. Defaults to None.
+
+#### CitysMetric.process
+
+This method would draw the masks on images and save the painted images to `work_dir`.
+
+Parameters:
+
+- data_batch (Any) - A batch of data from the dataloader.
+- data_samples (Sequence\[dict\]) - A batch of outputs from the model.
+
+Returns:
+
+This method doesn't have returns, the annotations' path would be stored in `self.results`, which will be used to compute the metrics when all batches have been processed.
+
+#### CitysMetric.compute_metrics
+
+This method would call `cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling` tool to calculate metrics.
+
+Parameters:
+
+- results (list) - Testing results of the dataset.
+
+Returns:
+
+- dict\[str: float\] - Cityscapes evaluation results.
diff --git a/mmseg/evaluation/metrics/citys_metric.py b/mmseg/evaluation/metrics/citys_metric.py
index af6e8b00d..50e9ea68a 100644
--- a/mmseg/evaluation/metrics/citys_metric.py
+++ b/mmseg/evaluation/metrics/citys_metric.py
@@ -76,9 +76,8 @@ class CitysMetric(BaseMetric):
output.putpalette(palette)
output.save(png_filename)
- ann_dir = osp.join(
- data_batch[0]['data_sample']['seg_map_path'].split('val')[0],
- 'val')
+ ann_dir = osp.join(data_samples[0]['seg_map_path'].split('val')[0],
+ 'val')
self.results.append(ann_dir)
def compute_metrics(self, results: list) -> Dict[str, float]:
@@ -86,9 +85,6 @@ class CitysMetric(BaseMetric):
Args:
results (list): Testing results of the dataset.
- logger (logging.Logger | str | None): Logger used for printing
- related information during evaluation. Default: None.
- imgfile_prefix (str | None): The prefix of output image file
Returns:
dict[str: float]: Cityscapes evaluation results.
diff --git a/mmseg/evaluation/metrics/iou_metric.py b/mmseg/evaluation/metrics/iou_metric.py
index 5a6958b70..a152ef9dd 100644
--- a/mmseg/evaluation/metrics/iou_metric.py
+++ b/mmseg/evaluation/metrics/iou_metric.py
@@ -51,7 +51,7 @@ class IoUMetric(BaseMetric):
"""Process one batch of data and data_samples.
The processed results should be stored in ``self.results``, which will
- be used to computed the metrics when all batches have been processed.
+ be used to compute the metrics when all batches have been processed.
Args:
data_batch (dict): A batch of data from the dataloader.
diff --git a/tests/test_evaluation/test_metrics/test_citys_metric.py b/tests/test_evaluation/test_metrics/test_citys_metric.py
index 34c0c9a5e..a6d6db5ca 100644
--- a/tests/test_evaluation/test_metrics/test_citys_metric.py
+++ b/tests/test_evaluation/test_metrics/test_citys_metric.py
@@ -46,6 +46,8 @@ class TestCitysMetric(TestCase):
'tests/data/pseudo_cityscapes_dataset/gtFine/val/\
frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png'
+ mm_inputs['seg_map_path'] = mm_inputs['data_sample'][
+ 'seg_map_path']
packed_inputs.append(mm_inputs)
return packed_inputs
@@ -96,16 +98,20 @@ class TestCitysMetric(TestCase):
def test_evaluate(self):
"""Test using the metric in the same way as Evalutor."""
- data_batch = self._demo_mm_inputs()
- predictions = self._demo_mm_model_output()
+ data_batch = self._demo_mm_inputs(2)
+ predictions = self._demo_mm_model_output(2)
+ data_samples = [
+ dict(**data, **result)
+ for data, result in zip(data_batch, predictions)
+ ]
iou_metric = CitysMetric(citys_metrics=['cityscapes'])
- iou_metric.process(data_batch, predictions)
+ iou_metric.process(data_batch, data_samples)
res = iou_metric.evaluate(6)
self.assertIsInstance(res, dict)
# test to_label_id = True
iou_metric = CitysMetric(
citys_metrics=['cityscapes'], to_label_id=True)
- iou_metric.process(data_batch, predictions)
+ iou_metric.process(data_batch, data_samples)
res = iou_metric.evaluate(6)
self.assertIsInstance(res, dict)
import shutil