diff --git a/stdlib/internal/types/complex.codon b/stdlib/internal/types/complex.codon index d78d8ee1..ba4bbedd 100644 --- a/stdlib/internal/types/complex.codon +++ b/stdlib/internal/types/complex.codon @@ -1,5 +1,10 @@ # Copyright (C) 2022-2023 Exaloop Inc. +@tuple +class complex64: + real: float32 + imag: float32 + @tuple class complex: real: float @@ -284,3 +289,280 @@ class int: class float: def __suffix_j__(x: float) -> complex: return complex(0, x) + +f32 = float32 + +@extend +class complex64: + def __new__() -> complex64: + return (f32(0.0), f32(0.0)) + + def __new__(other): + return complex64(other.__complex__()) + + def __new__(other: complex) -> complex64: + return (f32(other.real), f32(other.imag)) + + def __new__(real, imag) -> complex64: + return (f32(float(real)), f32(float(imag))) + + def __complex__(self) -> complex: + return complex(float(self.real), float(self.imag)) + + def __bool__(self) -> bool: + return self.real != f32(0.0) and self.imag != f32(0.0) + + def __pos__(self) -> complex64: + return self + + def __neg__(self) -> complex64: + return complex64(-self.real, -self.imag) + + def __abs__(self) -> f32: + @pure + @C + def hypotf(a: f32, b: f32) -> f32: + pass + + return hypotf(self.real, self.imag) + + def __copy__(self) -> complex64: + return self + + def __hash__(self) -> int: + return self.real.__hash__() + self.imag.__hash__() * 1000003 + + def __add__(self, other) -> complex64: + return self + complex64(other) + + def __sub__(self, other) -> complex64: + return self - complex64(other) + + def __mul__(self, other) -> complex64: + return self * complex64(other) + + def __truediv__(self, other) -> complex64: + return self / complex64(other) + + def __eq__(self, other) -> bool: + return self == complex64(other) + + def __ne__(self, other) -> bool: + return self != complex64(other) + + def __pow__(self, other) -> complex64: + return self ** complex64(other) + + def __radd__(self, other) -> complex64: + return complex64(other) + self + + def __rsub__(self, other) -> complex64: + return complex64(other) - self + + def __rmul__(self, other) -> complex64: + return complex64(other) * self + + def __rtruediv__(self, other) -> complex64: + return complex64(other) / self + + def __rpow__(self, other) -> complex64: + return complex64(other) ** self + + def __add__(self, other: complex64) -> complex64: + return complex64(self.real + other.real, self.imag + other.imag) + + def __sub__(self, other: complex64) -> complex64: + return complex64(self.real - other.real, self.imag - other.imag) + + def __mul__(self, other: complex64) -> complex64: + a = (self.real * other.real) - (self.imag * other.imag) + b = (self.real * other.imag) + (self.imag * other.real) + return complex64(a, b) + + def __truediv__(self, other: complex64) -> complex64: + a = self + b = other + abs_breal = (-b.real) if b.real < f32(0) else b.real + abs_bimag = (-b.imag) if b.imag < f32(0) else b.imag + + if abs_breal >= abs_bimag: + # divide tops and bottom by b.real + if abs_breal == f32(0.0): + # errno = EDOM + return complex64(0.0, 0.0) + else: + ratio = b.imag / b.real + denom = b.real + b.imag * ratio + return complex64( + (a.real + a.imag * ratio) / denom, (a.imag - a.real * ratio) / denom + ) + elif abs_bimag >= abs_breal: + # divide tops and bottom by b.imag + ratio = b.real / b.imag + denom = b.real * ratio + b.imag + # assert b.imag != 0.0 + return complex64( + (a.real * ratio + a.imag) / denom, (a.imag * ratio - a.real) / denom + ) + else: + nan = 0.0 / 0.0 + return complex64(nan, nan) + + def __eq__(self, other: complex64) -> bool: + return self.real == other.real and self.imag == other.imag + + def __ne__(self, other: complex64) -> bool: + return not (self == other) + + def __pow__(self, other: int) -> complex64: + def powu(x: complex64, n: int) -> complex64: + mask = 1 + r = complex64(1.0, 0.0) + p = x + while mask > 0 and n >= mask: + if n & mask: + r = r * p + mask <<= 1 + p = p * p + return r + + if other > 0: + return powu(self, other) + else: + return complex64(1.0, 0.0) / powu(self, -other) + + def __pow__(self, other: complex64) -> complex64: + @pure + @C + def hypotf(a: f32, b: f32) -> f32: + pass + + @pure + @C + def atan2f(a: f32, b: f32) -> f32: + pass + + @pure + @llvm + def exp(x: f32) -> f32: + declare float @llvm.exp.f32(float) + %y = call float @llvm.exp.f32(float %x) + ret float %y + + @pure + @llvm + def pow(x: f32, y: f32) -> f32: + declare float @llvm.pow.f32(float, float) + %z = call float @llvm.pow.f32(float %x, float %y) + ret float %z + + @pure + @llvm + def log(x: f32) -> f32: + declare float @llvm.log.f32(float) + %y = call float @llvm.log.f32(float %x) + ret float %y + + @pure + @llvm + def sin(x: f32) -> f32: + declare float @llvm.sin.f32(float) + %y = call float @llvm.sin.f32(float %x) + ret float %y + + @pure + @llvm + def cos(x: f32) -> f32: + declare float @llvm.cos.f32(float) + %y = call float @llvm.cos.f32(float %x) + ret float %y + + if other.real == f32(0.0) and other.imag == f32(0.0): + return complex64(1.0, 0.0) + elif self.real == f32(0.0) and self.imag == f32(0.0): + # if other.imag != 0. or other.real < 0.: errno = EDOM + return complex64(0.0, 0.0) + else: + vabs = hypotf(self.real, self.imag) + len = pow(vabs, other.real) + at = atan2f(self.imag, self.real) + phase = at * other.real + if other.imag != f32(0.0): + len /= exp(at * other.imag) + phase += other.imag * log(vabs) + return complex64(len * cos(phase), len * sin(phase)) + + def __repr__(self) -> str: + @pure + @llvm + def copysign(x: f32, y: f32) -> f32: + declare float @llvm.copysign.f32(float, float) + %z = call float @llvm.copysign.f32(float %x, float %y) + ret float %z + + @pure + @llvm + def fabs(x: f32) -> f32: + declare float @llvm.fabs.f32(float) + %y = call float @llvm.fabs.f32(float %x) + ret float %y + + if self.real == f32(0.0) and copysign(f32(1.0), self.real) == f32(1.0): + return f"complex64({self.imag}j)" + else: + sign = "+" + if self.imag < f32(0.0) or ( + self.imag == f32(0.0) and copysign(f32(1.0), self.imag) == f32(-1.0) + ): + sign = "-" + return f"complex64({self.real}{sign}{fabs(self.imag)}j)" + + def conjugate(self) -> complex64: + return complex64(self.real, -self.imag) + + # helpers + def _phase(self) -> f32: + @pure + @C + def atan2f(a: f32, b: f32) -> f32: + pass + + return atan2f(self.imag, self.real) + + def _polar(self) -> Tuple[f32, f32]: + return (self.__abs__(), self._phase()) + + @pure + @llvm + def _exp(x: f32) -> f32: + declare float @llvm.exp.f32(float) + %y = call float @llvm.exp.f32(float %x) + ret float %y + + @pure + @llvm + def _sqrt(x: f32) -> f32: + declare float @llvm.sqrt.f32(float) + %y = call float @llvm.sqrt.f32(float %x) + ret float %y + + @pure + @llvm + def _cos(x: f32) -> f32: + declare float @llvm.cos.f32(float) + %y = call float @llvm.cos.f32(float %x) + ret float %y + + @pure + @llvm + def _sin(x: f32) -> f32: + declare float @llvm.sin.f32(float) + %y = call float @llvm.sin.f32(float %x) + ret float %y + + @pure + @llvm + def _log(x: f32) -> f32: + declare float @llvm.log.f32(float) + %y = call float @llvm.log.f32(float %x) + ret float %y diff --git a/test/stdlib/cmath_test.codon b/test/stdlib/cmath_test.codon index 1d2da3bf..440e3693 100644 --- a/test/stdlib/cmath_test.codon +++ b/test/stdlib/cmath_test.codon @@ -794,3 +794,32 @@ def test_cmath_testcases(): test_cmath_testcases() + + +@test +def test_complex64(): + c64 = complex64 + z = c64(.5 + .5j) + assert +z == z + assert -z == c64(-.5 - .5j) + assert abs(z) == float32(0.7071067811865476) + assert z + 1 == c64(1.5 + .5j) + assert 1j + z == c64(.5 + 1.5j) + assert z * 2 == c64(1 + 1j) + assert 2j * z == c64(-1 + 1j) + assert z / .5 == c64(1 + 1j) + assert 1j / z == c64(1 + 1j) + assert z ** 2 == c64(.5j) + y = 1j ** z + assert math.isclose(float(y.real), 0.32239694194483454) + assert math.isclose(float(y.imag), 0.32239694194483454) + assert z != -z + assert z != 0 + assert z.real == float32(.5) + assert (z + 1j).imag == float32(1.5) + assert z.conjugate() == c64(.5 - .5j) + assert z.__copy__() == z + assert hash(z) + assert c64(complex(z)) == z + +test_complex64()