From f0101475788590720a3a4b1152a89d531f311dff Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 2 Dec 2020 15:53:16 +0100 Subject: [PATCH] Update matplotlib.use('Agg') tight (#1583) * Update matplotlib tight_layout=True * udpate * udpate * update * png to ps * update * update --- utils/autoanchor.py | 3 +-- utils/metrics.py | 6 ++---- utils/plots.py | 20 ++++++++++---------- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/utils/autoanchor.py b/utils/autoanchor.py index 63fac5497..0c33dcbc3 100644 --- a/utils/autoanchor.py +++ b/utils/autoanchor.py @@ -124,13 +124,12 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10 # k, d = [None] * 20, [None] * 20 # for i in tqdm(range(1, 21)): # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance - # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) + # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True) # ax = ax.ravel() # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.') # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh # ax[0].hist(wh[wh[:, 0]<100, 0],400) # ax[1].hist(wh[wh[:, 1]<100, 1],400) - # fig.tight_layout() # fig.savefig('wh.png', dpi=200) # Evolve diff --git a/utils/metrics.py b/utils/metrics.py index 79f18cff3..af32ddc5b 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -163,7 +163,7 @@ class ConfusionMatrix: array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) - fig = plt.figure(figsize=(12, 9)) + fig = plt.figure(figsize=(12, 9), tight_layout=True) sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True, @@ -171,7 +171,6 @@ class ConfusionMatrix: yticklabels=names + ['background FP'] if labels else "auto").set_facecolor((1, 1, 1)) fig.axes[0].set_xlabel('True') fig.axes[0].set_ylabel('Predicted') - fig.tight_layout() fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) except Exception as e: pass @@ -184,7 +183,7 @@ class ConfusionMatrix: # Plots ---------------------------------------------------------------------------------------------------------------- def plot_pr_curve(px, py, ap, save_dir='.', names=()): - fig, ax = plt.subplots(1, 1, figsize=(9, 6)) + fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) py = np.stack(py, axis=1) if 0 < len(names) < 21: # show mAP in legend if < 10 classes @@ -199,5 +198,4 @@ def plot_pr_curve(px, py, ap, save_dir='.', names=()): ax.set_xlim(0, 1) ax.set_ylim(0, 1) plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") - fig.tight_layout() fig.savefig(Path(save_dir) / 'precision_recall_curve.png', dpi=250) diff --git a/utils/plots.py b/utils/plots.py index 3808ec19d..bdc78c0e3 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -21,7 +21,7 @@ from utils.metrics import fitness # Settings matplotlib.rc('font', **{'size': 11}) -matplotlib.use('svg') # for writing to files only +matplotlib.use('Agg') # for writing to files only def color_list(): @@ -73,7 +73,7 @@ def plot_wh_methods(): # from utils.plots import *; plot_wh_methods() ya = np.exp(x) yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2 - fig = plt.figure(figsize=(6, 3), dpi=150) + fig = plt.figure(figsize=(6, 3), tight_layout=True) plt.plot(x, ya, '.-', label='YOLOv3') plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2') plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6') @@ -83,7 +83,6 @@ def plot_wh_methods(): # from utils.plots import *; plot_wh_methods() plt.ylabel('output') plt.grid() plt.legend() - fig.tight_layout() fig.savefig('comparison.png', dpi=200) @@ -145,7 +144,7 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max if boxes.max() <= 1: # if normalized boxes[[0, 2]] *= w # scale to pixels boxes[[1, 3]] *= h - elif scale_factor < 1: # absolute coords need scale if image scales + elif scale_factor < 1: # absolute coords need scale if image scales boxes *= scale_factor boxes[[0, 2]] += block_x boxes[[1, 3]] += block_y @@ -188,7 +187,6 @@ def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''): plt.grid() plt.xlim(0, epochs) plt.ylim(0) - plt.tight_layout() plt.savefig(Path(save_dir) / 'LR.png', dpi=200) @@ -267,12 +265,13 @@ def plot_labels(labels, save_dir=Path(''), loggers=None): sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o', plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02), diag_kws=dict(bins=50)) - plt.savefig(save_dir / 'labels_correlogram.png', dpi=200) + plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200) plt.close() except Exception as e: pass # matplotlib labels + matplotlib.use('svg') # faster ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) ax[0].set_xlabel('classes') @@ -295,13 +294,15 @@ def plot_labels(labels, save_dir=Path(''), loggers=None): for a in [0, 1, 2, 3]: for s in ['top', 'right', 'left', 'bottom']: ax[a].spines[s].set_visible(False) - plt.savefig(save_dir / 'labels.png', dpi=200) + + plt.savefig(save_dir / 'labels.jpg', dpi=200) + matplotlib.use('Agg') plt.close() # loggers for k, v in loggers.items() or {}: if k == 'wandb' and v: - v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.png')]}) + v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]}) def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution() @@ -353,7 +354,7 @@ def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_re def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp') - fig, ax = plt.subplots(2, 5, figsize=(12, 6)) + fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True) ax = ax.ravel() s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall', 'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95'] @@ -383,6 +384,5 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): except Exception as e: print('Warning: Plotting error for %s; %s' % (f, e)) - fig.tight_layout() ax[1].legend() fig.savefig(Path(save_dir) / 'results.png', dpi=200)