codon/stdlib/experimental/simd.codon

886 lines
34 KiB
Python

# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
@tuple(container=False) # disallow default __getitem__
class Vec[T, N: Static[int]]:
ZERO_16x8i = Vec[u8,16](u8(0))
FF_16x8i = Vec[u8,16](u8(0xff))
ZERO_32x8i = Vec[u8,32](u8(0))
FF_32x8i = Vec[u8,32](u8(0xff))
ZERO_4x64f = Vec[f64,4](0.0)
@llvm
def _mm_set1_epi8(val: u8) -> Vec[u8, 16]:
%0 = insertelement <16 x i8> undef, i8 %val, i32 0
%1 = shufflevector <16 x i8> %0, <16 x i8> undef, <16 x i32> zeroinitializer
ret <16 x i8> %1
@llvm
def _mm_set1_epi32(val: u32) -> Vec[u32, 4]:
%0 = insertelement <4 x i32> undef, i32 %val, i32 0
%1 = shufflevector <4 x i32> %0, <4 x i32> undef, <4 x i32> zeroinitializer
ret <4 x i32> %1
@llvm
def _mm256_set1_epi8(val: u8) -> Vec[u8, 32]:
%0 = insertelement <32 x i8> undef, i8 %val, i32 0
%1 = shufflevector <32 x i8> %0, <32 x i8> undef, <32 x i32> zeroinitializer
ret <32 x i8> %1
@llvm
def _mm256_set1_epi32(val: u32) -> Vec[u32, 8]:
%0 = insertelement <8 x i32> undef, i32 %val, i32 0
%1 = shufflevector <8 x i32> %0, <8 x i32> undef, <8 x i32> zeroinitializer
ret <8 x i32> %1
@llvm
def _mm256_set1_epi64x(val: u64) -> Vec[u64, 4]:
%0 = insertelement <4 x i64> undef, i64 %val, i32 0
%1 = shufflevector <4 x i64> %0, <4 x i64> undef, <4 x i32> zeroinitializer
ret <4 x i64> %1
@llvm
def _mm512_set1_epi64(val: u64) -> Vec[u64, 8]:
%0 = insertelement <8 x i64> undef, i64 %val, i32 0
%1 = shufflevector <8 x i64> %0, <8 x i64> undef, <8 x i32> zeroinitializer
ret <8 x i64> %1
@llvm
def _mm_load_epi32(data) -> Vec[u32, 4]:
%0 = bitcast i32* %data to <4 x i32>*
%1 = load <4 x i32>, <4 x i32>* %0, align 1
ret <4 x i32> %1
@llvm
def _mm_loadu_si128(data) -> Vec[u8, 16]:
%0 = bitcast i8* %data to <16 x i8>*
%1 = load <16 x i8>, <16 x i8>* %0, align 1
ret <16 x i8> %1
@llvm
def _mm256_loadu_si256(data) -> Vec[u8, 32]:
%0 = bitcast i8* %data to <32 x i8>*
%1 = load <32 x i8>, <32 x i8>* %0, align 1
ret <32 x i8> %1
@llvm
def _mm256_load_epi32(data) -> Vec[u32, 8]:
%0 = bitcast i32* %data to <8 x i32>*
%1 = load <8 x i32>, <8 x i32>* %0
ret <8 x i32> %1
@llvm
def _mm256_load_epi64(data) -> Vec[u64, 4]:
%0 = bitcast i64* %data to <4 x i64>*
%1 = load <4 x i64>, <4 x i64>* %0
ret <4 x i64> %1
@llvm
def _mm512_load_epi64(data) -> Vec[u64, 8]:
%0 = bitcast i64* %data to <8 x i64>*
%1 = load <8 x i64>, <8 x i64>* %0
ret <8 x i64> %1
@llvm
def _mm256_set1_ps(val: f32) -> Vec[f32, 8]:
%0 = insertelement <8 x float> undef, float %val, i32 0
%1 = shufflevector <8 x float> %0, <8 x float> undef, <8 x i32> zeroinitializer
ret <8 x float> %1
@llvm
def _mm256_set1_pd(val: f64) -> Vec[f64, 4]:
%0 = insertelement <4 x double> undef, double %val, i64 0
%1 = shufflevector <4 x double> %0, <4 x double> undef, <4 x i32> zeroinitializer
ret <4 x double> %1
@llvm
def _mm512_set1_pd(val: f64) -> Vec[f64, 8]:
%0 = insertelement <8 x double> undef, double %val, i64 0
%1 = shufflevector <8 x double> %0, <8 x double> undef, <8 x i32> zeroinitializer
ret <8 x double> %1
@llvm
def _mm512_set1_ps(val: f32) -> Vec[f32, 16]:
%0 = insertelement <16 x float> undef, float %val, i32 0
%1 = shufflevector <16 x float> %0, <16 x float> undef, <16 x i32> zeroinitializer
ret <16 x float> %1
@llvm
def _mm256_loadu_ps(data: Ptr[f32]) -> Vec[f32, 8]:
%0 = bitcast float* %data to <8 x float>*
%1 = load <8 x float>, <8 x float>* %0
ret <8 x float> %1
@llvm
def _mm256_loadu_pd(data: Ptr[f64]) -> Vec[f64, 4]:
%0 = bitcast double* %data to <4 x double>*
%1 = load <4 x double>, <4 x double>* %0
ret <4 x double> %1
@llvm
def _mm512_loadu_pd(data: Ptr[f64]) -> Vec[f64, 8]:
%0 = bitcast double* %data to <8 x double>*
%1 = load <8 x double>, <8 x double>* %0
ret <8 x double> %1
@llvm
def _mm512_loadu_ps(data: Ptr[f32]) -> Vec[f32, 16]:
%0 = bitcast float* %data to <16 x float>*
%1 = load <16 x float>, <16 x float>* %0
ret <16 x float> %1
@llvm
def _mm256_cvtepi8_epi32(vec: Vec[u8, 16]) -> Vec[u32, 8]:
%0 = shufflevector <16 x i8> %vec, <16 x i8> undef, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
%1 = sext <8 x i8> %0 to <8 x i32>
ret <8 x i32> %1
@llvm
def _mm512_cvtepi8_epi64(vec: Vec[u8, 32]) -> Vec[u32, 16]:
%0 = shufflevector <32 x i8> %vec, <32 x i8> undef, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
%1 = sext <16 x i8> %0 to <16 x i32>
ret <16 x i32> %1
@llvm
def _mm256_castsi256_ps(vec: Vec[u32, 8]) -> Vec[f32, 8]:
%0 = bitcast <8 x i32> %vec to <8 x float>
ret <8 x float> %0
@llvm
def _mm256_castsi256_pd(vec: Vec[u64, 4]) -> Vec[f64, 4]:
%0 = bitcast <4 x i64> %vec to <4 x double>
ret <4 x double> %0
@llvm
def _mm512_castsi512_ps(vec: Vec[u32, 16]) -> Vec[f32, 16]:
%0 = bitcast <16 x i32> %vec to <16 x float>
ret <16 x float> %0
def __new__(x, T: type, N: Static[int]) -> Vec[T, N]:
if isinstance(T, u8) and N == 16:
if isinstance(x, u8) or isinstance(x, byte): # TODO: u8<->byte
return Vec._mm_set1_epi8(x)
if isinstance(x, Ptr[u8]) or isinstance(x, Ptr[byte]):
return Vec._mm_loadu_si128(x)
if isinstance(x, str):
return Vec._mm_loadu_si128(x.ptr)
if isinstance(T, u8) and N == 32:
if isinstance(x, u8) or isinstance(x, byte): # TODO: u8<->byte
return Vec._mm256_set1_epi8(x)
if isinstance(x, Ptr[u8]) or isinstance(x, Ptr[byte]):
return Vec._mm256_loadu_si256(x)
if isinstance(x, str):
return Vec._mm256_loadu_si256(x.ptr)
if isinstance(T, u32) and N == 4:
if isinstance(x, int):
assert x >= 0, "SIMD: No support for negative int vectors added yet."
return Vec._mm_set1_epi32(u32(x))
if isinstance(x, u32):
return Vec._mm_set1_epi32(x)
if isinstance(x, Ptr[u32]):
return Vec._mm_load_epi32(x)
if isinstance(x, List[u32]):
return Vec._mm_load_epi32(x.arr.ptr)
if isinstance(x, str):
return Vec._mm_load_epi32(x.ptr)
if isinstance(T, u32) and N == 8:
if isinstance(x, int):
assert x >= 0, "SIMD: No support for negative int vectors added yet."
return Vec._mm256_set1_epi32(u32(x))
if isinstance(x, u64):
return Vec._mm256_set1_epi32(x)
if isinstance(x, Ptr[u64]):
return Vec._mm256_load_epi32(x)
if isinstance(x, List[u64]):
return Vec._mm256_load_epi32(x.arr.ptr)
if isinstance(x, str):
return Vec._mm256_load_epi32(x.ptr)
if isinstance(T, u64) and N == 4:
if isinstance(x, int):
assert x >= 0, "SIMD: No support for negative int vectors added yet."
return Vec._mm256_set1_epi64x(u64(x))
if isinstance(x, u64):
return Vec._mm256_set1_epi64x(x)
if isinstance(x, Ptr[u64]):
return Vec._mm256_load_epi64(x)
if isinstance(x, List[u64]):
return Vec._mm256_load_epi64(x.arr.ptr)
if isinstance(x, str):
return Vec._mm256_load_epi64(x.ptr)
if isinstance(T, u64) and N == 8:
if isinstance(x, int):
assert x >= 0, "SIMD: No support for negative int vectors added yet."
return Vec._mm512_set1_epi64(u64(x))
if isinstance(x, u64):
return Vec._mm512_set1_epi64(x)
if isinstance(x, Ptr[u64]):
return Vec._mm512_load_epi64(x)
if isinstance(x, List[u64]):
return Vec._mm512_load_epi64(x.arr.ptr)
if isinstance(x, str):
return Vec._mm512_load_epi64(x.ptr)
if isinstance(T, f32) and N == 8:
if isinstance(x, f32):
return Vec._mm256_set1_ps(x)
if isinstance(x, Ptr[f32]):
return Vec._mm256_loadu_ps(x)
if isinstance(x, List[f32]):
return Vec._mm256_loadu_ps(x.arr.ptr)
if isinstance(x, Vec[u8, 16]):
return Vec._mm256_castsi256_ps(Vec._mm256_cvtepi8_epi32(x))
if isinstance(T, f32) and N == 16:
if isinstance(x, f32):
return Vec._mm512_set1_ps(x)
if isinstance(x, Ptr[f32]):
return Vec._mm512_loadu_ps(x)
if isinstance(x, List[f32]):
return Vec._mm512_loadu_ps(x.arr.ptr)
if isinstance(x, Vec[u8, 32]):
return Vec._mm512_castsi512_ps(Vec._mm512_cvtepi8_epi64(x))
if isinstance(T, f64) and N == 4:
if isinstance(x, f64):
return Vec._mm256_set1_pd(x)
if isinstance(x, Ptr[f64]):
return Vec._mm256_loadu_pd(x)
if isinstance(x, List[f64]):
return Vec._mm256_loadu_pd(x.arr.ptr)
if isinstance(T, f64) and N == 8:
if isinstance(x, f64):
return Vec._mm512_set1_pd(x)
if isinstance(x, Ptr[f64]):
return Vec._mm512_loadu_pd(x)
if isinstance(x, List[f64]):
return Vec._mm512_loadu_pd(x.arr.ptr)
compile_error("invalid SIMD vector constructor")
def __new__(x: str, offset: int = 0) -> Vec[u8, N]:
return Vec(x.ptr + offset, u8, N)
def __new__(x: List[T], offset: int = 0) -> Vec[T, N]:
return Vec(x.arr.ptr + offset, T, N)
def __new__(x) -> Vec[T, N]:
return Vec(x, T, N)
@llvm
def _mm_cmpeq_epi8(x: Vec[u8, 16], y: Vec[u8, 16]) -> Vec[u8, 16]:
%0 = icmp eq <16 x i8> %x, %y
%1 = sext <16 x i1> %0 to <16 x i8>
ret <16 x i8> %1
def __eq__(self: Vec[u8, 16], other: Vec[u8, 16]) -> Vec[u8, 16]:
return Vec._mm_cmpeq_epi8(self, other)
@llvm
def _mm256_cmpeq_epi8(x: Vec[u8, 32], y: Vec[u8, 32]) -> Vec[u8, 32]:
%0 = icmp eq <32 x i8> %x, %y
%1 = sext <32 x i1> %0 to <32 x i8>
ret <32 x i8> %1
def __eq__(self: Vec[u8, 32], other: Vec[u8, 32]) -> Vec[u8, 32]:
return Vec._mm256_cmpeq_epi8(self, other)
@llvm
def _mm_andnot_si128(x: Vec[u8, 16], y: Vec[u8, 16]) -> Vec[u8, 16]:
%0 = xor <16 x i8> %x, <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
%1 = and <16 x i8> %y, %0
ret <16 x i8> %1
def __ne__(self: Vec[u8, 16], other: Vec[u8, 16]) -> Vec[u8, 16]:
return Vec._mm_andnot_si128((self == other), Vec.FF_16x8i)
@llvm
def _mm256_andnot_si256(x: Vec[u8, 32], y: Vec[u8, 32]) -> Vec[u8, 32]:
%0 = xor <32 x i8> %x, <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
%1 = and <32 x i8> %y, %0
ret <32 x i8> %1
def __ne__(self: Vec[u8, 32], other: Vec[u8, 32]) -> Vec[u8, 32]:
return Vec._mm256_andnot_si256((self == other), Vec.FF_32x8i)
def __eq__(self: Vec[u8, 16], other: bool) -> Vec[u8, 16]:
if not other:
return Vec._mm_andnot_si128(self, Vec.FF_16x8i)
else:
return Vec._mm_andnot_si128(self, Vec.ZERO_16x8i)
def __eq__(self: Vec[u8, 32], other: bool) -> Vec[u8, 32]:
if not other:
return Vec._mm256_andnot_si256(self, Vec.FF_32x8i)
else:
return Vec._mm256_andnot_si256(self, Vec.ZERO_32x8i)
@llvm
def _mm_and_si128(x: Vec[u8, 16], y: Vec[u8, 16]) -> Vec[u8, 16]:
%0 = and <16 x i8> %x, %y
ret <16 x i8> %0
def __and__(self: Vec[u8, 16], other: Vec[u8, 16]) -> Vec[u8, 16]:
return Vec._mm_and_si128(self, other)
@llvm
def _mm_and_si256(x: Vec[u8, 32], y: Vec[u8, 32]) -> Vec[u8, 32]:
%0 = and <32 x i8> %x, %y
ret <32 x i8> %0
def __and__(self: Vec[u8, 32], other: Vec[u8, 32]) -> Vec[u8, 32]:
return Vec._mm_and_si256(self, other)
@llvm
def _mm256_and_si256(x: Vec[u64, 4], y: Vec[u64, 4]) -> Vec[u64, 4]:
%0 = and <4 x i64> %x, %y
ret <4 x i64> %0
def __and__(self: Vec[u64, 4], other: Vec[u64, 4]) -> Vec[u64, 4]:
return Vec._mm256_and_si256(self, other)
@llvm
def _mm256_and_ps(x: Vec[f32, 8], y: Vec[f32, 8]) -> Vec[f32, 8]:
%0 = bitcast <8 x float> %x to <8 x i32>
%1 = bitcast <8 x float> %y to <8 x i32>
%2 = and <8 x i32> %0, %1
%3 = bitcast <8 x i32> %2 to <8 x float>
ret <8 x float> %3
def __and__(self: Vec[f32, 8], other: Vec[f32, 8]) -> Vec[f32, 8]:
return Vec._mm256_and_ps(self, other)
@llvm
def _mm512_and_ps(x: Vec[f32, 16], y: Vec[f32, 16]) -> Vec[f32, 16]:
%0 = bitcast <16 x float> %x to <16 x i32>
%1 = bitcast <16 x float> %y to <16 x i32>
%2 = and <16 x i32> %0, %1
%3 = bitcast <16 x i32> %2 to <16 x float>
ret <16 x float> %3
def __and__(self: Vec[f32, 16], other: Vec[f32, 16]) -> Vec[f32, 16]:
return Vec._mm512_and_ps(self, other)
@llvm
def _mm_or_si128(x: Vec[u8, 16], y: Vec[u8, 16]) -> Vec[u8, 16]:
%0 = or <16 x i8> %x, %y
ret <16 x i8> %0
def __or__(self: Vec[u8, 16], other: Vec[u8, 16]) -> Vec[u8, 16]:
return Vec._mm_or_si128(self, other)
@llvm
def _mm_or_si256(x: Vec[u8, 32], y: Vec[u8, 32]) -> Vec[u8, 32]:
%0 = or <32 x i8> %x, %y
ret <32 x i8> %0
def __or__(self: Vec[u8, 32], other: Vec[u8, 32]) -> Vec[u8, 32]:
return Vec._mm_or_si256(self, other)
@llvm
def _mm256_or_ps(x: Vec[f32, 8], y: Vec[f32, 8]) -> Vec[f32, 8]:
%0 = bitcast <8 x float> %x to <8 x i32>
%1 = bitcast <8 x float> %y to <8 x i32>
%2 = or <8 x i32> %0, %1
%3 = bitcast <8 x i32> %2 to <8 x float>
ret <8 x float> %3
def __or__(self: Vec[f32, 8], other: Vec[f32, 8]) -> Vec[f32, 8]:
return Vec._mm256_or_ps(self, other)
@llvm
def _mm512_or_ps(x: Vec[f32, 16], y: Vec[f32, 16]) -> Vec[f32, 16]:
%0 = bitcast <16 x float> %x to <16 x i32>
%1 = bitcast <16 x float> %y to <16 x i32>
%2 = or <16 x i32> %0, %1
%3 = bitcast <16 x i32> %2 to <16 x float>
ret <16 x float> %3
def __or__(self: Vec[f32, 16], other: Vec[f32, 16]) -> Vec[f32, 16]:
return Vec._mm512_or_ps(self, other)
@llvm
def _mm_bsrli_si128_8(vec: Vec[u8, 16]) -> Vec[u8, 16]:
%0 = shufflevector <16 x i8> %vec, <16 x i8> zeroinitializer, <16 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15, i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23>
ret <16 x i8> %0
@llvm
def _mm256_add_ps(x: Vec[f32, 8], y: Vec[f32, 8]) -> Vec[f32, 8]:
%0 = fadd <8 x float> %x, %y
ret <8 x float> %0
def __rshift__(self: Vec[u8, 16], shift: Static[int]) -> Vec[u8, 16]:
if shift == 0:
return self
elif shift == 8:
return Vec._mm_bsrli_si128_8(self)
else:
compile_error("invalid bitshift")
@llvm
def _mm_bsrli_256(vec: Vec[u8, 32]) -> Vec[u8, 32]:
%0 = shufflevector <32 x i8> %vec, <32 x i8> zeroinitializer, <32 x i32> <i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31, i32 32, i32 33, i32 34, i32 35, i32 36, i32 37, i32 38, i32 39, i32 40, i32 41, i32 42, i32 43, i32 44, i32 45, i32 46, i32 47>
ret <32 x i8> %0
def __rshift__(self: Vec[u8, 32], shift: Static[int]) -> Vec[u8, 32]:
if shift == 0:
return self
elif shift == 16:
return Vec._mm_bsrli_256(self)
else:
compile_error("invalid bitshift")
# @llvm # https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-sse-vector-sum-or-other-reduction
# def sum(self: Vec[f32, 8]) -> f32:
# %0 = shufflevector <8 x float> %self, <8 x float> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
# %1 = shufflevector <8 x float> %self, <8 x float> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
# %2 = fadd <4 x float> %0, %1
# %3 = shufflevector <4 x float> %2, <4 x float> undef, <4 x i32> <i32 1, i32 undef, i32 3, i32 undef>
# %4 = fadd <4 x float> %2, %3
# %5 = shufflevector <4 x float> %4, <4 x float> poison, <4 x i32> <i32 2, i32 undef, i32 undef, i32 undef>
# %6 = fadd <4 x float> %4, %5
# %7 = extractelement <4 x float> %6, i32 0
# ret float %7
def sum(self: Vec[f32, 8], x: f32 = f32(0.0)) -> f32:
return x + self[0] + self[1] + self[2] + self[3] + self[4] + self[5] + self[6] + self[7]
def __repr__(self):
return f"<{','.join(self.scatter())}>"
# Methods below added ad-hoc. TODO: Add Intel intrinsics equivalents for them and integrate them neatly above.
# Constructors
@llvm
def generic_init[SLS: Static[int]](val: T) -> Vec[T, SLS]:
%0 = insertelement <{=SLS} x {=T}> undef, {=T} %val, i32 0
%1 = shufflevector <{=SLS} x {=T}> %0, <{=SLS} x {=T}> undef, <{=SLS} x i32> zeroinitializer
ret <{=SLS} x {=T}> %1
@llvm
def generic_load[SLS: Static[int]](data: Ptr[T]) -> Vec[T, SLS]:
%0 = bitcast T* %data to <{=SLS} x {=T}>*
%1 = load <{=SLS} x {=T}>, <{=SLS} x {=T}>* %0
ret <{=SLS} x {=T}> %1
# Bitwise intrinsics
@llvm
def __and__(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = and <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def __and__(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = and <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
@llvm
def __or__(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = or <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def __or__(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = or <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
@llvm
def __xor__(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = xor <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def __xor__(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = xor <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
@llvm
def __lshift__(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = shl <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def __lshift__(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = shl <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
@llvm
def __rshift__(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = lshr <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def __rshift__(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = lshr <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
@llvm
def bit_flip(self: Vec[u1, N]) -> Vec[u1, N]:
%0 = insertelement <{=N} x i1> undef, i1 1, i32 0
%1 = shufflevector <{=N} x i1> %0, <{=N} x i1> undef, <{=N} x i32> zeroinitializer
%2 = xor <{=N} x i1> %self, %1
ret <{=N} x i1> %2
@llvm
def shift_half(self: Vec[u128, N]) -> Vec[u128, N]:
%0 = insertelement <{=N} x i128> undef, i128 64, i32 0
%1 = shufflevector <{=N} x i128> %0, <{=N} x i128> undef, <{=N} x i32> zeroinitializer
%2 = lshr <{=N} x i128> %self, %1
ret <{=N} x i128> %2
# Comparisons
@llvm
def __ge__(self: Vec[u64, N], other: Vec[u64, N]) -> Vec[u1, N]:
%0 = icmp uge <{=N} x i64> %self, %other
ret <{=N} x i1> %0
@llvm
def __ge__(self: Vec[u64, N], other: u64) -> Vec[u1, N]:
%0 = insertelement <{=N} x i64> undef, i64 %other, i32 0
%1 = shufflevector <{=N} x i64> %0, <{=N} x i64> undef, <{=N} x i32> zeroinitializer
%2 = icmp uge <{=N} x i64> %self, %1
ret <{=N} x i1> %2
@llvm
def __ge__(self: Vec[f64, N], other: Vec[f64, N]) -> Vec[u1, N]:
%0 = fcmp oge <{=N} x double> %self, %other
ret <{=N} x i1> %0
@llvm
def __ge__(self: Vec[f64, N], other: f64) -> Vec[u1, N]:
%0 = insertelement <{=N} x double> undef, double %other, i32 0
%1 = shufflevector <{=N} x double> %0, <{=N} x double> undef, <{=N} x i32> zeroinitializer
%2 = fcmp oge <{=N} x double> %self, %1
ret <{=N} x i1> %2
@llvm
def __le__(self: Vec[u64, N], other: Vec[u64, N]) -> Vec[u1, N]:
%0 = icmp ule <{=N} x i64> %self, %other
ret <{=N} x i1> %0
@llvm
def __le__(self: Vec[u64, N], other: u64) -> Vec[u1, N]:
%0 = insertelement <{=N} x i64> undef, i64 %other, i32 0
%1 = shufflevector <{=N} x i64> %0, <{=N} x i64> undef, <{=N} x i32> zeroinitializer
%2 = icmp ule <{=N} x i64> %self, %1
ret <{=N} x i1> %2
# Arithmetic intrinsics
@llvm
def __neg__(self: Vec[T, N]) -> Vec[T, N]:
%0 = sub <{=N} x {=T}> zeroinitializer, %self
ret <{=N} x {=T}> %0
@llvm
def __mod__(self: Vec[u64, N], other: Vec[u64, N]) -> Vec[u64, N]:
%0 = urem <{=N} x i64> %self, %other
ret <{=N} x i64> %0
@llvm
def __mod__(self: Vec[u64, N], other: u64) -> Vec[u64, N]:
%0 = insertelement <{=N} x i64> undef, i64 %other, i32 0
%1 = shufflevector <{=N} x i64> %0, <{=N} x i64> undef, <{=N} x i32> zeroinitializer
%2 = urem <{=N} x i64> %self, %1
ret <{=N} x i64> %2
@llvm
def add(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = add <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def add(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = add <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
@llvm
def fadd(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = fadd <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def fadd(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = fadd <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
def __add__(self: Vec[T, N], other) -> Vec[T, N]:
if isinstance(T, u8) or isinstance(T, u64) or isinstance(T, u128):
return self.add(other)
if isinstance(T, f32) or isinstance(T, f64):
return self.fadd(other)
compile_error("invalid SIMD vector addition")
@llvm
def sub(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = sub <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def sub(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = sub <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
@llvm
def fsub(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = fsub <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def fsub(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = fsub <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
def __sub__(self: Vec[T, N], other) -> Vec[T, N]:
if isinstance(T, u8) or isinstance(T, u64):
return self.sub(other)
if isinstance(T, f32) or isinstance(T, f64):
return self.fsub(other)
compile_error("invalid SIMD vector subtraction")
@llvm
def mul(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = mul <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def fmul(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = fmul <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def mul(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = mul <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
@llvm
def fmul(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = fmul <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
def __mul__(self: Vec[T, N], other) -> Vec[T, N]:
if isinstance(T, u8) or isinstance(T, u64) or isinstance(T, u128):
return self.mul(other)
if isinstance(T, f32) or isinstance(T, f64):
return self.fmul(other)
compile_error("invalid SIMD vector multiplication")
@llvm
def __truediv__(self: Vec[u64, N], other: Vec[u64, N]) -> Vec[f64, N]:
%0 = uitofp <{=N} x i64> %self to <{=N} x double>
%1 = uitofp <{=N} x i64> %other to <{=N} x double>
%2 = fdiv <{=N} x double> %0, %1
ret <{=N} x double> %2
@llvm
def __truediv__(self: Vec[u64, 4], other: Vec[u64, 4]) -> Vec[f64, 4]:
%0 = uitofp <4 x i64> %self to <4 x double>
%1 = uitofp <4 x i64> %other to <4 x double>
%2 = fdiv <4 x double> %0, %1
ret <4 x double> %2
@llvm
def __truediv__(self: Vec[u64, 8], other: Vec[u64, 8]) -> Vec[f64, 8]:
%0 = uitofp <8 x i64> %self to <8 x double>
%1 = uitofp <8 x i64> %other to <8 x double>
%2 = fdiv <8 x double> %0, %1
ret <8 x double> %2
@llvm
def __truediv__(self: Vec[u64, 8], other: u64) -> Vec[f64, 8]:
%0 = uitofp <8 x i64> %self to <8 x double>
%1 = uitofp i64 %other to double
%2 = insertelement <8 x double> undef, double %1, i32 0
%3 = shufflevector <8 x double> %0, <8 x double> undef, <8 x i32> zeroinitializer
%4 = fdiv <8 x double> %0, %3
ret <8 x double> %4
@llvm
def add_overflow(self: Vec[u64, N], other: Vec[u64, N]) -> Tuple[Vec[u64, N], Vec[u1, N]]:
declare {<{=N} x i64>, <{=N} x i1>} @llvm.uadd.with.overflow.v{=N}i64(<{=N} x i64>, <{=N} x i64>)
%0 = call {<{=N} x i64>, <{=N} x i1>} @llvm.uadd.with.overflow.v{=N}i64(<{=N} x i64> %self, <{=N} x i64> %other)
ret {<{=N} x i64>, <{=N} x i1>} %0
@llvm
def add_overflow(self: Vec[u64, N], other: u64) -> Tuple[Vec[u64, N], Vec[u1, N]]:
declare {<{=N} x i64>, <{=N} x i1>} @llvm.uadd.with.overflow.v{=N}i64(<{=N} x i64>, <{=N} x i64>)
%0 = insertelement <{=N} x i64> undef, i64 %other, i32 0
%1 = shufflevector <{=N} x i64> %0, <{=N} x i64> undef, <{=N} x i32> zeroinitializer
%2 = call {<{=N} x i64>, <{=N} x i1>} @llvm.uadd.with.overflow.v{=N}i64(<{=N} x i64> %self, <{=N} x i64> %1)
ret {<{=N} x i64>, <{=N} x i1>} %2
@llvm
def sub_overflow(self: Vec[u64, N], other: Vec[u64, N]) -> Tuple[Vec[u64, N], Vec[u1, N]]:
declare {<{=N} x i64>, <{=N} x i1>} @llvm.usub.with.overflow.v{=N}i64(<{=N} x i64>, <{=N} x i64>)
%0 = call {<{=N} x i64>, <{=N} x i1>} @llvm.usub.with.overflow.v{=N}i64(<{=N} x i64> %self, <{=N} x i64> %other)
ret {<{=N} x i64>, <{=N} x i1>} %0
def sub_overflow_commutative(self: Vec[u64, N], other: u64) -> Tuple[Vec[u64, N], Vec[u1, N]]:
return Vec[u64, N](other).sub_overflow(self)
@llvm
def sub_overflow(self: Vec[u64, N], other: u64) -> Tuple[Vec[u64, N], Vec[u1, N]]:
declare {<{=N} x i64>, <{=N} x i1>} @llvm.usub.with.overflow.v{=N}i64(<{=N} x i64>, <{=N} x i64>)
%0 = insertelement <{=N} x i64> undef, i64 %other, i32 0
%1 = shufflevector <{=N} x i64> %0, <{=N} x i64> undef, <{=N} x i32> zeroinitializer
%2 = call {<{=N} x i64>, <{=N} x i1>} @llvm.usub.with.overflow.v{=N}i64(<{=N} x i64> %self, <{=N} x i64> %1)
ret {<{=N} x i64>, <{=N} x i1>} %2
@llvm
def zext_mul(self: Vec[u64, N], other: Vec[u64, N]) -> Vec[u128, N]:
%0 = zext <{=N} x i64> %self to <{=N} x i128>
%1 = zext <{=N} x i64> %other to <{=N} x i128>
%2 = mul nuw <{=N} x i128> %0, %1
ret <{=N} x i128> %2
@llvm
def zext_mul(self: Vec[u64, N], other: u64) -> Vec[u128, N]:
%0 = zext <{=N} x i64> %self to <{=N} x i128>
%1 = insertelement <{=N} x i64> undef, i64 %other, i32 0
%2 = shufflevector <{=N} x i64> %1, <{=N} x i64> undef, <{=N} x i32> zeroinitializer
%3 = zext <{=N} x i64> %2 to <{=N} x i128>
%4 = mul nuw <{=N} x i128> %0, %3
ret <{=N} x i128> %4
@nocapture
@llvm
def mulx(self: Vec[u64, N], other: Vec[u64, N], hi: Ptr[Vec[u64, N]]) -> Vec[u64, N]:
%0 = zext <{=N} x i64> %self to <{=N} x i128>
%1 = zext <{=N} x i64> %other to <{=N} x i128>
%2 = mul nuw <{=N} x i128> %0, %1
%3 = lshr <{=N} x i128> %2, <i128 64, i128 64, i128 64, i128 64, i128 64, i128 64, i128 64, i128 64>
%4 = trunc <{=N} x i128> %3 to <{=N} x i64>
store <{=N} x i64> %4, <{=N} x i64>* %hi, align 8
%5 = trunc <{=N} x i128> %2 to <{=N} x i64>
ret <{=N} x i64> %5
def mulhi(self: Vec[u64, N], other: Vec[u64, N]) -> Vec[u64, N]:
hi = Ptr[Vec[u64, N]](1)
self.mulx(other, hi)
return hi[0]
@llvm
def sqrt(self: Vec[f64, N]) -> Vec[f64, N]:
declare <{=N} x double> @llvm.sqrt.v{=N}f64(<{=N} x double>)
%0 = call <{=N} x double> @llvm.sqrt.v{=N}f64(<{=N} x double> %self)
ret <{=N} x double> %0
@llvm
def log(self: Vec[f64, N]) -> Vec[f64, N]:
declare <{=N} x double> @llvm.log.v{=N}f64(<{=N} x double>)
%0 = call <{=N} x double> @llvm.log.v{=N}f64(<{=N} x double> %self)
ret <{=N} x double> %0
@llvm
def cos(self: Vec[f64, N]) -> Vec[f64, N]:
declare <{=N} x double> @llvm.cos.v{=N}f64(<{=N} x double>)
%0 = call <{=N} x double> @llvm.cos.v{=N}f64(<{=N} x double> %self)
ret <{=N} x double> %0
@llvm
def fabs(self: Vec[f64, N]) -> Vec[f64, N]:
declare <{=N} x double> @llvm.fabs.v{=N}f64(<{=N} x double>)
%0 = call <{=N} x double> @llvm.fabs.v{=N}f64(<{=N} x double> %self)
ret <{=N} x double> %0
# Conversion intrinsics
@llvm
def zext_double(self: Vec[u64, N]) -> Vec[u128, N]:
%0 = zext <{=N} x i64> %self to <{=N} x i128>
ret <{=N} x i128> %0
@llvm
def trunc_half(self: Vec[u128, N]) -> Vec[u64, N]:
%0 = trunc <{=N} x i128> %self to <{=N} x i64>
ret <{=N} x i64> %0
@llvm
def shift_trunc_half(self: Vec[u128, N]) -> Vec[u64, N]:
%0 = insertelement <{=N} x i128> undef, i128 64, i32 0
%1 = shufflevector <{=N} x i128> %0, <{=N} x i128> undef, <{=N} x i32> zeroinitializer
%2 = lshr <{=N} x i128> %self, %1
%3 = trunc <{=N} x i128> %2 to <{=N} x i64>
ret <{=N} x i64> %3
@llvm
def to_u64(self: Vec[f64, N]) -> Vec[u64, N]:
%0 = fptoui <{=N} x double> %self to <{=N} x i64>
ret <{=N} x i64> %0
@llvm
def to_u64(self: Vec[u1, N]) -> Vec[u64, N]:
%0 = zext <{=N} x i1> %self to <{=N} x i64>
ret <{=N} x i64> %0
@llvm
def to_float(self: Vec[T, N]) -> Vec[f64, N]:
%0 = uitofp <{=N} x {=T}> %self to <{=N} x double>
ret <{=N} x double> %0
# Predication intrinsics
@llvm
def sub_if(self: Vec[T, N], other: Vec[T, N], mask: Vec[u1, N]) -> Vec[T, N]:
%0 = sub <{=N} x {=T}> %self, %other
%1 = select <{=N} x i1> %mask, <{=N} x {=T}> %0, <{=N} x {=T}> %self
ret <{=N} x {=T}> %1
@llvm
def sub_if(self: Vec[T, N], other: T, mask: Vec[u1, N]) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = sub <{=N} x {=T}> %self, %1
%3 = select <{=N} x i1> %mask, <{=N} x {=T}> %2, <{=N} x {=T}> %self
ret <{=N} x {=T}> %3
# Gather-scatter
@llvm
def __getitem__(self: Vec[T, N], idx) -> T:
%0 = extractelement <{=N} x {=T}> %self, i64 %idx
ret {=T} %0
# Misc
def copy(self: Vec[T, N]) -> Vec[T, N]:
return self
@llvm
def mask(self: Vec[T, N], mask: Vec[u1, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = select <{=N} x i1> %mask, <{=N} x {=T}> %self, <{=N} x {=T}> %other
ret <{=N} x {=T}> %0
def scatter(self: Vec[T, N]) -> List[T]:
return [self[i] for i in staticrange(N)]
u8x16 = Vec[u8, 16]
u8x32 = Vec[u8, 32]
f32x8 = Vec[f32, 8]
@llvm
def bitcast_scatter[N: Static[int]](ptr_in: Ptr[Vec[u64, N]]) -> Ptr[u64]:
%0 = bitcast <{=N} x i64>* %ptr_in to i64*
ret i64* %0
@llvm
def bitcast_vectorize[N: Static[int]](ptr_in: Ptr[u64]) -> Ptr[Vec[u64, N]]:
%0 = bitcast i64* %ptr_in to <{=N} x i64>*
ret <{=N} x i64>* %0