[Fix] Fix the bug that mit cannot process init_cfg (#1102)
* [Fix] Fix the bug that mit cannot process init_cfg * fix errorpull/1801/head
parent
0d3ee4b520
commit
a370777e3b
|
@ -9,9 +9,8 @@ from mmcv.cnn.bricks.drop import build_dropout
|
|||
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
||||
from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
|
||||
trunc_normal_init)
|
||||
from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint
|
||||
from mmcv.runner import BaseModule, ModuleList, Sequential
|
||||
|
||||
from ...utils import get_root_logger
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw
|
||||
|
||||
|
@ -341,16 +340,18 @@ class MixVisionTransformer(BaseModule):
|
|||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
super(MixVisionTransformer, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(pretrained, str) or pretrained is None:
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
else:
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
|
||||
self.num_stages = num_stages
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
|
@ -362,7 +363,6 @@ class MixVisionTransformer(BaseModule):
|
|||
|
||||
self.out_indices = out_indices
|
||||
assert max(out_indices) < self.num_stages
|
||||
self.pretrained = pretrained
|
||||
|
||||
# transformer encoder
|
||||
dpr = [
|
||||
|
@ -401,7 +401,7 @@ class MixVisionTransformer(BaseModule):
|
|||
cur += num_layer
|
||||
|
||||
def init_weights(self):
|
||||
if self.pretrained is None:
|
||||
if self.init_cfg is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
|
@ -413,16 +413,8 @@ class MixVisionTransformer(BaseModule):
|
|||
fan_out //= m.groups
|
||||
normal_init(
|
||||
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
||||
elif isinstance(self.pretrained, str):
|
||||
logger = get_root_logger()
|
||||
checkpoint = _load_checkpoint(
|
||||
self.pretrained, logger=logger, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
self.load_state_dict(state_dict, False)
|
||||
else:
|
||||
super(MixVisionTransformer, self).init_weights()
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
|
|
|
@ -55,3 +55,59 @@ def test_mit():
|
|||
# Out identity
|
||||
outs = MHA(temp, hw_shape, temp)
|
||||
assert out.shape == (1, token_len, 64)
|
||||
|
||||
|
||||
def test_mit_init():
|
||||
path = 'PATH_THAT_DO_NOT_EXIST'
|
||||
# Test all combinations of pretrained and init_cfg
|
||||
# pretrained=None, init_cfg=None
|
||||
model = MixVisionTransformer(pretrained=None, init_cfg=None)
|
||||
assert model.init_cfg is None
|
||||
model.init_weights()
|
||||
|
||||
# pretrained=None
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
model = MixVisionTransformer(
|
||||
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||
# Test loading a checkpoint from an non-existent file
|
||||
with pytest.raises(OSError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained=None
|
||||
# init_cfg=123, whose type is unsupported
|
||||
model = MixVisionTransformer(pretrained=None, init_cfg=123)
|
||||
with pytest.raises(TypeError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained loads pretrain from an non-existent file
|
||||
# init_cfg=None
|
||||
model = MixVisionTransformer(pretrained=path, init_cfg=None)
|
||||
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||
# Test loading a checkpoint from an non-existent file
|
||||
with pytest.raises(OSError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained loads pretrain from an non-existent file
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
with pytest.raises(AssertionError):
|
||||
MixVisionTransformer(
|
||||
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
with pytest.raises(AssertionError):
|
||||
MixVisionTransformer(pretrained=path, init_cfg=123)
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg=None
|
||||
with pytest.raises(TypeError):
|
||||
MixVisionTransformer(pretrained=123, init_cfg=None)
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
with pytest.raises(AssertionError):
|
||||
MixVisionTransformer(
|
||||
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg=123, whose type is unsupported
|
||||
with pytest.raises(AssertionError):
|
||||
MixVisionTransformer(pretrained=123, init_cfg=123)
|
||||
|
|
Loading…
Reference in New Issue