[Fix] Convert SyncBN to BN when training on DP (#772)
* [Fix] Convert SyncBN to BN when training on DP. * Modify SyncBN2BN. * Add SyncBN2BN unit test. * Resolve some comments. * use mmcv official revert_sync_batchnorm * Remove local syncbn2bn unit tests. * Update mmcv version. * Fix bugs of gather model tools. * Modify warnings. * Modify docker mmcv version. * Update mmcv version table.pull/863/head
parent
5a7996db26
commit
cae715a4b6
|
@ -75,7 +75,7 @@ def get_final_results(log_json_path, iter_num):
|
|||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Gather benchmarked models')
|
||||
parser.add_argument(
|
||||
'-c', '--config-name', type=str, help='Process the selected config.')
|
||||
'-f', '--config-name', type=str, help='Process the selected config.')
|
||||
parser.add_argument(
|
||||
'-w',
|
||||
'--work-dir',
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
ARG PYTORCH="1.6.0"
|
||||
ARG CUDA="10.1"
|
||||
ARG CUDNN="7"
|
||||
ARG MMCV="1.3.12"
|
||||
ARG MMCV="1.3.13"
|
||||
|
||||
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ ARG CUDA="10.1"
|
|||
ARG CUDNN="7"
|
||||
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
|
||||
|
||||
ARG MMCV="1.3.12"
|
||||
ARG MMCV="1.3.13"
|
||||
ARG MMSEG="0.17.0"
|
||||
|
||||
ENV PYTHONUNBUFFERED TRUE
|
||||
|
|
|
@ -11,7 +11,7 @@ The compatible MMSegmentation and MMCV versions are as below. Please install the
|
|||
|
||||
| MMSegmentation version | MMCV version |
|
||||
|:-------------------:|:-------------------:|
|
||||
| master | mmcv-full>=1.3.7, <1.4.0 |
|
||||
| master | mmcv-full>=1.3.13, <1.4.0 |
|
||||
| 0.17.0 | mmcv-full>=1.3.7, <1.4.0 |
|
||||
| 0.16.0 | mmcv-full>=1.3.7, <1.4.0 |
|
||||
| 0.15.0 | mmcv-full>=1.3.7, <1.4.0 |
|
||||
|
|
|
@ -19,6 +19,14 @@ To trade speed with GPU memory, you may pass in `--options model.backbone.with_c
|
|||
|
||||
### Train with a single GPU
|
||||
|
||||
official support:
|
||||
|
||||
```shell
|
||||
./tools/dist_train.sh ${CONFIG_FILE} 1 [optional arguments]
|
||||
```
|
||||
|
||||
experimental support (Convert SyncBN to BN):
|
||||
|
||||
```shell
|
||||
python tools/train.py ${CONFIG_FILE} [optional arguments]
|
||||
```
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
| MMSegmentation 版本 | MMCV 版本 |
|
||||
|:-------------------:|:-------------------:|
|
||||
| master | mmcv-full>=1.3.7, <1.4.0 |
|
||||
| master | mmcv-full>=1.3.13, <1.4.0 |
|
||||
| 0.17.0 | mmcv-full>=1.3.7, <1.4.0 |
|
||||
| 0.16.0 | mmcv-full>=1.3.7, <1.4.0 |
|
||||
| 0.15.0 | mmcv-full>=1.3.7, <1.4.0 |
|
||||
|
|
|
@ -6,7 +6,7 @@ from packaging.version import parse
|
|||
|
||||
from .version import __version__, version_info
|
||||
|
||||
MMCV_MIN = '1.3.7'
|
||||
MMCV_MIN = '1.3.13'
|
||||
MMCV_MAX = '1.4.0'
|
||||
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm
|
||||
from mmcv.cnn.utils import revert_sync_batchnorm
|
||||
|
||||
|
||||
def _demo_mm_inputs(input_shape=(2, 3, 8, 16), num_classes=10):
|
||||
|
@ -189,28 +189,6 @@ def _check_input_dim(self, inputs):
|
|||
pass
|
||||
|
||||
|
||||
def _convert_batchnorm(module):
|
||||
module_output = module
|
||||
if isinstance(module, SyncBatchNorm):
|
||||
# to be consistent with SyncBN, we hack dim check function in BN
|
||||
module_output = _BatchNorm(module.num_features, module.eps,
|
||||
module.momentum, module.affine,
|
||||
module.track_running_stats)
|
||||
if module.affine:
|
||||
module_output.weight.data = module.weight.data.clone().detach()
|
||||
module_output.bias.data = module.bias.data.clone().detach()
|
||||
# keep requires_grad unchanged
|
||||
module_output.weight.requires_grad = module.weight.requires_grad
|
||||
module_output.bias.requires_grad = module.bias.requires_grad
|
||||
module_output.running_mean = module.running_mean
|
||||
module_output.running_var = module.running_var
|
||||
module_output.num_batches_tracked = module.num_batches_tracked
|
||||
for name, child in module.named_children():
|
||||
module_output.add_module(name, _convert_batchnorm(child))
|
||||
del module
|
||||
return module_output
|
||||
|
||||
|
||||
@patch('torch.nn.modules.batchnorm._BatchNorm._check_input_dim',
|
||||
_check_input_dim)
|
||||
@patch('torch.distributed.get_world_size', get_world_size)
|
||||
|
@ -241,7 +219,7 @@ def _test_encoder_decoder_forward(cfg_file):
|
|||
imgs = imgs.cuda()
|
||||
gt_semantic_seg = gt_semantic_seg.cuda()
|
||||
else:
|
||||
segmentor = _convert_batchnorm(segmentor)
|
||||
segmentor = revert_sync_batchnorm(segmentor)
|
||||
|
||||
# Test forward train
|
||||
losses = segmentor.forward(
|
||||
|
|
|
@ -4,9 +4,11 @@ import copy
|
|||
import os
|
||||
import os.path as osp
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
from mmcv.cnn.utils import revert_sync_batchnorm
|
||||
from mmcv.runner import get_dist_info, init_dist
|
||||
from mmcv.utils import Config, DictAction, get_git_hash
|
||||
|
||||
|
@ -137,6 +139,14 @@ def main():
|
|||
test_cfg=cfg.get('test_cfg'))
|
||||
model.init_weights()
|
||||
|
||||
# SyncBN is not support for DP
|
||||
if not distributed:
|
||||
warnings.warn(
|
||||
'SyncBN is only supported with DDP. To be compatible with DP, '
|
||||
'we convert SyncBN to BN. Please use dist_train.sh which can '
|
||||
'avoid this error.')
|
||||
model = revert_sync_batchnorm(model)
|
||||
|
||||
logger.info(model)
|
||||
|
||||
datasets = [build_dataset(cfg.data.train)]
|
||||
|
|
Loading…
Reference in New Issue