mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
fix load ckpt bug in swin (#928)
This commit is contained in:
parent
c1dcf91c1a
commit
10886b00f0
@ -680,7 +680,7 @@ class SwinTransformer(BaseModule):
|
||||
f'`init_cfg` in ' \
|
||||
f'{self.__class__.__name__} '
|
||||
ckpt = _load_checkpoint(
|
||||
self.init_cfg.checkpoint, logger=logger, map_location='cpu')
|
||||
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||
if 'state_dict' in ckpt:
|
||||
_state_dict = ckpt['state_dict']
|
||||
elif 'model' in ckpt:
|
||||
@ -692,6 +692,8 @@ class SwinTransformer(BaseModule):
|
||||
for k, v in _state_dict.items():
|
||||
if k.startswith('backbone.'):
|
||||
state_dict[k[9:]] = v
|
||||
else:
|
||||
state_dict[k] = v
|
||||
|
||||
# strip prefix of state_dict
|
||||
if list(state_dict.keys())[0].startswith('module.'):
|
||||
|
@ -96,7 +96,7 @@ def main():
|
||||
else:
|
||||
distributed = True
|
||||
init_dist(args.launcher, **cfg.dist_params)
|
||||
# gpu_ids is used to calculate iter when resuming checkpoint,
|
||||
# gpu_ids is used to calculate iter when resuming checkpoint
|
||||
_, world_size = get_dist_info()
|
||||
cfg.gpu_ids = range(world_size)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user