Pass LOCAL_RANK to torch_distributed_zero_first() (#5114)

Co-authored-by: qiningonline <qiningonline@gmail.com>
This commit is contained in:
Nan 2021-10-09 20:41:50 -05:00 committed by GitHub
parent 97b6b14abe
commit 4a6dfffdaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -99,7 +99,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
plots = not evolve # create plots
cuda = device.type != 'cpu'
init_seeds(1 + RANK)
with torch_distributed_zero_first(RANK):
with torch_distributed_zero_first(LOCAL_RANK):
data_dict = data_dict or check_dataset(data) # check if None
train_path, val_path = data_dict['train'], data_dict['val']
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
@ -111,7 +111,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
check_suffix(weights, '.pt') # check weights
pretrained = weights.endswith('.pt')
if pretrained:
with torch_distributed_zero_first(RANK):
with torch_distributed_zero_first(LOCAL_RANK):
weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
@ -208,7 +208,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Trainloader
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=RANK,
hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
workers=workers, image_weights=opt.image_weights, quad=opt.quad,
prefix=colorstr('train: '))
mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class