[Refactor] Update analysis tools and documentations. (#1359)

* [Refactor] Update analysis tools and documentations.

* Update migration.md and add unit test.

* Fix print_config.py
pull/1366/head
Ma Zerun 2023-02-15 10:28:08 +08:00 committed by GitHub
parent b4ee9d2848
commit bedf4e9f64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 285 additions and 142 deletions

View File

@ -576,7 +576,7 @@ Changes in [heads](mmcls.models.heads):
| :--------------------------: | :-------------------------------------------------------------------------------------------------------------- |
| `collect_env` | No changes |
| `get_root_logger` | Removed, use [`mmengine.logging.MMLogger.get_current_instance`](mmengine.logging.MMLogger.get_current_instance) |
| `load_json_log` | Waiting for support |
| `load_json_log` | The output format changed. |
| `setup_multi_processes` | Removed, use [`mmengine.utils.dl_utils.set_multi_processing`](mmengine.utils.dl_utils.set_multi_processing). |
| `wrap_non_distributed_model` | Removed, we auto wrap the model in the runner. |
| `wrap_distributed_model` | Removed, we auto wrap the model in the runner. |

View File

@ -48,7 +48,7 @@ python tools/analysis_tools/analyze_logs.py plot_curve your_log_json --keys loss
#### Plot the top-1 accuracy and top-5 accuracy curves, and save the figure to results.jpg.
```shell
python tools/analysis_tools/analyze_logs.py plot_curve your_log_json --keys accuracy_top-1 accuracy_top-5 --legend top1 top5 --out results.jpg
python tools/analysis_tools/analyze_logs.py plot_curve your_log_json --keys accuracy/top1 accuracy/top5 --legend top1 top5 --out results.jpg
```
#### Compare the top-1 accuracy of two log files in the same figure.
@ -57,11 +57,6 @@ python tools/analysis_tools/analyze_logs.py plot_curve your_log_json --keys accu
python tools/analysis_tools/analyze_logs.py plot_curve log1.json log2.json --keys accuracy_top-1 --legend exp1 exp2
```
```{note}
The tool will automatically select to find keys in training logs or validation logs according to the keys.
Therefore, if you add a custom evaluation metric, please also add the key to `TEST_METRICS` in this tool.
```
### How to calculate training time
`tools/analysis_tools/analyze_logs.py` can also calculate the training time according to the log files.
@ -75,18 +70,18 @@ python tools/analysis_tools/analyze_logs.py cal_train_time \
**Description of all arguments**:
- `json_logs` : The paths of the log files, separate multiple files by spaces.
- `--include-outliers` : If set, include the first iteration in each epoch (Sometimes the time of first iterations is longer).
- `--include-outliers` : If set, include the first time record in each epoch (Sometimes the time of the first iteration is longer).
Example:
```shell
python tools/analysis_tools/analyze_logs.py cal_train_time work_dirs/some_exp/20200422_153324.log.json
python tools/analysis_tools/analyze_logs.py cal_train_time work_dirs/your_exp/20230206_181002/vis_data/scalars.json
```
The output is expected to be like the below.
```text
-----Analyze train time of work_dirs/some_exp/20200422_153324.log.json-----
-----Analyze train time of work_dirs/your_exp/20230206_181002/vis_data/scalars.json-----
slowest epoch 68, average time is 0.3818
fastest epoch 1, average time is 0.3694
time std over epochs is 0.0020
@ -104,34 +99,42 @@ We provide `tools/analysis_tools/eval_metric.py` to enable the user evaluate the
```shell
python tools/analysis_tools/eval_metric.py \
${CONFIG} \
${RESULT} \
[--metrics ${METRICS}] \
[--cfg-options ${CFG_OPTIONS}] \
[--metric-options ${METRIC_OPTIONS}]
[--metric ${METRIC_OPTIONS} ...] \
```
Description of all arguments:
- `config` : The path of the model config file.
- `result`: The Output result file in json/pickle format from `tools/test.py`.
- `--metrics` : Evaluation metrics, the acceptable values depend on the dataset.
- `--cfg-options`: If specified, the key-value pair config will be merged into the config file, for more details please refer to [Learn about Configs](../user_guides/config.md)
- `--metric-options`: If specified, the key-value pair arguments will be passed to the `metric_options` argument of dataset's `evaluate` function.
- `result`: The output result file in pickle format from `tools/test.py`.
- `--metric`: The metric and options to evaluate the results. You need to specify at least one metric and you
can also specify multiple `--metric` to use multiple metrics.
Please refer the [Metric Documentation](mmcls.evaluation) to find the available metrics and options.
```{note}
In `tools/test.py`, we support using `--out-items` option to select which kind of results will be saved. Please ensure the result file includes "class_scores" to use this tool.
In `tools/test.py`, we support using `--out-item` option to select which kind of results will be saved.
Please ensure the `--out-item` is not specified or `--out-item=pred` to use this tool.
```
**Examples**:
```shell
python tools/analysis_tools/eval_metric.py configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py your_result.pkl --metrics accuracy --metric-options "topk=(1,5)"
# Get the prediction results
python tools/test.py configs/resnet/resnet18_8xb16_cifar10.py \
https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth \
--out results.pkl
# Eval the top-1 and top-5 accuracy
python tools/analysis_tools/eval_metric.py results.pkl --metric type=Accuracy topk=1,5
# Eval accuracy, precision, recall and f1-score
python tools/analysis_tools/eval_metric.py results.pkl --metric type=Accuracy \
--metric type=SingleLabelMetric items=precision,recall,f1-score
```
### How to visualize the prediction results
We can also use this tool `tools/analysis_tools/analyze_results.py` to save the images with the highest scores in successful or failed prediction.
We can use `tools/analysis_tools/analyze_results.py` to save the images with the highest scores in successful or failed prediction.
```shell
python tools/analysis_tools/analyze_results.py \
@ -139,6 +142,7 @@ python tools/analysis_tools/analyze_results.py \
${RESULT} \
[--out-dir ${OUT_DIR}] \
[--topk ${TOPK}] \
[--rescale-factor ${RESCALE_FACTOR}] \
[--cfg-options ${CFG_OPTIONS}]
```
@ -148,18 +152,28 @@ python tools/analysis_tools/analyze_results.py \
- `result`: Output result file in json/pickle format from `tools/test.py`.
- `--out_dir`: Directory to store output files.
- `--topk`: The number of images in successful or failed prediction with the highest `topk` scores to save. If not specified, it will be set to 20.
- `--rescale-factor`: Image rescale factor, which is useful if the output is too large or too small (Too small
images may cause the prediction tag is too vague).
- `--cfg-options`: If specified, the key-value pair config will be merged into the config file, for more details please refer to [Learn about Configs](../user_guides/config.md)
```{note}
In `tools/test.py`, we support using `--out-items` option to select which kind of results will be saved. Please ensure the result file includes "pred_score", "pred_label" and "pred_class" to use this tool.
In `tools/test.py`, we support using `--out-item` option to select which kind of results will be saved.
Please ensure the `--out-item` is not specified or `--out-item=pred` to use this tool.
```
**Examples**:
```shell
# Get the prediction results
python tools/test.py configs/resnet/resnet18_8xb16_cifar10.py \
https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth \
--out results.pkl
# Save the top-10 successful and failed predictions. And enlarge the sample images by 10 times.
python tools/analysis_tools/analyze_results.py \
configs/resnet/resnet50_b32x8_imagenet.py \
result.pkl \
--out_dir results \
--topk 50
configs/resnet/resnet18_8xb16_cifar10.py \
results.pkl \
--out-dir output \
--topk 10 \
--rescale-factor 10
```

View File

@ -18,8 +18,10 @@ Description of all arguments:
## Examples
Print the complete config of `configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py`
```shell
# Print a complete config
python tools/misc/print_config.py configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py
# Save the complete config to a independent config file.
python tools/misc/print_config.py configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py > final_config.py
```

View File

@ -564,7 +564,7 @@ visualizer = dict(
| :--------------------------: | :------------------------------------------------------------------------------------------------------------ |
| `collect_env` | 无变动 |
| `get_root_logger` | 移除,使用 [`mmengine.logging.MMLogger.get_current_instance`](mmengine.logging.MMLogger.get_current_instance) |
| `load_json_log` | 待支持 |
| `load_json_log` | 输出格式发生变化。 |
| `setup_multi_processes` | 移除,使用 [`mmengine.utils.dl_utils.set_multi_processing`](mmengine.utils.dl_utils.set_multi_processing) |
| `wrap_non_distributed_model` | 移除,现在 runner 会自动包装模型。 |
| `wrap_distributed_model` | 移除,现在 runner 会自动包装模型。 |

View File

@ -48,7 +48,7 @@ python tools/analysis_tools/analyze_logs.py plot_curve your_log_json --keys loss
#### 绘制某日志文件对应的 top-1 和 top-5 准确率曲线图,并将曲线图导出为 results.jpg 文件。
```shell
python tools/analysis_tools/analyze_logs.py plot_curve your_log_json --keys accuracy_top-1 accuracy_top-5 --legend top1 top5 --out results.jpg
python tools/analysis_tools/analyze_logs.py plot_curve your_log_json --keys accuracy/top1 accuracy/top5 --legend top1 top5 --out results.jpg
```
#### 在同一图像内绘制两份日志文件对应的 top-1 准确率曲线图。
@ -57,11 +57,6 @@ python tools/analysis_tools/analyze_logs.py plot_curve your_log_json --keys accu
python tools/analysis_tools/analyze_logs.py plot_curve log1.json log2.json --keys accuracy_top-1 --legend exp1 exp2
```
```{note}
The tool will automatically select to find keys in training logs or validation logs according to the keys.
Therefore, if you add a custom evaluation metric, please also add the key to `TEST_METRICS` in this tool.
```
### 如何统计训练时间
`tools/analysis_tools/analyze_logs.py` 也可以根据日志文件统计训练耗时。
@ -74,19 +69,19 @@ python tools/analysis_tools/analyze_logs.py cal_train_time \
**所有参数的说明:**
- `json_logs` : 模型配置文件的路径(可同时传入多个,使用空格分开)。
- `--include-outliers` :如果指定,将不会排除每个轮次中第一轮迭代的记录(有时第一轮迭代会耗时较长)。
- `json_logs`模型配置文件的路径(可同时传入多个,使用空格分开)。
- `--include-outliers`:如果指定,将不会排除每个轮次中第一个时间记录(有时第一轮迭代会耗时较长)。
**示例:**
```shell
python tools/analysis_tools/analyze_logs.py cal_train_time work_dirs/some_exp/20200422_153324.log.json
python tools/analysis_tools/analyze_logs.py cal_train_time work_dirs/your_exp/20230206_181002/vis_data/scalars.json
```
预计输出结果如下所示:
```text
-----Analyze train time of work_dirs/some_exp/20200422_153324.log.json-----
-----Analyze train time of work_dirs/your_exp/20230206_181002/vis_data/scalars.json-----
slowest epoch 68, average time is 0.3818
fastest epoch 1, average time is 0.3694
time std over epochs is 0.0020
@ -95,7 +90,7 @@ average iter time: 0.3777 s/iter
## 结果分析
利用 `tools/test.py`的`--out` w我们可以将所有的样本的推理结果保存到输出 文件中。利用这一文件,我们可以进行进一步的分析。
利用 `tools/test.py` 的`--out`,我们可以将所有的样本的推理结果保存到输出文件中。利用这一文件,我们可以进行进一步的分析。
### 如何进行离线度量评估
@ -103,29 +98,36 @@ average iter time: 0.3777 s/iter
```shell
python tools/analysis_tools/eval_metric.py \
${CONFIG} \
${RESULT} \
[--metrics ${METRICS}] \
[--cfg-options ${CFG_OPTIONS}] \
[--metric-options ${METRIC_OPTIONS}]
[--metric ${METRIC_OPTIONS} ...] \
```
**所有参数说明**
- `config` : 配置文件的路径。
- `result`: `tools/test.py`的输出结果文件。
- `--metrics` : 评估的衡量指标,可接受的值取决于数据集类。
- `--cfg-options`:额外的配置选项,会被合入配置文件,参考[教程 1如何编写配置文件](https://mmclassification.readthedocs.io/zh_CN/latest/tutorials/config.html)。
- `--metric-options`:如果指定了,这些选项将被传递给数据集 `evaluate` 函数的 `metric_options` 参数。
- `result``tools/test.py` 输出的结果文件。
- `--metric`:用于评估结果的指标,请至少指定一个指标,并且你可以通过指定多个 `--metric` 来同时计算多个指标。
请参考[评估文档](mmcls.evaluation)选择可用的评估指标和对应的选项。
```{note}
In `tools/test.py`, we support using `--out-items` option to select which kind of results will be saved. Please ensure the result file includes "class_scores" to use this tool.
`tools/test.py` 中,我们支持使用 `--out-item` 选项来选择保存何种结果至输出文件。
请确保没有额外指定 `--out-item`,或指定了 `--out-item=pred`
```
**示例**:
```shell
python tools/analysis_tools/eval_metric.py configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py your_result.pkl --metrics accuracy --metric-options "topk=(1,5)"
# 获取结果文件
python tools/test.py configs/resnet/resnet18_8xb16_cifar10.py \
https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth \
--out results.pkl
# 计算 top-1 和 top-5 准确率
python tools/analysis_tools/eval_metric.py results.pkl --metric type=Accuracy topk=1,5
# 计算准确率、精确度、召回率、F1-score
python tools/analysis_tools/eval_metric.py results.pkl --metric type=Accuracy \
--metric type=SingleLabelMetric items=precision,recall,f1-score
```
### 如何将预测结果可视化
@ -138,27 +140,37 @@ python tools/analysis_tools/analyze_results.py \
${RESULT} \
[--out-dir ${OUT_DIR}] \
[--topk ${TOPK}] \
[--rescale-factor ${RESCALE_FACTOR}] \
[--cfg-options ${CFG_OPTIONS}]
```
**所有参数说明:**:
- `config` : 配置文件的路径。
- `result`: `tools/test.py`的输出结果文件。
- `--out_dir`:保存结果分析的文件夹路径。
- `--topk`: 分别保存多少张预测成功/失败的图像。如果不指定,默认为 `20`
- `--cfg-options`: 额外的配置选项,会被合入配置文件,参考[教程 1如何编写配置文件](https://mmclassification.readthedocs.io/zh_CN/latest/tutorials/config.html)。
- `config`:配置文件的路径。
- `result``tools/test.py`的输出结果文件。
- `--out_dir`:保存结果分析的文件夹路径。
- `--topk`:分别保存多少张预测成功/失败的图像。如果不指定,默认为 `20`
- `--rescale-factor`:图像的缩放系数,如果样本图像过大或过小时可以使用(过小的图像可能导致结果标签非常模糊)。
- `--cfg-options`:额外的配置选项,会被合入配置文件,参考[学习配置文件](../user_guides/config.md)。
```{note}
In `tools/test.py`, we support using `--out-items` option to select which kind of results will be saved. Please ensure the result file includes "pred_score", "pred_label" and "pred_class" to use this tool.
`tools/test.py` 中,我们支持使用 `--out-item` 选项来选择保存何种结果至输出文件。
请确保没有额外指定 `--out-item`,或指定了 `--out-item=pred`
```
**示例**:
```shell
# 获取预测结果文件
python tools/test.py configs/resnet/resnet18_8xb16_cifar10.py \
https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth \
--out results.pkl
# 保存预测成功/失败的图像中,得分最高的前 10 张,并在可视化时将输出图像放大 10 倍。
python tools/analysis_tools/analyze_results.py \
configs/resnet/resnet50_b32x8_imagenet.py \
result.pkl \
--out_dir results \
--topk 50
configs/resnet/resnet18_8xb16_cifar10.py \
results.pkl \
--out-dir output \
--topk 10 \
--rescale-factor 10
```

View File

@ -13,12 +13,16 @@ python tools/misc/print_config.py ${CONFIG} [--cfg-options ${CFG_OPTIONS}]
所有参数的说明:
- `config` : 模型配置文件的路径。
- `--cfg-options`::额外的配置选项,会被合入配置文件,参考[教程 1如何编写配置文件](https://mmclassification.readthedocs.io/zh_CN/latest/tutorials/config.html)。
- `--cfg-options`::额外的配置选项,会被合入配置文件,参考[学习配置文件](../user_guides/config.md)。
## 示例:
打印`configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py`文件的完整配置
```shell
# 打印完成的配置文件
python tools/misc/print_config.py configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py
# 将完整的配置文件保存为一个独立的配置文件
python tools/misc/print_config.py configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py > final_config.py
```

View File

@ -1,6 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .analyze import load_json_log
from .collect_env import collect_env
from .progress import track_on_main_process
from .setup_env import register_all_modules
__all__ = ['collect_env', 'register_all_modules', 'track_on_main_process']
__all__ = [
'collect_env', 'register_all_modules', 'track_on_main_process',
'load_json_log'
]

View File

@ -0,0 +1,43 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
def load_json_log(json_log):
"""load and convert json_logs to log_dicts.
Args:
json_log (str): The path of the json log file.
Returns:
dict: The result dict contains two items, "train" and "val", for
the training log and validate log.
Example:
An example output:
.. code-block:: python
{
'train': [
{"lr": 0.1, "time": 0.02, "epoch": 1, "step": 100},
{"lr": 0.1, "time": 0.02, "epoch": 1, "step": 200},
{"lr": 0.1, "time": 0.02, "epoch": 1, "step": 300},
...
]
'val': [
{"accuracy/top1": 32.1, "step": 1},
{"accuracy/top1": 50.2, "step": 2},
{"accuracy/top1": 60.3, "step": 2},
...
]
}
"""
log_dict = dict(train=[], val=[])
with open(json_log, 'r') as log_file:
for line in log_file:
log = json.loads(line.strip())
# A hack trick to determine whether the line is training log.
mode = 'train' if 'lr' in log else 'val'
log_dict[mode].append(log)
return log_dict

View File

@ -0,0 +1,43 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from mmcls.utils import load_json_log
def test_load_json_log():
demo_log = """\
{"lr": 0.0001, "data_time": 0.003, "loss": 2.29, "time": 0.010, "epoch": 1, "step": 150}
{"lr": 0.0001, "data_time": 0.002, "loss": 2.28, "time": 0.007, "epoch": 1, "step": 300}
{"lr": 0.0001, "data_time": 0.001, "loss": 2.27, "time": 0.008, "epoch": 1, "step": 450}
{"accuracy/top1": 23.98, "accuracy/top5": 66.05, "step": 1}
{"lr": 0.0001, "data_time": 0.001, "loss": 2.25, "time": 0.014, "epoch": 2, "step": 619}
{"lr": 0.0001, "data_time": 0.000, "loss": 2.24, "time": 0.012, "epoch": 2, "step": 769}
{"lr": 0.0001, "data_time": 0.003, "loss": 2.23, "time": 0.009, "epoch": 2, "step": 919}
{"accuracy/top1": 41.82, "accuracy/top5": 81.26, "step": 2}
{"lr": 0.0001, "data_time": 0.002, "loss": 2.21, "time": 0.007, "epoch": 3, "step": 1088}
{"lr": 0.0001, "data_time": 0.005, "loss": 2.18, "time": 0.009, "epoch": 3, "step": 1238}
{"lr": 0.0001, "data_time": 0.002, "loss": 2.16, "time": 0.008, "epoch": 3, "step": 1388}
{"accuracy/top1": 54.07, "accuracy/top5": 89.80, "step": 3}
""" # noqa: E501
with tempfile.TemporaryDirectory() as tmpdir:
json_log = osp.join(tmpdir, 'scalars.json')
with open(json_log, 'w') as f:
f.write(demo_log)
log_dict = load_json_log(json_log)
assert log_dict.keys() == {'train', 'val'}
assert log_dict['train'][3] == {
'lr': 0.0001,
'data_time': 0.001,
'loss': 2.25,
'time': 0.014,
'epoch': 2,
'step': 619
}
assert log_dict['val'][2] == {
'accuracy/top1': 54.07,
'accuracy/top5': 89.80,
'step': 3
}

View File

@ -2,37 +2,40 @@
import argparse
import os
import re
from itertools import groupby
import matplotlib.pyplot as plt
import numpy as np
from mmcls.utils import load_json_log
TEST_METRICS = ('precision', 'recall', 'f1_score', 'support', 'mAP', 'CP',
'CR', 'CF1', 'OP', 'OR', 'OF1', 'accuracy')
def cal_train_time(log_dicts, args):
"""Compute the average time per training iteration."""
for i, log_dict in enumerate(log_dicts):
print(f'{"-" * 5}Analyze train time of {args.json_logs[i]}{"-" * 5}')
all_times = []
for epoch in log_dict.keys():
if args.include_outliers:
all_times.append(log_dict[epoch]['time'])
else:
all_times.append(log_dict[epoch]['time'][1:])
all_times = np.array(all_times)
epoch_ave_time = all_times.mean(-1)
slowest_epoch = epoch_ave_time.argmax()
fastest_epoch = epoch_ave_time.argmin()
std_over_epoch = epoch_ave_time.std()
print(f'slowest epoch {slowest_epoch + 1}, '
f'average time is {epoch_ave_time[slowest_epoch]:.4f}')
print(f'fastest epoch {fastest_epoch + 1}, '
f'average time is {epoch_ave_time[fastest_epoch]:.4f}')
print(f'time std over epochs is {std_over_epoch:.4f}')
print(f'average iter time: {np.mean(all_times):.4f} s/iter')
train_logs = log_dict['train']
if 'epoch' in train_logs[0]:
epoch_ave_times = []
for _, logs in groupby(train_logs, lambda log: log['epoch']):
if args.include_outliers:
all_time = np.array([log['time'] for log in logs])
else:
all_time = np.array([log['time'] for log in logs])[1:]
epoch_ave_times.append(all_time.mean())
epoch_ave_times = np.array(epoch_ave_times)
slowest_epoch = epoch_ave_times.argmax()
fastest_epoch = epoch_ave_times.argmin()
std_over_epoch = epoch_ave_times.std()
print(f'slowest epoch {slowest_epoch + 1}, '
f'average time is {epoch_ave_times[slowest_epoch]:.4f}')
print(f'fastest epoch {fastest_epoch + 1}, '
f'average time is {epoch_ave_times[fastest_epoch]:.4f}')
print(f'time std over epochs is {std_over_epoch:.4f}')
avg_iter_time = np.array([log['time'] for log in train_logs]).mean()
print(f'average iter time: {avg_iter_time:.4f} s/iter')
print()
@ -52,35 +55,27 @@ def get_legends(args):
return legend
def plot_phase_train(metric, log_dict, epochs, curve_label, json_log):
"""plot phase of train cruve."""
if metric not in log_dict[epochs[0]]:
raise KeyError(f'{json_log} does not contain metric {metric}'
f' in train mode')
xs, ys = [], []
for epoch in epochs:
iters = log_dict[epoch]['iter']
if log_dict[epoch]['mode'][-1] == 'val':
iters = iters[:-1]
num_iters_per_epoch = iters[-1]
assert len(iters) > 0, (
'The training log is empty, please try to reduce the '
'interval of log in config file.')
xs.append(np.array(iters) / num_iters_per_epoch + (epoch - 1))
ys.append(np.array(log_dict[epoch][metric][:len(iters)]))
xs = np.concatenate(xs)
ys = np.concatenate(ys)
plt.xlabel('Epochs')
def plot_phase_train(metric, train_logs, curve_label):
"""plot phase of train curve."""
xs = np.array([log['step'] for log in train_logs])
ys = np.array([log[metric] for log in train_logs])
if 'epoch' in train_logs[0]:
scale_factor = train_logs[-1]['step'] / train_logs[-1]['epoch']
xs = xs / scale_factor
plt.xlabel('Epochs')
else:
plt.xlabel('Iters')
plt.plot(xs, ys, label=curve_label, linewidth=0.75)
def plot_phase_val(metric, log_dict, epochs, curve_label, json_log):
"""plot phase of val cruves."""
# some epoch may not have evaluation. as [(train, 5),(val, 1)]
xs = [e for e in epochs if metric in log_dict[e]]
ys = [log_dict[e][metric] for e in xs if metric in log_dict[e]]
assert len(xs) > 0, (f'{json_log} does not contain metric {metric}')
plt.xlabel('Epochs')
def plot_phase_val(metric, val_logs, curve_label):
"""plot phase of val curve."""
xs = np.array([log['step'] for log in val_logs])
ys = np.array([log[metric] for log in val_logs])
plt.xlabel('Steps')
plt.plot(xs, ys, label=curve_label, linewidth=0.75)
@ -88,16 +83,23 @@ def plot_curve_helper(log_dicts, metrics, args, legend):
"""plot curves from log_dicts by metrics."""
num_metrics = len(metrics)
for i, log_dict in enumerate(log_dicts):
epochs = list(log_dict.keys())
for j, metric in enumerate(metrics):
for j, key in enumerate(metrics):
json_log = args.json_logs[i]
print(f'plot curve of {json_log}, metric is {metric}')
print(f'plot curve of {json_log}, metric is {key}')
curve_label = legend[i * num_metrics + j]
if any(m in metric for m in TEST_METRICS):
plot_phase_val(metric, log_dict, epochs, curve_label, json_log)
train_keys = {} if len(log_dict['train']) == 0 else set(
log_dict['train'][0].keys()) - {'step', 'epoch'}
val_keys = {} if len(log_dict['val']) == 0 else set(
log_dict['val'][0].keys()) - {'step'}
if key in val_keys:
plot_phase_val(key, log_dict['val'], curve_label)
elif key in train_keys:
plot_phase_train(key, log_dict['train'], curve_label)
else:
plot_phase_train(metric, log_dict, epochs, curve_label,
json_log)
raise ValueError(f'Invalid key "{key}", please choose from '
f'{set.union(train_keys, val_keys)}.')
plt.legend()
@ -208,7 +210,10 @@ def main():
log_dicts = [load_json_log(json_log) for json_log in json_logs]
eval(args.task)(log_dicts, args)
if args.task == 'cal_train_time':
cal_train_time(log_dicts, args)
elif args.task == 'plot_curve':
plot_curve(log_dicts, args)
if __name__ == '__main__':

View File

@ -47,8 +47,7 @@ def parse_args():
def save_imgs(result_dir, folder_name, results, dataset, rescale_factor=None):
full_dir = osp.join(result_dir, folder_name)
vis = ClsVisualizer(
save_dir=full_dir, vis_backends=[dict(type='LocalVisBackend')])
vis = ClsVisualizer()
vis.dataset_meta = {'classes': dataset.CLASSES}
# save imgs
@ -68,7 +67,8 @@ def save_imgs(result_dir, folder_name, results, dataset, rescale_factor=None):
raise ValueError('Cannot load images from the dataset infos.')
if rescale_factor is not None:
img = mmcv.imrescale(img, rescale_factor)
vis.add_datasample(name, img, data_sample)
vis.add_datasample(
name, img, data_sample, out_file=osp.join(full_dir, name + '.png'))
for k, v in result.items():
if isinstance(v, torch.Tensor):

View File

@ -3,26 +3,35 @@ import argparse
import mmengine
import rich
from mmengine import Config, DictAction
from mmengine import DictAction
from mmengine.evaluator import Evaluator
from mmengine.registry import init_default_scope
from mmcls.registry import METRICS
HELP_URL = (
'https://mmclassification.readthedocs.io/en/dev-1.x/useful_tools/'
'log_result_analysis.html#how-to-conduct-offline-metric-evaluation')
prog_description = f"""\
Evaluate metric of the results saved in pkl format.
The detailed usage can be found in {HELP_URL}
"""
def parse_args():
parser = argparse.ArgumentParser(description='Evaluate metric of the '
'results saved in pkl format')
parser.add_argument('config', help='Config of the model')
parser = argparse.ArgumentParser(description=prog_description)
parser.add_argument('pkl_results', help='Results in pickle format')
parser.add_argument(
'--cfg-options',
'--metric',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
action='append',
dest='metric_options',
help='The metric config, the key-value pair in xxx=yyy format will be '
'parsed as the metric config items. You can specify multiple metrics '
'by use multiple `--metric-options`. For list type value, you can use '
'"key=[a,b]" or "key=a,b", and it also allows nested list/tuple '
'values, e.g. "key=[(a,b),(c,d)]".')
args = parser.parse_args()
return args
@ -30,16 +39,21 @@ def parse_args():
def main():
args = parse_args()
# load config
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
if args.metric_options is None:
raise ValueError('Please speicfy at least one `--metric`. '
f'The detailed usage can be found in {HELP_URL}')
init_default_scope('mmcls') # Use mmcls as default scope.
test_metrics = []
for metric_option in args.metric_options:
metric_cfg = {}
for kv in metric_option:
k, v = kv.split('=', maxsplit=1)
metric_cfg[k] = DictAction._parse_iterable(v)
test_metrics.append(METRICS.build(metric_cfg))
predictions = mmengine.load(args.pkl_results)
evaluator = Evaluator(cfg.test_evaluator)
evaluator = Evaluator(test_metrics)
eval_results = evaluator.offline_evaluate(predictions, None)
rich.print(eval_results)

View File

@ -1,9 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import rich
import rich.console
from mmengine import Config, DictAction
console = rich.console.Console()
def parse_args():
parser = argparse.ArgumentParser(description='Print the whole config')
@ -29,7 +31,7 @@ def main():
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
rich.print(cfg.pretty_text)
console.print(cfg.pretty_text, markup=False)
if __name__ == '__main__':