# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>

import util
from .reductions import mean, std, percentile
from .ndarray import ndarray
from .routines import array, asarray, full, broadcast_to, swapaxes, empty, concatenate, around, diag, clip, zeros, searchsorted, linspace
from .ndmath import multiply, true_divide, conjugate, power, logical_and, ceil, sqrt, subtract, min
from .linalg_sym import dot
from .sorting import sort, argsort

def average(a,
            axis=None,
            weights=None,
            returned: Static[int] = False,
            keepdims: Static[int] = False):

    def result_type(a_dtype: type, w_dtype: type):
        common_dtype = type(util.coerce(a_dtype, w_dtype))
        if isinstance(a_dtype, int) or isinstance(a_dtype, bool):
            return type(util.coerce(common_dtype, float))()
        else:
            return common_dtype()

    a = asarray(a)

    if weights is None:
        avg = a.mean(axis, keepdims=keepdims)
        if returned:
            if isinstance(avg, ndarray):
                scl = a.size / avg.size
                return avg, full(avg.shape, scl, dtype=avg.dtype)
            else:
                scl = a.size
                return avg, util.cast(scl, type(avg))
        else:
            return avg
    else:
        w = asarray(weights)
        result_dtype = type(result_type(a.dtype, w.dtype))

        if a.ndim != w.ndim:
            if axis is None:
                compile_error(
                    "Axis must be specified when shapes of a and weights differ."
                )

            if w.ndim != 1:
                compile_error(
                    "1D weights expected when shapes of a and weights differ.")

            wgt = broadcast_to(w, ((1, ) * (a.ndim - 1)) + w.shape).swapaxes(
                -1, axis)
            scl = wgt.sum(axis=axis, dtype=result_dtype, keepdims=keepdims)
            avg = multiply(a, wgt, dtype=result_dtype).sum(
                axis, keepdims=keepdims) / scl

            if returned:
                if scl.shape != avg.shape:
                    scl = broadcast_to(scl, avg.shape).copy()
                return avg, scl
            else:
                return avg
        else:
            ax = 0
            if isinstance(axis, int):
                ax = axis
            elif axis is not None:
                ax = axis[0]

            if a.shape != w.shape:
                if axis is None:
                    raise TypeError(
                        "Axis must be specified when shapes of a and weights differ."
                    )

                if w.ndim != 1:
                    raise TypeError(
                        "1D weights expected when shapes of a and weights differ."
                    )

                if w.shape[0] != a.shape[ax]:
                    raise ValueError(
                        "Length of weights not compatible with specified axis."
                    )

            def get_axis(axis):
                if axis is not None:
                    return axis
                else:
                    return None

            ax = get_axis(axis)

            scl = w.sum(axis=ax, dtype=result_dtype, keepdims=keepdims)
            avg = multiply(a, w, dtype=result_dtype).sum(
                ax, keepdims=keepdims) / scl

            if returned:
                if isinstance(scl, ndarray) and isinstance(avg, ndarray):
                    if scl.shape != avg.shape:
                        scl = broadcast_to(scl, avg.shape).copy()
                return avg, util.cast(scl, type(avg))
            else:
                return avg

def cov(m,
        y=None,
        rowvar: bool = True,
        bias: bool = False,
        ddof: Optional[int] = None,
        fweights=None,
        aweights=None,
        dtype: type = NoneType):

    def result_type(a_dtype: type, y_dtype: type):
        common_dtype = type(util.coerce(a_dtype, y_dtype))
        if isinstance(a_dtype, int) or isinstance(a_dtype, bool):
            return util.coerce(common_dtype, float)
        else:
            return common_dtype()

    def get_dtype(m, y, dtype: type):
        if dtype is NoneType:
            if y is None:
                return result_type(m.dtype, float)
            else:
                tmp_dtype = result_type(m.dtype, y.dtype)
                return result_type(type(tmp_dtype), float)
        else:
            return dtype()

    m = asarray(m)
    if m.ndim > 2:
        compile_error("m has more than 2 dimensions")

    if y is not None:
        y2 = asarray(y)
        if y2.ndim > 2:
            compile_error("y has more than 2 dimensions")
    else:
        y2 = y

    dtype2 = type(get_dtype(m, y2, dtype))
    X = array(m, ndmin=2, dtype=dtype2)

    if not rowvar and X.shape[0] != 1:
        X = X.T

    if X.shape[0] == 0:
        if m.ndim == 1 and y is None:
            return util.cast(util.nan64(), dtype2)
        else:
            return empty((0, 0), dtype2)

    if y2 is not None:
        y2 = array(y2, copy=False, ndmin=2, dtype=dtype2)
        if not rowvar and y2.shape[0] != 1:
            y2 = y2.T
        X = concatenate((X, y2), axis=0)

    ddof2: int = 0
    if ddof is None:
        if not bias:
            ddof2 = 1
        else:
            ddof2 = 0
    else:
        ddof2 = ddof

    # Get the product of frequencies and weights
    w = ndarray[float, 1]((0, ), (0, ), Ptr[float]())

    if fweights is not None:
        fweights = asarray(fweights, dtype=float)

        if fweights.ndim == 0:
            compile_error("fweights must be at least 1-dimensional")
        elif fweights.ndim > 1:
            compile_error("cannot handle multidimensional fweights")

        if fweights.shape[0] != X.shape[1]:
            raise ValueError("incompatible numbers of samples and fweights")

        for item in fweights:
            if item != around(item):
                raise TypeError("fweights must be integer")
            elif item < 0:
                raise ValueError("fweights cannot be negative")

        w = fweights

    if aweights is not None:
        aweights2 = asarray(aweights, dtype=float)

        if aweights2.ndim == 0:
            compile_error("aweights must be at least 1-dimensional")
        elif aweights2.ndim > 1:
            compile_error("cannot handle multidimensional aweights")

        if aweights2.shape[0] != X.shape[1]:
            raise ValueError("incompatible numbers of samples and aweights")

        for item in aweights2:
            if item < 0:
                raise ValueError("aweights cannot be negative")

        if w.size == 0:
            w = aweights2
        else:
            w *= aweights2

    if w.size == 0:
        avg1, w_sum1 = average(X, axis=1, weights=None, returned=True)
        avg = asarray(avg1, dtype=dtype2)
        w_sum = asarray(w_sum1, dtype=dtype2)
    else:
        avg1, w_sum1 = average(X, axis=1, weights=w, returned=True)
        avg = asarray(avg1, dtype=dtype2)
        w_sum = asarray(w_sum1, dtype=dtype2)

    w_sum = w_sum[0]
    fact = util.cast(0, dtype2)

    # Determine the normalization
    if w.size == 0:
        fact = util.cast(X.shape[1] - ddof2, dtype2)
    elif ddof2 == 0:
        fact = util.cast(w_sum, dtype2)
    elif aweights is None:
        fact = util.cast(w_sum - ddof2, dtype2)
    else:
        fact = util.cast(w_sum - ddof2 * sum(w * aweights2) / w_sum, dtype2)

    if util.cast(fact, float) <= 0.0:
        # warn "Degrees of freedom <= 0 for slice"
        fact = util.cast(0, dtype2)

    X -= avg[:, None]

    if w.size == 0:
        X_T = X.T
    else:
        Xw = multiply(X, w)
        X_T = Xw.T

    c = dot(X, X_T.conj())
    c *= true_divide(1, fact)
    if m.ndim == 1 and y is None:
        return c.item()
    else:
        return c

def corrcoef(x, y=None, rowvar=True, dtype: type = NoneType):
    c = cov(x, y=y, rowvar=rowvar, dtype=dtype)
    b = asarray(c)
    if b.ndim != 1 and b.ndim != 2:
        # scalar covariance
        # nan if incorrect value (nan, inf, 0), 1 otherwise
        return c / c

    d = diag(c)
    stddev = sqrt(d.real)
    c /= stddev[:, None]
    c /= stddev[None, :]

    # Clip real and imaginary parts to [-1, 1]
    clip(c.real, -1, 1, out=c.real)
    if c.dtype is complex or c.dtype is complex64:
        clip(c.imag, -1, 1, out=c.imag)

    return c

def _correlate(a, b, mode: str):

    def kernel(d, dstride: int, nd: int, dtype: type,
               k, kstride: int, nk: Static[int], ktype: type,
               out, ostride: int):
        for i in range(nd):
            acc = util.zero(dtype)
            for j in staticrange(nk):
                acc += d[(i + j) * dstride] * k[j * kstride]
            out[i * ostride] = acc

    def small_correlate(d, dstride: int, nd: int, dtype: type,
                        k, kstride: int, nk: int, ktype: type,
                        out, ostride: int):
        if dtype is not ktype:
            return False

        dstride //= util.sizeof(dtype)
        kstride //= util.sizeof(dtype)
        ostride //= util.sizeof(dtype)

        if nk == 1:
            kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
                   k=k, kstride=kstride, nk=1, ktype=ktype,
                   out=out, ostride=ostride)
        elif nk == 2:
            kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
                   k=k, kstride=kstride, nk=2, ktype=ktype,
                   out=out, ostride=ostride)
        elif nk == 3:
            kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
                   k=k, kstride=kstride, nk=3, ktype=ktype,
                   out=out, ostride=ostride)
        elif nk == 4:
            kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
                   k=k, kstride=kstride, nk=4, ktype=ktype,
                   out=out, ostride=ostride)
        elif nk == 5:
            kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
                   k=k, kstride=kstride, nk=5, ktype=ktype,
                   out=out, ostride=ostride)
        elif nk == 6:
            kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
                   k=k, kstride=kstride, nk=6, ktype=ktype,
                   out=out, ostride=ostride)
        elif nk == 7:
            kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
                   k=k, kstride=kstride, nk=7, ktype=ktype,
                   out=out, ostride=ostride)
        elif nk == 8:
            kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
                   k=k, kstride=kstride, nk=8, ktype=ktype,
                   out=out, ostride=ostride)
        elif nk == 9:
            kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
                   k=k, kstride=kstride, nk=9, ktype=ktype,
                   out=out, ostride=ostride)
        elif nk == 10:
            kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
                   k=k, kstride=kstride, nk=10, ktype=ktype,
                   out=out, ostride=ostride)
        elif nk == 11:
            kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
                   k=k, kstride=kstride, nk=11, ktype=ktype,
                   out=out, ostride=ostride)
        else:
            return False

        return True

    def dot(_ip1: Ptr[T1], is1: int, _ip2: Ptr[T2], is2: int, op: Ptr[T3], n: int,
            T1: type, T2: type, T3: type):
        ip1 = _ip1.as_byte()
        ip2 = _ip2.as_byte()
        ans = util.zero(T3)

        for i in range(n):
            e1 = Ptr[T1](ip1)[0]
            e2 = Ptr[T2](ip2)[0]
            ans += util.cast(e1, T3) * util.cast(e2, T3)
            ip1 += is1
            ip2 += is2

        op[0] = ans

    def incr(p: Ptr[T], s: int, T: type):
        return Ptr[T](p.as_byte() + s)

    n1 = a.size
    n2 = b.size
    length = n1
    n = n2

    if mode == 'valid':
        length = length = length - n + 1
        n_left = 0
        n_right = 0
    elif mode == 'same':
        n_left = n >> 1
        n_right = n - n_left - 1
    elif mode == 'full':
        n_right = n - 1
        n_left = n - 1
        length = length + n - 1
    else:
        raise ValueError(
            f"mode must be one of 'valid', 'same', or 'full' (got {repr(mode)})"
        )

    dt = type(util.coerce(a.dtype, b.dtype))
    ret = empty(length, dtype=dt)

    is1 = a.strides[0]
    is2 = b.strides[0]
    op = ret.data
    os = ret.itemsize
    ip1 = a.data
    ip2 = Ptr[b.dtype](b.data.as_byte() + n_left * is2)
    n = n - n_left

    for i in range(n_left):
        dot(ip1, is1, ip2, is2, op, n)
        n += 1
        ip2 = incr(ip2, -is2)
        op = incr(op, os)

    if small_correlate(ip1, is1, n1 - n2 + 1, a.dtype,
                       ip2, is2, n, b.dtype,
                       op, os):
        ip1 = incr(ip1, is1 * (n1 - n2 + 1))
        op = incr(op, os * (n1 - n2 + 1))
    else:
        for i in range(n1 - n2 + 1):
            dot(ip1, is1, ip2, is2, op, n)
            ip1 = incr(ip1, is1)
            op = incr(op, os)

    for i in range(n_right):
        n -= 1
        dot(ip1, is1, ip2, is2, op, n)
        ip1 = incr(ip1, is1)
        op = incr(op, os)

    return ret

def correlate(a, b, mode: str = 'valid'):
    a = asarray(a)
    b = asarray(b)

    if a.ndim != 1 or b.ndim != 1:
        compile_error('object too deep for desired array')

    n1 = a.size
    n2 = b.size

    if n1 == 0:
        raise ValueError("first argument cannot be empty")

    if n2 == 0:
        raise ValueError("second argument cannot be empty")

    if b.dtype is complex or b.dtype is complex64:
        b = b.conjugate()

    if n1 < n2:
        return _correlate(b, a, mode=mode)[::-1]
    else:
        return _correlate(a, b, mode=mode)

def bincount(x, weights=None, minlength: int = 0):
    x = asarray(x).astype(int)

    if x.ndim > 1:
        compile_error("object too deep for desired array")
    elif x.ndim < 1:
        compile_error("object of too small depth for desired array")

    if minlength < 0:
        raise ValueError("'minlength' must not be negative")

    mn, mx = x._minmax()
    if mn < 0:
        raise ValueError("'list' argument must have no negative elements")
    max_val = mx + 1
    if minlength > max_val:
        max_val = minlength

    if weights is None:
        result = zeros(max_val, int)
        for i in range(len(x)):
            result._ptr((x._ptr((i, ))[0], ))[0] += 1
        return result
    else:
        weights = asarray(weights).astype(float)

        if weights.ndim > 1:
            compile_error("object too deep for desired array")
        elif weights.ndim < 1:
            compile_error("object of too small depth for desired array")

        if len(weights) != len(x):
            raise ValueError(
                "The weights and list don't have the same length.")

        result = zeros(max_val, float)
        for i in range(len(x)):
            result._ptr((x._ptr((i, ))[0], ))[0] += weights._ptr((i, ))[0]
        return result

def _monotonicity(bins):
    if bins.ndim < 1:
        compile_error("object of too small depth for desired array")
    elif bins.ndim > 1:
        compile_error("object too deep for desired array")

    increasing = True
    decreasing = True

    for i in range(len(bins) - 1):
        a = bins._ptr((i, ))[0]
        b = bins._ptr((i + 1, ))[0]
        if a < b:
            decreasing = False
        if b < a:
            increasing = False
        if not (increasing or decreasing):
            break

    if increasing:
        return 1
    elif decreasing:
        return -1
    else:
        return 0

def digitize(x, bins, right: bool = False):
    x = asarray(x)
    bins = asarray(bins)

    if x.dtype is complex or x.dtype is complex64:
        compile_error("x may not be complex")

    mono = _monotonicity(bins)
    if mono == 0:
        raise ValueError("bins must be monotonically increasing or decreasing")

    # this is backwards because the arguments below are swapped
    side = 'left' if right else 'right'
    if mono == -1:
        # reverse the bins, and invert the results
        return len(bins) - searchsorted(bins[::-1], x, side=side)
    else:
        return searchsorted(bins, x, side=side)

def _ravel_and_check_weights(a, weights):
    # Check a and weights have matching shapes, and ravel both
    a = asarray(a)

    if weights is not None:
        weights2 = asarray(weights)
        if weights2.shape != a.shape:
            raise ValueError('weights should have the same shape as a.')
        weights3 = weights2.ravel()
    else:
        weights3 = weights

    # Ensure that the array is a "subtractable" dtype
    if a.dtype is bool:
        # TODO: change to unsigned int after
        a2 = a.astype(int)
    else:
        a2 = a

    a3 = a2.ravel()
    return a3, weights3

_range = range

def _unsigned_subtract(a, b):
    if isinstance(a, ndarray):
        dt = type(util.coerce(a.dtype, type(b)))
        a = a.astype(dt)
        b = util.cast(b, dt)

        if dt is int:
            c = a - b
            return c.astype(u64)
        elif isinstance(dt, Int):
            return UInt[dt.N](a) - UInt[dt.N](b)
        else:
            return a - b
    else:
        dt = type(util.coerce(type(a), type(b)))
        a = util.cast(a, dt)
        b = util.cast(b, dt)

        if dt is int:
            return u64(a) - u64(b)
        elif isinstance(dt, Int):
            return UInt[dt.N](a) - UInt[dt.N](b)
        else:
            return a - b

def _get_outer_edges(a, range):
    if range is not None:
        first_edge = float(range[0])
        last_edge = float(range[1])
        if first_edge > last_edge:
            raise ValueError('max must be larger than min in range parameter.')
        if not (util.isfinite(first_edge) and util.isfinite(last_edge)):
            raise ValueError(
                f"supplied range of [{first_edge}, {last_edge}] is not finite")
    elif a.size == 0:
        # handle empty arrays. Can't determine range, so use 0-1.
        if not (a.dtype is complex or a.dtype is complex64):
            first_edge, last_edge = 0.0, 1.0
        else:
            first_edge, last_edge = 0 + 0j, 1 + 0j
    else:
        t = a._minmax()
        if not (a.dtype is complex or a.dtype is complex64):
            first_edge, last_edge = float(t[0]), float(t[1])
        else:
            first_edge, last_edge = a._minmax()
        if not (util.isfinite(first_edge) and util.isfinite(last_edge)):
            raise ValueError(
                f"autodetected range of [{first_edge}, {last_edge}] is not finite"
            )

    # expand empty range to avoid divide by zero
    if first_edge == last_edge:
        first_edge = first_edge - 0.5
        last_edge = last_edge + 0.5

    return first_edge, last_edge

def _ptp(x):
    xmin, xmax = x._minmax()
    return _unsigned_subtract(xmax, xmin)

def _hist_bin_sqrt(x):
    return float(_ptp(x)) / util.sqrt(float(x.size))

def _hist_bin_sturges(x, range):
    return float(_ptp(x)) / (util.log2(float(x.size)) + 1.0)

def _hist_bin_rice(x):
    return float(_ptp(x)) / (2.0 * x.size**(1.0 / 3))

def _hist_bin_scott(x):
    return (24.0 * util.PI**0.5 / x.size)**(1.0 / 3.0) * std(x)

def _hist_bin_stone(x, range, histogram):
    n = x.size
    ptp_x = float(_ptp(x))
    if n <= 1 or ptp_x == 0:
        return 0.0

    def jhat(nbins):
        hh = ptp_x / nbins
        p_k = histogram(x, bins=nbins, range=range)[0] / n
        return (2 - (n + 1) * p_k.dot(p_k)) / hh

    nbins_upper_bound = max(100, int(util.sqrt(float(n))))
    min_nbins = 1.
    min_value = float('inf')

    for nbins in _range(1, nbins_upper_bound + 1):
        current_value = jhat(nbins)
        if current_value < min_value:
            min_value = current_value
            min_nbins = nbins

    nbins = min_nbins
    # if nbins == nbins_upper_bound:
    #     warn "The number of bins estimated may be suboptimal."
    return ptp_x / nbins

def _hist_bin_doane(x):
    if x.size > 2:
        sg1 = util.sqrt(6.0 * (x.size - 2) / ((x.size + 1.0) * (x.size + 3)))
        sigma = std(x)
        if sigma > 0.0:
            temp = x - mean(x)
            true_divide(temp, sigma, temp)
            power(temp, 3, temp)
            g1 = mean(temp)
            return float(_ptp(x)) / (float(1.0) + util.log2(float(x.size)) +
                                     util.log2(1.0 + abs(g1) / sg1))
    return 0.0

def _hist_bin_fd(x, range):
    percentiles = percentile(x, [75, 25])
    iqr = subtract(percentiles[0], percentiles[1])
    return 2.0 * iqr * x.size**(-1.0 / 3.0)

def _hist_bin_auto(x, range):
    fd_bw = _hist_bin_fd(x, range)
    sturges_bw = _hist_bin_sturges(x, range)
    if fd_bw:
        return min(fd_bw, sturges_bw)
    else:
        # limited variance, so we return a len dependent bw estimator
        return sturges_bw

def _diff(a):
    a = asarray(a)
    if a.ndim != 1:
        compile_error("[internal error] expected 1-d array")

    ans = empty(a.size - 1, dtype=a.dtype)
    for i in range(ans.size):
        ans.data[i] = a._ptr((i + 1, ))[0] - a._ptr((i, ))[0]
    return ans

def _get_bin_edges(a, bins, range, weights, histogram):

    def get_bin_type(first_edge, last_edge, a):
        T1 = type(util.coerce(type(first_edge), type(last_edge)))
        T2 = type(util.coerce(T1, a.dtype))
        if T2 is int or T2 is byte or isinstance(T2, Int) or isinstance(
                T2, UInt):
            return float()
        else:
            return T2()

    def bin_edges_result(a,
                         bin_edges=None,
                         n_equal_bins=None,
                         first_edge=None,
                         last_edge=None):
        if n_equal_bins is not None:
            bin_type = type(get_bin_type(first_edge, last_edge, a))

            return (linspace(first_edge,
                             last_edge,
                             n_equal_bins + 1,
                             endpoint=True,
                             dtype=bin_type), (first_edge, last_edge,
                                               n_equal_bins))
        else:
            return bin_edges, None

    if isinstance(bins, str):
        bin_name = bins
        # if `bins` is a string for an automatic method,
        # this will replace it with the number of bins calculated
        if bin_name not in ('stone', 'auto', 'doane', 'fd', 'rice', 'scott',
                            'sqrt', 'sturges'):
            raise ValueError(f"{bin_name} is not a valid estimator for `bins`")
        if weights is not None:
            compile_error(
                "Automated estimation of the number of bins is not supported for weighted data"
            )

        tmp1, tmp2 = _get_outer_edges(a, range)
        if not (tmp1 == 0 + 0j and tmp2 == 1 + 0j):
            first_edge, last_edge = tmp1, tmp2
        else:
            first_edge, last_edge = 0, 1

        # truncate the range if needed
        if range is not None:
            keep = (a >= first_edge)
            keep &= (a <= last_edge)
            if not logical_and.reduce(keep):
                a = a[keep]

        if a.size == 0:
            n_equal_bins = 1
        else:
            # Do not call selectors on empty arrays
            if bin_name == 'stone':
                width = _hist_bin_stone(a, (first_edge, last_edge), histogram)
            elif bin_name == 'auto':
                width = _hist_bin_auto(a, (first_edge, last_edge))
            elif bin_name == 'fd':
                width = _hist_bin_fd(a, (first_edge, last_edge))
            if bin_name == 'doane':
                width = _hist_bin_doane(a)
            elif bin_name == 'rice':
                width = _hist_bin_rice(a)
            elif bin_name == 'scott':
                width = _hist_bin_scott(a)
            elif bin_name == 'sqrt':
                width = _hist_bin_sqrt(a)
            elif bin_name == 'sturges':
                width = _hist_bin_sturges(a, (first_edge, last_edge))

            if width:
                n_equal_bins = int(
                    ceil(_unsigned_subtract(last_edge, first_edge) / width))
            else:
                # Width can be zero for some estimators, e.g. FD when the IQR of the data is zero.
                n_equal_bins = 1

        return bin_edges_result(a,
                                n_equal_bins=n_equal_bins,
                                first_edge=first_edge,
                                last_edge=last_edge)

    bins = asarray(bins)

    if bins.ndim == 0:
        n_equal_bins = bins.item()
        if n_equal_bins < 1:
            raise ValueError('`bins` must be positive, when an integer')

        tmp1, tmp2 = _get_outer_edges(a, range)
        if not (tmp1 == 0 + 0j and tmp2 == 1 + 0j):
            first_edge, last_edge = tmp1, tmp2
        else:
            first_edge, last_edge = 0, 1
        return bin_edges_result(a,
                                n_equal_bins=n_equal_bins,
                                first_edge=first_edge,
                                last_edge=last_edge)

    if bins.ndim == 1:
        bin_edges = asarray(bins)

        i = 0
        while i < bin_edges.size - 1:  # range is shadowed by function arg
            if bin_edges._ptr((i, ))[0] > bin_edges._ptr((i + 1, ))[0]:
                raise ValueError(
                    '`bins` must increase monotonically, when an array')
            i += 1

        return bin_edges_result(a, bin_edges=bin_edges)

def _search_sorted_inclusive(a, v):
    # Like `searchsorted`, but where the last item in `v` is placed on the right.
    return concatenate(
        (a.searchsorted(v[:-1], 'left'), a.searchsorted(v[-1:], 'right')))

def _histogram_fast(a, bins=10, range=None):
    # Adapted from Numba's implementation
    # https://github.com/numba/numba/blob/main/numba/np/old_arraymath.py

    def intermediate_type(bins, dtype: type):
        if bins is None or isinstance(bins, int):
            if dtype is float32 or dtype is float16:
                return util.zero(dtype)
            else:
                return util.zero(float)
        else:
            return intermediate_type(None, type(asarray(bins).data[0]))

    a = asarray(a)
    T = type(intermediate_type(bins, a.dtype))

    if range is None:
        if a.size == 0:
            bin_min = util.cast(0, T)
            bin_max = util.cast(1, T)
        else:
            inf = util.inf(T)
            bin_min = inf
            bin_max = -inf
            for idx in util.multirange(a.shape):
                v = util.cast(a._ptr(idx)[0], T)
                if bin_min > v:
                    bin_min = v
                if bin_max < v:
                    bin_max = v
        return _histogram_fast(a, bins, range=(bin_min, bin_max))

    if isinstance(bins, int):
        if bins <= 0:
            raise ValueError("`bins` must be positive, when an integer")

        bin_min, bin_max = range
        bin_min = util.cast(bin_min, T)
        bin_max = util.cast(bin_max, T)

        if not (util.isfinite(bin_min) and util.isfinite(bin_max)):
            raise ValueError(f"supplied range of [{bin_min}, {bin_max}] is not finite")

        if not bin_min <= bin_max:
            raise ValueError("max must be larger than min in range parameter.")

        hist = zeros(bins, int)

        if bin_min == bin_max:
            bin_min -= T(0.5)
            bin_max += T(0.5)

        bin_ratio = util.cast(bins, T) / (bin_max - bin_min)
        for idx in util.multirange(a.shape):
            v = util.cast(a._ptr(idx)[0], T)
            b = int(util.floor((v - bin_min) * bin_ratio))
            if 0 <= b < bins:
                hist.data[b] += 1
            elif v == bin_max:
                hist.data[bins - 1] += 1

        bins_array = linspace(float(bin_min), float(bin_max), bins + 1, dtype=T)
        return hist, bins_array
    else:
        bins = asarray(bins)
        if bins.ndim != 1:
            compile_error("`bins` must be 1d, when an array")

        nbins = len(bins) - 1
        i = 0
        while i < nbins:
            # Note this also catches NaNs
            if not bins._ptr((i, ))[0] <= bins._ptr((i + 1, ))[0]:
                raise ValueError("`bins` must increase monotonically, when an array")
            i += 1

        bin_min = util.cast(bins._ptr((0, ))[0], T)
        bin_max = util.cast(bins._ptr((nbins, ))[0], T)
        hist = zeros(nbins, int)

        if nbins > 0:
            for idx in util.multirange(a.shape):
                v = util.cast(a._ptr(idx)[0], T)
                if not bin_min <= v <= bin_max:
                    # Value is out of bounds, ignore (also catches NaNs)
                    continue
                # Bisect in bins[:-1]
                lo = 0
                hi = nbins - 1
                while lo < hi:
                    # Note the `+ 1` is necessary to avoid an infinite
                    # loop where mid = lo => lo = mid
                    mid = (lo + hi + 1) >> 1
                    if v < util.cast(bins._ptr((mid, ))[0], T):
                        hi = mid - 1
                    else:
                        lo = mid
                hist.data[lo] += 1

        return hist, bins

def histogram(a,
              bins=10,
              range=None,
              density: Static[int] = False,
              weights=None):

    def return_zeros(size, weights):
        if weights is None:
            return zeros(size, dtype=int)
        else:
            return zeros(size, dtype=weights.dtype)

    def histogram_result(n, bin_edges, density: Static[int] = False):
        if density:
            db = array(_diff(bin_edges), float)
            return (n / db / n.sum(), bin_edges)
        else:
            return (n, bin_edges)

    a = asarray(a)

    if (not density and
        weights is None and
        not isinstance(bins, str) and
        not (a.dtype is complex or a.dtype is complex64)):
        return _histogram_fast(a, bins=bins, range=range)

    a, weights = _ravel_and_check_weights(a, weights)

    bin_edges, uniform_bins = _get_bin_edges(a, bins, range, weights,
                                             histogram)

    # We set a block size, as this allows us to iterate over chunks when computing histograms, to minimize memory usage.
    BLOCK = 65536

    # The fast path uses bincount, but that only works for certain types of weight
    simple_weights1: Static[int] = weights is None
    if isinstance(weights, ndarray):
        simple_weights2: Static[int] = (weights.dtype is float
                                        or weights.dtype is complex
                                        or weights.dtype is complex64)

    if uniform_bins is not None and (simple_weights1 or simple_weights2):
        # Fast algorithm for equal bins
        # We now convert values of a to bin indices, under the assumption of equal bin widths (which is valid here).
        first_edge, last_edge, n_equal_bins = uniform_bins

        # Initialize empty histogram
        n = return_zeros(n_equal_bins, weights)

        # Pre-compute histogram scaling factor
        norm_numerator = n_equal_bins
        norm_denom = _unsigned_subtract(last_edge, first_edge)

        # We iterate over blocks
        for i in _range(0, len(a), BLOCK):
            tmp_a = a[i:i + BLOCK]
            if weights is None:
                tmp_w = None
            else:
                tmp_w = weights[i:i + BLOCK]

            # Only include values in the right range
            keep = (tmp_a >= first_edge) & (tmp_a <= last_edge)
            if not logical_and.reduce(keep):
                tmp_a = tmp_a[keep]
                if tmp_w is not None:
                    tmp_w = tmp_w[keep]

            # This cast ensures no type promotions occur below, which gh-10322 make unpredictable
            tmp_a = tmp_a.astype(bin_edges.dtype, copy=False)

            # Compute the bin indices, and for values that lie exactly on last_edge we need to subtract one
            f_indices = ((_unsigned_subtract(tmp_a, first_edge) / norm_denom) *
                         norm_numerator)
            indices = f_indices.astype(int)
            indices[indices == n_equal_bins] -= 1

            decrement = tmp_a < bin_edges[indices]
            indices[decrement] -= 1
            # The last bin includes the right edge. The other bins do not.
            increment = ((tmp_a >= bin_edges[indices + 1])
                         & (indices != n_equal_bins - 1))
            indices[increment] += 1

            if weights is None:
                n += bincount(indices, weights=tmp_w,
                              minlength=n_equal_bins).astype(int)
            elif weights.dtype is complex or weights.dtype is complex64:
                n.real += bincount(indices,
                                   weights=tmp_w.real,
                                   minlength=n_equal_bins)
                n.imag += bincount(indices,
                                   weights=tmp_w.imag,
                                   minlength=n_equal_bins)
            else:
                n += bincount(indices, weights=tmp_w,
                              minlength=n_equal_bins).astype(weights.dtype)
    else:
        # Compute via cumulative histogram
        cum_n = return_zeros(bin_edges.shape, weights)
        if weights is None:
            for i in _range(0, len(a), BLOCK):
                sa = sort(a[i:i + BLOCK])
                cum_n += _search_sorted_inclusive(sa, bin_edges)
        else:
            zero = return_zeros(1, weights)
            for i in _range(0, len(a), BLOCK):
                tmp_a = a[i:i + BLOCK]
                tmp_w = weights[i:i + BLOCK]
                sorting_index = argsort(tmp_a)
                sa = tmp_a[sorting_index]
                sw = tmp_w[sorting_index]
                cw = concatenate((zero, sw.cumsum()))
                bin_index = _search_sorted_inclusive(sa, bin_edges)
                cum_n += cw[bin_index]

        r = _diff(cum_n)
        return histogram_result(r, bin_edges, density)

    return histogram_result(n, bin_edges, density)

def histogram_bin_edges(a, bins=10, range=None, weights=None):
    a, weights = _ravel_and_check_weights(a, weights)
    bin_edges, _ = _get_bin_edges(a, bins, range, weights, histogram)
    return bin_edges