diff --git a/docs/en/user_guides/useful_tools.md b/docs/en/user_guides/useful_tools.md
index 9ba6a49d..9828198f 100644
--- a/docs/en/user_guides/useful_tools.md
+++ b/docs/en/user_guides/useful_tools.md
@@ -1,10 +1,10 @@
# Useful Tools
-## Analysis Tools
+## Visualization Tools
### Dataset Visualization Tool
-MMOCR provides a dataset visualization tool `tools/analysis_tools/browse_datasets.py` to help users troubleshoot possible dataset-related problems. You just need to specify the path to the training config (usually stored in `configs/textdet/dbnet/xxx.py`) or the dataset config (usually stored in `configs/textdet/_base_/datasets/xxx.py`), and the tool will automatically plots the transformed (or original) images and labels.
+MMOCR provides a dataset visualization tool `tools/visualizations/browse_datasets.py` to help users troubleshoot possible dataset-related problems. You just need to specify the path to the training config (usually stored in `configs/textdet/dbnet/xxx.py`) or the dataset config (usually stored in `configs/textdet/_base_/datasets/xxx.py`), and the tool will automatically plots the transformed (or original) images and labels.
#### Usage
@@ -25,11 +25,11 @@ python tools/visualizations/browse_dataset.py \
| config | str | (required) Path to the config. |
| -o, --output-dir | str | If GUI is not available, specifying an output path to save the visualization results. |
| -p, --phase | str | Phase of dataset to visualize. Use "train", "test" or "val" if you just want to visualize the default split. It's also possible to be a dataset variable name, which might be useful when a dataset split has multiple variants in the config. |
-| -m, --mode | `original`, `transformed`, `pipeline` | Display mode: display original pictures or transformed pictures or comparison pictures. `original` only visualizes the original dataset & annotations; `transformed` shows the resulting images processed through all the transforms; `pipeline` shows all the intermediate images. Defaults to "transformed". |
+| -m, --mode | `original`, `transformed`, `pipeline` | Display mode: display original pictures or transformed pictures or comparison pictures.`original` only visualizes the original dataset & annotations; `transformed` shows the resulting images processed through all the transforms; `pipeline` shows all the intermediate images. Defaults to "transformed". |
| -t, --task | `auto`, `textdet`, `textrecog` | Specify the task type of the dataset. If `auto`, the task type will be inferred from the config. If the script is unable to infer the task type, you need to specify it manually. Defaults to `auto`. |
| -n, --show-number | int | The number of samples to visualized. If not specified, display all images in the dataset. |
| -i, --show-interval | float | Interval of visualization (s), defaults to 2. |
-| --cfg-options | float | Override configs. [Example](./config.md#command-line-modification) |
+| --cfg-options | float | Override configs.[Example](./config.md#command-line-modification) |
#### Examples
@@ -37,7 +37,7 @@ The following example demonstrates how to use the tool to visualize the training
```Bash
# Example: Visualizing the training data used by dbnet_r50dcn_v2_fpnc_1200e_icadr2015 model
-python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py
+python tools/visualizations/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py
```
By default, the visualization mode is "transformed", and you will see the images & annotations being transformed by the pipeline:
@@ -49,7 +49,7 @@ By default, the visualization mode is "transformed", and you will see the images
If you just want to visualize the original dataset, simply set the mode to "original":
```Bash
-python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -m original
+python tools/visualizations/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -m original
```

@@ -57,7 +57,7 @@ python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet
Or, to visualize the entire pipeline:
```Bash
-python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -m pipeline
+python tools/visualizations/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -m pipeline
```

@@ -65,7 +65,7 @@ python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet
In addition, users can also visualize the original images and their corresponding labels of the dataset by specifying the path to the dataset config file, for example:
```Bash
-python tools/analysis_tools/browse_dataset.py configs/textrecog/_base_/datasets/icdar2015.py
+python tools/visualizations/browse_dataset.py configs/textrecog/_base_/datasets/icdar2015.py
```
Some datasets might have multiple variants. For example, the test split of `icdar2015` textrecog dataset has two variants, which the [base dataset config](/configs/textrecog/_base_/datasets/icdar2015.py) defines as follows:
@@ -85,11 +85,58 @@ icdar2015_1811_textrecog_test = dict(
In this case, you can specify the variant name to visualize the corresponding dataset:
```Bash
-python tools/analysis_tools/browse_dataset.py configs/textrecog/_base_/datasets/icdar2015.py -p icdar2015_1811_textrecog_test
+python tools/visualizations/browse_dataset.py configs/textrecog/_base_/datasets/icdar2015.py -p icdar2015_1811_textrecog_test
```
Based on this tool, users can easily verify if the annotation of a custom dataset is correct.
+### Hyper-parameter Scheduler Visualization
+
+This tool aims to help the user to check the hyper-parameter scheduler of the optimizer (without training), which support the "learning rate" or "momentum"
+
+#### Introduce the scheduler visualization tool
+
+```bash
+python tools/visualizations/vis_scheduler.py \
+ ${CONFIG_FILE} \
+ [-p, --parameter ${PARAMETER_NAME}] \
+ [-d, --dataset-size ${DATASET_SIZE}] \
+ [-n, --ngpus ${NUM_GPUs}] \
+ [-s, --save-path ${SAVE_PATH}] \
+ [--title ${TITLE}] \
+ [--style ${STYLE}] \
+ [--window-size ${WINDOW_SIZE}] \
+ [--cfg-options]
+```
+
+**Description of all arguments**:
+
+- `config`: The path of a model config file.
+- **`-p, --parameter`**: The param to visualize its change curve, choose from "lr" and "momentum". Default to use "lr".
+- **`-d, --dataset-size`**: The size of the datasets. If set,`build_dataset` will be skipped and `${DATASET_SIZE}` will be used as the size. Default to use the function `build_dataset`.
+- **`-n, --ngpus`**: The number of GPUs used in training, default to be 1.
+- **`-s, --save-path`**: The learning rate curve plot save path, default not to save.
+- `--title`: Title of figure. If not set, default to be config file name.
+- `--style`: Style of plt. If not set, default to be `whitegrid`.
+- `--window-size`: The shape of the display window. If not specified, it will be set to `12*7`. If used, it must be in the format `'W*H'`.
+- `--cfg-options`: Modifications to the configuration file, refer to [Learn about Configs](../user_guides/config.md).
+
+```{note}
+Loading annotations maybe consume much time, you can directly specify the size of the dataset with `-d, dataset-size` to save time.
+```
+
+#### How to plot the learning rate curve without training
+
+You can use the following command to plot the step learning rate schedule used in the config `configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py`:
+
+```bash
+python tools/visualizations/vis_scheduler.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -d 100
+```
+
+
+
+## Analysis Tools
+
### Offline Evaluation Tool
For saved prediction results, we provide an offline evaluation script `tools/analysis_tools/offline_eval.py`. The following example demonstrates how to use this tool to evaluate the output of the "PSENet" model offline.
@@ -110,11 +157,11 @@ python tools/analysis_tools/offline_eval.py configs/textdet/psenet/psenet_r50_fp
In addition, based on this tool, users can also convert predictions obtained from other libraries into MMOCR-supported formats, then use MMOCR's built-in metrics to evaluate them.
-| ARGS | Type | Description |
-| ------------- | ----- | ------------------------------------------------------------------ |
-| config | str | (required) Path to the config. |
-| pkl_results | str | (required) The saved predictions. |
-| --cfg-options | float | Override configs. [Example](./config.md#command-line-modification) |
+| ARGS | Type | Description |
+| ------------- | ----- | ----------------------------------------------------------------- |
+| config | str | (required) Path to the config. |
+| pkl_results | str | (required) The saved predictions. |
+| --cfg-options | float | Override configs.[Example](./config.md#command-line-modification) |
### Calculate FLOPs and the Number of Parameters
diff --git a/docs/zh_cn/user_guides/useful_tools.md b/docs/zh_cn/user_guides/useful_tools.md
index f72322ce..2a607245 100644
--- a/docs/zh_cn/user_guides/useful_tools.md
+++ b/docs/zh_cn/user_guides/useful_tools.md
@@ -1,10 +1,10 @@
# 常用工具
-## 分析工具
+## 可视化工具
### 数据集可视化工具
-MMOCR 提供了数据集可视化工具 `tools/analysis_tools/browse_datasets.py` 以辅助用户排查可能遇到的数据集相关的问题。用户只需要指定所使用的训练配置文件(通常存放在如 `configs/textdet/dbnet/xxx.py` 文件中)或数据集配置(通常存放在 `configs/textdet/_base_/datasets/xxx.py` 文件中)路径。该工具将依据输入的配置文件类型自动将经过数据流水线(data pipeline)处理过的图像及其对应的标签,或原始图片及其对应的标签绘制出来。
+MMOCR 提供了数据集可视化工具 `tools/visualizations/browse_datasets.py` 以辅助用户排查可能遇到的数据集相关的问题。用户只需要指定所使用的训练配置文件(通常存放在如 `configs/textdet/dbnet/xxx.py` 文件中)或数据集配置(通常存放在 `configs/textdet/_base_/datasets/xxx.py` 文件中)路径。该工具将依据输入的配置文件类型自动将经过数据流水线(data pipeline)处理过的图像及其对应的标签,或原始图片及其对应的标签绘制出来。
#### 支持参数
@@ -37,7 +37,7 @@ python tools/visualizations/browse_dataset.py \
```Bash
# 使用默认参数可视化 "dbnet_r50dcn_v2_fpnc_1200e_icadr2015" 模型的训练数据
-python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py
+python tools/visualizations/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py
```
默认情况下,可视化模式为 "transformed",您将看到经由数据流水线变换过后的图像和标注:
@@ -49,7 +49,7 @@ python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet
如果您只想可视化原始数据集,只需将模式设置为 "original":
```Bash
-python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -m original
+python tools/visualizations/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -m original
```

@@ -57,7 +57,7 @@ python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet
或者,您也可以使用 "pipeline" 模式来可视化整个数据流水线的中间结果:
```Bash
-python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -m pipeline
+python tools/visualizations/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -m pipeline
```

@@ -65,7 +65,7 @@ python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet
另外,用户还可以通过指定数据集配置文件的路径来可视化数据集的原始图像及其对应的标注,例如:
```Bash
-python tools/analysis_tools/browse_dataset.py configs/textrecog/_base_/datasets/icdar2015.py
+python tools/visualizations/browse_dataset.py configs/textrecog/_base_/datasets/icdar2015.py
```
部分数据集可能有多个变体。例如,`icdar2015` 文本识别数据集的[配置文件](/configs/textrecog/_base_/datasets/icdar2015.py)中包含两个测试集变体,分别为 `icdar2015_textrecog_test` 和 `icdar2015_1811_textrecog_test`,如下所示:
@@ -85,11 +85,58 @@ icdar2015_1811_textrecog_test = dict(
在这种情况下,用户可以通过指定 `-p` 参数来可视化不同的变体,例如,使用以下命令可视化 `icdar2015_1811_textrecog_test` 变体:
```Bash
-python tools/analysis_tools/browse_dataset.py configs/textrecog/_base_/datasets/icdar2015.py -p icdar2015_1811_textrecog_test
+python tools/visualizations/browse_dataset.py configs/textrecog/_base_/datasets/icdar2015.py -p icdar2015_1811_textrecog_test
```
基于该工具,用户可以轻松地查看数据集的原始图像及其对应的标注,以便于检查数据集的标注是否正确。
+### 优化器参数策略可视化工具
+
+MMOCR提供了优化器参数可视化工具 `tools/visualizations/vis_scheduler.py` 以辅助用户排查优化器的超参数调度器(无需训练),支持学习率(learning rate)和动量(momentum)。
+
+#### 工具简介
+
+```bash
+python tools/visualizations/vis_scheduler.py \
+ ${CONFIG_FILE} \
+ [-p, --parameter ${PARAMETER_NAME}] \
+ [-d, --dataset-size ${DATASET_SIZE}] \
+ [-n, --ngpus ${NUM_GPUs}] \
+ [-s, --save-path ${SAVE_PATH}] \
+ [--title ${TITLE}] \
+ [--style ${STYLE}] \
+ [--window-size ${WINDOW_SIZE}] \
+ [--cfg-options]
+```
+
+**所有参数的说明**:
+
+- `config` : 模型配置文件的路径。
+- **`-p, parameter`**: 可视化参数名,只能为 `["lr", "momentum"]` 之一, 默认为 `"lr"`.
+- **`-d, --dataset-size`**: 数据集的大小。如果指定,`build_dataset` 将被跳过并使用这个大小作为数据集大小,默认使用 `build_dataset` 所得数据集的大小。
+- **`-n, --ngpus`**: 使用 GPU 的数量, 默认为1。
+- **`-s, --save-path`**: 保存的可视化图片的路径,默认不保存。
+- `--title`: 可视化图片的标题,默认为配置文件名。
+- `--style`: 可视化图片的风格,默认为 `whitegrid`。
+- `--window-size`: 可视化窗口大小,如果没有指定,默认为 `12*7`。如果需要指定,按照格式 \`W\*H'。
+- `--cfg-options`: 对配置文件的修改,参考[学习配置文件](../user_guides/config.md)。
+
+```{note}
+部分数据集在解析标注阶段比较耗时,可直接将 `-d, dataset-size` 指定数据集的大小,以节约时间。
+```
+
+#### 如何在开始训练前可视化学习率曲线
+
+你可以使用如下命令来绘制配置文件 `configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py` 将会使用的变化率曲线:
+
+```bash
+python tools/visualizations/vis_scheduler.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -d 100
+```
+
+
+
+## 分析工具
+
### 离线评测工具
对于已保存的预测结果,我们提供了离线评测脚本 `tools/analysis_tools/offline_eval.py`。例如,以下代码演示了如何使用该工具对 "PSENet" 模型的输出结果进行离线评估:
diff --git a/tools/analysis_tools/browse_dataset.py b/tools/visualizations/browse_dataset.py
similarity index 100%
rename from tools/analysis_tools/browse_dataset.py
rename to tools/visualizations/browse_dataset.py
diff --git a/tools/visualizations/vis_scheduler.py b/tools/visualizations/vis_scheduler.py
new file mode 100644
index 00000000..4a2d4a3c
--- /dev/null
+++ b/tools/visualizations/vis_scheduler.py
@@ -0,0 +1,286 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import json
+import os.path as osp
+import re
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import matplotlib.pyplot as plt
+import rich
+import torch.nn as nn
+from mmengine.config import Config, DictAction
+from mmengine.hooks import Hook
+from mmengine.model import BaseModel
+from mmengine.registry import init_default_scope
+from mmengine.runner import Runner
+from mmengine.visualization import Visualizer
+from rich.progress import BarColumn, MofNCompleteColumn, Progress, TextColumn
+
+from mmocr.registry import DATASETS
+
+
+class SimpleModel(BaseModel):
+ """simple model that do nothing in train_step."""
+
+ def __init__(self):
+ super(SimpleModel, self).__init__()
+ self.data_preprocessor = nn.Identity()
+ self.conv = nn.Conv2d(1, 1, 1)
+
+ def forward(self, inputs, data_samples, mode='tensor'):
+ pass
+
+ def train_step(self, data, optim_wrapper):
+ pass
+
+
+class ParamRecordHook(Hook):
+
+ def __init__(self, by_epoch):
+ super().__init__()
+ self.by_epoch = by_epoch
+ self.lr_list = []
+ self.momentum_list = []
+ self.wd_list = []
+ self.task_id = 0
+ self.progress = Progress(BarColumn(), MofNCompleteColumn(),
+ TextColumn('{task.description}'))
+
+ def before_train(self, runner):
+ if self.by_epoch:
+ total = runner.train_loop.max_epochs
+ self.task_id = self.progress.add_task(
+ 'epochs', start=True, total=total)
+ else:
+ total = runner.train_loop.max_iters
+ self.task_id = self.progress.add_task(
+ 'iters', start=True, total=total)
+ self.progress.start()
+
+ def after_train_epoch(self, runner):
+ if self.by_epoch:
+ self.progress.update(self.task_id, advance=1)
+
+ def after_train_iter(self, runner, batch_idx, data_batch, outputs):
+ if not self.by_epoch:
+ self.progress.update(self.task_id, advance=1)
+ self.lr_list.append(runner.optim_wrapper.get_lr()['lr'][0])
+ self.momentum_list.append(
+ runner.optim_wrapper.get_momentum()['momentum'][0])
+ self.wd_list.append(
+ runner.optim_wrapper.param_groups[0]['weight_decay'])
+
+ def after_train(self, runner):
+ self.progress.stop()
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='Visualize a Dataset Pipeline')
+ parser.add_argument('config', help='config file path')
+ parser.add_argument(
+ '-p',
+ '--parameter',
+ type=str,
+ default='lr',
+ choices=['lr', 'momentum', 'wd'],
+ help='The parameter to visualize its change curve, choose from'
+ '"lr", "wd" and "momentum". Defaults to "lr".')
+ parser.add_argument(
+ '-d',
+ '--dataset-size',
+ type=int,
+ help='The size of the dataset. If specify, `build_dataset` will '
+ 'be skipped and use this size as the dataset size.')
+ parser.add_argument(
+ '-n',
+ '--ngpus',
+ type=int,
+ default=1,
+ help='The number of GPUs used in training.')
+ parser.add_argument(
+ '-s',
+ '--save-path',
+ type=Path,
+ help='The learning rate curve plot save path')
+ parser.add_argument(
+ '--log-level',
+ default='WARNING',
+ help='The log level of the handler and logger. Defaults to '
+ 'WARNING.')
+ parser.add_argument('--title', type=str, help='title of figure')
+ parser.add_argument(
+ '--style', type=str, default='whitegrid', help='style of plt')
+ parser.add_argument('--not-show', default=False, action='store_true')
+ parser.add_argument(
+ '--window-size',
+ default='12*7',
+ help='Size of the window to display images, in format of "$W*$H".')
+ parser.add_argument(
+ '--cfg-options',
+ nargs='+',
+ action=DictAction,
+ help='override some settings in the used config, the key-value pair '
+ 'in xxx=yyy format will be merged into config file. If the value to '
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
+ 'Note that the quotation marks are necessary and that no white space '
+ 'is allowed.')
+ args = parser.parse_args()
+ if args.window_size != '':
+ assert re.match(r'\d+\*\d+', args.window_size), \
+ "'window-size' must be in format 'W*H'."
+
+ return args
+
+
+def plot_curve(lr_list, args, param_name, iters_per_epoch, by_epoch=True):
+ """Plot learning rate vs iter graph."""
+ try:
+ import seaborn as sns
+ sns.set_style(args.style)
+ except ImportError:
+ pass
+
+ wind_w, wind_h = args.window_size.split('*')
+ wind_w, wind_h = int(wind_w), int(wind_h)
+ plt.figure(figsize=(wind_w, wind_h))
+
+ ax: plt.Axes = plt.subplot()
+ ax.plot(lr_list, linewidth=1)
+
+ if by_epoch:
+ ax.xaxis.tick_top()
+ ax.set_xlabel('Iters')
+ ax.xaxis.set_label_position('top')
+ sec_ax = ax.secondary_xaxis(
+ 'bottom',
+ functions=(lambda x: x / iters_per_epoch,
+ lambda y: y * iters_per_epoch))
+ sec_ax.set_xlabel('Epochs')
+ else:
+ plt.xlabel('Iters')
+ plt.ylabel(param_name)
+
+ if args.title is None:
+ plt.title(f'{osp.basename(args.config)} {param_name} curve')
+ else:
+ plt.title(args.title)
+
+
+def simulate_train(data_loader, cfg, by_epoch):
+ model = SimpleModel()
+ param_record_hook = ParamRecordHook(by_epoch=by_epoch)
+ default_hooks = dict(
+ param_scheduler=cfg.default_hooks['param_scheduler'],
+ runtime_info=None,
+ timer=None,
+ logger=None,
+ checkpoint=None,
+ sampler_seed=None,
+ param_record=param_record_hook)
+
+ runner = Runner(
+ model=model,
+ work_dir=cfg.work_dir,
+ train_dataloader=data_loader,
+ train_cfg=cfg.train_cfg,
+ log_level=cfg.log_level,
+ optim_wrapper=cfg.optim_wrapper,
+ param_scheduler=cfg.param_scheduler,
+ default_scope=cfg.default_scope,
+ default_hooks=default_hooks,
+ visualizer=MagicMock(spec=Visualizer),
+ custom_hooks=cfg.get('custom_hooks', None))
+
+ runner.train()
+
+ param_dict = dict(
+ lr=param_record_hook.lr_list,
+ momentum=param_record_hook.momentum_list,
+ wd=param_record_hook.wd_list)
+
+ return param_dict
+
+
+def build_dataset(cfg):
+ return DATASETS.build(cfg)
+
+
+def main():
+ args = parse_args()
+ cfg = Config.fromfile(args.config)
+
+ init_default_scope(cfg.get('default_scope', 'mmocr'))
+
+ if args.cfg_options is not None:
+ cfg.merge_from_dict(args.cfg_options)
+ if cfg.get('work_dir', None) is None:
+ # use config filename as default work_dir if cfg.work_dir is None
+ cfg.work_dir = osp.join('./work_dirs',
+ osp.splitext(osp.basename(args.config))[0])
+
+ cfg.log_level = args.log_level
+
+ # make sure save_root exists
+ if args.save_path and not args.save_path.parent.exists():
+ raise FileNotFoundError(
+ f'The save path is {args.save_path}, and directory '
+ f"'{args.save_path.parent}' do not exist.")
+
+ # init logger
+ print('Param_scheduler :')
+ rich.print_json(json.dumps(cfg.param_scheduler))
+
+ # prepare data loader
+ batch_size = cfg.train_dataloader.batch_size * args.ngpus
+
+ if 'by_epoch' in cfg.train_cfg:
+ by_epoch = cfg.train_cfg.get('by_epoch')
+ elif 'type' in cfg.train_cfg:
+ by_epoch = cfg.train_cfg.get('type') == 'EpochBasedTrainLoop'
+ else:
+ raise ValueError('please set `train_cfg`.')
+
+ if args.dataset_size is None and by_epoch:
+ dataset_size = len(build_dataset(cfg.train_dataloader.dataset))
+ else:
+ dataset_size = args.dataset_size or batch_size
+
+ class FakeDataloader(list):
+ dataset = MagicMock(metainfo=None)
+
+ data_loader = FakeDataloader(range(dataset_size // batch_size))
+ dataset_info = (
+ f'\nDataset infos:'
+ f'\n - Dataset size: {dataset_size}'
+ f'\n - Batch size per GPU: {cfg.train_dataloader.batch_size}'
+ f'\n - Number of GPUs: {args.ngpus}'
+ f'\n - Total batch size: {batch_size}')
+ if by_epoch:
+ dataset_info += f'\n - Iterations per epoch: {len(data_loader)}'
+ rich.print(dataset_info + '\n')
+
+ # simulation training process
+ param_dict = simulate_train(data_loader, cfg, by_epoch)
+ param_list = param_dict[args.parameter]
+
+ if args.parameter == 'lr':
+ param_name = 'Learning Rate'
+ elif args.parameter == 'momentum':
+ param_name = 'Momentum'
+ else:
+ param_name = 'Weight Decay'
+ plot_curve(param_list, args, param_name, len(data_loader), by_epoch)
+
+ if args.save_path:
+ plt.savefig(args.save_path)
+ print(f'\nThe {param_name} graph is saved at {args.save_path}')
+
+ if not args.not_show:
+ plt.show()
+
+
+if __name__ == '__main__':
+ main()