fast-reid/projects/PartialReID/partialreid/partialbaseline.py

47 lines
1.4 KiB
Python

# encoding: utf-8
"""
@authorr: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from fastreid.modeling.losses import *
from fastreid.modeling.meta_arch import Baseline
from fastreid.modeling.meta_arch.build import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class PartialBaseline(Baseline):
def losses(self, outputs, gt_labels):
r"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
loss_dict = super().losses(outputs, gt_labels)
fore_cls_outputs = outputs["fore_cls_outputs"]
fore_feat = outputs["foreground_features"]
loss_names = self.loss_kwargs['loss_names']
if 'CrossEntropyLoss' in loss_names:
ce_kwargs = self.loss_kwargs.get('ce')
loss_dict['loss_fore_cls'] = cross_entropy_loss(
fore_cls_outputs,
gt_labels,
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale')
if 'TripletLoss' in loss_names:
tri_kwargs = self.loss_kwargs.get('tri')
loss_dict['loss_fore_triplet'] = triplet_loss(
fore_feat,
gt_labels,
tri_kwargs.get('margin'),
tri_kwargs.get('norm_feat'),
tri_kwargs.get('hard_mining')
) * tri_kwargs.get('scale')
return loss_dict