fix load ckpt bug in swin (#928)
parent
0cf56d48a4
commit
84edf6c190
|
@ -680,7 +680,7 @@ class SwinTransformer(BaseModule):
|
|||
f'`init_cfg` in ' \
|
||||
f'{self.__class__.__name__} '
|
||||
ckpt = _load_checkpoint(
|
||||
self.init_cfg.checkpoint, logger=logger, map_location='cpu')
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
if 'state_dict' in ckpt:
|
||||
_state_dict = ckpt['state_dict']
|
||||
elif 'model' in ckpt:
|
||||
|
@ -692,6 +692,8 @@ class SwinTransformer(BaseModule):
|
|||
for k, v in _state_dict.items():
|
||||
if k.startswith('backbone.'):
|
||||
state_dict[k[9:]] = v
|
||||
else:
|
||||
state_dict[k] = v
|
||||
|
||||
# strip prefix of state_dict
|
||||
if list(state_dict.keys())[0].startswith('module.'):
|
||||
|
|
|
@ -96,7 +96,7 @@ def main():
|
|||
else:
|
||||
distributed = True
|
||||
init_dist(args.launcher, **cfg.dist_params)
|
||||
# gpu_ids is used to calculate iter when resuming checkpoint,
|
||||
# gpu_ids is used to calculate iter when resuming checkpoint
|
||||
_, world_size = get_dist_info()
|
||||
cfg.gpu_ids = range(world_size)
|
||||
|
||||
|
|
Loading…
Reference in New Issue