# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>

import util
from .npdatetime import datetime64, timedelta64, busdaycalendar, _apply_busines_day_offset, \
                        _apply_busines_day_count, _apply_is_business_day, _BUSDAY_FORWARD, \
                        _BUSDAY_FOLLOWING, _BUSDAY_BACKWARD, _BUSDAY_PRECEDING, \
                        _BUSDAY_MODIFIEDFOLLOWING, _BUSDAY_MODIFIEDPRECEDING, _BUSDAY_NAT, \
                        _BUSDAY_RAISE
from .ndarray import ndarray, flagsobj

def _check_out(out, shape):
    if not isinstance(out, ndarray):
        compile_error("output must be an array")

    if out.ndim != staticlen(shape):
        compile_error("output parameter has the wrong number of dimensions")

    if out.shape != shape:
        raise ValueError("output parameter has incorrect shape")

############
# Creation #
############

def _inner_type(t):
    if isinstance(t, ndarray):
        return t.dtype()
    elif isinstance(t, List) or isinstance(t, Tuple):
        return _inner_type(t[0])
    else:
        return t

def _extract_shape(t):
    if isinstance(t, ndarray):
        return t.shape
    elif isinstance(t, List) or isinstance(t, Tuple):
        rest = _extract_shape(t[0])
        return (len(t), *rest)
    else:
        return ()

def _flatten(t, shape, D: type):
    def helper(t, p, start: int, D: type):
        for s in t:
            if isinstance(s, ndarray):
                cc, fc = s._contig
                if cc:
                    str.memcpy((p + start).as_byte(), s.data.as_byte(), s.nbytes)
                else:
                    for idx in util.multirange(s.shape):
                        p[start] = s._ptr(idx)[0]
                        start += 1
                start += s.size
            elif isinstance(s, List) or isinstance(s, Tuple):
                start = helper(s, p, start, D)
            else:
                p[start] = util.cast(s, D)
                start += 1
        return start

    p = Ptr[D](util.count(shape))
    helper(t, p, 0, D)
    return p

def _validate_shape(t, shape):
    def error():
        raise ValueError('array dimensions mismatch (jagged arrays are not allowed)')

    if staticlen(shape) == 0:
        return
    else:
        if not hasattr(type(t), "__getitem__"):
            error()
        else:
            if isinstance(t, ndarray):
                if not util.tuple_equal(t.shape, shape):
                    error()
            else:
                if len(t) != shape[0]:
                    error()
                for s in t:
                    _validate_shape(s, shape[1:])

def _array(a, dtype: type = NoneType, copy: bool = True, order: str = 'K'):
    if dtype is NoneType:
        if isinstance(a, ndarray):
            return _array(a, a.dtype, copy, order)
        else:
            return _array(a, type(_inner_type(a)), copy, order)

    ndarray._check_order(order)

    if isinstance(a, ndarray):
        if copy:
            return a.astype(dtype, order=order, copy=True)
        else:
            if order == 'K' or order == 'A':
                return a.astype(dtype, copy=False)

            cc, fc = a._contig

            if order == 'C':
                if cc:
                    return a.astype(dtype, copy=False)
                else:
                    return a.astype(dtype, order='C', copy=False)

            if order == 'F':
                if fc:
                    return a.astype(dtype, copy=False)
                else:
                    return a.astype(dtype, order='F', copy=False)
    elif isinstance(a, List) or isinstance(a, Tuple):
        shape = _extract_shape(a)
        data = _flatten(a, shape, dtype)
        _validate_shape(a, shape)
        result = ndarray(shape, data)

        if order == 'F':
            return result.astype(dtype, order='F')
        else:
            return result
    else:
        shape = ()
        data = Ptr[dtype](1)
        data[0] = util.cast(a, dtype)
        return ndarray(shape, data)

def array(a, dtype: type = NoneType, copy: bool = True, order: str = 'K', ndmin: Static[int] = 0):
    result = _array(a, dtype=dtype, copy=copy, order=order)
    if staticlen(result.shape) < ndmin:
        o = (1,) * (ndmin - staticlen(result.shape))
        return result.reshape((*o, *result.shape))
    return result

def asarray(a, dtype: type = NoneType, order: str = 'K'):
    return array(a, dtype=dtype, copy=False, order=order)

def asanyarray(a, dtype: type = NoneType, order: str = 'K'):
    return asarray(a, dtype=dtype, order=order)

def asarray_chkfinite(a, dtype: type = NoneType, order: str = 'K'):
    # Note: this is c/p from ndmath
    def isfinite(x):
        if isinstance(x, float) or isinstance(x, float32):
            return util.isfinite(x)
        elif isinstance(x, complex) or isinstance(x, complex64):
            return util.isfinite(x.real) and util.isfinite(x.imag)
        else:
            return True

    a = asarray(a, dtype=dtype, order=order)
    for idx in util.multirange(a.shape):
        if not isfinite(a._ptr(idx)[0]):
            raise ValueError("array must not contain infs or NaNs")
    return a

def empty(shape, dtype: type = float, order: str = 'C'):
    if isinstance(shape, int):
        return empty((shape,), dtype, order)

    for s in shape:
        if s < 0:
            raise ValueError('negative dimensions are not allowed')

    ccontig = (order == 'C')
    fcontig = (order == 'F')
    if not (ccontig or fcontig):
        raise ValueError("'order' must be 'C' or 'F'")

    data = Ptr[dtype](util.count(shape))
    return ndarray(shape, data, fcontig=fcontig)

def empty_like(prototype, dtype: type = NoneType, order: str = 'K', shape = None):
    prototype = asarray(prototype)

    if dtype is NoneType:
        return empty_like(prototype, dtype=prototype.dtype, order=order, shape=None)

    if shape is None:
        return empty_like(prototype, dtype=dtype, order=order, shape=prototype.shape)

    cc, fc = prototype._contig

    if order == 'A':
        order = 'F' if fc else 'C'
    elif order == 'K':
        if staticlen(shape) == prototype.ndim:
            if cc or prototype.ndim <= 1:
                order = 'C'
            elif fc:
                order = 'F'
            else:
                strides = (0,) * staticlen(prototype.strides)
                pstrides = Ptr[int](__ptr__(strides).as_byte())
                r = util.tuple_range(staticlen(shape))
                perm, strides_sorted = util.sort_by_stride(r, prototype.strides)
                stride = util.sizeof(dtype)
                ndim = prototype.ndim

                for idim in range(ndim - 1, -1, -1):
                    iperm = perm[idim]
                    pstrides[iperm] = stride
                    stride *= shape[iperm]

                data = Ptr[dtype](util.count(shape))
                return ndarray(shape, strides, data)
        else:
            order = 'C'

    return empty(shape, dtype, order)

def zeros(shape, dtype: type = float, order: str = 'C'):
    result = empty(shape, dtype, order)
    str.memset(result.data.as_byte(), byte(0), result.nbytes)
    return result

def zeros_like(prototype, dtype: type = NoneType, order: str = 'K'):
    result = empty_like(prototype, dtype, order)
    str.memset(result.data.as_byte(), byte(0), result.nbytes)
    return result

def ones(shape, dtype: type = float, order: str = 'C'):
    result = empty(shape, dtype, order)
    result.map(lambda x: util.cast(1, result.dtype), inplace=True)
    return result

def ones_like(prototype, dtype: type = NoneType, order: str = 'K'):
    result = empty_like(prototype, dtype, order)
    result.map(lambda x: util.cast(1, result.dtype), inplace=True)
    return result

def identity(n: int, dtype: type = float):
    result = zeros((n, n), dtype)
    p = result.data
    for i in range(n):
        p[i * (n + 1)] = dtype(1)
    return result

def _diag_count(k: int, n: int, m: int):
    count = 0
    if k < m and k > -n:
        count = min(n, m)
        if k > 0:
            d = max(m - n, 0)
            if k > d:
                count -= (k - d)
        elif k < 0:
            d = max(n - m, 0)
            if k < -d:
                count -= (-k - d)
    return count

def eye(N: int, M: Optional[int] = None, k: int = 0, dtype: type = float, order: str = 'C'):
    n: int = N
    m: int = M if M is not None else n
    result = zeros((n, m), dtype=dtype, order=order)
    p = result.data

    for i in range(_diag_count(k, n, m)):
        if k >= 0:
            result[i, i + k] = dtype(1)
        else:
            j = n - i - 1
            result[i - k, i] = dtype(1)

    return result

def diag(v, k: int = 0):
    v = asarray(v)
    data = v.data
    T = type(data[0])

    if staticlen(v.shape) == 1:
        count = v.shape[0]
        n = count + abs(k)
        result = zeros((n, n), dtype=T)
        p = result.data
        if k > 0:
            p += k
        elif k < 0:
            p += (-k) * n

        for i in range(count):
            q = v._ptr((i,))
            p[0] = q[0]
            p += n + 1

        return result
    elif staticlen(v.shape) == 2:
        n, m = v.shape
        sn, sm = v.strides
        new_shape = (_diag_count(k, n, m),)
        new_strides = (sn + sm,)
        new_data = data
        if new_shape[0] > 0:
            if k >= 0:
                new_data = v._ptr((0, k))
            else:
                new_data = v._ptr((-k, 0))
        return ndarray(new_shape, new_strides, new_data)
    else:
        compile_error('Input must be 1- or 2-d.')

def diagflat(v, k: int = 0):
    return diag(asarray(v).flatten(), k)

def tri(N: int, M: Optional[int] = None, k: int = 0, dtype: type = float):
    n: int = N
    m: int = M if M is not None else n
    result = zeros((n, m), dtype=dtype)
    p = result.data
    for i in range(n):
        for j in range(min(i + k + 1, m)):
            p[i*m + j] = dtype(1)
    return result

def triu(x, k: int = 0):
    x = asarray(x)
    T = x.dtype

    if staticlen(x.shape) == 0:
        compile_error('Cannot call triu on 0-d array.')
    elif staticlen(x.shape) == 1:
        n = x.shape[0]
        result = zeros((n, n), dtype=T)
        p = result.data
        for i in range(n):
            for j in range(max(0, i + k), n):
                p[i*n + j] = x[j]
        return result
    else:
        y = x.copy()
        n, m = x.shape[-2], x.shape[-1]
        pre = (slice(None, None, None),) * (staticlen(x.shape) - 2)
        for i in range(n):
            for j in range(min(i + k, m)):
                y[(*pre, i, j)] = T(0)
        return y

def tril(x, k: int = 0):
    x = asarray(x)
    T = x.dtype

    if staticlen(x.shape) == 0:
        compile_error('Cannot call tril on 0-d array.')
    elif staticlen(x.shape) == 1:
        n = x.shape[0]
        result = zeros((n, n), dtype=T)
        p = result.data
        for i in range(n):
            for j in range(min(i + k + 1, n)):
                p[i*n + j] = x[j]
        return result
    else:
        y = x.copy()
        n, m = x.shape[-2], x.shape[-1]
        pre = (slice(None, None, None),) * (staticlen(x.shape) - 2)
        for i in range(n):
            for j in range(max(0, i + k + 1), m):
                y[(*pre, i, j)] = T(0)
        return y

def vander(x, N: Optional[int] = None, increasing: bool = False):
    x = asarray(x)

    if staticlen(x.shape) != 1:
        compile_error('x must be a one-dimensional array or sequence.')

    T = x.dtype
    n: int = x.shape[0]
    m: int = N if N is not None else n
    result = zeros((n, m), dtype=T)
    p = result.data

    for i in range(n):
        base = x._ptr((i,))[0]
        for j in range(m):
            power = j if increasing else m - j - 1
            p[i*m + j] = base ** power

    return result

@pure
@llvm
def _signbit(x: float) -> bool:
    %y = bitcast double %x to i64
    %z = icmp slt i64 %y, 0
    %b = zext i1 %z to i8
    ret i8 %b

_imin: Static[int] = -9223372036854775808
_imax: Static[int] = 9223372036854775807

def _safe_ceil(value: float):
    ivalue = util.ceil64(value)
    if util.isnan64(ivalue):
        raise ValueError('arange: cannot compute array length')

    if not (float(_imin) <= ivalue <= float(_imax)):
        raise OverflowError('arange: overflow while computing length')

    return int(ivalue)

def _datetime_arange(start, stop, step, dtype: type):
    def convert(x, dtype: type):
        if isinstance(x, datetime64) or isinstance(x, timedelta64):
            return x.value
        elif isinstance(x, int):
            return x
        elif isinstance(x, str):
            return dtype(x, dtype.base).value
        else:
            compile_error("datetime_arange inputs must be datetime64, timedelta64 or int")

    def nat(x: int):
        return timedelta64(x, 'generic')._nat

    if not (isinstance(dtype, datetime64) or isinstance(dtype, timedelta64)):
        compile_error("datetime_arange was given a non-datetime dtype")

    a = convert(start, dtype)
    b = convert(stop, dtype)
    c = 0

    if step is None:
        c = 1
    elif isinstance(step, datetime64) or isinstance(step, str):
        compile_error("cannot use a datetime as a step in arange")
    else:
        c = convert(step, dtype)

    if ((isinstance(start, datetime64) or isinstance(start, str)) and not
        (isinstance(stop, datetime64) or isinstance(stop, str))):
        b += a

    if nat(a) or nat(b) or nat(c):
        raise ValueError("arange: cannot use NaT (not-a-time) datetime values")

    length = len(range(a, b, c))
    ans = empty(length, dtype)

    for i in range(length):
        ans.data[i] = dtype(a, dtype.base)
        a += c

    return ans

def arange(start: float, stop: float, step: float, dtype: type = float):
    if step == 0.0:
        raise ValueError('step cannot be zero')

    delta = stop - start
    tmp_len = delta / step

    length = 0
    if tmp_len == 0.0 and delta != 0.0:
        if _signbit(tmp_len):
            length = 0
        else:
            length = 1
    else:
        length = _safe_ceil(tmp_len)

    if length <= 0:
        return empty(0, dtype=dtype)

    result = empty(length, dtype=dtype)
    p = result.data
    i = start
    j = 0
    while (i < stop) if step > 0.0 else (i > stop):
        p[j] = util.cast(i, dtype)
        j += 1
        i += step

    return result

@overload
def arange(stop: float, step: float, dtype: type = float):
    return arange(0.0, stop, step, dtype)

@overload
def arange(start: float, stop: float, dtype: type = float):
    return arange(start, stop, 1.0, dtype)

@overload
def arange(stop: float, dtype: type = float):
    return arange(0.0, stop, 1.0, dtype)

@overload
def arange(start: int, stop: int, step: int, dtype: type = int):
    if isinstance(dtype, datetime64) or isinstance(dtype, timedelta64):
        return _datetime_arange(start, stop, step, dtype)

    length = len(range(start, stop, step))
    result = empty(length, dtype=dtype)
    p = result.data
    j = 0

    for i in range(start, stop, step):
        p[j] = util.cast(i, dtype)
        j += 1

    return result

@overload
def arange(stop: int, step: int, dtype: type = int):
    return arange(0, stop, step, dtype)

@overload
def arange(start: int, stop: int, dtype: type = int):
    return arange(start, stop, 1, dtype)

@overload
def arange(stop: int, dtype: type = int):
    return arange(0, stop, 1, dtype)

@overload
def arange(start: datetime64, stop, step = None, dtype: type = NoneType):
    if dtype is NoneType:
        return _datetime_arange(start, stop, step, type(start))
    else:
        return _datetime_arange(start, stop, step, dtype)

@overload
def arange(start: timedelta64, stop, step, dtype: type = NoneType):
    if dtype is NoneType:
        return _datetime_arange(start, stop, step, type(start))
    else:
        return _datetime_arange(start, stop, step, dtype)

@overload
def arange(start: timedelta64, stop, dtype: type = NoneType):
    if dtype is NoneType:
        return _datetime_arange(start, stop, 1, type(start))
    else:
        return _datetime_arange(start, stop, 1, dtype)

@overload
def arange(stop: timedelta64, dtype: type = NoneType):
    if dtype is NoneType:
        return _datetime_arange(0, stop, 1, type(stop))
    else:
        return _datetime_arange(0, stop, 1, dtype)

@overload
def arange(start: str, stop, step = None, dtype: type = datetime64['D', 1]):
    return _datetime_arange(start, stop, step, dtype)

def linspace(start: float, stop: float, num: int = 50,
             endpoint: bool = True, retstep: Static[int] = False,
             dtype: type = float):
    if num < 0:
        raise ValueError(f'Number of samples, {num}, must be non-negative.')

    delta = stop - start
    div = (num - 1) if endpoint else num
    step = delta / div

    result = empty(num, dtype=dtype)
    p = result.data

    if div > 0:
        if step == 0:
            for i in range(num):
                p[i] = util.cast(((i / div) * delta) + start, dtype)
        else:
            for i in range(num):
                p[i] = util.cast((i * step) + start, dtype)
    else:
        for i in range(num):
            p[i] = util.cast((i * delta) + start, dtype)

        step = util.nan64()

    if endpoint and num > 1:
        p[num - 1] = stop

    if retstep:
        return result, step
    else:
        return result

def _linlogspace(start: float, stop: float, num: int = 50, base: float = 10.0,
                 out_sign: int = 1, endpoint: bool = True, retstep: Static[int] = False,
                 dtype: type = float, log: Static[int] = False):
    if num < 0:
        raise ValueError(f'Number of samples, {num}, must be non-negative.')

    delta = stop - start
    div = (num - 1) if endpoint else num
    step = delta / div

    result = empty(num, dtype=dtype)
    p = result.data

    if div > 0:
        if step == 0:
            for i in range(num):
                y = ((i / div) * delta) + start
                if log:
                    y = util.pow64(base, y)
                y *= out_sign
                p[i] = util.cast(y, dtype)
        else:
            for i in range(num):
                y = (i * step) + start
                if log:
                    y = util.pow64(base, y)
                y *= out_sign
                p[i] = util.cast(y, dtype)
    else:
        for i in range(num):
            y = (i * delta) + start
            if log:
                y = util.pow64(base, y)
            y *= out_sign
            p[i] = util.cast(y, dtype)

        step = util.nan64()

    if endpoint and num > 1:
        y = stop
        if log:
            y = util.pow64(base, y)
        y *= out_sign
        p[num - 1] = util.cast(y, dtype)

    if retstep:
        return result, step
    else:
        return result

def linspace(start: float, stop: float, num: int = 50,
             endpoint: bool = True, retstep: Static[int] = False,
             dtype: type = float):
    return _linlogspace(start=start, stop=stop, num=num,
                        endpoint=endpoint, retstep=retstep,
                        dtype=dtype, log=False)

def logspace(start: float, stop: float, num: int = 50,
             endpoint: bool = True, base: float = 10.0,
             retstep: Static[int] = False,
             dtype: type = float):
    return _linlogspace(start=start, stop=stop, num=num,
                        endpoint=endpoint, retstep=retstep,
                        dtype=dtype, base=base, log=True)

def geomspace(start: float, stop: float, num: int = 50,
              endpoint: bool = True, dtype: type = float):
    if start == 0 or stop == 0:
        raise ValueError('Geometric sequence cannot include zero')

    out_sign = 1
    if start < 0 and stop < 0:
        start, stop = -start, -stop
        out_sign = -out_sign

    start = start + (stop - stop)
    stop = stop + (start - start)
    log_start = util.log10_64(start)
    log_stop = util.log10_64(stop)

    return _linlogspace(start=log_start, stop=log_stop, num=num,
                        endpoint=endpoint, retstep=False,
                        dtype=dtype, base=10.0, out_sign=out_sign,
                        log=True)

def fromfunction(function, shape, dtype: type = float, **kwargs):
    result_dtype = type(
                    function(
                     *tuple(util.zero(dtype)
                            for _ in staticrange(staticlen(shape))),
                     **kwargs))
    result = empty(shape, dtype=result_dtype)
    for idx in util.multirange(shape):
        p = result._ptr(idx)
        args = tuple(util.cast(i, dtype) for i in idx)
        p[0] = function(*args, **kwargs)
    return result

def fromiter(iterable, dtype: type, count: int = -1):
    if count < 0:
        return array([a for a in iterable], dtype=dtype)
    else:
        result = empty((count,), dtype=dtype)
        if count:
            p = result.data
            i = 0
            for a in iterable:
                p[i] = util.cast(a, dtype)
                i += 1
                if i == count:
                    break
            if i != count:
                raise ValueError(f'iterator too short: Expected {count} but iterator had only {i} items.')
        return result

def frombuffer(buffer: str, dtype: type = float, count: int = -1, offset: int = 0):
    if count < 0:
        count = len(buffer) // util.sizeof(dtype)

    p = Ptr[dtype](buffer.ptr + offset)
    return ndarray((count,), p)

################
# Broadcasting #
################

def broadcast_shapes(*args):
    def _largest(args):
        if staticlen(args) == 1:
            return args[0]

        a = args[0]
        b = _largest(args[1:])
        if staticlen(b) > staticlen(a):
            return b
        else:
            return a

    def _ensure_tuple(x):
        if isinstance(x, Tuple):
            return x
        else:
            return (x,)

    if staticlen(args) == 0:
        return ()

    args = tuple(_ensure_tuple(a) for a in args)
    for a in args:
        for i in a:
            if i < 0:
                raise ValueError('negative dimensions are not allowed')

    t = _largest(args)
    N: Static[int] = staticlen(t)
    ans = (0,) * N
    p = Ptr[int](__ptr__(ans).as_byte())

    for i in staticrange(N):
        p[i] = t[i]

    for a in args:
        for i in staticrange(staticlen(a)):
            x = a[len(a) - 1 - i]
            q = p + (len(t) - 1 - i)
            y = q[0]

            if y == 1:
                q[0] = x
            elif x != 1 and x != y:
                raise ValueError('shape mismatch: objects cannot be broadcast to a single shape')

    return ans

def broadcast_to(x, shape):
    x = asarray(x)

    if not isinstance(shape, Tuple):
        return broadcast_to(x, (shape,))

    N: Static[int] = staticlen(x.shape)

    if staticlen(shape) < N:
        compile_error('input operand has more dimensions than allowed by the axis remapping')

    shape1, shape2 = shape[:-N], shape[-N:]
    substrides = (0,) * N
    p = Ptr[int](__ptr__(substrides).as_byte())

    if N > 0:
        for i in range(N):
            a = x.shape[i]
            b = shape2[i]

            if a == b:
                p[i] = x.strides[i]
            elif a == 1:
                p[i] = 0
            else:
                raise ValueError(f'cannot broadcast array of shape {x.shape} to shape {shape}')

    z = (0,) * (staticlen(shape) - N)
    new_strides = (*z, *substrides)
    return ndarray(shape, new_strides, x.data)

def broadcast_arrays(*args):
    def _ensure_array(x):
        if isinstance(x, ndarray):
            return x
        else:
            return array(x)

    args = tuple(_ensure_array(a) for a in args)
    shapes = tuple(a.shape for a in args)
    bshape = broadcast_shapes(*shapes)
    return [broadcast_to(a, bshape) for a in args]

def meshgrid(*xi, copy: bool = True, sparse: Static[int] = False, indexing: Static[str] = 'xy'):
    def make_shape(i, ndim: Static[int]):
        t = (1,) * ndim
        p = Ptr[int](__ptr__(t).as_byte())
        p[i] = -1
        return t

    def build_output(xi, i: int = 0, ndim: Static[int]):
        if staticlen(xi) == 0:
            return ()

        x = xi[0]
        y = array(x).reshape(make_shape(i, ndim))
        rest = build_output(xi[1:], i + 1, ndim)
        return (y, *rest)

    if indexing != 'xy' and indexing != 'ij':
        compile_error("Valid values for `indexing` are 'xy' and 'ij'.")

    ndim: Static[int] = staticlen(xi)
    s0 = (1,) * ndim
    output = build_output(xi, ndim=ndim)

    if indexing == 'xy' and ndim > 1:
        # switch first and second axis
        output0 = output[0].reshape(1, -1, *s0[2:])
        output1 = output[1].reshape(-1, 1, *s0[2:])
        output = (output0, output1, *output[2:])

    if not sparse:
        # Return the full N-D matrix (not only the 1-D vector)
        return [a for a in broadcast_arrays(*output)]

    if copy:
        return [a.copy() for a in output]

    return [a for a in output]

class _broadcast[A, S]:
    _arrays: A
    _shape: S
    _index: int

    def __init__(self, arrays: A, shape: S):
        self._arrays = arrays
        self._shape = shape
        self._index = 0

    @property
    def shape(self):
        return self._shape

    @property
    def index(self):
        return self._index

    @property
    def size(self):
        return util.count(self.shape)

    @property
    def ndim(self):
        return staticlen(self.shape)

    @property
    def nd(self):
        return self.ndim

    @property
    def numiter(self):
        return staticlen(self._arrays)

    @property
    def iters(self):
        def get_flat(arr, index, bshape):
            f = broadcast_to(arr, bshape).flat
            f.index = index
            return f

        return tuple(get_flat(a, self.index, self.shape) for a in self._arrays)

    def __iter__(self):
        arrays = self._arrays
        n = self.size

        while self.index < n:
            idx = util.index_to_coords(self.index, self.shape)
            self._index += 1
            yield tuple(a._ptr(idx, broadcast=True)[0] for a in arrays)

    def reset(self):
        self._index = 0

def broadcast(*args):
    arrays = tuple(asarray(a) for a in args)
    shape = broadcast_shapes(*tuple(a.shape for a in arrays))
    return _broadcast(arrays, shape)

def full(shape, fill_value, dtype: type = NoneType, order: str = 'C'):
    if isinstance(shape, int):
        sh = (shape,)
    else:
        sh = shape

    fv = asarray(fill_value)
    if fv.ndim != 0:
        broadcast_to(fv, shape)  # error check

    if dtype is NoneType:
        result = empty(shape, fv.dtype, order)
        if fv.ndim == 0:
            e = fv.item()
            result.map(lambda x: e, inplace=True)
        else:
            for idx in util.multirange(shape):
                result._ptr(idx)[0] = fv._ptr(idx, broadcast=True)[0]
        return result
    else:
        result = empty(shape, dtype, order)
        if fv.ndim == 0:
            e = fv.item()
            result.map(lambda x: util.cast(e, result.dtype), inplace=True)
        else:
            for idx in util.multirange(shape):
                result._ptr(idx)[0] = util.cast(fv._ptr(idx, broadcast=True)[0], result.dtype)
        return result

def full_like(prototype, fill_value, dtype: type = NoneType, order: str = 'K'):
    prototype = asarray(prototype)
    shape = prototype.shape

    fv = asarray(fill_value)
    if fv.ndim != 0:
        broadcast_to(fv, shape)  # error check

    if dtype is NoneType:
        result = empty_like(prototype, fv.dtype, order)
        if fv.ndim == 0:
            e = fv.item()
            result.map(lambda x: e, inplace=True)
        else:
            for idx in util.multirange(shape):
                result._ptr(idx)[0] = fv._ptr(idx, broadcast=True)[0]
        return result
    else:
        result = empty_like(prototype, dtype, order)
        if fv.ndim == 0:
            e = fv.item()
            result.map(lambda x: util.cast(e, result.dtype), inplace=True)
        else:
            for idx in util.multirange(shape):
                result._ptr(idx)[0] = util.cast(fv._ptr(idx, broadcast=True)[0], result.dtype)
        return result

################
# Manipulation #
################

def copyto(dst: ndarray, src, where = True):
    src = asarray(src)
    dst_dtype = dst.dtype
    src_dtype = src.dtype

    if isinstance(where, bool):
        if not where:
            return

        if dst_dtype is src_dtype and src._contig_match(dst):
            str.memcpy(dst.data.as_byte(), src.data.as_byte(), dst.nbytes)
            return

    src = broadcast_to(src, dst.shape)
    where = broadcast_to(asarray(where), dst.shape)

    for idx in util.multirange(dst.shape):
        w = where._ptr(idx)
        if w[0]:
            p = src._ptr(idx)
            q = dst._ptr(idx)
            q[0] = util.cast(p[0], dst_dtype)

def ndim(a):
    return asarray(a).ndim

def size(a, axis: Optional[int] = None):
    a = asarray(a)
    if axis is None:
        return a.size
    else:
        return a.shape[axis]

def shape(a):
    if isinstance(a, ndarray):
        return a.shape
    else:
        shape = _extract_shape(a)
        _validate_shape(a, shape)
        return shape

def reshape(a, newshape, order: str = 'C'):
    return asarray(a).reshape(newshape, order=order)

def transpose(a, axes=None):
    return a.transpose(axes)

def ravel(a, order: str = 'C'):
    return asarray(a).ravel(order=order)

def ascontiguousarray(a, dtype: type = NoneType):
    return asarray(a, dtype=dtype, order='C')

def asfortranarray(a, dtype: type = NoneType):
    return asarray(a, dtype=dtype, order='F')

def asfarray(a, dtype: type = float):
    if (dtype is not float and
        dtype is not float32 and
        dtype is not complex and
        dtype is not complex64):
        return asfarray(a, float)
    return asarray(a, dtype=dtype)

def moveaxis(a, source, destination):
    a = asarray(a)
    source = util.normalize_axis_tuple(source, a.ndim, 'source')
    destination = util.normalize_axis_tuple(destination, a.ndim, 'destination')
    if len(source) != len(destination):
        raise ValueError('`source` and `destination` arguments must have '
                         'the same number of elements')

    order = [n for n in range(a.ndim) if n not in source]
    for dest, src in sorted(zip(destination, source)):
        order.insert(dest, src)

    order = Ptr[type(a.shape)](order.arr.ptr.as_byte())[0]
    return a.transpose(order)

def swapaxes(a, axis1: int, axis2: int):
    return asarray(a).swapaxes(axis1, axis2)

def atleast_1d(*arys):
    def atl1d(a):
        a = asarray(a)
        if staticlen(a.shape) == 0:
            return a.reshape(1)
        else:
            return a

    if staticlen(arys) == 1:
        return atl1d(arys[0])
    else:
        return tuple(atl1d(a) for a in arys)

def atleast_2d(*arys):
    def atl2d(a):
        a = asarray(a)
        if staticlen(a.shape) == 0:
            return a.reshape(1, 1)
        elif staticlen(a.shape) == 1:
            return a[None, :]
        else:
            return a

    if staticlen(arys) == 1:
        return atl2d(arys[0])
    else:
        return tuple(atl2d(a) for a in arys)

def atleast_3d(*arys):
    def atl3d(a):
        a = asarray(a)
        if staticlen(a.shape) == 0:
            return a.reshape(1, 1, 1)
        elif staticlen(a.shape) == 1:
            return a[None, :, None]
        elif staticlen(a.shape) == 2:
            return a[:, :, None]
        else:
            return a

    if staticlen(arys) == 1:
        return atl3d(arys[0])
    else:
        return tuple(atl3d(a) for a in arys)

def require(a, dtype: type = NoneType, requirements = None):
    REQ_C: Static[int] = 1
    REQ_F: Static[int] = 2
    REQ_A: Static[int] = 4
    REQ_W: Static[int] = 8
    REQ_O: Static[int] = 16
    REQ_E: Static[int] = 32

    if requirements is None:
        return asarray(a, dtype=dtype)

    if not requirements:
        return asarray(a, dtype=dtype)

    req = 0
    for x in requirements:
        if x in ('C', 'C_CONTIGUOUS', 'CONTIGUOUS'):
            req |= REQ_C
        elif x in ('F', 'F_CONTIGUOUS', 'FORTRAN'):
            req |= REQ_F
        elif x in ('A', 'ALIGNED'):
            req |= REQ_A
        elif x in ('W', 'WRITEABLE'):
            req |= REQ_W
        elif x in ('O', 'OWNDATA'):
            req |= REQ_O
        elif x in ('E', 'ENSUREARRAY'):
            req |= REQ_E
        else:
            raise ValueError("invalid requirement: " + repr(x))

    order = 'A'
    if (req & REQ_C) and (req & REQ_F):
        raise ValueError('Cannot specify both "C" and "F" order')
    elif req & REQ_F:
        order = 'F'
        req &= ~REQ_F
    elif req & REQ_C:
        order = 'C'
        req &= ~REQ_C

    # Codon-NumPy ignores other flags/properties currently

    copy = ((req & REQ_O) != 0)
    arr = array(a, dtype=dtype, order=order, copy=copy)
    return arr

def _copy_data_c(dst: cobj, src: cobj, shape: Ptr[int],
                strides: Ptr[int], dst_strides: Ptr[int],
                ndim: Static[int], element_size: int,
                block: int, block_dim: int):
    if ndim < 0:
        return

    if ndim == block_dim:
        str.memcpy(dst, src, block)
    else:
        src_stride = strides[0]
        dst_stride = dst_strides[0]
        for j in range(shape[0]):
            _copy_data_c(dst + j * dst_stride, src + j * src_stride,
                         shape + 1, strides + 1, dst_strides + 1,
                         ndim - 1, element_size, block, block_dim)

def _copy_data_f(dst: cobj, src: cobj, shape: Ptr[int],
                strides: Ptr[int], dst_strides: Ptr[int],
                ndim: Static[int], element_size: int,
                block: int, block_dim: int):
    if ndim < 0:
        return

    if ndim == block_dim:
        str.memcpy(dst, src, block)
    else:
        src_stride = strides[ndim - 1]
        dst_stride = dst_strides[ndim - 1]
        for j in range(shape[ndim - 1]):
            _copy_data_f(dst + j * dst_stride, src + j * src_stride,
                         shape, strides, dst_strides,
                         ndim - 1, element_size, block, block_dim)

def concatenate(arrays, axis = 0, out = None, dtype: type = NoneType):
    def check_array_dims_static(arrays, ndim: Static[int]):
        if staticlen(arrays) > 0:
            if staticlen(arrays[0].shape) != ndim:
                compile_error("all the input arrays must have same number of dimensions")
            check_array_dims_static(arrays[1:], ndim)

    def find_out_type(arrays, dtype: type):
        if staticlen(arrays) == 0:
            return dtype()
        else:
            x = util.coerce(arrays[0].dtype, dtype)
            return find_out_type(arrays[1:], type(x))

    def concat_inner(arrays, axis: int, out, dtype: type):
        if out is not None and dtype is NoneType:
            return concat_inner(arrays, axis=axis, out=out, dtype=out.dtype)

        if out is None and dtype is NoneType:
            compile_error("[internal error] bad out and dtype given to concat_inner")

        arr0 = asarray(arrays[0])
        ndim: Static[int] = arrays[0].ndim
        axis = util.normalize_axis_index(axis, ndim)
        shape = arr0.shape
        pshape = Ptr[int](__ptr__(shape).as_byte())

        for i in range(1, len(arrays)):
            arr = arrays[i]
            arr_shape = arr.shape

            if arr.ndim != ndim:
                compile_error("all the input arrays must have same number of dimensions")

            for idim in staticrange(ndim):
                if idim == axis:
                    pshape[idim] += arr_shape[idim]
                elif pshape[idim] != arr_shape[idim]:
                    raise ValueError("all the input array dimensions except for "
                                     "the concatenation axis must match exactly")

        corder = True
        if out is not None:
            if dtype is not NoneType:
                compile_error("can only specify one of 'out' and 'dtype'")
            _check_out(out, shape)
            ret = out
        else:
            num_c = 0
            num_f = 0
            for array in arrays:
                cc, fc = array._contig
                if cc:
                    num_c += 1
                if fc:
                    num_f += 1
            corder = (num_c >= num_f)
            ret = empty(shape, dtype, order=('C' if corder else 'F'))

        element_size = util.sizeof(dtype)
        offset = 0
        dst_strides = ret.strides
        dst_stride_axis = dst_strides[axis]
        element_size = util.sizeof(dtype)

        for array in arrays:
            array_shape = array.shape

            if array.dtype is dtype:
                strides = array.strides
                dst = ret.data.as_byte() + offset * dst_stride_axis
                src = array.data.as_byte()
                block = element_size
                block_dim = ndim

                if corder:
                    i = ndim - 1
                    while i >= 0:
                        if strides[i] == block and dst_strides[i] == block:
                            block *= array_shape[i]
                        else:
                            block_dim = ndim - 1 - i
                            break
                        i -= 1
                    _copy_data_c(dst, src, Ptr[int](__ptr__(array_shape).as_byte()),
                                 Ptr[int](__ptr__(strides).as_byte()),
                                 Ptr[int](__ptr__(dst_strides).as_byte()),
                                 array.ndim, element_size, block, block_dim)
                else:
                    i = 0
                    while i < ndim:
                        if strides[i] == block and dst_strides[i] == block:
                            block *= array_shape[i]
                        else:
                            block_dim = i
                            break
                        i += 1
                    _copy_data_f(dst, src, Ptr[int](__ptr__(array_shape).as_byte()),
                                 Ptr[int](__ptr__(strides).as_byte()),
                                 Ptr[int](__ptr__(dst_strides).as_byte()),
                                 array.ndim, element_size, block, block_dim)
            else:
                for src_idx in util.multirange(array_shape):
                    dst_idx = util.tuple_add(src_idx, axis, offset)
                    ret._ptr(dst_idx)[0] = util.cast(array._ptr(src_idx)[0], dtype)

            offset += array_shape[axis]

        return ret

    def concat_flatten(arrays, out):
        dtype = out.dtype
        out_ccontig = out._contig[0]
        i = 0

        for a in arrays:
            if a.dtype is dtype and out_ccontig and a._contig[0]:
                q = out._ptr((i,))
                n = a.size
                str.memcpy(q.as_byte(), a.data.as_byte(), n * util.sizeof(dtype))
                i += n
            else:
                for idx in util.multirange(a.shape):
                    p = a._ptr(idx)
                    q = out._ptr((i,))
                    q[0] = util.cast(p[0], dtype)
                    i += 1

        return out

    def concat_tuple(arrays, axis = 0, out = None, dtype: type = NoneType):
        if staticlen(arrays) == 0:
            compile_error("need at least one array to concatenate")

        arrays = tuple(asarray(arr) for arr in arrays)

        if axis is None:
            tot = 0
            for a in arrays:
                tot += a.size
            shape = (tot,)

            if out is None:
                if dtype is NoneType:
                    x = find_out_type(arrays[1:], arrays[0].dtype)
                    return concat_flatten(arrays, empty(shape, dtype=type(x)))
                else:
                    return concat_flatten(arrays, empty(shape, dtype=dtype))
            else:
                if staticlen(out.shape) != 1:
                    compile_error("Output array has wrong dimensionality")

                if not util.tuple_equal(out.shape, shape):
                    raise ValueError("Output array is the wrong shape")

                return concat_flatten(arrays, out)
        else:
            ndim: Static[int] = staticlen(arrays[0].shape)
            if ndim == 0:
                compile_error("zero-dimensional arrays cannot be concatenated")

            check_array_dims_static(arrays[1:], ndim)

            if out is None:
                if dtype is NoneType:
                    x = find_out_type(arrays[1:], arrays[0].dtype)
                    return concat_inner(arrays, axis, out=None, dtype=type(x))
                else:
                    return concat_inner(arrays, axis, out=None, dtype=dtype)
            else:
                return concat_inner(arrays, axis, out, dtype=out.dtype)

    def concat_list(arrays, axis = 0, out = None, dtype: type = NoneType):
        if len(arrays) == 0:
            raise ValueError("need at least one array to concatenate")

        if not isinstance(arrays[0], ndarray):
            arrays = [asarray(arr) for arr in arrays]

        if axis is None:
            if axis is not None and len(arrays) == 1:
                util.normalize_axis_index(axis, 1)  # error check

            tot = 0
            for a in arrays:
                tot += a.size
            shape = (tot,)

            if out is None:
                if dtype is NoneType:
                    return concat_flatten(arrays, empty(shape, dtype=arrays[0].dtype))
                else:
                    return concat_flatten(arrays, empty(shape, dtype=dtype))
            else:
                if staticlen(out.shape) != 1:
                    compile_error("Output array has wrong dimensionality")

                if not util.tuple_equal(out.shape, shape):
                    raise ValueError("Output array is the wrong shape")

                return concat_flatten(arrays, out)
        else:
            ndim: Static[int] = staticlen(arrays[0].shape)
            if ndim == 0:
                compile_error("zero-dimensional arrays cannot be concatenated")

            for arr in arrays:
                if arr.ndim != arrays[0].ndim:
                    raise ValueError("all the input arrays must have same number of dimensions")

            shape = arrays[0].shape
            pshape = Ptr[int](__ptr__(shape).as_byte())
            axis = util.normalize_axis_index(axis, ndim)

            for iarrays in range(1, len(arrays)):
                arr_shape = arrays[iarrays].shape

                for idim in range(ndim):
                    if idim == axis:
                        pshape[idim] += arr_shape[idim]
                    elif pshape[idim] != arr_shape[idim]:
                        raise ValueError("all the input array dimensions except for the "
                                         "concatenation axis must match exactly")

            if out is None:
                if dtype is NoneType:
                    return concat_inner(arrays, axis, out=None, dtype=arrays[0].dtype)
                else:
                    return concat_inner(arrays, axis, out=None, dtype=dtype)
            else:
                return concat_inner(arrays, axis, out=out, dtype=out.dtype)

    if out is not None:
        if dtype is not NoneType:
            compile_error("concatenate() only takes `out` or `dtype` as an argument, but both were provided.")

        if not isinstance(out, ndarray):
            compile_error("'out' must be an array")

    if isinstance(arrays, Tuple):
        return concat_tuple(arrays, axis=axis, out=out, dtype=dtype)
    else:
        return concat_list(arrays, axis=axis, out=out, dtype=dtype)

def expand_dims(a, axis):
    a = asarray(a)
    old_ndims: Static[int] = staticlen(a.shape)
    new_ndims: Static[int] = (old_ndims +
        staticlen(util.normalize_axis_tuple(axis, 9999999)))
    axis = util.normalize_axis_tuple(axis, new_ndims)

    old_shape = a.shape
    pold_shape = Ptr[int](__ptr__(old_shape).as_byte())
    new_shape = (1,) * new_ndims
    pnew_shape = Ptr[int](__ptr__(new_shape).as_byte())

    j = 0
    for i in range(new_ndims):
        if i not in axis:
            pnew_shape[i] = pold_shape[j]
            j += 1

    return a.reshape(new_shape)

def _apply(fn, seq):
    if isinstance(seq, Tuple):
        return tuple(fn(a) for a in seq)
    elif isinstance(seq, List):
        return [fn(a) for a in seq]
    else:
        compile_error("expected a tuple or a list as input")

def stack(arrays, axis: int = 0, out = None, dtype: type = NoneType):
    if not (isinstance(arrays, Tuple) or
            (isinstance(arrays, List) and
             isinstance(arrays[0], ndarray))):
        return asarray(arrays)

    if len(arrays) == 0:
        raise ValueError("need at least one array to stack")

    arrays = _apply(asarray, arrays)
    arr0 = arrays[0]
    for arr in arrays:
        if not util.tuple_equal(arr0.shape, arr.shape):
            raise ValueError("all input arrays must have the same shape")

    result_ndim = arrays[0].ndim + 1
    axis = util.normalize_axis_index(axis, result_ndim)
    expanded_arrays = _apply(lambda arr: expand_dims(arr, axis=axis), arrays)
    return concatenate(expanded_arrays, axis=axis, out=out, dtype=dtype)

def vstack(tup, dtype: type = NoneType):
    if not (isinstance(tup, Tuple) or
            (isinstance(tup, List) and
             isinstance(tup[0], ndarray))):
        return asarray(tup)

    arrs = _apply(atleast_2d, tup)
    return concatenate(arrs, axis=0, dtype=dtype)

def hstack(tup, dtype: type = NoneType):
    if not (isinstance(tup, Tuple) or
            (isinstance(tup, List) and
             isinstance(tup[0], ndarray))):
        return asarray(tup)

    arrs = _apply(atleast_1d, tup)
    axis = 0 if (len(arrs) > 0 and arrs[0].ndim == 1) else 1
    return concatenate(arrs, axis=axis, dtype=dtype)

def dstack(tup):
    arrs = _apply(atleast_3d, tup)
    return concatenate(arrs, axis=2)

row_stack = vstack

def column_stack(tup):
    def fix_array(v):
        arr = asarray(v)
        if staticlen(arr.shape) < 2:
            return array(arr, copy=False, ndmin=2).T
        else:
            return arr

    arrs = _apply(fix_array, tup)
    return concatenate(arrs, axis=1)

def repeat(a, repeats, axis = None):
    def neg_rep_error():
        raise ValueError("negative dimensions are not allowed")

    a = asarray(a, order='C')
    dtype = a.dtype
    shape = a.shape
    pshape = Ptr[int](__ptr__(shape).as_byte())

    if isinstance(repeats, int):
        if repeats < 0:
            neg_rep_error()
    else:
        for rep in repeats:
            if rep < 0:
                neg_rep_error()

    if isinstance(axis, int):
        axis = util.normalize_axis_index(axis, a.ndim)
        n = pshape[axis]
        nel = 1
        for i in range(axis + 1, a.ndim):
            nel *= pshape[i]
        chunk = nel * a.itemsize

        n_outer = 1
        for i in range(axis):
            n_outer *= pshape[i]

        rep_shape = shape
        prep_shape = Ptr[int](__ptr__(rep_shape).as_byte())

        if isinstance(repeats, int):
            prep_shape[axis] *= repeats
            repeated = empty(rep_shape, dtype=dtype)
            old_data = a.data.as_byte()
            new_data = repeated.data.as_byte()

            for i in range(n_outer):
                for j in range(n):
                    for k in range(repeats):
                        str.memcpy(new_data, old_data, chunk)
                        new_data += chunk
                    old_data += chunk

            for src in util.multirange(shape):
                e = a._ptr(src)[0]
                off = src[axis]
                for r in range(repeats):
                    dst = util.tuple_set(src, axis, off * repeats + r)
                    p = repeated._ptr(dst)
                    p[0] = e

            return repeated
        else:
            axis_dim = prep_shape[axis]

            if len(repeats) != axis_dim:
                raise ValueError("length of 'repeats' does not match axis size")

            prep_shape[axis] = 0
            for rep in repeats:
                prep_shape[axis] += rep

            repeated = empty(rep_shape, dtype=dtype)
            old_data = a.data.as_byte()
            new_data = repeated.data.as_byte()

            for i in range(n_outer):
                for j in range(n):
                    for k in range(repeats[j]):
                        str.memcpy(new_data, old_data, chunk)
                        new_data += chunk
                    old_data += chunk

            return repeated
    elif axis is None:
        if isinstance(repeats, int):
            a_size = a.size
            rep_size = a_size * repeats
            repeated = empty((rep_size,), dtype=dtype)
            p = a.data
            q = repeated.data
            off = 0

            for i in range(a_size):
                elem = p[i]
                for _ in range(repeats):
                    q[off] = elem
                    off += 1

            return repeated
        else:
            a_size = a.size

            if len(repeats) != a_size:
                raise ValueError("length of 'repeats' does not match array size")

            rep_tot = 0
            for rep in repeats:
                rep_tot += rep

            repeated = empty((rep_tot,), dtype=dtype)
            p = a.data
            q = repeated.data
            rep_idx = 0
            off = 0

            for i in range(a_size):
                elem = p[i]
                for _ in range(repeats[i]):
                    q[off] = elem
                    off += 1

            return repeated
    else:
        compile_error("'axis' must be None or an int")

def delete(arr, obj, axis = None):
    arr = asarray(arr)
    dtype = arr.dtype
    shape = arr.shape
    newshape = shape
    pnewshape = Ptr[int](__ptr__(newshape).as_byte())

    cc, fc = arr._contig
    arrorder = 'F' if (fc and not cc) else 'C'

    if isinstance(axis, int):
        axis = util.normalize_axis_index(axis, arr.ndim)
        shape_no_axis = util.tuple_delete(shape, axis)
        N = arr.shape[axis]

        if isinstance(obj, int):
            pnewshape[axis] -= 1
            new = empty(newshape, dtype, arrorder)
            obj = util.normalize_index(obj, axis, N)

            for i in range(obj):
                for idx0 in util.multirange(shape_no_axis):
                    idx = util.tuple_insert(idx0, axis, i)
                    p = arr._ptr(idx)
                    q = new._ptr(idx)
                    q[0] = p[0]

            for i in range(obj + 1, N):
                for idx0 in util.multirange(shape_no_axis):
                    idx1 = util.tuple_insert(idx0, axis, i)
                    idx2 = util.tuple_insert(idx0, axis, i - 1)
                    p = arr._ptr(idx1)
                    q = new._ptr(idx2)
                    q[0] = p[0]

            return new
        elif isinstance(obj, slice):
            start, stop, step = obj.adjust_indices(N)
            xr = range(start, stop, step)
            numtodel = len(xr)

            if numtodel <= 0:
                return arr.copy(order=arrorder)

            if step < 0:
                step = -step
                start = xr[-1]
                stop = xr[0] + 1

            pnewshape[axis] -= numtodel
            new = empty(newshape, dtype, arrorder)

            if start:
                for i in range(start):
                    for idx0 in util.multirange(shape_no_axis):
                        idx = util.tuple_insert(idx0, axis, i)
                        p = arr._ptr(idx)
                        q = new._ptr(idx)
                        q[0] = p[0]

            if stop != N:
                for i in range(stop, N):
                    for idx0 in util.multirange(shape_no_axis):
                        idx1 = util.tuple_insert(idx0, axis, i)
                        idx2 = util.tuple_insert(idx0, axis, i - numtodel)
                        p = arr._ptr(idx1)
                        q = new._ptr(idx2)
                        q[0] = p[0]

            if step != 1:
                off = start
                for i in range(start, stop):
                    if i in xr:
                        continue

                    for idx0 in util.multirange(shape_no_axis):
                        idx1 = util.tuple_insert(idx0, axis, i)
                        idx2 = util.tuple_insert(idx0, axis, off)
                        p = arr._ptr(idx1)
                        q = new._ptr(idx2)
                        q[0] = p[0]

                    off += 1

            return new
        else:
            if isinstance(obj[0], int):
                remove = [util.normalize_index(r, axis, N) for r in obj]
                remove.sort()

                # remove duplicates
                n = len(remove)
                numtodel = 0

                if n < 2:
                    numtodel = n
                else:
                    j = 0
                    for i in range(n - 1):
                        if remove[i] != remove[i+1]:
                            remove[j] = remove[i]
                            j += 1

                    remove[j] = remove[n-1]
                    j += 1
                    numtodel = j

                if numtodel == 0:
                    return arr.copy(order=arrorder)

                pnewshape[axis] -= numtodel
                new = empty(newshape, dtype, arrorder)

                curr = 0
                skip = remove[curr]

                for i in range(N):
                    if i == skip:
                        curr += 1
                        if curr < numtodel:
                            skip = remove[curr]
                        continue

                    for idx0 in util.multirange(shape_no_axis):
                        idx1 = util.tuple_insert(idx0, axis, i)
                        idx2 = util.tuple_insert(idx0, axis, i - curr)
                        p = arr._ptr(idx1)
                        q = new._ptr(idx2)
                        q[0] = p[0]

                return new
            elif isinstance(obj[0], bool):
                if len(obj) != N:
                    raise ValueError(
                        f"boolean array argument obj to delete must be one dimensional and match the axis length of {N}")

                numtodel = 0
                for r in obj:
                    if r:
                        numtodel += 1

                pnewshape[axis] -= numtodel
                new = empty(newshape, dtype, arrorder)
                off = 0

                for i in range(N):
                    if obj[i]:
                        continue

                    for idx0 in util.multirange(shape_no_axis):
                        idx1 = util.tuple_insert(idx0, axis, i)
                        idx2 = util.tuple_insert(idx0, axis, off)
                        p = arr._ptr(idx1)
                        q = new._ptr(idx2)
                        q[0] = p[0]

                    off += 1

                return new
            else:
                compile_error("arrays used as indices must be of integer (or boolean) type")
    elif axis is None:
        return delete(arr.ravel(), obj, axis=0)
    else:
        compile_error("'axis' must be None or an int")

def append(arr, values, axis = None):
    arr = asarray(arr)
    values = asarray(values)

    if staticlen(arr.shape) != staticlen(values.shape):
        compile_error("'arr' and 'values' must have the same number of dimensions")

    dtype1 = arr.dtype
    dtype2 = values.dtype
    dtype_out = type(util.coerce(dtype1, dtype2))

    ndim: Static[int] = staticlen(arr.shape)
    shape = arr.shape
    val_shape = values.shape
    arr_nbytes = arr.nbytes
    val_nbytes = values.nbytes

    newshape = shape
    pnewshape = Ptr[int](__ptr__(newshape).as_byte())

    cc, fc = arr._contig
    arrorder = 'F' if (fc and not cc) else 'C'

    if isinstance(axis, int):
        axis = util.normalize_axis_index(axis, ndim)

        if util.tuple_delete(val_shape, axis) != util.tuple_delete(shape, axis):
            raise ValueError("'arr' and 'values' must have the same shape aside from specified axis")

        pnewshape[axis] += val_shape[axis]
        new = empty(newshape, dtype_out, arrorder)
        off = shape[axis]

        if dtype1 is dtype_out and ((axis == 0 and cc) or (axis == ndim - 1 and fc)):
            str.memcpy(new.data.as_byte(), arr.data.as_byte(), arr_nbytes)
        else:
            for idx in util.multirange(shape):
                p = arr._ptr(idx)
                q = new._ptr(idx)
                q[0] = util.cast(p[0], dtype_out)

        if dtype2 is dtype_out and ((axis == 0 and cc) or (axis == ndim - 1 and fc)):
            q = new.data.as_byte() + arr_nbytes
            str.memcpy(q, values.data.as_byte(), val_nbytes)
        else:
            for idx1 in util.multirange(val_shape):
                idx2 = util.tuple_add(idx1, axis, off)
                p = values._ptr(idx1)
                q = new._ptr(idx2)
                q[0] = util.cast(p[0], dtype_out)

        return new
    elif axis is None:
        new = empty((arr.size + values.size,), dtype_out, arrorder)
        q = new.data
        off = 0
        val_cc = values._contig[0]

        if dtype1 is dtype_out and cc:
            str.memcpy(q.as_byte(), arr.data.as_byte(), arr_nbytes)
        else:
            for idx in util.multirange(shape):
                p = arr._ptr(idx)
                q[off] = util.cast(p[0], dtype_out)
                off += 1

        if dtype2 is dtype_out and val_cc:
            str.memcpy(q.as_byte() + arr_nbytes, values.data.as_byte(), val_nbytes)
        else:
            for idx in util.multirange(val_shape):
                p = values._ptr(idx)
                q[off] = util.cast(p[0], dtype_out)
                off += 1

        return new
    else:
        compile_error("'axis' must be None or an int")

def insert(arr, obj, values, axis = None):
    def normalize_special(idx: int, axis: int, n: int, numnew: int):
        idx0 = idx
        if idx < 0:
            idx += n
        if idx < 0 or idx >= n + numnew:
            raise IndexError(f'index {idx0} is out of bounds for axis {axis} with size {n + numnew}')
        return idx

    arr = asarray(arr)
    values = asarray(values)
    dtype = arr.dtype
    shape = arr.shape
    newshape = shape
    pnewshape = Ptr[int](__ptr__(newshape).as_byte())

    cc, fc = arr._contig
    arrorder = 'F' if (fc and not cc) else 'C'
    ndim: Static[int] = staticlen(arr.shape)

    if isinstance(axis, int):
        axis = util.normalize_axis_index(axis, arr.ndim)
        N = pnewshape[axis]
        numnew = 0
        indices: List[Tuple[int,int]]

        if isinstance(obj, slice):
            start, stop, step, length = obj.adjust_indices(N)
            numnew = length
            indices = [(normalize_special(a, axis, N, numnew), i)
                         for i, a in enumerate(range(start, stop, step))]
            indices.sort()
        elif isinstance(obj, int):
            numnew = 1
            indices = [(normalize_special(obj, axis, N, numnew), 0)]
        else:
            numnew = len(obj)
            indices = [(normalize_special(a, axis, N, numnew), i)
                         for i, a in enumerate(obj)]
            indices.sort()

        numnew = len(indices)

        for i in range(numnew):
            idx, rank = indices[i]
            idx += i
            util.normalize_index(idx, axis, N + numnew)  # error check
            indices[i] = (idx, rank)

        if numnew == 0:
            return arr.copy(order=arrorder)

        pnewshape[axis] += numnew
        new = empty(newshape, dtype, arrorder)
        newshape_no_axis = util.tuple_delete(newshape, axis)
        curr = 0
        off = 0
        next_index, next_rank = indices[curr]
        multiple_values = False
        if staticlen(values.shape) > 0:
            multiple_values = (len(values) == numnew)

        for ai in range(pnewshape[axis]):
            if ai == next_index:
                for idx0 in util.multirange(newshape_no_axis):
                    idx1 = util.tuple_insert(idx0, axis, 0)
                    idx2 = util.tuple_insert(idx0, axis, ai)
                    q = new._ptr(idx2)

                    if staticlen(values.shape) == 0:
                        q[0] = util.cast(values.data[0], dtype)
                    else:
                        if multiple_values:
                            p = asarray(values[next_rank])._ptr(idx1, broadcast=True)
                            q[0] = util.cast(p[0], dtype)
                        else:
                            p = values._ptr(idx1, broadcast=True)
                            q[0] = util.cast(p[0], dtype)

                curr += 1
                if curr < numnew:
                    next_index, next_rank = indices[curr]
            else:
                for idx0 in util.multirange(newshape_no_axis):
                    idx1 = util.tuple_insert(idx0, axis, off)
                    idx2 = util.tuple_insert(idx0, axis, ai)
                    p = arr._ptr(idx1)
                    q = new._ptr(idx2)
                    q[0] = util.cast(p[0], dtype)

                off += 1

        return new
    elif axis is None:
        return insert(arr.ravel(), obj, values, axis=0)
    else:
        compile_error("'axis' must be None or an int")

def array_split(ary, indices_or_sections, axis: int = 0):
    def correct(idx: int, n: int):
        if idx > n:
            return n
        elif idx < -n:
            return 0
        elif idx < 0:
            return idx + n
        else:
            return idx

    def slice_axis(arr, axis: int, start: int, stop: int):
        ndim: Static[int] = staticlen(ary.shape)
        dtype = arr.dtype

        shape = arr.shape
        pshape = Ptr[int](__ptr__(shape).as_byte())
        limit = pshape[axis]
        start = correct(start, limit)
        stop = correct(stop, limit)

        base = (0,) * ndim
        pbase = Ptr[int](__ptr__(base).as_byte())
        pbase[axis] = start
        pshape[axis] = stop - start if stop > start else 0
        sub = arr._ptr(base) if stop > start else Ptr[dtype]()

        return ndarray(shape, arr.strides, sub)

    ary = asarray(ary)
    axis = util.normalize_axis_index(axis, ary.ndim)
    ntotal = ary.shape[axis]

    if isinstance(indices_or_sections, int):
        nsections = indices_or_sections
        if nsections <= 0:
            raise ValueError("number sections must be larger than 0.")

        neach, extras = divmod(ntotal, nsections)
        result = List(capacity=nsections)
        start = 0

        for i in range(nsections):
            stop = start + neach + (1 if i < extras else 0)
            result.append(slice_axis(ary, axis, start, stop))
            start = stop

        return result
    else:
        result = List(capacity=(len(indices_or_sections) + 1))
        prev = 0

        for s in indices_or_sections:
            result.append(slice_axis(ary, axis, prev, s))
            prev = s

        result.append(slice_axis(ary, axis, prev, ntotal))
        return result

def split(ary, indices_or_sections, axis: int = 0):
    ary = asarray(ary)
    if isinstance(indices_or_sections, int):
        sections = indices_or_sections
        axis = util.normalize_axis_index(axis, ary.ndim)
        N = ary.shape[axis]
        if N % sections:
            raise ValueError("array split does not result in an equal division")
    return array_split(ary, indices_or_sections, axis)

def vsplit(ary, indices_or_sections):
    ary = asarray(ary)

    if staticlen(ary.shape) < 2:
        compile_error("vsplit only works on arrays of 2 or more dimensions")

    return split(ary, indices_or_sections, 0)

def hsplit(ary, indices_or_sections):
    ary = asarray(ary)

    if staticlen(ary.shape) == 0:
        compile_error("hsplit only works on arrays of 1 or more dimensions")

    return split(ary, indices_or_sections, 1 if ary.ndim > 1 else 0)

def dsplit(ary, indices_or_sections):
    ary = asarray(ary)

    if staticlen(ary.shape) < 3:
        compile_error("dsplit only works on arrays of 3 or more dimensions")

    return split(ary, indices_or_sections, 2)

def trim_zeros(filt, trim: str = 'fb'):
    filt = asarray(filt)

    if staticlen(filt.shape) != 1:
        compile_error("trim_zeros() only takes 1-dimensional arrays as input")

    fb = (trim == 'fb' or trim == 'bf')
    just_f = (trim == 'f')
    just_b = (trim == 'b')
    trim_front = fb or just_f
    trim_back = fb or just_b

    if trim_front:
        n = len(filt)
        i = 0
        while i < n and not filt[i]:
            i += 1
        filt = filt[i:]

    if trim_back:
        n = len(filt)
        i = n - 1
        while i >= 0 and not filt[i]:
            i -= 1
        filt = filt[:i+1]

    return filt

def flip(m, axis = None):
    m = asarray(m)
    ndim: Static[int] = staticlen(m.shape)

    if axis is None:
        return flip(m, util.tuple_range(ndim))
    elif isinstance(axis, int):
        return flip(m, (axis,))

    axis = util.normalize_axis_tuple(axis, ndim)
    dtype = m.dtype
    offset = 0
    shape = m.shape
    strides = m.strides
    pshape = Ptr[int](__ptr__(shape).as_byte())
    pstrides = Ptr[int](__ptr__(strides).as_byte())

    for i in staticrange(ndim):
        if i in axis:
            st = pstrides[i]
            offset += (pshape[i] - 1) * st
            pstrides[i] = -st

    return ndarray(shape, strides, Ptr[dtype](m.data.as_byte() + offset))

def fliplr(m):
    m = asarray(m)
    if staticlen(m.shape) < 2:
        compile_error("Input must be >= 2-d.")
    return flip(m, axis=1)

def flipud(m):
    m = asarray(m)
    if staticlen(m.shape) < 1:
        compile_error("Input must be >= 1-d.")
    return flip(m, axis=0)

def rot90(m, k: int = 1, axes: Tuple[int,int] = (0, 1)):
    m = asarray(m)
    ndim: Static[int] = staticlen(m.shape)

    if axes[0] == axes[1] or abs(axes[0] - axes[1]) == ndim:
        raise ValueError("Axes must be different.")

    if (axes[0] >= ndim or axes[0] < -ndim
        or axes[1] >= ndim or axes[1] < -ndim):
        raise ValueError(f"Axes={axes} out of range for array of ndim={m.ndim}.")

    k %= 4

    if k == 0:
        return m[:]
    if k == 2:
        return flip(flip(m, axes[0]), axes[1])

    axes_list = util.tuple_range(ndim)
    axes_list = util.tuple_swap(axes_list, axes[0], axes[1])

    if k == 1:
        return flip(m, axes[1]).transpose(axes_list)
    else:
        # k == 3
        return flip(m.transpose(axes_list), axes[1])

def resize(a, new_shape):
    if isinstance(new_shape, int):
        return resize(a, (new_shape,))

    a = asarray(a)

    new_size = 1
    for dim_length in new_shape:
        new_size *= dim_length
        if dim_length < 0:
            raise ValueError('all elements of `new_shape` must be non-negative')

    if a.size == 0 or new_size == 0:
        # First case must zero fill. The second would have repeats == 0.
        return zeros(new_shape, dtype=a.dtype)

    new = empty(new_shape, dtype=a.dtype)
    off = 0
    go = True

    while go:
        for idx in util.multirange(a.shape):
            p = a._ptr(idx)
            new.data[off] = p[0]
            off += 1

            if off == new_size:
                go = False
                break

    return new

def tile(A, reps):
    if isinstance(reps, int):
        return tile(A, (reps,))

    A = asarray(A)
    dtype = A.dtype
    ndim: Static[int] = staticlen(A.shape)
    d: Static[int] = staticlen(reps)

    if ndim < d:
        new_shape = ((1,) * (d - ndim)) + A.shape
        return tile(A.reshape(new_shape), reps)

    if ndim > d:
        new_reps = ((1,) * (ndim - d)) + reps
        return tile(A, new_reps)

    out_shape = util.tuple_apply(int.__mul__, A.shape, reps)
    new = empty(out_shape, dtype=dtype)

    for idx1 in util.multirange(out_shape):
        idx2 = util.tuple_apply(int.__mod__, idx1, A.shape)
        q = new._ptr(idx1)
        p = A._ptr(idx2)
        q[0] = p[0]

    return new

def roll(a, shift, axis = None):
    a = asarray(a)
    ndim: Static[int] = staticlen(a.shape)

    if axis is None:
        return roll(a.ravel(), shift, axis=0).reshape(a.shape)

    if isinstance(axis, int):
        return roll(a, shift, (axis,))

    if isinstance(shift, int):
        return roll(a, (shift,), axis)

    na: Static[int] = staticlen(axis)
    ns: Static[int] = staticlen(shift)

    if na == 0:
        compile_error("empty tuple given for 'axis'")

    if ns == 0:
        compile_error("empty tuple given for 'shift'")

    if na == 1 and ns != 1:
        return roll(a, shift, axis * ns)

    if na != 1 and ns == 1:
        return roll(a, shift * na, axis)

    if staticlen(axis) != staticlen(shift):
        compile_error("'shift' and 'axis' must be tuples of the same size")

    axis = util.normalize_axis_tuple(axis, a.ndim, allow_duplicates=True)

    shifts = (0,) * ndim
    pshifts = Ptr[int](__ptr__(shifts).as_byte())

    for i in staticrange(staticlen(axis)):
        pshifts[axis[i]] += shift[i]

    new = empty_like(a)

    @pure
    @llvm
    def cdiv(a: int, b: int) -> int:
        %0 = sdiv i64 %a, %b
        ret i64 %0

    def pymod(a: int, b: int):
        d = cdiv(a, b)
        m = a - d * b
        if m and ((b ^ m) < 0):
            m += b
        return m

    for idx1 in util.multirange(a.shape):
        idx2 = util.tuple_apply(int.__add__, idx1, shifts)
        idx2 = util.tuple_apply(pymod, idx2, a.shape)
        p = a._ptr(idx1)
        q = new._ptr(idx2)
        q[0] = p[0]

    return new

def _atleast_nd(a, ndim: Static[int]):
    return array(a, ndmin=ndim, copy=False)

def _accumulate(values):
    import itertools
    return list(itertools.accumulate(values))

def _longest_shape(a):
    if isinstance(a, ndarray):
        return (0,) * staticlen(a.shape)
    elif isinstance(a, List):
        if len(a) == 0:
            raise ValueError("Lists passed to block cannot be empty")
        return _longest_shape(a[0])
    elif isinstance(a, Tuple):
        if staticlen(a) == 0:
            compile_error("Tuples passed to block cannot be empty")
        if staticlen(a) == 1:
            return _longest_shape(a[0])
        ls1 = _longest_shape(a[0])
        ls2 = _longest_shape(a[1:])
        if staticlen(ls1) > staticlen(ls2):
            return ls1
        else:
            return ls2
    else:
        return ()

def _bottom_index(a):
    if isinstance(a, ndarray):
        return ()
    elif isinstance(a, List):
        return (0,) + _bottom_index(a[0])
    elif isinstance(a, Tuple):
        bi0 = _bottom_index(a[0])
        for i in staticrange(1, staticlen(a)):
            if staticlen(_bottom_index(a[i])) != staticlen(bi0):
                compile_error("Depths of block argument are mismatched")

        return (0,) + bi0
    else:
        return ()

def _final_size(a):
    if isinstance(a, ndarray):
        return a.size
    elif isinstance(a, List) or isinstance(a, Tuple):
        ans = 0
        for x in a:
            ans += _final_size(x)
        return ans
    else:
        return 1

def _block(arrays, max_depth: Static[int], result_ndim: Static[int], depth: Static[int] = 0):
    if depth < max_depth:
        arrs = [_block(arr, max_depth, result_ndim, depth+1)
                for arr in arrays]
        return concatenate(arrs, axis=-(max_depth-depth))
    else:
        return _atleast_nd(arrays, result_ndim)

def _block_concatenate(arrays, list_ndim: Static[int], result_ndim: Static[int]):
    result = _block(arrays, list_ndim, result_ndim)
    if list_ndim == 0:
        result = result.copy()
    return result

def _concatenate_shapes(shapes, axis: Static[int]):
    # Cache a result that will be reused.
    shape_at_axis = [shape[axis] for shape in shapes]

    # Take a shape, any shape
    first_shape = shapes[0]
    first_shape_pre = first_shape[:axis]
    first_shape_post = first_shape[axis+1:]

    if any(shape[:axis] != first_shape_pre or
           shape[axis+1:] != first_shape_post for shape in shapes):
        raise ValueError(
            f"Mismatched array shapes in block along axis {axis}.")

    shape = (first_shape_pre + (sum(shape_at_axis),) + first_shape[axis+1:])

    offsets_at_axis = _accumulate(shape_at_axis)
    slice_prefixes = [(slice(start, end),)
                      for start, end in zip([0] + offsets_at_axis,
                                            offsets_at_axis)]
    return shape, slice_prefixes

def _block_info_recursion(arrs, max_depth: Static[int], result_ndim: Static[int], depth: Static[int] = 0):
    if depth < max_depth:
        shapes = []
        slices = []
        arrays = []

        for arr in arrs:
            sh, sl, ar = _block_info_recursion(arr, max_depth, result_ndim, depth+1)
            shapes.append(sh)
            slices.append(sl)
            arrays.append(ar)

        axis: Static[int] = result_ndim - max_depth + depth
        shape, slice_prefixes = _concatenate_shapes(shapes, axis)

        # Prepend the slice prefix and flatten the slices
        slices = [slice_prefix + the_slice
                  for slice_prefix, inner_slices in zip(slice_prefixes, slices)
                  for the_slice in inner_slices]

        # Flatten the array list
        arrays_flat = []
        for arr in arrays:
            for a in arr:
                arrays_flat.append(a)

        return shape, slices, arrays_flat
    else:
        arr = _atleast_nd(arrs, result_ndim)
        return arr.shape, [()], [arr]

def _block_slicing(arrays, list_ndim: Static[int], result_ndim: Static[int]):
    shape, slices, arrays = _block_info_recursion(
        arrays, list_ndim, result_ndim)
    dtype = arrays[0].dtype

    C_order = True
    F_order = True
    for arr in arrays:
        cc, fc = arr._contig
        C_order = C_order and cc
        F_order = F_order and fc
        if not C_order and not F_order:
            break

    order = 'F' if F_order and not C_order else 'C'
    result = empty(shape=shape, dtype=dtype, order=order)

    for the_slice, arr in zip(slices, arrays):
        result[(Ellipsis,) + the_slice] = arr
    return result

def block(arrays):
    ls = _longest_shape(arrays)
    bi = _bottom_index(arrays)
    final_size = _final_size(arrays)

    list_ndim: Static[int] = staticlen(bi)
    arr_ndim: Static[int] = staticlen(ls)
    result_ndim: Static[int] = arr_ndim if arr_ndim > list_ndim else list_ndim

    # Note: This is just the heuristic NumPy uses. I have not tested how
    # good it is for Codon.
    if list_ndim * final_size > (2 * 512 * 512):
        return _block_slicing(arrays, list_ndim, result_ndim)
    else:
        return _block_concatenate(arrays, list_ndim, result_ndim)

def _close(a, b, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False):
    A = type(a)
    B = type(b)
    if A is not B:
        C = type(util.coerce(A, B))
        if (C is not float and
            C is not float32 and
            C is not float16 and
            C is not complex and
            C is not complex64):
            return _close(util.to_float(a),
                          util.to_float(b),
                          rtol=rtol,
                          atol=atol,
                          equal_nan=equal_nan)
        else:
            return _close(util.cast(a, C),
                          util.cast(b, C),
                          rtol=rtol,
                          atol=atol,
                          equal_nan=equal_nan)
    elif (A is not float and
          A is not float32 and
          A is not float16 and
          A is not complex and
          A is not complex64):
        return _close(util.to_float(a),
                      util.to_float(b),
                      rtol=rtol,
                      atol=atol,
                      equal_nan=equal_nan)

    if A is float or A is float32 or A is float16:
        fin_a = util.isfinite(a)
        fin_b = util.isfinite(b)
        if fin_a and fin_b:
            return abs(a - b) <= A(atol) + A(rtol) * abs(b)
        else:
            nan_a = util.isnan(a)
            nan_b = util.isnan(b)
            if nan_a or nan_b:
                if equal_nan:
                    return nan_a and nan_b
                else:
                    return False
            else:
                return a == b
    else:  # complex or complex64
        R = type(a.real)
        fin_a = util.isfinite(a.real) and util.isfinite(a.imag)
        fin_b = util.isfinite(b.real) and util.isfinite(b.imag)
        if fin_a and fin_b:
            return abs(a - b) <= R(atol) + R(rtol) * abs(b)
        else:
            nan_a = util.isnan(a.real) or util.isnan(a.imag)
            nan_b = util.isnan(b.real) or util.isnan(b.imag)
            if nan_a or nan_b:
                if equal_nan:
                    return nan_a and nan_b
                else:
                    return False
            else:
                return a == b

def isclose(a, b, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False):
    a = asarray(a)
    b = asarray(b)

    if a.ndim == 0 and b.ndim == 0:
        return _close(a.item(), b.item(), rtol=rtol, atol=atol, equal_nan=equal_nan)

    ans_shape = broadcast_shapes(a.shape, b.shape)
    ans = empty(ans_shape, bool)

    for idx in util.multirange(ans_shape):
        xa = a._ptr(idx, broadcast=True)[0]
        xb = b._ptr(idx, broadcast=True)[0]
        xans = _close(xa, xb, rtol=rtol, atol=atol, equal_nan=equal_nan)
        ans._ptr(idx)[0] = xans

    return ans

def allclose(a, b, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False):
    a = asarray(a)
    b = asarray(b)

    if a.ndim == 0 and b.ndim == 0:
        return _close(a.item(), b.item(), rtol=rtol, atol=atol, equal_nan=equal_nan)

    ans_shape = broadcast_shapes(a.shape, b.shape)

    for idx in util.multirange(ans_shape):
        xa = a._ptr(idx, broadcast=True)[0]
        xb = b._ptr(idx, broadcast=True)[0]
        if not _close(xa, xb, rtol=rtol, atol=atol, equal_nan=equal_nan):
            return False

    return True

def array_equal(a1, a2, equal_nan: bool = False):
    from .ndmath import isnan
    a1 = asarray(a1)
    a2 = asarray(a2)

    if a1.ndim != a2.ndim:
        return False

    if a1.shape != a2.shape:
        return False

    dtype = type(util.coerce(a1.dtype, a2.dtype))

    if a1._contig_match(a2):
        p1 = a1.data
        p2 = a2.data
        n = a1.size
        for i in range(n):
            e1 = util.cast(p1[i], dtype)
            e2 = util.cast(p2[i], dtype)
            if equal_nan:
                if e1 != e2 and not (isnan(e1) and isnan(e2)):
                    return False
            else:
                if e1 != e2:
                    return False
    else:
        for idx in util.multirange(a1.shape):
            e1 = util.cast(a1._ptr(idx)[0], dtype)
            e2 = util.cast(a2._ptr(idx)[0], dtype)
            if equal_nan:
                if e1 != e2 and not (isnan(e1) and isnan(e2)):
                    return False
            else:
                if e1 != e2:
                    return False

    return True

def array_equiv(a1, a2):
    def can_broadcast(s1, s2):
        if staticlen(s1) > staticlen(s2):
            return can_broadcast(s1[-staticlen(s2):], s2)
        elif staticlen(s1) < staticlen(s2):
            return can_broadcast(s1, s2[-staticlen(s1):])
        else:
            for i in staticrange(staticlen(s1)):
                d1 = s1[i]
                d2 = s2[i]
                if d1 != d2 and not (d1 == 1 or d2 == 1):
                    return False
            return True

    a1 = asarray(a1)
    a2 = asarray(a2)

    if not can_broadcast(a1.shape, a2.shape):
        return False

    bshape = broadcast_shapes(a1.shape, a2.shape)
    dtype = type(util.coerce(a1.dtype, a2.dtype))

    for idx in util.multirange(bshape):
        e1 = util.cast(a1._ptr(idx, broadcast=True)[0], dtype)
        e2 = util.cast(a2._ptr(idx, broadcast=True)[0], dtype)
        if e1 != e2:
            return False

    return True

def squeeze(a, axis = None):
    if axis is None:
        compile_error("squeeze() must specify 'axis' argument in Codon-NumPy")

    if isinstance(axis, int):
        return squeeze(a, (axis,))

    if not isinstance(axis, Tuple):
        compile_error("'axis' must be an int or a tuple of ints")

    a = asarray(a)
    axis = util.normalize_axis_tuple(axis, a.ndim)
    shape = a.shape
    strides = a.strides

    new_shape = (0,) * (staticlen(shape) - staticlen(axis))
    new_strides = (0,) * (staticlen(shape) - staticlen(axis))
    pnew_shape = Ptr[int](__ptr__(new_shape).as_byte())
    pnew_strides = Ptr[int](__ptr__(new_strides).as_byte())

    j = 0
    for i in staticrange(staticlen(shape)):
        if i in axis:
            if shape[i] != 1:
                raise ValueError("cannot select an axis to squeeze out which has size not equal to one")
        else:
            pnew_shape[j] = shape[i]
            pnew_strides[j] = strides[i]
            j += 1

    return ndarray(new_shape, new_strides, a.data)

def pad(array, pad_width, mode = 'constant', **kwargs):
    from .ndmath import isnan

    def unpack_params(x, iaxis: int, name: Static[str]):
        if isinstance(x, Tuple):
            if staticlen(x) == 1:
                if isinstance(x[0], Tuple):
                    if staticlen(x[0]) == 1:
                        return x[0][0], x[0][0]
                    elif staticlen(x[0]) == 2:
                        return x[0]
                    else:
                        compile_error("invalid parameter '" + name + "' given")
                else:
                    return x[0], x[0]
            elif staticlen(x) == 2:
                if isinstance(x[0], Tuple):
                    return x[iaxis]
                else:
                    return x
            else:
                return x[iaxis]
        else:
            return x, x

    def unpack_stat_length(k: int, iaxis: int, kwargs):
        stat_length = kwargs.get('stat_length', k)
        s1, s2 = unpack_params(stat_length, iaxis, 'stat_length')

        if s1 <= 0 or s2 <= 0:
            raise ValueError("'stat_length' must contain positive values")

        s1 = min(s1, k)
        s2 = min(s2, k)

        return s1, s2

    def round_if_needed(x, dtype: type):
        if dtype is int or isinstance(dtype, Int) or isinstance(dtype, UInt):
            return util.cast(util.rint(x), dtype)
        else:
            return util.cast(x, dtype)

    def pad_from_function(a: ndarray, pw, padding_func, kwargs, extra = None):
        shape = a.shape
        strides = a.strides
        ndim: Static[int] = staticlen(shape)

        for i in staticrange(ndim):
            length = shape[i]
            stride = strides[i]

            for idx in util.multirange(util.tuple_delete(shape, i)):
                p = a._ptr(util.tuple_insert(idx, i, 0))
                vector = ndarray((length,), (stride,), p)
                padding_func(vector, pw[i], i, kwargs, extra)

    def pad_constant(vector: ndarray,
                     iaxis_pad_width: Tuple[int, int],
                     iaxis: int,
                     kwargs,
                     extra):
        p1, p2 = iaxis_pad_width
        cval = kwargs.get('constant_values', 0)
        dtype = vector.dtype
        c1, c2 = unpack_params(cval, iaxis, 'constant_values')
        n = vector.size

        # Note we don't need to do range checks since padded
        # array size will always ensure we're not out of bounds,
        # as long as each `iaxis_pad_width` is non-negative.

        for i in range(0, p1):
            vector._ptr((i,))[0] = util.cast(c1, dtype)

        for i in range(n - p2, n):
            vector._ptr((i,))[0] = util.cast(c2, dtype)

    def pad_wrap(vector: ndarray,
                 iaxis_pad_width: Tuple[int, int],
                 iaxis: int,
                 kwargs,
                 extra):
        p1, p2 = iaxis_pad_width
        n = vector.size
        k = n - (p1 + p2)

        for i in range(n - p2, n):
            vector._ptr((i,))[0] = vector._ptr((i - k,))[0]

        for i in range(p1 - 1, -1, -1):
            vector._ptr((i,))[0] = vector._ptr((i + k,))[0]

    def pad_edge(vector: ndarray,
                 iaxis_pad_width: Tuple[int, int],
                 iaxis: int,
                 kwargs,
                 extra):
        p1, p2 = iaxis_pad_width
        n = vector.size
        c1 = vector._ptr((p1,))[0]
        c2 = vector._ptr((n - 1 - p2,))[0]

        for i in range(0, p1):
            vector._ptr((i,))[0] = c1

        for i in range(n - p2, n):
            vector._ptr((i,))[0] = c2

    def pad_max(vector: ndarray,
                iaxis_pad_width: Tuple[int, int],
                iaxis: int,
                kwargs,
                extra):
        p1, p2 = iaxis_pad_width
        n = vector.size
        k = n - (p1 + p2)
        s1, s2 = unpack_stat_length(k, iaxis, kwargs)

        m1 = vector._ptr((p1,))[0]
        for i in range(p1 + 1, p1 + s1):
            e = vector._ptr((i,))[0]
            if e > m1:
                m1 = e

        if s1 == k and s2 == k:
            m2 = m1
        else:
            m2 = vector._ptr((n - p2 - s2,))[0]
            for i in range(n - p2 - s2 + 1, n - p2):
                e = vector._ptr((i,))[0]
                if e > m2:
                    m2 = e

        for i in range(0, p1):
            vector._ptr((i,))[0] = m1

        for i in range(n - p2, n):
            vector._ptr((i,))[0] = m2

    def pad_min(vector: ndarray,
                iaxis_pad_width: Tuple[int, int],
                iaxis: int,
                kwargs,
                extra):
        p1, p2 = iaxis_pad_width
        n = vector.size
        k = n - (p1 + p2)
        s1, s2 = unpack_stat_length(k, iaxis, kwargs)

        m1 = vector._ptr((p1,))[0]
        for i in range(p1 + 1, p1 + s1):
            e = vector._ptr((i,))[0]
            if e < m1:
                m1 = e

        if s1 == k and s2 == k:
            m2 = m1
        else:
            m2 = vector._ptr((n - p2 - s2,))[0]
            for i in range(n - p2 - s2 + 1, n - p2):
                e = vector._ptr((i,))[0]
                if e < m2:
                    m2 = e

        for i in range(0, p1):
            vector._ptr((i,))[0] = m1

        for i in range(n - p2, n):
            vector._ptr((i,))[0] = m2

    def pad_mean(vector: ndarray,
                 iaxis_pad_width: Tuple[int, int],
                 iaxis: int,
                 kwargs,
                 extra):
        p1, p2 = iaxis_pad_width
        n = vector.size
        k = n - (p1 + p2)
        s1, s2 = unpack_stat_length(k, iaxis, kwargs)

        m1 = vector._ptr((p1,))[0]
        for i in range(p1 + 1, p1 + s1):
            m1 += vector._ptr((i,))[0]

        if s1 == k and s2 == k:
            m2 = m1
        else:
            m2 = vector._ptr((n - p2 - s2,))[0]
            for i in range(n - p2 - s2 + 1, n - p2):
                m2 += vector._ptr((i,))[0]

        dtype = vector.dtype
        av1 = round_if_needed(m1 / util.cast(s1, dtype), dtype)
        av2 = round_if_needed(m2 / util.cast(s2, dtype), dtype)

        for i in range(0, p1):
            vector._ptr((i,))[0] = av1

        for i in range(n - p2, n):
            vector._ptr((i,))[0] = av2

    def pad_median(vector: ndarray,
                   iaxis_pad_width: Tuple[int, int],
                   iaxis: int,
                   kwargs,
                   extra):
        p1, p2 = iaxis_pad_width
        n = vector.size
        k = n - (p1 + p2)
        s1, s2 = unpack_stat_length(k, iaxis, kwargs)

        buf = extra
        nan_idx = -1
        for i in range(p1, p1 + s1):
            e = vector._ptr((i,))[0]
            if isnan(e):
                nan_idx = i
                break
            else:
                buf[i - p1] = e

        dtype = vector.dtype
        if nan_idx >= 0:
            m1 = vector._ptr((nan_idx,))[0]
        else:
            m1a, m1b = util.median(buf, s1)
            if m1a == m1b:
                m1 = m1a
            else:
                m1 = round_if_needed((m1a + m1b) / util.cast(2, dtype), dtype)

        if s1 == k and s2 == k:
            m2 = m1
        else:
            nan_idx = -1
            base = n - p2 - s2
            for i in range(base, n - p2):
                e = vector._ptr((i,))[0]
                if isnan(e):
                    nan_idx = i
                    break
                else:
                    buf[i - base] = e

            if nan_idx >= 0:
                m2 = vector._ptr((nan_idx,))[0]
            else:
                m2a, m2b = util.median(buf, s2)
                if m2a == m2b:
                    m2 = m2a
                else:
                    m2 = round_if_needed((m2a + m2b) / util.cast(2, dtype), dtype)

        for i in range(0, p1):
            vector._ptr((i,))[0] = m1

        for i in range(n - p2, n):
            vector._ptr((i,))[0] = m2

    def pad_linear_ramp(vector: ndarray,
                        iaxis_pad_width: Tuple[int, int],
                        iaxis: int,
                        kwargs,
                        extra):
        def fill_linear(vec: ndarray, offset: int, start: float,
                        stop: float, num: int, rev: bool):
            if num == 0:
                return

            dtype = vec.dtype
            delta = stop - start
            step = delta / num

            for i in range(num):
                j = num - 1 - i if rev else i
                p = vec._ptr((j + offset,))
                e = ((i / num) * delta) + start if not step else (i * step) + start
                p[0] = round_if_needed(e, dtype)

        p1, p2 = iaxis_pad_width
        n = vector.size
        end_values = kwargs.get('end_values', 0)
        start1, end2 = unpack_params(end_values, iaxis, 'end_values')
        end1 = vector._ptr((p1,))[0]
        start2 = vector._ptr((n - 1 - p2,))[0]
        fill_linear(vector,
                    offset=0,
                    start=util.cast(start1, float),
                    stop=util.cast(end1, float),
                    num=p1,
                    rev=False)
        fill_linear(vector,
                    offset=(n - p2),
                    start=util.cast(end2, float),
                    stop=util.cast(start2, float),
                    num=p2,
                    rev=True)

    def pad_reflect_or_symmetric(vector: ndarray,
                                 iaxis_pad_width: Tuple[int, int],
                                 iaxis: int,
                                 kwargs,
                                 extra):
        p1, p2 = iaxis_pad_width
        n = vector.size
        k = n - (p1 + p2)
        diff = extra

        if k == 1:
            e = vector._ptr((p1,))[0]

            for i in range(0, p1):
                vector._ptr((i,))[0] = e

            for i in range(n - p2, n):
                vector._ptr((i,))[0] = e

            return

        even = kwargs.get('reflect_type', 'even') != 'odd'

        # left side
        i = p1 - 1
        while i >= 0:
            z = i
            edge = vector._ptr((z + 1,))[0]

            for j in range(k - diff):
                e = vector._ptr((z + j + (1 + diff),))[0]
                if not even:
                    e = (edge + edge) - e

                vector._ptr((i,))[0] = e
                i -= 1
                if i < 0:
                    break

        # right side
        i = n - p2
        while i < n:
            z = i
            edge = vector._ptr((z - 1,))[0]

            for j in range(k - diff):
                e = vector._ptr((z - j - (1 + diff),))[0]
                if not even:
                    e = (edge + edge) - e

                vector._ptr((i,))[0] = e
                i += 1
                if i >= n:
                    break

    a = asarray(array)
    s = a.shape
    dtype = a.dtype
    ndim: Static[int] = staticlen(s)

    if ndim == 0:
        return a

    if isinstance(mode, str):
        if a.size == 0 and mode not in ('empty', 'constant'):
            raise ValueError("can't pad empty array using modes other than 'constant' or 'empty'")

    if isinstance(pad_width, int):
        pw = ((pad_width, pad_width),) * ndim
    elif isinstance(pad_width, Tuple[int]):
        pw = ((pad_width[0], pad_width[0]),) * ndim
    elif isinstance(pad_width, Tuple[int, int]):
        pw = (pad_width,) * ndim
    elif isinstance(pad_width, Tuple[Tuple[int, int]]):
        pw = (pad_width[0],) * ndim
    elif staticlen(pad_width) == 1:
        pw = pad_width * ndim
    else:
        pw = pad_width

    if staticlen(pw) != ndim:
        compile_error("invalid pad_width")

    for p in pw:
        if p[0] < 0 or p[1] < 0:
            raise ValueError("padding can't be negative")

    new_shape = tuple(s[i] + pw[i][0] + pw[i][1] for i in staticrange(ndim))

    if isinstance(mode, str):
        ans = empty(new_shape, dtype)
    else:
        ans = zeros(new_shape, dtype)

    # copy in original array
    for idx in util.multirange(s):
        idx1 = tuple(idx[i] + pw[i][0] for i in staticrange(ndim))
        p = a._ptr(idx)
        q = ans._ptr(idx1)
        q[0] = p[0]

    if isinstance(mode, str):
        if mode == 'empty':
            pass  # do nothing
        elif mode == 'constant':
            pad_from_function(ans, pw, pad_constant, kwargs)
        elif mode == 'wrap':
            pad_from_function(ans, pw, pad_wrap, kwargs)
        elif mode == 'edge':
            pad_from_function(ans, pw, pad_edge, kwargs)
        elif mode == 'linear_ramp':
            pad_from_function(ans, pw, pad_linear_ramp, kwargs)
        elif mode == 'maximum':
            pad_from_function(ans, pw, pad_max, kwargs)
        elif mode == 'minimum':
            pad_from_function(ans, pw, pad_min, kwargs)
        elif mode == 'mean':
            pad_from_function(ans, pw, pad_mean, kwargs)
        elif mode == 'median':
            buf_size = 0
            for iaxis in range(ndim):
                s1, s2 = unpack_stat_length(s[iaxis], iaxis, kwargs)
                if s1 > buf_size:
                    buf_size = s1
                if s2 > buf_size:
                    buf_size = s2
            buf = Ptr[dtype](buf_size)
            pad_from_function(ans, pw, pad_median, kwargs, extra=buf)
        elif mode == 'reflect':
            pad_from_function(ans, pw, pad_reflect_or_symmetric, kwargs, extra=1)
        elif mode == 'symmetric':
            pad_from_function(ans, pw, pad_reflect_or_symmetric, kwargs, extra=0)
        else:
            raise ValueError(f"mode {repr(mode)} is not supported")
    else:
        pad_from_function(ans, pw, mode, kwargs)

    return ans

#############
# Searching #
#############

def nonzero(a):
    a = asarray(a)

    if staticlen(a.shape) == 0:
        # Note: this is technically deprecated behavior
        if bool(a.item()):
            ans = empty(1, int)
            ans.data[0] = 0
            return (ans,)
        else:
            return (empty(0, int),)
    else:
        num_true = 0
        if a._is_contig:
            for i in range(a.size):
                if bool(a.data[i]):
                    num_true += 1
        else:
            for idx in util.multirange(a.shape):
                if bool(a._ptr(idx)[0]):
                    num_true += 1

        ans = tuple(empty(num_true, int) for _ in a.shape)
        k = 0

        for idx in util.multirange(a.shape):
            if bool(a._ptr(idx)[0]):
                for i in staticrange(staticlen(ans)):
                    ans[i].data[k] = idx[i]
                k += 1

        return ans

def flatnonzero(a):
    a = asarray(a)

    if staticlen(a.shape) <= 1:
        return nonzero(a)[0]
    else:
        sz = a.size
        cc, fc = a._contig

        num_true = 0
        if cc or fc:
            for i in range(sz):
                if bool(a.data[i]):
                    num_true += 1
        else:
            for idx in util.multirange(a.shape):
                if bool(a._ptr(idx)[0]):
                    num_true += 1

        ans = empty(num_true, int)
        k = 0

        if cc:
            for i in range(sz):
                if bool(a.data[i]):
                    ans.data[k] = i
                    k += 1
        else:
            j = 0
            for idx in util.multirange(a.shape):
                if bool(a._ptr(idx)[0]):
                    ans.data[k] = j
                    k += 1
                j += 1

        return ans

def argwhere(a):
    a = asarray(a)

    if staticlen(a.shape) == 0:
        m = 1 if bool(a.item()) else 0
        return empty((m, 0), int)
    else:
        num_true = 0
        if a._is_contig:
            for i in range(a.size):
                if bool(a.data[i]):
                    num_true += 1
        else:
            for idx in util.multirange(a.shape):
                if bool(a._ptr(idx)[0]):
                    num_true += 1

        ans = empty((num_true, a.ndim), int)
        k = 0

        for idx in util.multirange(a.shape):
            if bool(a._ptr(idx)[0]):
                for i in range(a.ndim):
                    ans._ptr((k, i))[0] = idx[i]
                k += 1

        return ans

def where(condition, x, y):
    condition = asarray(condition)
    x = asarray(x)
    y = asarray(y)
    T = type(util.coerce(x.dtype, y.dtype))

    if condition._contig_match(x) and condition._contig_match(y):
        cc, fc = condition._contig
        ans = empty(condition.shape, dtype=T, order=('C' if cc else 'F'))

        sz = condition.size
        pc = condition.data
        px = x.data
        py = y.data
        q = ans.data

        for i in range(sz):
            if pc[i]:
                q[i] = util.cast(px[i], T)
            else:
                q[i] = util.cast(py[i], T)

        return ans
    else:
        bshape = broadcast_shapes(condition.shape, x.shape, y.shape)
        ans = empty(bshape, dtype=T)

        for idx in util.multirange(bshape):
            q = ans._ptr(idx)
            if condition._ptr(idx, broadcast=True)[0]:
                q[0] = util.cast(x._ptr(idx, broadcast=True)[0], T)
            else:
                q[0] = util.cast(y._ptr(idx, broadcast=True)[0], T)

        return ans

@overload
def where(condition):
    return nonzero(asarray(condition))

def extract(condition, arr):
    condition = ravel(asarray(condition))
    arr = ravel(asarray(arr))

    num_true = 0
    for i in range(condition.size):
        if condition.data[i]:
            num_true += 1

    if num_true > arr.size:
        raise IndexError(f"index {num_true} is out of bounds for axis 0 with size {arr.size}")

    ans = empty(num_true, arr.dtype)
    k = 0
    i = 0
    m = min(arr.size, condition.size)

    while i < m and k < num_true:
        if condition.data[i]:
            ans.data[k] = arr.data[i]
            k += 1
        i += 1

    return ans

def _less_than(a, b):
    T = type(util.coerce(type(a), type(b)))
    return util.cast(a, T) < util.cast(b, T)

def _bisect_left(p: Ptr[T], s: int, n: int, x: X, lo: int, hi: int, T: type, X: type):
    while lo < hi:
        mid = (lo + hi) >> 1
        elem = Ptr[T](p.as_byte() + (mid * s))[0]

        if _less_than(elem, x):
            lo = mid + 1
        else:
            hi = mid
    return lo, hi

def _bisect_right(p: Ptr[T], s: int, n: int, x: X, lo: int, hi: int, T: type, X: type):
    while lo < hi:
        mid = (lo + hi) >> 1
        elem = Ptr[T](p.as_byte() + (mid * s))[0]

        if _less_than(x, elem):
            hi = mid
        else:
            lo = mid + 1
    return lo, hi

def _bisect_left_perm(p: Ptr[T], s: int, n: int, x: X, perm: Ptr[int],
                      sp: int, lo: int, hi: int, T: type, X: type):
    while lo < hi:
        mid = (lo + hi) >> 1
        mid_remap = Ptr[int](perm.as_byte() + (mid * sp))[0]
        elem = Ptr[T](p.as_byte() + (mid_remap * s))[0]

        if _less_than(elem, x):
            lo = mid + 1
        else:
            hi = mid
    return lo, hi

def _bisect_right_perm(p: Ptr[T], s: int, n: int, x: X, perm: Ptr[int],
                       sp: int, lo: int, hi: int, T: type, X: type):
    while lo < hi:
        mid = (lo + hi) >> 1
        mid_remap = Ptr[int](perm.as_byte() + (mid * sp))[0]
        elem = Ptr[T](p.as_byte() + (mid_remap * s))[0]

        if _less_than(x, elem):
            hi = mid
        else:
            lo = mid + 1
    return lo, hi

def searchsorted(a, v, side: str = 'left', sorter = None):
    a = asarray(a)
    v = asarray(v)

    if staticlen(a.shape) != 1:
        compile_error("searchsorted requires 1-dimensional input array")

    n = a.size
    s = a.strides[0]

    left = False
    if side == 'left':
        left = True
    elif side == 'right':
        left = False
    else:
        raise ValueError(f"side must be 'left' or 'right' (got {repr(side)})")

    if sorter is None:
        perm = None
        sp = 0
    else:
        perm = asarray(sorter)

        if staticlen(perm.shape) != 1:
            compile_error("sorter argument must be 1-dimensional")

        if perm.dtype is not int:
            compile_error("sorter must only contain integers")

        if perm.size != n:
            raise ValueError("sorter.size must equal a.size")

        for sorter_idx in perm:
            if sorter_idx < 0 or sorter_idx >= n:
                raise ValueError("Sorter index out of range.")

        sp = perm.strides[0]

    # copy for large needles to improve cache performance
    if v.size > a.size:
        a = ascontiguousarray(a)
        s = a.strides[0]

    if staticlen(v.shape) == 0:
        x = v.item()
        if sorter is None:
            if left:
                return _bisect_left(a.data, s, n, x, lo=0, hi=n)[0]
            else:
                return _bisect_right(a.data, s, n, x, lo=0, hi=n)[0]
        else:
            if left:
                return _bisect_left_perm(a.data, s, n, x, perm.data, sp, lo=0, hi=n)[0]
            else:
                return _bisect_right_perm(a.data, s, n, x, perm.data, sp, lo=0, hi=n)[0]
    else:
        ans = empty(v.shape, int)

        if v.size == 0:
            return ans

        last_x = v._ptr((0,) * staticlen(v.shape))[0]
        lo = 0
        hi = n

        for idx in util.multirange(v.shape):
            x = v._ptr(idx)[0]

            if last_x < x:
                hi = n
            else:
                lo = 0
                hi = hi + 1 if hi < n else n

            last_x = x

            if sorter is None:
                if left:
                    lo, hi = _bisect_left(a.data, s, n, x, lo=lo, hi=hi)
                else:
                    lo, hi = _bisect_right(a.data, s, n, x, lo=lo, hi=hi)
            else:
                if left:
                    lo, hi = _bisect_left_perm(a.data, s, n, x, perm.data, sp, lo=lo, hi=hi)
                else:
                    lo, hi = _bisect_right_perm(a.data, s, n, x, perm.data, sp, lo=lo, hi=hi)

            ans._ptr(idx)[0] = lo

        return ans

############
# Indexing #
############

def take(a, indices, axis = None, out = None, mode: str = 'raise'):
    a = atleast_1d(asarray(a, order='C'))

    if axis is None:
        return take(a.ravel(), indices, axis=0, out=out, mode=mode)

    indices = asarray(indices, order='C')

    axis = util.normalize_axis_index(axis, a.ndim)
    n = 1
    m = 1
    chunk = 1
    nd: Static[int] = staticlen(a.shape) + staticlen(indices.shape) - 1
    shape = (0,) * nd
    ashape0 = a.shape
    ishape0 = indices.shape
    ashape = Ptr[int](__ptr__(ashape0).as_byte())
    ishape = Ptr[int](__ptr__(ishape0).as_byte())
    pshape = Ptr[int](__ptr__(shape).as_byte())

    for i in staticrange(nd):
        if i < axis:
            pshape[i] = ashape[i]
            n *= pshape[i]
        else:
            if i < axis + indices.ndim:
                pshape[i] = ishape[i - axis]
                m *= pshape[i]
            else:
                pshape[i] = ashape[i - indices.ndim + 1]
                chunk *= pshape[i]

    if out is None:
        res = empty(shape, dtype=a.dtype)
    elif isinstance(out, ndarray):
        if staticlen(out.shape) != staticlen(shape) or out.dtype is not a.dtype:
            compile_error("output array does not match result of ndarray.take")

        if out.shape != shape:
            raise ValueError("output array does not match result of ndarray.take")

        res = asarray(out, order='C')
    else:
        compile_error("output must be an array")

    max_item = a.shape[axis]
    nelem = chunk
    itemsize = res.itemsize
    chunk *= itemsize
    src = a.data.as_byte()
    dest = res.data.as_byte()
    indices_data = indices.data

    if max_item == 0 and res.size != 0:
        raise IndexError("cannot do a non-empty take from an empty axes.")

    if mode == 'raise':
        for i in range(n):
            for j in range(m):
                tmp = indices_data[j]
                tmp = util.normalize_index(tmp, axis, max_item)
                tmp_src = src + tmp * chunk
                str.memcpy(dest, tmp_src, chunk)
                dest += chunk
            src += chunk * max_item
    elif mode == 'wrap':
        for i in range(n):
            for j in range(m):
                tmp = indices_data[j]
                if tmp < 0:
                    while tmp < 0:
                        tmp += max_item
                elif tmp >= max_item:
                    while tmp >= max_item:
                        tmp -= max_item

                tmp_src = src + tmp * chunk
                str.memcpy(dest, tmp_src, chunk)
                dest += chunk
            src += chunk * max_item
    elif mode == 'clip':
        for i in range(n):
            for j in range(m):
                tmp = indices_data[j]
                if tmp < 0:
                    tmp = 0
                elif tmp >= max_item:
                    tmp = max_item - 1

                tmp_src = src + tmp * chunk
                str.memcpy(dest, tmp_src, chunk)
                dest += chunk
            src += chunk * max_item
    else:
        raise ValueError(f"clipmode must be one of 'clip', 'raise', or 'wrap' (got {repr(mode)})")

    if res.ndim == 0:
        return res.item()
    else:
        return res

def indices(dimensions, dtype: type = int, sparse: Static[int] = False):
    if not isinstance(dimensions, Tuple):
        compile_error("dimensions must be a tuple of integers")

    N: Static[int] = staticlen(dimensions)
    shape = (1,) * N

    if sparse:
        return tuple(
            arange(dimensions[i], dtype=dtype).reshape(
                shape[:i] + (dimensions[i],) + shape[i+1:]) for i in staticrange(N))

    res = empty((N,) + dimensions, dtype=dtype)

    for i in staticrange(N):
        dim = dimensions[i]
        idx = arange(dim, dtype=dtype).reshape(shape[:i] + (dim,) + shape[i+1:])
        res[i] = idx

    return res

def ix_(*args):
    def ix_one(new, nd: Static[int], k: Static[int]):
        new = asarray(new)
        if staticlen(new.shape) != 1:
            compile_error("Cross index must be 1 dimensional")

        if new.dtype is bool:
            newx = new.nonzero()[0]
        else:
            newx = new

        return newx.reshape((1,)*k + (newx.size,) + (1,)*(nd-k-1))

    nd: Static[int] = staticlen(args)
    return tuple(ix_one(args[k], nd, k) for k in staticrange(nd))

def ravel_multi_index(multi_index, dims, mode: str = 'raise', order: str = 'C'):
    def fix_index(idx: int, axis: int, dim: int, mode: str):
        if mode == 'raise':
            if idx < 0 or idx >= dim:
                raise ValueError("invalid entry in coordinates array")
        elif mode == 'wrap':
            if idx < 0:
                while idx < 0:
                    idx += dim
            elif idx >= dim:
                while idx >= dim:
                    idx -= dim
        elif mode == 'clip':
            if idx < 0:
                idx = 0
            elif idx >= dim:
                idx = dim - 1

        return idx

    def gather_non_ints(x):
        if staticlen(x) == 0:
            return ()

        i = x[0]
        rest = gather_non_ints(x[1:])

        if isinstance(i, int):
            return rest
        else:
            return (i,) + rest

    if isinstance(dims, int):
        return ravel_multi_index(multi_index, (dims,), mode=mode, order=order)

    for d in dims:
        if d <= 0:
            raise ValueError("dimensions must be positive")

    if mode not in ('raise', 'wrap', 'clip'):
        raise ValueError(f"clipmode must be one of 'clip', 'raise', or 'wrap' (got {repr(mode)})")

    corder = True
    if order == 'C':
        corder = True
    elif order == 'F':
        corder = False
    else:
        raise ValueError("only 'C' or 'F' order is permitted")

    N: Static[int] = staticlen(dims)

    if isinstance(multi_index, List[int]):
        if len(multi_index) != N:
            raise ValueError(f"parameter multi_index must be a sequence of length {N}")

        idx = tuple(fix_index(multi_index[j], j, dims[j], mode) for j in staticrange(N))
        return util.coords_to_index(idx, dims) if corder else util.coords_to_findex(idx, dims)
    elif isinstance(multi_index, Tuple):
        if staticlen(gather_non_ints(multi_index)) == 0:
            if staticlen(multi_index) != staticlen(dims):
                compile_error("parameter multi_index does not match dims in size")

            idx = tuple(fix_index(multi_index[j], j, dims[j], mode) for j in staticrange(N))
            return util.coords_to_index(idx, dims) if corder else util.coords_to_findex(idx, dims)

        midx = vstack(multi_index)
    else:
        midx = asarray(multi_index)

    if staticlen(midx.shape) != 2:
        compile_error("multi_index must be 2 dimensional")

    if len(multi_index) != N:
        raise ValueError(f"parameter multi_index must be a sequence of length {N}")

    ans = empty(midx.shape[1], int)

    for i in range(ans.size):
        idx = tuple(midx._ptr((j, i))[0] for j in staticrange(N))
        idx = tuple(fix_index(idx[j], j, dims[j], mode) for j in staticrange(N))
        res = util.coords_to_index(idx, dims) if corder else util.coords_to_findex(idx, dims)
        ans.data[i] = res

    return ans

def unravel_index(indices, shape, order: str = 'C'):
    def check(idx: int, n: int):
        if idx < 0 or idx >= n:
            raise ValueError(f"index {idx} is out of bounds for array with size {n}")
        return idx

    if isinstance(shape, int):
        return unravel_index(indices, (shape,), order=order)

    corder = True
    if order == 'C':
        corder = True
    elif order == 'F':
        corder = False
    else:
        raise ValueError("only 'C' or 'F' order is permitted")

    N: Static[int] = staticlen(shape)
    if N == 0:
        if not isinstance(indices, int):
            compile_error("multiple indices are not supported for 0d arrays")

        check(indices, 1)
        return ()

    n = 1
    for s in shape:
        if s < 0:
            raise ValueError("dimensions must be non-negative")
        n *= s

    if isinstance(indices, int):
        check(indices, n)
        return util.index_to_coords(indices, shape) if corder else util.index_to_fcoords(indices, shape)

    indices = asarray(indices)
    N: Static[int] = staticlen(shape)
    ans = tuple(empty(indices.shape, int) for _ in shape)

    for idx in util.multirange(indices.shape):
        index = indices._ptr(idx)[0]
        check(index, n)
        res = util.index_to_coords(index, shape) if corder else util.index_to_fcoords(index, shape)

        for i in staticrange(N):
            ans[i]._ptr(idx)[0] = res[i]

    return ans

def diag_indices(n: int, ndim: int):
    idx = arange(n)
    return [idx for _ in range(ndim)]

@overload
def diag_indices(n: int, ndim: Static[int] = 2):
    idx = arange(n)
    return (idx,) * ndim

def diag_indices_from(arr):
    arr = asarray(arr)
    shape = arr.shape
    ndim: Static[int] = staticlen(shape)

    if ndim < 2:
        compile_error("input array must be at least 2-d")

    for i in staticrange(1, ndim):
        if shape[i] != shape[0]:
            raise ValueError("All dimensions of input must be of equal length")

    return diag_indices(shape[0], ndim)

def mask_indices(n: int, mask_func, k = 0):
    m = ones((n, n), int)
    a = mask_func(m, k)
    return nonzero(a != 0)

def tril_indices(n: int, k: int = 0, m: Optional[int] = None):
    tri_ = tri(n, m, k=k, dtype=bool)
    return tuple(broadcast_to(inds, tri_.shape)[tri_]
                 for inds in indices(tri_.shape, sparse=True))

def tril_indices_from(arr, k: int = 0):
    arr = asarray(arr)
    if staticlen(arr.shape) != 2:
        compile_error("input array must be 2-d")
    return tril_indices(arr.shape[-2], k=k, m=arr.shape[-1])

def triu_indices(n: int, k: int = 0, m: Optional[int] = None):
    tri_ = ~tri(n, m, k=k - 1, dtype=bool)
    return tuple(broadcast_to(inds, tri_.shape)[tri_]
                 for inds in indices(tri_.shape, sparse=True))

def triu_indices_from(arr, k: int = 0):
    arr = asarray(arr)
    if staticlen(arr.shape) != 2:
        compile_error("input array must be 2-d")
    return triu_indices(arr.shape[-2], k=k, m=arr.shape[-1])

def take_along_axis(arr, indices, axis):
    def dim_mismatch():
        compile_error("`indices` and `arr` must have the same number of dimensions")

    arr = asarray(arr)
    indices = asarray(indices)

    if indices.dtype is not int:
        compile_error("`indices` must be an integer array")

    if axis is None:
        if staticlen(indices.shape) != 1:
            dim_mismatch()

        out = empty(indices.size, arr.dtype)
        for i in range(indices.size):
            out.data[i] = arr._get_flat(indices._ptr((i,))[0], check=True)

        return out

    if staticlen(arr.shape) != staticlen(indices.shape):
        dim_mismatch()

    axis = util.normalize_axis_index(axis, arr.ndim)
    M = arr.shape[axis]
    J = indices.shape[axis]

    bshape = broadcast_shapes(util.tuple_delete(arr.shape, axis),
                              util.tuple_delete(indices.shape, axis))
    out_shape = util.tuple_insert(bshape, axis, J)

    out = empty(out_shape, arr.dtype)
    a_stride = arr.strides[axis]
    indices_stride = indices.strides[axis]
    out_stride = out.strides[axis]

    for idx in util.multirange(bshape):
        base_idx   = util.tuple_insert(idx, axis, 0)
        a_1d       = ndarray((M,), (a_stride,), arr._ptr(base_idx, broadcast=True))
        indices_1d = ndarray((J,), (indices_stride,), indices._ptr(base_idx, broadcast=True))
        out_1d     = ndarray((J,), (out_stride,), out._ptr(base_idx))

        for j in range(J):
            out_1d._ptr((j,))[0] = a_1d[indices_1d._ptr((j,))[0]]

    return out

def choose(a, choices, out = None, mode: str = 'raise'):
    MODE_RAISE: Static[int] = 0
    MODE_WRAP : Static[int] = 1
    MODE_CLIP : Static[int] = 2

    a = asarray(a)

    if a.dtype is not int:
        compile_error("first argument must be an integer array")

    xmode = 0
    if mode == 'raise':
        xmode = MODE_RAISE
    elif mode == 'wrap':
        xmode = MODE_WRAP
    elif mode == 'clip':
        xmode = MODE_CLIP
    else:
        raise ValueError(f"clipmode must be one of 'clip', 'raise', or 'wrap' (got {repr(mode)})")

    n = len(choices)
    if n == 0:
        raise ValueError("0-length sequence.")

    if isinstance(choices, Tuple):
        xchoices1 = tuple(asarray(c) for c in choices)
        bshape = broadcast_shapes(*(tuple(c.shape for c in xchoices1) + (a.shape,)))
        xchoices = tuple(broadcast_to(c, bshape) for c in xchoices1)
    elif isinstance(choices, List):
        if not isinstance(choices[0], ndarray):
            xchoices = [asarray(c) for c in choices]
        else:
            xchoices = choices
        bshape = broadcast_shapes(xchoices[0].shape, a.shape)
    elif isinstance(choices, ndarray):
        xchoices = choices
        bshape = broadcast_shapes(xchoices[0].shape, a.shape)
    else:
        compile_error("'choices' must be a list, tuple or array")

    dtype = xchoices[0].dtype

    if out is None:
        ans = empty(bshape, dtype)
    else:
        if not isinstance(out, ndarray):
            compile_error("'out' must be an array")

        if out.dtype is not dtype:
            compile_error("'out' has incorrect dtype")

        if staticlen(bshape) != staticlen(out.shape):
            compile_error("'out' has incorrect number of dimensions")

        if bshape != out.shape:
            raise ValueError("'out' has incorrect shape")

        ans = out

    for idx in util.multirange(bshape):
        mi = a._ptr(idx, broadcast=True)[0]

        if mi < 0 or mi >= n:
            if xmode == MODE_RAISE:
                raise ValueError("invalid entry in choice array")
            elif xmode == MODE_WRAP:
                if mi < 0:
                    while mi < 0:
                        mi += n;
                else:
                    while mi >= n:
                        mi -= n
            elif xmode == MODE_CLIP:
                if mi < 0:
                    mi = 0
                elif mi >= n:
                    mi = n - 1

        choice = xchoices[mi]._ptr(idx, broadcast=True)[0]
        ans._ptr(idx)[0] = choice

    return ans

def compress(condition, a, axis = None, out = None):
    condition = asarray(condition)
    a = asarray(a)

    if staticlen(condition.shape) != 1:
        compile_error("condition must be a 1-d array")

    if axis is None:
        num_true = 0
        i = 0
        n = a.size
        for c in condition:
            if c:
                if i >= n:
                    raise IndexError(f"index {i} is out of bounds for axis 0 with size {n}")
                num_true += 1
            i += 1

        if out is None:
            ans = empty(num_true, a.dtype)
            ans_cc = True
        else:
            if not isinstance(out, ndarray):
                compile_error("'out' must be an array")

            if staticlen(out.shape) != 1:
                compile_error("'out' must be 1-dimensional")

            if out.size != num_true:
                raise ValueError("'out' has incorrect length")

            ans = out
            ans_cc = ans._contig[0]

        a_cc = a._contig[0]
        k = 0
        if a_cc and ans_cc:
            for i in range(min(condition.size, n)):
                if condition._ptr((i,))[0]:
                    ans.data[k] = util.cast(a.data[i], ans.dtype)
                    k += 1
        else:
            i = 0
            for idx in util.multirange(a.shape):
                if i >= condition.size:
                    break

                if condition[i]:
                    ans._ptr((k,))[0] = util.cast(a._ptr(idx)[0], ans.dtype)
                    k += 1

                i += 1

        return ans

    axis = util.normalize_axis_index(axis, a.ndim)
    num_true = 0
    i = 0
    n = a.shape[axis]
    for c in condition:
        if c:
            if i >= n:
                raise IndexError(f"index {i} is out of bounds for axis {axis} with size {n}")
            num_true += 1
        i += 1

    ans_shape = util.tuple_set(a.shape, axis, num_true)

    if out is None:
        ans = empty(ans_shape, a.dtype)
    else:
        if not isinstance(out, ndarray):
            compile_error("'out' must be an array")

        if staticlen(out.shape) != staticlen(ans_shape):
            compile_error("'out' has incorrect number of dimensions")

        if out.shape != ans_shape:
            raise ValueError("'out' has incorrect shape")

        ans = out

    sub_shape = util.tuple_delete(ans_shape, axis)
    k = 0

    for i in range(min(condition.size, n)):
        if condition._ptr((i,))[0]:
            for idx in util.multirange(sub_shape):
                idx1 = util.tuple_insert(idx, axis, i)
                idx2 = util.tuple_insert(idx, axis, k)

                p = a._ptr(idx1)
                q = ans._ptr(idx2)
                q[0] = util.cast(p[0], ans.dtype)
            k += 1

    return ans

def diagonal(a, offset: int = 0, axis1: int = 0, axis2: int = 1):
    a = asarray(a)
    shape = a.shape
    strides = a.strides
    ndim: Static[int] = staticlen(shape)

    if ndim < 2:
        compile_error("diag requires an array of at least two dimensions")

    axis1 = util.normalize_axis_index(axis1, ndim)
    axis2 = util.normalize_axis_index(axis2, ndim)

    if axis1 == axis2:
        raise ValueError("axis1 and axis2 cannot be the same")

    if axis1 > axis2:
        axis1, axis2 = axis2, axis1

    dim1 = shape[axis1]
    dim2 = shape[axis2]
    stride1 = strides[axis1]
    stride2 = strides[axis2]
    data = a.data

    if offset >= 0:
        offset_stride = stride2
        dim2 -= offset
    else:
        offset = -offset
        offset_stride = stride1
        dim1 -= offset

    diag_size = dim2 if dim2 < dim1 else dim1
    if diag_size < 0:
        diag_size = 0
    else:
        data = Ptr[a.dtype](data.as_byte() + (offset * offset_stride))

    ret_shape = util.tuple_delete(util.tuple_delete(shape, axis2), axis1) + (diag_size,)
    ret_strides = util.tuple_delete(util.tuple_delete(strides, axis2), axis1) + (stride1 + stride2,)

    return ndarray(ret_shape, ret_strides, data)

def select(condlist: List, choicelist: List, default = 0):
    n = len(condlist)
    if n != len(choicelist):
        raise ValueError(
            "list of cases must be same length as list of conditions")

    if n == 0:
        raise ValueError("select with an empty condition list is not possible")

    condlist = [asarray(cond) for cond in condlist]

    if condlist[0].dtype is not bool:
        compile_error("condlist entries should be boolean ndarray")

    cond_bshape = condlist[0].shape

    for cond in condlist:
        cond_bshape = broadcast_shapes(cond_bshape, cond.shape)

    for i in range(n):
        condlist[i] = broadcast_to(condlist[i], cond_bshape)

    choicelist = [asarray(choice) for choice in choicelist]
    choice_bshape = choicelist[0].shape

    for choice in choicelist:
        choice_bshape = broadcast_shapes(choice_bshape, choice.shape)

    for i in range(n):
        choicelist[i] = broadcast_to(choicelist[i], choice_bshape)

    default = asarray(default)
    ans_shape = broadcast_shapes(cond_bshape, choice_bshape, default.shape)
    dtype = type(util.coerce(choicelist[0].dtype, default.dtype))
    ans = empty(ans_shape, dtype)

    for idx in util.multirange(ans_shape):
        found = False
        p = ans._ptr(idx)
        for i in range(n):
            if condlist[i]._ptr(idx, broadcast=True)[0]:
                p[0] = util.cast(choicelist[i]._ptr(idx, broadcast=True)[0], dtype)
                found = True
                break
        if not found:
            p[0] = util.cast(default._ptr(idx, broadcast=True)[0], dtype)

    return ans

def place(arr: ndarray, mask, vals):
    mask = asarray(mask)
    vals = asarray(vals)

    ni = arr.size
    nm = mask.size
    nv = vals.size

    if nm != ni:
        raise ValueError("place: mask and data must be the same size")

    if nv <= 0:
        if mask.any():
            raise ValueError("Cannot insert from an empty array!");
        return

    cc1, _ = arr._contig
    cc2, _ = mask._contig
    cc3, _ = vals._contig
    j = 0

    if cc1 and cc2 and cc3:
        for i in range(ni):
            if mask.data[i]:
                if j >= nv:
                    j = 0

                arr.data[i] = util.cast(vals.data[j], arr.dtype)
                j += 1
    else:
        for i in range(ni):
            if mask._get_flat(i, check=False):
                if j >= nv:
                    j = 0

                arr._set_flat(i, vals._get_flat(j, check=False), check=False)
                j += 1

def put(a: ndarray, ind, v, mode: str = 'raise'):
    def fix_index(idx: int, dim: int, mode: str):
        if mode == 'raise':
            if idx < 0 or idx >= dim:
                raise ValueError("invalid entry in coordinates array")
        elif mode == 'wrap':
            if idx < 0:
                while idx < 0:
                    idx += dim
            elif idx >= dim:
                while idx >= dim:
                    idx -= dim
        elif mode == 'clip':
            if idx < 0:
                idx = 0
            elif idx >= dim:
                idx = dim - 1

        return idx

    if mode not in ('raise', 'wrap', 'clip'):
        raise ValueError(f"clipmode must be one of 'clip', 'raise', or 'wrap' (got {repr(mode)})")

    ind = asarray(ind)
    v = asarray(v)

    na = a.size
    ni = ind.size
    nv = v.size

    if ni == 0 or nv == 0:
        return

    if na == 0 and (mode == 'wrap' or mode == 'clip'):
        raise ValueError("empty array given to put")

    cc1, _ = a._contig
    cc2, _ = ind._contig
    cc3, _ = v._contig
    j = 0

    if cc1 and cc2 and cc3:
        for i in range(ni):
            idx = util.cast(ind.data[i], int)
            idx = fix_index(idx, na, mode)

            if j >= nv:
                j = 0

            a.data[idx] = util.cast(v.data[j], a.dtype)
            j += 1
    else:
        for i in range(ni):
            idx = util.cast(ind._get_flat(i, check=False), int)
            idx = fix_index(idx, na, mode)

            if j >= nv:
                j = 0

            a._set_flat(idx, v._get_flat(j, check=False), check=True)
            j += 1

def put_along_axis(arr: ndarray, indices, values, axis):
    def dim_mismatch():
        compile_error("`indices` and `arr` must have the same number of dimensions")

    indices = asarray(indices)
    values = asarray(values)

    if indices.dtype is not int:
        compile_error("`indices` must be an integer array")

    if axis is None:
        nv = values.size
        j = 0

        for i in range(indices.size):
            if j >= nv:
                j = 0
            arr._set_flat(indices._get_flat(i, check=False),
                          values._get_flat(j, check=False),
                          check=True)
            j += 1

        return

    bshape = broadcast_shapes(indices.shape, values.shape)
    indices = broadcast_to(indices, bshape)
    values = broadcast_to(values, bshape)

    if staticlen(arr.shape) != staticlen(indices.shape):
        dim_mismatch()

    axis = util.normalize_axis_index(axis, arr.ndim)
    M = arr.shape[axis]
    J = indices.shape[axis]

    bshape = broadcast_shapes(util.tuple_delete(arr.shape, axis),
                              util.tuple_delete(indices.shape, axis))

    a_stride = arr.strides[axis]
    indices_stride = indices.strides[axis]
    values_stride = values.strides[axis]

    for idx in util.multirange(bshape):
        base_idx   = util.tuple_insert(idx, axis, 0)
        a_1d       = ndarray((M,), (a_stride,), arr._ptr(base_idx, broadcast=True))
        indices_1d = ndarray((J,), (indices_stride,), indices._ptr(base_idx, broadcast=True))
        values_1d  = ndarray((J,), (values_stride,), values._ptr(base_idx, broadcast=True))

        for j in range(J):
            a_1d[indices_1d._ptr((j,))[0]] = util.cast(values_1d._ptr((j,))[0], arr.dtype)

def putmask(a: ndarray, mask, values):
    mask = asarray(mask)
    values = asarray(values)
    na = a.size
    nv = values.size

    if na != mask.size:
        raise ValueError("putmask: mask and data must be the same size")

    if na == 0 or nv == 0:
        return

    cc1, _ = a._contig
    cc2, _ = mask._contig
    cc3, _ = values._contig
    j = 0

    if cc1 or cc2 or cc3:
        for i in range(na):
            if mask.data[i]:
                a.data[i] = values.data[j]
            j += 1
            if j >= nv:
                j = 0
    else:
        for i in range(na):
            if mask._get_flat(i, check=False):
                a._set_flat(i, values._get_flat(j, check=False), check=False)
            j += 1
            if j >= nv:
                j = 0

def fill_diagonal(a: ndarray, val, wrap: bool = False):
    if staticlen(a.shape) < 2:
        compile_error("array must be at least 2-d")

    val = asarray(val)
    nv = val.size
    end = a.size

    if nv == 0 or end == 0:
        return

    if staticlen(a.shape) == 2:
        step = a.shape[1] + 1
        if not wrap:
            end = min(end, a.shape[1] * a.shape[1])
    else:
        for i in range(1, a.ndim):
            if a.shape[i] != a.shape[0]:
                raise ValueError("All dimensions of input must be of equal length")
        step = 1
        acc = 1
        for s in a.shape[:-1]:
            acc *= s
            step += acc

    j = 0
    for i in range(0, end, step):
        a._set_flat(i, val._get_flat(j, check=False), check=False)
        j += 1
        if j >= nv:
            j = 0

########
# Misc #
########

def _power_of_ten(n: int):
    p10 = (1e0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8)
    if n < 9:
        return Ptr[float](__ptr__(p10).as_byte())[n]
    else:
        ret = 1e9
        while n > 9:
            ret *= 10.
            n -= 1
        return ret

def _round_int(x: T, decimals: int, T: type):
    if decimals >= 0:
        return x
    else:
        f = _power_of_ten(-decimals)
        y = util.cast(x, float)
        return T(int(util.rint(y / f) * f))

def _round(x: T, decimals: int, T: type):
    if isinstance(x, complex) or isinstance(x, complex64):
        return T(_round(x.real, decimals), _round(x.imag, decimals))

    if (isinstance(x, int) or
        isinstance(x, Int) or
        isinstance(x, UInt) or
        isinstance(x, byte) or
        isinstance(x, bool)):
        return _round_int(x, decimals)

    if not (isinstance(x, float) or isinstance(x, float32)):
        compile_error("don't know how to round type '" + T.__name__ + "'")

    if decimals == 0:
        return util.rint(x)
    elif decimals > 0:
        f = T(_power_of_ten(decimals))
        return util.rint(x * f) / f
    else:
        f = T(_power_of_ten(-decimals))
        return util.rint(x / f) * f

def round(a, decimals: int = 0, out = None):
    a = asarray(a)

    if staticlen(a.shape) == 0:
        return _round(a.data[0], decimals)

    cc, fc = a._contig
    n = a.size

    if out is None:
        ans = empty_like(a, order=('F' if fc and not cc else 'C'))
    else:
        if not isinstance(out, ndarray):
            compile_error("output must be an array")

        if out.dtype is not a.dtype:
            compile_error("output has wrong type")

        if not util.tuple_equal(a.shape, out.shape):
            raise ValueError("invalid output shape")

        ans = out

    cc1, fc1 = ans._contig

    if (cc and cc1) or (fc and fc1):
        p = a.data
        q = ans.data
        for i in range(n):
            q[i] = _round(p[i], decimals)
    else:
        for idx in util.multirange(a.shape):
            p = a._ptr(idx)
            q = ans._ptr(idx)
            q[0] = _round(p[0], decimals)

    return ans

around = round

def _clip_full(a, a_min, a_max, out):
    a = asarray(a)
    a_min = asarray(a_min)
    a_max = asarray(a_max)
    bshape = broadcast_shapes(a.shape, a_min.shape, a_max.shape)

    if out is None:
        ans = empty(bshape, a.dtype)
    elif isinstance(out, ndarray):
        if staticlen(bshape) != staticlen(out.shape):
            compile_error("'out' has incorrect number of dimensions")

        if bshape != out.shape:
            raise ValueError("'out' has incorrect shape")

        ans = out
    else:
        compile_error("'out' must be an array")

    maxf = lambda a, b: a if a > b else b
    minf = lambda a, b: a if a < b else b

    for idx in util.multirange(bshape):
        x = a._ptr(idx, broadcast=True)[0]
        xmin = a_min._ptr(idx, broadcast=True)[0]
        xmax = a_max._ptr(idx, broadcast=True)[0]

        x = util.cast(x, ans.dtype)
        xmin = util.cast(xmin, ans.dtype)
        xmax = util.cast(xmax, ans.dtype)
        ans._ptr(idx)[0] = minf(xmax, maxf(x, xmin))

    return ans

def _clip_min(a, a_min, out):
    a = asarray(a)
    a_min = asarray(a_min)
    bshape = broadcast_shapes(a.shape, a_min.shape)

    if out is None:
        ans = empty(bshape, a.dtype)
    elif isinstance(out, ndarray):
        if staticlen(bshape) != staticlen(out.shape):
            compile_error("'out' has incorrect number of dimensions")

        if bshape != out.shape:
            raise ValueError("'out' has incorrect shape")

        ans = out
    else:
        compile_error("'out' must be an array")

    maxf = lambda a, b: a if a > b else b

    for idx in util.multirange(bshape):
        x = a._ptr(idx, broadcast=True)[0]
        xmin = a_min._ptr(idx, broadcast=True)[0]

        x = util.cast(x, ans.dtype)
        xmin = util.cast(xmin, ans.dtype)
        ans._ptr(idx)[0] = maxf(x, xmin)

    return ans

def _clip_max(a, a_max, out):
    a = asarray(a)
    a_max = asarray(a_max)
    bshape = broadcast_shapes(a.shape, a_max.shape)

    if out is None:
        ans = empty(bshape, a.dtype)
    elif isinstance(out, ndarray):
        if staticlen(bshape) != staticlen(out.shape):
            compile_error("'out' has incorrect number of dimensions")

        if bshape != out.shape:
            raise ValueError("'out' has incorrect shape")

        ans = out
    else:
        compile_error("'out' must be an array")

    minf = lambda a, b: a if a < b else b

    for idx in util.multirange(bshape):
        x = a._ptr(idx, broadcast=True)[0]
        xmax = a_max._ptr(idx, broadcast=True)[0]

        x = util.cast(x, ans.dtype)
        xmax = util.cast(xmax, ans.dtype)
        ans._ptr(idx)[0] = minf(xmax, x)

    return ans

def clip(a, a_min, a_max, out = None):
    if a_min is None and a_max is None:
        compile_error("One of max or min must be given")
    elif a_max is None:
        return _clip_min(a, a_min, out)
    elif a_min is None:
        return _clip_max(a, a_max, out)
    else:
        return _clip_full(a, a_min, a_max, out)

def ndenumerate(arr):
    arr = asarray(arr)
    for idx in util.multirange(arr.shape):
        yield (idx, arr._ptr(idx)[0])

def ndindex(*shape):
    if staticlen(shape) == 1:
        if isinstance(shape[0], Tuple):
            return ndindex(*shape[0])

    for s in shape:
        if s < 0:
            raise ValueError("negative dimensions are not allowed")
    return util.multirange(shape)

def iterable(y):
    if not hasattr(y, "__iter__"):
        return False
    elif isinstance(y, ndarray):
        if staticlen(y.shape) == 0:
            return False
        else:
            return True
    else:
        return True

def packbits(a, axis = None, bitorder: str = 'big'):
    a = asarray(a)

    if (a.dtype is not bool and
        a.dtype is not int and
        a.dtype is not byte and
        not isinstance(a.dtype, Int) and
        not isinstance(a.dtype, UInt)):
        compile_error("Expected an input array of integer or boolean data type")

    little_endian = False
    if bitorder == 'big':
        little_endian = False
    elif bitorder == 'little':
        little_endian = True
    else:
        raise ValueError("'order' must be either 'little' or 'big'")

    if axis is None:
        pack_size = ((a.size - 1) >> 3) + 1
        ans = empty(pack_size, dtype=u8)
        m = 0
        k = 0
        e = u8(0)

        for idx in util.multirange(a.shape):
            if m == 8:
                ans.data[k] = e
                k += 1
                e = u8(0)
                m = 0

            b = u8(1 if a._ptr(idx)[0] else 0)
            if little_endian:
                e = (e >> u8(1)) | (b << u8(7))
            else:
                e = (e << u8(1)) | b
            m += 1

        if k < pack_size:
            if little_endian:
                e >>= u8(8 - m)
            else:
                e <<= u8(8 - m)
            ans.data[k] = e

        return ans

    if not isinstance(axis, int):
        compile_error("'axis' must be an int or None")

    axis = util.normalize_axis_index(axis, a.ndim)
    n = a.shape[axis]
    pack_size = ((n - 1) >> 3) + 1
    ans_shape = util.tuple_set(a.shape, axis, pack_size)
    ans = empty(ans_shape, dtype=u8)

    for idx0 in util.multirange(util.tuple_delete(a.shape, axis)):
        m = 0
        k = 0
        e = u8(0)

        for i in range(n):
            if m == 8:
                ans._ptr(util.tuple_insert(idx0, axis, k))[0] = e
                k += 1
                e = u8(0)
                m = 0

            b = u8(1 if a._ptr(util.tuple_insert(idx0, axis, i))[0] else 0)
            if little_endian:
                e = (e >> u8(1)) | (b << u8(7))
            else:
                e = (e << u8(1)) | b
            m += 1

        if k < pack_size:
            if little_endian:
                e >>= u8(8 - m)
            else:
                e <<= u8(8 - m)
            ans._ptr(util.tuple_insert(idx0, axis, k))[0] = e

    return ans

def unpackbits(a, axis = None, count = None, bitorder: str = 'big'):
    a = asarray(a)

    if a.dtype is not u8:
        compile_error("Expected an input array of unsigned byte data type")

    little_endian = False
    if bitorder == 'big':
        little_endian = False
    elif bitorder == 'little':
        little_endian = True
    else:
        raise ValueError("'order' must be either 'little' or 'big'")

    if axis is None:
        unpack_size = a.size * 8

        if count is not None:
            if count < 0:
                if -count > unpack_size:
                    raise ValueError("-count larger than number of elements")

                unpack_size += count
            else:
                unpack_size = count

        ans = empty(unpack_size, dtype=u8)
        k = 0

        for idx in util.multirange(a.shape):
            e = a._ptr(idx)[0]

            for i in range(8):
                if k >= unpack_size:
                    break
                sh = u8(i if little_endian else 7 - i)
                ans.data[k] = (e & (u8(1) << sh)) >> sh
                k += 1

            if k >= unpack_size:
                break

        while k < unpack_size:
            ans.data[k] = u8(0)
            k += 1

        return ans

    if not isinstance(axis, int):
        compile_error("'axis' must be an int or None")

    axis = util.normalize_axis_index(axis, a.ndim)
    n = a.shape[axis]
    unpack_size = n * 8

    if count is not None:
        if count < 0:
            if -count > unpack_size:
                raise ValueError("-count larger than number of elements")

            unpack_size += count
        else:
            unpack_size = count

    ans_shape = util.tuple_set(a.shape, axis, unpack_size)
    ans = empty(ans_shape, dtype=u8)

    if unpack_size == 0:
        return ans

    for idx0 in util.multirange(util.tuple_delete(a.shape, axis)):
        k = 0

        for m in range(n):
            e = a._ptr(util.tuple_insert(idx0, axis, m))[0]

            for i in range(8):
                if k >= unpack_size:
                    break
                sh = u8(i if little_endian else 7 - i)
                ans._ptr(util.tuple_insert(idx0, axis, k))[0] = (e & (u8(1) << sh)) >> sh
                k += 1

            if k >= unpack_size:
                break

        while k < unpack_size:
            ans._ptr(util.tuple_insert(idx0, axis, k))[0] = u8(0)
            k += 1

    return ans

def _is_pos_neg_inf(x, pos: bool, out = None):
    from .ndmath import isinf, signbit
    x = asarray(x)

    if out is None:
        ans = empty(x.shape, bool)
    else:
        if not isinstance(out, ndarray):
            compile_error("'out' must be an array")

        if out.ndim != x.ndim:
            compile_error("'out' has incorrect number of dimensions")

        if x.shape != out.shape:
            raise ValueError("'out' has incorrect shape")

        ans = out

    for idx in util.multirange(x.shape):
        e = x._ptr(idx)[0]
        b = isinf(e)
        if pos:
            b = b and not signbit(e)
        else:
            b = b and signbit(e)
        ans._ptr(idx)[0] = util.cast(b, ans.dtype)

    if out is None and ans.ndim == 0:
        return ans.item()
    else:
        return ans

def isposinf(x, out = None):
    return _is_pos_neg_inf(x, pos=True, out=out)

def isneginf(x, out = None):
    return _is_pos_neg_inf(x, pos=False, out=out)

def iscomplex(x):
    x = asarray(x)
    if x.dtype is complex or x.dtype is complex64:
        ans = x.map(lambda c: bool(c.imag))
    else:
        ans = zeros(x.shape, bool)

    if ans.ndim == 0:
        return ans.item()
    else:
        return ans

def iscomplexobj(x):
    if isinstance(x, ndarray):
        return x.dtype is complex or x.dtype is complex64
    else:
        dtype = asarray(x).dtype
        return dtype is complex or dtype is complex64

def isreal(x):
    x = asarray(x)
    if x.dtype is complex or x.dtype is complex64:
        ans = x.map(lambda c: not bool(c.imag))
    else:
        ans = ones(x.shape, bool)

    if ans.ndim == 0:
        return ans.item()
    else:
        return ans

def isrealobj(x):
    return not iscomplexobj(x)

def isfortran(a: ndarray):
    cc, fc = a._contig
    return fc and not cc

def isscalar(element):
    T = type(element)
    return (T is int or
            T is float or
            T is complex or
            T is complex64 or
            T is bool or
            T is byte or
            isinstance(T, Int) or
            isinstance(T, UInt) or
            T is str or
            T is NoneType)

def _array_get_part(arr: ndarray, imag: Static[int]):
    if arr.dtype is complex:
        offset = util.sizeof(float) if imag else 0
        data = Ptr[float](arr.data.as_byte() + offset)
        return ndarray(arr.shape, arr.strides, data)
    elif arr.dtype is complex64:
        offset = util.sizeof(float32) if imag else 0
        data = Ptr[float32](arr.data.as_byte() + offset)
        return ndarray(arr.shape, arr.strides, data)
    else:
        if imag:
            n = arr.size
            data = Ptr[arr.dtype](n)
            str.memset(data.as_byte(), byte(0), n * arr.itemsize)
            return ndarray(arr.shape, data)
        else:
            return arr

def real(val):
    return _array_get_part(asarray(val), imag=False)

def imag(val):
    return _array_get_part(asarray(val), imag=True)

def _real_set(arr: ndarray, val):
    val = broadcast_to(asarray(val), arr.shape)
    for idx in util.multirange(arr.shape):
        x = val._ptr(idx)[0]
        p = arr._ptr(idx)
        if arr.dtype is complex:
            p[0] = complex(util.cast(x, float), p[0].imag)
        elif arr.dtype is complex64:
            p[0] = complex64(util.cast(x, float32), p[0].imag)
        else:
            p[0] = util.cast(x, arr.dtype)

def _imag_set(arr: ndarray, val):
    val = broadcast_to(asarray(val), arr.shape)
    for idx in util.multirange(arr.shape):
        x = val._ptr(idx)[0]
        p = arr._ptr(idx)
        if arr.dtype is complex:
            p[0] = complex(p[0].real, util.cast(x, float))
        elif arr.dtype is complex64:
            p[0] = complex64(p[0].imag, util.cast(x, float32))
        else:
            compile_error("array does not have imaginary part to set")

@extend
class ndarray:
    def take(self, indices, axis = None, out = None, mode: str = 'raise'):
        return take(self, indices, axis=axis, out=out, mode=mode)

    def squeeze(self, axis = None):
        return squeeze(self, axis)

    def nonzero(self):
        return nonzero(self)

    def searchsorted(self, v, side: str = 'left', sorter = None):
        return searchsorted(self, v=v, side=side, sorter=sorter)

    def repeat(self, repeats, axis = None):
        return repeat(self, repeats, axis=axis)

    def compress(self, condition, axis = None, out = None):
        return compress(condition, self, axis=axis, out=out)

    def choose(self, choices, out = None, mode: str = 'raise'):
        return choose(self, choices, out, mode)

    def diagonal(self, offset: int = 0, axis1: int = 0, axis2: int = 1):
        return diagonal(self, offset, axis1=axis1, axis2=axis2)

    def put(self, ind, v, mode: str = 'raise'):
        return put(self, ind, v, mode)

    def round(self, decimals: int = 0, out = None):
        return round(self, decimals, out=out)

    def clip(self, min = None, max = None, out = None):
        return clip(self, min, max, out)

    @property
    def real(self):
        return real(self)

    @real.setter
    def real(self, val):
        _real_set(self, val)

    @property
    def imag(self):
        return imag(self)

    @imag.setter
    def imag(self, val):
        _imag_set(self, val)

    def conj(self):
        if hasattr(dtype(), 'conjugate'):
            return self._op_unary(lambda a: a.conjugate())
        else:
            return self

    def conjugate(self):
        return self.conj()

    @flat.setter
    def flat(self, val):
        val = asarray(val)
        j = 0
        m = val.size

        if m == 0:
            return

        for i in range(self.size):
            p = self._ptr_flat(i, check=False)
            if j == m:
                j = 0
            q = val._ptr_flat(j, check=False)
            p[0] = util.cast(q[0], dtype)
            j += 1

# Not supported:
#   - rollaxis (use is discouraged in NumPy docs)
#   - asmatrix (matrix class not supported)

############
# Datetime #
############

@extend
class busdaycalendar:

    @property
    def weekmask(self):
        wm = empty(7, bool)
        for i in range(7):
            wm.data[i] = self._wm[i]
        return wm

    @property
    def holidays(self):
        hd = empty(self._nholidays, datetime64['D', 1])
        for i in range(len(hd)):
            hd.data[i] = self._holidays[i]
        return hd

def _get_busdaycal(weekmask=None, holidays=None, busdaycal=None):
    if busdaycal is None:
        if weekmask is None:
            w = "1111100"
        else:
            w = weekmask
        return busdaycalendar(weekmask=w, holidays=holidays)
    else:
        if weekmask is not None or holidays is not None:
            compile_error("Cannot supply both the weekmask/holidays and the "
                          "busdaycal parameters to busday_offset()")

        if not isinstance(busdaycal, busdaycalendar):
            compile_error("busdaycal parameter must be a busdaycalendar")

        return busdaycal

def busday_offset(dates, offsets, roll: str = "raise", weekmask=None,
                  holidays=None, busdaycal=None, out=None):
    cal = _get_busdaycal(weekmask=weekmask, holidays=holidays, busdaycal=busdaycal)
    wm = cal._wm
    busdays_in_weekmask = wm.count

    dates = asarray(dates, dtype=datetime64['D', 1])
    offsets = asarray(offsets)

    if offsets.dtype is not int:
        compile_error("offsets parameter must be an array of integers")

    roll_code = 0

    if roll == "forward":
        roll_code = _BUSDAY_FORWARD
    elif roll == "following":
        roll_code = _BUSDAY_FOLLOWING
    elif roll == "backward":
        roll_code = _BUSDAY_BACKWARD
    elif roll == "preceding":
        roll_code = _BUSDAY_PRECEDING
    elif roll == "modifiedfollowing":
        roll_code = _BUSDAY_MODIFIEDFOLLOWING
    elif roll == "modifiedpreceding":
        roll_code = _BUSDAY_MODIFIEDPRECEDING
    elif roll == "nat":
        roll_code = _BUSDAY_NAT
    elif roll == "raise":
        roll_code = _BUSDAY_RAISE
    else:
        raise ValueError(f"Invalid business day roll parameter \"{roll}\"")

    bshape = broadcast_shapes(dates.shape, offsets.shape)

    if out is not None:
        if not isinstance(out, ndarray):
            compile_error("'out' must be an array")

        if out.dtype is not datetime64['D', 1]:
            compile_error("'out' must have dtype datetime64[D]")

        if out.ndim != staticlen(bshape):
            compile_error("'out' has incorrect number of dimensions")

        if out.shape != bshape:
            raise ValueError("'out' has incorrect shape")

        ans = out
    else:
        ans = empty(bshape, datetime64['D', 1])

    for idx in util.multirange(bshape):
        d = dates._ptr(idx, broadcast=True)[0]
        o = offsets._ptr(idx, broadcast=True)[0]
        b = _apply_busines_day_offset(d, o, roll_code, wm, busdays_in_weekmask,
                                      cal._holidays, cal._holidays + cal._nholidays)
        ans._ptr(idx)[0] = b

    if ans.ndim == 0 and out is None:
        return ans.item()
    else:
        return ans

def busday_count(begindates, enddates, weekmask=None,
                 holidays=None, busdaycal=None, out=None):
    cal = _get_busdaycal(weekmask=weekmask, holidays=holidays, busdaycal=busdaycal)
    wm = cal._wm
    busdays_in_weekmask = wm.count
    begindates = asarray(begindates, dtype=datetime64['D', 1])
    enddates = asarray(enddates, dtype=datetime64['D', 1])

    if out is not None:
        if not isinstance(out, ndarray):
            compile_error("'out' must be an array")

        if out.dtype is not int:
            compile_error("'out' must have dtype int")

        bshape = broadcast_shapes(begindates.shape, enddates.shape, out.shape)
        ans = out
    else:
        bshape = broadcast_shapes(begindates.shape, enddates.shape)
        ans = empty(bshape, int)

    for idx in util.multirange(bshape):
        d1 = begindates._ptr(idx, broadcast=True)[0]
        d2 = enddates._ptr(idx, broadcast=True)[0]
        c = _apply_busines_day_count(d1, d2, wm, busdays_in_weekmask,
                                     cal._holidays, cal._holidays + cal._nholidays)
        ans._ptr(idx, broadcast=(out is not None))[0] = c

    if ans.ndim == 0 and out is None:
        return ans.item()
    else:
        return ans

def is_busday(dates, weekmask=None, holidays=None, busdaycal=None, out=None):
    cal = _get_busdaycal(weekmask=weekmask, holidays=holidays, busdaycal=busdaycal)
    wm = cal._wm
    busdays_in_weekmask = wm.count
    dates = asarray(dates, dtype=datetime64['D', 1])

    if out is not None:
        if not isinstance(out, ndarray):
            compile_error("'out' must be an array")

        if out.dtype is not bool:
            compile_error("'out' must have dtype bool")

        if out.ndim != dates.ndim:
            compile_error("'out' has incorrect number of dimensions")

        if out.shape != dates.shape:
            raise ValueError("'out' has incorrect shape")

        ans = out
    else:
        ans = empty(dates.shape, bool)

    for idx in util.multirange(dates.shape):
        d = dates._ptr(idx)[0]
        ans._ptr(idx)[0] = _apply_is_business_day(d, wm, cal._holidays, cal._holidays + cal._nholidays)

    if ans.ndim == 0 and out is None:
        return ans.item()
    else:
        return ans

def datetime_data(dtype: type):
    if not isinstance(dtype, datetime64) and not isinstance(dtype, timedelta64):
        compile_error("cannot get datetime metadata from non-datetime type")
    return (dtype.base, dtype.num)

def datetime_as_string(arr, unit=None, timezone: str = "naive"):
    arr = asarray(arr)
    if not isinstance(arr.dtype, datetime64):
        compile_error("input must have type NumPy datetime")
    return arr.map(lambda d: d._as_string(unit=unit, timezone=timezone))