Threaded TensorBoard graph logging (#9070)

* Log TensorBoard graph on pretrain_routine_end

* fix
pull/9071/head
Glenn Jocher 2022-08-21 16:51:50 +02:00 committed by GitHub
parent 0b8639a40a
commit 8665d557c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 19 deletions

View File

@ -219,7 +219,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # run AutoAnchor
model.half().float() # pre-reduce anchor precision
callbacks.run('on_pretrain_routine_end', labels, names, plots)
callbacks.run('on_pretrain_routine_end', labels, names)
# DDP mode
if cuda and RANK != -1:
@ -328,7 +328,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
pbar.set_description(('%11s' * 2 + '%11.4g' * 5) %
(f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots)
callbacks.run('on_train_batch_end', model, ni, imgs, targets, paths)
if callbacks.stop_training:
return
# end batch ------------------------------------------------------------------------------------------------
@ -420,7 +420,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
if is_coco:
callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
callbacks.run('on_train_end', last, best, plots, epoch, results)
callbacks.run('on_train_end', last, best, epoch, results)
torch.cuda.empty_cache()
return results

View File

@ -49,6 +49,7 @@ class Loggers():
self.weights = weights
self.opt = opt
self.hyp = hyp
self.plots = not opt.noplots # plot results
self.logger = logger # for printing results to console
self.include = include
self.keys = [
@ -110,26 +111,26 @@ class Loggers():
# Callback runs on train start
pass
def on_pretrain_routine_end(self, labels, names, plots):
def on_pretrain_routine_end(self, labels, names):
# Callback runs on pre-train routine end
if plots:
if self.plots:
plot_labels(labels, names, self.save_dir)
paths = self.save_dir.glob('*labels*.jpg') # training labels
if self.wandb:
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
# if self.clearml:
# pass # ClearML saves these images automatically using hooks
paths = self.save_dir.glob('*labels*.jpg') # training labels
if self.wandb:
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
# if self.clearml:
# pass # ClearML saves these images automatically using hooks
def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
def on_train_batch_end(self, model, ni, imgs, targets, paths):
# Callback runs on train batch end
# ni: number integrated batches (since train start)
if plots:
if ni == 0 and not self.opt.sync_bn and self.tb:
log_tensorboard_graph(self.tb, model, imgsz=list(imgs.shape[2:4]))
if self.plots:
if ni < 3:
f = self.save_dir / f'train_batch{ni}.jpg' # filename
plot_images(imgs, targets, paths, f)
if (self.wandb or self.clearml) and ni == 10:
if ni == 0 and self.tb and not self.opt.sync_bn:
log_tensorboard_graph(self.tb, model, imgsz=(self.opt.imgsz, self.opt.imgsz))
if ni == 10 and (self.wandb or self.clearml):
files = sorted(self.save_dir.glob('train*.jpg'))
if self.wandb:
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
@ -197,9 +198,9 @@ class Loggers():
model_name='Latest Model',
auto_delete_file=False)
def on_train_end(self, last, best, plots, epoch, results):
def on_train_end(self, last, best, epoch, results):
# Callback runs on training end, i.e. saving best model
if plots:
if self.plots:
plot_results(file=self.save_dir / 'results.csv') # save results.png
files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
@ -291,6 +292,7 @@ class GenericLogger:
wandb.log_artifact(art)
@threaded
def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
# Log model graph to TensorBoard
try:
@ -300,5 +302,5 @@ def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
tb.add_graph(torch.jit.trace(de_parallel(model), im, strict=False), [])
except Exception:
print('WARNING: TensorBoard graph visualization failure')
except Exception as e:
print(f'WARNING: TensorBoard graph visualization failure {e}')