1
0
mirror of https://github.com/open-mmlab/mmocr.git synced 2025-06-03 21:54:47 +08:00
mmocr/tests/models/textrecog/encoders/test_channel_reduction_encoder.py

27 lines
820 B
Python
Raw Normal View History

2022-07-08 16:09:06 +00:00
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import torch
2022-07-12 10:46:11 +00:00
from mmocr.data import TextRecogDataSample
2022-07-08 16:09:06 +00:00
from mmocr.models.textrecog.encoders import ChannelReductionEncoder
class TestChannelReductionEncoder(unittest.TestCase):
def setUp(self):
self.feat = torch.randn(2, 512, 8, 25)
gt_text_sample1 = TextRecogDataSample()
gt_text_sample1.set_metainfo(dict(valid_ratio=0.9))
gt_text_sample2 = TextRecogDataSample()
gt_text_sample2.set_metainfo(dict(valid_ratio=1.0))
self.data_info = [gt_text_sample1, gt_text_sample2]
def test_encoder(self):
encoder = ChannelReductionEncoder(512, 256)
encoder.train()
out_enc = encoder(self.feat, self.data_info)
self.assertEqual(out_enc.shape, torch.Size([2, 256, 8, 25]))