mirror of https://github.com/open-mmlab/mmcv.git
map_location for all (#826)
* map_location for all * format * hmm * map_location * back * doc * samepull/835/head
parent
999f2d08b4
commit
1c2e665ad6
|
@ -269,22 +269,23 @@ def load_from_http(filename, map_location=None, model_dir=None):
|
|||
Args:
|
||||
filename (str): checkpoint file path with modelzoo or
|
||||
torchvision prefix
|
||||
map_location (str, optional): it's not use.
|
||||
map_location (str, optional): Same as :func:`torch.load`.
|
||||
model_dir (string, optional): directory in which to save the object,
|
||||
Default: None
|
||||
|
||||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint.
|
||||
"""
|
||||
|
||||
rank, world_size = get_dist_info()
|
||||
rank = int(os.environ.get('LOCAL_RANK', rank))
|
||||
if rank == 0:
|
||||
checkpoint = model_zoo.load_url(filename, model_dir=model_dir)
|
||||
checkpoint = model_zoo.load_url(
|
||||
filename, model_dir=model_dir, map_location=map_location)
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
if rank > 0:
|
||||
checkpoint = model_zoo.load_url(filename, model_dir=model_dir)
|
||||
checkpoint = model_zoo.load_url(
|
||||
filename, model_dir=model_dir, map_location=map_location)
|
||||
return checkpoint
|
||||
|
||||
|
||||
|
@ -370,7 +371,7 @@ def load_from_torchvision(filename, map_location=None):
|
|||
Args:
|
||||
filename (str): checkpoint file path with modelzoo or
|
||||
torchvision prefix
|
||||
map_location (str, optional): it's not use.
|
||||
map_location (str, optional): Same as :func:`torch.load`.
|
||||
|
||||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint.
|
||||
|
@ -382,7 +383,7 @@ def load_from_torchvision(filename, map_location=None):
|
|||
model_name = filename[11:]
|
||||
else:
|
||||
model_name = filename[14:]
|
||||
return load_from_http(model_urls[model_name])
|
||||
return load_from_http(model_urls[model_name], map_location=map_location)
|
||||
|
||||
|
||||
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
|
||||
|
@ -416,7 +417,7 @@ def load_from_openmmlab(filename, map_location=None):
|
|||
model_url = model_urls[model_name]
|
||||
# check if is url
|
||||
if model_url.startswith(('http://', 'https://')):
|
||||
checkpoint = load_from_http(model_url)
|
||||
checkpoint = load_from_http(model_url, map_location=map_location)
|
||||
else:
|
||||
filename = osp.join(_get_mmcv_home(), model_url)
|
||||
if not osp.isfile(filename):
|
||||
|
@ -431,7 +432,7 @@ def load_from_mmcls(filename, map_location=None):
|
|||
|
||||
Args:
|
||||
filename (str): checkpoint file path with mmcls prefix
|
||||
map_location (str, optional): it's not use.
|
||||
map_location (str, optional): Same as :func:`torch.load`.
|
||||
|
||||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint.
|
||||
|
@ -439,7 +440,8 @@ def load_from_mmcls(filename, map_location=None):
|
|||
|
||||
model_urls = get_mmcls_models()
|
||||
model_name = filename[8:]
|
||||
checkpoint = load_from_http(model_urls[model_name])
|
||||
checkpoint = load_from_http(
|
||||
model_urls[model_name], map_location=map_location)
|
||||
checkpoint = _process_mmcls_checkpoint(checkpoint)
|
||||
return checkpoint
|
||||
|
||||
|
|
|
@ -58,11 +58,11 @@ def test_get_deprecated_models():
|
|||
}
|
||||
|
||||
|
||||
def load_from_http(url):
|
||||
def load_from_http(url, map_location=None):
|
||||
return 'url:' + url
|
||||
|
||||
|
||||
def load_url(url, model_dir=None):
|
||||
def load_url(url, map_location=None, model_dir=None):
|
||||
return load_from_http(url)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue