New CSV Logger (#4148)

* New CSV Logger

* cleanup

* move batch plots into Logger

* rename comment

* Remove total loss from progress bar

* mloss :-1 bug fix

* Update plot_results()

* Update plot_results()

* plot_results bug fix
This commit is contained in:
Glenn Jocher 2021-07-25 19:06:37 +02:00 committed by GitHub
parent 3764277f95
commit 96e36a7c91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 68 additions and 109 deletions

1
.gitignore vendored
View File

@ -31,6 +31,7 @@ data/*
!data/*.sh !data/*.sh
results*.txt results*.txt
results*.csv
# Datasets ------------------------------------------------------------------------------------------------------------- # Datasets -------------------------------------------------------------------------------------------------------------
coco/ coco/

View File

@ -12,7 +12,6 @@ import sys
import time import time
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from threading import Thread
import math import math
import numpy as np import numpy as np
@ -38,7 +37,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima
check_requirements, print_mutation, set_logging, one_cycle, colorstr check_requirements, print_mutation, set_logging, one_cycle, colorstr
from utils.google_utils import attempt_download from utils.google_utils import attempt_download
from utils.loss import ComputeLoss from utils.loss import ComputeLoss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution from utils.plots import plot_labels, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.metrics import fitness from utils.metrics import fitness
@ -61,7 +60,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Directories # Directories
w = save_dir / 'weights' # weights dir w = save_dir / 'weights' # weights dir
w.mkdir(parents=True, exist_ok=True) # make dir w.mkdir(parents=True, exist_ok=True) # make dir
last, best, results_file = w / 'last.pt', w / 'best.pt', save_dir / 'results.txt' last, best = w / 'last.pt', w / 'best.pt'
# Hyperparameters # Hyperparameters
if isinstance(hyp, str): if isinstance(hyp, str):
@ -88,7 +87,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Loggers # Loggers
if RANK in [-1, 0]: if RANK in [-1, 0]:
loggers = Loggers(save_dir, results_file, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict loggers = Loggers(save_dir, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict
if loggers.wandb and resume: if loggers.wandb and resume:
weights, epochs, hyp, data_dict = opt.weights, opt.epochs, opt.hyp, loggers.wandb.data_dict weights, epochs, hyp, data_dict = opt.weights, opt.epochs, opt.hyp, loggers.wandb.data_dict
@ -167,10 +166,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
ema.updates = ckpt['updates'] ema.updates = ckpt['updates']
# Results
if ckpt.get('training_results') is not None:
results_file.write_text(ckpt['training_results']) # write results.txt
# Epochs # Epochs
start_epoch = ckpt['epoch'] + 1 start_epoch = ckpt['epoch'] + 1
if resume: if resume:
@ -275,11 +270,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs) # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
mloss = torch.zeros(4, device=device) # mean losses mloss = torch.zeros(3, device=device) # mean losses
if RANK != -1: if RANK != -1:
train_loader.sampler.set_epoch(epoch) train_loader.sampler.set_epoch(epoch)
pbar = enumerate(train_loader) pbar = enumerate(train_loader)
LOGGER.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size')) LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
if RANK in [-1, 0]: if RANK in [-1, 0]:
pbar = tqdm(pbar, total=nb) # progress bar pbar = tqdm(pbar, total=nb) # progress bar
optimizer.zero_grad() optimizer.zero_grad()
@ -327,20 +322,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
ema.update(model) ema.update(model)
last_opt_step = ni last_opt_step = ni
# Print # Log
if RANK in [-1, 0]: if RANK in [-1, 0]:
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB) mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % ( pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]) f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
pbar.set_description(s) loggers.on_train_batch_end(ni, model, imgs, targets, paths, plots)
# Plot
if plots:
if ni < 3:
f = save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
loggers.on_train_batch_end(ni, model, imgs)
# end batch ------------------------------------------------------------------------------------------------ # end batch ------------------------------------------------------------------------------------------------
@ -371,13 +359,12 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
if fi > best_fitness: if fi > best_fitness:
best_fitness = fi best_fitness = fi
loggers.on_train_val_end(mloss, results, lr, epoch, s, best_fitness, fi) loggers.on_train_val_end(mloss, results, lr, epoch, best_fitness, fi)
# Save model # Save model
if (not nosave) or (final_epoch and not evolve): # if save if (not nosave) or (final_epoch and not evolve): # if save
ckpt = {'epoch': epoch, ckpt = {'epoch': epoch,
'best_fitness': best_fitness, 'best_fitness': best_fitness,
'training_results': results_file.read_text(),
'model': deepcopy(de_parallel(model)).half(), 'model': deepcopy(de_parallel(model)).half(),
'ema': deepcopy(ema.ema).half(), 'ema': deepcopy(ema.ema).half(),
'updates': ema.updates, 'updates': ema.updates,
@ -395,9 +382,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# end training ----------------------------------------------------------------------------------------------------- # end training -----------------------------------------------------------------------------------------------------
if RANK in [-1, 0]: if RANK in [-1, 0]:
LOGGER.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') LOGGER.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
if plots:
plot_results(save_dir=save_dir) # save as results.png
if not evolve: if not evolve:
if is_coco: # COCO dataset if is_coco: # COCO dataset
for m in [last, best] if best.exists() else [last]: # speed, mAP tests for m in [last, best] if best.exists() else [last]: # speed, mAP tests
@ -411,13 +395,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
save_dir=save_dir, save_dir=save_dir,
save_json=True, save_json=True,
plots=False) plots=False)
# Strip optimizers # Strip optimizers
for f in last, best: for f in last, best:
if f.exists(): if f.exists():
strip_optimizer(f) # strip optimizers strip_optimizer(f) # strip optimizers
loggers.on_train_end(last, best, plots)
loggers.on_train_end(last, best)
torch.cuda.empty_cache() torch.cuda.empty_cache()
return results return results

View File

@ -1,15 +1,17 @@
# YOLOv5 experiment logging utils # YOLOv5 experiment logging utils
import warnings import warnings
from threading import Thread
import torch import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from utils.general import colorstr, emojis from utils.general import colorstr, emojis
from utils.loggers.wandb.wandb_utils import WandbLogger from utils.loggers.wandb.wandb_utils import WandbLogger
from utils.plots import plot_images, plot_results
from utils.torch_utils import de_parallel from utils.torch_utils import de_parallel
LOGGERS = ('txt', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases LOGGERS = ('csv', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases
try: try:
import wandb import wandb
@ -21,10 +23,8 @@ except (ImportError, AssertionError):
class Loggers(): class Loggers():
# YOLOv5 Loggers class # YOLOv5 Loggers class
def __init__(self, save_dir=None, results_file=None, weights=None, opt=None, hyp=None, def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, data_dict=None, logger=None, include=LOGGERS):
data_dict=None, logger=None, include=LOGGERS):
self.save_dir = save_dir self.save_dir = save_dir
self.results_file = results_file
self.weights = weights self.weights = weights
self.opt = opt self.opt = opt
self.hyp = hyp self.hyp = hyp
@ -35,7 +35,7 @@ class Loggers():
setattr(self, k, None) # init empty logger dictionary setattr(self, k, None) # init empty logger dictionary
def start(self): def start(self):
self.txt = True # always log to txt self.csv = True # always log to csv
# Message # Message
try: try:
@ -63,15 +63,19 @@ class Loggers():
return self return self
def on_train_batch_end(self, ni, model, imgs): def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
# Callback runs on train batch end # Callback runs on train batch end
if ni == 0: if plots:
with warnings.catch_warnings(): if ni == 0:
warnings.simplefilter('ignore') # suppress jit trace warning with warnings.catch_warnings():
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), []) warnings.simplefilter('ignore') # suppress jit trace warning
if self.wandb and ni == 10: self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
files = sorted(self.save_dir.glob('train*.jpg')) if ni < 3:
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]}) f = self.save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
if self.wandb and ni == 10:
files = sorted(self.save_dir.glob('train*.jpg'))
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
def on_train_epoch_end(self, epoch): def on_train_epoch_end(self, epoch):
# Callback runs on train epoch end # Callback runs on train epoch end
@ -89,21 +93,28 @@ class Loggers():
files = sorted(self.save_dir.glob('val*.jpg')) files = sorted(self.save_dir.glob('val*.jpg'))
self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]}) self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
def on_train_val_end(self, mloss, results, lr, epoch, s, best_fitness, fi): def on_train_val_end(self, mloss, results, lr, epoch, best_fitness, fi):
# Callback runs on validation end during training # Callback runs on val end during training
vals = list(mloss[:-1]) + list(results) + lr vals = list(mloss) + list(results) + lr
tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
'x/lr0', 'x/lr1', 'x/lr2'] # params 'x/lr0', 'x/lr1', 'x/lr2'] # params
if self.txt: x = {k: v for k, v in zip(keys, vals)} # dict
with open(self.results_file, 'a') as f:
f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss if self.csv:
file = self.save_dir / 'results.csv'
n = len(x) + 1 # number of cols
s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # add header
with open(file, 'a') as f:
f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')
if self.tb: if self.tb:
for x, tag in zip(vals, tags): for k, v in x.items():
self.tb.add_scalar(tag, x, epoch) # TensorBoard self.tb.add_scalar(k, v, epoch) # TensorBoard
if self.wandb: if self.wandb:
self.wandb.log({k: v for k, v in zip(tags, vals)}) self.wandb.log(x)
self.wandb.end_epoch(best_result=best_fitness == fi) self.wandb.end_epoch(best_result=best_fitness == fi)
def on_model_save(self, last, epoch, final_epoch, best_fitness, fi): def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
@ -112,8 +123,10 @@ class Loggers():
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1: if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi) self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
def on_train_end(self, last, best): def on_train_end(self, last, best, plots):
# Callback runs on training end # Callback runs on training end
if plots:
plot_results(dir=self.save_dir) # save results.png
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]] files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
if self.wandb: if self.wandb:

View File

@ -162,8 +162,7 @@ class ComputeLoss:
lcls *= self.hyp['cls'] lcls *= self.hyp['cls']
bs = tobj.shape[0] # batch size bs = tobj.shape[0] # batch size
loss = lbox + lobj + lcls return (lbox + lobj + lcls) * bs, torch.cat((lbox, lobj, lcls)).detach()
return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
def build_targets(self, p, targets): def build_targets(self, p, targets):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h) # Build targets for compute_loss(), input targets(image,class,x,y,w,h)

View File

@ -1,7 +1,5 @@
# Plotting utils # Plotting utils
import glob
import os
from copy import copy from copy import copy
from pathlib import Path from pathlib import Path
@ -387,63 +385,29 @@ def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200) plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay() def plot_results(file='', dir=''):
# Plot training 'results*.txt', overlaying train and val losses # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends save_dir = Path(file).parent if file else Path(dir)
t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
n = results.shape[1] # number of rows
x = range(start, min(stop, n) if stop else n)
fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
ax = ax.ravel()
for i in range(5):
for j in [i, i + 5]:
y = results[j, x]
ax[i].plot(x, y, marker='.', label=s[j])
# y_smooth = butter_lowpass_filtfilt(y)
# ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])
ax[i].set_title(t[i])
ax[i].legend()
ax[i].set_ylabel(f) if i == 0 else None # add filename
fig.savefig(f.replace('.txt', '.png'), dpi=200)
def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
# Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp')
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True) fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
ax = ax.ravel() ax = ax.ravel()
s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall', files = list(save_dir.glob('results*.csv'))
'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95'] assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
if bucket:
# files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
files = ['results%g.txt' % x for x in id]
c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
os.system(c)
else:
files = list(Path(save_dir).glob('results*.txt'))
assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
for fi, f in enumerate(files): for fi, f in enumerate(files):
try: try:
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T data = pd.read_csv(f)
n = results.shape[1] # number of rows s = [x.strip() for x in data.columns]
x = range(start, min(stop, n) if stop else n) x = data.values[:, 0]
for i in range(10): for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]):
y = results[i, x] y = data.values[:, j]
if i in [0, 1, 2, 5, 6, 7]: # y[y == 0] = np.nan # don't show zero values
y[y == 0] = np.nan # don't show zero loss values ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)
# y /= y[0] # normalize ax[i].set_title(s[j], fontsize=12)
label = labels[fi] if len(labels) else f.stem # if j in [8, 9, 10]: # share train and val loss y axes
ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8)
ax[i].set_title(s[i])
# if i in [5, 6, 7]: # share train and val loss y axes
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
except Exception as e: except Exception as e:
print('Warning: Plotting error for %s; %s' % (f, e)) print(f'Warning: Plotting error for {f}: {e}')
ax[1].legend() ax[1].legend()
fig.savefig(Path(save_dir) / 'results.png', dpi=200) fig.savefig(save_dir / 'results.png', dpi=200)
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')): def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):

2
val.py
View File

@ -171,7 +171,7 @@ def run(data,
# Compute loss # Compute loss
if compute_loss: if compute_loss:
loss += compute_loss([x.float() for x in train_out], targets)[1][:3] # box, obj, cls loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls
# Run NMS # Run NMS
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels