DDP `torch.jit.trace()` `--sync-bn` fix (#4615)
* Remove assert * debug0 * trace=not opt.sync * sync to sync_bn fix * Cleanuppull/4618/head
parent
bb5ebc290e
commit
50a9828679
3
train.py
3
train.py
|
@ -333,7 +333,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
||||
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
|
||||
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
|
||||
callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots)
|
||||
callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots, opt.sync_bn)
|
||||
# end batch ------------------------------------------------------------------------------------------------
|
||||
|
||||
# Scheduler
|
||||
|
@ -499,7 +499,6 @@ def main(opt):
|
|||
assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
|
||||
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
|
||||
assert not opt.evolve, '--evolve argument is not compatible with DDP training'
|
||||
assert not opt.sync_bn, '--sync-bn known training issue, see https://github.com/ultralytics/yolov5/issues/3998'
|
||||
torch.cuda.set_device(LOCAL_RANK)
|
||||
device = torch.device('cuda', LOCAL_RANK)
|
||||
dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
|
||||
|
|
|
@ -69,13 +69,14 @@ class Loggers():
|
|||
if self.wandb:
|
||||
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
|
||||
|
||||
def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
|
||||
def on_train_batch_end(self, ni, model, imgs, targets, paths, plots, sync_bn):
|
||||
# Callback runs on train batch end
|
||||
if plots:
|
||||
if ni == 0:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore') # suppress jit trace warning
|
||||
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
|
||||
if not sync_bn: # tb.add_graph() --sync known issue https://github.com/ultralytics/yolov5/issues/3754
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore') # suppress jit trace warning
|
||||
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
|
||||
if ni < 3:
|
||||
f = self.save_dir / f'train_batch{ni}.jpg' # filename
|
||||
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
|
||||
|
|
Loading…
Reference in New Issue