[Fix] Convert SyncBN to BN when training on DP (#772)

* [Fix] Convert SyncBN to BN when training on DP.

* Modify SyncBN2BN.

* Add SyncBN2BN unit test.

* Resolve some comments.

* use mmcv official revert_sync_batchnorm

* Remove local syncbn2bn unit tests.

* Update mmcv version.

* Fix bugs of gather model tools.

* Modify warnings.

* Modify docker mmcv version.

* Update mmcv version table.
pull/863/head
sennnnn 2021-09-16 00:39:37 +08:00 committed by GitHub
parent 5a7996db26
commit cae715a4b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 26 additions and 30 deletions

View File

@ -75,7 +75,7 @@ def get_final_results(log_json_path, iter_num):
def parse_args():
parser = argparse.ArgumentParser(description='Gather benchmarked models')
parser.add_argument(
'-c', '--config-name', type=str, help='Process the selected config.')
'-f', '--config-name', type=str, help='Process the selected config.')
parser.add_argument(
'-w',
'--work-dir',

View File

@ -1,7 +1,7 @@
ARG PYTORCH="1.6.0"
ARG CUDA="10.1"
ARG CUDNN="7"
ARG MMCV="1.3.12"
ARG MMCV="1.3.13"
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel

View File

@ -3,7 +3,7 @@ ARG CUDA="10.1"
ARG CUDNN="7"
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
ARG MMCV="1.3.12"
ARG MMCV="1.3.13"
ARG MMSEG="0.17.0"
ENV PYTHONUNBUFFERED TRUE

View File

@ -11,7 +11,7 @@ The compatible MMSegmentation and MMCV versions are as below. Please install the
| MMSegmentation version | MMCV version |
|:-------------------:|:-------------------:|
| master | mmcv-full>=1.3.7, <1.4.0 |
| master | mmcv-full>=1.3.13, <1.4.0 |
| 0.17.0 | mmcv-full>=1.3.7, <1.4.0 |
| 0.16.0 | mmcv-full>=1.3.7, <1.4.0 |
| 0.15.0 | mmcv-full>=1.3.7, <1.4.0 |

View File

@ -19,6 +19,14 @@ To trade speed with GPU memory, you may pass in `--options model.backbone.with_c
### Train with a single GPU
official support:
```shell
./tools/dist_train.sh ${CONFIG_FILE} 1 [optional arguments]
```
experimental support (Convert SyncBN to BN):
```shell
python tools/train.py ${CONFIG_FILE} [optional arguments]
```

View File

@ -11,7 +11,7 @@
| MMSegmentation 版本 | MMCV 版本 |
|:-------------------:|:-------------------:|
| master | mmcv-full>=1.3.7, <1.4.0 |
| master | mmcv-full>=1.3.13, <1.4.0 |
| 0.17.0 | mmcv-full>=1.3.7, <1.4.0 |
| 0.16.0 | mmcv-full>=1.3.7, <1.4.0 |
| 0.15.0 | mmcv-full>=1.3.7, <1.4.0 |

View File

@ -6,7 +6,7 @@ from packaging.version import parse
from .version import __version__, version_info
MMCV_MIN = '1.3.7'
MMCV_MIN = '1.3.13'
MMCV_MAX = '1.4.0'

View File

@ -8,7 +8,7 @@ import numpy as np
import pytest
import torch
import torch.nn as nn
from mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm
from mmcv.cnn.utils import revert_sync_batchnorm
def _demo_mm_inputs(input_shape=(2, 3, 8, 16), num_classes=10):
@ -189,28 +189,6 @@ def _check_input_dim(self, inputs):
pass
def _convert_batchnorm(module):
module_output = module
if isinstance(module, SyncBatchNorm):
# to be consistent with SyncBN, we hack dim check function in BN
module_output = _BatchNorm(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output
@patch('torch.nn.modules.batchnorm._BatchNorm._check_input_dim',
_check_input_dim)
@patch('torch.distributed.get_world_size', get_world_size)
@ -241,7 +219,7 @@ def _test_encoder_decoder_forward(cfg_file):
imgs = imgs.cuda()
gt_semantic_seg = gt_semantic_seg.cuda()
else:
segmentor = _convert_batchnorm(segmentor)
segmentor = revert_sync_batchnorm(segmentor)
# Test forward train
losses = segmentor.forward(

View File

@ -4,9 +4,11 @@ import copy
import os
import os.path as osp
import time
import warnings
import mmcv
import torch
from mmcv.cnn.utils import revert_sync_batchnorm
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import Config, DictAction, get_git_hash
@ -137,6 +139,14 @@ def main():
test_cfg=cfg.get('test_cfg'))
model.init_weights()
# SyncBN is not support for DP
if not distributed:
warnings.warn(
'SyncBN is only supported with DDP. To be compatible with DP, '
'we convert SyncBN to BN. Please use dist_train.sh which can '
'avoid this error.')
model = revert_sync_batchnorm(model)
logger.info(model)
datasets = [build_dataset(cfg.data.train)]