# Copyright (C) 2022-2025 Exaloop Inc. # Implementation of vectorized Rabin-Karp string search. # See http://0x80.pl/articles/simd-strfind.html for # details. These implementations are modified to not # perform any out-of-bounds memory accesses. @pure @llvm def cttz(n: UInt[N], N: Static[int]) -> UInt[N]: declare i{=N} @llvm.cttz.i{=N}(i{=N}, i1) %0 = call i{=N} @llvm.cttz.i{=N}(i{=N} %n, i1 true) ret i{=N} %0 @pure @llvm def ctlz(n: UInt[N], N: Static[int]) -> UInt[N]: declare i{=N} @llvm.ctlz.i{=N}(i{=N}, i1) %0 = call i{=N} @llvm.ctlz.i{=N}(i{=N} %n, i1 true) ret i{=N} %0 @pure @llvm def forward_mask(s: Ptr[byte], n: int, needle: Ptr[byte], k: int, i: int, firstb: byte, lastb: byte) -> u16: %first0 = insertelement <16 x i8> undef, i8 %firstb, i64 0 %first = shufflevector <16 x i8> %first0, <16 x i8> poison, <16 x i32> zeroinitializer %last0 = insertelement <16 x i8> undef, i8 %lastb, i64 0 %last = shufflevector <16 x i8> %last0, <16 x i8> poison, <16 x i32> zeroinitializer %offset0 = add i64 %i, %k %offset = sub i64 %offset0, 1 %ptr_first = getelementptr inbounds i8, ptr %s, i64 %i %ptr_last = getelementptr inbounds i8, ptr %s, i64 %offset %block_first = load <16 x i8>, ptr %ptr_first, align 1 %block_last = load <16 x i8>, ptr %ptr_last, align 1 %eq_first = icmp eq <16 x i8> %first, %block_first %eq_last = icmp eq <16 x i8> %last, %block_last %mask0 = and <16 x i1> %eq_first, %eq_last %mask = bitcast <16 x i1> %mask0 to i16 ret i16 %mask @pure @llvm def backward_mask(s: Ptr[byte], n: int, needle: Ptr[byte], k: int, i: int, firstb: byte, lastb: byte) -> u16: %j0 = sub i64 %i, 16 %j = add i64 %j0, 1 %first0 = insertelement <16 x i8> undef, i8 %firstb, i64 0 %first = shufflevector <16 x i8> %first0, <16 x i8> poison, <16 x i32> zeroinitializer %last0 = insertelement <16 x i8> undef, i8 %lastb, i64 0 %last = shufflevector <16 x i8> %last0, <16 x i8> poison, <16 x i32> zeroinitializer %offset0 = sub i64 %j, %k %offset = add i64 %offset0, 1 %ptr_first = getelementptr inbounds i8, ptr %s, i64 %offset %ptr_last = getelementptr inbounds i8, ptr %s, i64 %j %block_first = load <16 x i8>, ptr %ptr_first, align 1 %block_last = load <16 x i8>, ptr %ptr_last, align 1 %eq_first = icmp eq <16 x i8> %last, %block_last %eq_last = icmp eq <16 x i8> %first, %block_first %mask0 = and <16 x i1> %eq_first, %eq_last %mask = bitcast <16 x i1> %mask0 to i16 ret i16 %mask def forward_find(s: Ptr[byte], n: int, needle: Ptr[byte], k: int): if k == 0: return 0 if n < k: return -1 if k == 1: p = _C.memchr(s, i32(int(needle[0])), n) return p - s if p else -1 firstb = needle[0] lastb = needle[k - 1] i = 0 while i + k + 16 - 1 <= n: mask = forward_mask(s, n, needle, k, i, firstb, lastb) while mask: bitpos = int(cttz(mask)) if _C.memcmp(s + i + bitpos + 1, needle + 1, k - 2) == i32(0): return i + bitpos mask = mask & (mask - u16(1)) i += 16 # unrolled by hand while True: j = i + 0 if j + k <= n: if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0): return j else: break j = i + 1 if j + k <= n: if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0): return j else: break j = i + 2 if j + k <= n: if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0): return j else: break j = i + 3 if j + k <= n: if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0): return j else: break j = i + 4 if j + k <= n: if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0): return j else: break j = i + 5 if j + k <= n: if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0): return j else: break j = i + 6 if j + k <= n: if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0): return j else: break j = i + 7 if j + k <= n: if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0): return j else: break j = i + 8 if j + k <= n: if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0): return j else: break j = i + 9 if j + k <= n: if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0): return j else: break j = i + 10 if j + k <= n: if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0): return j else: break j = i + 11 if j + k <= n: if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0): return j else: break j = i + 12 if j + k <= n: if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0): return j else: break j = i + 13 if j + k <= n: if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0): return j else: break j = i + 14 if j + k <= n: if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0): return j else: break j = i + 15 if j + k <= n: if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0): return j else: break break return -1 def backward_find(s: Ptr[byte], n: int, needle: Ptr[byte], k: int): if k == 0: return n if n < k: return -1 if k == 1: i = n - 1 while i >= 0: if s[i] == needle[0]: return i i -= 1 return -1 firstb = needle[0] lastb = needle[k - 1] i = n - 1 while i - (k - 1) - (16 - 1) >= 0: mask = backward_mask(s, n, needle, k, i, firstb, lastb) while mask: bitpos = int(ctlz(mask)) if _C.memcmp(s + i - (k - 1) - bitpos + 1, needle + 1, k - 2) == i32(0): return i - (k - 1) - bitpos mask &= ~(u16(1) << u16(16 - 1 - bitpos)) i -= 16 # unrolled by hand while True: j = i - 0 if j - k + 1 >= 0: if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0): return j - k + 1 else: break j = i - 1 if j - k + 1 >= 0: if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0): return j - k + 1 else: break j = i - 2 if j - k + 1 >= 0: if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0): return j - k + 1 else: break j = i - 3 if j - k + 1 >= 0: if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0): return j - k + 1 else: break j = i - 4 if j - k + 1 >= 0: if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0): return j - k + 1 else: break j = i - 5 if j - k + 1 >= 0: if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0): return j - k + 1 else: break j = i - 6 if j - k + 1 >= 0: if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0): return j - k + 1 else: break j = i - 7 if j - k + 1 >= 0: if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0): return j - k + 1 else: break j = i - 8 if j - k + 1 >= 0: if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0): return j - k + 1 else: break j = i - 9 if j - k + 1 >= 0: if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0): return j - k + 1 else: break j = i + 10 if j - k + 1 >= 0: if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0): return j - k + 1 else: break j = i - 11 if j - k + 1 >= 0: if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0): return j - k + 1 else: break j = i - 12 if j - k + 1 >= 0: if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0): return j - k + 1 else: break j = i - 13 if j - k + 1 >= 0: if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0): return j - k + 1 else: break j = i - 14 if j - k + 1 >= 0: if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0): return j - k + 1 else: break j = i - 15 if j - k + 1 >= 0: if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0): return j - k + 1 else: break break return -1 def find(haystack: str, needle: str): return forward_find(haystack.ptr, haystack.len, needle.ptr, needle.len) def rfind(haystack: str, needle: str): return backward_find(haystack.ptr, haystack.len, needle.ptr, needle.len) def count(haystack: str, needle: str): occ = 0 tmp = haystack.ptr n = haystack.len k = needle.len if k == 0: return n + 1 while True: pos = forward_find(tmp, n - (tmp - haystack.ptr), needle.ptr, k) if pos == -1: break tmp += pos + k occ += 1 return occ def count_with_max(haystack: str, needle: str, maxcount: int): occ = 0 tmp = haystack.ptr n = haystack.len k = needle.len if maxcount == 0: return 0 if k == 0: return n + 1 if n + 1 < maxcount else maxcount while True: pos = forward_find(tmp, n - (tmp - haystack.ptr), needle.ptr, k) if pos == -1: break tmp += pos + k occ += 1 if occ == maxcount: return occ return occ