diff --git a/train.py b/train.py index ac38d04db..e4c9b6ae6 100644 --- a/train.py +++ b/train.py @@ -219,7 +219,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # run AutoAnchor model.half().float() # pre-reduce anchor precision - callbacks.run('on_pretrain_routine_end', labels, names, plots) + callbacks.run('on_pretrain_routine_end', labels, names) # DDP mode if cuda and RANK != -1: @@ -328,7 +328,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB) pbar.set_description(('%11s' * 2 + '%11.4g' * 5) % (f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])) - callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots) + callbacks.run('on_train_batch_end', model, ni, imgs, targets, paths) if callbacks.stop_training: return # end batch ------------------------------------------------------------------------------------------------ @@ -420,7 +420,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio if is_coco: callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi) - callbacks.run('on_train_end', last, best, plots, epoch, results) + callbacks.run('on_train_end', last, best, epoch, results) torch.cuda.empty_cache() return results diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py index b9869df26..98a123eee 100644 --- a/utils/loggers/__init__.py +++ b/utils/loggers/__init__.py @@ -49,6 +49,7 @@ class Loggers(): self.weights = weights self.opt = opt self.hyp = hyp + self.plots = not opt.noplots # plot results self.logger = logger # for printing results to console self.include = include self.keys = [ @@ -110,26 +111,26 @@ class Loggers(): # Callback runs on train start pass - def on_pretrain_routine_end(self, labels, names, plots): + def on_pretrain_routine_end(self, labels, names): # Callback runs on pre-train routine end - if plots: + if self.plots: plot_labels(labels, names, self.save_dir) - paths = self.save_dir.glob('*labels*.jpg') # training labels - if self.wandb: - self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]}) - # if self.clearml: - # pass # ClearML saves these images automatically using hooks + paths = self.save_dir.glob('*labels*.jpg') # training labels + if self.wandb: + self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]}) + # if self.clearml: + # pass # ClearML saves these images automatically using hooks - def on_train_batch_end(self, ni, model, imgs, targets, paths, plots): + def on_train_batch_end(self, model, ni, imgs, targets, paths): # Callback runs on train batch end # ni: number integrated batches (since train start) - if plots: - if ni == 0 and not self.opt.sync_bn and self.tb: - log_tensorboard_graph(self.tb, model, imgsz=list(imgs.shape[2:4])) + if self.plots: if ni < 3: f = self.save_dir / f'train_batch{ni}.jpg' # filename plot_images(imgs, targets, paths, f) - if (self.wandb or self.clearml) and ni == 10: + if ni == 0 and self.tb and not self.opt.sync_bn: + log_tensorboard_graph(self.tb, model, imgsz=(self.opt.imgsz, self.opt.imgsz)) + if ni == 10 and (self.wandb or self.clearml): files = sorted(self.save_dir.glob('train*.jpg')) if self.wandb: self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]}) @@ -197,9 +198,9 @@ class Loggers(): model_name='Latest Model', auto_delete_file=False) - def on_train_end(self, last, best, plots, epoch, results): + def on_train_end(self, last, best, epoch, results): # Callback runs on training end, i.e. saving best model - if plots: + if self.plots: plot_results(file=self.save_dir / 'results.csv') # save results.png files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))] files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter @@ -291,6 +292,7 @@ class GenericLogger: wandb.log_artifact(art) +@threaded def log_tensorboard_graph(tb, model, imgsz=(640, 640)): # Log model graph to TensorBoard try: @@ -300,5 +302,5 @@ def log_tensorboard_graph(tb, model, imgsz=(640, 640)): with warnings.catch_warnings(): warnings.simplefilter('ignore') # suppress jit trace warning tb.add_graph(torch.jit.trace(de_parallel(model), im, strict=False), []) - except Exception: - print('WARNING: TensorBoard graph visualization failure') + except Exception as e: + print(f'WARNING: TensorBoard graph visualization failure {e}')