[Fix] Fix load_ckpt from pavi and s3 (#1020)

* fix load_ckpt

* revised according to comments

* revise according to comments

* fix typo
pull/1024/head
LXXXXR 2021-05-13 20:55:51 +08:00 committed by GitHub
parent b36c4de157
commit d7f8355011
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 34 deletions

View File

@ -293,7 +293,8 @@ def load_from_http(filename, map_location=None, model_dir=None):
@CheckpointLoader.register_scheme(prefixes='pavi://')
def load_from_pavi(filename, map_location=None):
"""load checkpoint through the file path prefixed with pavi. In distributed
setting, this function only download checkpoint at local rank 0.
setting, this function download ckpt at all ranks to different temporary
directories.
Args:
filename (str): checkpoint file path with pavi prefix
@ -312,30 +313,20 @@ def load_from_pavi(filename, map_location=None):
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(
downloaded_file, map_location=map_location)
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
return checkpoint
@CheckpointLoader.register_scheme(prefixes='s3://')
def load_from_ceph(filename, map_location=None, backend='ceph'):
"""load checkpoint through the file path prefixed with s3. In distributed
setting, this function only download checkpoint at local rank 0.
"""load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
Args:
filename (str): checkpoint file path with s3 prefix
@ -346,21 +337,14 @@ def load_from_ceph(filename, map_location=None, backend='ceph'):
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
allowed_backends = ['ceph']
if backend not in allowed_backends:
raise ValueError(f'Load from Backend {backend} is not supported.')
if rank == 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
return checkpoint
@ -663,7 +647,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
if filename.startswith('pavi://'):
try:
from pavi import modelcloud
from pavi.exception import NodeNotFoundError
from pavi import exception
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
@ -672,7 +656,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except NodeNotFoundError:
except exception.NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)