# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>

# Implementation of vectorized Rabin-Karp string search.
# See http://0x80.pl/articles/simd-strfind.html for
# details. These implementations are modified to not
# perform any out-of-bounds memory accesses.

@pure
@llvm
def cttz(n: UInt[N], N: Static[int]) -> UInt[N]:
    declare i{=N} @llvm.cttz.i{=N}(i{=N}, i1)
    %0 = call i{=N} @llvm.cttz.i{=N}(i{=N} %n, i1 true)
    ret i{=N} %0

@pure
@llvm
def ctlz(n: UInt[N], N: Static[int]) -> UInt[N]:
    declare i{=N} @llvm.ctlz.i{=N}(i{=N}, i1)
    %0 = call i{=N} @llvm.ctlz.i{=N}(i{=N} %n, i1 true)
    ret i{=N} %0

@pure
@llvm
def forward_mask(s: Ptr[byte], n: int, needle: Ptr[byte], k: int, i: int, firstb: byte, lastb: byte) -> u16:
    %first0 = insertelement <16 x i8> undef, i8 %firstb, i64 0
    %first = shufflevector <16 x i8> %first0, <16 x i8> poison, <16 x i32> zeroinitializer
    %last0 = insertelement <16 x i8> undef, i8 %lastb, i64 0
    %last = shufflevector <16 x i8> %last0, <16 x i8> poison, <16 x i32> zeroinitializer
    %offset0 = add i64 %i, %k
    %offset = sub i64 %offset0, 1
    %ptr_first = getelementptr inbounds i8, ptr %s, i64 %i
    %ptr_last = getelementptr inbounds i8, ptr %s, i64 %offset
    %block_first = load <16 x i8>, ptr %ptr_first, align 1
    %block_last = load <16 x i8>, ptr %ptr_last, align 1
    %eq_first = icmp eq <16 x i8> %first, %block_first
    %eq_last = icmp eq <16 x i8> %last, %block_last
    %mask0 = and <16 x i1> %eq_first, %eq_last
    %mask = bitcast <16 x i1> %mask0 to i16
    ret i16 %mask

@pure
@llvm
def backward_mask(s: Ptr[byte], n: int, needle: Ptr[byte], k: int, i: int, firstb: byte, lastb: byte) -> u16:
    %j0 = sub i64 %i, 16
    %j = add i64 %j0, 1
    %first0 = insertelement <16 x i8> undef, i8 %firstb, i64 0
    %first = shufflevector <16 x i8> %first0, <16 x i8> poison, <16 x i32> zeroinitializer
    %last0 = insertelement <16 x i8> undef, i8 %lastb, i64 0
    %last = shufflevector <16 x i8> %last0, <16 x i8> poison, <16 x i32> zeroinitializer
    %offset0 = sub i64 %j, %k
    %offset = add i64 %offset0, 1
    %ptr_first = getelementptr inbounds i8, ptr %s, i64 %offset
    %ptr_last = getelementptr inbounds i8, ptr %s, i64 %j
    %block_first = load <16 x i8>, ptr %ptr_first, align 1
    %block_last = load <16 x i8>, ptr %ptr_last, align 1
    %eq_first = icmp eq <16 x i8> %last, %block_last
    %eq_last = icmp eq <16 x i8> %first, %block_first
    %mask0 = and <16 x i1> %eq_first, %eq_last
    %mask = bitcast <16 x i1> %mask0 to i16
    ret i16 %mask

def forward_find(s: Ptr[byte], n: int, needle: Ptr[byte], k: int):
    if k == 0:
        return 0

    if n < k:
        return -1

    if k == 1:
        p = _C.memchr(s, i32(int(needle[0])), n)
        return p - s if p else -1

    firstb = needle[0]
    lastb = needle[k - 1]
    i = 0

    while i + k + 16 - 1 <= n:
        mask = forward_mask(s, n, needle, k, i, firstb, lastb)
        while mask:
            bitpos = int(cttz(mask))
            if _C.memcmp(s + i + bitpos + 1, needle + 1, k - 2) == i32(0):
                return i + bitpos
            mask = mask & (mask - u16(1))
        i += 16

    # unrolled by hand
    while True:
        j = i + 0
        if j + k <= n:
            if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
                return j
        else:
            break

        j = i + 1
        if j + k <= n:
            if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
                return j
        else:
            break

        j = i + 2
        if j + k <= n:
            if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
                return j
        else:
            break

        j = i + 3
        if j + k <= n:
            if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
                return j
        else:
            break

        j = i + 4
        if j + k <= n:
            if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
                return j
        else:
            break

        j = i + 5
        if j + k <= n:
            if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
                return j
        else:
            break

        j = i + 6
        if j + k <= n:
            if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
                return j
        else:
            break

        j = i + 7
        if j + k <= n:
            if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
                return j
        else:
            break

        j = i + 8
        if j + k <= n:
            if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
                return j
        else:
            break

        j = i + 9
        if j + k <= n:
            if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
                return j
        else:
            break

        j = i + 10
        if j + k <= n:
            if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
                return j
        else:
            break

        j = i + 11
        if j + k <= n:
            if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
                return j
        else:
            break

        j = i + 12
        if j + k <= n:
            if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
                return j
        else:
            break

        j = i + 13
        if j + k <= n:
            if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
                return j
        else:
            break

        j = i + 14
        if j + k <= n:
            if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
                return j
        else:
            break

        j = i + 15
        if j + k <= n:
            if firstb == s[j] and lastb == s[j + k - 1] and _C.memcmp(s + j + 1, needle + 1, k - 2) == i32(0):
                return j
        else:
            break

        break

    return -1

def backward_find(s: Ptr[byte], n: int, needle: Ptr[byte], k: int):
    if k == 0:
        return n

    if n < k:
        return -1

    if k == 1:
        i = n - 1
        while i >= 0:
            if s[i] == needle[0]:
                return i
            i -= 1
        return -1

    firstb = needle[0]
    lastb = needle[k - 1]
    i = n - 1

    while i - (k - 1) - (16 - 1) >= 0:
        mask = backward_mask(s, n, needle, k, i, firstb, lastb)
        while mask:
            bitpos = int(ctlz(mask))
            if _C.memcmp(s + i - (k - 1) - bitpos + 1, needle + 1, k - 2) == i32(0):
                return i - (k - 1) - bitpos
            mask &= ~(u16(1) << u16(16 - 1 - bitpos))
        i -= 16

    # unrolled by hand
    while True:
        j = i - 0
        if j - k + 1 >= 0:
            if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
                return j - k + 1
        else:
            break

        j = i - 1
        if j - k + 1 >= 0:
            if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
                return j - k + 1
        else:
            break

        j = i - 2
        if j - k + 1 >= 0:
            if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
                return j - k + 1
        else:
            break

        j = i - 3
        if j - k + 1 >= 0:
            if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
                return j - k + 1
        else:
            break

        j = i - 4
        if j - k + 1 >= 0:
            if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
                return j - k + 1
        else:
            break

        j = i - 5
        if j - k + 1 >= 0:
            if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
                return j - k + 1
        else:
            break

        j = i - 6
        if j - k + 1 >= 0:
            if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
                return j - k + 1
        else:
            break

        j = i - 7
        if j - k + 1 >= 0:
            if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
                return j - k + 1
        else:
            break

        j = i - 8
        if j - k + 1 >= 0:
            if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
                return j - k + 1
        else:
            break

        j = i - 9
        if j - k + 1 >= 0:
            if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
                return j - k + 1
        else:
            break

        j = i + 10
        if j - k + 1 >= 0:
            if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
                return j - k + 1
        else:
            break

        j = i - 11
        if j - k + 1 >= 0:
            if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
                return j - k + 1
        else:
            break

        j = i - 12
        if j - k + 1 >= 0:
            if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
                return j - k + 1
        else:
            break

        j = i - 13
        if j - k + 1 >= 0:
            if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
                return j - k + 1
        else:
            break

        j = i - 14
        if j - k + 1 >= 0:
            if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
                return j - k + 1
        else:
            break

        j = i - 15
        if j - k + 1 >= 0:
            if lastb == s[j] and firstb == s[j - k + 1] and _C.memcmp(s + j - k + 2, needle + 1, k - 2) == i32(0):
                return j - k + 1
        else:
            break

        break

    return -1

def find(haystack: str, needle: str):
    return forward_find(haystack.ptr, haystack.len, needle.ptr, needle.len)

def rfind(haystack: str, needle: str):
    return backward_find(haystack.ptr, haystack.len, needle.ptr, needle.len)

def count(haystack: str, needle: str):
    occ = 0
    tmp = haystack.ptr
    n = haystack.len
    k = needle.len

    if k == 0:
        return n + 1

    while True:
        pos = forward_find(tmp, n - (tmp - haystack.ptr), needle.ptr, k)
        if pos == -1:
            break
        tmp += pos + k
        occ += 1
    return occ

def count_with_max(haystack: str, needle: str, maxcount: int):
    occ = 0
    tmp = haystack.ptr
    n = haystack.len
    k = needle.len

    if maxcount == 0:
        return 0

    if k == 0:
        return n + 1 if n + 1 < maxcount else maxcount

    while True:
        pos = forward_find(tmp, n - (tmp - haystack.ptr), needle.ptr, k)
        if pos == -1:
            break
        tmp += pos + k
        occ += 1
        if occ == maxcount:
            return occ
    return occ