codon/test/transform/canonical.codon

201 lines
4.9 KiB
Python

@tuple
class Vec[T]:
x: T
y: T
@pure
def __abs__(self):
return ((self.x * self.x) + (self.y * self.y)) ** 0.5
@pure
@commutative
@associative
def __add__(self, other: Vec[T]):
print 'vec add', self, other
return Vec[T](self.x + other.x, self.y + other.y)
@pure
@commutative
@associative
def __add__(self, other: T):
print 'vec add', self, other
return Vec[T](self.x + other, self.y + other)
@pure
def __sub__(self, other: Vec[T]):
print 'vec sub', self, other
return Vec[T](self.x - other.x, self.y - other.y)
@pure
def __sub__(self, other: T):
print 'vec sub', self, other
return Vec[T](self.x - other, self.y - other)
@pure
@commutative
@associative
@distributive
def __mul__(self, other: Vec[T]):
print 'vec mul', self, other
return Vec[T](self.x * other.x, self.y * other.y)
@pure
@commutative
@associative
@distributive
def __mul__(self, other: T):
print 'vec mul', self, other
return Vec[T](self.x * other, self.y * other)
@pure
@commutative
def __eq__(self, other: Vec[T]):
print 'vec eq', self, other
return abs(self) == abs(other)
@pure
@commutative
def __ne__(self, other: Vec[T]):
print 'vec ne', self, other
return abs(self) != abs(other)
@pure
def __lt__(self, other: Vec[T]):
print 'vec lt', self, other
return abs(self) < abs(other)
@pure
def __le__(self, other: Vec[T]):
print 'vec le', self, other
return abs(self) <= abs(other)
@pure
def __gt__(self, other: Vec[T]):
print 'vec gt', self, other
return abs(self) > abs(other)
@pure
def __ge__(self, other: Vec[T]):
print 'vec ge', self, other
return abs(self) >= abs(other)
@test
def test_op_chain_canon():
@pure
def f(a): return a
a = Vec(1, 2)
b = Vec(3, 4)
c = a + f(b) # -> f(b) + a
assert (c.x, c.y) == (4, 6)
# EXPECT: vec add (x: 3, y: 4) (x: 1, y: 2)
a = Vec(1, 2)
b = Vec(3, 4)
c = Vec(5, 6)
d = f(a + f(b) + f(f(c))) # -> f(f(f(c)) + f(b) + a)
assert (d.x, d.y) == (9, 12)
# EXPECT: vec add (x: 5, y: 6) (x: 3, y: 4)
# EXPECT: vec add (x: 8, y: 10) (x: 1, y: 2)
a = Vec(1, 2)
b = Vec(3, 4)
c = Vec(5, 6)
d = f(a + (f(b) + f(f(c)))) # -> f(f(f(c)) + f(b) + a)
assert (d.x, d.y) == (9, 12)
# EXPECT: vec add (x: 5, y: 6) (x: 3, y: 4)
# EXPECT: vec add (x: 8, y: 10) (x: 1, y: 2)
a = Vec(1, 2)
b = Vec(3, 4)
c = a - f(b) # -> no change
assert (c.x, c.y) == (-2, -2)
# EXPECT: vec sub (x: 1, y: 2) (x: 3, y: 4)
# don't canon float ops
assert f(1e100) + f(f(-1e100)) + f(f(f(1.))) == 1.
test_op_chain_canon()
class C:
n: int
def __lt__(self: C, other: C):
return self.n < other.n
@test
def test_inequality_canon():
@pure
def f(a): return a
a = Vec(1,1)
b = Vec(2,2)
assert not (f(a) == b)
assert f(a) != b
assert f(a) < b
assert f(a) <= b
assert not (f(a) > b)
assert not (f(a) >= b)
# EXPECT: vec eq (x: 1, y: 1) (x: 2, y: 2)
# EXPECT: vec ne (x: 1, y: 1) (x: 2, y: 2)
# EXPECT: vec lt (x: 1, y: 1) (x: 2, y: 2)
# EXPECT: vec le (x: 1, y: 1) (x: 2, y: 2)
# EXPECT: vec gt (x: 1, y: 1) (x: 2, y: 2)
# EXPECT: vec ge (x: 1, y: 1) (x: 2, y: 2)
assert not (a == f(b))
assert a != f(b)
assert a < f(b)
assert a <= f(b)
assert not (a > f(b))
assert not (a >= f(b))
# EXPECT: vec eq (x: 2, y: 2) (x: 1, y: 1)
# EXPECT: vec ne (x: 2, y: 2) (x: 1, y: 1)
# EXPECT: vec gt (x: 2, y: 2) (x: 1, y: 1)
# EXPECT: vec ge (x: 2, y: 2) (x: 1, y: 1)
# EXPECT: vec lt (x: 2, y: 2) (x: 1, y: 1)
# EXPECT: vec le (x: 2, y: 2) (x: 1, y: 1)
c1 = C(1)
c2 = C(2)
# ensure we don't use missing ops
assert c1 < f(c2)
test_inequality_canon()
@test
def test_add_mul_canon():
@pure
def f(a): return a
a = Vec(1,1)
b = Vec(2,2)
c = Vec(3,3)
d = (a*f(b) + c*a) # -> (f(b) + c) * a
assert (d.x, d.y) == (5, 5)
# EXPECT: vec add (x: 2, y: 2) (x: 3, y: 3)
# EXPECT: vec mul (x: 5, y: 5) (x: 1, y: 1)
d = (a + c*a) # -> (c + 1) * a
assert (d.x, d.y) == (4, 4)
# EXPECT: vec add (x: 3, y: 3) 1
# EXPECT: vec mul (x: 4, y: 4) (x: 1, y: 1)
d = (c*a + a) # -> (c + 1) * a
assert (d.x, d.y) == (4, 4)
# EXPECT: vec add (x: 3, y: 3) 1
# EXPECT: vec mul (x: 4, y: 4) (x: 1, y: 1)
a = Vec(1,1)
b = a + a + a + a + a
assert (b.x, b.y) == (5, 5)
# EXPECT: vec mul (x: 1, y: 1) 5
a = Vec(1,1)
b = a + a*2 + a*3 + a*4 + a*5
assert (b.x, b.y) == (15, 15)
# EXPECT: vec mul (x: 1, y: 1) 15
x = f(100.) # don't distribute float ops
assert (x * 0.1) + (x * 0.2) == 30.
test_add_mul_canon()