[Fix] Convert SyncBN to BN when training on DP (#772)

* [Fix] Convert SyncBN to BN when training on DP.

* Modify SyncBN2BN.

* Add SyncBN2BN unit test.

* Resolve some comments.

* use mmcv official revert_sync_batchnorm

* Remove local syncbn2bn unit tests.

* Update mmcv version.

* Fix bugs of gather model tools.

* Modify warnings.

* Modify docker mmcv version.

* Update mmcv version table.
This commit is contained in:
sennnnn 2021-09-16 00:39:37 +08:00 committed by GitHub
parent 5a7996db26
commit cae715a4b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 26 additions and 30 deletions

View File

@ -75,7 +75,7 @@ def get_final_results(log_json_path, iter_num):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Gather benchmarked models') parser = argparse.ArgumentParser(description='Gather benchmarked models')
parser.add_argument( parser.add_argument(
'-c', '--config-name', type=str, help='Process the selected config.') '-f', '--config-name', type=str, help='Process the selected config.')
parser.add_argument( parser.add_argument(
'-w', '-w',
'--work-dir', '--work-dir',

View File

@ -1,7 +1,7 @@
ARG PYTORCH="1.6.0" ARG PYTORCH="1.6.0"
ARG CUDA="10.1" ARG CUDA="10.1"
ARG CUDNN="7" ARG CUDNN="7"
ARG MMCV="1.3.12" ARG MMCV="1.3.13"
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel

View File

@ -3,7 +3,7 @@ ARG CUDA="10.1"
ARG CUDNN="7" ARG CUDNN="7"
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
ARG MMCV="1.3.12" ARG MMCV="1.3.13"
ARG MMSEG="0.17.0" ARG MMSEG="0.17.0"
ENV PYTHONUNBUFFERED TRUE ENV PYTHONUNBUFFERED TRUE

View File

@ -11,7 +11,7 @@ The compatible MMSegmentation and MMCV versions are as below. Please install the
| MMSegmentation version | MMCV version | | MMSegmentation version | MMCV version |
|:-------------------:|:-------------------:| |:-------------------:|:-------------------:|
| master | mmcv-full>=1.3.7, <1.4.0 | | master | mmcv-full>=1.3.13, <1.4.0 |
| 0.17.0 | mmcv-full>=1.3.7, <1.4.0 | | 0.17.0 | mmcv-full>=1.3.7, <1.4.0 |
| 0.16.0 | mmcv-full>=1.3.7, <1.4.0 | | 0.16.0 | mmcv-full>=1.3.7, <1.4.0 |
| 0.15.0 | mmcv-full>=1.3.7, <1.4.0 | | 0.15.0 | mmcv-full>=1.3.7, <1.4.0 |

View File

@ -19,6 +19,14 @@ To trade speed with GPU memory, you may pass in `--options model.backbone.with_c
### Train with a single GPU ### Train with a single GPU
official support:
```shell
./tools/dist_train.sh ${CONFIG_FILE} 1 [optional arguments]
```
experimental support (Convert SyncBN to BN):
```shell ```shell
python tools/train.py ${CONFIG_FILE} [optional arguments] python tools/train.py ${CONFIG_FILE} [optional arguments]
``` ```

View File

@ -11,7 +11,7 @@
| MMSegmentation 版本 | MMCV 版本 | | MMSegmentation 版本 | MMCV 版本 |
|:-------------------:|:-------------------:| |:-------------------:|:-------------------:|
| master | mmcv-full>=1.3.7, <1.4.0 | | master | mmcv-full>=1.3.13, <1.4.0 |
| 0.17.0 | mmcv-full>=1.3.7, <1.4.0 | | 0.17.0 | mmcv-full>=1.3.7, <1.4.0 |
| 0.16.0 | mmcv-full>=1.3.7, <1.4.0 | | 0.16.0 | mmcv-full>=1.3.7, <1.4.0 |
| 0.15.0 | mmcv-full>=1.3.7, <1.4.0 | | 0.15.0 | mmcv-full>=1.3.7, <1.4.0 |

View File

@ -6,7 +6,7 @@ from packaging.version import parse
from .version import __version__, version_info from .version import __version__, version_info
MMCV_MIN = '1.3.7' MMCV_MIN = '1.3.13'
MMCV_MAX = '1.4.0' MMCV_MAX = '1.4.0'

View File

@ -8,7 +8,7 @@ import numpy as np
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm from mmcv.cnn.utils import revert_sync_batchnorm
def _demo_mm_inputs(input_shape=(2, 3, 8, 16), num_classes=10): def _demo_mm_inputs(input_shape=(2, 3, 8, 16), num_classes=10):
@ -189,28 +189,6 @@ def _check_input_dim(self, inputs):
pass pass
def _convert_batchnorm(module):
module_output = module
if isinstance(module, SyncBatchNorm):
# to be consistent with SyncBN, we hack dim check function in BN
module_output = _BatchNorm(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output
@patch('torch.nn.modules.batchnorm._BatchNorm._check_input_dim', @patch('torch.nn.modules.batchnorm._BatchNorm._check_input_dim',
_check_input_dim) _check_input_dim)
@patch('torch.distributed.get_world_size', get_world_size) @patch('torch.distributed.get_world_size', get_world_size)
@ -241,7 +219,7 @@ def _test_encoder_decoder_forward(cfg_file):
imgs = imgs.cuda() imgs = imgs.cuda()
gt_semantic_seg = gt_semantic_seg.cuda() gt_semantic_seg = gt_semantic_seg.cuda()
else: else:
segmentor = _convert_batchnorm(segmentor) segmentor = revert_sync_batchnorm(segmentor)
# Test forward train # Test forward train
losses = segmentor.forward( losses = segmentor.forward(

View File

@ -4,9 +4,11 @@ import copy
import os import os
import os.path as osp import os.path as osp
import time import time
import warnings
import mmcv import mmcv
import torch import torch
from mmcv.cnn.utils import revert_sync_batchnorm
from mmcv.runner import get_dist_info, init_dist from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import Config, DictAction, get_git_hash from mmcv.utils import Config, DictAction, get_git_hash
@ -137,6 +139,14 @@ def main():
test_cfg=cfg.get('test_cfg')) test_cfg=cfg.get('test_cfg'))
model.init_weights() model.init_weights()
# SyncBN is not support for DP
if not distributed:
warnings.warn(
'SyncBN is only supported with DDP. To be compatible with DP, '
'we convert SyncBN to BN. Please use dist_train.sh which can '
'avoid this error.')
model = revert_sync_batchnorm(model)
logger.info(model) logger.info(model)
datasets = [build_dataset(cfg.data.train)] datasets = [build_dataset(cfg.data.train)]