[Fix] Fix the bug that vit cannot load pretrain properly when using i… (#999)
* [Fix] Fix the bug that vit cannot load pretrain properly when using init_cfg to specify the pretrain scheme * [Fix] fix the coverage problem * Update mmseg/models/backbones/vit.py Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn> * [Fix] make the predicate more concise and clearer * [Fix] Modified the judgement logic * Update tests/test_models/test_backbones/test_vit.py Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn> * add comments Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>pull/1037/head
parent
14dc00af62
commit
7a1c9a5499
|
@ -170,7 +170,7 @@ class VisionTransformer(BaseModule):
|
|||
with_cp=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super(VisionTransformer, self).__init__()
|
||||
super(VisionTransformer, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(img_size, int):
|
||||
img_size = to_2tuple(img_size)
|
||||
|
@ -185,10 +185,13 @@ class VisionTransformer(BaseModule):
|
|||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
|
||||
if isinstance(pretrained, str) or pretrained is None:
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
else:
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.img_size = img_size
|
||||
|
@ -197,7 +200,6 @@ class VisionTransformer(BaseModule):
|
|||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.pretrained = pretrained
|
||||
self.init_cfg = init_cfg
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_channels=in_channels,
|
||||
|
@ -260,10 +262,12 @@ class VisionTransformer(BaseModule):
|
|||
return getattr(self, self.norm1_name)
|
||||
|
||||
def init_weights(self):
|
||||
if isinstance(self.pretrained, str):
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
logger = get_root_logger()
|
||||
checkpoint = _load_checkpoint(
|
||||
self.pretrained, logger=logger, map_location='cpu')
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
|
@ -283,9 +287,9 @@ class VisionTransformer(BaseModule):
|
|||
(pos_size, pos_size), self.interpolate_mode)
|
||||
|
||||
self.load_state_dict(state_dict, False)
|
||||
|
||||
elif self.pretrained is None:
|
||||
elif self.init_cfg is not None:
|
||||
super(VisionTransformer, self).init_weights()
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
trunc_normal_init(self.pos_embed, std=.02)
|
||||
|
|
|
@ -118,3 +118,59 @@ def test_vit_backbone():
|
|||
feat = model(imgs)
|
||||
assert feat[0][0].shape == (1, 768, 14, 14)
|
||||
assert feat[0][1].shape == (1, 768)
|
||||
|
||||
|
||||
def test_vit_init():
|
||||
path = 'PATH_THAT_DO_NOT_EXIST'
|
||||
# Test all combinations of pretrained and init_cfg
|
||||
# pretrained=None, init_cfg=None
|
||||
model = VisionTransformer(pretrained=None, init_cfg=None)
|
||||
assert model.init_cfg is None
|
||||
model.init_weights()
|
||||
|
||||
# pretrained=None
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
model = VisionTransformer(
|
||||
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||
# Test loading a checkpoint from an non-existent file
|
||||
with pytest.raises(OSError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained=None
|
||||
# init_cfg=123, whose type is unsupported
|
||||
model = VisionTransformer(pretrained=None, init_cfg=123)
|
||||
with pytest.raises(TypeError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained loads pretrain from an non-existent file
|
||||
# init_cfg=None
|
||||
model = VisionTransformer(pretrained=path, init_cfg=None)
|
||||
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||
# Test loading a checkpoint from an non-existent file
|
||||
with pytest.raises(OSError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained loads pretrain from an non-existent file
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
with pytest.raises(AssertionError):
|
||||
model = VisionTransformer(
|
||||
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
with pytest.raises(AssertionError):
|
||||
model = VisionTransformer(pretrained=path, init_cfg=123)
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg=None
|
||||
with pytest.raises(TypeError):
|
||||
model = VisionTransformer(pretrained=123, init_cfg=None)
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
with pytest.raises(AssertionError):
|
||||
model = VisionTransformer(
|
||||
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg=123, whose type is unsupported
|
||||
with pytest.raises(AssertionError):
|
||||
model = VisionTransformer(pretrained=123, init_cfg=123)
|
||||
|
|
Loading…
Reference in New Issue