1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00
codon/stdlib/algorithms/strings.codon
Ibrahim Numanagić 5de233a64e
Dynamic Polymorphism (#58)
* Use Static[] for static inheritance

* Support .seq extension

* Fix #36

* Polymorphic typechecking; vtables [wip]

* v-table dispatch [wip]

* vtable routing [wip; bug]

* vtable routing [MVP]

* Fix texts

* Add union type support

* Update FAQs

* Clarify

* Add BSL license

* Add makeUnion

* Add IR UnionType

* Update union representation in LLVM

* Update README

* Update README.md

* Update README

* Update README.md

* Add benchmarks

* Add more benchmarks and README

* Add primes benchmark

* Update benchmarks

* Fix cpp

* Clean up list

* Update faq.md

* Add binary trees benchmark

* Add fannkuch benchmark

* Fix paths

* Add PyPy

* Abort on fail

* More benchmarks

* Add cpp word_count

* Update set_partition cpp

* Add nbody cpp

* Add TAQ cpp; fix word_count timing

* Update CODEOWNERS

* Update README

* Update README.md

* Update CODEOWNERS

* Fix bench script

* Update binary_trees.cpp

* Update taq.cpp

* Fix primes benchmark

* Add mandelbrot benchmark

* Fix OpenMP init

* Add Module::unsafeGetUnionType

* UnionType [wip] [skip ci]

* Integrate IR unions and Union

* UnionType refactor [skip ci]

* Update README.md

* Update docs

* UnionType [wip] [skip ci]

* UnionType and automatic unions

* Add Slack

* Update faq.md

* Refactor types

* New error reporting [wip]

* New error reporting [wip]

* peglib updates [wip] [skip_ci]

* Fix parsing issues

* Fix parsing issues

* Fix error reporting issues

* Make sure random module matches Python

* Update releases.md

* Fix tests

* Fix #59

* Fix #57

* Fix #50

* Fix #49

* Fix #26; Fix #51; Fix #47; Fix #49

* Fix collection extension methods

* Fix #62

* Handle *args/**kwargs with Callable[]; Fix #43

* Fix #43

* Fix Ptr.__sub__; Fix polymorphism issues

* Add typeinfo

* clang-format

* Upgrade fmtlib to v9; Use CPM for fmtlib; format spec support; __format__ support

* Use CPM for semver and toml++

* Remove extension check

* Revamp str methods

* Update str.zfill

* Fix thunk crashes [wip] [skip_ci]

* Fix str.__reversed__

* Fix count_with_max

* Fix vtable memory allocation issues

* Add poly AST tests

* Use PDQsort when stability does not matter

* Fix dotted imports; Fix  issues

* Fix kwargs passing to Python

* Fix #61

* Fix #37

* Add isinstance support for unions; Union methods return Union type if different

* clang-format

* Nicely format error tracebacks

* Fix build issues; clang-format

* Fix OpenMP init

* Fix OpenMP init

* Update README.md

* Fix tests

* Update license [skip ci]

* Update license [ci skip]

* Add copyright header to all source files

* Fix super(); Fix error recovery in ClassStmt

* Clean up whitespace [ci skip]

* Use Python 3.9 on CI

* Print info in random test

* Fix single unions

* Update random_test.codon

* Fix polymorhic thunk instantiation

* Fix random test

* Add operator.attrgetter and operator.methodcaller

* Add code documentation

* Update documentation

* Update README.md

* Fix tests

* Fix random init

Co-authored-by: A. R. Shajii <ars@ars.me>
2022-12-04 19:45:21 -05:00

394 lines
12 KiB
Python

# Copyright (C) 2022 Exaloop Inc. <https://exaloop.io>
# Implementation of vectorized Rabin-Karp string search.
# See http://0x80.pl/articles/simd-strfind.html for
# details. These implementations are modified to not
# perform any out-of-bounds memory accesses.
@pure
@llvm
def cttz(n: UInt[N], N: Static[int]) -> UInt[N]:
declare i{=N} @llvm.cttz.i{=N}(i{=N}, i1)
%0 = call i{=N} @llvm.cttz.i{=N}(i{=N} %n, i1 true)
ret i{=N} %0
@pure
@llvm
def ctlz(n: UInt[N], N: Static[int]) -> UInt[N]:
declare i{=N} @llvm.ctlz.i{=N}(i{=N}, i1)
%0 = call i{=N} @llvm.ctlz.i{=N}(i{=N} %n, i1 true)
ret i{=N} %0
@pure
@llvm
def forward_mask(s: Ptr[byte], n: int, needle: Ptr[byte], k: int, i: int, firstb: byte, lastb: byte) -> u16:
%first0 = insertelement <16 x i8> undef, i8 %firstb, i64 0
%first = shufflevector <16 x i8> %first0, <16 x i8> poison, <16 x i32> zeroinitializer
%last0 = insertelement <16 x i8> undef, i8 %lastb, i64 0
%last = shufflevector <16 x i8> %last0, <16 x i8> poison, <16 x i32> zeroinitializer
%offset0 = add i64 %i, %k
%offset = sub i64 %offset0, 1
%ptr_first = getelementptr inbounds i8, ptr %s, i64 %i
%ptr_last = getelementptr inbounds i8, ptr %s, i64 %offset
%block_first = load <16 x i8>, ptr %ptr_first, align 1
%block_last = load <16 x i8>, ptr %ptr_last, align 1
%eq_first = icmp eq <16 x i8> %first, %block_first
%eq_last = icmp eq <16 x i8> %last, %block_last
%mask0 = and <16 x i1> %eq_first, %eq_last
%mask = bitcast <16 x i1> %mask0 to i16
ret i16 %mask
@pure
@llvm
def backward_mask(s: Ptr[byte], n: int, needle: Ptr[byte], k: int, i: int, firstb: byte, lastb: byte) -> u16:
%j0 = sub i64 %i, 16
%j = add i64 %j0, 1
%first0 = insertelement <16 x i8> undef, i8 %firstb, i64 0
%first = shufflevector <16 x i8> %first0, <16 x i8> poison, <16 x i32> zeroinitializer
%last0 = insertelement <16 x i8> undef, i8 %lastb, i64 0
%last = shufflevector <16 x i8> %last0, <16 x i8> poison, <16 x i32> zeroinitializer
%offset0 = sub i64 %j, %k
%offset = add i64 %offset0, 1
%ptr_first = getelementptr inbounds i8, ptr %s, i64 %offset
%ptr_last = getelementptr inbounds i8, ptr %s, i64 %j
%block_first = load <16 x i8>, ptr %ptr_first, align 1
%block_last = load <16 x i8>, ptr %ptr_last, align 1
%eq_first = icmp eq <16 x i8> %last, %block_last
%eq_last = icmp eq <16 x i8> %first, %block_first
%mask0 = and <16 x i1> %eq_first, %eq_last
%mask = bitcast <16 x i1> %mask0 to i16
ret i16 %mask
def forward_find(s: Ptr[byte], n: int, needle: Ptr[byte], k: int):
if k == 0:
return 0
if n < k:
return -1
if k == 1:
p = _C.memchr(s, i32(int(needle[0])), n)
return p - s if p else -1
firstb = needle[0]
lastb = needle[k - 1]
i = 0
while i + k + 16 - 1 <= n:
mask = forward_mask(s, n, needle, k, i, firstb, lastb)
while mask:
bitpos = int(cttz(mask))
if _C.memcmp(s + i + bitpos + 1, needle + 1, k - 2) == i32(0):
return i + bitpos
mask = mask & (mask - u16(1))
i += 16
# unrolled by hand
while True:
j = i + 0
if j + k <= n:
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
return j
else:
break
j = i + 1
if j + k <= n:
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
return j
else:
break
j = i + 2
if j + k <= n:
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
return j
else:
break
j = i + 3
if j + k <= n:
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
return j
else:
break
j = i + 4
if j + k <= n:
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
return j
else:
break
j = i + 5
if j + k <= n:
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
return j
else:
break
j = i + 6
if j + k <= n:
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
return j
else:
break
j = i + 7
if j + k <= n:
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
return j
else:
break
j = i + 8
if j + k <= n:
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
return j
else:
break
j = i + 9
if j + k <= n:
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
return j
else:
break
j = i + 10
if j + k <= n:
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
return j
else:
break
j = i + 11
if j + k <= n:
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
return j
else:
break
j = i + 12
if j + k <= n:
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
return j
else:
break
j = i + 13
if j + k <= n:
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
return j
else:
break
j = i + 14
if j + k <= n:
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
return j
else:
break
j = i + 15
if j + k <= n:
if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
return j
else:
break
break
return -1
def backward_find(s: Ptr[byte], n: int, needle: Ptr[byte], k: int):
if k == 0:
return n
if n < k:
return -1
if k == 1:
i = n - 1
while i >= 0:
if s[i] == needle[0]:
return i
i -= 1
return -1
firstb = needle[0]
lastb = needle[k - 1]
i = n - 1
while i - (k - 1) - (16 - 1) >= 0:
mask = backward_mask(s, n, needle, k, i, firstb, lastb)
while mask:
bitpos = int(ctlz(mask))
if _C.memcmp(s + i - (k - 1) - bitpos + 1, needle + 1, k - 2) == i32(0):
return i - (k - 1) - bitpos
mask &= ~(u16(1) << u16(16 - 1 - bitpos))
i -= 16
# unrolled by hand
while True:
j = i - 0
if j - k + 1 >= 0:
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
return j - k + 1
else:
break
j = i - 1
if j - k + 1 >= 0:
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
return j - k + 1
else:
break
j = i - 2
if j - k + 1 >= 0:
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
return j - k + 1
else:
break
j = i - 3
if j - k + 1 >= 0:
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
return j - k + 1
else:
break
j = i - 4
if j - k + 1 >= 0:
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
return j - k + 1
else:
break
j = i - 5
if j - k + 1 >= 0:
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
return j - k + 1
else:
break
j = i - 6
if j - k + 1 >= 0:
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
return j - k + 1
else:
break
j = i - 7
if j - k + 1 >= 0:
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
return j - k + 1
else:
break
j = i - 8
if j - k + 1 >= 0:
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
return j - k + 1
else:
break
j = i - 9
if j - k + 1 >= 0:
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
return j - k + 1
else:
break
j = i + 10
if j - k + 1 >= 0:
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
return j - k + 1
else:
break
j = i - 11
if j - k + 1 >= 0:
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
return j - k + 1
else:
break
j = i - 12
if j - k + 1 >= 0:
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
return j - k + 1
else:
break
j = i - 13
if j - k + 1 >= 0:
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
return j - k + 1
else:
break
j = i - 14
if j - k + 1 >= 0:
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
return j - k + 1
else:
break
j = i - 15
if j - k + 1 >= 0:
if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
return j - k + 1
else:
break
break
return -1
def find(haystack: str, needle: str):
return forward_find(haystack.ptr, haystack.len, needle.ptr, needle.len)
def rfind(haystack: str, needle: str):
return backward_find(haystack.ptr, haystack.len, needle.ptr, needle.len)
def count(haystack: str, needle: str):
occ = 0
tmp = haystack.ptr
n = haystack.len
k = needle.len
if k == 0:
return n + 1
while True:
pos = forward_find(tmp, n - (tmp - haystack.ptr), needle.ptr, k)
if pos == -1:
break
tmp += pos + k
occ += 1
return occ
def count_with_max(haystack: str, needle: str, maxcount: int):
occ = 0
tmp = haystack.ptr
n = haystack.len
k = needle.len
if maxcount == 0:
return 0
if k == 0:
return n + 1 if n + 1 < maxcount else maxcount
while True:
pos = forward_find(tmp, n - (tmp - haystack.ptr), needle.ptr, k)
if pos == -1:
break
tmp += pos + k
occ += 1
if occ == maxcount:
return occ
return occ