mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
136 lines
5.4 KiB
Python
136 lines
5.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
|
|
from mmocr.models.textrecog.recognizers import EncoderDecoderRecognizer
|
|
from mmocr.registry import MODELS
|
|
|
|
|
|
class TestEncoderDecoderRecognizer(TestCase):
|
|
|
|
@MODELS.register_module()
|
|
class DummyModule:
|
|
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
def __call__(self, x, *args, **kwargs):
|
|
return x + self.value
|
|
|
|
def predict(self, x, y, *args, **kwargs):
|
|
if y is None:
|
|
return x
|
|
return x + y
|
|
|
|
def loss(self, x, y, *args, **kwargs):
|
|
if y is None:
|
|
return x
|
|
return x * y
|
|
|
|
def test_init(self):
|
|
# Decoder is not allowed to be None
|
|
with self.assertRaises(AssertionError):
|
|
EncoderDecoderRecognizer()
|
|
|
|
for attr in ['backbone', 'preprocessor', 'encoder']:
|
|
recognizer = EncoderDecoderRecognizer(
|
|
**{
|
|
attr: dict(type='DummyModule', value=1),
|
|
'decoder': dict(type='DummyModule', value=2)
|
|
})
|
|
self.assertTrue(hasattr(recognizer, attr))
|
|
self.assertFalse(
|
|
any(
|
|
hasattr(recognizer, t)
|
|
for t in ['backbone', 'preprocessor', 'encoder']
|
|
if t != attr))
|
|
|
|
def test_extract_feat(self):
|
|
model = EncoderDecoderRecognizer(
|
|
decoder=dict(type='DummyModule', value=1))
|
|
self.assertEqual(
|
|
model.extract_feat(torch.tensor([1])), torch.Tensor([1]))
|
|
model = EncoderDecoderRecognizer(
|
|
backbone=dict(type='DummyModule', value=1),
|
|
decoder=dict(type='DummyModule', value=1))
|
|
self.assertEqual(
|
|
model.extract_feat(torch.tensor([1])), torch.Tensor([2]))
|
|
model = EncoderDecoderRecognizer(
|
|
preprocessor=dict(type='DummyModule', value=2),
|
|
decoder=dict(type='DummyModule', value=1))
|
|
self.assertEqual(
|
|
model.extract_feat(torch.tensor([1])), torch.Tensor([3]))
|
|
model = EncoderDecoderRecognizer(
|
|
preprocessor=dict(type='DummyModule', value=2),
|
|
backbone=dict(type='DummyModule', value=1),
|
|
decoder=dict(type='DummyModule', value=1))
|
|
self.assertEqual(
|
|
model.extract_feat(torch.tensor([1])), torch.Tensor([4]))
|
|
|
|
def test_loss(self):
|
|
model = EncoderDecoderRecognizer(
|
|
decoder=dict(type='DummyModule', value=1))
|
|
self.assertEqual(
|
|
model.loss(torch.tensor([1]), None), torch.Tensor([1]))
|
|
model = EncoderDecoderRecognizer(
|
|
encoder=dict(type='DummyModule', value=2),
|
|
decoder=dict(type='DummyModule', value=1))
|
|
self.assertEqual(
|
|
model.loss(torch.tensor([1]), None), torch.Tensor([3]))
|
|
model = EncoderDecoderRecognizer(
|
|
backbone=dict(type='DummyModule', value=1),
|
|
encoder=dict(type='DummyModule', value=2),
|
|
decoder=dict(type='DummyModule', value=1))
|
|
self.assertEqual(
|
|
model.loss(torch.tensor([1]), None), torch.Tensor([8]))
|
|
model = EncoderDecoderRecognizer(
|
|
backbone=dict(type='DummyModule', value=1),
|
|
decoder=dict(type='DummyModule', value=1))
|
|
self.assertEqual(
|
|
model.loss(torch.tensor([1]), None), torch.Tensor([2]))
|
|
|
|
def test_predict(self):
|
|
model = EncoderDecoderRecognizer(
|
|
decoder=dict(type='DummyModule', value=1))
|
|
self.assertEqual(
|
|
model.predict(torch.tensor([1]), None), torch.Tensor([1]))
|
|
model = EncoderDecoderRecognizer(
|
|
encoder=dict(type='DummyModule', value=2),
|
|
decoder=dict(type='DummyModule', value=1))
|
|
self.assertEqual(
|
|
model.predict(torch.tensor([1]), None), torch.Tensor([4]))
|
|
model = EncoderDecoderRecognizer(
|
|
backbone=dict(type='DummyModule', value=1),
|
|
encoder=dict(type='DummyModule', value=2),
|
|
decoder=dict(type='DummyModule', value=1))
|
|
self.assertEqual(
|
|
model.predict(torch.tensor([1]), None), torch.Tensor([6]))
|
|
model = EncoderDecoderRecognizer(
|
|
backbone=dict(type='DummyModule', value=1),
|
|
decoder=dict(type='DummyModule', value=1))
|
|
self.assertEqual(
|
|
model.loss(torch.tensor([1]), None), torch.Tensor([2]))
|
|
|
|
def test_forward(self):
|
|
model = EncoderDecoderRecognizer(
|
|
decoder=dict(type='DummyModule', value=1))
|
|
self.assertEqual(
|
|
model._forward(torch.tensor([1]), None), torch.Tensor([2]))
|
|
model = EncoderDecoderRecognizer(
|
|
encoder=dict(type='DummyModule', value=2),
|
|
decoder=dict(type='DummyModule', value=1))
|
|
self.assertEqual(
|
|
model._forward(torch.tensor([1]), None), torch.Tensor([2]))
|
|
model = EncoderDecoderRecognizer(
|
|
backbone=dict(type='DummyModule', value=1),
|
|
encoder=dict(type='DummyModule', value=2),
|
|
decoder=dict(type='DummyModule', value=1))
|
|
self.assertEqual(
|
|
model._forward(torch.tensor([1]), None), torch.Tensor([3]))
|
|
model = EncoderDecoderRecognizer(
|
|
backbone=dict(type='DummyModule', value=1),
|
|
decoder=dict(type='DummyModule', value=1))
|
|
self.assertEqual(
|
|
model._forward(torch.tensor([1]), None), torch.Tensor([3]))
|