mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
136 lines
5.4 KiB
Python
136 lines
5.4 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
from unittest import TestCase
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from mmocr.models.textrecog.recognizers import EncoderDecoderRecognizer
|
||
|
from mmocr.registry import MODELS
|
||
|
|
||
|
|
||
|
class TestEncoderDecoderRecognizer(TestCase):
|
||
|
|
||
|
@MODELS.register_module()
|
||
|
class DummyModule:
|
||
|
|
||
|
def __init__(self, value):
|
||
|
self.value = value
|
||
|
|
||
|
def __call__(self, x, *args, **kwargs):
|
||
|
return x + self.value
|
||
|
|
||
|
def predict(self, x, y, *args, **kwargs):
|
||
|
if y is None:
|
||
|
return x
|
||
|
return x + y
|
||
|
|
||
|
def loss(self, x, y, *args, **kwargs):
|
||
|
if y is None:
|
||
|
return x
|
||
|
return x * y
|
||
|
|
||
|
def test_init(self):
|
||
|
# Decoder is not allowed to be None
|
||
|
with self.assertRaises(AssertionError):
|
||
|
EncoderDecoderRecognizer()
|
||
|
|
||
|
for attr in ['backbone', 'preprocessor', 'encoder']:
|
||
|
recognizer = EncoderDecoderRecognizer(
|
||
|
**{
|
||
|
attr: dict(type='DummyModule', value=1),
|
||
|
'decoder': dict(type='DummyModule', value=2)
|
||
|
})
|
||
|
self.assertTrue(hasattr(recognizer, attr))
|
||
|
self.assertFalse(
|
||
|
any(
|
||
|
hasattr(recognizer, t)
|
||
|
for t in ['backbone', 'preprocessor', 'encoder']
|
||
|
if t != attr))
|
||
|
|
||
|
def test_extract_feat(self):
|
||
|
model = EncoderDecoderRecognizer(
|
||
|
decoder=dict(type='DummyModule', value=1))
|
||
|
self.assertEqual(
|
||
|
model.extract_feat(torch.tensor([1])), torch.Tensor([1]))
|
||
|
model = EncoderDecoderRecognizer(
|
||
|
backbone=dict(type='DummyModule', value=1),
|
||
|
decoder=dict(type='DummyModule', value=1))
|
||
|
self.assertEqual(
|
||
|
model.extract_feat(torch.tensor([1])), torch.Tensor([2]))
|
||
|
model = EncoderDecoderRecognizer(
|
||
|
preprocessor=dict(type='DummyModule', value=2),
|
||
|
decoder=dict(type='DummyModule', value=1))
|
||
|
self.assertEqual(
|
||
|
model.extract_feat(torch.tensor([1])), torch.Tensor([3]))
|
||
|
model = EncoderDecoderRecognizer(
|
||
|
preprocessor=dict(type='DummyModule', value=2),
|
||
|
backbone=dict(type='DummyModule', value=1),
|
||
|
decoder=dict(type='DummyModule', value=1))
|
||
|
self.assertEqual(
|
||
|
model.extract_feat(torch.tensor([1])), torch.Tensor([4]))
|
||
|
|
||
|
def test_loss(self):
|
||
|
model = EncoderDecoderRecognizer(
|
||
|
decoder=dict(type='DummyModule', value=1))
|
||
|
self.assertEqual(
|
||
|
model.loss(torch.tensor([1]), None), torch.Tensor([1]))
|
||
|
model = EncoderDecoderRecognizer(
|
||
|
encoder=dict(type='DummyModule', value=2),
|
||
|
decoder=dict(type='DummyModule', value=1))
|
||
|
self.assertEqual(
|
||
|
model.loss(torch.tensor([1]), None), torch.Tensor([3]))
|
||
|
model = EncoderDecoderRecognizer(
|
||
|
backbone=dict(type='DummyModule', value=1),
|
||
|
encoder=dict(type='DummyModule', value=2),
|
||
|
decoder=dict(type='DummyModule', value=1))
|
||
|
self.assertEqual(
|
||
|
model.loss(torch.tensor([1]), None), torch.Tensor([8]))
|
||
|
model = EncoderDecoderRecognizer(
|
||
|
backbone=dict(type='DummyModule', value=1),
|
||
|
decoder=dict(type='DummyModule', value=1))
|
||
|
self.assertEqual(
|
||
|
model.loss(torch.tensor([1]), None), torch.Tensor([2]))
|
||
|
|
||
|
def test_predict(self):
|
||
|
model = EncoderDecoderRecognizer(
|
||
|
decoder=dict(type='DummyModule', value=1))
|
||
|
self.assertEqual(
|
||
|
model.predict(torch.tensor([1]), None), torch.Tensor([1]))
|
||
|
model = EncoderDecoderRecognizer(
|
||
|
encoder=dict(type='DummyModule', value=2),
|
||
|
decoder=dict(type='DummyModule', value=1))
|
||
|
self.assertEqual(
|
||
|
model.predict(torch.tensor([1]), None), torch.Tensor([4]))
|
||
|
model = EncoderDecoderRecognizer(
|
||
|
backbone=dict(type='DummyModule', value=1),
|
||
|
encoder=dict(type='DummyModule', value=2),
|
||
|
decoder=dict(type='DummyModule', value=1))
|
||
|
self.assertEqual(
|
||
|
model.predict(torch.tensor([1]), None), torch.Tensor([6]))
|
||
|
model = EncoderDecoderRecognizer(
|
||
|
backbone=dict(type='DummyModule', value=1),
|
||
|
decoder=dict(type='DummyModule', value=1))
|
||
|
self.assertEqual(
|
||
|
model.loss(torch.tensor([1]), None), torch.Tensor([2]))
|
||
|
|
||
|
def test_forward(self):
|
||
|
model = EncoderDecoderRecognizer(
|
||
|
decoder=dict(type='DummyModule', value=1))
|
||
|
self.assertEqual(
|
||
|
model._forward(torch.tensor([1]), None), torch.Tensor([2]))
|
||
|
model = EncoderDecoderRecognizer(
|
||
|
encoder=dict(type='DummyModule', value=2),
|
||
|
decoder=dict(type='DummyModule', value=1))
|
||
|
self.assertEqual(
|
||
|
model._forward(torch.tensor([1]), None), torch.Tensor([2]))
|
||
|
model = EncoderDecoderRecognizer(
|
||
|
backbone=dict(type='DummyModule', value=1),
|
||
|
encoder=dict(type='DummyModule', value=2),
|
||
|
decoder=dict(type='DummyModule', value=1))
|
||
|
self.assertEqual(
|
||
|
model._forward(torch.tensor([1]), None), torch.Tensor([3]))
|
||
|
model = EncoderDecoderRecognizer(
|
||
|
backbone=dict(type='DummyModule', value=1),
|
||
|
decoder=dict(type='DummyModule', value=1))
|
||
|
self.assertEqual(
|
||
|
model._forward(torch.tensor([1]), None), torch.Tensor([3]))
|