fix load ckpt bug in swin (#928)

pull/1801/head
Junjun2016 2021-09-30 22:50:44 +08:00 committed by GitHub
parent 0cf56d48a4
commit 84edf6c190
2 changed files with 4 additions and 2 deletions

View File

@ -680,7 +680,7 @@ class SwinTransformer(BaseModule):
f'`init_cfg` in ' \ f'`init_cfg` in ' \
f'{self.__class__.__name__} ' f'{self.__class__.__name__} '
ckpt = _load_checkpoint( ckpt = _load_checkpoint(
self.init_cfg.checkpoint, logger=logger, map_location='cpu') self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
if 'state_dict' in ckpt: if 'state_dict' in ckpt:
_state_dict = ckpt['state_dict'] _state_dict = ckpt['state_dict']
elif 'model' in ckpt: elif 'model' in ckpt:
@ -692,6 +692,8 @@ class SwinTransformer(BaseModule):
for k, v in _state_dict.items(): for k, v in _state_dict.items():
if k.startswith('backbone.'): if k.startswith('backbone.'):
state_dict[k[9:]] = v state_dict[k[9:]] = v
else:
state_dict[k] = v
# strip prefix of state_dict # strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'): if list(state_dict.keys())[0].startswith('module.'):

View File

@ -96,7 +96,7 @@ def main():
else: else:
distributed = True distributed = True
init_dist(args.launcher, **cfg.dist_params) init_dist(args.launcher, **cfg.dist_params)
# gpu_ids is used to calculate iter when resuming checkpoint, # gpu_ids is used to calculate iter when resuming checkpoint
_, world_size = get_dist_info() _, world_size = get_dist_info()
cfg.gpu_ids = range(world_size) cfg.gpu_ids = range(world_size)