1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00

Add complex64 type

This commit is contained in:
A. R. Shajii 2023-03-12 17:03:13 -04:00
parent e1ddbf8fdd
commit ac97877edd
2 changed files with 311 additions and 0 deletions

View File

@ -1,5 +1,10 @@
# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
@tuple
class complex64:
real: float32
imag: float32
@tuple
class complex:
real: float
@ -284,3 +289,280 @@ class int:
class float:
def __suffix_j__(x: float) -> complex:
return complex(0, x)
f32 = float32
@extend
class complex64:
def __new__() -> complex64:
return (f32(0.0), f32(0.0))
def __new__(other):
return complex64(other.__complex__())
def __new__(other: complex) -> complex64:
return (f32(other.real), f32(other.imag))
def __new__(real, imag) -> complex64:
return (f32(float(real)), f32(float(imag)))
def __complex__(self) -> complex:
return complex(float(self.real), float(self.imag))
def __bool__(self) -> bool:
return self.real != f32(0.0) and self.imag != f32(0.0)
def __pos__(self) -> complex64:
return self
def __neg__(self) -> complex64:
return complex64(-self.real, -self.imag)
def __abs__(self) -> f32:
@pure
@C
def hypotf(a: f32, b: f32) -> f32:
pass
return hypotf(self.real, self.imag)
def __copy__(self) -> complex64:
return self
def __hash__(self) -> int:
return self.real.__hash__() + self.imag.__hash__() * 1000003
def __add__(self, other) -> complex64:
return self + complex64(other)
def __sub__(self, other) -> complex64:
return self - complex64(other)
def __mul__(self, other) -> complex64:
return self * complex64(other)
def __truediv__(self, other) -> complex64:
return self / complex64(other)
def __eq__(self, other) -> bool:
return self == complex64(other)
def __ne__(self, other) -> bool:
return self != complex64(other)
def __pow__(self, other) -> complex64:
return self ** complex64(other)
def __radd__(self, other) -> complex64:
return complex64(other) + self
def __rsub__(self, other) -> complex64:
return complex64(other) - self
def __rmul__(self, other) -> complex64:
return complex64(other) * self
def __rtruediv__(self, other) -> complex64:
return complex64(other) / self
def __rpow__(self, other) -> complex64:
return complex64(other) ** self
def __add__(self, other: complex64) -> complex64:
return complex64(self.real + other.real, self.imag + other.imag)
def __sub__(self, other: complex64) -> complex64:
return complex64(self.real - other.real, self.imag - other.imag)
def __mul__(self, other: complex64) -> complex64:
a = (self.real * other.real) - (self.imag * other.imag)
b = (self.real * other.imag) + (self.imag * other.real)
return complex64(a, b)
def __truediv__(self, other: complex64) -> complex64:
a = self
b = other
abs_breal = (-b.real) if b.real < f32(0) else b.real
abs_bimag = (-b.imag) if b.imag < f32(0) else b.imag
if abs_breal >= abs_bimag:
# divide tops and bottom by b.real
if abs_breal == f32(0.0):
# errno = EDOM
return complex64(0.0, 0.0)
else:
ratio = b.imag / b.real
denom = b.real + b.imag * ratio
return complex64(
(a.real + a.imag * ratio) / denom, (a.imag - a.real * ratio) / denom
)
elif abs_bimag >= abs_breal:
# divide tops and bottom by b.imag
ratio = b.real / b.imag
denom = b.real * ratio + b.imag
# assert b.imag != 0.0
return complex64(
(a.real * ratio + a.imag) / denom, (a.imag * ratio - a.real) / denom
)
else:
nan = 0.0 / 0.0
return complex64(nan, nan)
def __eq__(self, other: complex64) -> bool:
return self.real == other.real and self.imag == other.imag
def __ne__(self, other: complex64) -> bool:
return not (self == other)
def __pow__(self, other: int) -> complex64:
def powu(x: complex64, n: int) -> complex64:
mask = 1
r = complex64(1.0, 0.0)
p = x
while mask > 0 and n >= mask:
if n & mask:
r = r * p
mask <<= 1
p = p * p
return r
if other > 0:
return powu(self, other)
else:
return complex64(1.0, 0.0) / powu(self, -other)
def __pow__(self, other: complex64) -> complex64:
@pure
@C
def hypotf(a: f32, b: f32) -> f32:
pass
@pure
@C
def atan2f(a: f32, b: f32) -> f32:
pass
@pure
@llvm
def exp(x: f32) -> f32:
declare float @llvm.exp.f32(float)
%y = call float @llvm.exp.f32(float %x)
ret float %y
@pure
@llvm
def pow(x: f32, y: f32) -> f32:
declare float @llvm.pow.f32(float, float)
%z = call float @llvm.pow.f32(float %x, float %y)
ret float %z
@pure
@llvm
def log(x: f32) -> f32:
declare float @llvm.log.f32(float)
%y = call float @llvm.log.f32(float %x)
ret float %y
@pure
@llvm
def sin(x: f32) -> f32:
declare float @llvm.sin.f32(float)
%y = call float @llvm.sin.f32(float %x)
ret float %y
@pure
@llvm
def cos(x: f32) -> f32:
declare float @llvm.cos.f32(float)
%y = call float @llvm.cos.f32(float %x)
ret float %y
if other.real == f32(0.0) and other.imag == f32(0.0):
return complex64(1.0, 0.0)
elif self.real == f32(0.0) and self.imag == f32(0.0):
# if other.imag != 0. or other.real < 0.: errno = EDOM
return complex64(0.0, 0.0)
else:
vabs = hypotf(self.real, self.imag)
len = pow(vabs, other.real)
at = atan2f(self.imag, self.real)
phase = at * other.real
if other.imag != f32(0.0):
len /= exp(at * other.imag)
phase += other.imag * log(vabs)
return complex64(len * cos(phase), len * sin(phase))
def __repr__(self) -> str:
@pure
@llvm
def copysign(x: f32, y: f32) -> f32:
declare float @llvm.copysign.f32(float, float)
%z = call float @llvm.copysign.f32(float %x, float %y)
ret float %z
@pure
@llvm
def fabs(x: f32) -> f32:
declare float @llvm.fabs.f32(float)
%y = call float @llvm.fabs.f32(float %x)
ret float %y
if self.real == f32(0.0) and copysign(f32(1.0), self.real) == f32(1.0):
return f"complex64({self.imag}j)"
else:
sign = "+"
if self.imag < f32(0.0) or (
self.imag == f32(0.0) and copysign(f32(1.0), self.imag) == f32(-1.0)
):
sign = "-"
return f"complex64({self.real}{sign}{fabs(self.imag)}j)"
def conjugate(self) -> complex64:
return complex64(self.real, -self.imag)
# helpers
def _phase(self) -> f32:
@pure
@C
def atan2f(a: f32, b: f32) -> f32:
pass
return atan2f(self.imag, self.real)
def _polar(self) -> Tuple[f32, f32]:
return (self.__abs__(), self._phase())
@pure
@llvm
def _exp(x: f32) -> f32:
declare float @llvm.exp.f32(float)
%y = call float @llvm.exp.f32(float %x)
ret float %y
@pure
@llvm
def _sqrt(x: f32) -> f32:
declare float @llvm.sqrt.f32(float)
%y = call float @llvm.sqrt.f32(float %x)
ret float %y
@pure
@llvm
def _cos(x: f32) -> f32:
declare float @llvm.cos.f32(float)
%y = call float @llvm.cos.f32(float %x)
ret float %y
@pure
@llvm
def _sin(x: f32) -> f32:
declare float @llvm.sin.f32(float)
%y = call float @llvm.sin.f32(float %x)
ret float %y
@pure
@llvm
def _log(x: f32) -> f32:
declare float @llvm.log.f32(float)
%y = call float @llvm.log.f32(float %x)
ret float %y

View File

@ -794,3 +794,32 @@ def test_cmath_testcases():
test_cmath_testcases()
@test
def test_complex64():
c64 = complex64
z = c64(.5 + .5j)
assert +z == z
assert -z == c64(-.5 - .5j)
assert abs(z) == float32(0.7071067811865476)
assert z + 1 == c64(1.5 + .5j)
assert 1j + z == c64(.5 + 1.5j)
assert z * 2 == c64(1 + 1j)
assert 2j * z == c64(-1 + 1j)
assert z / .5 == c64(1 + 1j)
assert 1j / z == c64(1 + 1j)
assert z ** 2 == c64(.5j)
y = 1j ** z
assert math.isclose(float(y.real), 0.32239694194483454)
assert math.isclose(float(y.imag), 0.32239694194483454)
assert z != -z
assert z != 0
assert z.real == float32(.5)
assert (z + 1j).imag == float32(1.5)
assert z.conjugate() == c64(.5 - .5j)
assert z.__copy__() == z
assert hash(z)
assert c64(complex(z)) == z
test_complex64()