map_location for all (#826)

* map_location for all

* format

* hmm

* map_location

* back

* doc

* same
This commit is contained in:
lizz 2021-02-10 12:23:39 +08:00 committed by GitHub
parent 999f2d08b4
commit 1c2e665ad6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 11 deletions

View File

@ -269,22 +269,23 @@ def load_from_http(filename, map_location=None, model_dir=None):
Args: Args:
filename (str): checkpoint file path with modelzoo or filename (str): checkpoint file path with modelzoo or
torchvision prefix torchvision prefix
map_location (str, optional): it's not use. map_location (str, optional): Same as :func:`torch.load`.
model_dir (string, optional): directory in which to save the object, model_dir (string, optional): directory in which to save the object,
Default: None Default: None
Returns: Returns:
dict or OrderedDict: The loaded checkpoint. dict or OrderedDict: The loaded checkpoint.
""" """
rank, world_size = get_dist_info() rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank)) rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0: if rank == 0:
checkpoint = model_zoo.load_url(filename, model_dir=model_dir) checkpoint = model_zoo.load_url(
filename, model_dir=model_dir, map_location=map_location)
if world_size > 1: if world_size > 1:
torch.distributed.barrier() torch.distributed.barrier()
if rank > 0: if rank > 0:
checkpoint = model_zoo.load_url(filename, model_dir=model_dir) checkpoint = model_zoo.load_url(
filename, model_dir=model_dir, map_location=map_location)
return checkpoint return checkpoint
@ -370,7 +371,7 @@ def load_from_torchvision(filename, map_location=None):
Args: Args:
filename (str): checkpoint file path with modelzoo or filename (str): checkpoint file path with modelzoo or
torchvision prefix torchvision prefix
map_location (str, optional): it's not use. map_location (str, optional): Same as :func:`torch.load`.
Returns: Returns:
dict or OrderedDict: The loaded checkpoint. dict or OrderedDict: The loaded checkpoint.
@ -382,7 +383,7 @@ def load_from_torchvision(filename, map_location=None):
model_name = filename[11:] model_name = filename[11:]
else: else:
model_name = filename[14:] model_name = filename[14:]
return load_from_http(model_urls[model_name]) return load_from_http(model_urls[model_name], map_location=map_location)
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://')) @CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
@ -416,7 +417,7 @@ def load_from_openmmlab(filename, map_location=None):
model_url = model_urls[model_name] model_url = model_urls[model_name]
# check if is url # check if is url
if model_url.startswith(('http://', 'https://')): if model_url.startswith(('http://', 'https://')):
checkpoint = load_from_http(model_url) checkpoint = load_from_http(model_url, map_location=map_location)
else: else:
filename = osp.join(_get_mmcv_home(), model_url) filename = osp.join(_get_mmcv_home(), model_url)
if not osp.isfile(filename): if not osp.isfile(filename):
@ -431,7 +432,7 @@ def load_from_mmcls(filename, map_location=None):
Args: Args:
filename (str): checkpoint file path with mmcls prefix filename (str): checkpoint file path with mmcls prefix
map_location (str, optional): it's not use. map_location (str, optional): Same as :func:`torch.load`.
Returns: Returns:
dict or OrderedDict: The loaded checkpoint. dict or OrderedDict: The loaded checkpoint.
@ -439,7 +440,8 @@ def load_from_mmcls(filename, map_location=None):
model_urls = get_mmcls_models() model_urls = get_mmcls_models()
model_name = filename[8:] model_name = filename[8:]
checkpoint = load_from_http(model_urls[model_name]) checkpoint = load_from_http(
model_urls[model_name], map_location=map_location)
checkpoint = _process_mmcls_checkpoint(checkpoint) checkpoint = _process_mmcls_checkpoint(checkpoint)
return checkpoint return checkpoint

View File

@ -58,11 +58,11 @@ def test_get_deprecated_models():
} }
def load_from_http(url): def load_from_http(url, map_location=None):
return 'url:' + url return 'url:' + url
def load_url(url, model_dir=None): def load_url(url, map_location=None, model_dir=None):
return load_from_http(url) return load_from_http(url)