mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Update hyperparameters to add lrf, anchors
This commit is contained in:
parent
9776e70988
commit
08e97a2f88
@ -1,27 +1,34 @@
|
|||||||
# Hyperparameters for VOC fine-tuning
|
# Hyperparameters for VOC finetuning
|
||||||
# python train.py --batch 64 --cfg '' --weights yolov5m.pt --data voc.yaml --img 512 --epochs 50
|
# python train.py --batch 64 --weights yolov5m.pt --data voc.yaml --img 512 --epochs 50
|
||||||
# See tutorials for hyperparameter evolution https://github.com/ultralytics/yolov5#tutorials
|
# See tutorials for hyperparameter evolution https://github.com/ultralytics/yolov5#tutorials
|
||||||
|
|
||||||
|
|
||||||
lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
# Hyperparameter Evolution Results
|
||||||
momentum: 0.94 # SGD momentum/Adam beta1
|
# Generations: 51
|
||||||
weight_decay: 0.0005 # optimizer weight decay 5e-4
|
# P R mAP.5 mAP.5:.95 box obj cls
|
||||||
giou: 0.05 # GIoU loss gain
|
# Metrics: 0.625 0.926 0.89 0.677 0.0111 0.00849 0.00124
|
||||||
cls: 0.4 # cls loss gain
|
|
||||||
cls_pw: 1.0 # cls BCELoss positive_weight
|
lr0: 0.00447
|
||||||
obj: 0.5 # obj loss gain (scale with pixels)
|
lrf: 0.114
|
||||||
obj_pw: 1.0 # obj BCELoss positive_weight
|
momentum: 0.873
|
||||||
iou_t: 0.20 # IoU training threshold
|
weight_decay: 0.00047
|
||||||
anchor_t: 4.0 # anchor-multiple threshold
|
giou: 0.0306
|
||||||
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
|
cls: 0.211
|
||||||
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
|
cls_pw: 0.546
|
||||||
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
|
obj: 0.421
|
||||||
hsv_v: 0.4 # image HSV-Value augmentation (fraction)
|
obj_pw: 0.972
|
||||||
degrees: 1.0 # image rotation (+/- deg)
|
iou_t: 0.2
|
||||||
translate: 0.1 # image translation (+/- fraction)
|
anchor_t: 2.26
|
||||||
scale: 0.6 # image scale (+/- gain)
|
# anchors: 5.07
|
||||||
shear: 1.0 # image shear (+/- deg)
|
fl_gamma: 0.0
|
||||||
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
|
hsv_h: 0.0154
|
||||||
flipud: 0.01 # image flip up-down (probability)
|
hsv_s: 0.9
|
||||||
fliplr: 0.5 # image flip left-right (probability)
|
hsv_v: 0.619
|
||||||
mixup: 0.2 # image mixup (probability)
|
degrees: 0.404
|
||||||
|
translate: 0.206
|
||||||
|
scale: 0.86
|
||||||
|
shear: 0.795
|
||||||
|
perspective: 0.0
|
||||||
|
flipud: 0.00756
|
||||||
|
fliplr: 0.5
|
||||||
|
mixup: 0.153
|
||||||
|
@ -4,15 +4,17 @@
|
|||||||
|
|
||||||
|
|
||||||
lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||||
|
lrf: 0.2 # final OneCycleLR learning rate (lr0 * lrf)
|
||||||
momentum: 0.937 # SGD momentum/Adam beta1
|
momentum: 0.937 # SGD momentum/Adam beta1
|
||||||
weight_decay: 0.0005 # optimizer weight decay 5e-4
|
weight_decay: 0.0005 # optimizer weight decay 5e-4
|
||||||
giou: 0.05 # GIoU loss gain
|
giou: 0.05 # box loss gain
|
||||||
cls: 0.5 # cls loss gain
|
cls: 0.5 # cls loss gain
|
||||||
cls_pw: 1.0 # cls BCELoss positive_weight
|
cls_pw: 1.0 # cls BCELoss positive_weight
|
||||||
obj: 1.0 # obj loss gain (scale with pixels)
|
obj: 1.0 # obj loss gain (scale with pixels)
|
||||||
obj_pw: 1.0 # obj BCELoss positive_weight
|
obj_pw: 1.0 # obj BCELoss positive_weight
|
||||||
iou_t: 0.20 # IoU training threshold
|
iou_t: 0.20 # IoU training threshold
|
||||||
anchor_t: 4.0 # anchor-multiple threshold
|
anchor_t: 4.0 # anchor-multiple threshold
|
||||||
|
# anchors: 0 # anchors per output grid (0 to ignore)
|
||||||
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
|
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
|
||||||
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
|
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
|
||||||
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
|
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
|
||||||
|
14
train.py
14
train.py
@ -53,7 +53,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|||||||
cuda = device.type != 'cpu'
|
cuda = device.type != 'cpu'
|
||||||
init_seeds(2 + rank)
|
init_seeds(2 + rank)
|
||||||
with open(opt.data) as f:
|
with open(opt.data) as f:
|
||||||
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
data_dict = yaml.load(f, Loader=yaml.FullLoader) # data dict
|
||||||
with torch_distributed_zero_first(rank):
|
with torch_distributed_zero_first(rank):
|
||||||
check_dataset(data_dict) # check
|
check_dataset(data_dict) # check
|
||||||
train_path = data_dict['train']
|
train_path = data_dict['train']
|
||||||
@ -67,6 +67,8 @@ def train(hyp, opt, device, tb_writer=None):
|
|||||||
with torch_distributed_zero_first(rank):
|
with torch_distributed_zero_first(rank):
|
||||||
attempt_download(weights) # download if not found locally
|
attempt_download(weights) # download if not found locally
|
||||||
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
||||||
|
# if hyp['anchors']:
|
||||||
|
# ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor
|
||||||
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
|
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
|
||||||
exclude = ['anchor'] if opt.cfg else [] # exclude keys
|
exclude = ['anchor'] if opt.cfg else [] # exclude keys
|
||||||
state_dict = ckpt['model'].float().state_dict() # to FP32
|
state_dict = ckpt['model'].float().state_dict() # to FP32
|
||||||
@ -111,7 +113,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|||||||
|
|
||||||
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
|
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
|
||||||
# https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
|
# https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
|
||||||
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine
|
lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - hyp['lrf']) + hyp['lrf'] # cosine
|
||||||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
||||||
# plot_lr_scheduler(optimizer, scheduler, epochs)
|
# plot_lr_scheduler(optimizer, scheduler, epochs)
|
||||||
|
|
||||||
@ -459,6 +461,7 @@ if __name__ == '__main__':
|
|||||||
else:
|
else:
|
||||||
# Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
|
# Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
|
||||||
meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
|
meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||||
|
'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
|
||||||
'momentum': (0.1, 0.6, 0.98), # SGD momentum/Adam beta1
|
'momentum': (0.1, 0.6, 0.98), # SGD momentum/Adam beta1
|
||||||
'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
|
'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
|
||||||
'giou': (1, 0.02, 0.2), # GIoU loss gain
|
'giou': (1, 0.02, 0.2), # GIoU loss gain
|
||||||
@ -468,6 +471,7 @@ if __name__ == '__main__':
|
|||||||
'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
|
'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
|
||||||
'iou_t': (0, 0.1, 0.7), # IoU training threshold
|
'iou_t': (0, 0.1, 0.7), # IoU training threshold
|
||||||
'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
|
'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
|
||||||
|
# 'anchors': (1, 2.0, 10.0), # anchors per output grid (0 to ignore)
|
||||||
'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
|
'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
|
||||||
'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
||||||
'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
||||||
@ -476,9 +480,9 @@ if __name__ == '__main__':
|
|||||||
'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
|
'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
|
||||||
'scale': (1, 0.0, 0.9), # image scale (+/- gain)
|
'scale': (1, 0.0, 0.9), # image scale (+/- gain)
|
||||||
'shear': (1, 0.0, 10.0), # image shear (+/- deg)
|
'shear': (1, 0.0, 10.0), # image shear (+/- deg)
|
||||||
'perspective': (1, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
|
'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
|
||||||
'flipud': (0, 0.0, 1.0), # image flip up-down (probability)
|
'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
|
||||||
'fliplr': (1, 0.0, 1.0), # image flip left-right (probability)
|
'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
|
||||||
'mixup': (1, 0.0, 1.0)} # image mixup (probability)
|
'mixup': (1, 0.0, 1.0)} # image mixup (probability)
|
||||||
|
|
||||||
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
|
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user