2022-02-24 09:24:25 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
2020-07-08 17:29:15 +08:00
|
|
|
"""
|
|
|
|
CommandLine:
|
|
|
|
pytest tests/test_merge_cells.py
|
|
|
|
"""
|
2022-05-20 15:13:13 +08:00
|
|
|
import math
|
|
|
|
|
|
|
|
import pytest
|
2020-07-08 17:29:15 +08:00
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
from mmcv.ops.merge_cells import (BaseMergeCell, ConcatCell, GlobalPoolingCell,
|
|
|
|
SumCell)
|
|
|
|
|
|
|
|
|
2022-05-20 15:13:13 +08:00
|
|
|
# All size (14, 7) below is to test the situation that
|
|
|
|
# the input size can't be divisible by the target size.
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
'inputs_x, inputs_y',
|
|
|
|
[(torch.randn([2, 256, 16, 16]), torch.randn([2, 256, 32, 32])),
|
|
|
|
(torch.randn([2, 256, 14, 7]), torch.randn([2, 256, 32, 32]))])
|
|
|
|
def test_sum_cell(inputs_x, inputs_y):
|
2020-07-08 17:29:15 +08:00
|
|
|
sum_cell = SumCell(256, 256)
|
|
|
|
output = sum_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
|
|
|
|
assert output.size() == inputs_x.size()
|
|
|
|
output = sum_cell(inputs_x, inputs_y, out_size=inputs_y.shape[-2:])
|
|
|
|
assert output.size() == inputs_y.size()
|
|
|
|
output = sum_cell(inputs_x, inputs_y)
|
2022-05-20 15:13:13 +08:00
|
|
|
assert output.size() == inputs_y.size()
|
2020-07-08 17:29:15 +08:00
|
|
|
|
|
|
|
|
2022-05-20 15:13:13 +08:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
'inputs_x, inputs_y',
|
|
|
|
[(torch.randn([2, 256, 16, 16]), torch.randn([2, 256, 32, 32])),
|
|
|
|
(torch.randn([2, 256, 14, 7]), torch.randn([2, 256, 32, 32]))])
|
|
|
|
def test_concat_cell(inputs_x, inputs_y):
|
2020-07-08 17:29:15 +08:00
|
|
|
concat_cell = ConcatCell(256, 256)
|
|
|
|
output = concat_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
|
|
|
|
assert output.size() == inputs_x.size()
|
|
|
|
output = concat_cell(inputs_x, inputs_y, out_size=inputs_y.shape[-2:])
|
|
|
|
assert output.size() == inputs_y.size()
|
|
|
|
output = concat_cell(inputs_x, inputs_y)
|
2022-05-20 15:13:13 +08:00
|
|
|
assert output.size() == inputs_y.size()
|
2020-07-08 17:29:15 +08:00
|
|
|
|
|
|
|
|
2022-05-20 15:13:13 +08:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
'inputs_x, inputs_y',
|
|
|
|
[(torch.randn([2, 256, 16, 16]), torch.randn([2, 256, 32, 32])),
|
|
|
|
(torch.randn([2, 256, 14, 7]), torch.randn([2, 256, 32, 32]))])
|
|
|
|
def test_global_pool_cell(inputs_x, inputs_y):
|
2020-07-08 17:29:15 +08:00
|
|
|
gp_cell = GlobalPoolingCell(with_out_conv=False)
|
|
|
|
gp_cell_out = gp_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
|
|
|
|
assert (gp_cell_out.size() == inputs_x.size())
|
|
|
|
gp_cell = GlobalPoolingCell(256, 256)
|
|
|
|
gp_cell_out = gp_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
|
|
|
|
assert (gp_cell_out.size() == inputs_x.size())
|
|
|
|
|
|
|
|
|
2022-05-20 15:13:13 +08:00
|
|
|
@pytest.mark.parametrize('target_size', [(256, 256), (128, 128), (64, 64),
|
|
|
|
(14, 7)])
|
|
|
|
def test_resize_methods(target_size):
|
2020-07-08 17:29:15 +08:00
|
|
|
inputs_x = torch.randn([2, 256, 128, 128])
|
2022-05-20 15:13:13 +08:00
|
|
|
h, w = inputs_x.shape[-2:]
|
|
|
|
target_h, target_w = target_size
|
|
|
|
if (h <= target_h) or w <= target_w:
|
|
|
|
rs_mode = 'upsample'
|
|
|
|
else:
|
|
|
|
rs_mode = 'downsample'
|
2020-07-08 17:29:15 +08:00
|
|
|
|
2022-05-20 15:13:13 +08:00
|
|
|
if rs_mode == 'upsample':
|
|
|
|
upsample_methods_list = ['nearest', 'bilinear']
|
|
|
|
for method in upsample_methods_list:
|
|
|
|
merge_cell = BaseMergeCell(upsample_mode=method)
|
2020-07-08 17:29:15 +08:00
|
|
|
merge_cell_out = merge_cell._resize(inputs_x, target_size)
|
|
|
|
gt_out = F.interpolate(inputs_x, size=target_size, mode=method)
|
|
|
|
assert merge_cell_out.equal(gt_out)
|
2022-05-20 15:13:13 +08:00
|
|
|
elif rs_mode == 'downsample':
|
|
|
|
merge_cell = BaseMergeCell()
|
|
|
|
merge_cell_out = merge_cell._resize(inputs_x, target_size)
|
|
|
|
if h % target_h != 0 or w % target_w != 0:
|
|
|
|
pad_h = math.ceil(h / target_h) * target_h - h
|
|
|
|
pad_w = math.ceil(w / target_w) * target_w - w
|
|
|
|
pad_l = pad_w // 2
|
|
|
|
pad_r = pad_w - pad_l
|
|
|
|
pad_t = pad_h // 2
|
|
|
|
pad_b = pad_h - pad_t
|
|
|
|
pad = (pad_l, pad_r, pad_t, pad_b)
|
|
|
|
inputs_x = F.pad(inputs_x, pad, mode='constant', value=0.0)
|
|
|
|
kernel_size = (inputs_x.shape[-2] // target_h,
|
|
|
|
inputs_x.shape[-1] // target_w)
|
|
|
|
gt_out = F.max_pool2d(
|
|
|
|
inputs_x, kernel_size=kernel_size, stride=kernel_size)
|
|
|
|
print(merge_cell_out.shape, gt_out.shape)
|
|
|
|
assert (merge_cell_out == gt_out).all()
|
|
|
|
assert merge_cell_out.shape[-2:] == target_size
|