commit
f937e2a600
|
@ -7,7 +7,7 @@ from collections import OrderedDict
|
|||
|
||||
def load_checkpoint(model, checkpoint_path, use_ema=False):
|
||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
state_dict_key = ''
|
||||
if isinstance(checkpoint, dict):
|
||||
state_dict_key = 'state_dict'
|
||||
|
@ -32,7 +32,7 @@ def resume_checkpoint(model, checkpoint_path):
|
|||
optimizer_state = None
|
||||
resume_epoch = None
|
||||
if os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in checkpoint['state_dict'].items():
|
||||
|
|
Loading…
Reference in New Issue