mirror of https://github.com/open-mmlab/mmocr.git
[Feature] Add args and docs for multi-machine training/testing (#849)
* add args and docs for multi-machine training/testing * update docs * update docs * update docs * update docs * revert commitpull/873/head
parent
a851fadcb0
commit
838aa47f9f
|
@ -2,7 +2,7 @@
|
|||
|
||||
We introduce the way to test pretrained models on datasets here.
|
||||
|
||||
## Testing with Single GPU
|
||||
## Testing on a Single GPU
|
||||
|
||||
You can use `tools/test.py` to perform single CPU/GPU inference. For example, to evaluate DBNet on IC15: (You can download pretrained models from [Model Zoo](modelzoo.md)):
|
||||
|
||||
|
@ -25,8 +25,6 @@ CUDA_VISIBLE_DEVICES= python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [AR
|
|||
|
||||
:::
|
||||
|
||||
|
||||
|
||||
| ARGS | Type | Description |
|
||||
| ------------------ | --------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `--out` | str | Output result file in pickle format. |
|
||||
|
@ -43,8 +41,7 @@ CUDA_VISIBLE_DEVICES= python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [AR
|
|||
| `--eval-options` | str | Custom options for evaluation, the key-value pair in xxx=yyy format will be kwargs for dataset.evaluate() function. |
|
||||
| `--launcher` | 'none', 'pytorch', 'slurm', 'mpi' | Options for job launcher. |
|
||||
|
||||
|
||||
## Testing with Multiple GPUs
|
||||
## Testing on Multiple GPUs
|
||||
|
||||
MMOCR implements **distributed** testing with `MMDistributedDataParallel`.
|
||||
|
||||
|
@ -54,24 +51,63 @@ You can use the following command to test a dataset with multiple GPUs.
|
|||
[PORT={PORT}] ./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [PY_ARGS]
|
||||
```
|
||||
|
||||
|
||||
| Arguments | Type | Description |
|
||||
| --------- | ---- | -------------------------------------------------------------------------------- |
|
||||
| `PORT` | int | The master port that will be used by the machine with rank 0. Defaults to 29500. |
|
||||
| `CONFIG_FILE` | str | The path to config. |
|
||||
| `CHECKPOINT_FILE` | str | The path to the checkpoint. |
|
||||
| `GPU_NUM` | int | The number of GPUs to be used per node. Defaults to 8. |
|
||||
| `PY_ARGS` | str | Arguments to be parsed by `tools/test.py`. |
|
||||
|
||||
|
||||
For example,
|
||||
|
||||
```shell
|
||||
./tools/dist_test.sh configs/example_config.py work_dirs/example_exp/example_model_20200202.pth 1 --eval hmean-iou
|
||||
```
|
||||
|
||||
## Testing on Multiple Machines
|
||||
|
||||
You can launch a task on multiple machines connected to the same network.
|
||||
|
||||
```shell
|
||||
NNODES=${NNODES} NODE_RANK=${NODE_RANK} PORT=${MASTER_PORT} MASTER_ADDR=${MASTER_ADDR} ./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [PY_ARGS]
|
||||
```
|
||||
|
||||
| Arguments | Type | Description |
|
||||
| --------------- | ---- | ----------------------------------------------------------------------------------------------------------- |
|
||||
| `NNODES` | int | The number of nodes.
|
||||
| `NODE_RANK` | int | The rank of current node.
|
||||
| `PORT` | int | The master port that will be used by rank 0 node. Defaults to 29500. |
|
||||
| `MASTER_ADDR` | str | The address of rank 0 node. Defaults to "127.0.0.1". |
|
||||
| `CONFIG_FILE` | str | The path to config. |
|
||||
| `CHECKPOINT_FILE` | str | The path to the checkpoint. |
|
||||
| `GPU_NUM` | int | The number of GPUs to be used per node. Defaults to 8. |
|
||||
| `PY_ARGS` | str | Arguments to be parsed by `tools/test.py`. |
|
||||
|
||||
:::{note}
|
||||
MMOCR relies on torch.distributed package for distributed testing. Find more information at PyTorch’s [launch utility](https://pytorch.org/docs/stable/distributed.html#launch-utility).
|
||||
:::
|
||||
|
||||
Say that you want to launch a job on two machines. On the first machine:
|
||||
|
||||
```shell
|
||||
NNODES=2 NODE_RANK=0 PORT=${MASTER_PORT} MASTER_ADDR=${MASTER_ADDR} ./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [PY_ARGS]
|
||||
```
|
||||
|
||||
On the second machine:
|
||||
|
||||
```shell
|
||||
NNODES=2 NODE_RANK=1 PORT=${MASTER_PORT} MASTER_ADDR=${MASTER_ADDR} ./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [PY_ARGS]
|
||||
```
|
||||
|
||||
:::{note}
|
||||
The speed of the network could be the bottleneck of testing.
|
||||
:::
|
||||
|
||||
## Testing with Slurm
|
||||
|
||||
If you run MMOCR on a cluster managed with [Slurm](https://slurm.schedmd.com/), you can use the script `tools/slurm_test.sh`.
|
||||
|
||||
|
||||
```shell
|
||||
[GPUS=${GPUS}] [GPUS_PER_NODE=${GPUS_PER_NODE}] [SRUN_ARGS=${SRUN_ARGS}] ./tools/slurm_test.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${CHECKPOINT_FILE} [PY_ARGS]
|
||||
```
|
||||
|
@ -83,7 +119,6 @@ If you run MMOCR on a cluster managed with [Slurm](https://slurm.schedmd.com/),
|
|||
| `SRUN_ARGS` | str | Arguments to be parsed by srun. Available options can be found [here](https://slurm.schedmd.com/srun.html). |
|
||||
| `PY_ARGS` | str | Arguments to be parsed by `tools/test.py`. |
|
||||
|
||||
|
||||
Here is an example of using 8 GPUs to test an example model on the 'dev' partition with job name 'test_job'.
|
||||
|
||||
```shell
|
||||
|
@ -94,7 +129,7 @@ GPUS=8 ./tools/slurm_test.sh dev test_job configs/example_config.py work_dirs/ex
|
|||
|
||||
By default, MMOCR tests the model image by image. For faster inference, you may change `data.val_dataloader.samples_per_gpu` and `data.test_dataloader.samples_per_gpu` in the config. For example,
|
||||
|
||||
```
|
||||
```python
|
||||
data = dict(
|
||||
...
|
||||
val_dataloader=dict(samples_per_gpu=16),
|
||||
|
@ -102,6 +137,7 @@ data = dict(
|
|||
...
|
||||
)
|
||||
```
|
||||
|
||||
will test the model with 16 images in a batch.
|
||||
|
||||
:::{warning}
|
||||
|
|
|
@ -47,11 +47,49 @@ MMOCR implements **distributed** training with `MMDistributedDataParallel`. (Ple
|
|||
| Arguments | Type | Description |
|
||||
| --------- | ---- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `PORT` | int | The master port that will be used by the machine with rank 0. Defaults to 29500. **Note:** If you are launching multiple distrbuted training jobs on a single machine, you need to specify different ports for each job to avoid port conflicts. |
|
||||
| `CONFIG_FILE` | str | The path to config. |
|
||||
| `WORK_DIR` | str | The path to the working directory. |
|
||||
| `GPU_NUM` | int | The number of GPUs to be used per node. Defaults to 8. |
|
||||
| `PY_ARGS` | str | Arguments to be parsed by `tools/train.py`. |
|
||||
|
||||
## Training on Multiple Machines
|
||||
|
||||
MMOCR relies on torch.distributed package for distributed training. Thus, as a basic usage, one can launch distributed training via PyTorch’s [launch utility](https://pytorch.org/docs/stable/distributed.html#launch-utility).
|
||||
You can launch a task on multiple machines connected to the same network.
|
||||
|
||||
```shell
|
||||
NNODES=${NNODES} NODE_RANK=${NODE_RANK} PORT=${MASTER_PORT} MASTER_ADDR=${MASTER_ADDR} ./tools/dist_train.sh ${CONFIG_FILE} ${WORK_DIR} ${GPU_NUM} [PY_ARGS]
|
||||
```
|
||||
|
||||
| Arguments | Type | Description |
|
||||
| --------------- | ---- | ----------------------------------------------------------------------------------------------------------- |
|
||||
| `NNODES` | int | The number of nodes.
|
||||
| `NODE_RANK` | int | The rank of current node.
|
||||
| `PORT` | int | The master port that will be used by rank 0 node. Defaults to 29500. |
|
||||
| `MASTER_ADDR` | str | The address of rank 0 node. Defaults to "127.0.0.1". |
|
||||
| `CONFIG_FILE` | str | The path to config. |
|
||||
| `WORK_DIR` | str | The path to the working directory. |
|
||||
| `GPU_NUM` | int | The number of GPUs to be used per node. Defaults to 8. |
|
||||
| `PY_ARGS` | str | Arguments to be parsed by `tools/train.py`. |
|
||||
|
||||
:::{note}
|
||||
MMOCR relies on torch.distributed package for distributed training. Find more information at PyTorch’s [launch utility](https://pytorch.org/docs/stable/distributed.html#launch-utility).
|
||||
:::
|
||||
|
||||
Say that you want to launch a job on two machines. On the first machine:
|
||||
|
||||
```shell
|
||||
NNODES=2 NODE_RANK=0 PORT=${MASTER_PORT} MASTER_ADDR=${MASTER_ADDR} ./tools/dist_train.sh ${CONFIG_FILE} ${WORK_DIR} ${GPU_NUM} [PY_ARGS]
|
||||
```
|
||||
|
||||
On the second machine:
|
||||
|
||||
```shell
|
||||
NNODES=2 NODE_RANK=1 PORT=${MASTER_PORT} MASTER_ADDR=${MASTER_ADDR} ./tools/dist_train.sh ${CONFIG_FILE} ${WORK_DIR} ${GPU_NUM} [PY_ARGS]
|
||||
```
|
||||
|
||||
:::{note}
|
||||
The speed of the network could be the bottleneck of training.
|
||||
:::
|
||||
|
||||
## Training with Slurm
|
||||
|
||||
|
|
|
@ -9,8 +9,20 @@ fi
|
|||
CONFIG=$1
|
||||
CHECKPOINT=$2
|
||||
GPUS=$3
|
||||
NNODES=${NNODES:-1}
|
||||
NODE_RANK=${NODE_RANK:-0}
|
||||
PORT=${PORT:-29500}
|
||||
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
|
||||
|
||||
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
|
||||
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
|
||||
$(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
|
||||
python -m torch.distributed.launch \
|
||||
--nnodes=$NNODES \
|
||||
--node_rank=$NODE_RANK \
|
||||
--master_addr=$MASTER_ADDR \
|
||||
--nproc_per_node=$GPUS \
|
||||
--master_port=$PORT \
|
||||
$(dirname "$0")/test.py \
|
||||
$CONFIG \
|
||||
$CHECKPOINT \
|
||||
--launcher pytorch \
|
||||
${@:4}
|
||||
|
|
|
@ -9,14 +9,25 @@ fi
|
|||
CONFIG=$1
|
||||
WORK_DIR=$2
|
||||
GPUS=$3
|
||||
|
||||
NNODES=${NNODES:-1}
|
||||
NODE_RANK=${NODE_RANK:-0}
|
||||
PORT=${PORT:-29500}
|
||||
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
|
||||
|
||||
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
|
||||
|
||||
if [ ${GPUS} == 1 ]; then
|
||||
python $(dirname "$0")/train.py $CONFIG --work-dir=${WORK_DIR} ${@:4}
|
||||
else
|
||||
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
|
||||
$(dirname "$0")/train.py $CONFIG --work-dir=${WORK_DIR} --launcher pytorch ${@:4}
|
||||
python -m torch.distributed.launch \
|
||||
--nnodes=$NNODES \
|
||||
--node_rank=$NODE_RANK \
|
||||
--master_addr=$MASTER_ADDR \
|
||||
--nproc_per_node=$GPUS \
|
||||
--master_port=$PORT \
|
||||
$(dirname "$0")/train.py \
|
||||
$CONFIG \
|
||||
--seed 0 \
|
||||
--work-dir=${WORK_DIR} \
|
||||
--launcher pytorch ${@:4}
|
||||
fi
|
||||
|
|
Loading…
Reference in New Issue