1
0
mirror of https://github.com/open-mmlab/mmocr.git synced 2025-06-03 21:54:47 +08:00
2022-07-21 10:58:04 +08:00

30 lines
751 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmocr.models.textrecog.backbones import ResNet31OCR
class TestResNet31OCR(TestCase):
def test_forward(self):
"""Test resnet backbone."""
with self.assertRaises(AssertionError):
ResNet31OCR(2.5)
with self.assertRaises(AssertionError):
ResNet31OCR(3, layers=5)
with self.assertRaises(AssertionError):
ResNet31OCR(3, channels=5)
# Test ResNet18 forward
model = ResNet31OCR()
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 32, 160)
feat = model(imgs)
self.assertEqual(feat.shape, torch.Size([1, 512, 4, 40]))