mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
* add bias_act * support bias_act * support filtered_lrelu * support filtered_lrelu and upfirdn2d * support conv2d_gradfix and fix filtered_lrelu * fix lint * fix lint * fix c++ lint * fix part comments * fix lint * rm redundant header * fix upgrade pip * fix as comment * fix c++ lint * fix ci * fix-ut * fix as comments * add grad check * remove redundant template * Update mmcv/ops/bias_act.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * add typehint * fix as comment: * complete type hints * fix lint * add test for conv_gradfix * add test for conv_gradfix * fix lint * modify licenses and ops.md * add zh op md * add torch version policy for conv2d_gradfix * fix lint * fix as comments * rename impl * rm redudant function and add ut * fix as comment * fix lint * fix lint * fix as comments * fix lint * fix ut * fix as comment * fix as comment * fix as comment --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
145 lines
5.3 KiB
Python
145 lines
5.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmcv.ops import bias_act
|
|
from mmcv.ops.bias_act import EasyDict
|
|
|
|
_USING_PARROTS = True
|
|
try:
|
|
from parrots.autograd import gradcheck
|
|
except ImportError:
|
|
from torch.autograd import gradcheck, gradgradcheck
|
|
_USING_PARROTS = False
|
|
|
|
|
|
class TestBiasAct:
|
|
|
|
@classmethod
|
|
def setup_class(cls):
|
|
cls.input_tensor = torch.randn((1, 3), requires_grad=True)
|
|
cls.bias = torch.randn(3, requires_grad=True)
|
|
|
|
def test_bias_act_cpu(self):
|
|
out = bias_act(self.input_tensor, self.bias)
|
|
assert out.shape == (1, 3)
|
|
|
|
# test with different dim
|
|
input_tensor = torch.randn((1, 1, 3), requires_grad=True)
|
|
bias = torch.randn(3, requires_grad=True)
|
|
out = bias_act(input_tensor, bias, dim=2)
|
|
assert out.shape == (1, 1, 3)
|
|
|
|
# test with different act
|
|
out = bias_act(self.input_tensor, self.bias, act='relu')
|
|
assert out.shape == (1, 3)
|
|
out = bias_act(self.input_tensor, self.bias, act='lrelu')
|
|
assert out.shape == (1, 3)
|
|
out = bias_act(self.input_tensor, self.bias, act='tanh')
|
|
assert out.shape == (1, 3)
|
|
out = bias_act(self.input_tensor, self.bias, act='sigmoid')
|
|
assert out.shape == (1, 3)
|
|
out = bias_act(self.input_tensor, self.bias, act='elu')
|
|
assert out.shape == (1, 3)
|
|
out = bias_act(self.input_tensor, self.bias, act='selu')
|
|
assert out.shape == (1, 3)
|
|
out = bias_act(self.input_tensor, self.bias, act='softplus')
|
|
assert out.shape == (1, 3)
|
|
out = bias_act(self.input_tensor, self.bias, act='swish')
|
|
assert out.shape == (1, 3)
|
|
|
|
# test with different alpha
|
|
out = bias_act(self.input_tensor, self.bias, act='lrelu', alpha=0.1)
|
|
assert out.shape == (1, 3)
|
|
|
|
# test with different gain
|
|
out1 = bias_act(self.input_tensor, self.bias, act='lrelu', gain=0.2)
|
|
out2 = bias_act(self.input_tensor, self.bias, act='lrelu', gain=0.1)
|
|
assert torch.allclose(out1, out2 * 2)
|
|
|
|
# test with different clamp
|
|
out1 = bias_act(self.input_tensor, self.bias, act='lrelu', clamp=0.5)
|
|
out2 = bias_act(self.input_tensor, self.bias, act='lrelu', clamp=0.2)
|
|
assert out1.max() <= 0.5
|
|
assert out2.max() <= 0.5
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
|
def test_bias_act_cuda(self):
|
|
if _USING_PARROTS:
|
|
gradcheck(
|
|
bias_act, (self.input_tensor.cuda(), self.bias.cuda()),
|
|
delta=1e-4,
|
|
pt_atol=1e-3)
|
|
else:
|
|
gradcheck(
|
|
bias_act, (self.input_tensor.cuda(), self.bias.cuda()),
|
|
eps=1e-4,
|
|
atol=1e-3)
|
|
|
|
gradgradcheck(
|
|
bias_act, (self.input_tensor.cuda(), self.bias.cuda()),
|
|
eps=1e-4,
|
|
atol=1e-3)
|
|
|
|
out = bias_act(self.input_tensor.cuda(), self.bias.cuda())
|
|
assert out.shape == (1, 3)
|
|
|
|
# test with different dim
|
|
input_tensor = torch.randn((1, 1, 3), requires_grad=True).cuda()
|
|
bias = torch.randn(3, requires_grad=True).cuda()
|
|
out = bias_act(input_tensor, bias, dim=2)
|
|
assert out.shape == (1, 1, 3)
|
|
|
|
# test with different act
|
|
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='relu')
|
|
assert out.shape == (1, 3)
|
|
|
|
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='lrelu')
|
|
assert out.shape == (1, 3)
|
|
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='tanh')
|
|
assert out.shape == (1, 3)
|
|
out = bias_act(
|
|
self.input_tensor.cuda(), self.bias.cuda(), act='sigmoid')
|
|
assert out.shape == (1, 3)
|
|
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='elu')
|
|
assert out.shape == (1, 3)
|
|
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='selu')
|
|
assert out.shape == (1, 3)
|
|
out = bias_act(
|
|
self.input_tensor.cuda(), self.bias.cuda(), act='softplus')
|
|
assert out.shape == (1, 3)
|
|
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='swish')
|
|
assert out.shape == (1, 3)
|
|
|
|
# test with different alpha
|
|
out = bias_act(
|
|
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', alpha=0.1)
|
|
assert out.shape == (1, 3)
|
|
|
|
# test with different gain
|
|
out1 = bias_act(
|
|
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', gain=0.2)
|
|
out2 = bias_act(
|
|
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', gain=0.1)
|
|
assert torch.allclose(out1, out2 * 2)
|
|
|
|
# test with different clamp
|
|
out1 = bias_act(
|
|
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', clamp=0.5)
|
|
out2 = bias_act(
|
|
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', clamp=0.2)
|
|
assert out1.max() <= 0.5
|
|
assert out2.max() <= 0.5
|
|
|
|
def test_easy_dict(self):
|
|
easy_dict = EasyDict(
|
|
func=lambda x, **_: x,
|
|
def_alpha=0,
|
|
def_gain=1,
|
|
cuda_idx=1,
|
|
ref='',
|
|
has_2nd_grad=False)
|
|
_ = easy_dict.def_alpha
|
|
easy_dict.def_alpha = 1
|
|
del easy_dict.def_alpha
|