mirror of https://github.com/JDAI-CV/fast-reid.git
change cross_entroy_loss input name
Summary: change `pred_class_logits` to `pred_class_outputs` to prevent misleading. (#318) close #318pull/365/head
parent
a00e50d37f
commit
f496193f17
|
@ -28,17 +28,17 @@ def log_accuracy(pred_class_logits, gt_classes, topk=(1,)):
|
|||
storage.put_scalar("cls_accuracy", ret[0])
|
||||
|
||||
|
||||
def cross_entropy_loss(pred_class_logits, gt_classes, eps, alpha=0.2):
|
||||
num_classes = pred_class_logits.size(1)
|
||||
def cross_entropy_loss(pred_class_outputs, gt_classes, eps, alpha=0.2):
|
||||
num_classes = pred_class_outputs.size(1)
|
||||
|
||||
if eps >= 0:
|
||||
smooth_param = eps
|
||||
else:
|
||||
# Adaptive label smooth regularization
|
||||
soft_label = F.softmax(pred_class_logits, dim=1)
|
||||
soft_label = F.softmax(pred_class_outputs, dim=1)
|
||||
smooth_param = alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1)
|
||||
|
||||
log_probs = F.log_softmax(pred_class_logits, dim=1)
|
||||
log_probs = F.log_softmax(pred_class_outputs, dim=1)
|
||||
with torch.no_grad():
|
||||
targets = torch.ones_like(log_probs)
|
||||
targets *= smooth_param / (num_classes - 1)
|
||||
|
|
Loading…
Reference in New Issue