add mosaic and warmup to hyperparameters (#931)
parent
806e75f2b1
commit
f1c63e2784
|
@ -12,6 +12,9 @@ lr0: 0.0032
|
|||
lrf: 0.12
|
||||
momentum: 0.843
|
||||
weight_decay: 0.00036
|
||||
warmup_epochs: 2.0
|
||||
warmup_momentum: 0.5
|
||||
warmup_bias_lr: 0.05
|
||||
giou: 0.0296
|
||||
cls: 0.243
|
||||
cls_pw: 0.631
|
||||
|
@ -31,4 +34,5 @@ shear: 0.602
|
|||
perspective: 0.0
|
||||
flipud: 0.00856
|
||||
fliplr: 0.5
|
||||
mosaic: 1.0
|
||||
mixup: 0.243
|
||||
|
|
|
@ -7,6 +7,9 @@ lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
|||
lrf: 0.2 # final OneCycleLR learning rate (lr0 * lrf)
|
||||
momentum: 0.937 # SGD momentum/Adam beta1
|
||||
weight_decay: 0.0005 # optimizer weight decay 5e-4
|
||||
warmup_epochs: 3.0 # warmup epochs (fractions ok)
|
||||
warmup_momentum: 0.8 # warmup initial momentum
|
||||
warmup_bias_lr: 0.1 # warmup initial bias lr
|
||||
giou: 0.05 # box loss gain
|
||||
cls: 0.5 # cls loss gain
|
||||
cls_pw: 1.0 # cls BCELoss positive_weight
|
||||
|
@ -26,4 +29,5 @@ shear: 0.0 # image shear (+/- deg)
|
|||
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
|
||||
flipud: 0.0 # image flip up-down (probability)
|
||||
fliplr: 0.5 # image flip left-right (probability)
|
||||
mosaic: 1.0 # image mosaic (probability)
|
||||
mixup: 0.0 # image mixup (probability)
|
||||
|
|
18
train.py
18
train.py
|
@ -202,7 +202,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
|
||||
# Start training
|
||||
t0 = time.time()
|
||||
nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
|
||||
nw = max(round(hyp['warmup_epochs'] * nb), 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
|
||||
# nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
|
||||
maps = np.zeros(nc) # mAP per class
|
||||
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
|
||||
|
@ -250,9 +250,9 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
|
||||
for j, x in enumerate(optimizer.param_groups):
|
||||
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
||||
x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
|
||||
x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
|
||||
if 'momentum' in x:
|
||||
x['momentum'] = np.interp(ni, xi, [0.9, hyp['momentum']])
|
||||
x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
|
||||
|
||||
# Multi-scale
|
||||
if opt.multi_scale:
|
||||
|
@ -460,8 +460,11 @@ if __name__ == '__main__':
|
|||
# Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
|
||||
meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||
'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
|
||||
'momentum': (0.1, 0.6, 0.98), # SGD momentum/Adam beta1
|
||||
'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
|
||||
'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
|
||||
'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
|
||||
'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
|
||||
'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
|
||||
'giou': (1, 0.02, 0.2), # GIoU loss gain
|
||||
'cls': (1, 0.2, 4.0), # cls loss gain
|
||||
'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
|
||||
|
@ -469,7 +472,7 @@ if __name__ == '__main__':
|
|||
'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
|
||||
'iou_t': (0, 0.1, 0.7), # IoU training threshold
|
||||
'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
|
||||
'anchors': (1, 2.0, 10.0), # anchors per output grid (0 to ignore)
|
||||
'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
|
||||
'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
|
||||
'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
||||
'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
||||
|
@ -481,6 +484,7 @@ if __name__ == '__main__':
|
|||
'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
|
||||
'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
|
||||
'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
|
||||
'mosaic': (1, 0.0, 1.0), # image mixup (probability)
|
||||
'mixup': (1, 0.0, 1.0)} # image mixup (probability)
|
||||
|
||||
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
|
||||
|
@ -490,7 +494,7 @@ if __name__ == '__main__':
|
|||
if opt.bucket:
|
||||
os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
|
||||
|
||||
for _ in range(1): # generations to evolve
|
||||
for _ in range(300): # generations to evolve
|
||||
if os.path.exists('evolve.txt'): # if evolve.txt exists: select best hyps and mutate
|
||||
# Select parent(s)
|
||||
parent = 'single' # parent selection method: 'single' or 'weighted'
|
||||
|
@ -505,7 +509,7 @@ if __name__ == '__main__':
|
|||
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
|
||||
|
||||
# Mutate
|
||||
mp, s = 0.9, 0.2 # mutation probability, sigma
|
||||
mp, s = 0.8, 0.2 # mutation probability, sigma
|
||||
npr = np.random
|
||||
npr.seed(int(time.time()))
|
||||
g = np.array([x[0] for x in meta.values()]) # gains 0-1
|
||||
|
|
|
@ -516,7 +516,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
index = self.indices[index]
|
||||
|
||||
hyp = self.hyp
|
||||
if self.mosaic:
|
||||
mosaic = self.mosaic and random.random() < hyp['mosaic']
|
||||
if mosaic:
|
||||
# Load mosaic
|
||||
img, labels = load_mosaic(self, index)
|
||||
shapes = None
|
||||
|
@ -550,7 +551,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
|
||||
if self.augment:
|
||||
# Augment imagespace
|
||||
if not self.mosaic:
|
||||
if not mosaic:
|
||||
img, labels = random_perspective(img, labels,
|
||||
degrees=hyp['degrees'],
|
||||
translate=hyp['translate'],
|
||||
|
|
Loading…
Reference in New Issue