mmclassification/tests/test_losses.py
LXXXXR 9578bfa0f1
[Feature] Add focal loss for multilabel task (#131)
* add focal loss

* apply class wise sum

* fix doctring

* do not apply sum over classes and fix docstring

* fix docstring

* fix weight shape

* fix weight shape
2021-01-08 20:44:23 +08:00

23 lines
615 B
Python

import torch
from mmcls.models import build_loss
def test_focal_loss():
# test focal_loss
cls_score = torch.Tensor([[5, -5, 0], [5, -5, 0]])
label = torch.Tensor([[1, 0, 1], [0, 1, 0]])
weight = torch.tensor([0.5, 0.5])
loss_cfg = dict(
type='FocalLoss',
gamma=2.0,
alpha=0.25,
reduction='mean',
loss_weight=1.0)
loss = build_loss(loss_cfg)
assert torch.allclose(loss(cls_score, label), torch.tensor(0.8522))
# test focal_loss with weight
assert torch.allclose(
loss(cls_score, label, weight=weight), torch.tensor(0.8522 / 2))