Added param loss_ota for hyp.yaml, to disable OTA for faster training

pull/461/head
AlexeyAB84 2022-08-09 07:48:28 +03:00
parent 469a4d0e8d
commit 711a16ba57
5 changed files with 12 additions and 5 deletions

View File

@ -27,4 +27,5 @@ fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # image copy paste (probability)
paste_in: 0.0 # image copy paste (probability)
paste_in: 0.0 # image copy paste (probability), use 0 for faster training
loss_ota: 1 # use ComputeLossOTA, use 0 for faster training

View File

@ -27,4 +27,5 @@ fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability)
mixup: 0.15 # image mixup (probability)
copy_paste: 0.0 # image copy paste (probability)
paste_in: 0.15 # image copy paste (probability)
paste_in: 0.15 # image copy paste (probability), use 0 for faster training
loss_ota: 1 # use ComputeLossOTA, use 0 for faster training

View File

@ -27,4 +27,5 @@ fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability)
mixup: 0.15 # image mixup (probability)
copy_paste: 0.0 # image copy paste (probability)
paste_in: 0.15 # image copy paste (probability)
paste_in: 0.15 # image copy paste (probability), use 0 for faster training
loss_ota: 1 # use ComputeLossOTA, use 0 for faster training

View File

@ -27,4 +27,5 @@ fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability)
mixup: 0.05 # image mixup (probability)
copy_paste: 0.0 # image copy paste (probability)
paste_in: 0.05 # image copy paste (probability)
paste_in: 0.05 # image copy paste (probability), use 0 for faster training
loss_ota: 1 # use ComputeLossOTA, use 0 for faster training

View File

@ -359,7 +359,10 @@ def train(hyp, opt, device, tb_writer=None):
# Forward
with amp.autocast(enabled=cuda):
pred = model(imgs) # forward
loss, loss_items = compute_loss_ota(pred, targets.to(device), imgs) # loss scaled by batch_size
if hyp['loss_ota'] == 1:
loss, loss_items = compute_loss_ota(pred, targets.to(device), imgs) # loss scaled by batch_size
else:
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if rank != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode
if opt.quad: