fast-reid/projects/FastAttr/fastattr/attr_baseline.py
2021-01-18 11:36:38 +08:00

41 lines
1.2 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from fastreid.modeling.meta_arch.baseline import Baseline
from fastreid.modeling.meta_arch.build import META_ARCH_REGISTRY
from .bce_loss import cross_entropy_sigmoid_loss
@META_ARCH_REGISTRY.register()
class AttrBaseline(Baseline):
def __init__(self, cfg, sample_weights):
super(AttrBaseline, self).__init__(cfg)
bce_weight_enabled = cfg.MODEL.LOSSES.BCE.WEIGHT_ENABLED
if bce_weight_enabled:
self.register_buffer("sample_weight", sample_weights)
else:
self.sample_weights = None
def losses(self, outputs, gt_labels):
r"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
# model predictions
cls_outputs = outputs['cls_outputs']
loss_dict = {}
loss_names = self._cfg.MODEL.LOSSES.NAME
if "BinaryCrossEntropyLoss" in loss_names:
loss_dict["loss_bce"] = cross_entropy_sigmoid_loss(
cls_outputs,
gt_labels,
self.sample_weight,
) * self._cfg.MODEL.LOSSES.BCE.SCALE
return loss_dict