mmocr/tests/test_models/test_ocr_loss.py
2021-04-02 23:54:57 +08:00

124 lines
3.9 KiB
Python

import numpy as np
import pytest
import torch
from mmdet.core import BitmapMasks
from mmocr.models.common.losses import DiceLoss
from mmocr.models.textrecog.losses import (CAFCNLoss, CELoss, CTCLoss, SARLoss,
TFLoss)
def test_ctc_loss():
# test CTCLoss
ctc_loss = CTCLoss()
outputs = torch.zeros(2, 40, 37)
targets_dict = {
'flatten_targets': torch.IntTensor([1, 2, 3, 4, 5]),
'target_lengths': torch.LongTensor([2, 3])
}
losses = ctc_loss(outputs, targets_dict)
assert isinstance(losses, dict)
assert 'loss_ctc' in losses
assert torch.allclose(losses['loss_ctc'],
torch.tensor(losses['loss_ctc'].item()).float())
def test_ce_loss():
with pytest.raises(AssertionError):
CELoss(ignore_index='ignore')
with pytest.raises(AssertionError):
CELoss(reduction=1)
with pytest.raises(AssertionError):
CELoss(reduction='avg')
ce_loss = CELoss(ignore_index=0)
outputs = torch.rand(1, 10, 37)
targets_dict = {
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
}
losses = ce_loss(outputs, targets_dict)
assert isinstance(losses, dict)
assert 'loss_ce' in losses
print(losses['loss_ce'].size())
assert losses['loss_ce'].size(1) == 10
def test_sar_loss():
outputs = torch.rand(1, 10, 37)
targets_dict = {
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
}
sar_loss = SARLoss()
new_output, new_target = sar_loss.format(outputs, targets_dict)
assert new_output.shape == torch.Size([1, 37, 9])
assert new_target.shape == torch.Size([1, 9])
def test_tf_loss():
with pytest.raises(AssertionError):
TFLoss(flatten=1.0)
outputs = torch.rand(1, 10, 37)
targets_dict = {
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
}
tf_loss = TFLoss(flatten=False)
new_output, new_target = tf_loss.format(outputs, targets_dict)
assert new_output.shape == torch.Size([1, 37, 9])
assert new_target.shape == torch.Size([1, 9])
def test_cafcn_loss():
with pytest.raises(AssertionError):
CAFCNLoss(alpha='1')
with pytest.raises(AssertionError):
CAFCNLoss(attn_s2_downsample_ratio='2')
with pytest.raises(AssertionError):
CAFCNLoss(attn_s3_downsample_ratio='1.5')
with pytest.raises(AssertionError):
CAFCNLoss(seg_downsample_ratio='1.5')
with pytest.raises(AssertionError):
CAFCNLoss(attn_s2_downsample_ratio=2)
with pytest.raises(AssertionError):
CAFCNLoss(attn_s3_downsample_ratio=1.5)
with pytest.raises(AssertionError):
CAFCNLoss(seg_downsample_ratio=1.5)
bsz = 1
H = W = 64
out_neck = (torch.ones(bsz, 1, H // 4, W // 4) * 0.5,
torch.ones(bsz, 1, H // 8, W // 8) * 0.5,
torch.ones(bsz, 1, H // 8, W // 8) * 0.5,
torch.ones(bsz, 1, H // 8, W // 8) * 0.5,
torch.ones(bsz, 1, H // 2, W // 2) * 0.5)
out_head = torch.rand(bsz, 37, H // 2, W // 2)
attn_tgt = np.zeros((H, W), dtype=np.float32)
segm_tgt = np.zeros((H, W), dtype=np.float32)
mask = np.ones((H, W), dtype=np.float32)
gt_kernels = BitmapMasks([attn_tgt, segm_tgt, mask], H, W)
cafcn_loss = CAFCNLoss()
losses = cafcn_loss(out_neck, out_head, [gt_kernels])
assert isinstance(losses, dict)
assert 'loss_seg' in losses
assert torch.allclose(losses['loss_seg'],
torch.tensor(losses['loss_seg'].item()).float())
def test_dice_loss():
with pytest.raises(AssertionError):
DiceLoss(eps='1')
dice_loss = DiceLoss()
pred = torch.rand(1, 1, 32, 32)
gt = torch.rand(1, 1, 32, 32)
loss = dice_loss(pred, gt, None)
assert isinstance(loss, torch.Tensor)
mask = torch.rand(1, 1, 1, 1)
loss = dice_loss(pred, gt, mask)
assert isinstance(loss, torch.Tensor)