[Docs] Add analysis&misc docs (#525)

* refactor ananlysis_log and add docs

* refactor ananlysis_log and fix lint

* improve docs

* improve docs

* improve docs

* fix bugs and refactor analysis_log

* rename analysis folder to analysis_tools

* fix failure link

* add result analysis docs

* add eval-metrics docs

* add misc doc

* fix lint

* improve docs

* improve misc docs

* fix docs

* Change the `eval_options` in `tools/analysis_tools/eval_metric.py` to
`metric_options`

* Improve tutorials

Co-authored-by: mzr1996 <mzr1996@163.com>
pull/580/head
Ezra-Yu 2021-12-07 11:27:34 +08:00 committed by GitHub
parent b3a6cae522
commit fdb178303b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 550 additions and 6 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

View File

@ -184,7 +184,7 @@ We provide lots of useful tools under `tools/` directory.
We provide a script adapted from [flops-counter.pytorch](https://github.com/sovrasov/flops-counter.pytorch) to compute the FLOPs and params of a given model.
```shell
python tools/get_flops.py ${CONFIG_FILE} [--shape ${INPUT_SHAPE}]
python tools/analysis_tools/get_flops.py ${CONFIG_FILE} [--shape ${INPUT_SHAPE}]
```
You will get the result like this.

View File

@ -38,6 +38,8 @@ You can switch between Chinese and English documentation in the lower-left corne
tools/pytorch2torchscript.md
tools/model_serving.md
tools/visualization.md
tools/analysis.md
tools/miscellaneous.md
.. toctree::

View File

@ -0,0 +1,211 @@
# Analysis
<!-- TOC -->
- [Log Analysis](#log-analysis)
- [Plot Curves](#plot-curves)
- [Calculate Training Time](#calculate-training-time)
- [Result Analysis](#result-analysis)
- [Evaluate Results](#evaluate-results)
- [View Typical Results](#view-typical-results)
- [Model Complexity](#model-complexity)
- [FAQs](#faqs)
<!-- TOC -->
## Log Analysis
### Plot Curves
`tools/analysis_tools/analyze_logs.py` plots curves of given keys according to the log files.
<div align=center><img src="../_static/image/tools/analysis/analyze_log.jpg" style=" width: 75%; height: 30%; "></div>
```shell
python tools/analysis_tools/analyze_logs.py plot_curve \
${JSON_LOGS} \
[--keys ${KEYS}] \
[--title ${TITLE}] \
[--legend ${LEGEND}] \
[--backend ${BACKEND}] \
[--style ${STYLE}] \
[--out ${OUT_FILE}] \
[--window-size ${WINDOW_SIZE}]
```
**Description of all arguments**
- `json_logs` : The paths of the log files, separate multiple files by spaces.
- `--keys` : The fields of the logs to analyze, separate multiple keys by spaces. Defaults to 'loss'.
- `--title` : The title of the figure. Defaults to use the filename.
- `--legend` : The names of legend, the number of which must be equal to `len(${JSON_LOGS}) * len(${KEYS})`. Defaults to use `"${JSON_LOG}-${KEYS}"`.
- `--backend` : The backend of matplotlib. Defaults to auto selected by matplotlib.
- `--style` : The style of the figure. Default to `whitegrid`.
- `--out` : The path of the output picture. If not set, the figure won't be saved.
- `--window-size`: The shape of the display window. The format should be `'W*H'`. Defaults to `'12*7'`.
```{note}
The `--style` option depends on `seaborn` package, please install it before setting it.
```
Examples:
- Plot the loss curve in training.
```shell
python tools/analysis_tools/analyze_logs.py plot_curve your_log_json --keys loss --legend loss
```
- Plot the top-1 accuracy and top-5 accuracy curves, and save the figure to results.jpg.
```shell
python tools/analysis_tools/analyze_logs.py plot_curve your_log_json --keys accuracy_top-1 accuracy_top-5 --legend top1 top5 --out results.jpg
```
- Compare the top-1 accuracy of two log files in the same figure.
```shell
python tools/analysis_tools/analyze_logs.py plot_curve log1.json log2.json --keys accuracy_top-1 --legend exp1 exp2
```
```{note}
The tool will automatically select to find keys in training logs or validation logs according to the keys.
Therefore, if you add a custom evaluation metric, please also add the key to `TEST_METRICS` in this tool.
```
### Calculate Training Time
`tools/analysis_tools/analyze_logs.py` can also calculate the training time according to the log files.
```shell
python tools/analysis_tools/analyze_logs.py cal_train_time \
${JSON_LOGS}
[--include-outliers]
```
**Description of all arguments**:
- `json_logs` : The paths of the log files, separate multiple files by spaces.
- `--include-outliers` : If set, include the first iteration in each epoch (Sometimes the time of first iterations is longer).
Example:
```shell
python tools/analysis_tools/analyze_logs.py cal_train_time work_dirs/some_exp/20200422_153324.log.json
```
The output is expected to be like the below.
```text
-----Analyze train time of work_dirs/some_exp/20200422_153324.log.json-----
slowest epoch 68, average time is 0.3818
fastest epoch 1, average time is 0.3694
time std over epochs is 0.0020
average iter time: 0.3777 s/iter
```
## Result Analysis
With the `--out` argument in `tools/train.py`, we can save the inference results of all samples as a file.
And with this result file, we can do further analysis.
### Evaluate Results
`tools/analysis_tools/eval_metric.py` can evaluate metrics again.
```shell
python tools/analysis_tools/eval_metric.py \
${CONFIG} \
${RESULT} \
[--metrics ${METRICS}] \
[--cfg-options ${CFG_OPTIONS}] \
[--metric-options ${METRIC_OPTIONS}]
```
Description of all arguments:
- `config` : The path of the model config file.
- `result`: The Output result file in json/pickle format from `tools/test.py`.
- `--metrics` : Evaluation metrics, the acceptable values depend on the dataset.
- `--cfg-options`: If specified, the key-value pair config will be merged into the config file, for more details please refer to [Tutorial 1: Learn about Configs](../tutorials/config.md)
- `--metric-options`: If specified, the key-value pair arguments will be passed to the `metric_options` argument of dataset's `evaluate` function.
```{note}
In `tools/test.py`, we support using `--out-items` option to select which kind of results will be saved. Please ensure the result file includes "class_scores" to use this tool.
```
**Examples**:
```shell
python tools/analysis_tools/eval_metric.py configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py your_result.pkl --metrics accuracy --metric-options "topk=(1,5)"
```
### View Typical Results
`tools/analysis_tools/analyze_results.py` can save the images with the highest scores in successful or failed prediction.
```shell
python tools/analysis_tools/analyze_results.py \
${CONFIG} \
${RESULT} \
[--out-dir ${OUT_DIR}] \
[--topk ${TOPK}] \
[--cfg-options ${CFG_OPTIONS}]
```
**Description of all arguments**:
- `config` : The path of the model config file.
- `result`: Output result file in json/pickle format from `tools/test.py`.
- `--out_dir`: Directory to store output files.
- `--topk`: The number of images in successful or failed prediction with the highest `topk` scores to save. If not specified, it will be set to 20.
- `--cfg-options`: If specified, the key-value pair config will be merged into the config file, for more details please refer to [Tutorial 1: Learn about Configs](../tutorials/config.md)
```{note}
In `tools/test.py`, we support using `--out-items` option to select which kind of results will be saved. Please ensure the result file includes "pred_score", "pred_label" and "pred_class" to use this tool.
```
**Examples**:
```shell
python tools/analysis_tools/analyze_results.py \
configs/resnet/resnet50_b32x8_imagenet.py \
result.pkl \
--out_dir results \
--topk 50
```
## Model Complexity
### Get the FLOPs and params (experimental)
We provide a script adapted from [flops-counter.pytorch](https://github.com/sovrasov/flops-counter.pytorch) to compute the FLOPs and params of a given model.
```shell
python tools/analysis_tools/get_flops.py ${CONFIG_FILE} [--shape ${INPUT_SHAPE}]
```
Description of all arguments:
- `config` : The path of the model config file.
- `--shape`: Input size, support single value or double value parameter, such as `--shape 256` or `--shape 224 256`. If not set, default to be `224 224`.
You will get a result like this.
```text
==============================
Input shape: (3, 224, 224)
Flops: 4.12 GFLOPs
Params: 25.56 M
==============================
```
```{warning}
This tool is still experimental and we do not guarantee that the number is correct. You may well use the result for simple comparisons, but double-check it before you adopt it in technical reports or papers.
- FLOPs are related to the input shape while parameters are not. The default input shape is (1, 3, 224, 224).
- Some operators are not counted into FLOPs like GN and custom operators. Refer to [`mmcv.cnn.get_model_complexity_info()`](https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/utils/flops_counter.py) for details.
```
## FAQs
- None

View File

@ -0,0 +1,59 @@
# MISCELLANEOUS
<!-- TOC -->
- [Print the entire config](#print-the-entire-config)
- [Verify Dataset](#verify-dataset)
- [FAQs](#faqs)
<!-- TOC -->
## Print the entire config
`tools/misc/print_config.py` prints the whole config verbatim, expanding all its imports.
```shell
python tools/misc/print_config.py ${CONFIG} [--cfg-options ${CFG_OPTIONS}]
```
Description of all arguments:
- `config` : The path of the model config file.
- `--cfg-options`: If specified, the key-value pair config will be merged into the config file, for more details please refer to [Tutorial 1: Learn about Configs](../tutorials/config.md)
**Examples**:
```shell
python tools/misc/print_config.py configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py
```
## Verify Dataset
`tools/misc/verify_dataset.py` can verify dataset, check whether there are broken pictures in the given dataset.
```shell
python tools/print_config.py \
${CONFIG} \
[--out-path ${OUT-PATH}] \
[--phase ${PHASE}] \
[--num-process ${NUM-PROCESS}]
[--cfg-options ${CFG_OPTIONS}]
```
**Description of all arguments**:
- `config` : The path of the model config file.
- `--out-path` : The path to save the verification result, if not set, defaults to 'brokenfiles.log'.
- `--phase` : Phase of dataset to verify, accept "train" "test" and "val", if not set, defaults to "train".
- `--num-process` : number of process to use, if not set, defaults to 1.
- `--cfg-options`: If specified, the key-value pair config will be merged into the config file, for more details please refer to [Tutorial 1: Learn about Configs](../tutorials/config.md)
**Examples**:
```shell
python tools/misc/verify_dataset.py configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py --out-path broken_imgs.log --phase val --num-process 8
```
## FAQs
- None

View File

@ -2,7 +2,7 @@
MMClassification mainly uses python files as configs. The design of our configuration file system integrates modularity and inheritance, facilitating users to conduct various experiments. All configuration files are placed in the `configs` folder, which mainly contains the primitive configuration folder of `_base_` and many algorithm folders such as `resnet`, `swin_transformer`, `vision_transformer`, etc.
If you wish to inspect the config file, you may run `python tools/analysis/print_config.py /PATH/TO/CONFIG` to see the complete config.
If you wish to inspect the config file, you may run `python tools/misc/print_config.py /PATH/TO/CONFIG` to see the complete config.
<!-- TOC -->

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

View File

@ -180,7 +180,7 @@ CUDA_VISIBLE_DEVICES=4,5,6,7 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NA
我们根据 [flops-counter.pytorch](https://github.com/sovrasov/flops-counter.pytorch) 提供了一个脚本用于计算给定模型的 FLOPs 和参数量
```shell
python tools/get_flops.py ${CONFIG_FILE} [--shape ${INPUT_SHAPE}]
python tools/analysis_tools/get_flops.py ${CONFIG_FILE} [--shape ${INPUT_SHAPE}]
```
用户将获得如下结果:

View File

@ -38,6 +38,8 @@ You can switch between Chinese and English documentation in the lower-left corne
tools/pytorch2torchscript.md
tools/model_serving.md
tools/visualization.md
tools/analysis.md
tools/miscellaneous.md
.. toctree::

View File

@ -0,0 +1,211 @@
# 分析
<!-- TOC -->
- [日志分析](#日志分析)
- [绘制曲线图](#绘制曲线图)
- [统计训练时间](#统计训练时间)
- [结果分析](#结果分析)
- [评估结果](#查看典型结果)
- [查看典型结果](#查看典型结果)
- [模型复杂度分析](#模型复杂度分析)
- [常见问题](#常见问题)
<!-- TOC -->
## 日志分析
### 绘制曲线图
指定一个训练日志文件,可通过 `tools/analysis_tools/analyze_logs.py` 脚本绘制指定键值的变化曲线
<div align=center><img src="../_static/image/tools/analysis/analyze_log.jpg" style=" width: 75%; height: 30%; "></div>
```shell
python tools/analysis_tools/analyze_logs.py plot_curve \
${JSON_LOGS} \
[--keys ${KEYS}] \
[--title ${TITLE}] \
[--legend ${LEGEND}] \
[--backend ${BACKEND}] \
[--style ${STYLE}] \
[--out ${OUT_FILE}] \
[--window-size ${WINDOW_SIZE}]
```
所有参数的说明
- `json_logs` :模型配置文件的路径(可同时传入多个,使用空格分开)。
- `--keys` :分析日志的关键字段,数量为 `len(${JSON_LOGS}) * len(${KEYS})` 默认为 'loss'。
- `--title` :分析日志的图片名称,默认使用配置文件名, 默认为空。
- `--legend` :图例名(可同时传入多个,使用空格分开,数目与 `${JSON_LOGS} * ${KEYS}` 数目一致)。默认使用 `"${JSON_LOG}-${KEYS}"`
- `--backend` matplotlib 的绘图后端,默认由 matplotlib 自动选择。
- `--style` :绘图配色风格,默认为 `whitegrid`
- `--out` :保存分析图片的路径,如不指定则不保存。
- `--window-size`: 可视化窗口大小,如果没有指定,默认为 `12*7`。如果需要指定,需按照格式 `'W*H'`
```{note}
`--style` 选项依赖于第三方库 `seaborn`,需要设置绘图风格请现安装该库。
```
例如:
- 绘制某日志文件对应的损失曲线图。
```shell
python tools/analysis_tools/analyze_logs.py plot_curve your_log_json --keys loss --legend loss
```
- 绘制某日志文件对应的 top-1 和 top-5 准确率曲线图,并将曲线图导出为 results.jpg 文件。
```shell
python tools/analysis_tools/analyze_logs.py plot_curve your_log_json --keys accuracy_top-1 accuracy_top-5 --legend top1 top5 --out results.jpg
```
- 在同一图像内绘制两份日志文件对应的 top-1 准确率曲线图。
```shell
python tools/analysis_tools/analyze_logs.py plot_curve log1.json log2.json --keys accuracy_top-1 --legend run1 run2
```
```{note}
本工具会自动根据关键字段选择从日志的训练部分还是验证部分读取,因此如果你添加了
自定义的验证指标,请把相对应的关键字段加入到本工具的 `TEST_METRICS` 变量中。
```
### 统计训练时间
`tools/analysis_tools/analyze_logs.py` 也可以根据日志文件统计训练耗时。
```shell
python tools/analysis_tools/analyze_logs.py cal_train_time \
${JSON_LOGS}
[--include-outliers]
```
**所有参数的说明**
- `json_logs` :模型配置文件的路径(可同时传入多个,使用空格分开)。
- `--include-outliers` :如果指定,将不会排除每个轮次中第一轮迭代的记录(有时第一轮迭代会耗时较长)
**示例**:
```shell
python tools/analysis_tools/analyze_logs.py cal_train_time work_dirs/some_exp/20200422_153324.log.json
```
预计输出结果如下所示:
```text
-----Analyze train time of work_dirs/some_exp/20200422_153324.log.json-----
slowest epoch 68, average time is 0.3818
fastest epoch 1, average time is 0.3694
time std over epochs is 0.0020
average iter time: 0.3777 s/iter
```
## 结果分析
利用 `tools/train.py``--out` 参数,我们可以将所有的样本的推理结果保存到输出
文件中。利用这一文件,我们可以进行进一步的分析。
### 评估结果
`tools/analysis_tools/eval_metric.py` 可以用来再次计算评估结果。
```shell
python tools/analysis_tools/analyze_results.py \
${CONFIG} \
${RESULT} \
[--metrics ${METRICS}] \
[--cfg-options ${CFG_OPTIONS}] \
[--metric-options ${METRIC_OPTIONS}]
```
**所有参数说明**
- `config` :配置文件的路径。
- `result` `tools/test.py` 的输出结果文件。
- `metrics` 评估的衡量指标,可接受的值取决于数据集类。
- `--cfg-options`: 额外的配置选项,会被合入配置文件,参考[教程 1如何编写配置文件](https://mmclassification.readthedocs.io/zh_CN/latest/tutorials/config.html)。
- `--metric-options`: 如果指定了,这些选项将被传递给数据集 `evaluate` 函数的 `metric_options` 参数。
```{note}
`tools/test.py` 中,我们支持使用 `--out-items` 选项来选择保存哪些结果。为了使用本工具,请确保结果文件中包含 "class_scores"。
```
**示例**
```shell
python tools/analysis_tools/analyze_results.py configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py ./result.pkl --metrics accuracy --metric-options "topk=(1,5)"
```
### 查看典型结果
`tools/analysis_tools/analyze_results.py` 可以保存预测成功/失败,同时得分最高的 k 个图像。
```shell
python tools/analysis_tools/analyze_results.py \
${CONFIG} \
${RESULT} \
[--out-dir ${OUT_DIR}] \
[--topk ${TOPK}] \
[--cfg-options ${CFG_OPTIONS}]
```
**所有参数说明**
- `config` :配置文件的路径。
- `result` `tools/test.py` 的输出结果文件。
- `--out_dir` :保存结果分析的文件夹路径。
- `--topk` :分别保存多少张预测成功/失败的图像。如果不指定,默认为 `20`
- `--cfg-options`: 额外的配置选项,会被合入配置文件,参考[教程 1如何编写配置文件](https://mmclassification.readthedocs.io/zh_CN/latest/tutorials/config.html)。
```{note}
`tools/test.py` 中,我们支持使用 `--out-items` 选项来选择保存哪些结果。为了使用本工具,请确保结果文件中包含 "pred_score"、"pred_label" 和 "pred_class"。
```
**示例**
```shell
python tools/analysis_tools/analyze_results.py \
configs/resnet/resnet50_xxxx.py \
result.pkl \
--out_dir results \
--topk 50
```
## 模型复杂度分析
### 计算 FLOPs 和参数量(试验性的)
我们根据 [flops-counter.pytorch](https://github.com/sovrasov/flops-counter.pytorch) 提供了一个脚本用于计算给定模型的 FLOPs 和参数量。
```shell
python tools/analysis_tools/get_flops.py ${CONFIG_FILE} [--shape ${INPUT_SHAPE}]
```
**所有参数说明**
- `config` :配置文件的路径。
- `--shape`: 输入尺寸,支持单值或者双值, 如: `--shape 256`、`--shape 224 256`。默认为`224 224`。
用户将获得如下结果:
```text
==============================
Input shape: (3, 224, 224)
Flops: 4.12 GFLOPs
Params: 25.56 M
==============================
```
```{warning}
此工具仍处于试验阶段,我们不保证该数字正确无误。您最好将结果用于简单比较,但在技术报告或论文中采用该结果之前,请仔细检查。
- FLOPs 与输入的尺寸有关,而参数量与输入尺寸无关。默认输入尺寸为 (1, 3, 224, 224)
- 一些运算不会被计入 FLOPs 的统计中,例如 GN 和自定义运算。详细信息请参考 [`mmcv.cnn.get_model_complexity_info()`](https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/utils/flops_counter.py)
```
## 常见问题
- 无

View File

@ -0,0 +1,59 @@
# 其他工具
<!-- TOC -->
- [打印完整配置](#打印完整配置)
- [检查数据集](#检查数据集)
- [常见问题](#常见问题)
<!-- TOC -->
## 打印完整配置
`tools/misc/print_config.py` 脚本会解析所有输入变量,并打印完整配置信息。
```shell
python tools/misc/print_config.py ${CONFIG} [--cfg-options ${CFG_OPTIONS}]
```
**所有参数说明**
- `config` :配置文件的路径。
- `--cfg-options`: 额外的配置选项,会被合入配置文件,参考[教程 1如何编写配置文件](https://mmclassification.readthedocs.io/zh_CN/latest/tutorials/config.html)。
**示例**
```shell
python tools/misc/print_config.py configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py
```
## 检查数据集
`tools/misc/verify_dataset.py` 脚本会检查数据集的所有图片,查看是否有已经损坏的图片。
```shell
python tools/print_config.py \
${CONFIG} \
[--out-path ${OUT-PATH}] \
[--phase ${PHASE}] \
[--num-process ${NUM-PROCESS}]
[--cfg-options ${CFG_OPTIONS}]
```
**所有参数说明**:
- `config` 配置文件的路径。
- `--out-path` 输出结果路径,默认为 'brokenfiles.log'。
- `--phase` 检查哪个阶段的数据集,可用值为 "train" 、"test" 或者 "val" 默认为 "train"。
- `--num-process` 指定的进程数默认为1。
- `--cfg-options`: 额外的配置选项,会被合入配置文件,参考[教程 1如何编写配置文件](https://mmclassification.readthedocs.io/zh_CN/latest/tutorials/config.html)。
**示例**:
```shell
python tools/misc/verify_dataset.py configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py --out-path broken_imgs.log --phase val --num-process 8
```
## 常见问题
- 无

View File

@ -29,7 +29,7 @@ def parse_args():
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--eval-options',
'--metric-options',
nargs='+',
action=DictAction,
help='custom options for evaluation, the key-value pair in xxx=yyy '
@ -56,14 +56,14 @@ def main():
dataset = build_dataset(cfg.data.test)
pred_score = outputs['class_scores']
kwargs = {} if args.eval_options is None else args.eval_options
eval_kwargs = cfg.get('evaluation', {}).copy()
# hard-code way to remove EvalHook args
for key in [
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 'rule'
]:
eval_kwargs.pop(key, None)
eval_kwargs.update(dict(metric=args.metrics, **kwargs))
eval_kwargs.update(
dict(metric=args.metrics, metric_options=args.metric_options))
print(dataset.evaluate(pred_score, **eval_kwargs))