[Improve] Use MMCV load_state_dict func in ViT/Swin (#1272)

* [Improve] Use MMCV load_state_dict func in ViT/Swin

* use CheckpointLoader instead
This commit is contained in:
Jerry Jiarui XU 2022-02-09 00:52:42 -05:00 committed by GitHub
parent b4314f98c1
commit 66b778c064
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 6 deletions

View File

@ -11,7 +11,8 @@ from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, build_dropout from mmcv.cnn.bricks.transformer import FFN, build_dropout
from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_, from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
trunc_normal_init) trunc_normal_init)
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
load_state_dict)
from mmcv.utils import to_2tuple from mmcv.utils import to_2tuple
from ...utils import get_root_logger from ...utils import get_root_logger
@ -678,7 +679,7 @@ class SwinTransformer(BaseModule):
f'specify `Pretrained` in ' \ f'specify `Pretrained` in ' \
f'`init_cfg` in ' \ f'`init_cfg` in ' \
f'{self.__class__.__name__} ' f'{self.__class__.__name__} '
ckpt = _load_checkpoint( ckpt = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu') self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
if 'state_dict' in ckpt: if 'state_dict' in ckpt:
_state_dict = ckpt['state_dict'] _state_dict = ckpt['state_dict']
@ -732,7 +733,7 @@ class SwinTransformer(BaseModule):
nH2, L2).permute(1, 0).contiguous() nH2, L2).permute(1, 0).contiguous()
# load state_dict # load state_dict
self.load_state_dict(state_dict, False) load_state_dict(self, state_dict, strict=False, logger=logger)
def forward(self, x): def forward(self, x):
x, hw_shape = self.patch_embed(x) x, hw_shape = self.patch_embed(x)

View File

@ -8,7 +8,8 @@ from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init, from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
trunc_normal_) trunc_normal_)
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
load_state_dict)
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple from torch.nn.modules.utils import _pair as to_2tuple
@ -266,7 +267,7 @@ class VisionTransformer(BaseModule):
if (isinstance(self.init_cfg, dict) if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'): and self.init_cfg.get('type') == 'Pretrained'):
logger = get_root_logger() logger = get_root_logger()
checkpoint = _load_checkpoint( checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu') self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
if 'state_dict' in checkpoint: if 'state_dict' in checkpoint:
@ -287,7 +288,7 @@ class VisionTransformer(BaseModule):
(h // self.patch_size, w // self.patch_size), (h // self.patch_size, w // self.patch_size),
(pos_size, pos_size), self.interpolate_mode) (pos_size, pos_size), self.interpolate_mode)
self.load_state_dict(state_dict, False) load_state_dict(self, state_dict, strict=False, logger=logger)
elif self.init_cfg is not None: elif self.init_cfg is not None:
super(VisionTransformer, self).init_weights() super(VisionTransformer, self).init_weights()
else: else: