mirror of
https://github.com/exaloop/codon.git
synced 2025-06-03 15:03:52 +08:00
Add complex64 type
This commit is contained in:
parent
e1ddbf8fdd
commit
ac97877edd
@ -1,5 +1,10 @@
|
||||
# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
|
||||
|
||||
@tuple
|
||||
class complex64:
|
||||
real: float32
|
||||
imag: float32
|
||||
|
||||
@tuple
|
||||
class complex:
|
||||
real: float
|
||||
@ -284,3 +289,280 @@ class int:
|
||||
class float:
|
||||
def __suffix_j__(x: float) -> complex:
|
||||
return complex(0, x)
|
||||
|
||||
f32 = float32
|
||||
|
||||
@extend
|
||||
class complex64:
|
||||
def __new__() -> complex64:
|
||||
return (f32(0.0), f32(0.0))
|
||||
|
||||
def __new__(other):
|
||||
return complex64(other.__complex__())
|
||||
|
||||
def __new__(other: complex) -> complex64:
|
||||
return (f32(other.real), f32(other.imag))
|
||||
|
||||
def __new__(real, imag) -> complex64:
|
||||
return (f32(float(real)), f32(float(imag)))
|
||||
|
||||
def __complex__(self) -> complex:
|
||||
return complex(float(self.real), float(self.imag))
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self.real != f32(0.0) and self.imag != f32(0.0)
|
||||
|
||||
def __pos__(self) -> complex64:
|
||||
return self
|
||||
|
||||
def __neg__(self) -> complex64:
|
||||
return complex64(-self.real, -self.imag)
|
||||
|
||||
def __abs__(self) -> f32:
|
||||
@pure
|
||||
@C
|
||||
def hypotf(a: f32, b: f32) -> f32:
|
||||
pass
|
||||
|
||||
return hypotf(self.real, self.imag)
|
||||
|
||||
def __copy__(self) -> complex64:
|
||||
return self
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.real.__hash__() + self.imag.__hash__() * 1000003
|
||||
|
||||
def __add__(self, other) -> complex64:
|
||||
return self + complex64(other)
|
||||
|
||||
def __sub__(self, other) -> complex64:
|
||||
return self - complex64(other)
|
||||
|
||||
def __mul__(self, other) -> complex64:
|
||||
return self * complex64(other)
|
||||
|
||||
def __truediv__(self, other) -> complex64:
|
||||
return self / complex64(other)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return self == complex64(other)
|
||||
|
||||
def __ne__(self, other) -> bool:
|
||||
return self != complex64(other)
|
||||
|
||||
def __pow__(self, other) -> complex64:
|
||||
return self ** complex64(other)
|
||||
|
||||
def __radd__(self, other) -> complex64:
|
||||
return complex64(other) + self
|
||||
|
||||
def __rsub__(self, other) -> complex64:
|
||||
return complex64(other) - self
|
||||
|
||||
def __rmul__(self, other) -> complex64:
|
||||
return complex64(other) * self
|
||||
|
||||
def __rtruediv__(self, other) -> complex64:
|
||||
return complex64(other) / self
|
||||
|
||||
def __rpow__(self, other) -> complex64:
|
||||
return complex64(other) ** self
|
||||
|
||||
def __add__(self, other: complex64) -> complex64:
|
||||
return complex64(self.real + other.real, self.imag + other.imag)
|
||||
|
||||
def __sub__(self, other: complex64) -> complex64:
|
||||
return complex64(self.real - other.real, self.imag - other.imag)
|
||||
|
||||
def __mul__(self, other: complex64) -> complex64:
|
||||
a = (self.real * other.real) - (self.imag * other.imag)
|
||||
b = (self.real * other.imag) + (self.imag * other.real)
|
||||
return complex64(a, b)
|
||||
|
||||
def __truediv__(self, other: complex64) -> complex64:
|
||||
a = self
|
||||
b = other
|
||||
abs_breal = (-b.real) if b.real < f32(0) else b.real
|
||||
abs_bimag = (-b.imag) if b.imag < f32(0) else b.imag
|
||||
|
||||
if abs_breal >= abs_bimag:
|
||||
# divide tops and bottom by b.real
|
||||
if abs_breal == f32(0.0):
|
||||
# errno = EDOM
|
||||
return complex64(0.0, 0.0)
|
||||
else:
|
||||
ratio = b.imag / b.real
|
||||
denom = b.real + b.imag * ratio
|
||||
return complex64(
|
||||
(a.real + a.imag * ratio) / denom, (a.imag - a.real * ratio) / denom
|
||||
)
|
||||
elif abs_bimag >= abs_breal:
|
||||
# divide tops and bottom by b.imag
|
||||
ratio = b.real / b.imag
|
||||
denom = b.real * ratio + b.imag
|
||||
# assert b.imag != 0.0
|
||||
return complex64(
|
||||
(a.real * ratio + a.imag) / denom, (a.imag * ratio - a.real) / denom
|
||||
)
|
||||
else:
|
||||
nan = 0.0 / 0.0
|
||||
return complex64(nan, nan)
|
||||
|
||||
def __eq__(self, other: complex64) -> bool:
|
||||
return self.real == other.real and self.imag == other.imag
|
||||
|
||||
def __ne__(self, other: complex64) -> bool:
|
||||
return not (self == other)
|
||||
|
||||
def __pow__(self, other: int) -> complex64:
|
||||
def powu(x: complex64, n: int) -> complex64:
|
||||
mask = 1
|
||||
r = complex64(1.0, 0.0)
|
||||
p = x
|
||||
while mask > 0 and n >= mask:
|
||||
if n & mask:
|
||||
r = r * p
|
||||
mask <<= 1
|
||||
p = p * p
|
||||
return r
|
||||
|
||||
if other > 0:
|
||||
return powu(self, other)
|
||||
else:
|
||||
return complex64(1.0, 0.0) / powu(self, -other)
|
||||
|
||||
def __pow__(self, other: complex64) -> complex64:
|
||||
@pure
|
||||
@C
|
||||
def hypotf(a: f32, b: f32) -> f32:
|
||||
pass
|
||||
|
||||
@pure
|
||||
@C
|
||||
def atan2f(a: f32, b: f32) -> f32:
|
||||
pass
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def exp(x: f32) -> f32:
|
||||
declare float @llvm.exp.f32(float)
|
||||
%y = call float @llvm.exp.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def pow(x: f32, y: f32) -> f32:
|
||||
declare float @llvm.pow.f32(float, float)
|
||||
%z = call float @llvm.pow.f32(float %x, float %y)
|
||||
ret float %z
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def log(x: f32) -> f32:
|
||||
declare float @llvm.log.f32(float)
|
||||
%y = call float @llvm.log.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def sin(x: f32) -> f32:
|
||||
declare float @llvm.sin.f32(float)
|
||||
%y = call float @llvm.sin.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def cos(x: f32) -> f32:
|
||||
declare float @llvm.cos.f32(float)
|
||||
%y = call float @llvm.cos.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
if other.real == f32(0.0) and other.imag == f32(0.0):
|
||||
return complex64(1.0, 0.0)
|
||||
elif self.real == f32(0.0) and self.imag == f32(0.0):
|
||||
# if other.imag != 0. or other.real < 0.: errno = EDOM
|
||||
return complex64(0.0, 0.0)
|
||||
else:
|
||||
vabs = hypotf(self.real, self.imag)
|
||||
len = pow(vabs, other.real)
|
||||
at = atan2f(self.imag, self.real)
|
||||
phase = at * other.real
|
||||
if other.imag != f32(0.0):
|
||||
len /= exp(at * other.imag)
|
||||
phase += other.imag * log(vabs)
|
||||
return complex64(len * cos(phase), len * sin(phase))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@pure
|
||||
@llvm
|
||||
def copysign(x: f32, y: f32) -> f32:
|
||||
declare float @llvm.copysign.f32(float, float)
|
||||
%z = call float @llvm.copysign.f32(float %x, float %y)
|
||||
ret float %z
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def fabs(x: f32) -> f32:
|
||||
declare float @llvm.fabs.f32(float)
|
||||
%y = call float @llvm.fabs.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
if self.real == f32(0.0) and copysign(f32(1.0), self.real) == f32(1.0):
|
||||
return f"complex64({self.imag}j)"
|
||||
else:
|
||||
sign = "+"
|
||||
if self.imag < f32(0.0) or (
|
||||
self.imag == f32(0.0) and copysign(f32(1.0), self.imag) == f32(-1.0)
|
||||
):
|
||||
sign = "-"
|
||||
return f"complex64({self.real}{sign}{fabs(self.imag)}j)"
|
||||
|
||||
def conjugate(self) -> complex64:
|
||||
return complex64(self.real, -self.imag)
|
||||
|
||||
# helpers
|
||||
def _phase(self) -> f32:
|
||||
@pure
|
||||
@C
|
||||
def atan2f(a: f32, b: f32) -> f32:
|
||||
pass
|
||||
|
||||
return atan2f(self.imag, self.real)
|
||||
|
||||
def _polar(self) -> Tuple[f32, f32]:
|
||||
return (self.__abs__(), self._phase())
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def _exp(x: f32) -> f32:
|
||||
declare float @llvm.exp.f32(float)
|
||||
%y = call float @llvm.exp.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def _sqrt(x: f32) -> f32:
|
||||
declare float @llvm.sqrt.f32(float)
|
||||
%y = call float @llvm.sqrt.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def _cos(x: f32) -> f32:
|
||||
declare float @llvm.cos.f32(float)
|
||||
%y = call float @llvm.cos.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def _sin(x: f32) -> f32:
|
||||
declare float @llvm.sin.f32(float)
|
||||
%y = call float @llvm.sin.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def _log(x: f32) -> f32:
|
||||
declare float @llvm.log.f32(float)
|
||||
%y = call float @llvm.log.f32(float %x)
|
||||
ret float %y
|
||||
|
@ -794,3 +794,32 @@ def test_cmath_testcases():
|
||||
|
||||
|
||||
test_cmath_testcases()
|
||||
|
||||
|
||||
@test
|
||||
def test_complex64():
|
||||
c64 = complex64
|
||||
z = c64(.5 + .5j)
|
||||
assert +z == z
|
||||
assert -z == c64(-.5 - .5j)
|
||||
assert abs(z) == float32(0.7071067811865476)
|
||||
assert z + 1 == c64(1.5 + .5j)
|
||||
assert 1j + z == c64(.5 + 1.5j)
|
||||
assert z * 2 == c64(1 + 1j)
|
||||
assert 2j * z == c64(-1 + 1j)
|
||||
assert z / .5 == c64(1 + 1j)
|
||||
assert 1j / z == c64(1 + 1j)
|
||||
assert z ** 2 == c64(.5j)
|
||||
y = 1j ** z
|
||||
assert math.isclose(float(y.real), 0.32239694194483454)
|
||||
assert math.isclose(float(y.imag), 0.32239694194483454)
|
||||
assert z != -z
|
||||
assert z != 0
|
||||
assert z.real == float32(.5)
|
||||
assert (z + 1j).imag == float32(1.5)
|
||||
assert z.conjugate() == c64(.5 - .5j)
|
||||
assert z.__copy__() == z
|
||||
assert hash(z)
|
||||
assert c64(complex(z)) == z
|
||||
|
||||
test_complex64()
|
||||
|
Loading…
x
Reference in New Issue
Block a user