mirror of https://github.com/open-mmlab/mmocr.git
134 lines
4.1 KiB
Python
134 lines
4.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmocr.models.common.losses import DiceLoss
|
|
from mmocr.models.textrecog.losses import (ABILoss, CELoss, CTCLoss, SARLoss,
|
|
TFLoss)
|
|
|
|
|
|
def test_ctc_loss():
|
|
with pytest.raises(AssertionError):
|
|
CTCLoss(flatten='flatten')
|
|
with pytest.raises(AssertionError):
|
|
CTCLoss(blank=None)
|
|
with pytest.raises(AssertionError):
|
|
CTCLoss(reduction=1)
|
|
with pytest.raises(AssertionError):
|
|
CTCLoss(zero_infinity='zero')
|
|
# test CTCLoss
|
|
ctc_loss = CTCLoss()
|
|
outputs = torch.zeros(2, 40, 37)
|
|
targets_dict = {
|
|
'flatten_targets': torch.IntTensor([1, 2, 3, 4, 5]),
|
|
'target_lengths': torch.LongTensor([2, 3])
|
|
}
|
|
|
|
losses = ctc_loss(outputs, targets_dict)
|
|
assert isinstance(losses, dict)
|
|
assert 'loss_ctc' in losses
|
|
assert torch.allclose(losses['loss_ctc'],
|
|
torch.tensor(losses['loss_ctc'].item()).float())
|
|
|
|
|
|
def test_ce_loss():
|
|
with pytest.raises(AssertionError):
|
|
CELoss(ignore_index='ignore')
|
|
with pytest.raises(AssertionError):
|
|
CELoss(reduction=1)
|
|
with pytest.raises(AssertionError):
|
|
CELoss(reduction='avg')
|
|
|
|
ce_loss = CELoss(ignore_index=0)
|
|
outputs = torch.rand(1, 10, 37)
|
|
targets_dict = {
|
|
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
|
|
}
|
|
losses = ce_loss(outputs, targets_dict)
|
|
assert isinstance(losses, dict)
|
|
assert 'loss_ce' in losses
|
|
assert losses['loss_ce'].size(1) == 10
|
|
|
|
ce_loss = CELoss(ignore_first_char=True)
|
|
outputs = torch.rand(1, 10, 37)
|
|
targets_dict = {
|
|
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
|
|
}
|
|
new_output, new_target = ce_loss.format(outputs, targets_dict)
|
|
assert new_output.shape == torch.Size([1, 37, 9])
|
|
assert new_target.shape == torch.Size([1, 9])
|
|
|
|
|
|
def test_sar_loss():
|
|
outputs = torch.rand(1, 10, 37)
|
|
targets_dict = {
|
|
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
|
|
}
|
|
sar_loss = SARLoss()
|
|
new_output, new_target = sar_loss.format(outputs, targets_dict)
|
|
assert new_output.shape == torch.Size([1, 37, 9])
|
|
assert new_target.shape == torch.Size([1, 9])
|
|
|
|
|
|
def test_tf_loss():
|
|
with pytest.raises(AssertionError):
|
|
TFLoss(flatten=1.0)
|
|
|
|
outputs = torch.rand(1, 10, 37)
|
|
targets_dict = {
|
|
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
|
|
}
|
|
tf_loss = TFLoss(flatten=False)
|
|
new_output, new_target = tf_loss.format(outputs, targets_dict)
|
|
assert new_output.shape == torch.Size([1, 37, 9])
|
|
assert new_target.shape == torch.Size([1, 9])
|
|
|
|
|
|
def test_dice_loss():
|
|
with pytest.raises(AssertionError):
|
|
DiceLoss(eps='1')
|
|
|
|
dice_loss = DiceLoss()
|
|
pred = torch.rand(1, 1, 32, 32)
|
|
gt = torch.rand(1, 1, 32, 32)
|
|
|
|
loss = dice_loss(pred, gt, None)
|
|
assert isinstance(loss, torch.Tensor)
|
|
|
|
mask = torch.rand(1, 1, 1, 1)
|
|
loss = dice_loss(pred, gt, mask)
|
|
assert isinstance(loss, torch.Tensor)
|
|
|
|
|
|
def test_abi_loss():
|
|
loss = ABILoss(num_classes=90)
|
|
outputs = dict(
|
|
out_enc=dict(logits=torch.randn(2, 10, 90)),
|
|
out_decs=[
|
|
dict(logits=torch.randn(2, 10, 90)),
|
|
dict(logits=torch.randn(2, 10, 90))
|
|
],
|
|
out_fusers=[
|
|
dict(logits=torch.randn(2, 10, 90)),
|
|
dict(logits=torch.randn(2, 10, 90))
|
|
])
|
|
targets_dict = {
|
|
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]]),
|
|
'targets':
|
|
[torch.LongTensor([1, 2, 3, 4]),
|
|
torch.LongTensor([1, 2, 3])]
|
|
}
|
|
result = loss(outputs, targets_dict)
|
|
assert isinstance(result, dict)
|
|
assert isinstance(result['loss_visual'], torch.Tensor)
|
|
assert isinstance(result['loss_lang'], torch.Tensor)
|
|
assert isinstance(result['loss_fusion'], torch.Tensor)
|
|
|
|
outputs.pop('out_enc')
|
|
loss(outputs, targets_dict)
|
|
outputs.pop('out_decs')
|
|
loss(outputs, targets_dict)
|
|
outputs.pop('out_fusers')
|
|
with pytest.raises(AssertionError):
|
|
loss(outputs, targets_dict)
|