32 lines
1.1 KiB
Python
Raw Normal View History

2020-05-21 23:58:35 +08:00
# encoding: utf-8
"""
@authorr: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
2020-07-10 16:27:22 +08:00
from fastreid.modeling.losses import *
2020-05-21 23:58:35 +08:00
from fastreid.modeling.meta_arch import Baseline
from fastreid.modeling.meta_arch.build import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class PartialBaseline(Baseline):
2020-07-10 16:27:22 +08:00
2020-07-15 15:08:53 +08:00
def losses(self, outputs, gt_labels):
cls_outputs, fore_cls_outputs, pred_class_logits, global_feat, fore_feat = outputs
2020-05-21 23:58:35 +08:00
loss_dict = {}
2020-07-10 16:27:22 +08:00
loss_names = self._cfg.MODEL.LOSSES.NAME
2020-07-15 15:08:53 +08:00
# Log prediction accuracy
CrossEntropyLoss.log_accuracy(pred_class_logits, gt_labels)
2020-07-10 16:27:22 +08:00
if "CrossEntropyLoss" in loss_names:
loss_dict['loss_avg_branch_cls'] = CrossEntropyLoss(self._cfg)(cls_outputs, gt_labels)
loss_dict['loss_fore_branch_cls'] = CrossEntropyLoss(self._cfg)(fore_cls_outputs, gt_labels)
if "TripletLoss" in loss_names:
loss_dict['loss_avg_branch_triplet'] = TripletLoss(self._cfg)(global_feat, gt_labels)
loss_dict['loss_fore_branch_triplet'] = TripletLoss(self._cfg)(fore_feat, gt_labels)
2020-05-21 23:58:35 +08:00
return loss_dict