change cross_entroy_loss input name

Summary: change `pred_class_logits` to `pred_class_outputs` to prevent misleading. (#318)

close #318
pull/365/head
liaoxingyu 2020-11-06 14:16:31 +08:00
parent a00e50d37f
commit f496193f17
1 changed files with 4 additions and 4 deletions

View File

@ -28,17 +28,17 @@ def log_accuracy(pred_class_logits, gt_classes, topk=(1,)):
storage.put_scalar("cls_accuracy", ret[0])
def cross_entropy_loss(pred_class_logits, gt_classes, eps, alpha=0.2):
num_classes = pred_class_logits.size(1)
def cross_entropy_loss(pred_class_outputs, gt_classes, eps, alpha=0.2):
num_classes = pred_class_outputs.size(1)
if eps >= 0:
smooth_param = eps
else:
# Adaptive label smooth regularization
soft_label = F.softmax(pred_class_logits, dim=1)
soft_label = F.softmax(pred_class_outputs, dim=1)
smooth_param = alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1)
log_probs = F.log_softmax(pred_class_logits, dim=1)
log_probs = F.log_softmax(pred_class_outputs, dim=1)
with torch.no_grad():
targets = torch.ones_like(log_probs)
targets *= smooth_param / (num_classes - 1)