# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>

import bisect
import random
from math import frexp as _frexp, floor as _floor, fabs as _fabs, erf as _erf
from math import (
    gcd as _gcd,
    exp as _exp,
    log as _log,
    sqrt as _sqrt,
    tau as _tau,
    hypot as _hypot,
)

class StatisticsError(Static[Exception]):
    def __init__(self, message: str = ""):
        super().__init__("StatisticsError", message)

def median(data: List[T], T: type) -> float:
    """
    Return the median (middle value) of numeric data.

    When the number of data points is odd, return the middle data point.
    When the number of data points is even, the median is interpolated by
    taking the average of the two middle values
    """
    data = sorted(data)
    n = len(data)
    if n == 0:
        raise StatisticsError("no median for empty data")
    if n % 2 == 1:
        return float(data[n // 2])
    else:
        i = n // 2
        return (data[i - 1] + data[i]) / 2

def median_low(data: List[T], T: type) -> float:
    """
    Return the low median of numeric data.

    When the number of data points is odd, the middle value is returned.
    When it is even, the smaller of the two middle values is returned.
    """
    data = sorted(data)
    n = len(data)
    if n == 0:
        raise StatisticsError("no median for empty data")
    if n % 2 == 1:
        return float(data[n // 2])
    else:
        return float(data[n // 2 - 1])

def median_high(data: List[T], T: type) -> float:
    """
    Return the high median of data.

    When the number of data points is odd, the middle value is returned.
    When it is even, the larger of the two middle values is returned.
    """
    data = sorted(data)
    n = len(data)
    if n == 0:
        raise StatisticsError("no median for empty data")
    return float(data[n // 2])

def _find_lteq(a: List[T], x: float, T: type):
    """
    Locate the leftmost value exactly equal to x
    """
    i = bisect.bisect_left(a, x)
    if i != len(a) and a[i] == x:
        return i
    assert False

def _find_rteq(a: List[T], l: int, x: float, T: type):
    """
    Locate the rightmost value exactly equal to x
    """
    i = bisect.bisect_right(a, x, lo=l)
    if i != (len(a) + 1) and a[i - 1] == x:
        return i - 1
    assert False

def median_grouped(data: List[T], interval: S = 1, T: type, S: type = int) -> float:
    """
    Return the 50th percentile (median) of grouped continuous data.
    """
    data = sorted(data)
    n = len(data)
    if n == 0:
        raise StatisticsError("no median for empty data")
    elif n == 1:
        return float(data[0])

    # Find the value at the midpoint.
    x = float(data[n // 2])
    L = x - float(interval) / 2  # The lower limit of the median interval.

    # Find the position of leftmost occurrence of x in data
    l1 = _find_lteq(data, x)
    # Find the position of rightmost occurrence of x in data[l1...len(data)]
    # Assuming always l1 <= l2
    l2 = _find_rteq(data, l1, x)
    cf = l1
    f = l2 - l1 + 1
    return L + interval * (n / 2 - cf) / f

def mode(data: List[T], T: type) -> T:
    """
    Return the most common data point from discrete or nominal data.
    """
    counter = 0
    elem = data[0]

    for i in data:
        curr_frequency = data.count(i)
        if curr_frequency > counter:
            counter = curr_frequency
            elem = i
    return elem

def multimode(data: List[T], T: type):
    """
    Return a list of the most frequently occurring values.

    Will return more than one result if there are multiple modes
    or an empty list if *data* is empty.
    """
    elem = data[0]
    counter = data.count(elem)
    li = sorted(data)
    mulmode = []

    for i in li:
        curr_frequency = data.count(i)
        if curr_frequency > counter:
            mulmode = []
            mulmode.append(i)
            counter = curr_frequency
        elif curr_frequency == counter and i not in mulmode:
            mulmode.append(i)
    return mulmode

def quantiles(
    data: List[T], n: int = 4, method: str = "exclusive", T: type
) -> List[float]:
    """
    Divide *data* into *n* continuous intervals with equal probability.

    Returns a list of (n - 1) cut points separating the intervals.

    Set *n* to 4 for quartiles (the default).  Set *n* to 10 for deciles.
    Set *n* to 100 for percentiles which gives the 99 cuts points that
    separate *data* into 100 equal sized groups.

    The *data* can be any iterable containing sample.
    The cut points are linearly interpolated between data points.

    If *method* is set to *inclusive*, *data* is treated as population
    data.  The minimum value is treated as the 0th percentile and the
    maximum value is treated as the 100th percentile.
    """
    if n < 1:
        raise StatisticsError("n must be at least 1")
    data = sorted(data)
    ld = len(data)
    if ld < 2:
        raise StatisticsError("must have at least two data points")

    if method == "inclusive":
        m = ld - 1
        result = []
        for i in range(1, n):
            j = i * m // n
            delta = (i * m) - (j * n)
            interpolated = (data[j] * (n - delta) + data[j + 1] * delta) / n
            result.append(interpolated)
        return result
    if method == "exclusive":
        m = ld + 1
        result = []
        for i in range(1, n):
            j = i * m // n  # rescale i to m/n
            j = 1 if j < 1 else ld - 1 if j > ld - 1 else j  # clamp to 1 .. ld-1
            delta = (i * m) - (j * n)  # exact integer math
            interpolated = (data[j - 1] * (n - delta) + data[j] * delta) / n
            result.append(interpolated)
        return result
    raise ValueError(f"Unknown method: {method}")

def _lcm(x: int, y: int):
    """
    Returns the lowest common multiple between x and y
    """
    greater = 0
    if x > y:
        greater = x
    else:
        greater = y

    while True:
        if greater % x == 0 and greater % y == 0:
            lcm = greater
            return lcm
        greater += 1

def _sum(data: List[float]) -> float:
    """
    Return a high-precision sum of the given numeric data as a fraction,
    together with the type to be converted to and the count of items.

    If optional argument ``start`` is given, it is added to the total.
    If ``data`` is empty, ``start`` (defaulting to 0) is returned.

    TODO/CAVEATS
      - The start argument should default to 0 or 0.0
      - Assumes input is floats
    """
    # Neumaier sum
    # https://en.wikipedia.org/wiki/Kahan_summation_algorithm#Further_enhancements
    # https://www.mat.univie.ac.at/~neum/scan/01.pdf (German)
    s = 0.0
    c = 0.0
    i = 0
    N = len(data)
    while i < N:
        x = data[i]
        t = s + x
        if abs(s) >= abs(x):
            c += (s - t) + x
        else:
            c += (x - t) + s
        s = t
        i += 1
    return s + c

def mean(data: List[float]) -> float:
    """
    Return the sample arithmetic mean of data.

    TODO/CAVEATS
      - Assumes input is floats
      - Does not address NAN or INF
    """
    n = len(data)
    if n < 1:
        raise StatisticsError("mean requires at least one data point")
    total = _sum(data)
    return total / n

'''
def fmean(data: List[float]) -> float:
    """
    Convert data to floats and compute the arithmetic mean.

    TODO/CAVEATS
    @jordan- fsum is not implemented in math.seq yet and the above
             mean(data) deals with only floats. Thus this function is passed for now.
    """
    pass
'''

def geometric_mean(data: List[float]) -> float:
    """
    Convert data to floats and compute the geometric mean.

    Raises a StatisticsError if the input dataset is empty,

    TODO/CAVEATS:
      - Assumes input is a list of floats
      - Uses mean instead of fmean for now
      - Does not handle data that contains a zero, or if it contains a negative value.
    """
    if len(data) < 1:
        raise StatisticsError("geometric mean requires a non-empty dataset")
    return _exp(mean(list(map(_log, data))))

def _fail_neg(values: List[float], errmsg: str):
    """
    Iterate over values, failing if any are less than zero.
    """
    for x in values:
        if x < 0:
            raise StatisticsError(errmsg)
        yield x

def harmonic_mean(data: List[float]) -> float:
    """
    Return the harmonic mean of data.

    The harmonic mean, sometimes called the subcontrary mean, is the
    reciprocal of the arithmetic mean of the reciprocals of the data,
    and is often appropriate when averaging quantities which are rates
    or ratios.
    """
    errmsg = "harmonic mean does not support negative values"
    n = len(data)
    if n < 1:
        raise StatisticsError("harmonic_mean requires at least one data point")

    x = data[0]

    if n == 1:
        return x
    total = 0.0

    li = List[float](n)
    for x in _fail_neg(data, errmsg):
        if x == 0.0:
            return 0.0
        li.append(1 / x)
    total = _sum(li)
    return n / total

def _ss(data: List[float], c: float):
    """
    Return sum of square deviations of sequence data.

    If c is None, the mean is calculated in one pass, and the deviations
    from the mean are calculated in a second pass. Otherwise, deviations are
    calculated from c as given.
    """
    total = _sum([(x - c) ** 2 for x in data])
    total2 = _sum([(x - c) for x in data])

    total -= total2 ** 2 / len(data)
    return total

def pvariance(data: List[float], mu: Optional[float] = None):
    """
    Return the population variance of `data`.

    Should contain atleast one value.
    The optional argument mu, if given, should be the mean of
    the data. If it is missing or None, the mean is automatically calculated.

    TODO/CAVEATS:
      - Assumes input is a list of floats
    """
    if mu is None:
        mu = mean(data)

    n = len(data)
    if n < 1:
        raise StatisticsError("pvariance requires at least one data point")

    ss = _ss(data, mu)
    return ss / n

def pstdev(data: List[float], mu: Optional[float] = None):
    """
    Return the square root of the population variance.
    """
    if mu is None:
        mu = mean(data)
    var = pvariance(data, mu)
    return _sqrt(var)

def variance(data: List[float], xbar: Optional[float] = None):
    """
    Return the sample variance of data.

    Shoulw contain atleast two values.
    The optional argument xbar, if given, should be the mean of
    the data. If it is missing or None, the mean is automatically calculated.
    """
    if xbar is None:
        xbar = mean(data)
    n = len(data)
    if n < 2:
        raise StatisticsError("variance requires at least two data points")
    ss = _ss(data, xbar)
    return ss / (n - 1)

def stdev(data, xbar: Optional[float] = None):
    """
    Return the square root of the sample variance.
    """
    if xbar is None:
        xbar = mean(data)
    var = variance(data, xbar)
    return _sqrt(var)

class NormalDist:
    """
    Normal distribution of a random variable
    """

    PRECISION: float
    _mu: float
    _sigma: float

    def __eq__(self, other: NormalDist):
        return (self._mu - other._mu) < self.PRECISION and (
            self._sigma - other._sigma
        ) < self.PRECISION

    def _init(self, mu: float, sigma: float):
        self.PRECISION = 1e-6
        if sigma < 0.0:
            raise StatisticsError("sigma must be non-negative")
        self._mu = mu
        self._sigma = sigma

    def __init__(self, mu, sigma):
        self._init(float(mu), float(sigma))

    def __init__(self, mu):
        self._init(float(mu), 1.0)

    def __init__(self):
        self._init(0.0, 1.0)

    @property
    def mean(self):
        """
        Arithmetic mean of the normal distribution.
        """
        return self._mu

    @property
    def median(self):
        """
        Return the median of the normal distribution
        """
        return self._mu

    @property
    def mode(self):
        """
        Return the mode of the normal distribution

        The mode is the value x where which the probability density
        function (pdf) takes its maximum value.
        """
        return self._mu

    @property
    def stdev(self):
        """
        Standard deviation of the normal distribution.
        """
        return self._sigma

    @property
    def variance(self):
        """
        Square of the standard deviation.
        """
        return self._sigma ** 2.0

    def pdf(self, x):
        """
        Probability density function.  P(x <= X < x+dx) / dx
        """
        variance = self._sigma ** 2.0
        if not variance:
            raise StatisticsError("pdf() not defined when sigma is zero")
        return _exp((x - self._mu) ** 2.0 / (-2.0 * variance)) / _sqrt(_tau * variance)

    def cdf(self, x):
        """
        Cumulative distribution function.  P(X <= x)
        """
        if not self._sigma:
            raise StatisticsError("cdf() not defined when sigma is zero")
        return 0.5 * (1.0 + _erf((x - self._mu) / (self._sigma * _sqrt(2.0))))

    def _normal_dist_inv_cdf(self, p: float, mu: float, sigma: float):
        """
        Wichura, M.J. (1988). "Algorithm AS241: The Percentage Points of the
        Normal Distribution".  Applied Statistics. Blackwell Publishing. 37
        (3): 477–484. doi:10.2307/2347330. JSTOR 2347330.
        """
        q = p - 0.5
        num = 0.0
        den = 0.0
        if _fabs(q) <= 0.425:
            r = 0.180625 - q * q
            # Hash sum: 55.88319_28806_14901_4439
            num = (((((((2.5090809287301226727e+3 * r +
                    3.3430575583588128105e+4) * r +
                    6.7265770927008700853e+4) * r +
                    4.5921953931549871457e+4) * r +
                    1.3731693765509461125e+4) * r +
                    1.9715909503065514427e+3) * r +
                    1.3314166789178437745e+2) * r +
                    3.3871328727963666080e+0) * q
            den = (((((((5.2264952788528545610e+3 * r +
                    2.8729085735721942674e+4) * r +
                    3.9307895800092710610e+4) * r +
                    2.1213794301586595867e+4) * r +
                    5.3941960214247511077e+3) * r +
                    6.8718700749205790830e+2) * r +
                    4.2313330701600911252e+1) * r +
                    1.0)
            x = num / den
            return mu + (x * sigma)
        r = p if q <= 0.0 else 1.0 - p
        r = _sqrt(-_log(r))
        if r <= 5.0:
            r = r - 1.6
            # Hash sum: 49.33206_50330_16102_89036
            num = (((((((7.74545014278341407640e-4 * r +
                    2.27238449892691845833e-2) * r +
                    2.41780725177450611770e-1) * r +
                    1.27045825245236838258e+0) * r +
                    3.64784832476320460504e+0) * r +
                    5.76949722146069140550e+0) * r +
                    4.63033784615654529590e+0) * r +
                    1.42343711074968357734e+0)
            den = (((((((1.05075007164441684324e-9 * r +
                    5.47593808499534494600e-4) * r +
                    1.51986665636164571966e-2) * r +
                    1.48103976427480074590e-1) * r +
                    6.89767334985100004550e-1) * r +
                    1.67638483018380384940e+0) * r +
                    2.05319162663775882187e+0) * r +
                    1.0)
        else:
            r = r - 5.0
            # Hash sum: 47.52583_31754_92896_71629
            num = (((((((2.01033439929228813265e-7 * r +
                    2.71155556874348757815e-5) * r +
                    1.24266094738807843860e-3) * r +
                    2.65321895265761230930e-2) * r +
                    2.96560571828504891230e-1) * r +
                    1.78482653991729133580e+0) * r +
                    5.46378491116411436990e+0) * r +
                    6.65790464350110377720e+0)
            den = (((((((2.04426310338993978564e-15 * r +
                    1.42151175831644588870e-7) * r +
                    1.84631831751005468180e-5) * r +
                    7.86869131145613259100e-4) * r +
                    1.48753612908506148525e-2) * r +
                    1.36929880922735805310e-1) * r +
                    5.99832206555887937690e-1) * r +
                    1.0)
        x = num / den
        if q < 0.0:
            x = -x
        return mu + (x * sigma)

    def inv_cdf(self, p: float):
        """
        Inverse cumulative distribution function.  x : P(X <= x) = p

        Finds the value of the random variable such that the probability of
        the variable being less than or equal to that value equals the given
        probability.
        """
        if p <= 0.0 or p >= 1.0:
            raise StatisticsError("p must be in the range 0.0 < p < 1.0")
        if self._sigma <= 0.0:
            raise StatisticsError("cdf() not defined when sigma at or below zero")
        return self._normal_dist_inv_cdf(p, self._mu, self._sigma)

    def quantiles(self, n: int = 4):
        """
        Divide into *n* continuous intervals with equal probability.

        Returns a list of (n - 1) cut points separating the intervals.

        Set *n* to 4 for quartiles (the default).  Set *n* to 10 for deciles.
        Set *n* to 100 for percentiles which gives the 99 cuts points that
        separate the normal distribution into 100 equal sized groups.
        """
        return [self.inv_cdf(float(i) / float(n)) for i in range(1, n)]

    def overlap(self, other: NormalDist) -> float:
        """
        Compute the overlapping coefficient (OVL) between two normal distributions.

        Measures the agreement between two normal probability distributions.
        Returns a value between 0.0 and 1.0 giving the overlapping area in
        the two underlying probability density functions.
        """
        X, Y = self, other

        if (Y._sigma, Y._mu) < (X._sigma, X._mu):  # sort to assure commutativity
            X, Y = Y, X

        X_var, Y_var = X.variance, Y.variance

        if not X_var or not Y_var:
            raise StatisticsError("overlap() not defined when sigma is zero")

        dv = Y_var - X_var
        dm = _fabs(Y._mu - X._mu)

        if not dv:
            return 1.0 - _erf(dm / (2.0 * X._sigma * _sqrt(2.0)))

        a = X._mu * Y_var - Y._mu * X_var
        b = X._sigma * Y._sigma * _sqrt(dm ** 2.0 + dv * _log(Y_var / X_var))
        x1 = (a + b) / dv
        x2 = (a - b) / dv

        return 1.0 - (_fabs(Y.cdf(x1) - X.cdf(x1)) + _fabs(Y.cdf(x2) - X.cdf(x2)))

    def samples(self, n: int):
        """
        Generate *n* samples for a given mean and standard deviation.
        """
        gauss = random.gauss
        return [gauss(self._mu, self._sigma) for i in range(n)]

    def from_samples(data: List[float]):
        """
        Make a normal distribution instance from sample data.
        TODO/CAVEATS:
          - Assumes input is a list of floats
          - Uses mean instead of fmean for now
        """
        xbar = mean(data)
        return NormalDist(xbar, stdev(data, xbar))

    def __add__(x1: NormalDist, x2: NormalDist):
        """
        Add a constant or another NormalDist instance.
        If *other* is a constant, translate mu by the constant,
        leaving sigma unchanged.
        If *other* is a NormalDist, add both the means and the variances.
        Mathematically, this works only if the two distributions are
        independent or if they are jointly normally distributed.
        """
        return NormalDist(x1._mu + x2._mu, _hypot(x1._sigma, x2._sigma))

    def __add__(x1: NormalDist, x2: float):
        """
        Add a constant or another NormalDist instance.
        If *other* is a constant, translate mu by the constant,
        leaving sigma unchanged.
        If *other* is a NormalDist, add both the means and the variances.
        Mathematically, this works only if the two distributions are
        independent or if they are jointly normally distributed.
        """
        return NormalDist(x1._mu + x2, x1._sigma)

    def __sub__(x1: NormalDist, x2: NormalDist):
        """
        Subtract a constant or another NormalDist instance.
        If *other* is a constant, translate by the constant mu,
        leaving sigma unchanged.
        If *other* is a NormalDist, subtract the means and add the variances.
        Mathematically, this works only if the two distributions are
        independent or if they are jointly normally distributed.
        """
        return NormalDist(x1._mu - x2._mu, _hypot(x1._sigma, x2._sigma))

    def __sub__(x1: NormalDist, x2: float):
        """
        Subtract a constant or another NormalDist instance.
        If *other* is a constant, translate by the constant mu,
        leaving sigma unchanged.
        If *other* is a NormalDist, subtract the means and add the variances.
        Mathematically, this works only if the two distributions are
        independent or if they are jointly normally distributed.
        """
        return NormalDist(x1._mu - x2, x1._sigma)

    def __mul__(x1: NormalDist, x2: float):
        """
        Multiply both mu and sigma by a constant.
        Used for rescaling, perhaps to change measurement units.
        Sigma is scaled with the absolute value of the constant.
        """
        return NormalDist(x1._mu * x2, x1._sigma * _fabs(x2))

    def __truediv__(x1: NormalDist, x2: float):
        """
        Divide both mu and sigma by a constant.
        Used for rescaling, perhaps to change measurement units.
        Sigma is scaled with the absolute value of the constant.
        """
        return NormalDist(x1._mu / x2, x1._sigma / _fabs(x2))

    def __pos__(x1: NormalDist):
        return NormalDist(x1._mu, x1._sigma)

    def __neg__(x1: NormalDist):
        return NormalDist(-x1._mu, x1._sigma)

    def __radd__(x1: NormalDist, x2: float):
        return x1 + x2

    def __rsub__(x1: NormalDist, x2: NormalDist):
        return -(x1 - x2)

    def __rmul__(x1: NormalDist, x2: float):
        return x1 * x2

    def __eq__(x1: NormalDist, x2: NormalDist):
        return x1._mu == x2._mu and x1._sigma == x2._sigma

    def __hash__(self):
        return hash((self._mu, self._sigma))

    def __repr__(self):
        return f"NormalDist(mu={self._mu}, sigma={self._sigma})"