# (c) 2022 Exaloop Inc. All rights reserved.
# Adapted in part from Google's Python re2 wrapper
# https://github.com/google/re2/blob/abseil/python/re2.py

A          = (1 << 0)
ASCII      = (1 << 0)
DEBUG      = (1 << 1)
I          = (1 << 2)
IGNORECASE = (1 << 2)
L          = (1 << 3)
LOCALE     = (1 << 3)
M          = (1 << 4)
MULTILINE  = (1 << 4)
S          = (1 << 5)
DOTALL     = (1 << 5)
X          = (1 << 6)
VERBOSE    = (1 << 6)

_ANCHOR_NONE  = 0
_ANCHOR_START = 1
_ANCHOR_BOTH  = 2

@tuple
class Span:
    start: int
    end: int

    def __bool__(self):
        return not (self.start == -1 and self.end == -1)

@C
@nocapture
def seq_re_match(re: cobj,
                 anchor: int,
                 string: str,
                 pos: int,
                 endpos: int) -> Ptr[Span]:
    pass

@C
@pure
def seq_re_match_one(re: cobj,
                     anchor: int,
                     string: str,
                     pos: int,
                     endpos: int) -> Span:
    pass

@C
@pure
def seq_re_pattern_groups(re: cobj) -> int:
    pass

@C
@pure
def seq_re_group_name_to_index(re: cobj, name: str) -> int:
    pass

@C
@pure
def seq_re_group_index_to_name(re: cobj, index: int) -> str:
    pass

@C
@nocapture
def seq_re_pattern_groupindex(re: cobj,
                              names: Ptr[Ptr[str]],
                              indices: Ptr[Ptr[int]]) -> int:
    pass

@C
@pure
def seq_re_pattern_error(re: cobj) -> str:
    pass

@C
@pure
def seq_re_escape(pattern: str) -> str:
    pass

@C
def seq_re_purge() -> void:
    pass

@C
def seq_re_compile(pattern: str, flags: int) -> cobj:
    pass

class error:
    _hdr: ExcHeader
    pattern: str

    def __init__(self):
        self._hdr = ("re.error", "", "", "", 0, 0)
        self.pattern = ''

    def __init__(self, message: str, pattern: str = ''):
        self._hdr = ("re.error", message, "", "", 0, 0)
        self.pattern = pattern

    @property
    def message(self) -> str:
        return self._hdr.msg

    @property
    def msg(self):
        return self.message

@tuple
class Pattern:
    pattern: str
    flags: int
    _re: cobj

def compile(pattern: str, flags: int = 0):
    re = seq_re_compile(pattern, flags)
    err_msg = seq_re_pattern_error(re)
    if err_msg:
        raise error(err_msg, pattern)
    return Pattern(pattern, flags, re)

def search(pattern: str, string: str, flags: int = 0):
    return compile(pattern, flags).search(string)

def match(pattern: str, string: str, flags: int = 0):
    return compile(pattern, flags).match(string)

def fullmatch(pattern: str, string: str, flags: int = 0):
    return compile(pattern, flags).fullmatch(string)

def finditer(pattern: str, string: str, flags: int = 0):
    return compile(pattern, flags).finditer(string)

def findall(pattern: str, string: str, flags: int = 0):
    return compile(pattern, flags).findall(string)

def split(pattern: str, string: str, maxsplit: int =0, flags: int = 0):
    return compile(pattern, flags).split(string, maxsplit)

def sub(pattern: str, repl, string: str, count: int =0, flags: int = 0):
    return compile(pattern, flags).sub(repl, string, count)

def subn(pattern: str, repl, string: str, count: int =0, flags: int = 0):
    return compile(pattern, flags).subn(repl, string, count)

def escape(pattern: str):
    return seq_re_escape(pattern)

def purge():
    seq_re_purge()

@tuple
class Match:
    _spans: Ptr[Span]
    pos: int
    endpos: int
    re: Pattern
    string: str

    def _get_group_int(self, g: int, n: int):
        if not (0 <= g <= n):
            raise IndexError("no such group")
        return self._spans[g]

    def _get_group_str(self, g: str, n: int):
        return self._get_group_int(seq_re_group_name_to_index(self.re._re, g), n)

    def _get_group(self, g, n: int):
        if isinstance(g, int):
            return self._get_group_int(g, n)
        elif isinstance(g, str):
            return self._get_group_str(g, n)
        else:
            return self._get_group(g.__index__(), n)

    def _span_match(self, span: Span):
        if not span:
            return None
        return self.string._slice(span.start, span.end)

    def _get_match(self, g, n: int):
        span = self._get_group(g, n)
        return self._span_match(span)

    def _group_multi(self, n: int, *args):
        if staticlen(args) == 1:
            return (self._get_match(args[0], n),)
        else:
            return (self._get_match(args[0], n), *self._group_multi(n, *args[1:]))

    def group(self, *args):
        if staticlen(args) == 0:
            return ~(self._get_match(0, 1))
        elif staticlen(args) == 1:
            return self._get_match(args[0], self.re.groups)
        else:
            return self._group_multi(self.re.groups, *args)

    def __getitem__(self, g):
        return self._get_match(g, self.re.groups)

    def start(self, group = 0):
        return self._get_group(group, self.re.groups).start

    def end(self, group = 0):
        return self._get_group(group, self.re.groups).end

    def span(self, group = 0):
        start, end = self._get_group(group, self.re.groups)
        return start, end

    def _split(template: str):
        backslash = '\\'
        pieces = ['']
        index = template.find(backslash)

        OCTAL = compile(r'\\[0-7][0-7][0-7]')
        GROUP = compile(r'\\[1-9][0-9]?|\\g<\w+>')

        while index != -1:
            piece, template = template[:index], template[index:]
            pieces[-1] += piece

            octal_match = OCTAL.match(template)
            group_match = GROUP.match(template)

            if (not octal_match) and group_match:
                index = group_match.end()
                piece, template = template[:index], template[index:]
                pieces.extend((piece, ''))
            else:
                index = 2
                piece, template = template[:index], template[index:]
                pieces[-1] += piece

            index = template.find(backslash)

        pieces[-1] += template
        return pieces

    def _unescape(s: str):
        r = []
        n = len(s)
        i = 0
        while i < n:
            if s[i] == '\\' and i + 1 < n:
                c = s[i + 1]
                if c == 'a':
                    r.append('\a')
                    i += 1
                elif c == 'b':
                    r.append('\b')
                    i += 1
                elif c == 'f':
                    r.append('\f')
                    i += 1
                elif c == 'n':
                    r.append('\n')
                    i += 1
                elif c == 'r':
                    r.append('\r')
                    i += 1
                elif c == 't':
                    r.append('\t')
                    i += 1
                elif c == 'v':
                    r.append('\v')
                    i += 1
                elif c == '"':
                    r.append('\"')
                    i += 1
                elif c == '\'':
                    r.append('\'')
                    i += 1
                elif c == '\\':
                    r.append('\\')
                    i += 1
                elif '0' <= c <= '7':
                        k = i + 2
                        while k < n and k - i <= 4 and '0' <= s[k] <= '7':
                            k += 1
                        code = int(s[i+1:k], 8)
                        p = Ptr[byte](1)
                        p[0] = byte(code)
                        r.append(str(p, 1))
                        i = k - 1
                elif c.isalpha():
                    raise error(f"bad escape \\{c} at position {i}")
                else:
                    r.append(s[i])
            else:
                r.append(s[i])
            i += 1

        return str.cat(r)

    def expand(self, template: str):
        def get_or_empty(s: Optional[str]):
            return s if s is not None else ''

        pieces = list(Match._split(template))
        INT = compile(r'[+-]?\d+')

        for index, piece in enumerate(pieces):
            if not (index % 2):
                pieces[index] = Match._unescape(piece)
            else:
                if len(piece) <= 3:
                    pieces[index] = get_or_empty(self[int(piece[1:])])
                else:
                    group = piece[3:-1]
                    if INT.fullmatch(group):
                        pieces[index] = get_or_empty(self[int(group)])
                    else:
                        pieces[index] = get_or_empty(self[group])
        return str.cat(pieces)

    @property
    def lastindex(self):
        max_end = -1
        max_group = None
        for group in range(1, self.re.groups + 1):
            end = self._spans[group].end
            if max_end < end:
                max_end = end
                max_group = group
        return max_group

    @property
    def lastgroup(self):
        max_group = self.lastindex
        if max_group is None:
            return None
        return seq_re_group_index_to_name(self.re._re, max_group)

    def groups(self, default: Optional[str] = None):
        def get_or_default(item, default):
            return item if item is not None else default

        n = self.re.groups
        return [get_or_default(self._span_match(self._spans[i]), default)
                for i in range(1, n + 1)]

    def groupdict(self, default: Optional[str] = None):
        d = {}
        for group, index in self.re.groupindex.items():
            item = self[index]
            d[group] = item if item is not None else default
        return d

    def __copy__(self):
        return self

    def __deepcopy__(self):
        return self

    def __bool__(self):
        return True

@extend
class Pattern:
    @property
    def groups(self):
        return seq_re_pattern_groups(self._re)

    @property
    def groupindex(self):
        d = {}
        for i in range(1, self.groups + 1):
            name = seq_re_group_index_to_name(self._re, i)
            if name:
                d[name] = i
        return d

    def _match_one(self, anchor: int, string: str, pos: Optional[int], endpos: Optional[int]):
        posx = 0 if pos is None else max(0, min(~pos, len(string)))
        endposx = len(string) if endpos is None else max(0, min(~endpos, len(string)))

        if posx > endposx:
            return None

        spans = seq_re_match(self._re, anchor, string, posx, endposx)
        if not spans[0]:
            return None

        return Match(spans, posx, endposx, self, string)

    def _match(self, anchor: int, string: str, pos: Optional[int], endpos: Optional[int]):
        posx = 0 if pos is None else max(0, min(~pos, len(string)))
        endposx = len(string) if endpos is None else max(0, min(~endpos, len(string)))

        if posx > endposx:
            return

        while True:
            spans = seq_re_match(self._re, anchor, string, posx, endposx)

            if not spans[0]:
                break

            yield Match(spans, posx, endposx, self, string)

            if posx == endposx:
                break

            elif posx == spans[0][1]:
                # We matched the empty string at pos and would be stuck, so in order
                # to make forward progress, increment the bytes offset.
                posx += 1
            else:
                posx = spans[0][1]

    def search(self, string: str, pos: Optional[int] = None, endpos: Optional[int] = None):
        return self._match_one(_ANCHOR_NONE, string, pos, endpos)

    def match(self, string: str, pos: Optional[int] = None, endpos: Optional[int] = None):
        return self._match_one(_ANCHOR_START, string, pos, endpos)

    def fullmatch(self, string: str, pos: Optional[int] = None, endpos: Optional[int] = None):
        return self._match_one(_ANCHOR_BOTH, string, pos, endpos)

    def finditer(self, string: str, pos: Optional[int] = None, endpos: Optional[int] = None):
        return self._match(_ANCHOR_NONE, string, pos, endpos)

    def findall(self, string: str, pos: Optional[int] = None, endpos: Optional[int] = None):
        return [m.group() for m in self.finditer(string, pos, endpos)]

    def _split(self, cb, string: str, maxsplit: int = 0, T: type = str):
        if maxsplit < 0:
            return [T(string)], 0

        pieces: List[T] = []
        end = 0
        numsplit = 0
        for match in self.finditer(string):
            if (maxsplit > 0 and numsplit >= maxsplit):
                break
            pieces.append(string[end:match.start()])
            pieces.extend(cb(match))
            end = match.end()
            numsplit += 1
        pieces.append(string[end:])
        return pieces, numsplit

    def split(self, string: str, maxsplit: int = 0):
        cb = lambda match: [match[group] for group in range(1, self.groups + 1)]
        pieces, _ = self._split(cb, string, maxsplit, Optional[str])
        return pieces

    def _repl(match, repl):
        if isinstance(repl, str):
            return match.expand(repl)
        else:
            return repl(match)

    def subn(self, repl, string: str, count: int = 0):
        cb = lambda match: [Pattern._repl(match, repl)]
        pieces, numsplit = self._split(cb, string, count, str)
        joined_pieces = str.cat(pieces)
        return joined_pieces, numsplit

    def sub(self, repl, string: str, count: int = 0):
        joined_pieces, _ = self.subn(repl, string, count)
        return joined_pieces

    def __bool__(self):
        return True