mirror of https://github.com/open-mmlab/mmcv.git
47 lines
1.7 KiB
Python
47 lines
1.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmcv.ops import SAConv2d
|
|
|
|
|
|
def test_sacconv():
|
|
# test with normal cast
|
|
x = torch.rand(1, 3, 256, 256)
|
|
saconv = SAConv2d(3, 5, kernel_size=3, padding=1)
|
|
sac_out = saconv(x)
|
|
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1)
|
|
refer_out = refer_conv(x)
|
|
assert sac_out.shape == refer_out.shape
|
|
|
|
# test with dilation >= 2
|
|
dalited_saconv = SAConv2d(3, 5, kernel_size=3, padding=2, dilation=2)
|
|
dalited_sac_out = dalited_saconv(x)
|
|
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=2, dilation=2)
|
|
refer_out = refer_conv(x)
|
|
assert dalited_sac_out.shape == refer_out.shape
|
|
|
|
# test with deform
|
|
deform_saconv = SAConv2d(3, 5, kernel_size=3, padding=1, use_deform=True)
|
|
if torch.cuda.is_available():
|
|
x = torch.rand(1, 3, 256, 256).cuda()
|
|
deform_saconv = SAConv2d(
|
|
3, 5, kernel_size=3, padding=1, use_deform=True).cuda()
|
|
deform_sac_out = deform_saconv(x).cuda()
|
|
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1).cuda()
|
|
refer_out = refer_conv(x)
|
|
assert deform_sac_out.shape == refer_out.shape
|
|
else:
|
|
deform_sac_out = deform_saconv(x)
|
|
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1)
|
|
refer_out = refer_conv(x)
|
|
assert deform_sac_out.shape == refer_out.shape
|
|
|
|
# test with groups >= 2
|
|
x = torch.rand(1, 4, 256, 256)
|
|
group_saconv = SAConv2d(4, 4, kernel_size=3, padding=1, groups=2)
|
|
group_sac_out = group_saconv(x)
|
|
refer_conv = nn.Conv2d(4, 4, kernel_size=3, padding=1, groups=2)
|
|
refer_out = refer_conv(x)
|
|
assert group_sac_out.shape == refer_out.shape
|