fix load ckpt bug in swin (#928)

pull/1801/head
Junjun2016 2021-09-30 22:50:44 +08:00 committed by GitHub
parent 0cf56d48a4
commit 84edf6c190
2 changed files with 4 additions and 2 deletions

View File

@ -680,7 +680,7 @@ class SwinTransformer(BaseModule):
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
ckpt = _load_checkpoint(
self.init_cfg.checkpoint, logger=logger, map_location='cpu')
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
if 'state_dict' in ckpt:
_state_dict = ckpt['state_dict']
elif 'model' in ckpt:
@ -692,6 +692,8 @@ class SwinTransformer(BaseModule):
for k, v in _state_dict.items():
if k.startswith('backbone.'):
state_dict[k[9:]] = v
else:
state_dict[k] = v
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):

View File

@ -96,7 +96,7 @@ def main():
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# gpu_ids is used to calculate iter when resuming checkpoint,
# gpu_ids is used to calculate iter when resuming checkpoint
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)