[Enhance] Add progress argument in load_from_http (#770)

This commit is contained in:
Austin Welch 2022-12-26 04:07:22 -05:00 committed by GitHub
parent e83ac944b6
commit eb803f8702
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -320,7 +320,10 @@ def load_from_local(filename, map_location):
@CheckpointLoader.register_scheme(prefixes=('http://', 'https://')) @CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
def load_from_http(filename, map_location=None, model_dir=None): def load_from_http(filename,
map_location=None,
model_dir=None,
progress=os.isatty(0)):
"""load checkpoint through HTTP or HTTPS scheme path. In distributed """load checkpoint through HTTP or HTTPS scheme path. In distributed
setting, this function only download checkpoint at local rank 0. setting, this function only download checkpoint at local rank 0.
@ -337,12 +340,18 @@ def load_from_http(filename, map_location=None, model_dir=None):
rank, world_size = get_dist_info() rank, world_size = get_dist_info()
if rank == 0: if rank == 0:
checkpoint = load_url( checkpoint = load_url(
filename, model_dir=model_dir, map_location=map_location) filename,
model_dir=model_dir,
map_location=map_location,
progress=progress)
if world_size > 1: if world_size > 1:
torch.distributed.barrier() torch.distributed.barrier()
if rank > 0: if rank > 0:
checkpoint = load_url( checkpoint = load_url(
filename, model_dir=model_dir, map_location=map_location) filename,
model_dir=model_dir,
map_location=map_location,
progress=progress)
return checkpoint return checkpoint