mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Improve] Use MMCV load_state_dict func in ViT/Swin (#1272)
* [Improve] Use MMCV load_state_dict func in ViT/Swin * use CheckpointLoader instead
This commit is contained in:
parent
b4314f98c1
commit
66b778c064
@ -11,7 +11,8 @@ from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, build_dropout
|
||||
from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
|
||||
trunc_normal_init)
|
||||
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
|
||||
from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
|
||||
load_state_dict)
|
||||
from mmcv.utils import to_2tuple
|
||||
|
||||
from ...utils import get_root_logger
|
||||
@ -678,7 +679,7 @@ class SwinTransformer(BaseModule):
|
||||
f'specify `Pretrained` in ' \
|
||||
f'`init_cfg` in ' \
|
||||
f'{self.__class__.__name__} '
|
||||
ckpt = _load_checkpoint(
|
||||
ckpt = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
if 'state_dict' in ckpt:
|
||||
_state_dict = ckpt['state_dict']
|
||||
@ -732,7 +733,7 @@ class SwinTransformer(BaseModule):
|
||||
nH2, L2).permute(1, 0).contiguous()
|
||||
|
||||
# load state_dict
|
||||
self.load_state_dict(state_dict, False)
|
||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||
|
||||
def forward(self, x):
|
||||
x, hw_shape = self.patch_embed(x)
|
||||
|
@ -8,7 +8,8 @@ from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
||||
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
|
||||
from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
|
||||
load_state_dict)
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
@ -266,7 +267,7 @@ class VisionTransformer(BaseModule):
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
logger = get_root_logger()
|
||||
checkpoint = _load_checkpoint(
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
@ -287,7 +288,7 @@ class VisionTransformer(BaseModule):
|
||||
(h // self.patch_size, w // self.patch_size),
|
||||
(pos_size, pos_size), self.interpolate_mode)
|
||||
|
||||
self.load_state_dict(state_dict, False)
|
||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||
elif self.init_cfg is not None:
|
||||
super(VisionTransformer, self).init_weights()
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user