map_location for all (#826)

* map_location for all

* format

* hmm

* map_location

* back

* doc

* same
pull/835/head
lizz 2021-02-10 12:23:39 +08:00 committed by GitHub
parent 999f2d08b4
commit 1c2e665ad6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 11 deletions

View File

@ -269,22 +269,23 @@ def load_from_http(filename, map_location=None, model_dir=None):
Args:
filename (str): checkpoint file path with modelzoo or
torchvision prefix
map_location (str, optional): it's not use.
map_location (str, optional): Same as :func:`torch.load`.
model_dir (string, optional): directory in which to save the object,
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
checkpoint = model_zoo.load_url(filename, model_dir=model_dir)
checkpoint = model_zoo.load_url(
filename, model_dir=model_dir, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
checkpoint = model_zoo.load_url(filename, model_dir=model_dir)
checkpoint = model_zoo.load_url(
filename, model_dir=model_dir, map_location=map_location)
return checkpoint
@ -370,7 +371,7 @@ def load_from_torchvision(filename, map_location=None):
Args:
filename (str): checkpoint file path with modelzoo or
torchvision prefix
map_location (str, optional): it's not use.
map_location (str, optional): Same as :func:`torch.load`.
Returns:
dict or OrderedDict: The loaded checkpoint.
@ -382,7 +383,7 @@ def load_from_torchvision(filename, map_location=None):
model_name = filename[11:]
else:
model_name = filename[14:]
return load_from_http(model_urls[model_name])
return load_from_http(model_urls[model_name], map_location=map_location)
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
@ -416,7 +417,7 @@ def load_from_openmmlab(filename, map_location=None):
model_url = model_urls[model_name]
# check if is url
if model_url.startswith(('http://', 'https://')):
checkpoint = load_from_http(model_url)
checkpoint = load_from_http(model_url, map_location=map_location)
else:
filename = osp.join(_get_mmcv_home(), model_url)
if not osp.isfile(filename):
@ -431,7 +432,7 @@ def load_from_mmcls(filename, map_location=None):
Args:
filename (str): checkpoint file path with mmcls prefix
map_location (str, optional): it's not use.
map_location (str, optional): Same as :func:`torch.load`.
Returns:
dict or OrderedDict: The loaded checkpoint.
@ -439,7 +440,8 @@ def load_from_mmcls(filename, map_location=None):
model_urls = get_mmcls_models()
model_name = filename[8:]
checkpoint = load_from_http(model_urls[model_name])
checkpoint = load_from_http(
model_urls[model_name], map_location=map_location)
checkpoint = _process_mmcls_checkpoint(checkpoint)
return checkpoint

View File

@ -58,11 +58,11 @@ def test_get_deprecated_models():
}
def load_from_http(url):
def load_from_http(url, map_location=None):
return 'url:' + url
def load_url(url, model_dir=None):
def load_url(url, map_location=None, model_dir=None):
return load_from_http(url)