mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Fix] Replace SyncBN with BN for inference (#420)
* add revert_sync_batchnorm * replace SyncBN in inference and test scripts * add tests * hide BatchNormXd
This commit is contained in:
parent
532e8f808d
commit
7bbb14f0d1
@ -9,6 +9,7 @@ from .fileio import list_from_file, list_to_file
|
|||||||
from .img_util import drop_orientation, is_not_png
|
from .img_util import drop_orientation, is_not_png
|
||||||
from .lmdb_util import lmdb_converter
|
from .lmdb_util import lmdb_converter
|
||||||
from .logger import get_root_logger
|
from .logger import get_root_logger
|
||||||
|
from .model import revert_sync_batchnorm
|
||||||
from .string_util import StringStrip
|
from .string_util import StringStrip
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -17,5 +18,5 @@ __all__ = [
|
|||||||
'equal_len', 'is_2dlist', 'valid_boundary', 'lmdb_converter',
|
'equal_len', 'is_2dlist', 'valid_boundary', 'lmdb_converter',
|
||||||
'drop_orientation', 'convert_annotations', 'is_not_png', 'list_to_file',
|
'drop_orientation', 'convert_annotations', 'is_not_png', 'list_to_file',
|
||||||
'list_from_file', 'is_on_same_line', 'stitch_boxes_into_lines',
|
'list_from_file', 'is_on_same_line', 'stitch_boxes_into_lines',
|
||||||
'StringStrip'
|
'StringStrip', 'revert_sync_batchnorm'
|
||||||
]
|
]
|
||||||
|
49
mmocr/utils/model.py
Normal file
49
mmocr/utils/model.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
|
||||||
|
"""A general BatchNorm layer without input dimension check.
|
||||||
|
|
||||||
|
Reproduced from @kapily's work:
|
||||||
|
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
|
||||||
|
The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
|
||||||
|
is `_check_input_dim` that is designed for tensor sanity checks.
|
||||||
|
The check has been bypassed in this class for the convenience of converting
|
||||||
|
SyncBatchNorm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _check_input_dim(self, input):
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def revert_sync_batchnorm(module):
|
||||||
|
"""Helper function to convert all `SyncBatchNorm` layers in the model to
|
||||||
|
`BatchNormXd` layers.
|
||||||
|
|
||||||
|
Reproduced from @kapily's work:
|
||||||
|
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): The module containing `SyncBatchNorm` layers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
module_output: The converted module with `BatchNormXd` layers.
|
||||||
|
"""
|
||||||
|
module_output = module
|
||||||
|
if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm):
|
||||||
|
module_output = _BatchNormXd(module.num_features, module.eps,
|
||||||
|
module.momentum, module.affine,
|
||||||
|
module.track_running_stats)
|
||||||
|
if module.affine:
|
||||||
|
with torch.no_grad():
|
||||||
|
module_output.weight = module.weight
|
||||||
|
module_output.bias = module.bias
|
||||||
|
module_output.running_mean = module.running_mean
|
||||||
|
module_output.running_var = module.running_var
|
||||||
|
module_output.num_batches_tracked = module.num_batches_tracked
|
||||||
|
if hasattr(module, 'qconfig'):
|
||||||
|
module_output.qconfig = module.qconfig
|
||||||
|
for name, child in module.named_children():
|
||||||
|
module_output.add_module(name, revert_sync_batchnorm(child))
|
||||||
|
del module
|
||||||
|
return module_output
|
@ -19,6 +19,7 @@ from mmocr.datasets.pipelines.crop import crop_img
|
|||||||
from mmocr.models import build_detector
|
from mmocr.models import build_detector
|
||||||
from mmocr.utils.box_util import stitch_boxes_into_lines
|
from mmocr.utils.box_util import stitch_boxes_into_lines
|
||||||
from mmocr.utils.fileio import list_from_file
|
from mmocr.utils.fileio import list_from_file
|
||||||
|
from mmocr.utils.model import revert_sync_batchnorm
|
||||||
|
|
||||||
|
|
||||||
# Parse CLI arguments
|
# Parse CLI arguments
|
||||||
@ -324,6 +325,7 @@ class MMOCR:
|
|||||||
|
|
||||||
self.detect_model = init_detector(
|
self.detect_model = init_detector(
|
||||||
det_config, det_ckpt, device=self.device)
|
det_config, det_ckpt, device=self.device)
|
||||||
|
self.detect_model = revert_sync_batchnorm(self.detect_model)
|
||||||
|
|
||||||
self.recog_model = None
|
self.recog_model = None
|
||||||
if self.tr:
|
if self.tr:
|
||||||
@ -338,6 +340,7 @@ class MMOCR:
|
|||||||
|
|
||||||
self.recog_model = init_detector(
|
self.recog_model = init_detector(
|
||||||
recog_config, recog_ckpt, device=self.device)
|
recog_config, recog_ckpt, device=self.device)
|
||||||
|
self.recog_model = revert_sync_batchnorm(self.recog_model)
|
||||||
|
|
||||||
self.kie_model = None
|
self.kie_model = None
|
||||||
if self.kie:
|
if self.kie:
|
||||||
@ -352,6 +355,7 @@ class MMOCR:
|
|||||||
kie_cfg = Config.fromfile(kie_config)
|
kie_cfg = Config.fromfile(kie_config)
|
||||||
self.kie_model = build_detector(
|
self.kie_model = build_detector(
|
||||||
kie_cfg.model, test_cfg=kie_cfg.get('test_cfg'))
|
kie_cfg.model, test_cfg=kie_cfg.get('test_cfg'))
|
||||||
|
self.kie_model = revert_sync_batchnorm(self.kie_model)
|
||||||
self.kie_model.cfg = kie_cfg
|
self.kie_model.cfg = kie_cfg
|
||||||
load_checkpoint(self.kie_model, kie_ckpt, map_location=self.device)
|
load_checkpoint(self.kie_model, kie_ckpt, map_location=self.device)
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import mmocr.core.evaluation.utils as utils
|
import mmocr.core.evaluation.utils as utils
|
||||||
|
from mmocr.utils import revert_sync_batchnorm
|
||||||
|
|
||||||
|
|
||||||
def _demo_mm_inputs(num_kernels=0, input_shape=(1, 3, 300, 300),
|
def _demo_mm_inputs(num_kernels=0, input_shape=(1, 3, 300, 300),
|
||||||
@ -192,10 +193,10 @@ def test_ocr_mask_rcnn(cfg_file):
|
|||||||
def test_panet(cfg_file):
|
def test_panet(cfg_file):
|
||||||
model = _get_detector_cfg(cfg_file)
|
model = _get_detector_cfg(cfg_file)
|
||||||
model['pretrained'] = None
|
model['pretrained'] = None
|
||||||
model['backbone']['norm_cfg']['type'] = 'BN'
|
|
||||||
|
|
||||||
from mmocr.models import build_detector
|
from mmocr.models import build_detector
|
||||||
detector = build_detector(model)
|
detector = build_detector(model)
|
||||||
|
detector = revert_sync_batchnorm(detector)
|
||||||
|
|
||||||
input_shape = (1, 3, 224, 224)
|
input_shape = (1, 3, 224, 224)
|
||||||
num_kernels = 2
|
num_kernels = 2
|
||||||
@ -247,10 +248,10 @@ def test_panet(cfg_file):
|
|||||||
def test_psenet(cfg_file):
|
def test_psenet(cfg_file):
|
||||||
model = _get_detector_cfg(cfg_file)
|
model = _get_detector_cfg(cfg_file)
|
||||||
model['pretrained'] = None
|
model['pretrained'] = None
|
||||||
model['backbone']['norm_cfg']['type'] = 'BN'
|
|
||||||
|
|
||||||
from mmocr.models import build_detector
|
from mmocr.models import build_detector
|
||||||
detector = build_detector(model)
|
detector = build_detector(model)
|
||||||
|
detector = revert_sync_batchnorm(detector)
|
||||||
|
|
||||||
input_shape = (1, 3, 224, 224)
|
input_shape = (1, 3, 224, 224)
|
||||||
num_kernels = 7
|
num_kernels = 7
|
||||||
@ -289,10 +290,10 @@ def test_psenet(cfg_file):
|
|||||||
def test_dbnet(cfg_file):
|
def test_dbnet(cfg_file):
|
||||||
model = _get_detector_cfg(cfg_file)
|
model = _get_detector_cfg(cfg_file)
|
||||||
model['pretrained'] = None
|
model['pretrained'] = None
|
||||||
model['backbone']['norm_cfg']['type'] = 'BN'
|
|
||||||
|
|
||||||
from mmocr.models import build_detector
|
from mmocr.models import build_detector
|
||||||
detector = build_detector(model)
|
detector = build_detector(model)
|
||||||
|
detector = revert_sync_batchnorm(detector)
|
||||||
detector = detector.cuda()
|
detector = detector.cuda()
|
||||||
input_shape = (1, 3, 224, 224)
|
input_shape = (1, 3, 224, 224)
|
||||||
num_kernels = 7
|
num_kernels = 7
|
||||||
@ -338,10 +339,10 @@ def test_dbnet(cfg_file):
|
|||||||
def test_textsnake(cfg_file):
|
def test_textsnake(cfg_file):
|
||||||
model = _get_detector_cfg(cfg_file)
|
model = _get_detector_cfg(cfg_file)
|
||||||
model['pretrained'] = None
|
model['pretrained'] = None
|
||||||
model['backbone']['norm_cfg']['type'] = 'BN'
|
|
||||||
|
|
||||||
from mmocr.models import build_detector
|
from mmocr.models import build_detector
|
||||||
detector = build_detector(model)
|
detector = build_detector(model)
|
||||||
|
detector = revert_sync_batchnorm(detector)
|
||||||
input_shape = (1, 3, 224, 224)
|
input_shape = (1, 3, 224, 224)
|
||||||
num_kernels = 1
|
num_kernels = 1
|
||||||
mm_inputs = _demo_mm_inputs(num_kernels, input_shape)
|
mm_inputs = _demo_mm_inputs(num_kernels, input_shape)
|
||||||
@ -394,10 +395,10 @@ def test_textsnake(cfg_file):
|
|||||||
def test_fcenet(cfg_file):
|
def test_fcenet(cfg_file):
|
||||||
model = _get_detector_cfg(cfg_file)
|
model = _get_detector_cfg(cfg_file)
|
||||||
model['pretrained'] = None
|
model['pretrained'] = None
|
||||||
model['backbone']['norm_cfg']['type'] = 'BN'
|
|
||||||
|
|
||||||
from mmocr.models import build_detector
|
from mmocr.models import build_detector
|
||||||
detector = build_detector(model)
|
detector = build_detector(model)
|
||||||
|
detector = revert_sync_batchnorm(detector)
|
||||||
detector = detector.cuda()
|
detector = detector.cuda()
|
||||||
|
|
||||||
fourier_degree = 5
|
fourier_degree = 5
|
||||||
@ -451,10 +452,10 @@ def test_fcenet(cfg_file):
|
|||||||
def test_drrg(cfg_file):
|
def test_drrg(cfg_file):
|
||||||
model = _get_detector_cfg(cfg_file)
|
model = _get_detector_cfg(cfg_file)
|
||||||
model['pretrained'] = None
|
model['pretrained'] = None
|
||||||
model['backbone']['norm_cfg']['type'] = 'BN'
|
|
||||||
|
|
||||||
from mmocr.models import build_detector
|
from mmocr.models import build_detector
|
||||||
detector = build_detector(model)
|
detector = build_detector(model)
|
||||||
|
detector = revert_sync_batchnorm(detector)
|
||||||
|
|
||||||
input_shape = (1, 3, 224, 224)
|
input_shape = (1, 3, 224, 224)
|
||||||
num_kernels = 1
|
num_kernels = 1
|
||||||
|
15
tests/test_utils/test_model.py
Normal file
15
tests/test_utils/test_model.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from mmcv.cnn.bricks import ConvModule
|
||||||
|
|
||||||
|
from mmocr.utils import revert_sync_batchnorm
|
||||||
|
|
||||||
|
|
||||||
|
def test_revert_sync_batchnorm():
|
||||||
|
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN'))
|
||||||
|
x = torch.randn(1, 3, 10, 10)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
y = conv(x)
|
||||||
|
conv = revert_sync_batchnorm(conv)
|
||||||
|
y = conv(x)
|
||||||
|
assert y.shape == (1, 8, 9, 9)
|
@ -16,6 +16,7 @@ from mmdet.datasets import replace_ImageToTensor
|
|||||||
from mmocr.apis.inference import disable_text_recog_aug_test
|
from mmocr.apis.inference import disable_text_recog_aug_test
|
||||||
from mmocr.datasets import build_dataloader, build_dataset
|
from mmocr.datasets import build_dataloader, build_dataset
|
||||||
from mmocr.models import build_detector
|
from mmocr.models import build_detector
|
||||||
|
from mmocr.utils import revert_sync_batchnorm
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -196,6 +197,7 @@ def main():
|
|||||||
# build the model and load checkpoint
|
# build the model and load checkpoint
|
||||||
cfg.model.train_cfg = None
|
cfg.model.train_cfg = None
|
||||||
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
|
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
|
||||||
|
model = revert_sync_batchnorm(model)
|
||||||
fp16_cfg = cfg.get('fp16', None)
|
fp16_cfg = cfg.get('fp16', None)
|
||||||
if fp16_cfg is not None:
|
if fp16_cfg is not None:
|
||||||
wrap_fp16_model(model)
|
wrap_fp16_model(model)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user