diff --git a/mmseg/models/backbones/swin.py b/mmseg/models/backbones/swin.py index a360ab018..d5d11ac83 100644 --- a/mmseg/models/backbones/swin.py +++ b/mmseg/models/backbones/swin.py @@ -11,7 +11,8 @@ from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import FFN, build_dropout from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_, trunc_normal_init) -from mmcv.runner import BaseModule, ModuleList, _load_checkpoint +from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList, + load_state_dict) from mmcv.utils import to_2tuple from ...utils import get_root_logger @@ -678,7 +679,7 @@ class SwinTransformer(BaseModule): f'specify `Pretrained` in ' \ f'`init_cfg` in ' \ f'{self.__class__.__name__} ' - ckpt = _load_checkpoint( + ckpt = CheckpointLoader.load_checkpoint( self.init_cfg['checkpoint'], logger=logger, map_location='cpu') if 'state_dict' in ckpt: _state_dict = ckpt['state_dict'] @@ -732,7 +733,7 @@ class SwinTransformer(BaseModule): nH2, L2).permute(1, 0).contiguous() # load state_dict - self.load_state_dict(state_dict, False) + load_state_dict(self, state_dict, strict=False, logger=logger) def forward(self, x): x, hw_shape = self.patch_embed(x) diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index 965652503..9c920baa6 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -8,7 +8,8 @@ from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init, trunc_normal_) -from mmcv.runner import BaseModule, ModuleList, _load_checkpoint +from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList, + load_state_dict) from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.utils import _pair as to_2tuple @@ -266,7 +267,7 @@ class VisionTransformer(BaseModule): if (isinstance(self.init_cfg, dict) and self.init_cfg.get('type') == 'Pretrained'): logger = get_root_logger() - checkpoint = _load_checkpoint( + checkpoint = CheckpointLoader.load_checkpoint( self.init_cfg['checkpoint'], logger=logger, map_location='cpu') if 'state_dict' in checkpoint: @@ -287,7 +288,7 @@ class VisionTransformer(BaseModule): (h // self.patch_size, w // self.patch_size), (pos_size, pos_size), self.interpolate_mode) - self.load_state_dict(state_dict, False) + load_state_dict(self, state_dict, strict=False, logger=logger) elif self.init_cfg is not None: super(VisionTransformer, self).init_weights() else: