From 08d4918d7f49055158b1cceb27ea0d1990251afc Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 12 Mar 2021 22:15:41 -0800 Subject: [PATCH] labels.jpg class names (#2454) * labels.png class names * fontsize=10 --- train.py | 2 +- utils/plots.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index dcb89a3c1..005fdf60c 100644 --- a/train.py +++ b/train.py @@ -203,7 +203,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency # model._initialize_biases(cf.to(device)) if plots: - plot_labels(labels, save_dir, loggers) + plot_labels(labels, names, save_dir, loggers) if tb_writer: tb_writer.add_histogram('classes', c, 0) diff --git a/utils/plots.py b/utils/plots.py index aa9a1cab8..47e7b7b74 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -269,7 +269,7 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx plt.savefig(str(Path(path).name) + '.png', dpi=300) -def plot_labels(labels, save_dir=Path(''), loggers=None): +def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): # plot dataset labels print('Plotting labels... ') c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes @@ -286,7 +286,12 @@ def plot_labels(labels, save_dir=Path(''), loggers=None): matplotlib.use('svg') # faster ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) - ax[0].set_xlabel('classes') + ax[0].set_ylabel('instances') + if 0 < len(names) < 30: + ax[0].set_xticks(range(len(names))) + ax[0].set_xticklabels(names, rotation=90, fontsize=10) + else: + ax[0].set_xlabel('classes') sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9) sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)