mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Fix load_ckpt from pavi and s3 (#1020)
* fix load_ckpt * revised according to comments * revise according to comments * fix typopull/1024/head
parent
b36c4de157
commit
d7f8355011
|
@ -293,7 +293,8 @@ def load_from_http(filename, map_location=None, model_dir=None):
|
|||
@CheckpointLoader.register_scheme(prefixes='pavi://')
|
||||
def load_from_pavi(filename, map_location=None):
|
||||
"""load checkpoint through the file path prefixed with pavi. In distributed
|
||||
setting, this function only download checkpoint at local rank 0.
|
||||
setting, this function download ckpt at all ranks to different temporary
|
||||
directories.
|
||||
|
||||
Args:
|
||||
filename (str): checkpoint file path with pavi prefix
|
||||
|
@ -312,30 +313,20 @@ def load_from_pavi(filename, map_location=None):
|
|||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please install pavi to load checkpoint from modelcloud.')
|
||||
rank, world_size = get_dist_info()
|
||||
rank = int(os.environ.get('LOCAL_RANK', rank))
|
||||
if rank == 0:
|
||||
model = modelcloud.get(model_path)
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
downloaded_file = osp.join(tmp_dir, model.name)
|
||||
model.download(downloaded_file)
|
||||
checkpoint = torch.load(downloaded_file, map_location=map_location)
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
if rank > 0:
|
||||
model = modelcloud.get(model_path)
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
downloaded_file = osp.join(tmp_dir, model.name)
|
||||
model.download(downloaded_file)
|
||||
checkpoint = torch.load(
|
||||
downloaded_file, map_location=map_location)
|
||||
|
||||
model = modelcloud.get(model_path)
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
downloaded_file = osp.join(tmp_dir, model.name)
|
||||
model.download(downloaded_file)
|
||||
checkpoint = torch.load(downloaded_file, map_location=map_location)
|
||||
return checkpoint
|
||||
|
||||
|
||||
@CheckpointLoader.register_scheme(prefixes='s3://')
|
||||
def load_from_ceph(filename, map_location=None, backend='ceph'):
|
||||
"""load checkpoint through the file path prefixed with s3. In distributed
|
||||
setting, this function only download checkpoint at local rank 0.
|
||||
"""load checkpoint through the file path prefixed with s3. In distributed
|
||||
setting, this function download ckpt at all ranks to different temporary
|
||||
directories.
|
||||
|
||||
Args:
|
||||
filename (str): checkpoint file path with s3 prefix
|
||||
|
@ -346,21 +337,14 @@ def load_from_ceph(filename, map_location=None, backend='ceph'):
|
|||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint.
|
||||
"""
|
||||
rank, world_size = get_dist_info()
|
||||
rank = int(os.environ.get('LOCAL_RANK', rank))
|
||||
|
||||
allowed_backends = ['ceph']
|
||||
if backend not in allowed_backends:
|
||||
raise ValueError(f'Load from Backend {backend} is not supported.')
|
||||
if rank == 0:
|
||||
fileclient = FileClient(backend=backend)
|
||||
buffer = io.BytesIO(fileclient.get(filename))
|
||||
checkpoint = torch.load(buffer, map_location=map_location)
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
if rank > 0:
|
||||
fileclient = FileClient(backend=backend)
|
||||
buffer = io.BytesIO(fileclient.get(filename))
|
||||
checkpoint = torch.load(buffer, map_location=map_location)
|
||||
|
||||
fileclient = FileClient(backend=backend)
|
||||
buffer = io.BytesIO(fileclient.get(filename))
|
||||
checkpoint = torch.load(buffer, map_location=map_location)
|
||||
return checkpoint
|
||||
|
||||
|
||||
|
@ -663,7 +647,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
|
|||
if filename.startswith('pavi://'):
|
||||
try:
|
||||
from pavi import modelcloud
|
||||
from pavi.exception import NodeNotFoundError
|
||||
from pavi import exception
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please install pavi to load checkpoint from modelcloud.')
|
||||
|
@ -672,7 +656,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
|
|||
model_dir, model_name = osp.split(model_path)
|
||||
try:
|
||||
model = modelcloud.get(model_dir)
|
||||
except NodeNotFoundError:
|
||||
except exception.NodeNotFoundError:
|
||||
model = root.create_training_model(model_dir)
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
checkpoint_file = osp.join(tmp_dir, model_name)
|
||||
|
|
Loading…
Reference in New Issue