mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Add progress argument in load_from_http (#770)
This commit is contained in:
parent
e83ac944b6
commit
eb803f8702
@ -320,7 +320,10 @@ def load_from_local(filename, map_location):
|
||||
|
||||
|
||||
@CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
|
||||
def load_from_http(filename, map_location=None, model_dir=None):
|
||||
def load_from_http(filename,
|
||||
map_location=None,
|
||||
model_dir=None,
|
||||
progress=os.isatty(0)):
|
||||
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
|
||||
setting, this function only download checkpoint at local rank 0.
|
||||
|
||||
@ -337,12 +340,18 @@ def load_from_http(filename, map_location=None, model_dir=None):
|
||||
rank, world_size = get_dist_info()
|
||||
if rank == 0:
|
||||
checkpoint = load_url(
|
||||
filename, model_dir=model_dir, map_location=map_location)
|
||||
filename,
|
||||
model_dir=model_dir,
|
||||
map_location=map_location,
|
||||
progress=progress)
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
if rank > 0:
|
||||
checkpoint = load_url(
|
||||
filename, model_dir=model_dir, map_location=map_location)
|
||||
filename,
|
||||
model_dir=model_dir,
|
||||
map_location=map_location,
|
||||
progress=progress)
|
||||
return checkpoint
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user