Add complex tests

pull/1/head
A. R. Shajii 2021-10-04 13:07:11 -04:00
parent 587ce851c4
commit 6ed45a644c
2 changed files with 284 additions and 41 deletions

View File

@ -40,8 +40,7 @@ class complex:
return self
def __hash__(self):
# TODO
return self.real.__hash__() ^ self.imag.__hash__()
return self.real.__hash__() + self.imag.__hash__()*1000003
def __add__(self, other: complex):
return complex(self.real + other.real, self.imag + other.imag)
@ -55,10 +54,29 @@ class complex:
return complex(a, b)
def __truediv__(self, other: complex):
h = (other.real * other.real) + (other.imag * other.imag)
a = ((self.real * other.real) + (self.imag * other.imag)) / h
b = ((self.imag * other.real) - (self.real * other.imag)) / h
return complex(a, b)
a = self
b = other
abs_breal = (-b.real) if b.real < 0 else b.real
abs_bimag = (-b.imag) if b.imag < 0 else b.imag
if abs_breal >= abs_bimag:
# divide tops and bottom by b.real
if abs_breal == 0.0:
# errno = EDOM
return complex(0.0, 0.0)
else:
ratio = b.imag / b.real
denom = b.real + b.imag*ratio
return complex((a.real + a.imag*ratio)/denom, (a.imag - a.real*ratio)/denom)
elif abs_bimag >= abs_breal:
# divide tops and bottom by b.imag
ratio = b.real / b.imag
denom = b.real*ratio + b.imag
# assert b.imag != 0.0
return complex((a.real*ratio + a.imag)/denom, (a.imag*ratio - a.real)/denom)
else:
nan = 0./0.
return complex(nan, nan)
def __eq__(self, other: complex):
return self.real == other.real and self.imag == other.imag
@ -66,20 +84,81 @@ class complex:
def __ne__(self, other: complex):
return not (self == other)
def __pow__(self, other: int):
def powu(x: complex, n: int):
mask = 1
r = complex(1.0, 0.0)
p = x
while mask > 0 and n >= mask:
if n & mask:
r = r*p
mask <<= 1
p = p*p
return r
if other > 0:
return powu(self, other)
else:
return complex(1.0, 0.0) / powu(self, -other)
def __pow__(self, other: complex):
x = other.real
y = other.imag
absa = self.__abs__()
if absa == 0.0:
return complex(0.0, 0.0)
arga = self._phase()
r = absa ** x
theta = x * arga
if y != 0.0:
r = r * complex._exp(-y * arga)
theta = theta + y*complex._log(absa)
w = complex(r * complex._cos(theta), r * complex._sin(theta))
return w
@pure
@C
def hypot(a: float, b: float) -> float: pass
@pure
@C
def atan2(a: float, b: float) -> float: pass
@pure
@llvm
def exp(x: float) -> float:
declare double @llvm.exp.f64(double)
%y = call double @llvm.exp.f64(double %x)
ret double %y
@pure
@llvm
def pow(x: float, y: float) -> float:
declare double @llvm.pow.f64(double, double)
%z = call double @llvm.pow.f64(double %x, double %y)
ret double %z
@pure
@llvm
def log(x: float) -> float:
declare double @llvm.log.f64(double)
%y = call double @llvm.log.f64(double %x)
ret double %y
@pure
@llvm
def sin(x: float) -> float:
declare double @llvm.sin.f64(double)
%y = call double @llvm.sin.f64(double %x)
ret double %y
@pure
@llvm
def cos(x: float) -> float:
declare double @llvm.cos.f64(double)
%y = call double @llvm.cos.f64(double %x)
ret double %y
if other.real == 0. and other.imag == 0.:
return complex(1., 0.)
elif self.real == 0. and self.imag == 0.:
# if other.imag != 0. or other.real < 0.: errno = EDOM
return complex(0., 0.)
else:
vabs = hypot(self.real, self.imag)
len = pow(vabs, other.real)
at = atan2(self.imag, self.real)
phase = at * other.real
if other.imag != 0.:
len /= exp(at * other.imag)
phase += other.imag * log(vabs)
return complex(len * cos(phase), len * sin(phase))
def __add__(self, other):
return self + complex(other)
@ -118,13 +197,27 @@ class complex:
return complex(other) ** self
def __str__(self):
if self.real == 0.0:
@pure
@llvm
def copysign(x: float, y: float) -> float:
declare double @llvm.copysign.f64(double, double)
%z = call double @llvm.copysign.f64(double %x, double %y)
ret double %z
@pure
@llvm
def fabs(x: float) -> float:
declare double @llvm.fabs.f64(double)
%y = call double @llvm.fabs.f64(double %x)
ret double %y
if self.real == 0.0 and copysign(1., self.real) == 1.:
return f'{self.imag}j'
else:
if self.imag >= 0:
return f'{self.real}+{self.imag}j'
else:
return f'{self.real}-{-self.imag}j'
sign = '+'
if self.imag < 0.0 or (self.imag == 0.0 and copysign(1., self.imag) == -1.):
sign = '-'
return f'({self.real}{sign}{fabs(self.imag)}j)'
def conjugate(self):
return complex(self.real, -self.imag)

View File

@ -5,6 +5,173 @@ INF = float('inf')
NAN = float('nan')
j = complex(0, 1)
def float_identical(x, y):
if math.isnan(x) or math.isnan(y):
if math.isnan(x) and math.isnan(y):
return True
elif x == y:
if x != 0.0:
return True
# both zero; check that signs match
elif math.copysign(1.0, x) == math.copysign(1.0, y):
return True
else:
return False
return False
def complex_identical(x, y):
return float_identical(x.real, y.real) and float_identical(x.imag, y.imag)
###########
# complex #
###########
ZERO_DIVISION = (
(1+1*j, 0+0*j),
(1+1*j, 0.0+0*j),
(1+1*j, 0+0*j),
(1.0+0*j, 0+0*j),
(1+0*j, 0+0*j),
)
def close_abs(x, y, eps=1e-9):
"""Return true iff floats x and y "are close"."""
# put the one with larger magnitude second
if abs(x) > abs(y):
x, y = y, x
if y == 0:
return abs(x) < eps
if x == 0:
return abs(y) < eps
# check that relative difference < eps
return abs((x-y)/y) < eps
def close_complex(x, y, eps=1e-9):
a = complex(x)
b = complex(y)
return close_abs(a.real, b.real, eps) and close_abs(a.imag, b.imag, eps)
def check_div(x, y):
"""Compute complex z=x*y, and check that z/x==y and z/y==x."""
z = x * y
if x != 0:
q = z / x
if not close_complex(q, y):
return False
q = z.__truediv__(x)
if not close_complex(q, y):
return False
if y != 0:
q = z / y
if not close_complex(q, x):
return False
q = z.__truediv__(y)
if not close_complex(q, x):
return False
return True
@test
def test_truediv():
from random import random
simple_real = [float(i) for i in range(-5, 6)]
simple_complex = [complex(x, y) for x in simple_real for y in simple_real]
for x in simple_complex:
for y in simple_complex:
assert check_div(x, y)
# A naive complex division algorithm (such as in 2.0) is very prone to
# nonsense errors for these (overflows and underflows).
assert check_div(complex(1e200, 1e200), 1+0*j)
assert check_div(complex(1e-200, 1e-200), 1+0*j)
# Just for fun.
for i in range(100):
check_div(complex(random(), random()), complex(random(), random()))
assert close_complex(complex.__truediv__(2+0*j, 1+1*j), 1-1*j)
for denom_real, denom_imag in [(0., NAN), (NAN, 0.), (NAN, NAN)]:
z = complex(0, 0) / complex(denom_real, denom_imag)
assert math.isnan(z.real)
assert math.isnan(z.imag)
test_truediv()
@test
def test_richcompare():
assert not complex.__eq__(1+1*j, 1<<10000)
assert complex.__eq__(1+1*j, 1+1*j)
assert not complex.__eq__(1+1*j, 2+2*j)
assert not complex.__ne__(1+1*j, 1+1*j)
assert complex.__ne__(1+1*j, 2+2*j), True
for i in range(1, 100):
f = i / 100.0
assert complex.__eq__(f+0*j, f)
assert not complex.__ne__(f+0*j, f)
assert not complex.__eq__(complex(f, f), f)
assert complex.__ne__(complex(f, f), f)
import operator
assert operator.eq(1+1*j, 1+1*j) == True
assert operator.eq(1+1*j, 2+2*j) == False
assert operator.ne(1+1*j, 1+1*j) == False
assert operator.ne(1+1*j, 2+2*j) == True
test_richcompare()
@test
def test_pow():
def pow(a, b): return a ** b
assert close_complex(pow(1+1*j, 0+0*j), 1.0)
assert close_complex(pow(0+0*j, 2+0*j), 0.0)
assert close_complex(pow(1*j, -1), 1/(1*j))
assert close_complex(pow(1*j, 200), 1)
a = 3.33+4.43*j
assert a ** (0*j) == 1
assert a ** (0.+0.*j) == 1
assert (3*j) ** (0*j) == 1
assert (3*j) ** 0 == 1
# The following is used to exercise certain code paths
assert a ** 105 == a ** 105
assert a ** -105 == a ** -105
assert a ** -30 == a ** -30
assert (0.0*j) ** 0 == 1
test_pow()
@test
def test_conjugate():
assert close_complex(complex(5.3, 9.8).conjugate(), 5.3-9.8*j)
test_conjugate()
@test
def test_cabs():
nums = [complex(x/3., y/7.) for x in range(-9,9) for y in range(-9,9)]
for num in nums:
assert close_complex((num.real**2 + num.imag**2) ** 0.5, abs(num))
test_cabs()
@test
def test_negative_zero_repr_str():
def test(v, expected):
return str(v) == expected
assert test(complex(0., 1.), "1j")
assert test(complex(-0., 1.), "(-0+1j)")
assert test(complex(0., -1.), "-1j")
assert test(complex(-0., -1.), "(-0-1j)")
assert test(complex(0., 0.), "0j")
assert test(complex(0., -0.), "-0j")
assert test(complex(-0., 0.), "(-0+0j)")
assert test(complex(-0., -0.), "(-0-0j)")
test_negative_zero_repr_str()
#########
# cmath #
#########
complex_zeros = [complex(x, y) for x in [0.0, -0.0] for y in [0.0, -0.0]]
complex_infinities = [complex(x, y) for x, y in [
(INF, 0.0), # 1st quadrant
@ -43,23 +210,6 @@ complex_nans = [complex(x, y) for x, y in [
(INF, NAN)
]]
def float_identical(x, y):
if math.isnan(x) or math.isnan(y):
if math.isnan(x) and math.isnan(y):
return True
elif x == y:
if x != 0.0:
return True
# both zero; check that signs match
elif math.copysign(1.0, x) == math.copysign(1.0, y):
return True
else:
return False
return False
def complex_identical(x, y):
return float_identical(x.real, y.real) and float_identical(x.imag, y.imag)
@llvm
@pure
def small() -> float: