mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
16 lines
368 B
Python
16 lines
368 B
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
from unittest import TestCase
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from mmocr.models.common.modules import PositionalEncoding
|
||
|
|
||
|
|
||
|
class TestPositionalEncoding(TestCase):
|
||
|
|
||
|
def test_forward(self):
|
||
|
pos_encoder = PositionalEncoding()
|
||
|
x = torch.rand(1, 30, 512)
|
||
|
out = pos_encoder(x)
|
||
|
assert out.size() == x.size()
|