mirror of https://github.com/JDAI-CV/fast-reid.git
47 lines
1.4 KiB
Python
47 lines
1.4 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@authorr: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
from fastreid.modeling.losses import *
|
|
from fastreid.modeling.meta_arch import Baseline
|
|
from fastreid.modeling.meta_arch.build import META_ARCH_REGISTRY
|
|
|
|
|
|
@META_ARCH_REGISTRY.register()
|
|
class PartialBaseline(Baseline):
|
|
|
|
def losses(self, outputs, gt_labels):
|
|
r"""
|
|
Compute loss from modeling's outputs, the loss function input arguments
|
|
must be the same as the outputs of the model forwarding.
|
|
"""
|
|
loss_dict = super().losses(outputs, gt_labels)
|
|
|
|
fore_cls_outputs = outputs["fore_cls_outputs"]
|
|
fore_feat = outputs["foreground_features"]
|
|
|
|
loss_names = self.loss_kwargs['loss_names']
|
|
|
|
if 'CrossEntropyLoss' in loss_names:
|
|
ce_kwargs = self.loss_kwargs.get('ce')
|
|
loss_dict['loss_fore_cls'] = cross_entropy_loss(
|
|
fore_cls_outputs,
|
|
gt_labels,
|
|
ce_kwargs.get('eps'),
|
|
ce_kwargs.get('alpha')
|
|
) * ce_kwargs.get('scale')
|
|
|
|
if 'TripletLoss' in loss_names:
|
|
tri_kwargs = self.loss_kwargs.get('tri')
|
|
loss_dict['loss_fore_triplet'] = triplet_loss(
|
|
fore_feat,
|
|
gt_labels,
|
|
tri_kwargs.get('margin'),
|
|
tri_kwargs.get('norm_feat'),
|
|
tri_kwargs.get('hard_mining')
|
|
) * tri_kwargs.get('scale')
|
|
|
|
return loss_dict
|