From 945307beba39fbe627c38eb9a9082b1a4c2c22e4 Mon Sep 17 00:00:00 2001 From: Alex Stoken Date: Wed, 17 Jun 2020 16:03:18 -0500 Subject: [PATCH] Add save_dir to plot_lr_scheduler and plot_labels Set save_dir = log_dir in train.py --- data/coco128.yaml | 4 ++-- train.py | 2 +- utils/utils.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/data/coco128.yaml b/data/coco128.yaml index 2b6184890..6f72f4ab9 100644 --- a/data/coco128.yaml +++ b/data/coco128.yaml @@ -8,8 +8,8 @@ # train and val datasets (image directory or *.txt file with image paths) -train: ../coco128/images/train2017/ -val: ../coco128/images/train2017/ +train: C:/Users/astoken/projects/yolov5/data/coco/images/train2017 +val: C:/Users/astoken/projects/yolov5/data/coco/images/train2017 # number of classes nc: 80 diff --git a/train.py b/train.py index e4fc6254a..500a65818 100644 --- a/train.py +++ b/train.py @@ -196,7 +196,7 @@ def train(hyp): c = torch.tensor(labels[:, 0]) # classes # cf = torch.bincount(c.long(), minlength=nc) + 1. # model._initialize_biases(cf.to(device)) - plot_labels(labels) + plot_labels(labels, save_dir=log_dir) tb_writer.add_histogram('classes', c, 0) # Check anchors diff --git a/utils/utils.py b/utils/utils.py index 5332beeb6..fb8d4877d 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -1025,7 +1025,7 @@ def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir='./'): plt.xlim(0, epochs) plt.ylim(0) plt.tight_layout() - plt.savefig('LR.png', dpi=200) + plt.savefig(os.path.join(save_dir, 'LR.png'), dpi=200) def plot_test_txt(): # from utils.utils import *; plot_test() @@ -1088,7 +1088,7 @@ def plot_study_txt(f='study.txt', x=None): # from utils.utils import *; plot_st plt.savefig(f.replace('.txt', '.png'), dpi=200) -def plot_labels(labels): +def plot_labels(labels, save_dir= '.'): # plot dataset labels c, b = labels[:, 0], labels[:, 1:].transpose() # classees, boxes @@ -1109,7 +1109,7 @@ def plot_labels(labels): ax[2].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet') ax[2].set_xlabel('width') ax[2].set_ylabel('height') - plt.savefig('labels.png', dpi=200) + plt.savefig(os.path.join(save_dir,'labels.png'), dpi=200) def plot_evolution_results(hyp): # from utils.utils import *; plot_evolution_results(hyp)