codon/stdlib/numpy/util.codon

3018 lines
79 KiB
Python

# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
from .npdatetime import datetime64, timedelta64, _promote as dt_promote, \
_parse_datetime_type as dt_parse
##############
# Exceptions #
##############
class AxisError(Static[Exception]):
def __init__(self, message: str = ''):
super().__init__("numpy.AxisError", message)
class TooHardError(Static[Exception]):
def __init__(self, message: str = ''):
super().__init__("numpy.TooHardError", message)
##############
# Singletons #
##############
@tuple
class _NoValue:
pass
###################
# Tuples / Shapes #
###################
def tuple_get(tup, idx):
p = Ptr[type(tup[0])](__ptr__(tup).as_byte())
return p[idx]
def tuple_set(tup, idx, elm, init = 0):
tup2 = (init,) * staticlen(tup)
p = Ptr[type(init)](__ptr__(tup).as_byte())
q = Ptr[type(init)](__ptr__(tup2).as_byte())
i = 0
while i < staticlen(tup):
if i == idx:
q[i] = elm
else:
q[i] = p[i]
i += 1
return tup2
def tuple_add(tup, idx, inc, init = 0):
tup2 = (init,) * staticlen(tup)
p = Ptr[type(init)](__ptr__(tup).as_byte())
q = Ptr[type(init)](__ptr__(tup2).as_byte())
i = 0
while i < staticlen(tup):
if i == idx:
q[i] = p[i] + inc
else:
q[i] = p[i]
i += 1
return tup2
def tuple_delete(tup, idx, init = 0):
tup2 = (init,) * (staticlen(tup) - 1)
p = Ptr[type(init)](__ptr__(tup).as_byte())
q = Ptr[type(init)](__ptr__(tup2).as_byte())
i = 0
j = 0
while i < staticlen(tup) - 1:
if j != idx:
q[i] = p[j]
i += 1
j += 1
return tup2
def tuple_insert(tup, idx, elm, init = 0):
tup2 = (init,) * (staticlen(tup) + 1)
p = Ptr[type(init)](__ptr__(tup).as_byte())
q = Ptr[type(init)](__ptr__(tup2).as_byte())
i = 0
j = 0
while i < staticlen(tup) + 1:
if i == idx:
q[i] = elm
else:
q[i] = p[j]
j += 1
i += 1
return tup2
def tuple_swap(tup, idx1, idx2, init = 0):
tup2 = tup
p = Ptr[type(init)](__ptr__(tup2).as_byte())
tmp = p[idx1]
p[idx1] = p[idx2]
p[idx2] = tmp
return tup2
def tuple_perm(tup, perm):
if perm is None:
return tup
if staticlen(tup) != staticlen(perm):
compile_error("[internal error] tuple_perm got different tuple lengths")
if staticlen(tup) <= 1:
return tup
return tuple(tuple_get(tup, p) for p in perm)
def tuple_perm_inv(perm):
iperm = (0,) * staticlen(perm)
p = Ptr[int](__ptr__(perm).as_byte())
q = Ptr[int](__ptr__(iperm).as_byte())
for i in staticrange(staticlen(perm)):
q[p[i]] = i
return iperm
def tuple_apply(fn, tup1, tup2):
if staticlen(tup1) != staticlen(tup2):
compile_error("[internal error] tuple_apply got different tuple lengths")
if staticlen(tup1) == 0:
return ()
T = type(tup1[0])
tup3 = tup1
p1 = Ptr[T](__ptr__(tup1).as_byte())
p2 = Ptr[T](__ptr__(tup2).as_byte())
p3 = Ptr[T](__ptr__(tup3).as_byte())
for i in range(len(tup1)):
p3[i] = fn(p1[i], p2[i])
return tup3
def tuple_insert_tuple(tup, idx, ins, init = 0):
tup2 = (init,) * (staticlen(tup) + staticlen(ins))
p1 = Ptr[type(init)](__ptr__(tup).as_byte())
p2 = Ptr[type(init)](__ptr__(ins).as_byte())
q = Ptr[type(init)](__ptr__(tup2).as_byte())
j = 0
for i in range(idx):
q[j] = p1[i]
j += 1
for i in range(len(ins)):
q[j] = p2[i]
j += 1
for i in range(idx, len(tup)):
q[j] = p1[i]
j += 1
return tup2
def count(shape):
if staticlen(shape) == 0:
return 1
total = 1
for i in range(staticlen(shape)):
total *= shape[i]
return total
def tuple_range(n: Static[int]):
if n <= 0:
return ()
return (*tuple_range(n - 1), n - 1)
def tuple_equal(t1, t2):
if staticlen(t1) != staticlen(t2):
return False
return t1 == t2
def broadcast(shape1, shape2):
D1: Static[int] = staticlen(shape1)
D2: Static[int] = staticlen(shape2)
Dmin: Static[int] = D1 if D1 < D2 else D2
Dmax: Static[int] = D1 if D1 > D2 else D2
Diff: Static[int] = Dmax - Dmin
t1 = shape1[-Dmin:]
t2 = shape2[-Dmin:]
ans = (0,) * Dmax
p = Ptr[int](__ptr__(ans).as_byte())
for i in staticrange(Diff):
if D1 > D2:
p[i] = shape1[i]
else:
p[i] = shape2[i]
for i in staticrange(Dmin):
a = t1[i]
b = t2[i]
dim = a
if a == 1:
dim = b
elif b != 1 and a != b:
raise ValueError(f'operands could not be broadcast together with shapes {shape1} {shape2}')
p[i + Diff] = dim
return ans
###########
# Indexes #
###########
def normalize_index(idx: int, axis: int, n: int):
idx0 = idx
if idx < 0:
idx += n
if idx < 0 or idx >= n:
raise IndexError(f'index {idx0} is out of bounds for axis {axis} with size {n}')
return idx
def normalize_axis_index(axis: int, ndim: int, prefix: str = ''):
if axis < -ndim or axis >= ndim:
raise AxisError(f"{prefix}{': ' if prefix else ''}axis {axis} is out of bounds for array of dimension {ndim}")
if axis < 0:
axis += ndim
return axis
def has_duplicate(t):
# use O(n^2) method since this ndim should be small
n = len(t)
for i in range(1, n):
for j in range(i):
if t[i] == t[j]:
return True
return False
def normalize_axis_tuple(axis, ndim: int, argname: str = '', allow_duplicates: bool = False):
if isinstance(axis, Tuple):
t = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis)
elif isinstance(axis, List[int]):
t = [normalize_axis_index(ax, ndim, argname) for ax in axis]
elif isinstance(axis, int):
t = (normalize_axis_index(axis, ndim, argname),)
else:
compile_error("invalid axis argument: type must be int, tuple of int, or list of int")
if not allow_duplicates and has_duplicate(t):
if argname:
raise ValueError(f"repeated axis in `{argname}` argument")
else:
raise ValueError("repeated axis")
return t
def reconstruct_index(t1, t2, mask):
if staticlen(t1) + staticlen(t2) != staticlen(mask):
compile_error("[internal error] bad index reconstruction")
t = (0,) * staticlen(mask)
p1 = Ptr[int](__ptr__(t1).as_byte())
p2 = Ptr[int](__ptr__(t2).as_byte())
p = Ptr[int](__ptr__(t).as_byte())
m = Ptr[bool](__ptr__(mask).as_byte())
k1, k2 = 0, 0
for i in range(staticlen(mask)):
if m[i]:
p[i] = p1[k1]
k1 += 1
else:
p[i] = p2[k2]
k2 += 1
return t
def index_to_coords(index: int, limits):
if staticlen(limits) == 0:
return ()
coords = (0,) * staticlen(limits)
pcoords = Ptr[int](__ptr__(coords).as_byte())
for i in range(staticlen(coords) - 1, -1, -1):
d, m = divmod(index, limits[i])
pcoords[i] = m
index = d
return coords
def index_to_fcoords(index: int, limits):
if staticlen(limits) == 0:
return ()
coords = (0,) * staticlen(limits)
pcoords = Ptr[int](__ptr__(coords).as_byte())
for i in range(staticlen(coords)):
d, m = divmod(index, limits[i])
pcoords[i] = m
index = d
return coords
def coords_to_index(coords: S, limits: S, S: type):
s = 1
idx = 0
for i in staticrange(staticlen(coords) - 1, -1, -1):
idx += coords[i] * s
s *= limits[i]
return idx
def coords_to_findex(coords: S, limits: S, S: type):
s = 1
idx = 0
for i in staticrange(staticlen(coords)):
idx += coords[i] * s
s *= limits[i]
return idx
#######################
# Sorting / Searching #
#######################
def sort(p: Ptr[T], n: int, key, T: type):
from algorithms.pdqsort import pdq_sort_array
pdq_sort_array(Array(p, n), n, key)
def sort_by_stride(shape, strides):
if staticlen(strides) <= 1:
return shape, strides
strides_sorted = strides
shape_sorted = shape
pstrides = Ptr[int](__ptr__(strides_sorted).as_byte())
pshape = Ptr[int](__ptr__(shape_sorted).as_byte())
n = len(strides)
while True:
swapped = False
for i in range(1, n):
if abs(pstrides[i - 1]) < abs(pstrides[i]):
tmp = pstrides[i]
pstrides[i] = pstrides[i - 1]
pstrides[i - 1] = tmp
tmp = pshape[i]
pshape[i] = pshape[i - 1]
pshape[i - 1] = tmp
swapped = True
if not swapped:
break
n -= 1
return shape_sorted, strides_sorted
# Following is adapted from Numba - github.com/numba/numba
def _partition(A, low: int, high: int):
mid = (low + high) >> 1
if A[mid] < A[low]:
A[low], A[mid] = A[mid], A[low]
if A[high] < A[mid]:
A[high], A[mid] = A[mid], A[high]
if A[mid] < A[low]:
A[low], A[mid] = A[mid], A[low]
pivot = A[mid]
A[high], A[mid] = A[mid], A[high]
i = low
for j in range(low, high):
if A[j] <= pivot:
A[i], A[j] = A[j], A[i]
i += 1
A[i], A[high] = A[high], A[i]
return i
def _select(arry, k: int, low: int, high: int):
i = _partition(arry, low, high)
while i != k:
if i < k:
low = i + 1
i = _partition(arry, low, high)
else:
high = i - 1
i = _partition(arry, low, high)
return arry[k]
def _select_two(arry, k: int, low: int, high: int):
while True:
# assert high > low # by construction
i = _partition(arry, low, high)
if i < k:
low = i + 1
elif i > k + 1:
high = i - 1
elif i == k:
_select(arry, k + 1, i + 1, high)
break
else: # i == k + 1
_select(arry, k, low, i - 1)
break
return arry[k], arry[k + 1]
def median(arr, n: int):
low = 0
high = n - 1
half = n >> 1
if n & 1 == 0:
a, b = _select_two(arr, half - 1, low, high)
return (a, b)
else:
a = _select(arr, half, low, high)
return (a, a)
#############
# Iterators #
#############
def multirange(limits):
if staticlen(limits) == 0:
yield ()
elif staticlen(limits) == 1:
x = limits[0]
for i in range(x):
yield (i,)
elif staticlen(limits) == 2:
x, y = limits
for i in range(x):
for j in range(y):
yield (i, j)
elif staticlen(limits) == 3:
x, y, z = limits
for i in range(x):
for j in range(y):
for k in range(z):
yield (i, j, k)
elif staticlen(limits) == 4:
n0, n1, n2, n3 = limits
for i0 in range(n0):
for i1 in range(n1):
for i2 in range(n2):
for i3 in range(n3):
yield (i0, i1, i2, i3)
elif staticlen(limits) == 5:
n0, n1, n2, n3, n4 = limits
for i0 in range(n0):
for i1 in range(n1):
for i2 in range(n2):
for i3 in range(n3):
for i4 in range(n4):
yield (i0, i1, i2, i3, i4)
elif staticlen(limits) == 6:
n0, n1, n2, n3, n4, n5 = limits
for i0 in range(n0):
for i1 in range(n1):
for i2 in range(n2):
for i3 in range(n3):
for i4 in range(n4):
for i5 in range(n5):
yield (i0, i1, i2, i3, i4, i5)
else:
n = limits[0]
for i in range(n):
for idx in multirange(limits[1:]):
yield (i,) + idx
def fmultirange(limits):
if staticlen(limits) == 0:
yield ()
elif staticlen(limits) == 1:
x = limits[0]
for i in range(x):
yield (i,)
elif staticlen(limits) == 2:
x, y = limits
for j in range(y):
for i in range(x):
yield (i, j)
elif staticlen(limits) == 3:
x, y, z = limits
for k in range(z):
for j in range(y):
for i in range(x):
yield (i, j, k)
elif staticlen(limits) == 4:
n0, n1, n2, n3 = limits
for i3 in range(n3):
for i2 in range(n2):
for i1 in range(n1):
for i0 in range(n0):
yield (i0, i1, i2, i3)
elif staticlen(limits) == 5:
n0, n1, n2, n3, n4 = limits
for i4 in range(n4):
for i3 in range(n3):
for i2 in range(n2):
for i1 in range(n1):
for i0 in range(n0):
yield (i0, i1, i2, i3, i4)
elif staticlen(limits) == 6:
n0, n1, n2, n3, n4, n5 = limits
for i5 in range(n5):
for i4 in range(n4):
for i3 in range(n3):
for i2 in range(n2):
for i1 in range(n1):
for i0 in range(n0):
yield (i0, i1, i2, i3, i4, i5)
else:
n = limits[0]
for idx in fmultirange(limits[1:]):
for i in range(n):
yield idx + (i,)
########
# Math #
########
@pure
@llvm
def noop(x: T, D: type, T: type) -> D:
ret {=D} %x
@pure
@llvm
def zero(T: type) -> T:
ret {=T} zeroinitializer
@pure
@llvm
def inf64() -> float:
ret double 0x7FF0000000000000
@pure
@llvm
def inf32() -> float32:
ret float 0x7FF0000000000000
@pure
@llvm
def inf16() -> float16:
ret half 0x7FF0000000000000
def inf(T: type):
if T is float:
return inf64()
elif T is float32:
return inf32()
elif T is float16:
return inf16()
@pure
@llvm
def nan64() -> float:
ret double 0x7FF8000000000000
@pure
@llvm
def nan32() -> float32:
ret float 0x7FF8000000000000
@pure
@llvm
def nan16() -> float16:
ret half 0x7FF8000000000000
def nan(T: type):
if T is float:
return nan64()
elif T is float32:
return nan32()
elif T is float16:
return nan16()
@pure
@llvm
def minnum64() -> float:
ret double 0x10000000000000
@pure
@llvm
def minnum32() -> float32:
ret float 0x3810000000000000
@pure
@llvm
def minnum16() -> float16:
ret half 0xHFBFF
def minnum(T: type):
if T is float:
return minnum64()
elif T is float32:
return minnum32()
elif T is float16:
return minnum16()
@pure
@llvm
def maxnum64() -> float:
ret double 0x7FEFFFFFFFFFFFFF
@pure
@llvm
def maxnum32() -> float32:
ret float 0x47EFFFFFE0000000
@pure
@llvm
def maxnum16() -> float16:
ret half 0xH7BFF
def maxnum(T: type):
if T is float:
return maxnum64()
elif T is float32:
return maxnum32()
elif T is float16:
return maxnum16()
@pure
@llvm
def eps64() -> float:
ret double 0x3CB0000000000000
@pure
@llvm
def eps32() -> float32:
ret float 0x3E80000000000000
@pure
@llvm
def eps16() -> float16:
ret half 0xH1400
def eps(T: type):
if T is float:
return eps64()
elif T is float32:
return eps32()
elif T is float16:
return eps16()
def mantdig64():
return 53
def mantdig32():
return 24
def mantdig16():
return 11
def mantdig(T: type):
if T is float:
return mantdig64()
elif T is float32:
return mantdig32()
elif T is float16:
return mantdig16()
def maxexp64():
return 1024
def maxexp32():
return 128
def maxexp16():
return 16
def maxexp(T: type):
if T is float:
return maxexp64()
elif T is float32:
return maxexp32()
elif T is float16:
return maxexp16()
@pure
@llvm
def bitcast(x: T, D: type, T: type) -> D:
%y = bitcast {=T} %x to {=D}
ret {=D} %y
@pure
@llvm
def uitofp(x: T, D: type, T: type) -> D:
%y = uitofp {=T} %x to {=D}
ret {=D} %y
@pure
@llvm
def sitofp(x: T, D: type, T: type) -> D:
%y = sitofp {=T} %x to {=D}
ret {=D} %y
@pure
@llvm
def fptoui(x: T, D: type, T: type) -> D:
%y = fptoui {=T} %x to {=D}
ret {=D} %y
@pure
@llvm
def fptosi(x: T, D: type, T: type) -> D:
%y = fptosi {=T} %x to {=D}
ret {=D} %y
@pure
@llvm
def zext(x: T, D: type, T: type) -> D:
%y = zext {=T} %x to {=D}
ret {=D} %y
@pure
@llvm
def sext(x: T, D: type, T: type) -> D:
%y = sext {=T} %x to {=D}
ret {=D} %y
@pure
@llvm
def itrunc(x: T, D: type, T: type) -> D:
%y = trunc {=T} %x to {=D}
ret {=D} %y
@pure
@llvm
def fpext(x: T, D: type, T: type) -> D:
%y = fpext {=T} %x to {=D}
ret {=D} %y
@pure
@llvm
def fptrunc(x: T, D: type, T: type) -> D:
%y = fptrunc {=T} %x to {=D}
ret {=D} %y
@pure
@llvm
def fmin64(x: float, y: float) -> float:
declare double @llvm.minimum.f64(double, double)
%z = call double @llvm.minimum.f64(double %x, double %y)
ret double %z
@pure
@llvm
def fmin32(x: float32, y: float32) -> float32:
declare float @llvm.minimum.f32(float, float)
%z = call float @llvm.minimum.f32(float %x, float %y)
ret float %z
@pure
@llvm
def fmin16(x: float16, y: float16) -> float16:
declare half @llvm.minimum.f16(half, half)
%z = call half @llvm.minimum.f16(half %x, half %y)
ret half %z
def fmin(x, y):
if isinstance(x, float) and isinstance(y, float):
return fmin64(x, y)
elif isinstance(x, float32) and isinstance(y, float32):
return fmin32(x, y)
elif isinstance(x, float16) and isinstance(y, float16):
return fmin16(x, y)
@pure
@llvm
def fmax64(x: float, y: float) -> float:
declare double @llvm.maximum.f64(double, double)
%z = call double @llvm.maximum.f64(double %x, double %y)
ret double %z
@pure
@llvm
def fmax32(x: float32, y: float32) -> float32:
declare float @llvm.maximum.f32(float, float)
%z = call float @llvm.maximum.f32(float %x, float %y)
ret float %z
@pure
@llvm
def fmax16(x: float16, y: float16) -> float16:
declare half @llvm.maximum.f16(half, half)
%z = call half @llvm.maximum.f16(half %x, half %y)
ret half %z
def fmax(x, y):
if isinstance(x, float) and isinstance(y, float):
return fmax64(x, y)
elif isinstance(x, float32) and isinstance(y, float32):
return fmax32(x, y)
elif isinstance(x, float16) and isinstance(y, float16):
return fmax16(x, y)
@pure
@llvm
def fminnum64(x: float, y: float) -> float:
declare double @llvm.minimumnum.f64(double, double)
%z = call double @llvm.minimumnum.f64(double %x, double %y)
ret double %z
@pure
@llvm
def fminnum32(x: float32, y: float32) -> float32:
declare float @llvm.minimumnum.f32(float, float)
%z = call float @llvm.minimumnum.f32(float %x, float %y)
ret float %z
@pure
@llvm
def fminnum16(x: float16, y: float16) -> float16:
declare half @llvm.minimumnum.f16(half, half)
%z = call half @llvm.minimumnum.f16(half %x, half %y)
ret half %z
def fminnum(x, y):
if isinstance(x, float) and isinstance(y, float):
return fminnum64(x, y)
elif isinstance(x, float32) and isinstance(y, float32):
return fminnum32(x, y)
elif isinstance(x, float16) and isinstance(y, float16):
return fminnum16(x, y)
@pure
@llvm
def fmaxnum64(x: float, y: float) -> float:
declare double @llvm.maximumnum.f64(double, double)
%z = call double @llvm.maximumnum.f64(double %x, double %y)
ret double %z
@pure
@llvm
def fmaxnum32(x: float32, y: float32) -> float32:
declare float @llvm.maximumnum.f32(float, float)
%z = call float @llvm.maximumnum.f32(float %x, float %y)
ret float %z
@pure
@llvm
def fmaxnum16(x: float16, y: float16) -> float16:
declare half @llvm.maximumnum.f16(half, half)
%z = call half @llvm.maximumnum.f16(half %x, half %y)
ret half %z
def fmaxnum(x, y):
if isinstance(x, float) and isinstance(y, float):
return fmaxnum64(x, y)
elif isinstance(x, float32) and isinstance(y, float32):
return fmaxnum32(x, y)
elif isinstance(x, float16) and isinstance(y, float16):
return fmaxnum16(x, y)
def _fp16_op_via_fp32(x: float16, op) -> float16:
return fptrunc(op(fpext(x, float32)), float16)
def _fp16_op_via_fp32_2(x: float16, y: float16, op) -> float16:
return fptrunc(op(fpext(x, float32), fpext(y, float32)), float16)
@pure
@llvm
def bswap(x: T, T: type) -> T:
declare {=T} @llvm.bswap.{=T}({=T})
%y = call {=T} @llvm.bswap.{=T}({=T} %x)
ret {=T} %y
@pure
@llvm
def isinf64(x: float) -> bool:
declare double @llvm.fabs.f64(double)
%a = call double @llvm.fabs.f64(double %x)
%y = fcmp oeq double %a, 0x7FF0000000000000
%z = zext i1 %y to i8
ret i8 %z
@pure
@llvm
def isinf32(x: float32) -> bool:
declare float @llvm.fabs.f32(float)
%a = call float @llvm.fabs.f32(float %x)
%y = fcmp oeq float %a, 0x7FF0000000000000
%z = zext i1 %y to i8
ret i8 %z
@pure
@llvm
def isinf16(x: float16) -> bool:
declare half @llvm.fabs.f16(half)
%a = call half @llvm.fabs.f16(half %x)
%y = fcmp oeq half %a, 0x7FF0000000000000
%z = zext i1 %y to i8
ret i8 %z
def isinf(x):
if isinstance(x, float):
return isinf64(x)
elif isinstance(x, float32):
return isinf32(x)
elif isinstance(x, float16):
return isinf16(x)
@pure
@llvm
def isnan64(x: float) -> bool:
%y = fcmp uno double %x, 0.000000e+00
%z = zext i1 %y to i8
ret i8 %z
@pure
@llvm
def isnan32(x: float32) -> bool:
%y = fcmp uno float %x, 0.000000e+00
%z = zext i1 %y to i8
ret i8 %z
@pure
@llvm
def isnan16(x: float16) -> bool:
%y = fcmp uno half %x, 0.000000e+00
%z = zext i1 %y to i8
ret i8 %z
def isnan(x):
if isinstance(x, float):
return isnan64(x)
elif isinstance(x, float32):
return isnan32(x)
elif isinstance(x, float16):
return isnan16(x)
@pure
@llvm
def signbit64(x: float) -> bool:
%n = bitcast double %x to i64
%s = icmp slt i64 %n, 0
%b = zext i1 %s to i8
ret i8 %b
@pure
@llvm
def signbit32(x: float32) -> bool:
%n = bitcast float %x to i32
%s = icmp slt i32 %n, 0
%b = zext i1 %s to i8
ret i8 %b
@pure
@llvm
def signbit16(x: float16) -> bool:
%n = bitcast half %x to i16
%s = icmp slt i16 %n, 0
%b = zext i1 %s to i8
ret i8 %b
def signbit(x):
if isinstance(x, float):
return signbit64(x)
elif isinstance(x, float32):
return signbit32(x)
elif isinstance(x, float16):
return signbit16(x)
@pure
@llvm
def fabs64(x: float) -> float:
declare double @llvm.fabs.f64(double)
%y = call double @llvm.fabs.f64(double %x)
ret double %y
@pure
@llvm
def fabs32(x: float32) -> float32:
declare float @llvm.fabs.f32(float)
%y = call float @llvm.fabs.f32(float %x)
ret float %y
@pure
@llvm
def fabs16(x: float16) -> float16:
declare half @llvm.fabs.f16(half)
%y = call half @llvm.fabs.f16(half %x)
ret half %y
def fabs(x):
if isinstance(x, float):
return fabs64(x)
elif isinstance(x, float32):
return fabs32(x)
elif isinstance(x, float16):
return fabs16(x)
@pure
@llvm
def _ordered_not_equal64(x: float, y: float) -> bool:
%tmp = fcmp one double %x, %y
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def _ordered_not_equal32(x: float32, y: float32) -> bool:
%tmp = fcmp one float %x, %y
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def _ordered_not_equal16(x: float16, y: float16) -> bool:
%tmp = fcmp one half %x, %y
%res = zext i1 %tmp to i8
ret i8 %res
def isfinite64(x: float):
return _ordered_not_equal64(fabs64(x), inf64())
def isfinite32(x: float32):
return _ordered_not_equal32(fabs32(x), inf32())
def isfinite16(x: float16):
return _ordered_not_equal16(fabs16(x), inf16())
def isfinite(x):
if isinstance(x, float):
return isfinite64(x)
elif isinstance(x, float32):
return isfinite32(x)
elif isinstance(x, float16):
return isfinite16(x)
@pure
@llvm
def rint64(x: float) -> float:
declare double @llvm.rint.f64(double)
%y = call double @llvm.rint.f64(double %x)
ret double %y
@pure
@llvm
def rint32(x: float32) -> float32:
declare float @llvm.rint.f32(float)
%y = call float @llvm.rint.f32(float %x)
ret float %y
@pure
@llvm
def rint16(x: float16) -> float16:
declare half @llvm.rint.f16(half)
%y = call half @llvm.rint.f16(half %x)
ret half %y
def rint(x):
if isinstance(x, float):
return rint64(x)
elif isinstance(x, float32):
return rint32(x)
elif isinstance(x, float16):
return rint16(x)
@pure
@llvm
def floor64(x: float) -> float:
declare double @llvm.floor.f64(double)
%y = call double @llvm.floor.f64(double %x)
ret double %y
@pure
@llvm
def floor32(x: float32) -> float32:
declare float @llvm.floor.f32(float)
%y = call float @llvm.floor.f32(float %x)
ret float %y
@pure
@llvm
def floor16(x: float16) -> float16:
declare half @llvm.floor.f16(half)
%y = call half @llvm.floor.f16(half %x)
ret half %y
def floor(x):
if isinstance(x, float):
return floor64(x)
elif isinstance(x, float32):
return floor32(x)
elif isinstance(x, float16):
return floor16(x)
@pure
@llvm
def ceil64(x: float) -> float:
declare double @llvm.ceil.f64(double)
%y = call double @llvm.ceil.f64(double %x)
ret double %y
@pure
@llvm
def ceil32(x: float32) -> float32:
declare float @llvm.ceil.f32(float)
%y = call float @llvm.ceil.f32(float %x)
ret float %y
@pure
@llvm
def ceil16(x: float16) -> float16:
declare half @llvm.ceil.f16(half)
%y = call half @llvm.ceil.f16(half %x)
ret half %y
def ceil(x):
if isinstance(x, float):
return ceil64(x)
elif isinstance(x, float32):
return ceil32(x)
elif isinstance(x, float16):
return ceil16(x)
@pure
@llvm
def trunc64(x: float) -> float:
declare double @llvm.trunc.f64(double)
%y = call double @llvm.trunc.f64(double %x)
ret double %y
@pure
@llvm
def trunc32(x: float32) -> float32:
declare float @llvm.trunc.f32(float)
%y = call float @llvm.trunc.f32(float %x)
ret float %y
@pure
@llvm
def trunc16(x: float16) -> float16:
declare half @llvm.trunc.f16(half)
%y = call half @llvm.trunc.f16(half %x)
ret half %y
def trunc(x):
if isinstance(x, float):
return trunc64(x)
elif isinstance(x, float32):
return trunc32(x)
elif isinstance(x, float16):
return trunc16(x)
@pure
@llvm
def exp64(x: float) -> float:
declare double @llvm.exp.f64(double)
%y = call double @llvm.exp.f64(double %x)
ret double %y
@pure
@llvm
def exp32(x: float32) -> float32:
declare float @llvm.exp.f32(float)
%y = call float @llvm.exp.f32(float %x)
ret float %y
@pure
@llvm
def exp16(x: float16) -> float16:
declare half @llvm.exp.f16(half)
%y = call half @llvm.exp.f16(half %x)
ret half %y
def exp(x):
if isinstance(x, float):
return exp64(x)
elif isinstance(x, float32):
return exp32(x)
elif isinstance(x, float16):
return exp16(x)
@pure
@llvm
def exp2_64(x: float) -> float:
declare double @llvm.exp2.f64(double)
%y = call double @llvm.exp2.f64(double %x)
ret double %y
@pure
@llvm
def exp2_32(x: float32) -> float32:
declare float @llvm.exp2.f32(float)
%y = call float @llvm.exp2.f32(float %x)
ret float %y
@pure
@llvm
def exp2_16(x: float16) -> float16:
declare half @llvm.exp2.f16(half)
%y = call half @llvm.exp2.f16(half %x)
ret half %y
def exp2(x):
if isinstance(x, float):
return exp2_64(x)
elif isinstance(x, float32):
return exp2_32(x)
elif isinstance(x, float16):
return exp2_16(x)
def expm1_64(x: float):
return _C.expm1(x)
def expm1_32(x: float32):
return _C.expm1f(x)
def expm1(x):
if isinstance(x, float):
return expm1_64(x)
elif isinstance(x, float32):
return expm1_32(x)
elif isinstance(x, float16):
return _fp16_op_via_fp32(x, expm1_32)
@pure
@llvm
def log64(x: float) -> float:
declare double @llvm.log.f64(double)
%y = call double @llvm.log.f64(double %x)
ret double %y
@pure
@llvm
def log32(x: float32) -> float32:
declare float @llvm.log.f32(float)
%y = call float @llvm.log.f32(float %x)
ret float %y
@pure
@llvm
def log16(x: float16) -> float16:
declare half @llvm.log.f16(half)
%y = call half @llvm.log.f16(half %x)
ret half %y
def log(x):
if isinstance(x, float):
return log64(x)
elif isinstance(x, float32):
return log32(x)
elif isinstance(x, float16):
return log16(x)
@pure
@llvm
def log2_64(x: float) -> float:
declare double @llvm.log2.f64(double)
%y = call double @llvm.log2.f64(double %x)
ret double %y
@pure
@llvm
def log2_32(x: float32) -> float32:
declare float @llvm.log2.f32(float)
%y = call float @llvm.log2.f32(float %x)
ret float %y
@pure
@llvm
def log2_16(x: float16) -> float16:
declare half @llvm.log2.f16(half)
%y = call half @llvm.log2.f16(half %x)
ret half %y
def log2(x):
if isinstance(x, float):
return log2_64(x)
elif isinstance(x, float32):
return log2_32(x)
elif isinstance(x, float16):
return log2_16(x)
@pure
@llvm
def log10_64(x: float) -> float:
declare double @llvm.log10.f64(double)
%y = call double @llvm.log10.f64(double %x)
ret double %y
@pure
@llvm
def log10_32(x: float32) -> float32:
declare float @llvm.log10.f32(float)
%y = call float @llvm.log10.f32(float %x)
ret float %y
@pure
@llvm
def log10_16(x: float16) -> float16:
declare half @llvm.log10.f16(half)
%y = call half @llvm.log10.f16(half %x)
ret half %y
def log10(x):
if isinstance(x, float):
return log10_64(x)
elif isinstance(x, float32):
return log10_32(x)
elif isinstance(x, float16):
return log10_16(x)
def log1p64(x: float):
return _C.log1p(x)
def log1p32(x: float32):
return _C.log1pf(x)
def log1p(x):
if isinstance(x, float):
return log1p64(x)
elif isinstance(x, float32):
return log1p32(x)
elif isinstance(x, float16):
return _fp16_op_via_fp32(x, log1p32)
@pure
@llvm
def sqrt64(x: float) -> float:
declare double @llvm.sqrt.f64(double)
%y = call double @llvm.sqrt.f64(double %x)
ret double %y
@pure
@llvm
def sqrt32(x: float32) -> float32:
declare float @llvm.sqrt.f32(float)
%y = call float @llvm.sqrt.f32(float %x)
ret float %y
@pure
@llvm
def sqrt16(x: float16) -> float16:
declare half @llvm.sqrt.f16(half)
%y = call half @llvm.sqrt.f16(half %x)
ret half %y
def sqrt(x):
if isinstance(x, float):
return sqrt64(x)
elif isinstance(x, float32):
return sqrt32(x)
elif isinstance(x, float16):
return sqrt16(x)
def cbrt64(x: float):
return _C.cbrt(x)
def cbrt32(x: float32):
return _C.cbrtf(x)
def cbrt(x):
if isinstance(x, float):
return cbrt64(x)
elif isinstance(x, float32):
return cbrt32(x)
elif isinstance(x, float16):
return _fp16_op_via_fp32(x, cbrt32)
@pure
@llvm
def sin64(x: float) -> float:
declare double @llvm.sin.f64(double)
%y = call double @llvm.sin.f64(double %x)
ret double %y
@pure
@llvm
def sin32(x: float32) -> float32:
declare float @llvm.sin.f32(float)
%y = call float @llvm.sin.f32(float %x)
ret float %y
@pure
@llvm
def sin16(x: float16) -> float16:
declare half @llvm.sin.f16(half)
%y = call half @llvm.sin.f16(half %x)
ret half %y
def sin(x):
if isinstance(x, float):
return sin64(x)
elif isinstance(x, float32):
return sin32(x)
elif isinstance(x, float16):
return sin16(x)
@pure
@llvm
def cos64(x: float) -> float:
declare double @llvm.cos.f64(double)
%y = call double @llvm.cos.f64(double %x)
ret double %y
@pure
@llvm
def cos32(x: float32) -> float32:
declare float @llvm.cos.f32(float)
%y = call float @llvm.cos.f32(float %x)
ret float %y
@pure
@llvm
def cos16(x: float16) -> float16:
declare half @llvm.cos.f16(half)
%y = call half @llvm.cos.f16(half %x)
ret half %y
def cos(x):
if isinstance(x, float):
return cos64(x)
elif isinstance(x, float32):
return cos32(x)
elif isinstance(x, float16):
return cos16(x)
def tan64(x: float):
return _C.tan(x)
def tan32(x: float32):
return _C.tanf(x)
def tan(x):
if isinstance(x, float):
return tan64(x)
elif isinstance(x, float32):
return tan32(x)
elif isinstance(x, float16):
return _fp16_op_via_fp32(x, tan32)
def asin64(x: float):
return _C.asin(x)
def asin32(x: float32):
return _C.asinf(x)
def asin(x):
if isinstance(x, float):
return asin64(x)
elif isinstance(x, float32):
return asin32(x)
elif isinstance(x, float16):
return _fp16_op_via_fp32(x, asin32)
def acos64(x: float):
return _C.acos(x)
def acos32(x: float32):
return _C.acosf(x)
def acos(x):
if isinstance(x, float):
return acos64(x)
elif isinstance(x, float32):
return acos32(x)
elif isinstance(x, float16):
return _fp16_op_via_fp32(x, acos32)
def atan64(x: float):
return _C.atan(x)
def atan32(x: float32):
return _C.atanf(x)
def atan(x):
if isinstance(x, float):
return atan64(x)
elif isinstance(x, float32):
return atan32(x)
elif isinstance(x, float16):
return _fp16_op_via_fp32(x, atan32)
def sinh64(x: float):
return _C.sinh(x)
def sinh32(x: float32):
return _C.sinhf(x)
def sinh(x):
if isinstance(x, float):
return sinh64(x)
elif isinstance(x, float32):
return sinh32(x)
elif isinstance(x, float16):
return _fp16_op_via_fp32(x, sinh32)
def cosh64(x: float):
return _C.cosh(x)
def cosh32(x: float32):
return _C.coshf(x)
def cosh(x):
if isinstance(x, float):
return cosh64(x)
elif isinstance(x, float32):
return cosh32(x)
elif isinstance(x, float16):
return _fp16_op_via_fp32(x, cosh32)
def tanh64(x: float):
return _C.tanh(x)
def tanh32(x: float32):
return _C.tanhf(x)
def tanh(x):
if isinstance(x, float):
return tanh64(x)
elif isinstance(x, float32):
return tanh32(x)
elif isinstance(x, float16):
return _fp16_op_via_fp32(x, tanh32)
def asinh64(x: float):
return _C.asinh(x)
def asinh32(x: float32):
return _C.asinhf(x)
def asinh(x):
if isinstance(x, float):
return asinh64(x)
elif isinstance(x, float32):
return asinh32(x)
elif isinstance(x, float16):
return _fp16_op_via_fp32(x, asinh32)
def acosh64(x: float):
return _C.acosh(x)
def acosh32(x: float32):
return _C.acoshf(x)
def acosh(x):
if isinstance(x, float):
return acosh64(x)
elif isinstance(x, float32):
return acosh32(x)
elif isinstance(x, float16):
return _fp16_op_via_fp32(x, acosh32)
def atanh64(x: float):
return _C.atanh(x)
def atanh32(x: float32):
return _C.atanhf(x)
def atanh(x):
if isinstance(x, float):
return atanh64(x)
elif isinstance(x, float32):
return atanh32(x)
elif isinstance(x, float16):
return _fp16_op_via_fp32(x, atanh32)
def atan2_64(x: float, y: float):
return _C.atan2(x, y)
def atan2_32(x: float32, y: float32):
return _C.atan2f(x, y)
def atan2(x, y):
if isinstance(x, float) and isinstance(y, float):
return atan2_64(x, y)
elif isinstance(x, float32) and isinstance(y, float32):
return atan2_32(x, y)
elif isinstance(x, float16) and isinstance(y, float16):
return _fp16_op_via_fp32_2(x, y, atan2_32)
@pure
@llvm
def copysign64(x: float, y: float) -> float:
declare double @llvm.copysign.f64(double, double)
%z = call double @llvm.copysign.f64(double %x, double %y)
ret double %z
@pure
@llvm
def copysign32(x: float32, y: float32) -> float32:
declare float @llvm.copysign.f32(float, float)
%z = call float @llvm.copysign.f32(float %x, float %y)
ret float %z
@pure
@llvm
def copysign16(x: float16, y: float16) -> float16:
declare half @llvm.copysign.f16(half, half)
%z = call half @llvm.copysign.f16(half %x, half %y)
ret half %z
def copysign(x, y):
if isinstance(x, float) and isinstance(y, float):
return copysign64(x, y)
elif isinstance(x, float32) and isinstance(y, float32):
return copysign32(x, y)
elif isinstance(x, float16) and isinstance(y, float16):
return copysign16(x, y)
@pure
@llvm
def pow64(x: float, y: float) -> float:
declare double @llvm.pow.f64(double, double)
%z = call double @llvm.pow.f64(double %x, double %y)
ret double %z
@pure
@llvm
def pow32(x: float32, y: float32) -> float32:
declare float @llvm.pow.f32(float, float)
%z = call float @llvm.pow.f32(float %x, float %y)
ret float %z
@pure
@llvm
def pow16(x: float16, y: float16) -> float16:
declare half @llvm.pow.f16(half, half)
%z = call half @llvm.pow.f16(half %x, half %y)
ret half %z
def pow(x, y):
if isinstance(x, float) and isinstance(y, float):
return pow64(x, y)
elif isinstance(x, float32) and isinstance(y, float32):
return pow32(x, y)
elif isinstance(x, float16) and isinstance(y, float16):
return pow16(x, y)
@pure
@llvm
def frexp64(x: float) -> Tuple[float, int]:
declare { double, i32 } @llvm.frexp.f64(double)
%y = call { double, i32 } @llvm.frexp.f64(double %x)
%y0 = extractvalue { double, i32 } %y, 0
%y1 = extractvalue { double, i32 } %y, 1
%z1 = sext i32 %y1 to i64
%r0 = insertvalue { double, i64 } undef, double %y0, 0
%r1 = insertvalue { double, i64 } %r0, i64 %z1, 1
ret { double, i64 } %r1
@pure
@llvm
def frexp32(x: float32) -> Tuple[float32, int]:
declare { float, i32 } @llvm.frexp.f32(float)
%y = call { float, i32 } @llvm.frexp.f32(float %x)
%y0 = extractvalue { float, i32 } %y, 0
%y1 = extractvalue { float, i32 } %y, 1
%z1 = sext i32 %y1 to i64
%r0 = insertvalue { float, i64 } undef, float %y0, 0
%r1 = insertvalue { float, i64 } %r0, i64 %z1, 1
ret { float, i64 } %r1
@pure
@llvm
def frexp16(x: float16) -> Tuple[float16, int]:
declare { half, i32 } @llvm.frexp.f16(half)
%y = call { half, i32 } @llvm.frexp.f16(half %x)
%y0 = extractvalue { half, i32 } %y, 0
%y1 = extractvalue { half, i32 } %y, 1
%z1 = sext i32 %y1 to i64
%r0 = insertvalue { half, i64 } undef, half %y0, 0
%r1 = insertvalue { half, i64 } %r0, i64 %z1, 1
ret { half, i64 } %r1
def frexp(x):
if isinstance(x, float):
return frexp64(x)
elif isinstance(x, float32):
return frexp32(x)
elif isinstance(x, float16):
return frexp16(x)
@pure
@llvm
def ldexp64(x: float, exp: int) -> float:
declare double @llvm.ldexp.f64(double, i32)
%e = trunc i64 %exp to i32
%y = call double @llvm.ldexp.f64(double %x, i32 %e)
ret double %y
@pure
@llvm
def ldexp32(x: float32, exp: int) -> float32:
declare float @llvm.ldexp.f32(float, i32)
%e = trunc i64 %exp to i32
%y = call float @llvm.ldexp.f32(float %x, i32 %e)
ret float %y
def ldexp16(x: float16, exp: int) -> float16:
return fptrunc(ldexp32(fpext(x, float32), exp), float16)
def ldexp(x, exp: int):
if isinstance(x, float):
return ldexp64(x, exp)
elif isinstance(x, float32):
return ldexp32(x, exp)
elif isinstance(x, float16):
return ldexp16(x, exp)
LOGE2 = 0.693147180559945309417232121458176568
LOGE10 = 2.302585092994045684017991454684364208
LOG2E = 1.442695040888963407359924681001892137
LOG10E = 0.434294481903251827651128918916605082
SQRT2 = 1.414213562373095048801688724209698079
PI = 3.141592653589793238462643383279502884
PI_2 = 1.570796326794896619231321691639751442
E = 2.718281828459045235360287471352662498
def logaddexp64(x: float, y: float):
if x == y:
return x + LOGE2
tmp = x - y
if tmp > 0:
return x + _C.log1p(exp64(-tmp))
elif tmp <= 0:
return y + _C.log1p(exp64(tmp))
else:
return tmp
def logaddexp32(x: float32, y: float32):
if x == y:
return float32(float(x) + LOGE2)
tmp = x - y
if tmp > float32(0):
return x + _C.log1pf(exp32(-tmp))
elif tmp <= float32(0):
return y + _C.log1pf(exp32(tmp))
else:
return tmp
def logaddexp(x, y):
if isinstance(x, float) and isinstance(y, float):
return logaddexp64(x, y)
elif isinstance(x, float32) and isinstance(y, float32):
return logaddexp32(x, y)
elif isinstance(x, float16) and isinstance(y, float16):
return _fp16_op_via_fp32_2(x, y, logaddexp32)
def logaddexp2_64(x: float, y: float):
if x == y:
return x + 1
tmp = x - y
if tmp > 0:
return x + LOG2E * _C.log1p(exp2_64(-tmp))
elif tmp <= 0:
return y + LOG2E * _C.log1p(exp2_64(tmp))
else:
return tmp
def logaddexp2_32(x: float32, y: float32):
if x == y:
return x + float32(1)
tmp = x - y
if tmp > float32(0):
return x + float32(LOG2E * float(_C.log1pf(exp2_32(-tmp))))
elif tmp <= float32(0):
return y + float32(LOG2E * float(_C.log1pf(exp2_32(tmp))))
else:
return tmp
def logaddexp2(x, y):
if isinstance(x, float) and isinstance(y, float):
return logaddexp2_64(x, y)
elif isinstance(x, float32) and isinstance(y, float32):
return logaddexp2_32(x, y)
elif isinstance(x, float16) and isinstance(y, float16):
return _fp16_op_via_fp32_2(x, y, logaddexp2_32)
def hypot64(x: float, y: float):
return _C.hypot(x, y)
def hypot32(x: float32, y: float32):
return _C.hypotf(x, y)
def hypot(x, y):
if isinstance(x, float) and isinstance(y, float):
return hypot64(x, y)
elif isinstance(x, float32) and isinstance(y, float32):
return hypot32(x, y)
elif isinstance(x, float16) and isinstance(y, float16):
return _fp16_op_via_fp32_2(x, y, hypot32)
def atan2_64(x: float, y: float):
return _C.atan2(x, y)
def atan2_32(x: float32, y: float32):
return _C.atan2f(x, y)
def atan2(x, y):
if isinstance(x, float) and isinstance(y, float):
return atan2_64(x, y)
elif isinstance(x, float32) and isinstance(y, float32):
return atan2_32(x, y)
elif isinstance(x, float16) and isinstance(y, float16):
return _fp16_op_via_fp32_2(x, y, atan2_32)
def modf64(x: float):
i = float()
a = _C.modf(x, __ptr__(i))
return (a, i)
def modf32(x: float32):
i = float32()
a = _C.modff(x, __ptr__(i))
return (a, i)
def modf16(x: float16):
y = fpext(x, float32)
i = float32()
a = _C.modff(y, __ptr__(i))
return (fptrunc(a, float16), fptrunc(i, float16))
def modf(x):
if isinstance(x, float):
return modf64(x)
elif isinstance(x, float32):
return modf32(x)
elif isinstance(x, float16):
return modf16(x)
def nextafter64(x: float, y: float):
return _C.nextafter(x, y)
def nextafter32(x: float32, y: float32):
return _C.nextafterf(x, y)
def nextafter16(x: float16, y: float16):
y_u16 = bitcast(y, u16)
x_u16 = bitcast(x, u16)
one = u16(1)
if isnan16(x) or isnan16(y):
return nan16()
elif x_u16 == y_u16 or (x_u16 & y_u16) == u16(0x8000): # x == y (non-nan)
return x
elif not (x_u16 & u16(0x7fff)): # x == 0
return bitcast((y_u16 & u16(0x8000)) + one, float16)
elif not (x_u16 & u16(0x8000)): # x > 0
if i16(x_u16) > i16(y_u16): # x > y
return bitcast(x_u16 - one, float16)
else:
return bitcast(x_u16 + one, float16)
else:
if (not (y_u16 & u16(0x8000))) or (x_u16 & u16(0x7fff)) > (y_u16 & u16(0x7fff)): # x < y
return bitcast(x_u16 - one, float16)
else:
return bitcast(x_u16 + one, float16)
def nextafter(x, y):
if isinstance(x, float) and isinstance(y, float):
return nextafter64(x, y)
elif isinstance(x, float32) and isinstance(y, float32):
return nextafter32(x, y)
elif isinstance(x, float16) and isinstance(y, float16):
return nextafter16(x, y)
@pure
@llvm
def cdiv_int(x: Int[N], y: Int[N], N: Static[int]) -> Int[N]:
%z = sdiv i{=N} %x, %y
ret i{=N} %z
@overload
@pure
@llvm
def cdiv_int(x: int, y: int) -> int:
%z = sdiv i64 %x, %y
ret i64 %z
@pure
@llvm
def cmod_int(x: Int[N], y: Int[N], N: Static[int]) -> Int[N]:
%z = srem i{=N} %x, %y
ret i{=N} %z
@overload
@pure
@llvm
def cmod_int(x: int, y: int) -> int:
%z = srem i64 %x, %y
ret i64 %z
def pydiv(x: Int[N], y: Int[N], N: Static[int]):
div = cdiv_int(x, y)
mask = Int[N](1) << Int[N](N - 1)
if ((x ^ y) & mask) and div * y != x:
div -= Int[N](1)
return div
@overload
def pydiv(x: int, y: int) -> int:
div = cdiv_int(x, y)
mask = 1 << 63
if ((x ^ y) & mask) and div * y != x:
div -= 1
return div
def pymod(x: Int[N], y: Int[N], N: Static[int]):
mod = cmod_int(x, y)
mask = Int[N](1) << Int[N](N - 1)
if mod and ((x ^ y) & mask):
mod += y
return mod
@overload
def pymod(x: int, y: int) -> int:
mod = cmod_int(x, y)
mask = 1 << 63
if mod and ((x ^ y) & mask):
mod += y
return mod
@pure
@llvm
def cdiv64(x: float, y: float) -> float:
%z = fdiv double %x, %y
ret double %z
@pure
@llvm
def cdiv32(x: float32, y: float32) -> float32:
%z = fdiv float %x, %y
ret float %z
@pure
@llvm
def cdiv16(x: float16, y: float16) -> float16:
%z = fdiv half %x, %y
ret half %z
def cdiv(x, y):
if isinstance(x, float) and isinstance(y, float):
return cdiv64(x, y)
elif isinstance(x, float32) and isinstance(y, float32):
return cdiv32(x, y)
elif isinstance(x, float16) and isinstance(y, float16):
return cdiv16(x, y)
@pure
@llvm
def cmod64(x: float, y: float) -> float:
%z = frem double %x, %y
ret double %z
@pure
@llvm
def cmod32(x: float32, y: float32) -> float32:
%z = frem float %x, %y
ret float %z
@pure
@llvm
def cmod16(x: float16, y: float16) -> float16:
%z = frem half %x, %y
ret half %z
def cmod(x, y):
if isinstance(x, float) and isinstance(y, float):
return cmod64(x, y)
elif isinstance(x, float32) and isinstance(y, float32):
return cmod32(x, y)
elif isinstance(x, float16) and isinstance(y, float16):
return cmod16(x, y)
def pyfmod(x, y):
X = type(x)
zero = X()
mod = cmod(x, y)
if mod:
if (y < zero) != (mod < zero):
mod += y
else:
mod = copysign(zero, y)
return mod
@pure
@llvm
def _i0A(idx: int) -> float:
@data = private unnamed_addr constant [30 x double] [double 0xBC545CB72134D0EF, double 0x3C833362977DA589, double 0xBCB184EB721EBBB4, double 0x3CDEE6D893F65EBA, double 0xBD0A5022C297FBEB, double 0x3D359B464B262627, double 0xBD61164C62EE1AF0, double 0x3D89FE2FE19BD324, double 0xBDB2FC957A946ABC, double 0x3DDA98BECC743C10, double 0xBE01D4FE13AE9556, double 0x3E26D903A454CB34, double 0xBE4BEAF68C0B30AB, double 0x3E703B769D4D6435, double 0xBE91EC638F227F8D, double 0x3EB2BF24978CF4AC, double 0xBED2866FCBA56427, double 0x3EF13F58BE9A2859, double 0xBF0E2B2659C41D5A, double 0x3F28B51B74107CAB, double 0xBF42E2FD1F15EB52, double 0x3F5ADC758A12100E, double 0xBF71B65E201AA849, double 0x3F859961F3DDE3DD, double 0xBF984E9EF121B6F0, double 0x3FA93E8ACEA8A32D, double 0xBFB84B70342D06EA, double 0x3FC5F7AC77AC88C0, double 0xBFD37FEBC057CD8D, double 0x3FE5A84E9035A22A], align 16
%p = getelementptr inbounds [256 x double], ptr @data, i64 0, i64 %idx
%x = load double, ptr %p, align 8
ret double %x
@pure
@llvm
def _i0B(idx: int) -> float:
@data = private unnamed_addr constant [25 x double] [double 0xBC60ADB754CA8B19, double 0xBC5646DA66119130, double 0x3C89BE1812D98421, double 0x3C83F3DD076041CD, double 0xBCB4600BABD21FE4, double 0xBCB8AEE7D908DE38, double 0x3CDFEE7DA3EAFB1F, double 0x3CF12A919094E6D7, double 0xBD0583FE7E65629A, double 0xBD275D99CF68BB32, double 0x3D1156FF0D5FC545, double 0x3D5B1C8C6B83C073, double 0x3D694347FA268CEC, double 0xBD7F904303178D66, double 0xBDAD0FD7357E7BF2, double 0xBDC1511D08397425, double 0x3DAA24FEABE8004F, double 0x3E00F9CCC0F46F75, double 0x3E2D2C64A9225B87, double 0x3E58569280D6D56D, double 0x3E8B8007D9CD616E, double 0x3EC8412BC101C586, double 0x3F120FA378999E52, double 0x3F6B998CA2E59049, double 0x3FE9BE62ACA809CB], align 16
%p = getelementptr inbounds [256 x double], ptr @data, i64 0, i64 %idx
%x = load double, ptr %p, align 8
ret double %x
_NUM_I0A: Static[int] = 30
_NUM_I0B: Static[int] = 25
def _chbevl64(x: float, vals, nvals: int):
b0 = vals(0)
b1 = 0.0
for i in range(1, nvals):
b2 = b1
b1 = b0
b0 = x*b1 - b2 + vals(i)
return 0.5*(b0 - b2)
def _chbevl32(x: float32, vals, nvals: int):
b0 = float32(vals(0))
b1 = float32(0.0)
for i in range(1, nvals):
b2 = b1
b1 = b0
b0 = x*b1 - b2 + float32(vals(i))
return float32(0.5)*(b0 - b2)
def _i0_1_64(x: float):
return exp64(x) * _chbevl64(x/2.0-2, _i0A, _NUM_I0A)
def _i0_2_64(x: float):
return exp64(x) * _chbevl64(32.0/x - 2.0, _i0B, _NUM_I0B) / sqrt64(x)
def _i0_1_32(x: float32):
return exp32(x) * _chbevl32(x/float32(2.0)-float32(2), _i0A, _NUM_I0A)
def _i0_2_32(x: float32):
return exp32(x) * _chbevl32(float32(32.0)/x - float32(2.0), _i0B, _NUM_I0B) / sqrt32(x)
def i0_64(x: float):
x = fabs64(x)
if x <= 8.0:
return _i0_1_64(x)
else:
return _i0_2_64(x)
def i0_32(x: float32):
x = fabs32(x)
if x <= float32(8.0):
return _i0_1_32(x)
else:
return _i0_2_32(x)
def i0(x):
if isinstance(x, float):
return i0_64(x)
elif isinstance(x, float32):
return i0_32(x)
elif isinstance(x, float16):
return _fp16_op_via_fp32(x, i0_32)
def call_vectorized_loop(in1: Ptr[T], is1: int, in2: Ptr[T], is2: int, out: Ptr[T],
os: int, n: int, func: Static[str], T: type):
from C import cnp_acos_float64(Ptr[float], int, Ptr[float], int, int)
from C import cnp_acos_float32(Ptr[float32], int, Ptr[float32], int, int)
from C import cnp_acosh_float64(Ptr[float], int, Ptr[float], int, int)
from C import cnp_acosh_float32(Ptr[float32], int, Ptr[float32], int, int)
from C import cnp_asin_float64(Ptr[float], int, Ptr[float], int, int)
from C import cnp_asin_float32(Ptr[float32], int, Ptr[float32], int, int)
from C import cnp_asinh_float64(Ptr[float], int, Ptr[float], int, int)
from C import cnp_asinh_float32(Ptr[float32], int, Ptr[float32], int, int)
from C import cnp_atan_float64(Ptr[float], int, Ptr[float], int, int)
from C import cnp_atan_float32(Ptr[float32], int, Ptr[float32], int, int)
from C import cnp_atanh_float64(Ptr[float], int, Ptr[float], int, int)
from C import cnp_atanh_float32(Ptr[float32], int, Ptr[float32], int, int)
from C import cnp_atan2_float64(Ptr[float], int, Ptr[float], int, Ptr[float], int, int)
from C import cnp_atan2_float32(Ptr[float32], int, Ptr[float32], int, Ptr[float32], int, int)
from C import cnp_cos_float64(Ptr[float], int, Ptr[float], int, int)
from C import cnp_cos_float32(Ptr[float32], int, Ptr[float32], int, int)
from C import cnp_exp_float64(Ptr[float], int, Ptr[float], int, int)
from C import cnp_exp_float32(Ptr[float32], int, Ptr[float32], int, int)
from C import cnp_exp2_float64(Ptr[float], int, Ptr[float], int, int)
from C import cnp_exp2_float32(Ptr[float32], int, Ptr[float32], int, int)
from C import cnp_expm1_float64(Ptr[float], int, Ptr[float], int, int)
from C import cnp_expm1_float32(Ptr[float32], int, Ptr[float32], int, int)
from C import cnp_log_float64(Ptr[float], int, Ptr[float], int, int)
from C import cnp_log_float32(Ptr[float32], int, Ptr[float32], int, int)
from C import cnp_log10_float64(Ptr[float], int, Ptr[float], int, int)
from C import cnp_log10_float32(Ptr[float32], int, Ptr[float32], int, int)
from C import cnp_log1p_float64(Ptr[float], int, Ptr[float], int, int)
from C import cnp_log1p_float32(Ptr[float32], int, Ptr[float32], int, int)
from C import cnp_log2_float64(Ptr[float], int, Ptr[float], int, int)
from C import cnp_log2_float32(Ptr[float32], int, Ptr[float32], int, int)
from C import cnp_sin_float64(Ptr[float], int, Ptr[float], int, int)
from C import cnp_sin_float32(Ptr[float32], int, Ptr[float32], int, int)
from C import cnp_sinh_float64(Ptr[float], int, Ptr[float], int, int)
from C import cnp_sinh_float32(Ptr[float32], int, Ptr[float32], int, int)
from C import cnp_tanh_float64(Ptr[float], int, Ptr[float], int, int)
from C import cnp_tanh_float32(Ptr[float32], int, Ptr[float32], int, int)
from C import cnp_hypot_float64(Ptr[float], int, Ptr[float], int, Ptr[float], int, int)
from C import cnp_hypot_float32(Ptr[float32], int, Ptr[float32], int, Ptr[float32], int, int)
if func == 'arccos' and T is float:
cnp_acos_float64(in1, is1, out, os, n)
elif func == 'arccos' and T is float32:
cnp_acos_float32(in1, is1, out, os, n)
elif func == 'arccosh' and T is float:
cnp_acosh_float64(in1, is1, out, os, n)
elif func == 'arccosh' and T is float32:
cnp_acosh_float32(in1, is1, out, os, n)
elif func == 'arcsin' and T is float:
cnp_asin_float64(in1, is1, out, os, n)
elif func == 'arcsin' and T is float32:
cnp_asin_float32(in1, is1, out, os, n)
elif func == 'arcsinh' and T is float:
cnp_asinh_float64(in1, is1, out, os, n)
elif func == 'arcsinh' and T is float32:
cnp_asinh_float32(in1, is1, out, os, n)
elif func == 'arctan' and T is float:
cnp_atan_float64(in1, is1, out, os, n)
elif func == 'arctan' and T is float32:
cnp_atan_float32(in1, is1, out, os, n)
elif func == 'arctanh' and T is float:
cnp_atanh_float64(in1, is1, out, os, n)
elif func == 'arctanh' and T is float32:
cnp_atanh_float32(in1, is1, out, os, n)
elif func == 'arctan2' and T is float:
cnp_atan2_float64(in1, is1, in2, is2, out, os, n)
elif func == 'arctan2' and T is float32:
cnp_atan2_float32(in1, is1, in2, is2, out, os, n)
elif func == 'cos' and T is float:
cnp_cos_float64(in1, is1, out, os, n)
elif func == 'cos' and T is float32:
cnp_cos_float32(in1, is1, out, os, n)
elif func == 'exp' and T is float:
cnp_exp_float64(in1, is1, out, os, n)
elif func == 'exp' and T is float32:
cnp_exp_float32(in1, is1, out, os, n)
elif func == 'exp2' and T is float:
cnp_exp2_float64(in1, is1, out, os, n)
elif func == 'exp2' and T is float32:
cnp_exp2_float32(in1, is1, out, os, n)
elif func == 'expm1' and T is float:
cnp_expm1_float64(in1, is1, out, os, n)
elif func == 'expm1' and T is float32:
cnp_expm1_float32(in1, is1, out, os, n)
elif func == 'log' and T is float:
cnp_log_float64(in1, is1, out, os, n)
elif func == 'log' and T is float32:
cnp_log_float32(in1, is1, out, os, n)
elif func == 'log10' and T is float:
cnp_log10_float64(in1, is1, out, os, n)
elif func == 'log10' and T is float32:
cnp_log10_float32(in1, is1, out, os, n)
elif func == 'log1p' and T is float:
cnp_log1p_float64(in1, is1, out, os, n)
elif func == 'log1p' and T is float32:
cnp_log1p_float32(in1, is1, out, os, n)
elif func == 'log2' and T is float:
cnp_log2_float64(in1, is1, out, os, n)
elif func == 'log2' and T is float32:
cnp_log2_float32(in1, is1, out, os, n)
elif func == 'sin' and T is float:
cnp_sin_float64(in1, is1, out, os, n)
elif func == 'sin' and T is float32:
cnp_sin_float32(in1, is1, out, os, n)
elif func == 'sinh' and T is float:
cnp_sinh_float64(in1, is1, out, os, n)
elif func == 'sinh' and T is float32:
cnp_sinh_float32(in1, is1, out, os, n)
elif func == 'tanh' and T is float:
cnp_tanh_float64(in1, is1, out, os, n)
elif func == 'tanh' and T is float32:
cnp_tanh_float32(in1, is1, out, os, n)
elif func == 'hypot' and T is float:
cnp_hypot_float64(in1, is1, in2, is2, out, os, n)
elif func == 'hypot' and T is float32:
cnp_hypot_float32(in1, is1, in2, is2, out, os, n)
else:
pass
#########
# Types #
#########
def sizeof(X: type):
from internal.gc import sizeof as sz
return sz(X)
def atomic(X: type):
from internal.gc import atomic as atm
return atm(X)
def free(p: Ptr[T], T: type):
from internal.gc import free as fr
fr(p.as_byte())
def realloc(p: Ptr[T], newsize: int, oldsize: int, T: type):
from internal.gc import realloc as re
sz = sizeof(T)
return Ptr[T](re(p.as_byte(), newsize * sz, oldsize * sz))
def strides(shape, forder: bool, X: type):
if staticlen(shape) == 0:
return ()
strides = (0,) * staticlen(shape)
p = Ptr[int](__ptr__(strides).as_byte())
n = len(shape)
curr = sizeof(X)
for i in range(n):
j = i if forder else n - 1 - i
p[j] = curr
curr *= shape[j]
return strides
def cast(x, T: type):
X = type(x)
# Need to support specially:
# - bool
# - int
# - byte
# - Int[N]
# - UInt[N]
# - float
# - float32
# - complex
# - complex64
if ((isinstance(X, datetime64) and isinstance(T, datetime64)) or
(isinstance(X, timedelta64) and isinstance(T, timedelta64))):
return x._cast(T)
elif isinstance(X, datetime64) or isinstance(X, timedelta64):
if T is float or T is float32 or T is float16:
return nan(T) if x._nat else cast(x.value, T)
else:
return cast(x.value, T)
elif T is X or T is NoneType:
return x
elif T is bool:
return bool(x)
elif isinstance(T, datetime64) or isinstance(T, timedelta64):
_DATETIME_NAT: Static[int] = -9_223_372_036_854_775_808
if X is float or X is float32 or X is float16:
return T(_DATETIME_NAT, T.base, T.num) if isnan(x) else T(cast(x, int), T.base, T.num)
elif X is complex or X is complex64:
return T(_DATETIME_NAT, T.base, T.num) if (isnan(x.real) or isnan(x.imag)) else T(cast(x, int), T.base, T.num)
elif X is str:
return T(x, T.base, T.num)
else:
return T(cast(x, int), T.base, T.num)
elif T is int:
if X is bool:
return 1 if x else 0
elif X is int:
return x
elif X is byte:
return int(x)
elif isinstance(X, Int):
return int(x)
elif isinstance(X, UInt):
return int(x)
elif X is float:
return int(x)
elif X is float16:
return int(float(x))
elif X is float32:
return int(float(x))
elif X is complex:
return int(x.real)
elif X is complex64:
return int(float(x.real))
else:
return T(x)
elif T is byte:
return byte(int(x))
elif isinstance(T, Int):
if X is bool:
return T(1 if x else 0)
elif X is int:
return T(x)
elif X is byte:
if T.N > 8:
return zext(x, T)
elif T.N < 8:
return itrunc(x, T)
else:
return noop(x, T)
elif isinstance(X, Int):
if T.N > X.N:
return sext(x, T)
elif T.N < X.N:
return itrunc(x, T)
else:
return noop(x, T)
elif isinstance(X, UInt):
if T.N > X.N:
return zext(x, T)
elif T.N < X.N:
return itrunc(x, T)
else:
return noop(x, T)
elif X is float:
return fptosi(x, T)
elif X is float16:
return fptosi(x, T)
elif X is float32:
return fptosi(x, T)
elif X is complex:
return cast(x.real, T)
elif X is complex64:
return cast(x.real, T)
else:
return T(x)
elif isinstance(T, UInt):
if X is bool:
return T(1 if x else 0)
elif X is int:
return T(x)
elif X is byte:
if T.N > 8:
return zext(x, T)
elif T.N < 8:
return itrunc(x, T)
else:
return noop(x, T)
elif isinstance(X, Int):
if T.N > X.N:
return sext(x, T)
elif T.N < X.N:
return itrunc(x, T)
else:
return noop(x, T)
elif isinstance(X, UInt):
if T.N > X.N:
return zext(x, T)
elif T.N < X.N:
return itrunc(x, T)
else:
return noop(x, T)
elif X is float:
return fptoui(x, T)
elif X is float16:
return fptoui(x, T)
elif X is float32:
return fptoui(x, T)
elif X is complex:
return cast(x.real, T)
elif X is complex64:
return cast(x.real, T)
else:
return T(x)
elif T is float:
if X is bool:
return 1. if x else 0.
elif X is int:
return sitofp(x, T)
elif X is byte:
return uitofp(x, T)
elif isinstance(X, Int):
return sitofp(x, T)
elif isinstance(X, UInt):
return uitofp(x, T)
elif X is float:
return x
elif X is float16:
return float(x)
elif X is float32:
return float(x)
elif X is complex:
return x.real
elif X is complex64:
return float(x.real)
else:
return T(x)
elif T is float16:
if X is bool:
return float16(1. if x else 0.)
elif X is int:
return sitofp(x, T)
elif X is byte:
return uitofp(x, T)
elif isinstance(X, Int):
return sitofp(x, T)
elif isinstance(X, UInt):
return uitofp(x, T)
elif X is float:
return float16(x)
elif X is float16:
return x
elif X is float32:
return fptrunc(x, T)
elif X is complex:
return float16(x.real)
elif X is complex64:
return fptrunc(x.real, T)
else:
return T(x)
elif T is float32:
if X is bool:
return float32(1. if x else 0.)
elif X is int:
return sitofp(x, T)
elif X is byte:
return uitofp(x, T)
elif isinstance(X, Int):
return sitofp(x, T)
elif isinstance(X, UInt):
return uitofp(x, T)
elif X is float:
return float32(x)
elif X is float16:
return fpext(x, T)
elif X is float32:
return x
elif X is complex:
return float32(x.real)
elif X is complex64:
return x.real
else:
return T(x)
elif T is complex:
if X is complex64:
return complex(x)
else:
return complex(cast(x, float), 0.0)
elif T is complex64:
if X is complex:
return complex64(x)
else:
return complex64(cast(x, float32), float32(0))
else:
return T(x)
def coerce(T1: type, T2: type):
def coerce_error(T1: type, T2: type):
compile_error("cannot coerce types '" + T1.__name__ + "' and '" + T2.__name__ + "'")
def coerce_ints_helper(T1: type, T2: type):
if isinstance(T1, Int) and isinstance(T2, Int):
if T1.N >= T2.N:
return T1()
else:
return T2()
elif isinstance(T1, UInt) and isinstance(T2, UInt):
if T1.N >= T2.N:
return T1()
else:
return T2()
elif isinstance(T1, Int) and isinstance(T2, UInt):
if T1.N > T2.N:
return T1()
else:
if T2.N == 8 or T2.N == 16 or T2.N == 32:
return Int[2 * T2.N]()
elif T2.N >= 64:
return float()
elif isinstance(T1, UInt) and isinstance(T2, Int):
return coerce_ints_helper(T2, T1)
else:
coerce_error(T1, T2)
def coerce_ints(T1: type, T2: type, I1: type = T1, I2: type = T2):
if T1 is T2:
return T1()
if T1 is I1 and T2 is I2:
return coerce_ints_helper(T1, T2)
if T1 is I1:
x = coerce_ints_helper(T1, I2)
if type(x) is I2:
return T2()
else:
return x
if T2 is I2:
x = coerce_ints_helper(I1, T2)
if type(x) is I1:
return T1()
else:
return x
x = coerce_ints_helper(I1, I2)
if type(x) is I1:
return T1()
elif type(x) is I2:
return T2()
else:
return x
# Need to support specially:
# - bool
# - int
# - byte
# - Int[N]
# - UInt[N]
# - float
# - float32
# - complex
# - complex64
if T1 is T2:
return T1()
if T1 is NoneType:
return T2()
if T2 is NoneType:
return T1()
if ((isinstance(T1, datetime64) or isinstance(T1, timedelta64)) and
(isinstance(T2, datetime64) or isinstance(T2, timedelta64))):
return dt_promote(T1(), T2())
if T1 is bool:
if T2 is bool:
return coerce(T2, T1)
elif T2 is int:
return int()
elif T2 is byte:
return byte()
elif isinstance(T2, Int) or isinstance(T2, UInt):
return T2()
elif T2 is float:
return float()
elif T2 is float16:
return float16()
elif T2 is float32:
return float32()
elif T2 is complex:
return complex()
elif T2 is complex64:
return complex64()
elif isinstance(T2, datetime64):
return T2()
elif isinstance(T2, timedelta64):
return T2()
elif T1 is int:
if T2 is bool:
return coerce(T2, T1)
elif T2 is int:
return coerce(T2, T1)
elif T2 is byte:
return int()
elif isinstance(T2, Int) or isinstance(T2, UInt):
return coerce_ints(int, T2, I1=Int[64])
elif T2 is float:
return float()
elif T2 is float16:
return float()
elif T2 is float32:
return float()
elif T2 is complex:
return complex()
elif T2 is complex64:
return complex()
elif isinstance(T2, datetime64):
return T2()
elif isinstance(T2, timedelta64):
return T2()
elif T1 is byte:
if T2 is bool:
return coerce(T2, T1)
elif T2 is int:
return coerce(T2, T1)
elif T2 is byte:
return coerce(T2, T1)
elif isinstance(T2, Int) or isinstance(T2, UInt):
return coerce_ints(int, T2, I1=UInt[8])
elif T2 is float:
return float()
elif T2 is float16:
return float16()
elif T2 is float32:
return float32()
elif T2 is complex:
return complex()
elif T2 is complex64:
return complex64()
elif isinstance(T2, datetime64):
return T2()
elif isinstance(T2, timedelta64):
return T2()
elif isinstance(T1, Int):
if T2 is bool:
return coerce(T2, T1)
elif T2 is int:
return coerce(T2, T1)
elif T2 is byte:
return coerce(T2, T1)
elif isinstance(T2, Int) or isinstance(T2, UInt):
return coerce_ints(T1, T2)
elif T2 is float:
return float()
elif T2 is float16:
if T1.N >= 32:
return float()
elif T1.N >= 16:
return float32()
else:
return float16()
elif T2 is float32:
if T1.N >= 32:
return float()
else:
return float32()
elif T2 is complex:
return complex()
elif T2 is complex64:
if T1.N >= 32:
return complex()
else:
return complex64()
elif isinstance(T2, datetime64):
return T2()
elif isinstance(T2, timedelta64):
return T2()
elif isinstance(T1, UInt):
if T2 is bool:
return coerce(T2, T1)
elif T2 is int:
return coerce(T2, T1)
elif T2 is byte:
return coerce(T2, T1)
elif isinstance(T2, Int) or isinstance(T2, UInt):
return coerce_ints(T1, T2)
elif T2 is float:
return float()
elif T2 is float16:
if T1.N >= 32:
return float()
elif T1.N >= 16:
return float32()
else:
return float16()
elif T2 is float32:
if T1.N >= 32:
return float()
else:
return float32()
elif T2 is complex:
return complex()
elif T2 is complex64:
if T1.N >= 32:
return complex()
else:
return complex64()
elif isinstance(T2, datetime64):
return T2()
elif isinstance(T2, timedelta64):
return T2()
elif T1 is float:
if T2 is bool:
return coerce(T2, T1)
elif T2 is int:
return coerce(T2, T1)
elif T2 is byte:
return coerce(T2, T1)
elif isinstance(T2, Int) or isinstance(T2, UInt):
return coerce(T2, T1)
elif T2 is float:
return coerce(T2, T1)
elif T2 is float16:
return float()
elif T2 is float32:
return float()
elif T2 is complex:
return complex()
elif T2 is complex64:
return complex()
elif isinstance(T2, datetime64):
return T2()
elif isinstance(T2, timedelta64):
return T2()
elif T1 is float16:
if T2 is bool:
return coerce(T2, T1)
elif T2 is int:
return coerce(T2, T1)
elif T2 is byte:
return coerce(T2, T1)
elif isinstance(T2, Int) or isinstance(T2, UInt):
return coerce(T2, T1)
elif T2 is float:
return coerce(T2, T1)
elif T2 is float16:
return coerce(T2, T1)
elif T2 is float32:
return float32()
elif T2 is complex:
return complex()
elif T2 is complex64:
return complex64()
elif isinstance(T2, datetime64):
return T2()
elif isinstance(T2, timedelta64):
return T2()
elif T1 is float32:
if T2 is bool:
return coerce(T2, T1)
elif T2 is int:
return coerce(T2, T1)
elif T2 is byte:
return coerce(T2, T1)
elif isinstance(T2, Int) or isinstance(T2, UInt):
return coerce(T2, T1)
elif T2 is float:
return coerce(T2, T1)
elif T2 is float16:
return coerce(T2, T1)
elif T2 is float32:
return coerce(T2, T1)
elif T2 is complex:
return complex()
elif T2 is complex64:
return complex64()
elif isinstance(T2, datetime64):
return T2()
elif isinstance(T2, timedelta64):
return T2()
elif T1 is complex:
if T2 is bool:
return coerce(T2, T1)
elif T2 is int:
return coerce(T2, T1)
elif T2 is byte:
return coerce(T2, T1)
elif isinstance(T2, Int) or isinstance(T2, UInt):
return coerce(T2, T1)
elif T2 is float:
return coerce(T2, T1)
elif T2 is float16:
return coerce(T2, T1)
elif T2 is float32:
return coerce(T2, T1)
elif T2 is complex:
return coerce(T2, T1)
elif T2 is complex64:
return complex()
elif isinstance(T2, datetime64):
return T2()
elif isinstance(T2, timedelta64):
return T2()
elif T1 is complex64:
if T2 is bool:
return coerce(T2, T1)
elif T2 is int:
return coerce(T2, T1)
elif T2 is byte:
return coerce(T2, T1)
elif isinstance(T2, Int) or isinstance(T2, UInt):
return coerce(T2, T1)
elif T2 is float:
return coerce(T2, T1)
elif T2 is float16:
return coerce(T2, T1)
elif T2 is float32:
return coerce(T2, T1)
elif T2 is complex:
return coerce(T2, T1)
elif T2 is complex64:
return coerce(T2, T1)
elif isinstance(T2, datetime64):
return T2()
elif isinstance(T2, timedelta64):
return T2()
elif isinstance(T1, datetime64):
if T2 is bool:
return coerce(T2, T1)
elif T2 is int:
return coerce(T2, T1)
elif T2 is byte:
return coerce(T2, T1)
elif isinstance(T2, Int) or isinstance(T2, UInt):
return coerce(T2, T1)
elif T2 is float:
return coerce(T2, T1)
elif T2 is float16:
return coerce(T2, T1)
elif T2 is float32:
return coerce(T2, T1)
elif T2 is complex:
return coerce(T2, T1)
elif T2 is complex64:
return coerce(T2, T1)
elif isinstance(T1, timedelta64):
if T2 is bool:
return coerce(T2, T1)
elif T2 is int:
return coerce(T2, T1)
elif T2 is byte:
return coerce(T2, T1)
elif isinstance(T2, Int) or isinstance(T2, UInt):
return coerce(T2, T1)
elif T2 is float:
return coerce(T2, T1)
elif T2 is float16:
return coerce(T2, T1)
elif T2 is float32:
return coerce(T2, T1)
elif T2 is complex:
return coerce(T2, T1)
elif T2 is complex64:
return coerce(T2, T1)
coerce_error(T1, T2)
def op_types(T1: type, T2: type):
if (isinstance(T1, datetime64) or isinstance(T2, datetime64) or
isinstance(T1, timedelta64) or isinstance(T2, timedelta64)):
return T1(), T2()
else:
ct = coerce(T1, T2)
return (ct, ct)
def to_float(x):
if (isinstance(x, float) or
isinstance(x, float32) or
isinstance(x, float16)):
return x
if isinstance(x, float128):
return fptrunc(x, float)
if (isinstance(x, int) or
isinstance(x, i64) or
isinstance(x, u64) or
isinstance(x, i32) or
isinstance(x, u32)):
return cast(x, float)
if (isinstance(x, i16) or
isinstance(x, u16)):
return cast(x, float32)
if (isinstance(x, i8) or
isinstance(x, u8) or
isinstance(x, byte)):
return cast(x, float16)
if (isinstance(x, datetime64) or
isinstance(x, timedelta64)):
return nan64() if x._nat else cast(x.value, float)
return float(x)
def str_to_dtype(s: Static[str]):
if s[:1] == '>':
compile_error("big-endian data types are not supported")
if s[:1] == '<' or s[:1] == '|' or s[:1] == '=':
return str_to_dtype(s[1:])
if s == '?' or s == 'bool' or s == 'bool_':
return bool()
if s == 'b' or s == 'byte':
return i8()
if s == 'B' or s == 'ubyte':
return u8()
if s == 'u1' or s == 'uint8':
return u8()
if s == 'u2' or s == 'uint16' or s == 'ushort':
return u16()
if s == 'u4' or s == 'uint32':
return u32()
if s == 'u8' or s == 'uint64' or s == 'ulong' or s == 'ulonglong':
return u64()
if s == 'i1' or s == 'int8':
return i8()
if s == 'i2' or s == 'int16' or s == 'short':
return i16()
if s == 'i4' or s == 'int32':
return i32()
if s == 'i8' or s == 'int64' or s == 'long' or s == 'longlong' or s == 'int' or s == 'int_':
return int()
if s == 'f2' or s == 'float16' or s == 'half':
return float16()
if s == 'f4' or s == 'float32':
return float32()
if s == 'f8' or s == 'float64' or s == 'float' or s == 'float_':
return float()
if s == 'c8' or s == 'complex64':
return complex64()
if s == 'c16' or s == 'complex128' or s == 'complex' or s == 'complex_':
return complex()
if s == 'm8':
return timedelta64['generic', 1]()
if s == 'M8':
return datetime64['generic', 1]()
if s[:11] == 'datetime64[' or s[:12] == 'timedelta64[' or s[:3] == 'm8[' or s[:3] == 'M8[':
return dt_parse(s)
# TODO: add these when applicable
# O -> Python objects
# S / a -> zero-terminated bytes
# U -> Unicode string
# V -> raw data
compile_error("data type '" + s + "' not understood")
def dtype_to_str(dtype: type, include_byteorder: bool = True):
if dtype is bool:
return '|b1' if include_byteorder else 'b1'
if dtype is byte:
return '|u1' if include_byteorder else 'u1'
if dtype is i8:
return '|i1' if include_byteorder else 'i1'
if dtype is u8:
return '|u1' if include_byteorder else 'u1'
if dtype is i16:
return '<i2' if include_byteorder else 'i2'
if dtype is u16:
return '<u2' if include_byteorder else 'u2'
if dtype is i32:
return '<i4' if include_byteorder else 'i4'
if dtype is u32:
return '<u4' if include_byteorder else 'u4'
if dtype is i64 or dtype is int:
return '<i8' if include_byteorder else 'i8'
if dtype is u64:
return '<u8' if include_byteorder else 'u8'
if dtype is float16:
return '<f2' if include_byteorder else 'f2'
if dtype is float32:
return '<f4' if include_byteorder else 'f4'
if dtype is float:
return '<f8' if include_byteorder else 'f8'
if dtype is complex64:
return '<c8' if include_byteorder else 'c8'
if dtype is complex:
return '<c16' if include_byteorder else 'c16'
if isinstance(dtype, timedelta64):
if dtype.base == 'generic':
return '<m8' if include_byteorder else 'm8'
elif dtype.num == 1:
return f'<m8[{dtype.base}]' if include_byteorder else f'm8[{dtype.base}]'
else:
return f'<m8[{dtype.num}{dtype.base}]' if include_byteorder else f'm8[{dtype.num}{dtype.base}]'
if isinstance(dtype, datetime64):
if dtype.base == 'generic':
return '<M8' if include_byteorder else 'M8'
elif dtype.num == 1:
return f'<M8[{dtype.base}]' if include_byteorder else f'M8[{dtype.base}]'
else:
return f'<M8[{dtype.num}{dtype.base}]' if include_byteorder else f'M8[{dtype.num}{dtype.base}]'
# TODO: add these when applicable
# O -> Python objects
# S / a -> zero-terminated bytes
# U -> Unicode string
# V -> raw data
compile_error("unsupported data type: '" + dtype.__name__ + "'")