mmfewshot/tests/test_classification_models/test_classification_losses.py
Linyiqi 37cd8a1f7d
Fix docs and add test code (#45)
* fix init

* fix test api

fix test api bug

* add metarcnn fsdetview config

* add pr

* add metatestparallel comments

* add test code and fix typos

* add test code of model frozen

* update test det forward code

* update pr

* update doc str
2021-11-01 16:33:21 +08:00

23 lines
699 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmfewshot.classification.models import build_loss
def test_mse_loss():
cls_score = torch.Tensor([1, 1, 1, 1, 1, 0])
label = torch.Tensor([1, 0, 1, 0, 1, 0])
loss_cfg = dict(type='MSELoss', reduction='mean', loss_weight=1.0)
loss = build_loss(loss_cfg)
assert torch.allclose(loss(cls_score, label), torch.tensor(1 / 3))
def test_nll_loss():
cls_score = torch.Tensor([[1, 0, 0], [1, 1, 0]])
label = torch.LongTensor([1, 0])
loss_cfg = dict(type='NLLLoss', reduction='mean', loss_weight=1.0)
loss = build_loss(loss_cfg)
assert torch.allclose(loss(cls_score, label), torch.tensor(-0.5000))