# Copyright (C) 2022-2023 Exaloop Inc. @tuple(container=False) # disallow default __getitem__ class Vec[T, N: Static[int]]: ZERO_16x8i = Vec[u8,16](u8(0)) FF_16x8i = Vec[u8,16](u8(0xff)) ZERO_32x8i = Vec[u8,32](u8(0)) FF_32x8i = Vec[u8,32](u8(0xff)) @llvm def _mm_set1_epi8(val: u8) -> Vec[u8, 16]: %0 = insertelement <16 x i8> undef, i8 %val, i32 0 %1 = shufflevector <16 x i8> %0, <16 x i8> undef, <16 x i32> zeroinitializer ret <16 x i8> %1 @llvm def _mm256_set1_epi8(val: u8) -> Vec[u8, 32]: %0 = insertelement <32 x i8> undef, i8 %val, i32 0 %1 = shufflevector <32 x i8> %0, <32 x i8> undef, <32 x i32> zeroinitializer ret <32 x i8> %1 @llvm def _mm_loadu_si128(data) -> Vec[u8, 16]: %0 = bitcast i8* %data to <16 x i8>* %1 = load <16 x i8>, <16 x i8>* %0, align 1 ret <16 x i8> %1 @llvm def _mm256_loadu_si256(data) -> Vec[u8, 32]: %0 = bitcast i8* %data to <32 x i8>* %1 = load <32 x i8>, <32 x i8>* %0, align 1 ret <32 x i8> %1 @llvm def _mm256_set1_ps(val: f32) -> Vec[f32, 8]: %0 = insertelement <8 x float> undef, float %val, i32 0 %1 = shufflevector <8 x float> %0, <8 x float> undef, <8 x i32> zeroinitializer ret <8 x float> %1 @llvm def _mm512_set1_ps(val: f32) -> Vec[f32, 16]: %0 = insertelement <16 x float> undef, float %val, i32 0 %1 = shufflevector <16 x float> %0, <16 x float> undef, <16 x i32> zeroinitializer ret <16 x float> %1 @llvm def _mm256_loadu_ps(data: Ptr[f32]) -> Vec[f32, 8]: %0 = bitcast float* %data to <8 x float>* %1 = load <8 x float>, <8 x float>* %0 ret <8 x float> %1 @llvm def _mm512_loadu_ps(data: Ptr[f32]) -> Vec[f32, 16]: %0 = bitcast float* %data to <16 x float>* %1 = load <16 x float>, <16 x float>* %0 ret <16 x float> %1 @llvm def _mm256_cvtepi8_epi32(vec: Vec[u8, 16]) -> Vec[u32, 8]: %0 = shufflevector <16 x i8> %vec, <16 x i8> undef, <8 x i32> %1 = sext <8 x i8> %0 to <8 x i32> ret <8 x i32> %1 @llvm def _mm512_cvtepi8_epi64(vec: Vec[u8, 32]) -> Vec[u32, 16]: %0 = shufflevector <32 x i8> %vec, <32 x i8> undef, <16 x i32> %1 = sext <16 x i8> %0 to <16 x i32> ret <16 x i32> %1 @llvm def _mm256_castsi256_ps(vec: Vec[u32, 8]) -> Vec[f32, 8]: %0 = bitcast <8 x i32> %vec to <8 x float> ret <8 x float> %0 @llvm def _mm512_castsi512_ps(vec: Vec[u32, 16]) -> Vec[f32, 16]: %0 = bitcast <16 x i32> %vec to <16 x float> ret <16 x float> %0 def __new__(x, T: type, N: Static[int]) -> Vec[T, N]: if isinstance(T, u8) and N == 16: if isinstance(x, u8) or isinstance(x, byte): # TODO: u8<->byte return Vec._mm_set1_epi8(x) if isinstance(x, Ptr[u8]) or isinstance(x, Ptr[byte]): return Vec._mm_loadu_si128(x) if isinstance(x, str): return Vec._mm_loadu_si128(x.ptr) if isinstance(T, u8) and N == 32: if isinstance(x, u8) or isinstance(x, byte): # TODO: u8<->byte return Vec._mm256_set1_epi8(x) if isinstance(x, Ptr[u8]) or isinstance(x, Ptr[byte]): return Vec._mm256_loadu_si256(x) if isinstance(x, str): return Vec._mm256_loadu_si256(x.ptr) if isinstance(T, f32) and N == 8: if isinstance(x, f32): return Vec._mm256_set1_ps(x) if isinstance(x, Ptr[f32]): # TODO: multi-elif does NOT work with statics [why?!] return Vec._mm256_loadu_ps(x) if isinstance(x, List[f32]): return Vec._mm256_loadu_ps(x.arr.ptr) if isinstance(x, Vec[u8, 16]): return Vec._mm256_castsi256_ps(Vec._mm256_cvtepi8_epi32(x)) if isinstance(T, f32) and N == 16: if isinstance(x, f32): return Vec._mm512_set1_ps(x) if isinstance(x, Ptr[f32]): # TODO: multi-elif does NOT work with statics [why?!] return Vec._mm512_loadu_ps(x) if isinstance(x, List[f32]): return Vec._mm512_loadu_ps(x.arr.ptr) if isinstance(x, Vec[u8, 32]): return Vec._mm512_castsi512_ps(Vec._mm512_cvtepi8_epi64(x)) compile_error("invalid SIMD vector constructor") def __new__(x: str, offset: int = 0) -> Vec[u8, N]: return Vec(x.ptr + offset, u8, N) def __new__(x: List[T], offset: int = 0) -> Vec[T, N]: return Vec(x.arr.ptr + offset, T, N) def __new__(x) -> Vec[T, N]: return Vec(x, T, N) @llvm def _mm_cmpeq_epi8(x: Vec[u8, 16], y: Vec[u8, 16]) -> Vec[u8, 16]: %0 = icmp eq <16 x i8> %x, %y %1 = sext <16 x i1> %0 to <16 x i8> ret <16 x i8> %1 def __eq__(self: Vec[u8, 16], other: Vec[u8, 16]) -> Vec[u8, 16]: return Vec._mm_cmpeq_epi8(self, other) @llvm def _mm256_cmpeq_epi8(x: Vec[u8, 32], y: Vec[u8, 32]) -> Vec[u8, 32]: %0 = icmp eq <32 x i8> %x, %y %1 = sext <32 x i1> %0 to <32 x i8> ret <32 x i8> %1 def __eq__(self: Vec[u8, 32], other: Vec[u8, 32]) -> Vec[u8, 32]: return Vec._mm256_cmpeq_epi8(self, other) @llvm def _mm_andnot_si128(x: Vec[u8, 16], y: Vec[u8, 16]) -> Vec[u8, 16]: %0 = xor <16 x i8> %x, %1 = and <16 x i8> %y, %0 ret <16 x i8> %1 def __ne__(self: Vec[u8, 16], other: Vec[u8, 16]) -> Vec[u8, 16]: return Vec._mm_andnot_si128((self == other), Vec.FF_16x8i) @llvm def _mm256_andnot_si256(x: Vec[u8, 32], y: Vec[u8, 32]) -> Vec[u8, 32]: %0 = xor <32 x i8> %x, %1 = and <32 x i8> %y, %0 ret <32 x i8> %1 def __ne__(self: Vec[u8, 32], other: Vec[u8, 32]) -> Vec[u8, 32]: return Vec._mm256_andnot_si256((self == other), Vec.FF_32x8i) def __eq__(self: Vec[u8, 16], other: bool) -> Vec[u8, 16]: if not other: return Vec._mm_andnot_si128(self, Vec.FF_16x8i) else: return Vec._mm_andnot_si128(self, Vec.ZERO_16x8i) def __eq__(self: Vec[u8, 32], other: bool) -> Vec[u8, 32]: if not other: return Vec._mm256_andnot_si256(self, Vec.FF_32x8i) else: return Vec._mm256_andnot_si256(self, Vec.ZERO_32x8i) @llvm def _mm_and_si128(x: Vec[u8, 16], y: Vec[u8, 16]) -> Vec[u8, 16]: %0 = and <16 x i8> %x, %y ret <16 x i8> %0 def __and__(self: Vec[u8, 16], other: Vec[u8, 16]) -> Vec[u8, 16]: return Vec._mm_and_si128(self, other) @llvm def _mm_and_si256(x: Vec[u8, 32], y: Vec[u8, 32]) -> Vec[u8, 32]: %0 = and <32 x i8> %x, %y ret <32 x i8> %0 def __and__(self: Vec[u8, 32], other: Vec[u8, 32]) -> Vec[u8, 32]: return Vec._mm_and_si256(self, other) @llvm def _mm256_and_ps(x: Vec[f32, 8], y: Vec[f32, 8]) -> Vec[f32, 8]: %0 = bitcast <8 x float> %x to <8 x i32> %1 = bitcast <8 x float> %y to <8 x i32> %2 = and <8 x i32> %0, %1 %3 = bitcast <8 x i32> %2 to <8 x float> ret <8 x float> %3 def __and__(self: Vec[f32, 8], other: Vec[f32, 8]) -> Vec[f32, 8]: return Vec._mm256_and_ps(self, other) @llvm def _mm512_and_ps(x: Vec[f32, 16], y: Vec[f32, 16]) -> Vec[f32, 16]: %0 = bitcast <16 x float> %x to <16 x i32> %1 = bitcast <16 x float> %y to <16 x i32> %2 = and <16 x i32> %0, %1 %3 = bitcast <16 x i32> %2 to <16 x float> ret <16 x float> %3 def __and__(self: Vec[f32, 16], other: Vec[f32, 16]) -> Vec[f32, 16]: return Vec._mm512_and_ps(self, other) @llvm def _mm_or_si128(x: Vec[u8, 16], y: Vec[u8, 16]) -> Vec[u8, 16]: %0 = or <16 x i8> %x, %y ret <16 x i8> %0 def __or__(self: Vec[u8, 16], other: Vec[u8, 16]) -> Vec[u8, 16]: return Vec._mm_or_si128(self, other) @llvm def _mm_or_si256(x: Vec[u8, 32], y: Vec[u8, 32]) -> Vec[u8, 32]: %0 = or <32 x i8> %x, %y ret <32 x i8> %0 def __or__(self: Vec[u8, 32], other: Vec[u8, 32]) -> Vec[u8, 32]: return Vec._mm_or_si256(self, other) @llvm def _mm256_or_ps(x: Vec[f32, 8], y: Vec[f32, 8]) -> Vec[f32, 8]: %0 = bitcast <8 x float> %x to <8 x i32> %1 = bitcast <8 x float> %y to <8 x i32> %2 = or <8 x i32> %0, %1 %3 = bitcast <8 x i32> %2 to <8 x float> ret <8 x float> %3 def __or__(self: Vec[f32, 8], other: Vec[f32, 8]) -> Vec[f32, 8]: return Vec._mm256_or_ps(self, other) @llvm def _mm512_or_ps(x: Vec[f32, 16], y: Vec[f32, 16]) -> Vec[f32, 16]: %0 = bitcast <16 x float> %x to <16 x i32> %1 = bitcast <16 x float> %y to <16 x i32> %2 = or <16 x i32> %0, %1 %3 = bitcast <16 x i32> %2 to <16 x float> ret <16 x float> %3 def __or__(self: Vec[f32, 16], other: Vec[f32, 16]) -> Vec[f32, 16]: return Vec._mm512_or_ps(self, other) @llvm def _mm_bsrli_si128_8(vec: Vec[u8, 16]) -> Vec[u8, 16]: %0 = shufflevector <16 x i8> %vec, <16 x i8> zeroinitializer, <16 x i32> ret <16 x i8> %0 @llvm def _mm256_add_ps(x: Vec[f32, 8], y: Vec[f32, 8]) -> Vec[f32, 8]: %0 = fadd <8 x float> %x, %y ret <8 x float> %0 def __add__(self: Vec[f32, 8], other: Vec[f32, 8]) -> Vec[f32, 8]: return Vec._mm256_add_ps(self, other) def __rshift__(self: Vec[u8, 16], shift: Static[int]) -> Vec[u8, 16]: if shift == 0: return self elif shift == 8: return Vec._mm_bsrli_si128_8(self) else: compile_error("invalid bitshift") @llvm def _mm_bsrli_256(vec: Vec[u8, 32]) -> Vec[u8, 32]: %0 = shufflevector <32 x i8> %vec, <32 x i8> zeroinitializer, <32 x i32> ret <32 x i8> %0 def __rshift__(self: Vec[u8, 32], shift: Static[int]) -> Vec[u8, 32]: if shift == 0: return self elif shift == 16: return Vec._mm_bsrli_256(self) else: compile_error("invalid bitshift") # @llvm # https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-sse-vector-sum-or-other-reduction # def sum(self: Vec[f32, 8]) -> f32: # %0 = shufflevector <8 x float> %self, <8 x float> undef, <4 x i32> # %1 = shufflevector <8 x float> %self, <8 x float> poison, <4 x i32> # %2 = fadd <4 x float> %0, %1 # %3 = shufflevector <4 x float> %2, <4 x float> undef, <4 x i32> # %4 = fadd <4 x float> %2, %3 # %5 = shufflevector <4 x float> %4, <4 x float> poison, <4 x i32> # %6 = fadd <4 x float> %4, %5 # %7 = extractelement <4 x float> %6, i32 0 # ret float %7 def sum(self: Vec[f32, 8], x: f32 = f32(0.0)) -> f32: return x + self[0] + self[1] + self[2] + self[3] + self[4] + self[5] + self[6] + self[7] @llvm def __getitem__(self, n: Static[int]) -> T: %0 = extractelement <{=N} x {=T}> %self, i32 {=n} ret {=T} %0 def __repr__(self): if N == 8: return f"<{self[0]}, {self[1]}, {self[2]}, {self[3]}, {self[4]}, {self[5]}, {self[6]}, {self[7]}>" elif N == 16: return f"<{self[0]}, {self[1]}, {self[2]}, {self[3]}, {self[4]}, {self[5]}, {self[6]}, {self[7]}, {self[8]}, {self[9]}, {self[10]}, {self[11]}, {self[12]}, {self[13]}, {self[14]}, {self[15]}>" else: return "?" def scatter(self: Vec[T, N]) -> List[T]: return [self[i] for i in staticrange(N)] u8x16 = Vec[u8, 16] u8x32 = Vec[u8, 32] f32x8 = Vec[f32, 8]