mirror of
https://github.com/open-mmlab/mmfewshot.git
synced 2025-06-03 14:49:43 +08:00
* fix init * fix test api fix test api bug * add metarcnn fsdetview config * add pr * add metatestparallel comments * add test code and fix typos * add test code of model frozen * update test det forward code * update pr * update doc str
23 lines
699 B
Python
23 lines
699 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
|
|
from mmfewshot.classification.models import build_loss
|
|
|
|
|
|
def test_mse_loss():
|
|
cls_score = torch.Tensor([1, 1, 1, 1, 1, 0])
|
|
label = torch.Tensor([1, 0, 1, 0, 1, 0])
|
|
|
|
loss_cfg = dict(type='MSELoss', reduction='mean', loss_weight=1.0)
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(cls_score, label), torch.tensor(1 / 3))
|
|
|
|
|
|
def test_nll_loss():
|
|
cls_score = torch.Tensor([[1, 0, 0], [1, 1, 0]])
|
|
label = torch.LongTensor([1, 0])
|
|
|
|
loss_cfg = dict(type='NLLLoss', reduction='mean', loss_weight=1.0)
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(cls_score, label), torch.tensor(-0.5000))
|