43 lines
1.2 KiB
Python
43 lines
1.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
|
|
from mmfewshot.detection.models import build_loss
|
|
|
|
|
|
def test_contrastive_loss():
|
|
feat = torch.Tensor([[1, 0, 1, 1], [1, 1, 1, 1]] * 2)
|
|
label = torch.LongTensor([0, 1] * 2)
|
|
iou = torch.Tensor([0.9] * 4)
|
|
|
|
loss_cfg = dict(
|
|
type='SupervisedContrastiveLoss',
|
|
temperature=0.2,
|
|
iou_threshold=0.8,
|
|
loss_weight=0.5,
|
|
reweight_type='none')
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(feat, label, iou), torch.tensor(0.1390))
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(
|
|
loss(feat, label, iou, decay_rate=0.5), torch.tensor(0.1390 * 0.5))
|
|
iou = torch.Tensor([0.7] * 4)
|
|
assert torch.allclose(loss(feat, label, iou), torch.tensor(0.))
|
|
|
|
loss_cfg = dict(
|
|
type='SupervisedContrastiveLoss',
|
|
temperature=0.2,
|
|
iou_threshold=0.5,
|
|
loss_weight=0.5,
|
|
reweight_type='linear')
|
|
loss = build_loss(loss_cfg)
|
|
assert loss(feat, label, iou)
|
|
|
|
loss_cfg = dict(
|
|
type='SupervisedContrastiveLoss',
|
|
temperature=0.2,
|
|
iou_threshold=0.5,
|
|
loss_weight=0.5,
|
|
reweight_type='exp')
|
|
loss = build_loss(loss_cfg)
|
|
assert loss(feat, label, iou)
|