# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>

from .ndarray import ndarray
from .routines import asarray, isreal
import ndmath as nx
import util

def _complex_type(dtype: type):
    if (dtype is float32 or
        dtype is u8 or
        dtype is i8 or
        dtype is u16 or
        dtype is i16 or
        dtype is complex64):
        return complex64()
    else:
        return complex()

def _tocomplex(arr, op, C: type):
    return arr.map(lambda x: op(util.cast(x, C)))

def _isreal(x):
    if isinstance(x, complex) or isinstance(x, complex64):
        return not bool(x.imag)
    return True

def _real_lt_zero(x):
    T = type(x)
    return _isreal(x) and x < T(0)

def _real_abs_gt_1(x):
    T = type(x)
    return _isreal(x) and abs(x) > T(1)

def _unary_emath_op(x, op, cond):
    x = asarray(x)

    if x.dtype is complex or x.dtype is complex64:
        return op(x)

    ndim: Static[int] = x.ndim
    C = type(_complex_type(x.dtype))
    F = type(util.to_float(util.zero(x.dtype)))
    O = Union[ndarray[F, ndim], ndarray[C, ndim]]

    if x._any(cond):
        return O(_tocomplex(x, op, C))
    else:
        return O(op(x))

def sqrt(x):
    return _unary_emath_op(x, nx.sqrt, _real_lt_zero)

def log(x):
    return _unary_emath_op(x, nx.log, _real_lt_zero)

def log10(x):
    return _unary_emath_op(x, nx.log10, _real_lt_zero)

def log2(x):
    return _unary_emath_op(x, nx.log2, _real_lt_zero)

def arccos(x):
    return _unary_emath_op(x, nx.arccos, _real_abs_gt_1)

def arcsin(x):
    return _unary_emath_op(x, nx.arcsin, _real_abs_gt_1)

def arctanh(x):
    return _unary_emath_op(x, nx.arctanh, _real_abs_gt_1)

def logn(n, x):
    return log(x) / log(n)


# power