AutoAnchor bug fix #72
parent
8fa3724072
commit
8b26e89006
3
train.py
3
train.py
|
@ -4,7 +4,6 @@ import torch.distributed as dist
|
|||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.optim.lr_scheduler as lr_scheduler
|
||||
import yaml
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
import test # import test.py to get mAP after each epoch
|
||||
|
@ -200,7 +199,7 @@ def train(hyp):
|
|||
tb_writer.add_histogram('classes', c, 0)
|
||||
|
||||
# Check anchors
|
||||
check_best_possible_recall(dataset, anchors=model.model[-1].anchor_grid, thr=hyp['anchor_t'])
|
||||
check_best_possible_recall(dataset, anchors=model.model[-1].anchor_grid, thr=hyp['anchor_t'], imgsz=imgsz)
|
||||
|
||||
# Exponential moving average
|
||||
ema = torch_utils.ModelEMA(model)
|
||||
|
|
|
@ -52,15 +52,17 @@ def check_img_size(img_size, s=32):
|
|||
return make_divisible(img_size, s) # nearest gs-multiple
|
||||
|
||||
|
||||
def check_best_possible_recall(dataset, anchors, thr):
|
||||
def check_best_possible_recall(dataset, anchors, thr=4.0, imgsz=640):
|
||||
# Check best possible recall of dataset with current anchors
|
||||
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(dataset.shapes, dataset.labels)])).float() # wh
|
||||
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
|
||||
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
|
||||
ratio = wh[:, None] / anchors.view(-1, 2).cpu()[None] # ratio
|
||||
m = torch.max(ratio, 1. / ratio).max(2)[0] # max ratio
|
||||
bpr = (m.min(1)[0] < thr).float().mean() # best possible recall
|
||||
mr = (m < thr).float().mean() # match ratio
|
||||
print(('Label width-height:' + '%10s' * 6) % ('n', 'mean', 'min', 'max', 'matching', 'recall'))
|
||||
print((' ' + '%10.4g' * 6) % (wh.shape[0], wh.mean(), wh.min(), wh.max(), mr, bpr))
|
||||
print(('AutoAnchor labels:' + '%10s' * 6) % ('n', 'mean', 'min', 'max', 'matching', 'recall'))
|
||||
print((' ' + '%10.4g' * 6) % (wh.shape[0], wh.mean(), wh.min(), wh.max(), mr, bpr))
|
||||
|
||||
assert bpr > 0.9, 'Best possible recall %.3g (BPR) below 0.9 threshold. Training cancelled. ' \
|
||||
'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.' % bpr
|
||||
|
||||
|
|
Loading…
Reference in New Issue