45 lines
1.2 KiB
Python
Raw Normal View History

2020-09-23 19:45:13 +08:00
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from fastreid.modeling.meta_arch.baseline import Baseline
from fastreid.modeling.meta_arch.build import META_ARCH_REGISTRY
from .bce_loss import cross_entropy_sigmoid_loss
@META_ARCH_REGISTRY.register()
class AttrBaseline(Baseline):
@classmethod
def from_config(cls, cfg):
base_res = Baseline.from_config(cfg)
base_res["loss_kwargs"].update({
'bce': {
'scale': cfg.MODEL.LOSSES.BCE.SCALE
}
})
return base_res
2021-01-18 11:36:38 +08:00
def losses(self, outputs, gt_labels):
2020-09-23 19:45:13 +08:00
r"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
# model predictions
cls_outputs = outputs["cls_outputs"]
2020-09-23 19:45:13 +08:00
loss_dict = {}
loss_names = self.loss_kwargs["loss_names"]
2020-09-23 19:45:13 +08:00
if "BinaryCrossEntropyLoss" in loss_names:
bce_kwargs = self.loss_kwargs.get('bce')
2021-01-18 11:36:38 +08:00
loss_dict["loss_bce"] = cross_entropy_sigmoid_loss(
2020-09-23 19:45:13 +08:00
cls_outputs,
gt_labels,
self.sample_weights,
) * bce_kwargs.get('scale')
2020-09-23 19:45:13 +08:00
return loss_dict