mirror of https://github.com/open-mmlab/mmcv.git
277 lines
7.1 KiB
Python
277 lines
7.1 KiB
Python
import pytest
|
|
import torch
|
|
|
|
import mmcv
|
|
from mmcv.utils import TORCH_VERSION
|
|
|
|
skip_no_parrots = pytest.mark.skipif(
|
|
TORCH_VERSION != 'parrots', reason='test case under parrots environment')
|
|
|
|
|
|
class TestJit(object):
|
|
|
|
def test_add_dict(self):
|
|
|
|
@mmcv.jit
|
|
def add_dict(oper):
|
|
rets = oper['x'] + oper['y']
|
|
return {'result': rets}
|
|
|
|
def add_dict_pyfunc(oper):
|
|
rets = oper['x'] + oper['y']
|
|
return {'result': rets}
|
|
|
|
a = torch.rand((3, 4))
|
|
b = torch.rand((3, 4))
|
|
oper = {'x': a, 'y': b}
|
|
|
|
rets_t = add_dict(oper)
|
|
rets = add_dict_pyfunc(oper)
|
|
assert 'result' in rets
|
|
assert (rets_t['result'] == rets['result']).all()
|
|
|
|
def test_add_list(self):
|
|
|
|
@mmcv.jit
|
|
def add_list(oper, x, y):
|
|
rets = {}
|
|
for idx, pair in enumerate(oper):
|
|
rets[f'k{idx}'] = pair['x'] + pair['y']
|
|
rets[f'k{len(oper)}'] = x + y
|
|
return rets
|
|
|
|
def add_list_pyfunc(oper, x, y):
|
|
rets = {}
|
|
for idx, pair in enumerate(oper):
|
|
rets[f'k{idx}'] = pair['x'] + pair['y']
|
|
rets[f'k{len(oper)}'] = x + y
|
|
return rets
|
|
|
|
pair_num = 3
|
|
oper = []
|
|
for _ in range(pair_num):
|
|
oper.append({'x': torch.rand((3, 4)), 'y': torch.rand((3, 4))})
|
|
a = torch.rand((3, 4))
|
|
b = torch.rand((3, 4))
|
|
rets = add_list_pyfunc(oper, x=a, y=b)
|
|
rets_t = add_list(oper, x=a, y=b)
|
|
for idx in range(pair_num + 1):
|
|
assert f'k{idx}' in rets_t
|
|
assert (rets[f'k{idx}'] == rets_t[f'k{idx}']).all()
|
|
|
|
@skip_no_parrots
|
|
def test_jit_cache(self):
|
|
|
|
@mmcv.jit
|
|
def func(oper):
|
|
if oper['const'] > 1:
|
|
return oper['x'] * 2 + oper['y']
|
|
else:
|
|
return oper['x'] * 2 - oper['y']
|
|
|
|
def pyfunc(oper):
|
|
if oper['const'] > 1:
|
|
return oper['x'] * 2 + oper['y']
|
|
else:
|
|
return oper['x'] * 2 - oper['y']
|
|
|
|
assert len(func._cache._cache) == 0
|
|
|
|
oper = {'const': 2, 'x': torch.rand((3, 4)), 'y': torch.rand((3, 4))}
|
|
rets_plus = pyfunc(oper)
|
|
rets_plus_t = func(oper)
|
|
assert (rets_plus == rets_plus_t).all()
|
|
assert len(func._cache._cache) == 1
|
|
|
|
oper['const'] = 0.5
|
|
rets_minus = pyfunc(oper)
|
|
rets_minus_t = func(oper)
|
|
assert (rets_minus == rets_minus_t).all()
|
|
assert len(func._cache._cache) == 2
|
|
|
|
rets_a = (rets_minus_t + rets_plus_t) / 4
|
|
assert torch.allclose(oper['x'], rets_a)
|
|
|
|
@skip_no_parrots
|
|
def test_jit_shape(self):
|
|
|
|
@mmcv.jit
|
|
def func(a):
|
|
return a + 1
|
|
|
|
assert len(func._cache._cache) == 0
|
|
|
|
a = torch.ones((3, 4))
|
|
r = func(a)
|
|
assert r.shape == (3, 4)
|
|
assert (r == 2).all()
|
|
assert len(func._cache._cache) == 1
|
|
|
|
a = torch.ones((2, 3, 4))
|
|
r = func(a)
|
|
assert r.shape == (2, 3, 4)
|
|
assert (r == 2).all()
|
|
assert len(func._cache._cache) == 2
|
|
|
|
@skip_no_parrots
|
|
def test_jit_kwargs(self):
|
|
|
|
@mmcv.jit
|
|
def func(a, b):
|
|
return torch.mean((a - b) * (a - b))
|
|
|
|
assert len(func._cache._cache) == 0
|
|
x = torch.rand((16, 32))
|
|
y = torch.rand((16, 32))
|
|
func(x, y)
|
|
assert len(func._cache._cache) == 1
|
|
func(x, b=y)
|
|
assert len(func._cache._cache) == 1
|
|
func(b=y, a=x)
|
|
assert len(func._cache._cache) == 1
|
|
|
|
def test_jit_derivate(self):
|
|
|
|
@mmcv.jit(derivate=True)
|
|
def func(x, y):
|
|
return (x + 2) * (y - 2)
|
|
|
|
a = torch.rand((3, 4))
|
|
b = torch.rand((3, 4))
|
|
a.requires_grad = True
|
|
|
|
c = func(a, b)
|
|
assert c.requires_grad
|
|
d = torch.empty_like(c)
|
|
d.fill_(1.0)
|
|
c.backward(d)
|
|
assert torch.allclose(a.grad, (b - 2))
|
|
assert b.grad is None
|
|
|
|
a.grad = None
|
|
c = func(a, b)
|
|
assert c.requires_grad
|
|
d = torch.empty_like(c)
|
|
d.fill_(2.7)
|
|
c.backward(d)
|
|
assert torch.allclose(a.grad, 2.7 * (b - 2))
|
|
assert b.grad is None
|
|
|
|
def test_jit_optimize(self):
|
|
|
|
@mmcv.jit(optimize=True)
|
|
def func(a, b):
|
|
return torch.mean((a - b) * (a - b))
|
|
|
|
def pyfunc(a, b):
|
|
return torch.mean((a - b) * (a - b))
|
|
|
|
a = torch.rand((16, 32))
|
|
b = torch.rand((16, 32))
|
|
|
|
c = func(a, b)
|
|
d = pyfunc(a, b)
|
|
assert torch.allclose(c, d)
|
|
|
|
@mmcv.skip_no_elena
|
|
def test_jit_coderize(self):
|
|
if not torch.cuda.is_available():
|
|
return
|
|
|
|
@mmcv.jit(coderize=True)
|
|
def func(a, b):
|
|
return (a + b) * (a - b)
|
|
|
|
def pyfunc(a, b):
|
|
return (a + b) * (a - b)
|
|
|
|
a = torch.rand((16, 32), device='cuda')
|
|
b = torch.rand((16, 32), device='cuda')
|
|
|
|
c = func(a, b)
|
|
d = pyfunc(a, b)
|
|
assert torch.allclose(c, d)
|
|
|
|
def test_jit_value_dependent(self):
|
|
|
|
@mmcv.jit
|
|
def func(a, b):
|
|
torch.nonzero(a)
|
|
return torch.mean((a - b) * (a - b))
|
|
|
|
def pyfunc(a, b):
|
|
torch.nonzero(a)
|
|
return torch.mean((a - b) * (a - b))
|
|
|
|
a = torch.rand((16, 32))
|
|
b = torch.rand((16, 32))
|
|
|
|
c = func(a, b)
|
|
d = pyfunc(a, b)
|
|
assert torch.allclose(c, d)
|
|
|
|
@skip_no_parrots
|
|
def test_jit_check_input(self):
|
|
|
|
def func(x):
|
|
y = torch.rand_like(x)
|
|
return x + y
|
|
|
|
a = torch.ones((3, 4))
|
|
with pytest.raises(AssertionError):
|
|
func = mmcv.jit(func, check_input=(a, ))
|
|
|
|
@skip_no_parrots
|
|
def test_jit_partial_shape(self):
|
|
|
|
@mmcv.jit(full_shape=False)
|
|
def func(a, b):
|
|
return torch.mean((a - b) * (a - b))
|
|
|
|
def pyfunc(a, b):
|
|
return torch.mean((a - b) * (a - b))
|
|
|
|
a = torch.rand((3, 4))
|
|
b = torch.rand((3, 4))
|
|
assert torch.allclose(func(a, b), pyfunc(a, b))
|
|
assert len(func._cache._cache) == 1
|
|
|
|
a = torch.rand((6, 5))
|
|
b = torch.rand((6, 5))
|
|
assert torch.allclose(func(a, b), pyfunc(a, b))
|
|
assert len(func._cache._cache) == 1
|
|
|
|
a = torch.rand((3, 4, 5))
|
|
b = torch.rand((3, 4, 5))
|
|
assert torch.allclose(func(a, b), pyfunc(a, b))
|
|
assert len(func._cache._cache) == 2
|
|
|
|
a = torch.rand((1, 9, 8))
|
|
b = torch.rand((1, 9, 8))
|
|
assert torch.allclose(func(a, b), pyfunc(a, b))
|
|
assert len(func._cache._cache) == 2
|
|
|
|
def test_instance_method(self):
|
|
|
|
class T(object):
|
|
|
|
def __init__(self, shape):
|
|
self._c = torch.rand(shape)
|
|
|
|
@mmcv.jit
|
|
def test_method(self, x, y):
|
|
return (x * self._c) + y
|
|
|
|
shape = (16, 32)
|
|
t = T(shape)
|
|
a = torch.rand(shape)
|
|
b = torch.rand(shape)
|
|
res = (a * t._c) + b
|
|
jit_res = t.test_method(a, b)
|
|
assert torch.allclose(res, jit_res)
|
|
|
|
t = T(shape)
|
|
res = (a * t._c) + b
|
|
jit_res = t.test_method(a, b)
|
|
assert torch.allclose(res, jit_res)
|