Threaded TensorBoard graph logging (#9070)
* Log TensorBoard graph on pretrain_routine_end * fixpull/9071/head
parent
0b8639a40a
commit
8665d557c1
6
train.py
6
train.py
|
@ -219,7 +219,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|||
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # run AutoAnchor
|
||||
model.half().float() # pre-reduce anchor precision
|
||||
|
||||
callbacks.run('on_pretrain_routine_end', labels, names, plots)
|
||||
callbacks.run('on_pretrain_routine_end', labels, names)
|
||||
|
||||
# DDP mode
|
||||
if cuda and RANK != -1:
|
||||
|
@ -328,7 +328,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
||||
pbar.set_description(('%11s' * 2 + '%11.4g' * 5) %
|
||||
(f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
|
||||
callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots)
|
||||
callbacks.run('on_train_batch_end', model, ni, imgs, targets, paths)
|
||||
if callbacks.stop_training:
|
||||
return
|
||||
# end batch ------------------------------------------------------------------------------------------------
|
||||
|
@ -420,7 +420,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|||
if is_coco:
|
||||
callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
|
||||
|
||||
callbacks.run('on_train_end', last, best, plots, epoch, results)
|
||||
callbacks.run('on_train_end', last, best, epoch, results)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
return results
|
||||
|
|
|
@ -49,6 +49,7 @@ class Loggers():
|
|||
self.weights = weights
|
||||
self.opt = opt
|
||||
self.hyp = hyp
|
||||
self.plots = not opt.noplots # plot results
|
||||
self.logger = logger # for printing results to console
|
||||
self.include = include
|
||||
self.keys = [
|
||||
|
@ -110,26 +111,26 @@ class Loggers():
|
|||
# Callback runs on train start
|
||||
pass
|
||||
|
||||
def on_pretrain_routine_end(self, labels, names, plots):
|
||||
def on_pretrain_routine_end(self, labels, names):
|
||||
# Callback runs on pre-train routine end
|
||||
if plots:
|
||||
if self.plots:
|
||||
plot_labels(labels, names, self.save_dir)
|
||||
paths = self.save_dir.glob('*labels*.jpg') # training labels
|
||||
if self.wandb:
|
||||
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
|
||||
# if self.clearml:
|
||||
# pass # ClearML saves these images automatically using hooks
|
||||
paths = self.save_dir.glob('*labels*.jpg') # training labels
|
||||
if self.wandb:
|
||||
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
|
||||
# if self.clearml:
|
||||
# pass # ClearML saves these images automatically using hooks
|
||||
|
||||
def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
|
||||
def on_train_batch_end(self, model, ni, imgs, targets, paths):
|
||||
# Callback runs on train batch end
|
||||
# ni: number integrated batches (since train start)
|
||||
if plots:
|
||||
if ni == 0 and not self.opt.sync_bn and self.tb:
|
||||
log_tensorboard_graph(self.tb, model, imgsz=list(imgs.shape[2:4]))
|
||||
if self.plots:
|
||||
if ni < 3:
|
||||
f = self.save_dir / f'train_batch{ni}.jpg' # filename
|
||||
plot_images(imgs, targets, paths, f)
|
||||
if (self.wandb or self.clearml) and ni == 10:
|
||||
if ni == 0 and self.tb and not self.opt.sync_bn:
|
||||
log_tensorboard_graph(self.tb, model, imgsz=(self.opt.imgsz, self.opt.imgsz))
|
||||
if ni == 10 and (self.wandb or self.clearml):
|
||||
files = sorted(self.save_dir.glob('train*.jpg'))
|
||||
if self.wandb:
|
||||
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
|
||||
|
@ -197,9 +198,9 @@ class Loggers():
|
|||
model_name='Latest Model',
|
||||
auto_delete_file=False)
|
||||
|
||||
def on_train_end(self, last, best, plots, epoch, results):
|
||||
def on_train_end(self, last, best, epoch, results):
|
||||
# Callback runs on training end, i.e. saving best model
|
||||
if plots:
|
||||
if self.plots:
|
||||
plot_results(file=self.save_dir / 'results.csv') # save results.png
|
||||
files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
|
||||
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
|
||||
|
@ -291,6 +292,7 @@ class GenericLogger:
|
|||
wandb.log_artifact(art)
|
||||
|
||||
|
||||
@threaded
|
||||
def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
|
||||
# Log model graph to TensorBoard
|
||||
try:
|
||||
|
@ -300,5 +302,5 @@ def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
|
|||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore') # suppress jit trace warning
|
||||
tb.add_graph(torch.jit.trace(de_parallel(model), im, strict=False), [])
|
||||
except Exception:
|
||||
print('WARNING: TensorBoard graph visualization failure')
|
||||
except Exception as e:
|
||||
print(f'WARNING: TensorBoard graph visualization failure {e}')
|
||||
|
|
Loading…
Reference in New Issue