2020-05-21 23:58:35 +08:00
|
|
|
# encoding: utf-8
|
|
|
|
"""
|
|
|
|
@authorr: liaoxingyu
|
|
|
|
@contact: sherlockliao01@gmail.com
|
|
|
|
"""
|
|
|
|
|
|
|
|
from fastreid.modeling.losses import reid_losses
|
|
|
|
from fastreid.modeling.meta_arch import Baseline
|
|
|
|
from fastreid.modeling.meta_arch.build import META_ARCH_REGISTRY
|
|
|
|
|
|
|
|
|
|
|
|
@META_ARCH_REGISTRY.register()
|
|
|
|
class PartialBaseline(Baseline):
|
|
|
|
def losses(self, outputs):
|
|
|
|
pred_logits, global_feat, fore_pred_logits, fore_feat, targets = outputs
|
|
|
|
loss_dict = {}
|
2020-05-28 11:14:13 +08:00
|
|
|
loss_dict.update(reid_losses(self._cfg, pred_logits, global_feat, targets, 'avg_branch_'))
|
|
|
|
loss_dict.update(reid_losses(self._cfg, fore_pred_logits, fore_feat, targets, 'fore_branch_'))
|
2020-05-21 23:58:35 +08:00
|
|
|
return loss_dict
|