[Fix] Fix the bug that mit cannot process init_cfg (#1102)

* [Fix] Fix the bug that mit cannot process init_cfg

* fix error
pull/1801/head
Rockey 2021-12-09 15:17:43 +08:00 committed by GitHub
parent 0d3ee4b520
commit a370777e3b
2 changed files with 67 additions and 19 deletions

View File

@ -9,9 +9,8 @@ from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import MultiheadAttention
from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
trunc_normal_init)
from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint
from mmcv.runner import BaseModule, ModuleList, Sequential
from ...utils import get_root_logger
from ..builder import BACKBONES
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw
@ -341,16 +340,18 @@ class MixVisionTransformer(BaseModule):
norm_cfg=dict(type='LN', eps=1e-6),
pretrained=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
super(MixVisionTransformer, self).__init__(init_cfg=init_cfg)
if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
else:
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.embed_dims = embed_dims
self.num_stages = num_stages
self.num_layers = num_layers
self.num_heads = num_heads
@ -362,7 +363,6 @@ class MixVisionTransformer(BaseModule):
self.out_indices = out_indices
assert max(out_indices) < self.num_stages
self.pretrained = pretrained
# transformer encoder
dpr = [
@ -401,7 +401,7 @@ class MixVisionTransformer(BaseModule):
cur += num_layer
def init_weights(self):
if self.pretrained is None:
if self.init_cfg is None:
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
@ -413,16 +413,8 @@ class MixVisionTransformer(BaseModule):
fan_out //= m.groups
normal_init(
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
elif isinstance(self.pretrained, str):
logger = get_root_logger()
checkpoint = _load_checkpoint(
self.pretrained, logger=logger, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
self.load_state_dict(state_dict, False)
else:
super(MixVisionTransformer, self).init_weights()
def forward(self, x):
outs = []

View File

@ -55,3 +55,59 @@ def test_mit():
# Out identity
outs = MHA(temp, hw_shape, temp)
assert out.shape == (1, token_len, 64)
def test_mit_init():
path = 'PATH_THAT_DO_NOT_EXIST'
# Test all combinations of pretrained and init_cfg
# pretrained=None, init_cfg=None
model = MixVisionTransformer(pretrained=None, init_cfg=None)
assert model.init_cfg is None
model.init_weights()
# pretrained=None
# init_cfg loads pretrain from an non-existent file
model = MixVisionTransformer(
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
# Test loading a checkpoint from an non-existent file
with pytest.raises(OSError):
model.init_weights()
# pretrained=None
# init_cfg=123, whose type is unsupported
model = MixVisionTransformer(pretrained=None, init_cfg=123)
with pytest.raises(TypeError):
model.init_weights()
# pretrained loads pretrain from an non-existent file
# init_cfg=None
model = MixVisionTransformer(pretrained=path, init_cfg=None)
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
# Test loading a checkpoint from an non-existent file
with pytest.raises(OSError):
model.init_weights()
# pretrained loads pretrain from an non-existent file
# init_cfg loads pretrain from an non-existent file
with pytest.raises(AssertionError):
MixVisionTransformer(
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
with pytest.raises(AssertionError):
MixVisionTransformer(pretrained=path, init_cfg=123)
# pretrain=123, whose type is unsupported
# init_cfg=None
with pytest.raises(TypeError):
MixVisionTransformer(pretrained=123, init_cfg=None)
# pretrain=123, whose type is unsupported
# init_cfg loads pretrain from an non-existent file
with pytest.raises(AssertionError):
MixVisionTransformer(
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
# pretrain=123, whose type is unsupported
# init_cfg=123, whose type is unsupported
with pytest.raises(AssertionError):
MixVisionTransformer(pretrained=123, init_cfg=123)