Daemon thread plotting (#1561)

* Daemon thread plotting

* remove process_batch

* plot after print
pull/1566/head
Glenn Jocher 2020-11-30 16:44:14 +01:00 committed by GitHub
parent 68211f72c9
commit b6ed1104a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 19 deletions

23
test.py
View File

@ -3,6 +3,7 @@ import glob
import json import json
import os import os
from pathlib import Path from pathlib import Path
from threading import Thread
import numpy as np import numpy as np
import torch import torch
@ -206,10 +207,10 @@ def test(data,
# Plot images # Plot images
if plots and batch_i < 3: if plots and batch_i < 3:
f = save_dir / f'test_batch{batch_i}_labels.jpg' # filename f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels
plot_images(img, targets, paths, f, names) # labels Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
f = save_dir / f'test_batch{batch_i}_pred.jpg' f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions
plot_images(img, output_to_target(output), paths, f, names) # predictions Thread(target=plot_images, args=(img, output_to_target(output), paths, f, names), daemon=True).start()
# Compute statistics # Compute statistics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
@ -221,13 +222,6 @@ def test(data,
else: else:
nt = torch.zeros(1) nt = torch.zeros(1)
# Plots
if plots:
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
if wandb and wandb.run:
wandb.log({"Images": wandb_images})
wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]})
# Print results # Print results
pf = '%20s' + '%12.3g' * 6 # print format pf = '%20s' + '%12.3g' * 6 # print format
print(pf % ('all', seen, nt.sum(), mp, mr, map50, map)) print(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
@ -242,6 +236,13 @@ def test(data,
if not training: if not training:
print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t) print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t)
# Plots
if plots:
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
if wandb and wandb.run:
wandb.log({"Images": wandb_images})
wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]})
# Save JSON # Save JSON
if save_json and len(jdict): if save_json and len(jdict):
w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights

View File

@ -1,12 +1,13 @@
import argparse import argparse
import logging import logging
import math
import os import os
import random import random
import time import time
from pathlib import Path from pathlib import Path
from threading import Thread
from warnings import warn from warnings import warn
import math
import numpy as np import numpy as np
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
@ -134,6 +135,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem, project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
name=save_dir.stem, name=save_dir.stem,
id=ckpt.get('wandb_id') if 'ckpt' in locals() else None) id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
loggers = {'wandb': wandb} # loggers dict
# Resume # Resume
start_epoch, best_fitness = 0, 0.0 start_epoch, best_fitness = 0, 0.0
@ -201,11 +203,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
# model._initialize_biases(cf.to(device)) # model._initialize_biases(cf.to(device))
if plots: if plots:
plot_labels(labels, save_dir=save_dir) Thread(target=plot_labels, args=(labels, save_dir, loggers), daemon=True).start()
if tb_writer: if tb_writer:
tb_writer.add_histogram('classes', c, 0) tb_writer.add_histogram('classes', c, 0)
if wandb:
wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.png')]})
# Anchors # Anchors
if not opt.noautoanchor: if not opt.noautoanchor:
@ -311,7 +311,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# Plot # Plot
if plots and ni < 3: if plots and ni < 3:
f = save_dir / f'train_batch{ni}.jpg' # filename f = save_dir / f'train_batch{ni}.jpg' # filename
plot_images(images=imgs, targets=targets, paths=paths, fname=f) Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
# if tb_writer: # if tb_writer:
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
# tb_writer.add_graph(model, imgs) # add model to tensorboard # tb_writer.add_graph(model, imgs) # add model to tensorboard

View File

@ -250,7 +250,7 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx
plt.savefig('test_study.png', dpi=300) plt.savefig('test_study.png', dpi=300)
def plot_labels(labels, save_dir=''): def plot_labels(labels, save_dir=Path(''), loggers=None):
# plot dataset labels # plot dataset labels
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
nc = int(c.max() + 1) # number of classes nc = int(c.max() + 1) # number of classes
@ -264,7 +264,7 @@ def plot_labels(labels, save_dir=''):
sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o', sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o',
plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02), plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02),
diag_kws=dict(bins=50)) diag_kws=dict(bins=50))
plt.savefig(Path(save_dir) / 'labels_correlogram.png', dpi=200) plt.savefig(save_dir / 'labels_correlogram.png', dpi=200)
plt.close() plt.close()
except Exception as e: except Exception as e:
pass pass
@ -292,9 +292,14 @@ def plot_labels(labels, save_dir=''):
for a in [0, 1, 2, 3]: for a in [0, 1, 2, 3]:
for s in ['top', 'right', 'left', 'bottom']: for s in ['top', 'right', 'left', 'bottom']:
ax[a].spines[s].set_visible(False) ax[a].spines[s].set_visible(False)
plt.savefig(Path(save_dir) / 'labels.png', dpi=200) plt.savefig(save_dir / 'labels.png', dpi=200)
plt.close() plt.close()
# loggers
for k, v in loggers.items() or {}:
if k == 'wandb' and v:
v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.png')]})
def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution() def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
# Plot hyperparameter evolution results in evolve.txt # Plot hyperparameter evolution results in evolve.txt