mmcv/tests/test_ops/test_bias_act.py
Yifei Yang 869dbf1bf2
[Feature] Add Ops of StyleGAN3 (#2290)
* add bias_act

* support bias_act

* support filtered_lrelu

* support filtered_lrelu and upfirdn2d

* support conv2d_gradfix and fix filtered_lrelu

* fix lint

* fix lint

* fix c++ lint

* fix part comments

* fix lint

* rm redundant header

* fix upgrade pip

* fix as comment

* fix c++ lint

* fix ci

* fix-ut

* fix as comments

* add grad check

* remove redundant template

* Update mmcv/ops/bias_act.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* add typehint

* fix as comment:

* complete type hints

* fix lint

* add test for conv_gradfix

* add test for conv_gradfix

* fix lint

* modify licenses and ops.md

* add zh op md

* add torch version policy for conv2d_gradfix

* fix lint

* fix as comments

* rename impl

* rm redudant function and add ut

* fix as comment

* fix lint

* fix lint

* fix as comments

* fix lint

* fix ut

* fix as comment

* fix as comment

* fix as comment

---------

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
2023-03-13 16:05:11 +08:00

145 lines
5.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.ops import bias_act
from mmcv.ops.bias_act import EasyDict
_USING_PARROTS = True
try:
from parrots.autograd import gradcheck
except ImportError:
from torch.autograd import gradcheck, gradgradcheck
_USING_PARROTS = False
class TestBiasAct:
@classmethod
def setup_class(cls):
cls.input_tensor = torch.randn((1, 3), requires_grad=True)
cls.bias = torch.randn(3, requires_grad=True)
def test_bias_act_cpu(self):
out = bias_act(self.input_tensor, self.bias)
assert out.shape == (1, 3)
# test with different dim
input_tensor = torch.randn((1, 1, 3), requires_grad=True)
bias = torch.randn(3, requires_grad=True)
out = bias_act(input_tensor, bias, dim=2)
assert out.shape == (1, 1, 3)
# test with different act
out = bias_act(self.input_tensor, self.bias, act='relu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='lrelu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='tanh')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='sigmoid')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='elu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='selu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='softplus')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor, self.bias, act='swish')
assert out.shape == (1, 3)
# test with different alpha
out = bias_act(self.input_tensor, self.bias, act='lrelu', alpha=0.1)
assert out.shape == (1, 3)
# test with different gain
out1 = bias_act(self.input_tensor, self.bias, act='lrelu', gain=0.2)
out2 = bias_act(self.input_tensor, self.bias, act='lrelu', gain=0.1)
assert torch.allclose(out1, out2 * 2)
# test with different clamp
out1 = bias_act(self.input_tensor, self.bias, act='lrelu', clamp=0.5)
out2 = bias_act(self.input_tensor, self.bias, act='lrelu', clamp=0.2)
assert out1.max() <= 0.5
assert out2.max() <= 0.5
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_bias_act_cuda(self):
if _USING_PARROTS:
gradcheck(
bias_act, (self.input_tensor.cuda(), self.bias.cuda()),
delta=1e-4,
pt_atol=1e-3)
else:
gradcheck(
bias_act, (self.input_tensor.cuda(), self.bias.cuda()),
eps=1e-4,
atol=1e-3)
gradgradcheck(
bias_act, (self.input_tensor.cuda(), self.bias.cuda()),
eps=1e-4,
atol=1e-3)
out = bias_act(self.input_tensor.cuda(), self.bias.cuda())
assert out.shape == (1, 3)
# test with different dim
input_tensor = torch.randn((1, 1, 3), requires_grad=True).cuda()
bias = torch.randn(3, requires_grad=True).cuda()
out = bias_act(input_tensor, bias, dim=2)
assert out.shape == (1, 1, 3)
# test with different act
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='relu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='lrelu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='tanh')
assert out.shape == (1, 3)
out = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='sigmoid')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='elu')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='selu')
assert out.shape == (1, 3)
out = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='softplus')
assert out.shape == (1, 3)
out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='swish')
assert out.shape == (1, 3)
# test with different alpha
out = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', alpha=0.1)
assert out.shape == (1, 3)
# test with different gain
out1 = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', gain=0.2)
out2 = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', gain=0.1)
assert torch.allclose(out1, out2 * 2)
# test with different clamp
out1 = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', clamp=0.5)
out2 = bias_act(
self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', clamp=0.2)
assert out1.max() <= 0.5
assert out2.max() <= 0.5
def test_easy_dict(self):
easy_dict = EasyDict(
func=lambda x, **_: x,
def_alpha=0,
def_gain=1,
cuda_idx=1,
ref='',
has_2nd_grad=False)
_ = easy_dict.def_alpha
easy_dict.def_alpha = 1
del easy_dict.def_alpha