diff --git a/mmocr/utils/__init__.py b/mmocr/utils/__init__.py index b295d078..ef2290f8 100644 --- a/mmocr/utils/__init__.py +++ b/mmocr/utils/__init__.py @@ -9,6 +9,7 @@ from .fileio import list_from_file, list_to_file from .img_util import drop_orientation, is_not_png from .lmdb_util import lmdb_converter from .logger import get_root_logger +from .model import revert_sync_batchnorm from .string_util import StringStrip __all__ = [ @@ -17,5 +18,5 @@ __all__ = [ 'equal_len', 'is_2dlist', 'valid_boundary', 'lmdb_converter', 'drop_orientation', 'convert_annotations', 'is_not_png', 'list_to_file', 'list_from_file', 'is_on_same_line', 'stitch_boxes_into_lines', - 'StringStrip' + 'StringStrip', 'revert_sync_batchnorm' ] diff --git a/mmocr/utils/model.py b/mmocr/utils/model.py new file mode 100644 index 00000000..66978fa3 --- /dev/null +++ b/mmocr/utils/model.py @@ -0,0 +1,49 @@ +import torch + + +class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm): + """A general BatchNorm layer without input dimension check. + + Reproduced from @kapily's work: + (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) + The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc + is `_check_input_dim` that is designed for tensor sanity checks. + The check has been bypassed in this class for the convenience of converting + SyncBatchNorm. + """ + + def _check_input_dim(self, input): + return + + +def revert_sync_batchnorm(module): + """Helper function to convert all `SyncBatchNorm` layers in the model to + `BatchNormXd` layers. + + Reproduced from @kapily's work: + (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) + + Args: + module (nn.Module): The module containing `SyncBatchNorm` layers. + + Returns: + module_output: The converted module with `BatchNormXd` layers. + """ + module_output = module + if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm): + module_output = _BatchNormXd(module.num_features, module.eps, + module.momentum, module.affine, + module.track_running_stats) + if module.affine: + with torch.no_grad(): + module_output.weight = module.weight + module_output.bias = module.bias + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + if hasattr(module, 'qconfig'): + module_output.qconfig = module.qconfig + for name, child in module.named_children(): + module_output.add_module(name, revert_sync_batchnorm(child)) + del module + return module_output diff --git a/mmocr/utils/ocr.py b/mmocr/utils/ocr.py index 84ef5f89..796e3496 100644 --- a/mmocr/utils/ocr.py +++ b/mmocr/utils/ocr.py @@ -19,6 +19,7 @@ from mmocr.datasets.pipelines.crop import crop_img from mmocr.models import build_detector from mmocr.utils.box_util import stitch_boxes_into_lines from mmocr.utils.fileio import list_from_file +from mmocr.utils.model import revert_sync_batchnorm # Parse CLI arguments @@ -324,6 +325,7 @@ class MMOCR: self.detect_model = init_detector( det_config, det_ckpt, device=self.device) + self.detect_model = revert_sync_batchnorm(self.detect_model) self.recog_model = None if self.tr: @@ -338,6 +340,7 @@ class MMOCR: self.recog_model = init_detector( recog_config, recog_ckpt, device=self.device) + self.recog_model = revert_sync_batchnorm(self.recog_model) self.kie_model = None if self.kie: @@ -352,6 +355,7 @@ class MMOCR: kie_cfg = Config.fromfile(kie_config) self.kie_model = build_detector( kie_cfg.model, test_cfg=kie_cfg.get('test_cfg')) + self.kie_model = revert_sync_batchnorm(self.kie_model) self.kie_model.cfg = kie_cfg load_checkpoint(self.kie_model, kie_ckpt, map_location=self.device) diff --git a/tests/test_models/test_detector.py b/tests/test_models/test_detector.py index b397cc02..3f538ffa 100644 --- a/tests/test_models/test_detector.py +++ b/tests/test_models/test_detector.py @@ -9,6 +9,7 @@ import pytest import torch import mmocr.core.evaluation.utils as utils +from mmocr.utils import revert_sync_batchnorm def _demo_mm_inputs(num_kernels=0, input_shape=(1, 3, 300, 300), @@ -192,10 +193,10 @@ def test_ocr_mask_rcnn(cfg_file): def test_panet(cfg_file): model = _get_detector_cfg(cfg_file) model['pretrained'] = None - model['backbone']['norm_cfg']['type'] = 'BN' from mmocr.models import build_detector detector = build_detector(model) + detector = revert_sync_batchnorm(detector) input_shape = (1, 3, 224, 224) num_kernels = 2 @@ -247,10 +248,10 @@ def test_panet(cfg_file): def test_psenet(cfg_file): model = _get_detector_cfg(cfg_file) model['pretrained'] = None - model['backbone']['norm_cfg']['type'] = 'BN' from mmocr.models import build_detector detector = build_detector(model) + detector = revert_sync_batchnorm(detector) input_shape = (1, 3, 224, 224) num_kernels = 7 @@ -289,10 +290,10 @@ def test_psenet(cfg_file): def test_dbnet(cfg_file): model = _get_detector_cfg(cfg_file) model['pretrained'] = None - model['backbone']['norm_cfg']['type'] = 'BN' from mmocr.models import build_detector detector = build_detector(model) + detector = revert_sync_batchnorm(detector) detector = detector.cuda() input_shape = (1, 3, 224, 224) num_kernels = 7 @@ -338,10 +339,10 @@ def test_dbnet(cfg_file): def test_textsnake(cfg_file): model = _get_detector_cfg(cfg_file) model['pretrained'] = None - model['backbone']['norm_cfg']['type'] = 'BN' from mmocr.models import build_detector detector = build_detector(model) + detector = revert_sync_batchnorm(detector) input_shape = (1, 3, 224, 224) num_kernels = 1 mm_inputs = _demo_mm_inputs(num_kernels, input_shape) @@ -394,10 +395,10 @@ def test_textsnake(cfg_file): def test_fcenet(cfg_file): model = _get_detector_cfg(cfg_file) model['pretrained'] = None - model['backbone']['norm_cfg']['type'] = 'BN' from mmocr.models import build_detector detector = build_detector(model) + detector = revert_sync_batchnorm(detector) detector = detector.cuda() fourier_degree = 5 @@ -451,10 +452,10 @@ def test_fcenet(cfg_file): def test_drrg(cfg_file): model = _get_detector_cfg(cfg_file) model['pretrained'] = None - model['backbone']['norm_cfg']['type'] = 'BN' from mmocr.models import build_detector detector = build_detector(model) + detector = revert_sync_batchnorm(detector) input_shape = (1, 3, 224, 224) num_kernels = 1 diff --git a/tests/test_utils/test_model.py b/tests/test_utils/test_model.py new file mode 100644 index 00000000..39d6df13 --- /dev/null +++ b/tests/test_utils/test_model.py @@ -0,0 +1,15 @@ +import pytest +import torch +from mmcv.cnn.bricks import ConvModule + +from mmocr.utils import revert_sync_batchnorm + + +def test_revert_sync_batchnorm(): + conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')) + x = torch.randn(1, 3, 10, 10) + with pytest.raises(ValueError): + y = conv(x) + conv = revert_sync_batchnorm(conv) + y = conv(x) + assert y.shape == (1, 8, 9, 9) diff --git a/tools/test.py b/tools/test.py index 93b02d68..66084256 100755 --- a/tools/test.py +++ b/tools/test.py @@ -16,6 +16,7 @@ from mmdet.datasets import replace_ImageToTensor from mmocr.apis.inference import disable_text_recog_aug_test from mmocr.datasets import build_dataloader, build_dataset from mmocr.models import build_detector +from mmocr.utils import revert_sync_batchnorm def parse_args(): @@ -196,6 +197,7 @@ def main(): # build the model and load checkpoint cfg.model.train_cfg = None model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) + model = revert_sync_batchnorm(model) fp16_cfg = cfg.get('fp16', None) if fp16_cfg is not None: wrap_fp16_model(model)