Add CSV logging to GenericLogger (#9128)
Enable CSV logging for Classify training. Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/9150/head
parent
f8816f58b7
commit
f0e5a608f5
|
@ -242,9 +242,10 @@ class GenericLogger:
|
||||||
|
|
||||||
def __init__(self, opt, console_logger, include=('tb', 'wandb')):
|
def __init__(self, opt, console_logger, include=('tb', 'wandb')):
|
||||||
# init default loggers
|
# init default loggers
|
||||||
self.save_dir = opt.save_dir
|
self.save_dir = Path(opt.save_dir)
|
||||||
self.include = include
|
self.include = include
|
||||||
self.console_logger = console_logger
|
self.console_logger = console_logger
|
||||||
|
self.csv = self.save_dir / 'results.csv' # CSV logger
|
||||||
if 'tb' in self.include:
|
if 'tb' in self.include:
|
||||||
prefix = colorstr('TensorBoard: ')
|
prefix = colorstr('TensorBoard: ')
|
||||||
self.console_logger.info(
|
self.console_logger.info(
|
||||||
|
@ -258,14 +259,21 @@ class GenericLogger:
|
||||||
else:
|
else:
|
||||||
self.wandb = None
|
self.wandb = None
|
||||||
|
|
||||||
def log_metrics(self, metrics_dict, epoch):
|
def log_metrics(self, metrics, epoch):
|
||||||
# Log metrics dictionary to all loggers
|
# Log metrics dictionary to all loggers
|
||||||
|
if self.csv:
|
||||||
|
keys, vals = list(metrics.keys()), list(metrics.values())
|
||||||
|
n = len(metrics) + 1 # number of cols
|
||||||
|
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
|
||||||
|
with open(self.csv, 'a') as f:
|
||||||
|
f.write(s + ('%23.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')
|
||||||
|
|
||||||
if self.tb:
|
if self.tb:
|
||||||
for k, v in metrics_dict.items():
|
for k, v in metrics.items():
|
||||||
self.tb.add_scalar(k, v, epoch)
|
self.tb.add_scalar(k, v, epoch)
|
||||||
|
|
||||||
if self.wandb:
|
if self.wandb:
|
||||||
self.wandb.log(metrics_dict, step=epoch)
|
self.wandb.log(metrics, step=epoch)
|
||||||
|
|
||||||
def log_images(self, files, name='Images', epoch=0):
|
def log_images(self, files, name='Images', epoch=0):
|
||||||
# Log images to all loggers
|
# Log images to all loggers
|
||||||
|
@ -291,6 +299,11 @@ class GenericLogger:
|
||||||
art.add_file(str(model_path))
|
art.add_file(str(model_path))
|
||||||
wandb.log_artifact(art)
|
wandb.log_artifact(art)
|
||||||
|
|
||||||
|
def update_params(self, params):
|
||||||
|
# Update the paramters logged
|
||||||
|
if self.wandb:
|
||||||
|
wandb.run.config.update(params, allow_val_change=True)
|
||||||
|
|
||||||
|
|
||||||
def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
|
def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
|
||||||
# Log model graph to TensorBoard
|
# Log model graph to TensorBoard
|
||||||
|
|
Loading…
Reference in New Issue