mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
map_location for all (#826)
* map_location for all * format * hmm * map_location * back * doc * same
This commit is contained in:
parent
999f2d08b4
commit
1c2e665ad6
@ -269,22 +269,23 @@ def load_from_http(filename, map_location=None, model_dir=None):
|
|||||||
Args:
|
Args:
|
||||||
filename (str): checkpoint file path with modelzoo or
|
filename (str): checkpoint file path with modelzoo or
|
||||||
torchvision prefix
|
torchvision prefix
|
||||||
map_location (str, optional): it's not use.
|
map_location (str, optional): Same as :func:`torch.load`.
|
||||||
model_dir (string, optional): directory in which to save the object,
|
model_dir (string, optional): directory in which to save the object,
|
||||||
Default: None
|
Default: None
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict or OrderedDict: The loaded checkpoint.
|
dict or OrderedDict: The loaded checkpoint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rank, world_size = get_dist_info()
|
rank, world_size = get_dist_info()
|
||||||
rank = int(os.environ.get('LOCAL_RANK', rank))
|
rank = int(os.environ.get('LOCAL_RANK', rank))
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
checkpoint = model_zoo.load_url(filename, model_dir=model_dir)
|
checkpoint = model_zoo.load_url(
|
||||||
|
filename, model_dir=model_dir, map_location=map_location)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
if rank > 0:
|
if rank > 0:
|
||||||
checkpoint = model_zoo.load_url(filename, model_dir=model_dir)
|
checkpoint = model_zoo.load_url(
|
||||||
|
filename, model_dir=model_dir, map_location=map_location)
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
@ -370,7 +371,7 @@ def load_from_torchvision(filename, map_location=None):
|
|||||||
Args:
|
Args:
|
||||||
filename (str): checkpoint file path with modelzoo or
|
filename (str): checkpoint file path with modelzoo or
|
||||||
torchvision prefix
|
torchvision prefix
|
||||||
map_location (str, optional): it's not use.
|
map_location (str, optional): Same as :func:`torch.load`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict or OrderedDict: The loaded checkpoint.
|
dict or OrderedDict: The loaded checkpoint.
|
||||||
@ -382,7 +383,7 @@ def load_from_torchvision(filename, map_location=None):
|
|||||||
model_name = filename[11:]
|
model_name = filename[11:]
|
||||||
else:
|
else:
|
||||||
model_name = filename[14:]
|
model_name = filename[14:]
|
||||||
return load_from_http(model_urls[model_name])
|
return load_from_http(model_urls[model_name], map_location=map_location)
|
||||||
|
|
||||||
|
|
||||||
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
|
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
|
||||||
@ -416,7 +417,7 @@ def load_from_openmmlab(filename, map_location=None):
|
|||||||
model_url = model_urls[model_name]
|
model_url = model_urls[model_name]
|
||||||
# check if is url
|
# check if is url
|
||||||
if model_url.startswith(('http://', 'https://')):
|
if model_url.startswith(('http://', 'https://')):
|
||||||
checkpoint = load_from_http(model_url)
|
checkpoint = load_from_http(model_url, map_location=map_location)
|
||||||
else:
|
else:
|
||||||
filename = osp.join(_get_mmcv_home(), model_url)
|
filename = osp.join(_get_mmcv_home(), model_url)
|
||||||
if not osp.isfile(filename):
|
if not osp.isfile(filename):
|
||||||
@ -431,7 +432,7 @@ def load_from_mmcls(filename, map_location=None):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename (str): checkpoint file path with mmcls prefix
|
filename (str): checkpoint file path with mmcls prefix
|
||||||
map_location (str, optional): it's not use.
|
map_location (str, optional): Same as :func:`torch.load`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict or OrderedDict: The loaded checkpoint.
|
dict or OrderedDict: The loaded checkpoint.
|
||||||
@ -439,7 +440,8 @@ def load_from_mmcls(filename, map_location=None):
|
|||||||
|
|
||||||
model_urls = get_mmcls_models()
|
model_urls = get_mmcls_models()
|
||||||
model_name = filename[8:]
|
model_name = filename[8:]
|
||||||
checkpoint = load_from_http(model_urls[model_name])
|
checkpoint = load_from_http(
|
||||||
|
model_urls[model_name], map_location=map_location)
|
||||||
checkpoint = _process_mmcls_checkpoint(checkpoint)
|
checkpoint = _process_mmcls_checkpoint(checkpoint)
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
@ -58,11 +58,11 @@ def test_get_deprecated_models():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_from_http(url):
|
def load_from_http(url, map_location=None):
|
||||||
return 'url:' + url
|
return 'url:' + url
|
||||||
|
|
||||||
|
|
||||||
def load_url(url, model_dir=None):
|
def load_url(url, map_location=None, model_dir=None):
|
||||||
return load_from_http(url)
|
return load_from_http(url)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user