mmocr/tests/models/common/modules/test_transformer_module.py
2022-07-21 10:58:04 +08:00

16 lines
368 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmocr.models.common.modules import PositionalEncoding
class TestPositionalEncoding(TestCase):
def test_forward(self):
pos_encoder = PositionalEncoding()
x = torch.rand(1, 30, 512)
out = pos_encoder(x)
assert out.size() == x.size()