assert best possible recall > 0.9 before training
parent
19e68e8a7b
commit
31f3310029
5
train.py
5
train.py
|
@ -191,7 +191,7 @@ def train(hyp):
|
||||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||||
model.names = data_dict['names']
|
model.names = data_dict['names']
|
||||||
|
|
||||||
# class frequency
|
# Class frequency
|
||||||
labels = np.concatenate(dataset.labels, 0)
|
labels = np.concatenate(dataset.labels, 0)
|
||||||
c = torch.tensor(labels[:, 0]) # classes
|
c = torch.tensor(labels[:, 0]) # classes
|
||||||
# cf = torch.bincount(c.long(), minlength=nc) + 1.
|
# cf = torch.bincount(c.long(), minlength=nc) + 1.
|
||||||
|
@ -199,6 +199,9 @@ def train(hyp):
|
||||||
plot_labels(labels)
|
plot_labels(labels)
|
||||||
tb_writer.add_histogram('classes', c, 0)
|
tb_writer.add_histogram('classes', c, 0)
|
||||||
|
|
||||||
|
# Check anchors
|
||||||
|
check_best_possible_recall(dataset, anchors=model.model[-1].anchor_grid, thr=hyp['anchor_t'])
|
||||||
|
|
||||||
# Exponential moving average
|
# Exponential moving average
|
||||||
ema = torch_utils.ModelEMA(model)
|
ema = torch_utils.ModelEMA(model)
|
||||||
|
|
||||||
|
|
|
@ -291,20 +291,22 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||||
self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt')
|
self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt')
|
||||||
for x in self.img_files]
|
for x in self.img_files]
|
||||||
|
|
||||||
|
# Read image shapes (wh)
|
||||||
|
sp = path.replace('.txt', '') + '.shapes' # shapefile path
|
||||||
|
try:
|
||||||
|
with open(sp, 'r') as f: # read existing shapefile
|
||||||
|
s = [x.split() for x in f.read().splitlines()]
|
||||||
|
assert len(s) == n, 'Shapefile out of sync'
|
||||||
|
except:
|
||||||
|
s = [exif_size(Image.open(f)) for f in tqdm(self.img_files, desc='Reading image shapes')]
|
||||||
|
np.savetxt(sp, s, fmt='%g') # overwrites existing (if any)
|
||||||
|
|
||||||
|
self.shapes = np.array(s, dtype=np.float64)
|
||||||
|
|
||||||
# Rectangular Training https://github.com/ultralytics/yolov3/issues/232
|
# Rectangular Training https://github.com/ultralytics/yolov3/issues/232
|
||||||
if self.rect:
|
if self.rect:
|
||||||
# Read image shapes (wh)
|
|
||||||
sp = path.replace('.txt', '') + '.shapes' # shapefile path
|
|
||||||
try:
|
|
||||||
with open(sp, 'r') as f: # read existing shapefile
|
|
||||||
s = [x.split() for x in f.read().splitlines()]
|
|
||||||
assert len(s) == n, 'Shapefile out of sync'
|
|
||||||
except:
|
|
||||||
s = [exif_size(Image.open(f)) for f in tqdm(self.img_files, desc='Reading image shapes')]
|
|
||||||
np.savetxt(sp, s, fmt='%g') # overwrites existing (if any)
|
|
||||||
|
|
||||||
# Sort by aspect ratio
|
# Sort by aspect ratio
|
||||||
s = np.array(s, dtype=np.float64)
|
s = self.shapes # wh
|
||||||
ar = s[:, 1] / s[:, 0] # aspect ratio
|
ar = s[:, 1] / s[:, 0] # aspect ratio
|
||||||
irect = ar.argsort()
|
irect = ar.argsort()
|
||||||
self.img_files = [self.img_files[i] for i in irect]
|
self.img_files = [self.img_files[i] for i in irect]
|
||||||
|
|
|
@ -51,6 +51,19 @@ def check_img_size(img_size, s=32):
|
||||||
return make_divisible(img_size, s) # nearest gs-multiple
|
return make_divisible(img_size, s) # nearest gs-multiple
|
||||||
|
|
||||||
|
|
||||||
|
def check_best_possible_recall(dataset, anchors, thr):
|
||||||
|
# Check best possible recall of dataset with current anchors
|
||||||
|
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(dataset.shapes, dataset.labels)])) # width-height
|
||||||
|
ratio = wh[:, None] / anchors.view(-1, 2)[None] # ratio
|
||||||
|
m = torch.max(ratio, 1. / ratio).max(2)[0] # max ratio
|
||||||
|
bpr = (m.min(1)[0] < thr).float().mean() # best possible recall
|
||||||
|
mr = (m < thr).float().mean() # match ratio
|
||||||
|
print(('Label width-height:' + '%10s' * 6) % ('n', 'mean', 'min', 'max', 'matching', 'recall'))
|
||||||
|
print((' ' + '%10.4g' * 6) % (wh.shape[0], wh.mean(), wh.min(), wh.max(), mr, bpr))
|
||||||
|
assert bpr > 0.9, 'Best possible recall %.3g (BPR) below 0.9 threshold. Training cancelled. ' \
|
||||||
|
'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.' % bpr
|
||||||
|
|
||||||
|
|
||||||
def make_divisible(x, divisor):
|
def make_divisible(x, divisor):
|
||||||
# Returns x evenly divisble by divisor
|
# Returns x evenly divisble by divisor
|
||||||
return math.ceil(x / divisor) * divisor
|
return math.ceil(x / divisor) * divisor
|
||||||
|
|
Loading…
Reference in New Issue