[Feature] Add scheduler visualization from mmpretrain to mmocr (#1866)

* 2023/04/18 vis_scheduler_transportation_v1

* lint_fix

* Update docs/zh_cn/user_guides/useful_tools.md

Co-authored-by: Tong Gao <gaotongxiao@gmail.com>

* Update docs/zh_cn/user_guides/useful_tools.md

Co-authored-by: Tong Gao <gaotongxiao@gmail.com>

* Update docs/en/user_guides/useful_tools.md

Co-authored-by: Tong Gao <gaotongxiao@gmail.com>

* 2023/04/25 add -d 100

---------

Co-authored-by: Tong Gao <gaotongxiao@gmail.com>
pull/1922/head
Quantum Cat 2023-04-25 22:36:12 +08:00 committed by GitHub
parent e9a31ddd70
commit 4eb3cc7de5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 401 additions and 21 deletions

View File

@ -1,10 +1,10 @@
# Useful Tools
## Analysis Tools
## Visualization Tools
### Dataset Visualization Tool
MMOCR provides a dataset visualization tool `tools/analysis_tools/browse_datasets.py` to help users troubleshoot possible dataset-related problems. You just need to specify the path to the training config (usually stored in `configs/textdet/dbnet/xxx.py`) or the dataset config (usually stored in `configs/textdet/_base_/datasets/xxx.py`), and the tool will automatically plots the transformed (or original) images and labels.
MMOCR provides a dataset visualization tool `tools/visualizations/browse_datasets.py` to help users troubleshoot possible dataset-related problems. You just need to specify the path to the training config (usually stored in `configs/textdet/dbnet/xxx.py`) or the dataset config (usually stored in `configs/textdet/_base_/datasets/xxx.py`), and the tool will automatically plots the transformed (or original) images and labels.
#### Usage
@ -25,11 +25,11 @@ python tools/visualizations/browse_dataset.py \
| config | str | (required) Path to the config. |
| -o, --output-dir | str | If GUI is not available, specifying an output path to save the visualization results. |
| -p, --phase | str | Phase of dataset to visualize. Use "train", "test" or "val" if you just want to visualize the default split. It's also possible to be a dataset variable name, which might be useful when a dataset split has multiple variants in the config. |
| -m, --mode | `original`, `transformed`, `pipeline` | Display mode: display original pictures or transformed pictures or comparison pictures. `original` only visualizes the original dataset & annotations; `transformed` shows the resulting images processed through all the transforms; `pipeline` shows all the intermediate images. Defaults to "transformed". |
| -m, --mode | `original`, `transformed`, `pipeline` | Display mode: display original pictures or transformed pictures or comparison pictures.`original` only visualizes the original dataset & annotations; `transformed` shows the resulting images processed through all the transforms; `pipeline` shows all the intermediate images. Defaults to "transformed". |
| -t, --task | `auto`, `textdet`, `textrecog` | Specify the task type of the dataset. If `auto`, the task type will be inferred from the config. If the script is unable to infer the task type, you need to specify it manually. Defaults to `auto`. |
| -n, --show-number | int | The number of samples to visualized. If not specified, display all images in the dataset. |
| -i, --show-interval | float | Interval of visualization (s), defaults to 2. |
| --cfg-options | float | Override configs. [Example](./config.md#command-line-modification) |
| --cfg-options | float | Override configs.[Example](./config.md#command-line-modification) |
#### Examples
@ -37,7 +37,7 @@ The following example demonstrates how to use the tool to visualize the training
```Bash
# Example: Visualizing the training data used by dbnet_r50dcn_v2_fpnc_1200e_icadr2015 model
python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py
python tools/visualizations/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py
```
By default, the visualization mode is "transformed", and you will see the images & annotations being transformed by the pipeline:
@ -49,7 +49,7 @@ By default, the visualization mode is "transformed", and you will see the images
If you just want to visualize the original dataset, simply set the mode to "original":
```Bash
python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -m original
python tools/visualizations/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -m original
```
<div align=center><img src="https://user-images.githubusercontent.com/22607038/206646570-382d0f26-908a-4ab4-b1a7-5cc31fa70c5f.jpg" style=" width: auto; height: 40%; "></div>
@ -57,7 +57,7 @@ python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet
Or, to visualize the entire pipeline:
```Bash
python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -m pipeline
python tools/visualizations/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -m pipeline
```
<div align=center><img src="https://user-images.githubusercontent.com/22607038/206637571-287640c0-1f55-453f-a2fc-9f9734b9593f.jpg" style=" width: auto; height: 40%; "></div>
@ -65,7 +65,7 @@ python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet
In addition, users can also visualize the original images and their corresponding labels of the dataset by specifying the path to the dataset config file, for example:
```Bash
python tools/analysis_tools/browse_dataset.py configs/textrecog/_base_/datasets/icdar2015.py
python tools/visualizations/browse_dataset.py configs/textrecog/_base_/datasets/icdar2015.py
```
Some datasets might have multiple variants. For example, the test split of `icdar2015` textrecog dataset has two variants, which the [base dataset config](/configs/textrecog/_base_/datasets/icdar2015.py) defines as follows:
@ -85,11 +85,58 @@ icdar2015_1811_textrecog_test = dict(
In this case, you can specify the variant name to visualize the corresponding dataset:
```Bash
python tools/analysis_tools/browse_dataset.py configs/textrecog/_base_/datasets/icdar2015.py -p icdar2015_1811_textrecog_test
python tools/visualizations/browse_dataset.py configs/textrecog/_base_/datasets/icdar2015.py -p icdar2015_1811_textrecog_test
```
Based on this tool, users can easily verify if the annotation of a custom dataset is correct.
### Hyper-parameter Scheduler Visualization
This tool aims to help the user to check the hyper-parameter scheduler of the optimizer (without training), which support the "learning rate" or "momentum"
#### Introduce the scheduler visualization tool
```bash
python tools/visualizations/vis_scheduler.py \
${CONFIG_FILE} \
[-p, --parameter ${PARAMETER_NAME}] \
[-d, --dataset-size ${DATASET_SIZE}] \
[-n, --ngpus ${NUM_GPUs}] \
[-s, --save-path ${SAVE_PATH}] \
[--title ${TITLE}] \
[--style ${STYLE}] \
[--window-size ${WINDOW_SIZE}] \
[--cfg-options]
```
**Description of all arguments**
- `config`: The path of a model config file.
- **`-p, --parameter`**: The param to visualize its change curve, choose from "lr" and "momentum". Default to use "lr".
- **`-d, --dataset-size`**: The size of the datasets. If set`build_dataset` will be skipped and `${DATASET_SIZE}` will be used as the size. Default to use the function `build_dataset`.
- **`-n, --ngpus`**: The number of GPUs used in training, default to be 1.
- **`-s, --save-path`**: The learning rate curve plot save path, default not to save.
- `--title`: Title of figure. If not set, default to be config file name.
- `--style`: Style of plt. If not set, default to be `whitegrid`.
- `--window-size`: The shape of the display window. If not specified, it will be set to `12*7`. If used, it must be in the format `'W*H'`.
- `--cfg-options`: Modifications to the configuration file, refer to [Learn about Configs](../user_guides/config.md).
```{note}
Loading annotations maybe consume much time, you can directly specify the size of the dataset with `-d, dataset-size` to save time.
```
#### How to plot the learning rate curve without training
You can use the following command to plot the step learning rate schedule used in the config `configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py`:
```bash
python tools/visualizations/vis_scheduler.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -d 100
```
<div align=center><img src="https://user-images.githubusercontent.com/43344034/232757392-b29b8e3a-77af-451c-8786-d3b4259ab388.png" style=" width: auto; height: 40%; "></div>
## Analysis Tools
### Offline Evaluation Tool
For saved prediction results, we provide an offline evaluation script `tools/analysis_tools/offline_eval.py`. The following example demonstrates how to use this tool to evaluate the output of the "PSENet" model offline.
@ -111,10 +158,10 @@ python tools/analysis_tools/offline_eval.py configs/textdet/psenet/psenet_r50_fp
In addition, based on this tool, users can also convert predictions obtained from other libraries into MMOCR-supported formats, then use MMOCR's built-in metrics to evaluate them.
| ARGS | Type | Description |
| ------------- | ----- | ------------------------------------------------------------------ |
| ------------- | ----- | ----------------------------------------------------------------- |
| config | str | (required) Path to the config. |
| pkl_results | str | (required) The saved predictions. |
| --cfg-options | float | Override configs. [Example](./config.md#command-line-modification) |
| --cfg-options | float | Override configs.[Example](./config.md#command-line-modification) |
### Calculate FLOPs and the Number of Parameters

View File

@ -1,10 +1,10 @@
# 常用工具
## 分析工具
## 可视化工具
### 数据集可视化工具
MMOCR 提供了数据集可视化工具 `tools/analysis_tools/browse_datasets.py` 以辅助用户排查可能遇到的数据集相关的问题。用户只需要指定所使用的训练配置文件(通常存放在如 `configs/textdet/dbnet/xxx.py` 文件中)或数据集配置(通常存放在 `configs/textdet/_base_/datasets/xxx.py` 文件中路径。该工具将依据输入的配置文件类型自动将经过数据流水线data pipeline处理过的图像及其对应的标签或原始图片及其对应的标签绘制出来。
MMOCR 提供了数据集可视化工具 `tools/visualizations/browse_datasets.py` 以辅助用户排查可能遇到的数据集相关的问题。用户只需要指定所使用的训练配置文件(通常存放在如 `configs/textdet/dbnet/xxx.py` 文件中)或数据集配置(通常存放在 `configs/textdet/_base_/datasets/xxx.py` 文件中路径。该工具将依据输入的配置文件类型自动将经过数据流水线data pipeline处理过的图像及其对应的标签或原始图片及其对应的标签绘制出来。
#### 支持参数
@ -37,7 +37,7 @@ python tools/visualizations/browse_dataset.py \
```Bash
# 使用默认参数可视化 "dbnet_r50dcn_v2_fpnc_1200e_icadr2015" 模型的训练数据
python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py
python tools/visualizations/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py
```
默认情况下,可视化模式为 "transformed",您将看到经由数据流水线变换过后的图像和标注:
@ -49,7 +49,7 @@ python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet
如果您只想可视化原始数据集,只需将模式设置为 "original"
```Bash
python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -m original
python tools/visualizations/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -m original
```
<div align=center><img src="https://user-images.githubusercontent.com/22607038/206646570-382d0f26-908a-4ab4-b1a7-5cc31fa70c5f.jpg" style=" width: auto; height: 40%; "></div>
@ -57,7 +57,7 @@ python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet
或者,您也可以使用 "pipeline" 模式来可视化整个数据流水线的中间结果:
```Bash
python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -m pipeline
python tools/visualizations/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -m pipeline
```
<div align=center><img src="https://user-images.githubusercontent.com/22607038/206637571-287640c0-1f55-453f-a2fc-9f9734b9593f.jpg" style=" width: auto; height: 40%; "></div>
@ -65,7 +65,7 @@ python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet
另外,用户还可以通过指定数据集配置文件的路径来可视化数据集的原始图像及其对应的标注,例如:
```Bash
python tools/analysis_tools/browse_dataset.py configs/textrecog/_base_/datasets/icdar2015.py
python tools/visualizations/browse_dataset.py configs/textrecog/_base_/datasets/icdar2015.py
```
部分数据集可能有多个变体。例如,`icdar2015` 文本识别数据集的[配置文件](/configs/textrecog/_base_/datasets/icdar2015.py)中包含两个测试集变体,分别为 `icdar2015_textrecog_test``icdar2015_1811_textrecog_test`,如下所示:
@ -85,11 +85,58 @@ icdar2015_1811_textrecog_test = dict(
在这种情况下,用户可以通过指定 `-p` 参数来可视化不同的变体,例如,使用以下命令可视化 `icdar2015_1811_textrecog_test` 变体:
```Bash
python tools/analysis_tools/browse_dataset.py configs/textrecog/_base_/datasets/icdar2015.py -p icdar2015_1811_textrecog_test
python tools/visualizations/browse_dataset.py configs/textrecog/_base_/datasets/icdar2015.py -p icdar2015_1811_textrecog_test
```
基于该工具,用户可以轻松地查看数据集的原始图像及其对应的标注,以便于检查数据集的标注是否正确。
### 优化器参数策略可视化工具
MMOCR提供了优化器参数可视化工具 `tools/visualizations/vis_scheduler.py` 以辅助用户排查优化器的超参数调度器无需训练支持学习率learning rate和动量momentum
#### 工具简介
```bash
python tools/visualizations/vis_scheduler.py \
${CONFIG_FILE} \
[-p, --parameter ${PARAMETER_NAME}] \
[-d, --dataset-size ${DATASET_SIZE}] \
[-n, --ngpus ${NUM_GPUs}] \
[-s, --save-path ${SAVE_PATH}] \
[--title ${TITLE}] \
[--style ${STYLE}] \
[--window-size ${WINDOW_SIZE}] \
[--cfg-options]
```
**所有参数的说明**
- `config` : 模型配置文件的路径。
- **`-p, parameter`**: 可视化参数名,只能为 `["lr", "momentum"]` 之一, 默认为 `"lr"`.
- **`-d, --dataset-size`**: 数据集的大小。如果指定,`build_dataset` 将被跳过并使用这个大小作为数据集大小,默认使用 `build_dataset` 所得数据集的大小。
- **`-n, --ngpus`**: 使用 GPU 的数量, 默认为1。
- **`-s, --save-path`**: 保存的可视化图片的路径,默认不保存。
- `--title`: 可视化图片的标题,默认为配置文件名。
- `--style`: 可视化图片的风格,默认为 `whitegrid`
- `--window-size`: 可视化窗口大小,如果没有指定,默认为 `12*7`。如果需要指定,按照格式 \`W\*H'。
- `--cfg-options`: 对配置文件的修改,参考[学习配置文件](../user_guides/config.md)。
```{note}
部分数据集在解析标注阶段比较耗时,可直接将 `-d, dataset-size` 指定数据集的大小,以节约时间。
```
#### 如何在开始训练前可视化学习率曲线
你可以使用如下命令来绘制配置文件 `configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py` 将会使用的变化率曲线:
```bash
python tools/visualizations/vis_scheduler.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -d 100
```
<div align=center><img src="https://user-images.githubusercontent.com/43344034/232755081-cad8fe62-349d-400a-bc38-7f5d17824011.png" style=" width: auto; height: 40%; "></div>
## 分析工具
### 离线评测工具
对于已保存的预测结果,我们提供了离线评测脚本 `tools/analysis_tools/offline_eval.py`。例如,以下代码演示了如何使用该工具对 "PSENet" 模型的输出结果进行离线评估:

View File

@ -0,0 +1,286 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import json
import os.path as osp
import re
from pathlib import Path
from unittest.mock import MagicMock
import matplotlib.pyplot as plt
import rich
import torch.nn as nn
from mmengine.config import Config, DictAction
from mmengine.hooks import Hook
from mmengine.model import BaseModel
from mmengine.registry import init_default_scope
from mmengine.runner import Runner
from mmengine.visualization import Visualizer
from rich.progress import BarColumn, MofNCompleteColumn, Progress, TextColumn
from mmocr.registry import DATASETS
class SimpleModel(BaseModel):
"""simple model that do nothing in train_step."""
def __init__(self):
super(SimpleModel, self).__init__()
self.data_preprocessor = nn.Identity()
self.conv = nn.Conv2d(1, 1, 1)
def forward(self, inputs, data_samples, mode='tensor'):
pass
def train_step(self, data, optim_wrapper):
pass
class ParamRecordHook(Hook):
def __init__(self, by_epoch):
super().__init__()
self.by_epoch = by_epoch
self.lr_list = []
self.momentum_list = []
self.wd_list = []
self.task_id = 0
self.progress = Progress(BarColumn(), MofNCompleteColumn(),
TextColumn('{task.description}'))
def before_train(self, runner):
if self.by_epoch:
total = runner.train_loop.max_epochs
self.task_id = self.progress.add_task(
'epochs', start=True, total=total)
else:
total = runner.train_loop.max_iters
self.task_id = self.progress.add_task(
'iters', start=True, total=total)
self.progress.start()
def after_train_epoch(self, runner):
if self.by_epoch:
self.progress.update(self.task_id, advance=1)
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
if not self.by_epoch:
self.progress.update(self.task_id, advance=1)
self.lr_list.append(runner.optim_wrapper.get_lr()['lr'][0])
self.momentum_list.append(
runner.optim_wrapper.get_momentum()['momentum'][0])
self.wd_list.append(
runner.optim_wrapper.param_groups[0]['weight_decay'])
def after_train(self, runner):
self.progress.stop()
def parse_args():
parser = argparse.ArgumentParser(
description='Visualize a Dataset Pipeline')
parser.add_argument('config', help='config file path')
parser.add_argument(
'-p',
'--parameter',
type=str,
default='lr',
choices=['lr', 'momentum', 'wd'],
help='The parameter to visualize its change curve, choose from'
'"lr", "wd" and "momentum". Defaults to "lr".')
parser.add_argument(
'-d',
'--dataset-size',
type=int,
help='The size of the dataset. If specify, `build_dataset` will '
'be skipped and use this size as the dataset size.')
parser.add_argument(
'-n',
'--ngpus',
type=int,
default=1,
help='The number of GPUs used in training.')
parser.add_argument(
'-s',
'--save-path',
type=Path,
help='The learning rate curve plot save path')
parser.add_argument(
'--log-level',
default='WARNING',
help='The log level of the handler and logger. Defaults to '
'WARNING.')
parser.add_argument('--title', type=str, help='title of figure')
parser.add_argument(
'--style', type=str, default='whitegrid', help='style of plt')
parser.add_argument('--not-show', default=False, action='store_true')
parser.add_argument(
'--window-size',
default='12*7',
help='Size of the window to display images, in format of "$W*$H".')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
if args.window_size != '':
assert re.match(r'\d+\*\d+', args.window_size), \
"'window-size' must be in format 'W*H'."
return args
def plot_curve(lr_list, args, param_name, iters_per_epoch, by_epoch=True):
"""Plot learning rate vs iter graph."""
try:
import seaborn as sns
sns.set_style(args.style)
except ImportError:
pass
wind_w, wind_h = args.window_size.split('*')
wind_w, wind_h = int(wind_w), int(wind_h)
plt.figure(figsize=(wind_w, wind_h))
ax: plt.Axes = plt.subplot()
ax.plot(lr_list, linewidth=1)
if by_epoch:
ax.xaxis.tick_top()
ax.set_xlabel('Iters')
ax.xaxis.set_label_position('top')
sec_ax = ax.secondary_xaxis(
'bottom',
functions=(lambda x: x / iters_per_epoch,
lambda y: y * iters_per_epoch))
sec_ax.set_xlabel('Epochs')
else:
plt.xlabel('Iters')
plt.ylabel(param_name)
if args.title is None:
plt.title(f'{osp.basename(args.config)} {param_name} curve')
else:
plt.title(args.title)
def simulate_train(data_loader, cfg, by_epoch):
model = SimpleModel()
param_record_hook = ParamRecordHook(by_epoch=by_epoch)
default_hooks = dict(
param_scheduler=cfg.default_hooks['param_scheduler'],
runtime_info=None,
timer=None,
logger=None,
checkpoint=None,
sampler_seed=None,
param_record=param_record_hook)
runner = Runner(
model=model,
work_dir=cfg.work_dir,
train_dataloader=data_loader,
train_cfg=cfg.train_cfg,
log_level=cfg.log_level,
optim_wrapper=cfg.optim_wrapper,
param_scheduler=cfg.param_scheduler,
default_scope=cfg.default_scope,
default_hooks=default_hooks,
visualizer=MagicMock(spec=Visualizer),
custom_hooks=cfg.get('custom_hooks', None))
runner.train()
param_dict = dict(
lr=param_record_hook.lr_list,
momentum=param_record_hook.momentum_list,
wd=param_record_hook.wd_list)
return param_dict
def build_dataset(cfg):
return DATASETS.build(cfg)
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
init_default_scope(cfg.get('default_scope', 'mmocr'))
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
if cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
cfg.log_level = args.log_level
# make sure save_root exists
if args.save_path and not args.save_path.parent.exists():
raise FileNotFoundError(
f'The save path is {args.save_path}, and directory '
f"'{args.save_path.parent}' do not exist.")
# init logger
print('Param_scheduler :')
rich.print_json(json.dumps(cfg.param_scheduler))
# prepare data loader
batch_size = cfg.train_dataloader.batch_size * args.ngpus
if 'by_epoch' in cfg.train_cfg:
by_epoch = cfg.train_cfg.get('by_epoch')
elif 'type' in cfg.train_cfg:
by_epoch = cfg.train_cfg.get('type') == 'EpochBasedTrainLoop'
else:
raise ValueError('please set `train_cfg`.')
if args.dataset_size is None and by_epoch:
dataset_size = len(build_dataset(cfg.train_dataloader.dataset))
else:
dataset_size = args.dataset_size or batch_size
class FakeDataloader(list):
dataset = MagicMock(metainfo=None)
data_loader = FakeDataloader(range(dataset_size // batch_size))
dataset_info = (
f'\nDataset infos:'
f'\n - Dataset size: {dataset_size}'
f'\n - Batch size per GPU: {cfg.train_dataloader.batch_size}'
f'\n - Number of GPUs: {args.ngpus}'
f'\n - Total batch size: {batch_size}')
if by_epoch:
dataset_info += f'\n - Iterations per epoch: {len(data_loader)}'
rich.print(dataset_info + '\n')
# simulation training process
param_dict = simulate_train(data_loader, cfg, by_epoch)
param_list = param_dict[args.parameter]
if args.parameter == 'lr':
param_name = 'Learning Rate'
elif args.parameter == 'momentum':
param_name = 'Momentum'
else:
param_name = 'Weight Decay'
plot_curve(param_list, args, param_name, len(data_loader), by_epoch)
if args.save_path:
plt.savefig(args.save_path)
print(f'\nThe {param_name} graph is saved at {args.save_path}')
if not args.not_show:
plt.show()
if __name__ == '__main__':
main()