[Docs] Training & Testing Tutorials (#1331)

* zh-cn train & test tutorial

* add En

* fix comments

* Update docs/en/user_guides/train_test.md

Co-authored-by: Tong Gao <gaotongxiao@gmail.com>
pull/1352/head
Xinyu Wang 2022-08-30 19:54:04 +08:00 committed by GitHub
parent 8c904127a8
commit 8b32ea6fa9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 618 additions and 1 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 192 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 147 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.7 KiB

View File

@ -1 +1,310 @@
# Train and Test
# Training and Testing
To meet diverse requirements, MMOCR supports training and testing models on various devices, including PCs, work stations, computation clusters, etc.
## Single GPU Training and Testing
### Training
`tools/train.py` provides the basic training service. MMOCR recommends using GPUs for model training and testing, but it still enables CPU-Only training and testing. For example, the following commands demonstrate how to train a DBNet model using a single GPU or CPU.
```bash
# Train the specified MMOCR model by calling tools/train.py
CUDA_VISIBLE_DEVICES= python tools/train.py ${CONFIG_FILE} [PY_ARGS]
# Training
# Example 1: Training DBNet with CPU
CUDA_VISIBLE_DEVICES=-1 python tools/train.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py
# Example 2: Specify to train DBNet with gpu:0, specify the working directory as dbnet/, and turn on mixed precision (amp) training
CUDA_VISIBLE_DEVICES=0 python tools/train.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py --work-dir dbnet/ --amp
```
```{note}
If multiple GPUs are available, you can specify a certain GPU, e.g. the third one, by setting CUDA_VISIBLE_DEVICES=3.
```
The following table lists all the arguments supported by `train.py`. Args without the `--` prefix are mandatory, while others are optional.
| ARGS | Type | Description |
| --------------- | ---- | --------------------------------------------------------------------------- |
| config | str | (required)Path to config. |
| --work-dir | str | Specify the working directory for the training logs and models checkpoints. |
| --resume | bool | Whether to resume training from the latest checkpoint. |
| --amp | bool | Whether to use automatic mixture precision for training. |
| --auto-scale-lr | bool | Whether to use automatic learning rate scaling. |
| --cfg-options | str | Override some settings in the configs. [Example](<>) |
| --launcher | str | Option for launcher\['none', 'pytorch', 'slurm', 'mpi'\]. |
| --local_rank | int | Rank of local machineused for distributed trainingdefaults to 0。 |
### Test
`tools/test.py` provides the basic testing service, which is used in a similar way to the training script. For example, the following command demonstrates test a DBNet model on a single GPU or CPU.
```bash
# Test a pretrained MMOCR model by calling tools/test.py
CUDA_VISIBLE_DEVICES= python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [PY_ARGS]
# Test
# Example 1: Testing DBNet with CPU
CUDA_VISIBLE_DEVICES=-1 python tools/test.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py dbnet_r50.pth
# Example 2: Testing DBNet on gpu:0
CUDA_VISIBLE_DEVICES=0 python tools/test.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py dbnet_r50.pth
```
The following table lists all the arguments supported by `test.py`. Args without the `--` prefix are mandatory, while others are optional.
| ARGS | Type | Description |
| ------------- | ----- | -------------------------------------------------------------------- |
| config | str | (required)Path to config. |
| checkpoint | str | (required)The model to be tested. |
| --work-dir | str | Specify the working directory for the logs. |
| --save-preds | bool | Whether to save the predictions to a pkl file. |
| --show | bool | Whether to visualize the predictions. |
| --show-dir | str | Path to save the visualization results. |
| --wait-time | float | Interval of visualization (s), defaults to 2. |
| --cfg-options | str | Override some settings in the configs. [Example](<>) |
| --launcher | str | Option for launcher\['none', 'pytorch', 'slurm', 'mpi'\]. |
| --local_rank | int | Rank of local machineused for distributed trainingdefaults to 0. |
## Training and Testing with Multiple GPUs
For large models, distributed training or testing significantly improves the efficiency. For this purpose, MMOCR provides distributed scripts `tools/dist_train.sh` and `tools/dist_test.sh` implemented based on [MMDistributedDataParallel](mmengine.model.wrappers.MMDistributedDataParallel).
```bash
# Training
NNODES=${NNODES} NODE_RANK=${NODE_RANK} PORT=${MASTER_PORT} MASTER_ADDR=${MASTER_ADDR} ./tools/dist_train.sh ${CONFIG_FILE} ${GPU_NUM} [PY_ARGS]
# Testing
NNODES=${NNODES} NODE_RANK=${NODE_RANK} PORT=${MASTER_PORT} MASTER_ADDR=${MASTER_ADDR} ./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [PY_ARGS]
```
The following table lists the arguments supported by `dist_*.sh`.
| ARGS | Type | Description |
| --------------- | ---- | --------------------------------------------------------------------------------------------- |
| NNODES | int | The number of nodes. Defaults to 1. |
| NODE_RANK | int | The rank of current node. Defaults to 0. |
| PORT | int | The master port that will be used by rank 0 node, ranging from 0 to 65535. Defaults to 29500. |
| MASTER_ADDR | str | The address of rank 0 node. Defaults to "127.0.0.1". |
| CONFIG_FILE | str | (required)The path to config. |
| CHECKPOINT_FILE | str | (requiredonly used in dist_test.sh)The path to checkpoint to be tested. |
| GPU_NUM | int | (required)The number of GPUs to be used per node. |
| \[PY_ARGS\] | str | Arguments to be parsed by tools/train.py and tools/test.py. |
These two scripts enable training and testing on **single-machine multi-GPU** or **multi-machine multi-GPU**. See the following example for usage.
### Single-machine Multi-GPU
The following commands demonstrate how to train and test with a specified number of GPUs on a **single machine** with multiple GPUs.
1. **Training**
Training DBNet using 4 GPUs on a single machine.
```bash
tools/dist_train.sh configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py 4
```
2. **Testing**
Testing DBNet using 4 GPUs on a single machine.
```bash
tools/dist_test.sh configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py dbnet_r50.pth 4
```
### Launching Multiple Tasks on Single Machine
For a workstation equipped with multiple GPUs, the user can launch multiple tasks simultaneously by specifying the GPU IDs. For example, the following command demonstrates how to test DBNet with GPU `[0, 1, 2, 3]` and train CRNN on GPU `[4, 5, 6, 7]`.
```bash
# Specify gpu:0,1,2,3 for testing and assign port number 29500
CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_test.sh configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py dbnet_r50.pth 4
# Specify gpu:4,5,6,7 for training and assign port number 29501
CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh configs/textrecog/crnn/crnn_academic_dataset.py 4
```
```{note}
`dist_train.sh` sets `MASTER_PORT` to `29500` by default. When other processes already occupy this port, the program will get a runtime error `RuntimeError: Address already in use`. In this case, you need to set `MASTER_PORT` to another free port number in the range of `(0~65535)`.
```
### Multi-machine Multi-GPU Training and Testing
You can launch a task on multiple machines connected to the same network. MMOCR relies on `torch.distributed` package for distributed training. Find more information at PyTorchs [launch utility](https://pytorch.org/docs/stable/distributed.html#launch-utility).
1. **Training**
The following command demonstrates how to train DBNet on two machines with a total of 4 GPUs.
```bash
# Say that you want to launch the training job on two machines
# On the first machine:
NNODES=2 NODE_RANK=0 PORT=29500 MASTER_ADDR=10.140.0.169 tools/dist_train.sh configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py 2
# On the second machine:
NNODES=2 NODE_RANK=1 PORT=29501 MASTER_ADDR=10.140.0.169 tools/dist_train.sh configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py 2
```
2. **Testing**
The following command demonstrates how to test DBNet on two machines with a total of 4 GPUs.
```bash
# Say that you want to launch the testing job on two machines
# On the first machine:
NNODES=2 NODE_RANK=0 PORT=29500 MASTER_ADDR=10.140.0.169 tools/dist_test.sh configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py dbnet_r50.pth 2
# On the second machine:
NNODES=2 NODE_RANK=1 PORT=29501 MASTER_ADDR=10.140.0.169 tools/dist_test.sh configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py dbnet_r50.pth 2
```
```{note}
The speed of the network could be the bottleneck of training.
```
## Training and Testing with Slurm Cluster
If you run MMOCR on a cluster managed with [Slurm](https://slurm.schedmd.com/), you can use the script `tools/slurm_train.sh` and `tools/slurm_test.sh`.
```bash
# tools/slurm_train.sh provides scripts for submitting training tasks on clusters managed by the slurm
GPUS=${GPUS} GPUS_PER_NODE=${GPUS_PER_NODE} CPUS_PER_TASK=${CPUS_PER_TASK} SRUN_ARGS=${SRUN_ARGS} ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${WORK_DIR} [PY_ARGS]
# tools/slurm_test.sh provides scripts for submitting testing tasks on clusters managed by the slurm
GPUS=${GPUS} GPUS_PER_NODE=${GPUS_PER_NODE} CPUS_PER_TASK=${CPUS_PER_TASK} SRUN_ARGS=${SRUN_ARGS} ./tools/slurm_test.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${CHECKPOINT_FILE} ${WORK_DIR} [PY_ARGS]
```
| ARGS | Type | Description |
| --------------- | ---- | ----------------------------------------------------------------------------------------------------------- |
| GPUS | int | The number of GPUs to be used by this task. Defaults to 8. |
| GPUS_PER_NODE | int | The number of GPUs to be allocated per node. Defaults to 8. |
| CPUS_PER_TASK | int | The number of CPUs to be allocated per task. Defaults to 5. |
| SRUN_ARGS | str | Arguments to be parsed by srun. Available options can be found [here](https://slurm.schedmd.com/srun.html). |
| PARTITION | str | (required)Specify the partition on cluster. |
| JOB_NAME | str | (required)Name of the submitted job. |
| WORK_DIR | str | (required)Specify the working directory for saving the logs and checkpoints. |
| CHECKPOINT_FILE | str | (requiredonly used in slurm_test.sh)Path to the checkpoint to be tested. |
| PY_ARGS | str | Arguments to be parsed by `tools/train.py` and `tools/test.py`. |
These scripts enable training and testing on slurm clusters, see the following examples.
1. Training
Here is an example of using 1 GPU to train a DBNet model on the `dev` partition.
```bash
# Example: Request 1 GPU resource on dev partition for DBNet training task
GPUS=1 GPUS_PER_NODE=1 CPUS_PER_TASK=5 tools/slurm_train.sh dev db_r50 configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py work_dir
```
2. Testing
Similarly, the following example requests 1 GPU for testing.
```bash
# Example: Request 1 GPU resource on dev partition for DBNet testing task
GPUS=1 GPUS_PER_NODE=1 CPUS_PER_TASK=5 tools/slurm_test.sh dev db_r50 configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py dbnet_r50.pth work_dir
```
## Advanced Tips
### Resume Training from a Checkpoint
`tools/train.py` allows users to resume training from a checkpoint by specifying the `--resume` parameter, where it will automatically resume training from the latest saved checkpoint.
```bash
# Example: Resuming training from the latest checkpoint
python tools/train.py configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py 4 --resume
```
By default, the program will automatically resume training from the last successfully saved checkpoint in the last training session, i.e. `latest.pth`. However,
```python
# Example: Set the path of the checkpoint you want to load in the configuration file
load_from = 'work_dir/dbnet/models/epoch_10000.pth'
```
### Mixed Precision Training
Mixed precision training offers significant computational speedup by performing operations in half-precision format, while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. In MMOCR, the users can enable the automatic mixed precision training by simply add `--amp`.
```bash
# Example: Using automatic mixed precision training
python tools/train.py configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py 4 --amp
```
The following table shows the support of each algorithm in MMOCR for automatic mixed precision training.
| | Whether support AMP | Description |
| ------------- | :-----------------: | :-------------------------------------: |
| | Text Detection | |
| DBNet | Y | |
| DBNetpp | Y | |
| DRRG | N | roi_align_rotated does not support fp16 |
| FCENet | N | BCELoss does not support fp16 |
| Mask R-CNN | Y | |
| PANet | Y | |
| PSENet | Y | |
| TextSnake | N | |
| | Text Recognition | |
| ABINet | Y | |
| CRNN | Y | |
| MASTER | Y | |
| NRTR | Y | |
| RobustScanner | Y | |
| SAR | Y | |
| SATRN | Y | |
### Automatic Learning Rate Scaling
MMOCR sets default initial learning rates for each model in the configuration file. However, these initial learning rates may not be applicable when the user uses a different `batch_size` than our preset `base_batch_size`. Therefore, we provide a tool to automatically scale the learning rate, which can be called by adding the `--auto-scale-lr`.
```bash
# Example: Using automatic learning rate scaling
python tools/train.py configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py 4 --auto-scale-lr
```
### Visualize the Predictions
`tools/test.py` provides the visualization interface to facilitate the qualitative analysis of the OCR models.
<div align="center">
![Detection](../../../demo/resources/det_vis.png)
(Green boxes are GTs, while red boxes are predictions)
</div>
<div align="center">
![Recognition](../../../demo/resources/rec_vis.png)
(Green font is the GT, red font is the prediction)
</div>
<div align="center">
![KIE](../../../demo/resources/kie_vis.png)
(From left to right: original image, text detection and recognition result, text classification result, relationship)
</div>
```bash
# Example 1: Show the visualization results per 2 seconds
python tools/test.py configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py dbnet_r50.pth --show --wait-time 2
# Example 2: For systems that do not support graphical interfaces (such as computing clusters, etc.), the visualization results can be dumped in the specified path
python tools/test.py configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py dbnet_r50.pth --show-dir ./vis_results
```
The visualization-related parameters in `tools/test.py` are described as follows.
| ARGS | Type | Description |
| ----------- | ----- | --------------------------------------------- |
| --show | bool | Whether to show the visualization results. |
| --show-dir | str | Path to save the visualization results. |
| --wait-time | float | Interval of visualization (s), defaults to 2. |

View File

@ -1 +1,309 @@
# 训练与测试
为了适配多样化的用户需求MMOCR 实现了多种不同操作系统及设备上的模型训练及测试。无论是使用本地机器进行单机单卡训练测试,还是在部署了 slurm 系统的大规模集群上进行训练测试MMOCR 都提供了便捷的解决方案。
## 单卡机器训练及测试
### 训练
`tools/train.py` 实现了基础的训练服务。MMOCR 推荐用户使用 GPU 进行模型训练和测试,但是,用户也可以通过指定 `CUDA_VISIBLE_DEVICES=-1` 来使用 CPU 设备进行模型训练及测试。例如,以下命令演示了如何使用 CPU 或单卡 GPU 来训练 DBNet 文本检测器。
```bash
# 通过调用 tools/train.py 来训练指定的 MMOCR 模型
CUDA_VISIBLE_DEVICES= python tools/train.py ${CONFIG_FILE} [PY_ARGS]
# 训练
# 示例 1使用 CPU 训练 DBNet
CUDA_VISIBLE_DEVICES=-1 python tools/train.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py
# 示例 2指定使用 gpu:0 训练 DBNet指定工作目录为 dbnet/并打开混合精度amp训练
CUDA_VISIBLE_DEVICES=0 python tools/train.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py --work-dir dbnet/ --amp
```
```{note}
此外,如需使用指定编号的 GPU 进行训练或测试例如使用3号 GPU则可以通过设定 CUDA_VISIBLE_DEVICES=3 来实现。
```
下表列出了 `train.py` 支持的所有参数。其中,不带 `--` 前缀的参数为必须的位置参数,带 `--` 前缀的参数为可选参数。
| 参数 | 类型 | 说明 |
| --------------- | ---- | -------------------------------------------------------------- |
| config | str | (必须)配置文件路径。 |
| --work-dir | str | 指定工作目录,用于存放训练日志以及模型 checkpoints。 |
| --resume | bool | 是否从断点处恢复训练。 |
| --amp | bool | 是否使用混合精度。 |
| --auto-scale-lr | bool | 是否使用学习率自动缩放。 |
| --cfg-options | str | 用于覆写配置文件中的指定参数。[示例](#添加示例) |
| --launcher | str | 启动器选项,可选项目为 \['none', 'pytorch', 'slurm', 'mpi'\]。 |
| --local_rank | int | 本地机器编号,用于多机多卡分布式训练,默认为 0。 |
### 测试
`tools/test.py` 提供了基础的测试服务,其使用原理和训练脚本类似。例如,以下命令演示了 CPU 或 GPU 单卡测试 DBNet 模型。
```bash
# 通过调用 tools/test.py 来测试指定的 MMOCR 模型
CUDA_VISIBLE_DEVICES= python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [PY_ARGS]
# 测试
# 示例 1使用 CPU 测试 DBNet
CUDA_VISIBLE_DEVICES=-1 python tools/test.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py dbnet_r50.pth
# 示例 2使用 gpu:0 测试 DBNet
CUDA_VISIBLE_DEVICES=0 python tools/test.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py dbnet_r50.pth
```
下表列出了 `test.py` 支持的所有参数。其中,不带 `--` 前缀的参数为必须的位置参数,带 `--` 前缀的参数为可选参数。
| 参数 | 类型 | 说明 |
| ------------- | ----- | -------------------------------------------------------------- |
| config | str | (必须)配置文件路径。 |
| checkpoint | str | (必须)待测试模型路径。 |
| --work-dir | str | 工作目录,用于存放训练日志以及模型 checkpoints。 |
| --save-preds | bool | 是否将预测结果写入 pkl 文件并保存。 |
| --show | bool | 是否可视化预测结果。 |
| --show-dir | str | 将可视化的预测结果保存至指定路径。 |
| --wait-time | float | 可视化间隔时间(秒),默认为 2 秒。 |
| --cfg-options | str | 用于覆写配置文件中的指定参数。[示例](#添加示例) |
| --launcher | str | 启动器选项,可选项目为 \['none', 'pytorch', 'slurm', 'mpi'\]。 |
| --local_rank | int | 本地机器编号,用于多机多卡分布式训练,默认为 0。 |
## 多卡机器训练及测试
对于大规模模型,采用多 GPU 训练和测试可以极大地提升操作的效率。为此MMOCR 提供了基于 [MMDistributedDataParallel](mmengine.model.wrappers.MMDistributedDataParallel) 实现的分布式脚本 `tools/dist_train.sh``tools/dist_test.sh`
```bash
# 训练
NNODES=${NNODES} NODE_RANK=${NODE_RANK} PORT=${MASTER_PORT} MASTER_ADDR=${MASTER_ADDR} ./tools/dist_train.sh ${CONFIG_FILE} ${GPU_NUM} [PY_ARGS]
# 测试
NNODES=${NNODES} NODE_RANK=${NODE_RANK} PORT=${MASTER_PORT} MASTER_ADDR=${MASTER_ADDR} ./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [PY_ARGS]
```
下表列出了 `dist_*.sh` 支持的参数:
| 参数 | 类型 | 说明 |
| --------------- | ---- | ---------------------------------------------------------------------------------- |
| NNODES | int | 总共使用的机器节点个数,默认为 1。 |
| NODE_RANK | int | 节点编号,默认为 0。 |
| PORT | int | 在 RANK 0 机器上使用的 MASTER_PORT 端口号,取值范围是 0 至 65535默认值为 29500。 |
| MASTER_ADDR | str | RANK 0 机器的 IP 地址,默认值为 127.0.0.1。 |
| CONFIG_FILE | str | (必须)指定配置文件的地址。 |
| CHECKPOINT_FILE | str | (必须,仅在 dist_test.sh 中适用)指定模型权重的地址。 |
| GPU_NUM | int | (必须)指定 GPU 的数量。 |
| \[PY_ARGS\] | str | 该部分一切的参数都会被直接传入 tools/train.py 或 tools/test.py 中。 |
这两个脚本可以实现**单机多卡**或**多机多卡**的训练和测试,下面演示了它们在不同场景下的用法。
### 单机多卡
以下命令演示了如何在搭载多块 GPU 的**单台机器**上使用指定数目的 GPU 进行训练及测试:
1. **训练**
使用单台机器上的 4 块 GPU 训练 DBNet。
```bash
# 单机 4 卡训练 DBNet
tools/dist_train.sh configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py 4
```
2. **测试**
使用单台机器上的 4 块 GPU 测试 DBNet。
```bash
# 单机 4 卡测试 DBNet
tools/dist_test.sh configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py dbnet_r50.pth 4
```
### 单机多任务训练及测试
对于搭载多块 GPU 的单台服务器而言,用户可以通过指定 GPU 的形式来同时执行不同的训练任务。例如,以下命令演示了如何在一台 8 卡 GPU 服务器上分别使用 `[0, 1, 2, 3]` 卡测试 DBNet 及 `[4, 5, 6, 7]` 卡训练 CRNN
```bash
# 指定使用 gpu:0,1,2,3 测试 DBNet并分配端口号 29500
CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_test.sh configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py dbnet_r50.pth 4
# 指定使用 gpu:4,5,6,7 训练 CRNN并分配端口号 29501
CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh configs/textrecog/crnn/crnn_academic_dataset.py 4
```
```{note}
`dist_train.sh` 默认将 `MASTER_PORT` 设置为 `29500`,当单台机器上有其它进程已占用该端口时,程序则会出现运行时错误 `RuntimeError: Address already in use`。此时,用户需要将 `MASTER_PORT` 设置为 `(0~65535)` 范围内的其它空闲端口号。
```
### 多机多卡训练及测试
MMOCR 基于[torch.distributed](https://pytorch.org/docs/stable/distributed.html#launch-utility) 提供了相同局域网下的多台机器间的多卡分布式训练。
1. **训练**
以下命令演示了如何在两台机器上分别使用 2 张 GPU 合计 4 卡训练 DBNet
```bash
# 示例:在两台机器上分别使用 2 张 GPU 合计 4 卡训练 DBNet
# 在 “机器1” 上运行以下命令
NNODES=2 NODE_RANK=0 PORT=29501 MASTER_ADDR=10.140.0.169 tools/dist_train.sh configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py 2
# 在 “机器2” 上运行以下命令
NNODES=2 NODE_RANK=1 PORT=29501 MASTER_ADDR=10.140.0.169 tools/dist_train.sh configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py 2
```
2. **测试**
以下命令演示了如何在两台机器上分别使用 2 张 GPU 合计 4 卡测试:
```bash
# 示例:在两台机器上分别使用 2 张 GPU 合计 4 卡测试
# 在 “机器1” 上运行以下命令
NNODES=2 NODE_RANK=0 PORT=29500 MASTER_ADDR=10.140.0.169 tools/dist_test.sh configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py dbnet_r50.pth 2
# 在 “机器2” 上运行以下命令
NNODES=2 NODE_RANK=1 PORT=29501 MASTER_ADDR=10.140.0.169 tools/dist_test.sh configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py dbnet_r50.pth 2
```
```{note}
需要注意的是,采用多机多卡训练时,机器间的网络传输速度可能成为训练速度的瓶颈。
```
## 集群训练及测试
针对 [Slurm](https://slurm.schedmd.com/) 调度系统管理的计算集群MMOCR 提供了对应的训练和测试任务提交脚本 `tools/slurm_train.sh``tools/slurm_test.sh`
```bash
# tools/slurm_train.sh 提供基于 slurm 调度系统管理的计算集群上提交训练任务的脚本
GPUS=${GPUS} GPUS_PER_NODE=${GPUS_PER_NODE} CPUS_PER_TASK=${CPUS_PER_TASK} SRUN_ARGS=${SRUN_ARGS} ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${WORK_DIR} [PY_ARGS]
# tools/slurm_test.sh 提供基于 slurm 调度系统管理的计算集群上提交测试任务的脚本
GPUS=${GPUS} GPUS_PER_NODE=${GPUS_PER_NODE} CPUS_PER_TASK=${CPUS_PER_TASK} SRUN_ARGS=${SRUN_ARGS} ./tools/slurm_test.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${CHECKPOINT_FILE} ${WORK_DIR} [PY_ARGS]
```
| 参数 | 类型 | 说明 |
| --------------- | ---- | ------------------------------------------------------------------------- |
| GPUS | int | 使用的 GPU 数目默认为8。 |
| GPUS_PER_NODE | int | 每台节点机器上搭载的 GPU 数目默认为8。 |
| CPUS_PER_TASK | int | 任务使用的 CPU 个数默认为5。 |
| SRUN_ARGS | str | 其他 srun 支持的参数。详见[这里](https://slurm.schedmd.com/srun.html) |
| PARTITION | str | (必须)指定使用的集群分区。 |
| JOB_NAME | str | (必须)提交任务的名称。 |
| WORK_DIR | str | (必须)任务的工作目录,训练日志以及模型的 checkpoints 将被保存至该目录。 |
| CHECKPOINT_FILE | str | (必须,仅在 slurm_test.sh 中适用)指向模型权重的地址。 |
| \[PY_ARGS\] | str | tools/train.py 以及 tools/test.py 支持的参数。 |
这两个脚本可以实现 slurm 集群上的训练和测试,下面演示了它们在不同场景下的用法。
1. 训练
以下示例为在 slurm 集群 dev 分区申请 1 块 GPU 进行 DBNet 训练。
```bash
# 示例:在 slurm 集群 dev 分区申请 1块 GPU 资源进行 DBNet 训练任务
GPUS=1 GPUS_PER_NODE=1 CPUS_PER_TASK=5 tools/slurm_train.sh dev db_r50 configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py work_dir
```
2. 测试
同理, 则提供了测试任务提交脚本。以下示例为在 slurm 集群 dev 分区申请 1 块 GPU 资源进行 DBNet 测试。
```bash
# 示例:在 slurm 集群 dev 分区申请 1块 GPU 资源进行 DBNet 测试任务
GPUS=1 GPUS_PER_NODE=1 CPUS_PER_TASK=5 tools/slurm_test.sh dev db_r50 configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py dbnet_r50.pth work_dir
```
## 进阶技巧
### 从断点恢复训练
`tools/train.py` 提供了从断点恢复训练的功能,用户仅需在命令中指定 `--resume` 参数,即可自动从断点恢复训练。
```bash
# 示例:从断点恢复训练
python tools/train.py configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py 4 --resume
```
默认地,程序将自动从上次训练过程中最后成功保存的断点,即 `latest.pth` 处开始继续训练。如果用户希望指定从特定的断点处开始恢复训练,则可以按如下格式在模型的配置文件中设定该断点的路径。
```python
# 示例:在配置文件中设置想要加载的断点路径
load_from = 'work_dir/dbnet/models/epoch_10000.pth'
```
### 混合精度训练
混合精度训练可以在缩减内存占用的同时提升训练速度为此MMOCR 提供了一键式的混合精度训练方案,仅需在训练时添加 `--amp` 参数即可。
```bash
# 示例:使用自动混合精度训练
python tools/train.py configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py 4 --amp
```
下表列出了 MMOCR 中各算法对自动混合精度训练的支持情况:
| | 是否支持混合精度训练 | 备注 |
| ------------- | :------------------: | :---------------------------: |
| | 文本检测 | |
| DBNet | 是 | |
| DBNetpp | 是 | |
| DRRG | 否 | roi_align_rotated 不支持 fp16 |
| FCENet | 否 | BCELoss 不支持 fp16 |
| Mask R-CNN | 是 | |
| PANet | 是 | |
| PSENet | 是 | |
| TextSnake | 否 | |
| | 文本识别 | |
| ABINet | 是 | |
| CRNN | 是 | |
| MASTER | 是 | |
| NRTR | 是 | |
| RobustScanner | 是 | |
| SAR | 是 | |
| SATRN | 是 | |
### 自动学习率缩放
MMOCR 在配置文件中为每一个模型设置了默认的初始学习率,然而,当用户使用的 `batch_size` 不同于我们预设的 `base_batch_size` 时,这些初始学习率可能不再完全适用。因此,我们提供了自动学习率缩放工具。当使用不同于 MMOCR 预设的 `base_batch_size` 进行训练时,用户仅需添加 `--auto-scale-lr` 参数即可自动依据新的 `batch_size` 将学习率缩放至对应尺度。
```bash
# 示例:使用自动学习率缩放
python tools/train.py configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py 4 --auto-scale-lr
```
### 可视化模型测试结果
`tools/test.py` 提供了可视化接口,以方便用户对模型进行定性分析。
<div align="center">
![可视化文本检测模型](../../../demo/resources/det_vis.png)
(绿色框为真实标注,红色框为预测结果)
</div>
<div align="center">
![可视化文本识别模型](../../../demo/resources/rec_vis.png)
(绿色字体为真实标注,红色字体为预测结果)
</div>
<div align="center">
![可视化关键信息抽取模型结果](../../../demo/resources/kie_vis.png)
(从左至右分别为:原图,文本检测和识别结果,文本分类结果,关系图)
</div>
```bash
# 示例1每间隔 2 秒绘制出
python tools/test.py configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py dbnet_r50.pth --show --wait-time 2
# 示例2对于不支持图形化界面的系统如计算集群等可以将可视化结果存入指定路径
python tools/test.py configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py dbnet_r50.pth --show-dir ./vis_results
```
`tools/test.py` 中可视化相关参数说明:
| 参数 | 类型 | 说明 |
| ----------- | ----- | -------------------------------- |
| --show | bool | 是否绘制可视化结果。 |
| --show-dir | str | 可视化图片存储路径。 |
| --wait-time | float | 可视化间隔时间(秒),默认为 2。 |