2021-04-02 23:54:57 +08:00
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from mmocr.models.textrecog import SegHead
|
|
|
|
|
|
|
|
|
2021-04-06 15:17:20 +08:00
|
|
|
def test_seg_head():
|
2021-04-02 23:54:57 +08:00
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
SegHead(num_classes='100')
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
SegHead(num_classes=-1)
|
|
|
|
|
2021-04-06 15:17:20 +08:00
|
|
|
seg_head = SegHead(num_classes=37)
|
2021-04-02 23:54:57 +08:00
|
|
|
out_neck = (torch.rand(1, 128, 32, 32), )
|
2021-04-06 15:17:20 +08:00
|
|
|
out_head = seg_head(out_neck)
|
2021-04-02 23:54:57 +08:00
|
|
|
assert out_head.shape == torch.Size([1, 37, 32, 32])
|