[Fix] Fix the bug that mit cannot process init_cfg (#1102)
* [Fix] Fix the bug that mit cannot process init_cfg * fix errorpull/1801/head
parent
0d3ee4b520
commit
a370777e3b
|
@ -9,9 +9,8 @@ from mmcv.cnn.bricks.drop import build_dropout
|
||||||
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
||||||
from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
|
from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
|
||||||
trunc_normal_init)
|
trunc_normal_init)
|
||||||
from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint
|
from mmcv.runner import BaseModule, ModuleList, Sequential
|
||||||
|
|
||||||
from ...utils import get_root_logger
|
|
||||||
from ..builder import BACKBONES
|
from ..builder import BACKBONES
|
||||||
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw
|
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw
|
||||||
|
|
||||||
|
@ -341,16 +340,18 @@ class MixVisionTransformer(BaseModule):
|
||||||
norm_cfg=dict(type='LN', eps=1e-6),
|
norm_cfg=dict(type='LN', eps=1e-6),
|
||||||
pretrained=None,
|
pretrained=None,
|
||||||
init_cfg=None):
|
init_cfg=None):
|
||||||
super().__init__(init_cfg=init_cfg)
|
super(MixVisionTransformer, self).__init__(init_cfg=init_cfg)
|
||||||
|
|
||||||
if isinstance(pretrained, str) or pretrained is None:
|
assert not (init_cfg and pretrained), \
|
||||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
'init_cfg and pretrained cannot be set at the same time'
|
||||||
|
if isinstance(pretrained, str):
|
||||||
|
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||||
'please use "init_cfg" instead')
|
'please use "init_cfg" instead')
|
||||||
else:
|
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||||
|
elif pretrained is not None:
|
||||||
raise TypeError('pretrained must be a str or None')
|
raise TypeError('pretrained must be a str or None')
|
||||||
|
|
||||||
self.embed_dims = embed_dims
|
self.embed_dims = embed_dims
|
||||||
|
|
||||||
self.num_stages = num_stages
|
self.num_stages = num_stages
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
|
@ -362,7 +363,6 @@ class MixVisionTransformer(BaseModule):
|
||||||
|
|
||||||
self.out_indices = out_indices
|
self.out_indices = out_indices
|
||||||
assert max(out_indices) < self.num_stages
|
assert max(out_indices) < self.num_stages
|
||||||
self.pretrained = pretrained
|
|
||||||
|
|
||||||
# transformer encoder
|
# transformer encoder
|
||||||
dpr = [
|
dpr = [
|
||||||
|
@ -401,7 +401,7 @@ class MixVisionTransformer(BaseModule):
|
||||||
cur += num_layer
|
cur += num_layer
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
if self.pretrained is None:
|
if self.init_cfg is None:
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Linear):
|
if isinstance(m, nn.Linear):
|
||||||
trunc_normal_init(m, std=.02, bias=0.)
|
trunc_normal_init(m, std=.02, bias=0.)
|
||||||
|
@ -413,16 +413,8 @@ class MixVisionTransformer(BaseModule):
|
||||||
fan_out //= m.groups
|
fan_out //= m.groups
|
||||||
normal_init(
|
normal_init(
|
||||||
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
||||||
elif isinstance(self.pretrained, str):
|
else:
|
||||||
logger = get_root_logger()
|
super(MixVisionTransformer, self).init_weights()
|
||||||
checkpoint = _load_checkpoint(
|
|
||||||
self.pretrained, logger=logger, map_location='cpu')
|
|
||||||
if 'state_dict' in checkpoint:
|
|
||||||
state_dict = checkpoint['state_dict']
|
|
||||||
else:
|
|
||||||
state_dict = checkpoint
|
|
||||||
|
|
||||||
self.load_state_dict(state_dict, False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
outs = []
|
outs = []
|
||||||
|
|
|
@ -55,3 +55,59 @@ def test_mit():
|
||||||
# Out identity
|
# Out identity
|
||||||
outs = MHA(temp, hw_shape, temp)
|
outs = MHA(temp, hw_shape, temp)
|
||||||
assert out.shape == (1, token_len, 64)
|
assert out.shape == (1, token_len, 64)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mit_init():
|
||||||
|
path = 'PATH_THAT_DO_NOT_EXIST'
|
||||||
|
# Test all combinations of pretrained and init_cfg
|
||||||
|
# pretrained=None, init_cfg=None
|
||||||
|
model = MixVisionTransformer(pretrained=None, init_cfg=None)
|
||||||
|
assert model.init_cfg is None
|
||||||
|
model.init_weights()
|
||||||
|
|
||||||
|
# pretrained=None
|
||||||
|
# init_cfg loads pretrain from an non-existent file
|
||||||
|
model = MixVisionTransformer(
|
||||||
|
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||||
|
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||||
|
# Test loading a checkpoint from an non-existent file
|
||||||
|
with pytest.raises(OSError):
|
||||||
|
model.init_weights()
|
||||||
|
|
||||||
|
# pretrained=None
|
||||||
|
# init_cfg=123, whose type is unsupported
|
||||||
|
model = MixVisionTransformer(pretrained=None, init_cfg=123)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
model.init_weights()
|
||||||
|
|
||||||
|
# pretrained loads pretrain from an non-existent file
|
||||||
|
# init_cfg=None
|
||||||
|
model = MixVisionTransformer(pretrained=path, init_cfg=None)
|
||||||
|
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||||
|
# Test loading a checkpoint from an non-existent file
|
||||||
|
with pytest.raises(OSError):
|
||||||
|
model.init_weights()
|
||||||
|
|
||||||
|
# pretrained loads pretrain from an non-existent file
|
||||||
|
# init_cfg loads pretrain from an non-existent file
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
MixVisionTransformer(
|
||||||
|
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
MixVisionTransformer(pretrained=path, init_cfg=123)
|
||||||
|
|
||||||
|
# pretrain=123, whose type is unsupported
|
||||||
|
# init_cfg=None
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
MixVisionTransformer(pretrained=123, init_cfg=None)
|
||||||
|
|
||||||
|
# pretrain=123, whose type is unsupported
|
||||||
|
# init_cfg loads pretrain from an non-existent file
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
MixVisionTransformer(
|
||||||
|
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||||
|
|
||||||
|
# pretrain=123, whose type is unsupported
|
||||||
|
# init_cfg=123, whose type is unsupported
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
MixVisionTransformer(pretrained=123, init_cfg=123)
|
||||||
|
|
Loading…
Reference in New Issue