fast-reid/projects/PartialReID/partialreid/partialbaseline.py

20 lines
675 B
Python

# encoding: utf-8
"""
@authorr: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from fastreid.modeling.losses import reid_losses
from fastreid.modeling.meta_arch import Baseline
from fastreid.modeling.meta_arch.build import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class PartialBaseline(Baseline):
def losses(self, outputs):
pred_logits, global_feat, fore_pred_logits, fore_feat, targets = outputs
loss_dict = {}
loss_dict.update(reid_losses(self._cfg, pred_logits, global_feat, targets, 'avg_branch_'))
loss_dict.update(reid_losses(self._cfg, fore_pred_logits, fore_feat, targets, 'fore_branch_'))
return loss_dict