Add ConfusionMatrix `normalize=True` flag ()

pull/3587/head
Glenn Jocher 2021-06-11 11:37:08 +02:00 committed by GitHub
parent 46e1fdfbc6
commit ec2da4a82c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 3 deletions

View File

@ -158,11 +158,12 @@ class ConfusionMatrix:
def matrix(self):
return self.matrix
def plot(self, save_dir='', names=()):
def plot(self, normalize=True, save_dir='', names=()):
try:
import seaborn as sn
array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize
if normalize:
array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize columns
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
fig = plt.figure(figsize=(12, 9), tight_layout=True)