mirror of
https://github.com/exaloop/codon.git
synced 2025-06-03 15:03:52 +08:00
* Use Static[] for static inheritance * Support .seq extension * Fix #36 * Polymorphic typechecking; vtables [wip] * v-table dispatch [wip] * vtable routing [wip; bug] * vtable routing [MVP] * Fix texts * Add union type support * Update FAQs * Clarify * Add BSL license * Add makeUnion * Add IR UnionType * Update union representation in LLVM * Update README * Update README.md * Update README * Update README.md * Add benchmarks * Add more benchmarks and README * Add primes benchmark * Update benchmarks * Fix cpp * Clean up list * Update faq.md * Add binary trees benchmark * Add fannkuch benchmark * Fix paths * Add PyPy * Abort on fail * More benchmarks * Add cpp word_count * Update set_partition cpp * Add nbody cpp * Add TAQ cpp; fix word_count timing * Update CODEOWNERS * Update README * Update README.md * Update CODEOWNERS * Fix bench script * Update binary_trees.cpp * Update taq.cpp * Fix primes benchmark * Add mandelbrot benchmark * Fix OpenMP init * Add Module::unsafeGetUnionType * UnionType [wip] [skip ci] * Integrate IR unions and Union * UnionType refactor [skip ci] * Update README.md * Update docs * UnionType [wip] [skip ci] * UnionType and automatic unions * Add Slack * Update faq.md * Refactor types * New error reporting [wip] * New error reporting [wip] * peglib updates [wip] [skip_ci] * Fix parsing issues * Fix parsing issues * Fix error reporting issues * Make sure random module matches Python * Update releases.md * Fix tests * Fix #59 * Fix #57 * Fix #50 * Fix #49 * Fix #26; Fix #51; Fix #47; Fix #49 * Fix collection extension methods * Fix #62 * Handle *args/**kwargs with Callable[]; Fix #43 * Fix #43 * Fix Ptr.__sub__; Fix polymorphism issues * Add typeinfo * clang-format * Upgrade fmtlib to v9; Use CPM for fmtlib; format spec support; __format__ support * Use CPM for semver and toml++ * Remove extension check * Revamp str methods * Update str.zfill * Fix thunk crashes [wip] [skip_ci] * Fix str.__reversed__ * Fix count_with_max * Fix vtable memory allocation issues * Add poly AST tests * Use PDQsort when stability does not matter * Fix dotted imports; Fix issues * Fix kwargs passing to Python * Fix #61 * Fix #37 * Add isinstance support for unions; Union methods return Union type if different * clang-format * Nicely format error tracebacks * Fix build issues; clang-format * Fix OpenMP init * Fix OpenMP init * Update README.md * Fix tests * Update license [skip ci] * Update license [ci skip] * Add copyright header to all source files * Fix super(); Fix error recovery in ClassStmt * Clean up whitespace [ci skip] * Use Python 3.9 on CI * Print info in random test * Fix single unions * Update random_test.codon * Fix polymorhic thunk instantiation * Fix random test * Add operator.attrgetter and operator.methodcaller * Add code documentation * Update documentation * Update README.md * Fix tests * Fix random init Co-authored-by: A. R. Shajii <ars@ars.me>
394 lines
12 KiB
Python
394 lines
12 KiB
Python
# Copyright (C) 2022 Exaloop Inc. <https://exaloop.io>
|
|
|
|
# Implementation of vectorized Rabin-Karp string search.
|
|
# See http://0x80.pl/articles/simd-strfind.html for
|
|
# details. These implementations are modified to not
|
|
# perform any out-of-bounds memory accesses.
|
|
|
|
@pure
|
|
@llvm
|
|
def cttz(n: UInt[N], N: Static[int]) -> UInt[N]:
|
|
declare i{=N} @llvm.cttz.i{=N}(i{=N}, i1)
|
|
%0 = call i{=N} @llvm.cttz.i{=N}(i{=N} %n, i1 true)
|
|
ret i{=N} %0
|
|
|
|
@pure
|
|
@llvm
|
|
def ctlz(n: UInt[N], N: Static[int]) -> UInt[N]:
|
|
declare i{=N} @llvm.ctlz.i{=N}(i{=N}, i1)
|
|
%0 = call i{=N} @llvm.ctlz.i{=N}(i{=N} %n, i1 true)
|
|
ret i{=N} %0
|
|
|
|
@pure
|
|
@llvm
|
|
def forward_mask(s: Ptr[byte], n: int, needle: Ptr[byte], k: int, i: int, firstb: byte, lastb: byte) -> u16:
|
|
%first0 = insertelement <16 x i8> undef, i8 %firstb, i64 0
|
|
%first = shufflevector <16 x i8> %first0, <16 x i8> poison, <16 x i32> zeroinitializer
|
|
%last0 = insertelement <16 x i8> undef, i8 %lastb, i64 0
|
|
%last = shufflevector <16 x i8> %last0, <16 x i8> poison, <16 x i32> zeroinitializer
|
|
%offset0 = add i64 %i, %k
|
|
%offset = sub i64 %offset0, 1
|
|
%ptr_first = getelementptr inbounds i8, ptr %s, i64 %i
|
|
%ptr_last = getelementptr inbounds i8, ptr %s, i64 %offset
|
|
%block_first = load <16 x i8>, ptr %ptr_first, align 1
|
|
%block_last = load <16 x i8>, ptr %ptr_last, align 1
|
|
%eq_first = icmp eq <16 x i8> %first, %block_first
|
|
%eq_last = icmp eq <16 x i8> %last, %block_last
|
|
%mask0 = and <16 x i1> %eq_first, %eq_last
|
|
%mask = bitcast <16 x i1> %mask0 to i16
|
|
ret i16 %mask
|
|
|
|
@pure
|
|
@llvm
|
|
def backward_mask(s: Ptr[byte], n: int, needle: Ptr[byte], k: int, i: int, firstb: byte, lastb: byte) -> u16:
|
|
%j0 = sub i64 %i, 16
|
|
%j = add i64 %j0, 1
|
|
%first0 = insertelement <16 x i8> undef, i8 %firstb, i64 0
|
|
%first = shufflevector <16 x i8> %first0, <16 x i8> poison, <16 x i32> zeroinitializer
|
|
%last0 = insertelement <16 x i8> undef, i8 %lastb, i64 0
|
|
%last = shufflevector <16 x i8> %last0, <16 x i8> poison, <16 x i32> zeroinitializer
|
|
%offset0 = sub i64 %j, %k
|
|
%offset = add i64 %offset0, 1
|
|
%ptr_first = getelementptr inbounds i8, ptr %s, i64 %offset
|
|
%ptr_last = getelementptr inbounds i8, ptr %s, i64 %j
|
|
%block_first = load <16 x i8>, ptr %ptr_first, align 1
|
|
%block_last = load <16 x i8>, ptr %ptr_last, align 1
|
|
%eq_first = icmp eq <16 x i8> %last, %block_last
|
|
%eq_last = icmp eq <16 x i8> %first, %block_first
|
|
%mask0 = and <16 x i1> %eq_first, %eq_last
|
|
%mask = bitcast <16 x i1> %mask0 to i16
|
|
ret i16 %mask
|
|
|
|
def forward_find(s: Ptr[byte], n: int, needle: Ptr[byte], k: int):
|
|
if k == 0:
|
|
return 0
|
|
|
|
if n < k:
|
|
return -1
|
|
|
|
if k == 1:
|
|
p = _C.memchr(s, i32(int(needle[0])), n)
|
|
return p - s if p else -1
|
|
|
|
firstb = needle[0]
|
|
lastb = needle[k - 1]
|
|
i = 0
|
|
|
|
while i + k + 16 - 1 <= n:
|
|
mask = forward_mask(s, n, needle, k, i, firstb, lastb)
|
|
while mask:
|
|
bitpos = int(cttz(mask))
|
|
if _C.memcmp(s + i + bitpos + 1, needle + 1, k - 2) == i32(0):
|
|
return i + bitpos
|
|
mask = mask & (mask - u16(1))
|
|
i += 16
|
|
|
|
# unrolled by hand
|
|
while True:
|
|
j = i + 0
|
|
if j + k <= n:
|
|
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
|
|
return j
|
|
else:
|
|
break
|
|
|
|
j = i + 1
|
|
if j + k <= n:
|
|
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
|
|
return j
|
|
else:
|
|
break
|
|
|
|
j = i + 2
|
|
if j + k <= n:
|
|
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
|
|
return j
|
|
else:
|
|
break
|
|
|
|
j = i + 3
|
|
if j + k <= n:
|
|
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
|
|
return j
|
|
else:
|
|
break
|
|
|
|
j = i + 4
|
|
if j + k <= n:
|
|
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
|
|
return j
|
|
else:
|
|
break
|
|
|
|
j = i + 5
|
|
if j + k <= n:
|
|
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
|
|
return j
|
|
else:
|
|
break
|
|
|
|
j = i + 6
|
|
if j + k <= n:
|
|
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
|
|
return j
|
|
else:
|
|
break
|
|
|
|
j = i + 7
|
|
if j + k <= n:
|
|
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
|
|
return j
|
|
else:
|
|
break
|
|
|
|
j = i + 8
|
|
if j + k <= n:
|
|
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
|
|
return j
|
|
else:
|
|
break
|
|
|
|
j = i + 9
|
|
if j + k <= n:
|
|
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
|
|
return j
|
|
else:
|
|
break
|
|
|
|
j = i + 10
|
|
if j + k <= n:
|
|
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
|
|
return j
|
|
else:
|
|
break
|
|
|
|
j = i + 11
|
|
if j + k <= n:
|
|
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
|
|
return j
|
|
else:
|
|
break
|
|
|
|
j = i + 12
|
|
if j + k <= n:
|
|
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
|
|
return j
|
|
else:
|
|
break
|
|
|
|
j = i + 13
|
|
if j + k <= n:
|
|
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
|
|
return j
|
|
else:
|
|
break
|
|
|
|
j = i + 14
|
|
if j + k <= n:
|
|
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
|
|
return j
|
|
else:
|
|
break
|
|
|
|
j = i + 15
|
|
if j + k <= n:
|
|
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
|
|
return j
|
|
else:
|
|
break
|
|
|
|
break
|
|
|
|
return -1
|
|
|
|
def backward_find(s: Ptr[byte], n: int, needle: Ptr[byte], k: int):
|
|
if k == 0:
|
|
return n
|
|
|
|
if n < k:
|
|
return -1
|
|
|
|
if k == 1:
|
|
i = n - 1
|
|
while i >= 0:
|
|
if s[i] == needle[0]:
|
|
return i
|
|
i -= 1
|
|
return -1
|
|
|
|
firstb = needle[0]
|
|
lastb = needle[k - 1]
|
|
i = n - 1
|
|
|
|
while i - (k - 1) - (16 - 1) >= 0:
|
|
mask = backward_mask(s, n, needle, k, i, firstb, lastb)
|
|
while mask:
|
|
bitpos = int(ctlz(mask))
|
|
if _C.memcmp(s + i - (k - 1) - bitpos + 1, needle + 1, k - 2) == i32(0):
|
|
return i - (k - 1) - bitpos
|
|
mask &= ~(u16(1) << u16(16 - 1 - bitpos))
|
|
i -= 16
|
|
|
|
# unrolled by hand
|
|
while True:
|
|
j = i - 0
|
|
if j - k + 1 >= 0:
|
|
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
|
|
return j - k + 1
|
|
else:
|
|
break
|
|
|
|
j = i - 1
|
|
if j - k + 1 >= 0:
|
|
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
|
|
return j - k + 1
|
|
else:
|
|
break
|
|
|
|
j = i - 2
|
|
if j - k + 1 >= 0:
|
|
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
|
|
return j - k + 1
|
|
else:
|
|
break
|
|
|
|
j = i - 3
|
|
if j - k + 1 >= 0:
|
|
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
|
|
return j - k + 1
|
|
else:
|
|
break
|
|
|
|
j = i - 4
|
|
if j - k + 1 >= 0:
|
|
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
|
|
return j - k + 1
|
|
else:
|
|
break
|
|
|
|
j = i - 5
|
|
if j - k + 1 >= 0:
|
|
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
|
|
return j - k + 1
|
|
else:
|
|
break
|
|
|
|
j = i - 6
|
|
if j - k + 1 >= 0:
|
|
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
|
|
return j - k + 1
|
|
else:
|
|
break
|
|
|
|
j = i - 7
|
|
if j - k + 1 >= 0:
|
|
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
|
|
return j - k + 1
|
|
else:
|
|
break
|
|
|
|
j = i - 8
|
|
if j - k + 1 >= 0:
|
|
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
|
|
return j - k + 1
|
|
else:
|
|
break
|
|
|
|
j = i - 9
|
|
if j - k + 1 >= 0:
|
|
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
|
|
return j - k + 1
|
|
else:
|
|
break
|
|
|
|
j = i + 10
|
|
if j - k + 1 >= 0:
|
|
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
|
|
return j - k + 1
|
|
else:
|
|
break
|
|
|
|
j = i - 11
|
|
if j - k + 1 >= 0:
|
|
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
|
|
return j - k + 1
|
|
else:
|
|
break
|
|
|
|
j = i - 12
|
|
if j - k + 1 >= 0:
|
|
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
|
|
return j - k + 1
|
|
else:
|
|
break
|
|
|
|
j = i - 13
|
|
if j - k + 1 >= 0:
|
|
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
|
|
return j - k + 1
|
|
else:
|
|
break
|
|
|
|
j = i - 14
|
|
if j - k + 1 >= 0:
|
|
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
|
|
return j - k + 1
|
|
else:
|
|
break
|
|
|
|
j = i - 15
|
|
if j - k + 1 >= 0:
|
|
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
|
|
return j - k + 1
|
|
else:
|
|
break
|
|
|
|
break
|
|
|
|
return -1
|
|
|
|
def find(haystack: str, needle: str):
|
|
return forward_find(haystack.ptr, haystack.len, needle.ptr, needle.len)
|
|
|
|
def rfind(haystack: str, needle: str):
|
|
return backward_find(haystack.ptr, haystack.len, needle.ptr, needle.len)
|
|
|
|
def count(haystack: str, needle: str):
|
|
occ = 0
|
|
tmp = haystack.ptr
|
|
n = haystack.len
|
|
k = needle.len
|
|
|
|
if k == 0:
|
|
return n + 1
|
|
|
|
while True:
|
|
pos = forward_find(tmp, n - (tmp - haystack.ptr), needle.ptr, k)
|
|
if pos == -1:
|
|
break
|
|
tmp += pos + k
|
|
occ += 1
|
|
return occ
|
|
|
|
def count_with_max(haystack: str, needle: str, maxcount: int):
|
|
occ = 0
|
|
tmp = haystack.ptr
|
|
n = haystack.len
|
|
k = needle.len
|
|
|
|
if maxcount == 0:
|
|
return 0
|
|
|
|
if k == 0:
|
|
return n + 1 if n + 1 < maxcount else maxcount
|
|
|
|
while True:
|
|
pos = forward_find(tmp, n - (tmp - haystack.ptr), needle.ptr, k)
|
|
if pos == -1:
|
|
break
|
|
tmp += pos + k
|
|
occ += 1
|
|
if occ == maxcount:
|
|
return occ
|
|
return occ
|