mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
ONNX, BCEBlurWithLogitsLoss, plot_study updates
This commit is contained in:
parent
18b4da91e0
commit
3a5c5328c5
@ -1,6 +1,9 @@
|
||||
# Exports a pytorch *.pt model to *.onnx format
|
||||
# Example usage (run from ./yolov5 directory):
|
||||
# $ export PYTHONPATH="$PWD" && python models/onnx_export.py --weights ./weights/yolov5s.pt --img 640 --batch 1
|
||||
"""Exports a pytorch *.pt model to *.onnx format
|
||||
|
||||
Usage:
|
||||
import torch
|
||||
$ export PYTHONPATH="$PWD" && python models/onnx_export.py --weights ./weights/yolov5s.pt --img 640 --batch 1
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
|
@ -339,6 +339,23 @@ def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#iss
|
||||
return 1.0 - 0.5 * eps, 0.5 * eps
|
||||
|
||||
|
||||
class BCEBlurWithLogitsLoss(nn.Module):
|
||||
# BCEwithLogitLoss() with reduced missing label effects.
|
||||
def __init__(self, alpha=0.05):
|
||||
super(BCEBlurWithLogitsLoss, self).__init__()
|
||||
self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss()
|
||||
self.alpha = alpha
|
||||
|
||||
def forward(self, pred, true):
|
||||
loss = self.loss_fcn(pred, true)
|
||||
pred = torch.sigmoid(pred) # prob from logits
|
||||
dx = pred - true # reduce only missing label effects
|
||||
# dx = (pred - true).abs() # reduce missing label and false label effects
|
||||
alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
|
||||
loss *= alpha_factor
|
||||
return loss.mean()
|
||||
|
||||
|
||||
def compute_loss(p, targets, model): # predictions, targets, model
|
||||
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
|
||||
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
|
||||
@ -1009,7 +1026,7 @@ def plot_study_txt(f='study.txt', x=None): # from utils.utils import *; plot_st
|
||||
ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [33.5, 39.1, 42.5, 45.9, 49., 50.5],
|
||||
'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
|
||||
ax2.set_xlim(0, 30)
|
||||
ax2.set_ylim(23, 50)
|
||||
ax2.set_ylim(25, 50)
|
||||
ax2.set_xlabel('GPU Latency (ms)')
|
||||
ax2.set_ylabel('COCO AP val')
|
||||
ax2.legend(loc='lower right')
|
||||
|
Loading…
x
Reference in New Issue
Block a user