mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Enhance] New-style CPU training and inference. (#1251)
* [Enhance] New-style CPU training and inference. * assert mmcv version * SyncBN to BN in training and testing * SyncBN to BN in training and testing * upload untracked files to this branch * delete gpu_ids * fix bugs * assert args.gpu_id in train.py * use cfg.gpu_ids = [args.gpu_id] * use cfg.gpu_ids = [args.gpu_id] * fix typo * fix typo * fix typos
This commit is contained in:
parent
46898b8e8a
commit
574b195be1
@ -6,6 +6,7 @@ and also some high-level apis for easier integration to other projects.
|
|||||||
### Test a dataset
|
### Test a dataset
|
||||||
|
|
||||||
- single GPU
|
- single GPU
|
||||||
|
- CPU
|
||||||
- single node multiple GPU
|
- single node multiple GPU
|
||||||
- multiple node
|
- multiple node
|
||||||
|
|
||||||
@ -15,6 +16,10 @@ You can use the following commands to test a dataset.
|
|||||||
# single-gpu testing
|
# single-gpu testing
|
||||||
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] [--show]
|
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] [--show]
|
||||||
|
|
||||||
|
# CPU: disable GPUs and run single-gpu testing script
|
||||||
|
export CUDA_VISIBLE_DEVICES=-1
|
||||||
|
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] [--show]
|
||||||
|
|
||||||
# multi-gpu testing
|
# multi-gpu testing
|
||||||
./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}]
|
./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}]
|
||||||
```
|
```
|
||||||
|
@ -33,6 +33,20 @@ python tools/train.py ${CONFIG_FILE} [optional arguments]
|
|||||||
|
|
||||||
If you want to specify the working directory in the command, you can add an argument `--work-dir ${YOUR_WORK_DIR}`.
|
If you want to specify the working directory in the command, you can add an argument `--work-dir ${YOUR_WORK_DIR}`.
|
||||||
|
|
||||||
|
### Train with CPU
|
||||||
|
|
||||||
|
The process of training on the CPU is consistent with single GPU training. We just need to disable GPUs before the training process.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export CUDA_VISIBLE_DEVICES=-1
|
||||||
|
```
|
||||||
|
|
||||||
|
And then run the script [above](#train-with-a-single-gpu).
|
||||||
|
|
||||||
|
```{warning}
|
||||||
|
The process of training on the CPU is consistent with single GPU training. We just need to disable GPUs before the training process.
|
||||||
|
```
|
||||||
|
|
||||||
### Train with multiple GPUs
|
### Train with multiple GPUs
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
### 测试一个数据集
|
### 测试一个数据集
|
||||||
|
|
||||||
- 单卡 GPU
|
- 单卡 GPU
|
||||||
|
- CPU
|
||||||
- 单节点多卡 GPU
|
- 单节点多卡 GPU
|
||||||
- 多节点
|
- 多节点
|
||||||
|
|
||||||
@ -14,6 +15,10 @@
|
|||||||
# 单卡 GPU 测试
|
# 单卡 GPU 测试
|
||||||
python tools/test.py ${配置文件} ${检查点文件} [--out ${结果文件}] [--eval ${评估指标}] [--show]
|
python tools/test.py ${配置文件} ${检查点文件} [--out ${结果文件}] [--eval ${评估指标}] [--show]
|
||||||
|
|
||||||
|
# CPU: 禁用 GPU 并运行单 GPU 测试脚本
|
||||||
|
export CUDA_VISIBLE_DEVICES=-1
|
||||||
|
python tools/test.py ${配置文件} ${检查点文件} [--out ${结果文件}] [--eval ${评估指标}] [--show]
|
||||||
|
|
||||||
# 多卡GPU 测试
|
# 多卡GPU 测试
|
||||||
./tools/dist_test.sh ${配置文件} ${检查点文件} ${GPU数目} [--out ${结果文件}] [--eval ${评估指标}]
|
./tools/dist_test.sh ${配置文件} ${检查点文件} ${GPU数目} [--out ${结果文件}] [--eval ${评估指标}]
|
||||||
```
|
```
|
||||||
|
@ -23,6 +23,20 @@ python tools/train.py ${配置文件} [可选参数]
|
|||||||
|
|
||||||
如果您想在命令里定义工作文件夹路径,您可以添加一个参数`--work-dir ${YOUR_WORK_DIR}`。
|
如果您想在命令里定义工作文件夹路径,您可以添加一个参数`--work-dir ${YOUR_WORK_DIR}`。
|
||||||
|
|
||||||
|
### 使用 CPU 训练
|
||||||
|
|
||||||
|
使用 CPU 训练的流程和使用单 GPU 训练的流程一致,我们仅需要在训练流程开始前禁用 GPU。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export CUDA_VISIBLE_DEVICES=-1
|
||||||
|
```
|
||||||
|
|
||||||
|
之后运行单 GPU 训练脚本即可。
|
||||||
|
|
||||||
|
```{warning}
|
||||||
|
我们不推荐用户使用 CPU 进行训练,这太过缓慢。我们支持这个功能是为了方便用户在没有 GPU 的机器上进行调试。
|
||||||
|
```
|
||||||
|
|
||||||
### 使用多卡 GPU 训练
|
### 使用多卡 GPU 训练
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -9,6 +10,7 @@ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
|||||||
from mmcv.runner import HOOKS, build_optimizer, build_runner, get_dist_info
|
from mmcv.runner import HOOKS, build_optimizer, build_runner, get_dist_info
|
||||||
from mmcv.utils import build_from_cfg
|
from mmcv.utils import build_from_cfg
|
||||||
|
|
||||||
|
from mmseg import digit_version
|
||||||
from mmseg.core import DistEvalHook, EvalHook
|
from mmseg.core import DistEvalHook, EvalHook
|
||||||
from mmseg.datasets import build_dataloader, build_dataset
|
from mmseg.datasets import build_dataloader, build_dataset
|
||||||
from mmseg.utils import find_latest_checkpoint, get_root_logger
|
from mmseg.utils import find_latest_checkpoint, get_root_logger
|
||||||
@ -99,9 +101,10 @@ def train_segmentor(model,
|
|||||||
broadcast_buffers=False,
|
broadcast_buffers=False,
|
||||||
find_unused_parameters=find_unused_parameters)
|
find_unused_parameters=find_unused_parameters)
|
||||||
else:
|
else:
|
||||||
model = MMDataParallel(
|
if not torch.cuda.is_available():
|
||||||
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
|
assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \
|
||||||
|
'Please use MMCV >= 1.4.4 for CPU training!'
|
||||||
|
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
|
||||||
# build runner
|
# build runner
|
||||||
optimizer = build_optimizer(model, cfg.optimizer)
|
optimizer = build_optimizer(model, cfg.optimizer)
|
||||||
|
|
||||||
|
@ -8,11 +8,13 @@ import warnings
|
|||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import torch
|
import torch
|
||||||
|
from mmcv.cnn.utils import revert_sync_batchnorm
|
||||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||||
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
|
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
|
||||||
wrap_fp16_model)
|
wrap_fp16_model)
|
||||||
from mmcv.utils import DictAction
|
from mmcv.utils import DictAction
|
||||||
|
|
||||||
|
from mmseg import digit_version
|
||||||
from mmseg.apis import multi_gpu_test, single_gpu_test
|
from mmseg.apis import multi_gpu_test, single_gpu_test
|
||||||
from mmseg.datasets import build_dataloader, build_dataset
|
from mmseg.datasets import build_dataloader, build_dataset
|
||||||
from mmseg.models import build_segmentor
|
from mmseg.models import build_segmentor
|
||||||
@ -147,11 +149,18 @@ def main():
|
|||||||
cfg.model.pretrained = None
|
cfg.model.pretrained = None
|
||||||
cfg.data.test.test_mode = True
|
cfg.data.test.test_mode = True
|
||||||
|
|
||||||
|
if args.gpu_id is not None:
|
||||||
cfg.gpu_ids = [args.gpu_id]
|
cfg.gpu_ids = [args.gpu_id]
|
||||||
|
|
||||||
# init distributed env first, since logger depends on the dist info.
|
# init distributed env first, since logger depends on the dist info.
|
||||||
if args.launcher == 'none':
|
if args.launcher == 'none':
|
||||||
|
cfg.gpu_ids = [args.gpu_id]
|
||||||
distributed = False
|
distributed = False
|
||||||
|
if len(cfg.gpu_ids) > 1:
|
||||||
|
warnings.warn(f'The gpu-ids is reset from {cfg.gpu_ids} to '
|
||||||
|
f'{cfg.gpu_ids[0:1]} to avoid potential error in '
|
||||||
|
'non-distribute testing time.')
|
||||||
|
cfg.gpu_ids = cfg.gpu_ids[0:1]
|
||||||
else:
|
else:
|
||||||
distributed = True
|
distributed = True
|
||||||
init_dist(args.launcher, **cfg.dist_params)
|
init_dist(args.launcher, **cfg.dist_params)
|
||||||
@ -236,7 +245,15 @@ def main():
|
|||||||
tmpdir = None
|
tmpdir = None
|
||||||
|
|
||||||
if not distributed:
|
if not distributed:
|
||||||
model = MMDataParallel(model, device_ids=[0])
|
warnings.warn(
|
||||||
|
'SyncBN is only supported with DDP. To be compatible with DP, '
|
||||||
|
'we convert SyncBN to BN. Please use dist_train.sh which can '
|
||||||
|
'avoid this error.')
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \
|
||||||
|
'Please use MMCV >= 1.4.4 for CPU training!'
|
||||||
|
model = revert_sync_batchnorm(model)
|
||||||
|
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
|
||||||
results = single_gpu_test(
|
results = single_gpu_test(
|
||||||
model,
|
model,
|
||||||
data_loader,
|
data_loader,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user