From 56c00d36c2abe76afc18ac0e9e52e12a330a96a0 Mon Sep 17 00:00:00 2001 From: "A. R. Shajii" Date: Thu, 6 Feb 2025 14:11:52 -0500 Subject: [PATCH] Add additional int-float operators --- stdlib/internal/types/float.codon | 414 ++++++++++++++++++++++++++++++ test/core/arithmetic.codon | 52 ++++ 2 files changed, 466 insertions(+) diff --git a/stdlib/internal/types/float.codon b/stdlib/internal/types/float.codon index 584514bf..15e3601e 100644 --- a/stdlib/internal/types/float.codon +++ b/stdlib/internal/types/float.codon @@ -4,6 +4,22 @@ from internal.attributes import commutative from internal.gc import alloc_atomic, free from internal.types.complex import complex +def _float_int_pow(a: F, b: int, F: type) -> F: + abs_exp = b.__abs__() + result = F(1) + factor = a + + while abs_exp: + if abs_exp & 1: + result *= factor + factor *= factor + abs_exp >>= 1 + + if b < 0: + result = F(1) / result + + return result + @extend class float: def __new__() -> float: @@ -401,6 +417,50 @@ class float: def imag(self) -> float: return 0.0 + @commutative + def __add__(self: float, b: int) -> float: + return self + float(b) + + def __sub__(self: float, b: int) -> float: + return self - float(b) + + @commutative + def __mul__(self: float, b: int) -> float: + return self * float(b) + + def __floordiv__(self, b: int) -> float: + return self // float(b) + + def __truediv__(self: float, b: int) -> float: + return self / float(b) + + def __mod__(self: float, b: int) -> float: + return self % float(b) + + def __divmod__(self, b: int): + return self.__divmod__(float(b)) + + def __eq__(self: float, b: int) -> bool: + return self == float(b) + + def __ne__(self: float, b: int) -> bool: + return self != float(b) + + def __lt__(self: float, b: int) -> bool: + return self < float(b) + + def __gt__(self: float, b: int) -> bool: + return self > float(b) + + def __le__(self: float, b: int) -> bool: + return self <= float(b) + + def __ge__(self: float, b: int) -> bool: + return self >= float(b) + + def __pow__(self: float, b: int) -> float: + return _float_int_pow(self, b) + @extend class float32: @pure @@ -755,6 +815,50 @@ class float32: def __match__(self, obj: float32) -> bool: return self == obj + @commutative + def __add__(self: float32, b: int) -> float32: + return self + float32(b) + + def __sub__(self: float32, b: int) -> float32: + return self - float32(b) + + @commutative + def __mul__(self: float32, b: int) -> float32: + return self * float32(b) + + def __floordiv__(self, b: int) -> float32: + return self // float32(b) + + def __truediv__(self: float32, b: int) -> float32: + return self / float32(b) + + def __mod__(self: float32, b: int) -> float32: + return self % float32(b) + + def __divmod__(self, b: int): + return self.__divmod__(float32(b)) + + def __eq__(self: float32, b: int) -> bool: + return self == float32(b) + + def __ne__(self: float32, b: int) -> bool: + return self != float32(b) + + def __lt__(self: float32, b: int) -> bool: + return self < float32(b) + + def __gt__(self: float32, b: int) -> bool: + return self > float32(b) + + def __le__(self: float32, b: int) -> bool: + return self <= float32(b) + + def __ge__(self: float32, b: int) -> bool: + return self >= float32(b) + + def __pow__(self: float32, b: int) -> float32: + return _float_int_pow(self, b) + @extend class float16: @pure @@ -1055,6 +1159,50 @@ class float16: def __match__(self, obj: float16) -> bool: return self == obj + @commutative + def __add__(self: float16, b: int) -> float16: + return self + float16(b) + + def __sub__(self: float16, b: int) -> float16: + return self - float16(b) + + @commutative + def __mul__(self: float16, b: int) -> float16: + return self * float16(b) + + def __floordiv__(self, b: int) -> float16: + return self // float16(b) + + def __truediv__(self: float16, b: int) -> float16: + return self / float16(b) + + def __mod__(self: float16, b: int) -> float16: + return self % float16(b) + + def __divmod__(self, b: int): + return self.__divmod__(float16(b)) + + def __eq__(self: float16, b: int) -> bool: + return self == float16(b) + + def __ne__(self: float16, b: int) -> bool: + return self != float16(b) + + def __lt__(self: float16, b: int) -> bool: + return self < float16(b) + + def __gt__(self: float16, b: int) -> bool: + return self > float16(b) + + def __le__(self: float16, b: int) -> bool: + return self <= float16(b) + + def __ge__(self: float16, b: int) -> bool: + return self >= float16(b) + + def __pow__(self: float16, b: int) -> float16: + return _float_int_pow(self, b) + @extend class bfloat16: @pure @@ -1355,6 +1503,50 @@ class bfloat16: def __match__(self, obj: bfloat16) -> bool: return self == obj + @commutative + def __add__(self: bfloat16, b: int) -> bfloat16: + return self + bfloat16(b) + + def __sub__(self: bfloat16, b: int) -> bfloat16: + return self - bfloat16(b) + + @commutative + def __mul__(self: bfloat16, b: int) -> bfloat16: + return self * bfloat16(b) + + def __floordiv__(self, b: int) -> bfloat16: + return self // bfloat16(b) + + def __truediv__(self: bfloat16, b: int) -> bfloat16: + return self / bfloat16(b) + + def __mod__(self: bfloat16, b: int) -> bfloat16: + return self % bfloat16(b) + + def __divmod__(self, b: int): + return self.__divmod__(bfloat16(b)) + + def __eq__(self: bfloat16, b: int) -> bool: + return self == bfloat16(b) + + def __ne__(self: bfloat16, b: int) -> bool: + return self != bfloat16(b) + + def __lt__(self: bfloat16, b: int) -> bool: + return self < bfloat16(b) + + def __gt__(self: bfloat16, b: int) -> bool: + return self > bfloat16(b) + + def __le__(self: bfloat16, b: int) -> bool: + return self <= bfloat16(b) + + def __ge__(self: bfloat16, b: int) -> bool: + return self >= bfloat16(b) + + def __pow__(self: bfloat16, b: int) -> bfloat16: + return _float_int_pow(self, b) + @extend class float128: @pure @@ -1652,6 +1844,50 @@ class float128: def __match__(self, obj: float128) -> bool: return self == obj + @commutative + def __add__(self: float128, b: int) -> float128: + return self + float128(b) + + def __sub__(self: float128, b: int) -> float128: + return self - float128(b) + + @commutative + def __mul__(self: float128, b: int) -> float128: + return self * float128(b) + + def __floordiv__(self, b: int) -> float128: + return self // float128(b) + + def __truediv__(self: float128, b: int) -> float128: + return self / float128(b) + + def __mod__(self: float128, b: int) -> float128: + return self % float128(b) + + def __divmod__(self, b: int): + return self.__divmod__(float128(b)) + + def __eq__(self: float128, b: int) -> bool: + return self == float128(b) + + def __ne__(self: float128, b: int) -> bool: + return self != float128(b) + + def __lt__(self: float128, b: int) -> bool: + return self < float128(b) + + def __gt__(self: float128, b: int) -> bool: + return self > float128(b) + + def __le__(self: float128, b: int) -> bool: + return self <= float128(b) + + def __ge__(self: float128, b: int) -> bool: + return self >= float128(b) + + def __pow__(self: float128, b: int) -> float128: + return _float_int_pow(self, b) + @extend class float: def __suffix_f32__(double) -> float32: @@ -1666,6 +1902,184 @@ class float: def __suffix_f128__(double) -> float128: return float128.__new__(double) +@extend +class int: + @commutative + def __add__(self, b: float32) -> float32: + return float32(self) + b + + def __sub__(self, b: float32) -> float32: + return float32(self) - b + + @commutative + def __mul__(self, b: float32) -> float32: + return float32(self) * b + + def __floordiv__(self, b: float32) -> float32: + return float32(self) // b + + def __truediv__(self, b: float32) -> float32: + return float32(self) / b + + def __mod__(self, b: float32) -> float32: + return float32(self) % b + + def __divmod__(self, b: float32): + return float32(self).__divmod__(b) + + def __pow__(self, b: float32) -> float32: + return float32(self) ** b + + def __eq__(self, b: float32) -> bool: + return float32(self) == b + + def __ne__(self, b: float32) -> bool: + return float32(self) != b + + def __lt__(self, b: float32) -> bool: + return float32(self) < b + + def __gt__(self, b: float32) -> bool: + return float32(self) > b + + def __le__(self, b: float32) -> bool: + return float32(self) <= b + + def __ge__(self, b: float32) -> bool: + return float32(self) >= b + + @commutative + def __add__(self, b: float16) -> float16: + return float16(self) + b + + def __sub__(self, b: float16) -> float16: + return float16(self) - b + + @commutative + def __mul__(self, b: float16) -> float16: + return float16(self) * b + + def __floordiv__(self, b: float16) -> float16: + return float16(self) // b + + def __truediv__(self, b: float16) -> float16: + return float16(self) / b + + def __mod__(self, b: float16) -> float16: + return float16(self) % b + + def __divmod__(self, b: float16): + return float16(self).__divmod__(b) + + def __pow__(self, b: float16) -> float16: + return float16(self) ** b + + def __eq__(self, b: float16) -> bool: + return float16(self) == b + + def __ne__(self, b: float16) -> bool: + return float16(self) != b + + def __lt__(self, b: float16) -> bool: + return float16(self) < b + + def __gt__(self, b: float16) -> bool: + return float16(self) > b + + def __le__(self, b: float16) -> bool: + return float16(self) <= b + + def __ge__(self, b: float16) -> bool: + return float16(self) >= b + + @commutative + def __add__(self, b: bfloat16) -> bfloat16: + return bfloat16(self) + b + + def __sub__(self, b: bfloat16) -> bfloat16: + return bfloat16(self) - b + + @commutative + def __mul__(self, b: bfloat16) -> bfloat16: + return bfloat16(self) * b + + def __floordiv__(self, b: bfloat16) -> bfloat16: + return bfloat16(self) // b + + def __truediv__(self, b: bfloat16) -> bfloat16: + return bfloat16(self) / b + + def __mod__(self, b: bfloat16) -> bfloat16: + return bfloat16(self) % b + + def __divmod__(self, b: bfloat16): + return bfloat16(self).__divmod__(b) + + def __pow__(self, b: bfloat16) -> bfloat16: + return bfloat16(self) ** b + + def __eq__(self, b: bfloat16) -> bool: + return bfloat16(self) == b + + def __ne__(self, b: bfloat16) -> bool: + return bfloat16(self) != b + + def __lt__(self, b: bfloat16) -> bool: + return bfloat16(self) < b + + def __gt__(self, b: bfloat16) -> bool: + return bfloat16(self) > b + + def __le__(self, b: bfloat16) -> bool: + return bfloat16(self) <= b + + def __ge__(self, b: bfloat16) -> bool: + return bfloat16(self) >= b + + @commutative + def __add__(self, b: float128) -> float128: + return float128(self) + b + + def __sub__(self, b: float128) -> float128: + return float128(self) - b + + @commutative + def __mul__(self, b: float128) -> float128: + return float128(self) * b + + def __floordiv__(self, b: float128) -> float128: + return float128(self) // b + + def __truediv__(self, b: float128) -> float128: + return float128(self) / b + + def __mod__(self, b: float128) -> float128: + return float128(self) % b + + def __divmod__(self, b: float128): + return float128(self).__divmod__(b) + + def __pow__(self, b: float128) -> float128: + return float128(self) ** b + + def __eq__(self, b: float128) -> bool: + return float128(self) == b + + def __ne__(self, b: float128) -> bool: + return float128(self) != b + + def __lt__(self, b: float128) -> bool: + return float128(self) < b + + def __gt__(self, b: float128) -> bool: + return float128(self) > b + + def __le__(self, b: float128) -> bool: + return float128(self) <= b + + def __ge__(self, b: float128) -> bool: + return float128(self) >= b + f16 = float16 bf16 = bfloat16 f32 = float32 diff --git a/test/core/arithmetic.codon b/test/core/arithmetic.codon index 147ed4b8..e1220c77 100644 --- a/test/core/arithmetic.codon +++ b/test/core/arithmetic.codon @@ -199,3 +199,55 @@ def test_float_out_of_range_parse(): assert 1e10000 == float('inf') test_float_out_of_range_parse() + +@test +def test_int_float_ops(F: type): + def check(got, exp=True): + return (exp == got) and (type(exp) is type(got)) + + # standard + assert check(F(1.5) + 1, F(2.5)) + assert check(F(1.5) - 1, F(0.5)) + assert check(F(1.5) * 2, F(3.0)) + assert check(F(1.5) / 2, F(0.75)) + assert check(F(3.5) // 2, F(1.0)) + assert check(F(3.5) % 2, F(1.5)) + assert check(F(3.5) ** 2, F(12.25)) + assert check(divmod(F(3.5), 2), (F(1.0), F(1.5))) + + # right-hand ops + assert check(1 + F(1.5), F(2.5)) + assert check(1 - F(1.5), F(-0.5)) + assert check(2 * F(1.5), F(3.0)) + assert check(2 / F(2.5), F(0.8)) + assert check(2 // F(1.5), F(1.0)) + assert check(2 % F(1.5), F(0.5)) + assert check(4 ** F(2.5), F(32.0)) + assert check(divmod(4, F(2.5)), (F(1.0), F(1.5))) + + # comparisons + assert check(F(1.0) == 1) + assert check(F(2.0) != 1) + assert check(F(0.0) < 1) + assert check(F(2.0) > 1) + assert check(F(0.0) <= 1) + assert check(F(2.0) >= 1) + assert check(1 == F(1.0)) + assert check(1 != F(2.0)) + assert check(1 < F(2.0)) + assert check(1 > F(0.0)) + assert check(1 <= F(2.0)) + assert check(1 >= F(0.0)) + + # power + assert check(F(3.5) ** 1, F(3.5)) + assert check(F(3.5) ** 2, F(12.25)) + assert check(F(3.5) ** 3, F(42.875)) + assert check(F(4.0) ** -1, F(0.25)) + assert check(F(4.0) ** -2, F(0.0625)) + assert check(F(4.0) ** -3, F(0.015625)) + assert check(F(3.5) ** 0, F(1.0)) + +test_int_float_ops(float) +test_int_float_ops(float32) +test_int_float_ops(float16)