Ignore Seaborn plot warnings (#3576)

* Ignore Seaborn plot warnings

* Update plots.py

* Update metrics.py
pull/3579/head
Glenn Jocher 2021-06-10 13:51:29 +02:00 committed by GitHub
parent 4695ca8314
commit 095197bd4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 7 deletions

View File

@ -1,5 +1,6 @@
# Model validation metrics
import warnings
from pathlib import Path
import matplotlib.pyplot as plt
@ -167,9 +168,11 @@ class ConfusionMatrix:
fig = plt.figure(figsize=(12, 9), tight_layout=True)
sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
xticklabels=names + ['background FP'] if labels else "auto",
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
xticklabels=names + ['background FP'] if labels else "auto",
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
fig.axes[0].set_xlabel('True')
fig.axes[0].set_ylabel('Predicted')
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)

View File

@ -11,7 +11,7 @@ import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import seaborn as sn
import torch
import yaml
from PIL import Image, ImageDraw, ImageFont
@ -291,7 +291,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
# seaborn correlogram
sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
plt.close()
@ -306,8 +306,8 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
ax[0].set_xticklabels(names, rotation=90, fontsize=10)
else:
ax[0].set_xlabel('classes')
sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
# rectangles
labels[:, 1:3] = 0.5 # center