mirror of https://github.com/WongKinYiu/yolov7.git
Update validation set : Test36B 13/11
training-set Test36B 12/11 +missing png/s and other tiff from the final annotation list : tir_tiff_seq_png_3_class_fixed_whether_copied_dataset_label.xlsxpull/2071/head
parent
c90fa54f1e
commit
2cc7e05cc8
12
train.py
12
train.py
|
@ -388,9 +388,13 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
for file in test_dataset.img_files:
|
||||
f.write(f"{file}\n")
|
||||
|
||||
labels = np.concatenate(dataset.labels, 0)
|
||||
c = torch.tensor(labels[:, 0]) # classes
|
||||
|
||||
labels_test = np.concatenate(testloader.dataset.labels, 0)
|
||||
c_test = torch.tensor(labels_test[:, 0]) # classes
|
||||
|
||||
if not opt.resume:
|
||||
labels = np.concatenate(dataset.labels, 0)
|
||||
c = torch.tensor(labels[:, 0]) # classes
|
||||
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
|
||||
# model._initialize_biases(cf.to(device))
|
||||
if plots:
|
||||
|
@ -483,6 +487,10 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
|
||||
print(100 * '==')
|
||||
print('Training set labels {} count : {}'.format(names, torch.bincount(c.long(), minlength=nc) + 1))
|
||||
|
||||
print(100 * '==')
|
||||
print('Validation set labels {} count : {}'.format(names, torch.bincount(c_test.long(), minlength=nc) + 1))
|
||||
|
||||
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
|
||||
model.train()
|
||||
|
||||
|
|
Loading…
Reference in New Issue