fix load ckpt bug in swin (#928)
parent
0cf56d48a4
commit
84edf6c190
|
@ -680,7 +680,7 @@ class SwinTransformer(BaseModule):
|
||||||
f'`init_cfg` in ' \
|
f'`init_cfg` in ' \
|
||||||
f'{self.__class__.__name__} '
|
f'{self.__class__.__name__} '
|
||||||
ckpt = _load_checkpoint(
|
ckpt = _load_checkpoint(
|
||||||
self.init_cfg.checkpoint, logger=logger, map_location='cpu')
|
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||||
if 'state_dict' in ckpt:
|
if 'state_dict' in ckpt:
|
||||||
_state_dict = ckpt['state_dict']
|
_state_dict = ckpt['state_dict']
|
||||||
elif 'model' in ckpt:
|
elif 'model' in ckpt:
|
||||||
|
@ -692,6 +692,8 @@ class SwinTransformer(BaseModule):
|
||||||
for k, v in _state_dict.items():
|
for k, v in _state_dict.items():
|
||||||
if k.startswith('backbone.'):
|
if k.startswith('backbone.'):
|
||||||
state_dict[k[9:]] = v
|
state_dict[k[9:]] = v
|
||||||
|
else:
|
||||||
|
state_dict[k] = v
|
||||||
|
|
||||||
# strip prefix of state_dict
|
# strip prefix of state_dict
|
||||||
if list(state_dict.keys())[0].startswith('module.'):
|
if list(state_dict.keys())[0].startswith('module.'):
|
||||||
|
|
|
@ -96,7 +96,7 @@ def main():
|
||||||
else:
|
else:
|
||||||
distributed = True
|
distributed = True
|
||||||
init_dist(args.launcher, **cfg.dist_params)
|
init_dist(args.launcher, **cfg.dist_params)
|
||||||
# gpu_ids is used to calculate iter when resuming checkpoint,
|
# gpu_ids is used to calculate iter when resuming checkpoint
|
||||||
_, world_size = get_dist_info()
|
_, world_size = get_dist_info()
|
||||||
cfg.gpu_ids = range(world_size)
|
cfg.gpu_ids = range(world_size)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue