mirror of https://github.com/open-mmlab/mmcv.git
65 lines
2.3 KiB
Python
65 lines
2.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.autograd import gradcheck, gradgradcheck
|
|
|
|
from mmcv.ops import conv2d, conv_transpose2d
|
|
from mmcv.utils import IS_MUSA_AVAILABLE
|
|
|
|
|
|
class TestCond2d:
|
|
|
|
@classmethod
|
|
def setup_class(cls):
|
|
cls.input = torch.randn((1, 3, 32, 32), requires_grad=True)
|
|
cls.weight = nn.Parameter(torch.randn(1, 3, 3, 3))
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
|
def test_conv2d_cuda(self):
|
|
x = self.input.cuda()
|
|
weight = self.weight.cuda()
|
|
res = conv2d(x, weight, None, 1, 1)
|
|
assert res.shape == (1, 1, 32, 32)
|
|
gradcheck(conv2d, (x, weight, None, 1, 1), eps=1e-2, atol=0.1)
|
|
gradgradcheck(conv2d, (x, weight, None, 1, 1), eps=1e-2, atol=0.1)
|
|
|
|
@pytest.mark.skipif(not IS_MUSA_AVAILABLE, reason='requires musa')
|
|
def test_conv2d_musa(self):
|
|
x = self.input.musa()
|
|
weight = self.weight.musa()
|
|
res = conv2d(x, weight, None, 1, 1)
|
|
assert res.shape == (1, 1, 32, 32)
|
|
gradcheck(conv2d, (x, weight, None, 1, 1), eps=1e-2, atol=0.1)
|
|
gradgradcheck(conv2d, (x, weight, None, 1, 1), eps=1e-2, atol=0.1)
|
|
|
|
|
|
class TestCond2dTansposed:
|
|
|
|
@classmethod
|
|
def setup_class(cls):
|
|
cls.input = torch.randn((1, 3, 32, 32), requires_grad=True)
|
|
cls.weight = nn.Parameter(torch.randn(3, 1, 3, 3))
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
|
def test_conv2d_transposed_cuda(self):
|
|
x = self.input.cuda()
|
|
weight = self.weight.cuda()
|
|
res = conv_transpose2d(x, weight, None, 1, 1)
|
|
assert res.shape == (1, 1, 32, 32)
|
|
gradcheck(
|
|
conv_transpose2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2)
|
|
gradgradcheck(
|
|
conv_transpose2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2)
|
|
|
|
@pytest.mark.skipif(not IS_MUSA_AVAILABLE, reason='requires musa')
|
|
def test_conv2d_transposed_musa(self):
|
|
x = self.input.musa()
|
|
weight = self.weight.musa()
|
|
res = conv_transpose2d(x, weight, None, 1, 1)
|
|
assert res.shape == (1, 1, 32, 32)
|
|
gradcheck(
|
|
conv_transpose2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2)
|
|
gradgradcheck(
|
|
conv_transpose2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2)
|