mmocr/tests/models/textrecog/backbones/test_resnet31_ocr.py

30 lines
751 B
Python
Raw Normal View History

2022-07-14 11:57:35 +00:00
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmocr.models.textrecog.backbones import ResNet31OCR
class TestResNet31OCR(TestCase):
def test_forward(self):
"""Test resnet backbone."""
with self.assertRaises(AssertionError):
ResNet31OCR(2.5)
with self.assertRaises(AssertionError):
ResNet31OCR(3, layers=5)
with self.assertRaises(AssertionError):
ResNet31OCR(3, channels=5)
# Test ResNet18 forward
model = ResNet31OCR()
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 32, 160)
feat = model(imgs)
self.assertEqual(feat.shape, torch.Size([1, 512, 4, 40]))