mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
16 lines
368 B
Python
16 lines
368 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
|
|
from mmocr.models.common.modules import PositionalEncoding
|
|
|
|
|
|
class TestPositionalEncoding(TestCase):
|
|
|
|
def test_forward(self):
|
|
pos_encoder = PositionalEncoding()
|
|
x = torch.rand(1, 30, 512)
|
|
out = pos_encoder(x)
|
|
assert out.size() == x.size()
|