@tuple class Vec[T]: x: T y: T @pure def __abs__(self): return ((self.x * self.x) + (self.y * self.y)) ** 0.5 @pure @commutative @associative def __add__(self, other: Vec[T]): print 'vec add', self, other return Vec[T](self.x + other.x, self.y + other.y) @pure @commutative @associative def __add__(self, other: T): print 'vec add', self, other return Vec[T](self.x + other, self.y + other) @pure def __sub__(self, other: Vec[T]): print 'vec sub', self, other return Vec[T](self.x - other.x, self.y - other.y) @pure def __sub__(self, other: T): print 'vec sub', self, other return Vec[T](self.x - other, self.y - other) @pure @commutative @associative @distributive def __mul__(self, other: Vec[T]): print 'vec mul', self, other return Vec[T](self.x * other.x, self.y * other.y) @pure @commutative @associative @distributive def __mul__(self, other: T): print 'vec mul', self, other return Vec[T](self.x * other, self.y * other) @pure @commutative def __eq__(self, other: Vec[T]): print 'vec eq', self, other return abs(self) == abs(other) @pure @commutative def __ne__(self, other: Vec[T]): print 'vec ne', self, other return abs(self) != abs(other) @pure def __lt__(self, other: Vec[T]): print 'vec lt', self, other return abs(self) < abs(other) @pure def __le__(self, other: Vec[T]): print 'vec le', self, other return abs(self) <= abs(other) @pure def __gt__(self, other: Vec[T]): print 'vec gt', self, other return abs(self) > abs(other) @pure def __ge__(self, other: Vec[T]): print 'vec ge', self, other return abs(self) >= abs(other) @test def test_op_chain_canon(): @pure def f(a): return a a = Vec(1, 2) b = Vec(3, 4) c = a + f(b) # -> f(b) + a assert (c.x, c.y) == (4, 6) # EXPECT: vec add (x: 3, y: 4) (x: 1, y: 2) a = Vec(1, 2) b = Vec(3, 4) c = Vec(5, 6) d = f(a + f(b) + f(f(c))) # -> f(f(f(c)) + f(b) + a) assert (d.x, d.y) == (9, 12) # EXPECT: vec add (x: 5, y: 6) (x: 3, y: 4) # EXPECT: vec add (x: 8, y: 10) (x: 1, y: 2) a = Vec(1, 2) b = Vec(3, 4) c = Vec(5, 6) d = f(a + (f(b) + f(f(c)))) # -> f(f(f(c)) + f(b) + a) assert (d.x, d.y) == (9, 12) # EXPECT: vec add (x: 5, y: 6) (x: 3, y: 4) # EXPECT: vec add (x: 8, y: 10) (x: 1, y: 2) a = Vec(1, 2) b = Vec(3, 4) c = a - f(b) # -> no change assert (c.x, c.y) == (-2, -2) # EXPECT: vec sub (x: 1, y: 2) (x: 3, y: 4) # don't canon float ops assert f(1e100) + f(f(-1e100)) + f(f(f(1.))) == 1. test_op_chain_canon() class C: n: int def __lt__(self: C, other: C): return self.n < other.n @test def test_inequality_canon(): @pure def f(a): return a a = Vec(1,1) b = Vec(2,2) assert not (f(a) == b) assert f(a) != b assert f(a) < b assert f(a) <= b assert not (f(a) > b) assert not (f(a) >= b) # EXPECT: vec eq (x: 1, y: 1) (x: 2, y: 2) # EXPECT: vec ne (x: 1, y: 1) (x: 2, y: 2) # EXPECT: vec lt (x: 1, y: 1) (x: 2, y: 2) # EXPECT: vec le (x: 1, y: 1) (x: 2, y: 2) # EXPECT: vec gt (x: 1, y: 1) (x: 2, y: 2) # EXPECT: vec ge (x: 1, y: 1) (x: 2, y: 2) assert not (a == f(b)) assert a != f(b) assert a < f(b) assert a <= f(b) assert not (a > f(b)) assert not (a >= f(b)) # EXPECT: vec eq (x: 2, y: 2) (x: 1, y: 1) # EXPECT: vec ne (x: 2, y: 2) (x: 1, y: 1) # EXPECT: vec gt (x: 2, y: 2) (x: 1, y: 1) # EXPECT: vec ge (x: 2, y: 2) (x: 1, y: 1) # EXPECT: vec lt (x: 2, y: 2) (x: 1, y: 1) # EXPECT: vec le (x: 2, y: 2) (x: 1, y: 1) c1 = C(1) c2 = C(2) # ensure we don't use missing ops assert c1 < f(c2) test_inequality_canon() @test def test_add_mul_canon(): @pure def f(a): return a a = Vec(1,1) b = Vec(2,2) c = Vec(3,3) d = (a*f(b) + c*a) # -> (f(b) + c) * a assert (d.x, d.y) == (5, 5) # EXPECT: vec add (x: 2, y: 2) (x: 3, y: 3) # EXPECT: vec mul (x: 5, y: 5) (x: 1, y: 1) d = (a + c*a) # -> (c + 1) * a assert (d.x, d.y) == (4, 4) # EXPECT: vec add (x: 3, y: 3) 1 # EXPECT: vec mul (x: 4, y: 4) (x: 1, y: 1) d = (c*a + a) # -> (c + 1) * a assert (d.x, d.y) == (4, 4) # EXPECT: vec add (x: 3, y: 3) 1 # EXPECT: vec mul (x: 4, y: 4) (x: 1, y: 1) a = Vec(1,1) b = a + a + a + a + a assert (b.x, b.y) == (5, 5) # EXPECT: vec mul (x: 1, y: 1) 5 a = Vec(1,1) b = a + a*2 + a*3 + a*4 + a*5 assert (b.x, b.y) == (15, 15) # EXPECT: vec mul (x: 1, y: 1) 15 x = f(100.) # don't distribute float ops assert (x * 0.1) + (x * 0.2) == 30. test_add_mul_canon()