mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Convert SyncBN to BN when training on DP (#772)
* [Fix] Convert SyncBN to BN when training on DP. * Modify SyncBN2BN. * Add SyncBN2BN unit test. * Resolve some comments. * use mmcv official revert_sync_batchnorm * Remove local syncbn2bn unit tests. * Update mmcv version. * Fix bugs of gather model tools. * Modify warnings. * Modify docker mmcv version. * Update mmcv version table.
This commit is contained in:
parent
5a7996db26
commit
cae715a4b6
@ -75,7 +75,7 @@ def get_final_results(log_json_path, iter_num):
|
|||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Gather benchmarked models')
|
parser = argparse.ArgumentParser(description='Gather benchmarked models')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-c', '--config-name', type=str, help='Process the selected config.')
|
'-f', '--config-name', type=str, help='Process the selected config.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-w',
|
'-w',
|
||||||
'--work-dir',
|
'--work-dir',
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
ARG PYTORCH="1.6.0"
|
ARG PYTORCH="1.6.0"
|
||||||
ARG CUDA="10.1"
|
ARG CUDA="10.1"
|
||||||
ARG CUDNN="7"
|
ARG CUDNN="7"
|
||||||
ARG MMCV="1.3.12"
|
ARG MMCV="1.3.13"
|
||||||
|
|
||||||
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
|
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ ARG CUDA="10.1"
|
|||||||
ARG CUDNN="7"
|
ARG CUDNN="7"
|
||||||
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
|
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
|
||||||
|
|
||||||
ARG MMCV="1.3.12"
|
ARG MMCV="1.3.13"
|
||||||
ARG MMSEG="0.17.0"
|
ARG MMSEG="0.17.0"
|
||||||
|
|
||||||
ENV PYTHONUNBUFFERED TRUE
|
ENV PYTHONUNBUFFERED TRUE
|
||||||
|
@ -11,7 +11,7 @@ The compatible MMSegmentation and MMCV versions are as below. Please install the
|
|||||||
|
|
||||||
| MMSegmentation version | MMCV version |
|
| MMSegmentation version | MMCV version |
|
||||||
|:-------------------:|:-------------------:|
|
|:-------------------:|:-------------------:|
|
||||||
| master | mmcv-full>=1.3.7, <1.4.0 |
|
| master | mmcv-full>=1.3.13, <1.4.0 |
|
||||||
| 0.17.0 | mmcv-full>=1.3.7, <1.4.0 |
|
| 0.17.0 | mmcv-full>=1.3.7, <1.4.0 |
|
||||||
| 0.16.0 | mmcv-full>=1.3.7, <1.4.0 |
|
| 0.16.0 | mmcv-full>=1.3.7, <1.4.0 |
|
||||||
| 0.15.0 | mmcv-full>=1.3.7, <1.4.0 |
|
| 0.15.0 | mmcv-full>=1.3.7, <1.4.0 |
|
||||||
|
@ -19,6 +19,14 @@ To trade speed with GPU memory, you may pass in `--options model.backbone.with_c
|
|||||||
|
|
||||||
### Train with a single GPU
|
### Train with a single GPU
|
||||||
|
|
||||||
|
official support:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
./tools/dist_train.sh ${CONFIG_FILE} 1 [optional arguments]
|
||||||
|
```
|
||||||
|
|
||||||
|
experimental support (Convert SyncBN to BN):
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
python tools/train.py ${CONFIG_FILE} [optional arguments]
|
python tools/train.py ${CONFIG_FILE} [optional arguments]
|
||||||
```
|
```
|
||||||
|
@ -11,7 +11,7 @@
|
|||||||
|
|
||||||
| MMSegmentation 版本 | MMCV 版本 |
|
| MMSegmentation 版本 | MMCV 版本 |
|
||||||
|:-------------------:|:-------------------:|
|
|:-------------------:|:-------------------:|
|
||||||
| master | mmcv-full>=1.3.7, <1.4.0 |
|
| master | mmcv-full>=1.3.13, <1.4.0 |
|
||||||
| 0.17.0 | mmcv-full>=1.3.7, <1.4.0 |
|
| 0.17.0 | mmcv-full>=1.3.7, <1.4.0 |
|
||||||
| 0.16.0 | mmcv-full>=1.3.7, <1.4.0 |
|
| 0.16.0 | mmcv-full>=1.3.7, <1.4.0 |
|
||||||
| 0.15.0 | mmcv-full>=1.3.7, <1.4.0 |
|
| 0.15.0 | mmcv-full>=1.3.7, <1.4.0 |
|
||||||
|
@ -6,7 +6,7 @@ from packaging.version import parse
|
|||||||
|
|
||||||
from .version import __version__, version_info
|
from .version import __version__, version_info
|
||||||
|
|
||||||
MMCV_MIN = '1.3.7'
|
MMCV_MIN = '1.3.13'
|
||||||
MMCV_MAX = '1.4.0'
|
MMCV_MAX = '1.4.0'
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm
|
from mmcv.cnn.utils import revert_sync_batchnorm
|
||||||
|
|
||||||
|
|
||||||
def _demo_mm_inputs(input_shape=(2, 3, 8, 16), num_classes=10):
|
def _demo_mm_inputs(input_shape=(2, 3, 8, 16), num_classes=10):
|
||||||
@ -189,28 +189,6 @@ def _check_input_dim(self, inputs):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _convert_batchnorm(module):
|
|
||||||
module_output = module
|
|
||||||
if isinstance(module, SyncBatchNorm):
|
|
||||||
# to be consistent with SyncBN, we hack dim check function in BN
|
|
||||||
module_output = _BatchNorm(module.num_features, module.eps,
|
|
||||||
module.momentum, module.affine,
|
|
||||||
module.track_running_stats)
|
|
||||||
if module.affine:
|
|
||||||
module_output.weight.data = module.weight.data.clone().detach()
|
|
||||||
module_output.bias.data = module.bias.data.clone().detach()
|
|
||||||
# keep requires_grad unchanged
|
|
||||||
module_output.weight.requires_grad = module.weight.requires_grad
|
|
||||||
module_output.bias.requires_grad = module.bias.requires_grad
|
|
||||||
module_output.running_mean = module.running_mean
|
|
||||||
module_output.running_var = module.running_var
|
|
||||||
module_output.num_batches_tracked = module.num_batches_tracked
|
|
||||||
for name, child in module.named_children():
|
|
||||||
module_output.add_module(name, _convert_batchnorm(child))
|
|
||||||
del module
|
|
||||||
return module_output
|
|
||||||
|
|
||||||
|
|
||||||
@patch('torch.nn.modules.batchnorm._BatchNorm._check_input_dim',
|
@patch('torch.nn.modules.batchnorm._BatchNorm._check_input_dim',
|
||||||
_check_input_dim)
|
_check_input_dim)
|
||||||
@patch('torch.distributed.get_world_size', get_world_size)
|
@patch('torch.distributed.get_world_size', get_world_size)
|
||||||
@ -241,7 +219,7 @@ def _test_encoder_decoder_forward(cfg_file):
|
|||||||
imgs = imgs.cuda()
|
imgs = imgs.cuda()
|
||||||
gt_semantic_seg = gt_semantic_seg.cuda()
|
gt_semantic_seg = gt_semantic_seg.cuda()
|
||||||
else:
|
else:
|
||||||
segmentor = _convert_batchnorm(segmentor)
|
segmentor = revert_sync_batchnorm(segmentor)
|
||||||
|
|
||||||
# Test forward train
|
# Test forward train
|
||||||
losses = segmentor.forward(
|
losses = segmentor.forward(
|
||||||
|
@ -4,9 +4,11 @@ import copy
|
|||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import time
|
import time
|
||||||
|
import warnings
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import torch
|
import torch
|
||||||
|
from mmcv.cnn.utils import revert_sync_batchnorm
|
||||||
from mmcv.runner import get_dist_info, init_dist
|
from mmcv.runner import get_dist_info, init_dist
|
||||||
from mmcv.utils import Config, DictAction, get_git_hash
|
from mmcv.utils import Config, DictAction, get_git_hash
|
||||||
|
|
||||||
@ -137,6 +139,14 @@ def main():
|
|||||||
test_cfg=cfg.get('test_cfg'))
|
test_cfg=cfg.get('test_cfg'))
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
|
|
||||||
|
# SyncBN is not support for DP
|
||||||
|
if not distributed:
|
||||||
|
warnings.warn(
|
||||||
|
'SyncBN is only supported with DDP. To be compatible with DP, '
|
||||||
|
'we convert SyncBN to BN. Please use dist_train.sh which can '
|
||||||
|
'avoid this error.')
|
||||||
|
model = revert_sync_batchnorm(model)
|
||||||
|
|
||||||
logger.info(model)
|
logger.info(model)
|
||||||
|
|
||||||
datasets = [build_dataset(cfg.data.train)]
|
datasets = [build_dataset(cfg.data.train)]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user