Suppress jit trace warning + graph once (#3454)

* Suppress jit trace warning + graph once

Suppress harmless jit trace warning on TensorBoard add_graph call. Also fix multiple add_graph() calls bug, now only on batch 0.

* Update train.py
pull/3455/head
Glenn Jocher 2021-06-04 12:37:41 +02:00 committed by GitHub
parent af2bc3a1c3
commit 4aa2959101
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 5 deletions

View File

@ -4,6 +4,7 @@ import math
import os import os
import random import random
import time import time
import warnings
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
@ -323,18 +324,19 @@ def train(hyp, opt, device, tb_writer=None):
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % ( s = ('%10s' * 2 + '%10.4g' * 6) % (
'%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1]) f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])
pbar.set_description(s) pbar.set_description(s)
# Plot # Plot
if plots and ni < 3: if plots and ni < 3:
f = save_dir / f'train_batch{ni}.jpg' # filename f = save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
if tb_writer: if tb_writer and ni == 0:
tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # model graph with warnings.catch_warnings():
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) warnings.simplefilter('ignore') # suppress jit trace warning
tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # graph
elif plots and ni == 10 and wandb_logger.wandb: elif plots and ni == 10 and wandb_logger.wandb:
wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
save_dir.glob('train*.jpg') if x.exists()]}) save_dir.glob('train*.jpg') if x.exists()]})
# end batch ------------------------------------------------------------------------------------------------ # end batch ------------------------------------------------------------------------------------------------