2021-08-17 17:39:30 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
2021-04-06 05:20:43 -05:00
|
|
|
import pytest
|
2021-04-06 15:16:27 +08:00
|
|
|
import torch
|
|
|
|
|
2022-05-30 16:56:28 +08:00
|
|
|
from mmocr.models.textdet.necks import FPN_UNet
|
2021-04-06 05:20:43 -05:00
|
|
|
|
|
|
|
|
|
|
|
def test_fpn_unet_neck():
|
|
|
|
s = 64
|
|
|
|
feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8]
|
|
|
|
in_channels = [8, 16, 32, 64]
|
|
|
|
out_channels = 4
|
|
|
|
|
|
|
|
# len(in_channcels) is not equal to 4
|
|
|
|
with pytest.raises(AssertionError):
|
2021-05-17 22:15:47 -05:00
|
|
|
FPN_UNet(in_channels + [128], out_channels)
|
2021-04-06 05:20:43 -05:00
|
|
|
|
|
|
|
# `out_channels` is not int type
|
|
|
|
with pytest.raises(AssertionError):
|
2021-05-17 22:15:47 -05:00
|
|
|
FPN_UNet(in_channels, [2, 4])
|
2021-04-06 05:20:43 -05:00
|
|
|
|
|
|
|
feats = [
|
|
|
|
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
|
|
|
|
for i in range(len(in_channels))
|
|
|
|
]
|
|
|
|
|
2021-05-17 22:15:47 -05:00
|
|
|
fpn_unet_neck = FPN_UNet(in_channels, out_channels)
|
2021-04-06 05:20:43 -05:00
|
|
|
fpn_unet_neck.init_weights()
|
|
|
|
|
|
|
|
out_neck = fpn_unet_neck(feats)
|
|
|
|
assert out_neck.shape == torch.Size([1, out_channels, s * 4, s * 4])
|