23 lines
699 B
Python
23 lines
699 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
|
|
from mmfewshot.classification.models import build_loss
|
|
|
|
|
|
def test_mse_loss():
|
|
cls_score = torch.Tensor([1, 1, 1, 1, 1, 0])
|
|
label = torch.Tensor([1, 0, 1, 0, 1, 0])
|
|
|
|
loss_cfg = dict(type='MSELoss', reduction='mean', loss_weight=1.0)
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(cls_score, label), torch.tensor(1 / 3))
|
|
|
|
|
|
def test_nll_loss():
|
|
cls_score = torch.Tensor([[1, 0, 0], [1, 1, 0]])
|
|
label = torch.LongTensor([1, 0])
|
|
|
|
loss_cfg = dict(type='NLLLoss', reduction='mean', loss_weight=1.0)
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(cls_score, label), torch.tensor(-0.5000))
|