mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Improve] Use MMCV load_state_dict func in ViT/Swin (#1272)
* [Improve] Use MMCV load_state_dict func in ViT/Swin * use CheckpointLoader instead
This commit is contained in:
parent
b4314f98c1
commit
66b778c064
@ -11,7 +11,8 @@ from mmcv.cnn import build_norm_layer
|
|||||||
from mmcv.cnn.bricks.transformer import FFN, build_dropout
|
from mmcv.cnn.bricks.transformer import FFN, build_dropout
|
||||||
from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
|
from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
|
||||||
trunc_normal_init)
|
trunc_normal_init)
|
||||||
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
|
from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
|
||||||
|
load_state_dict)
|
||||||
from mmcv.utils import to_2tuple
|
from mmcv.utils import to_2tuple
|
||||||
|
|
||||||
from ...utils import get_root_logger
|
from ...utils import get_root_logger
|
||||||
@ -678,7 +679,7 @@ class SwinTransformer(BaseModule):
|
|||||||
f'specify `Pretrained` in ' \
|
f'specify `Pretrained` in ' \
|
||||||
f'`init_cfg` in ' \
|
f'`init_cfg` in ' \
|
||||||
f'{self.__class__.__name__} '
|
f'{self.__class__.__name__} '
|
||||||
ckpt = _load_checkpoint(
|
ckpt = CheckpointLoader.load_checkpoint(
|
||||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||||
if 'state_dict' in ckpt:
|
if 'state_dict' in ckpt:
|
||||||
_state_dict = ckpt['state_dict']
|
_state_dict = ckpt['state_dict']
|
||||||
@ -732,7 +733,7 @@ class SwinTransformer(BaseModule):
|
|||||||
nH2, L2).permute(1, 0).contiguous()
|
nH2, L2).permute(1, 0).contiguous()
|
||||||
|
|
||||||
# load state_dict
|
# load state_dict
|
||||||
self.load_state_dict(state_dict, False)
|
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x, hw_shape = self.patch_embed(x)
|
x, hw_shape = self.patch_embed(x)
|
||||||
|
@ -8,7 +8,8 @@ from mmcv.cnn import build_norm_layer
|
|||||||
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
||||||
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
|
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
|
||||||
trunc_normal_)
|
trunc_normal_)
|
||||||
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
|
from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
|
||||||
|
load_state_dict)
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
from torch.nn.modules.utils import _pair as to_2tuple
|
from torch.nn.modules.utils import _pair as to_2tuple
|
||||||
|
|
||||||
@ -266,7 +267,7 @@ class VisionTransformer(BaseModule):
|
|||||||
if (isinstance(self.init_cfg, dict)
|
if (isinstance(self.init_cfg, dict)
|
||||||
and self.init_cfg.get('type') == 'Pretrained'):
|
and self.init_cfg.get('type') == 'Pretrained'):
|
||||||
logger = get_root_logger()
|
logger = get_root_logger()
|
||||||
checkpoint = _load_checkpoint(
|
checkpoint = CheckpointLoader.load_checkpoint(
|
||||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||||
|
|
||||||
if 'state_dict' in checkpoint:
|
if 'state_dict' in checkpoint:
|
||||||
@ -287,7 +288,7 @@ class VisionTransformer(BaseModule):
|
|||||||
(h // self.patch_size, w // self.patch_size),
|
(h // self.patch_size, w // self.patch_size),
|
||||||
(pos_size, pos_size), self.interpolate_mode)
|
(pos_size, pos_size), self.interpolate_mode)
|
||||||
|
|
||||||
self.load_state_dict(state_dict, False)
|
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||||
elif self.init_cfg is not None:
|
elif self.init_cfg is not None:
|
||||||
super(VisionTransformer, self).init_weights()
|
super(VisionTransformer, self).init_weights()
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user