# Copyright (C) 2022-2025 Exaloop Inc. from .ndarray import ndarray from .routines import asarray, isreal import ndmath as nx import util def _complex_type(dtype: type): if (dtype is float32 or dtype is u8 or dtype is i8 or dtype is u16 or dtype is i16 or dtype is complex64): return complex64() else: return complex() def _tocomplex(arr, op, C: type): return arr.map(lambda x: op(util.cast(x, C))) def _isreal(x): if isinstance(x, complex) or isinstance(x, complex64): return not bool(x.imag) return True def _real_lt_zero(x): T = type(x) return _isreal(x) and x < T(0) def _real_abs_gt_1(x): T = type(x) return _isreal(x) and abs(x) > T(1) def _unary_emath_op(x, op, cond): x = asarray(x) if x.dtype is complex or x.dtype is complex64: return op(x) ndim: Static[int] = x.ndim C = type(_complex_type(x.dtype)) F = type(util.to_float(util.zero(x.dtype))) O = Union[ndarray[F, ndim], ndarray[C, ndim]] if x._any(cond): return O(_tocomplex(x, op, C)) else: return O(op(x)) def sqrt(x): return _unary_emath_op(x, nx.sqrt, _real_lt_zero) def log(x): return _unary_emath_op(x, nx.log, _real_lt_zero) def log10(x): return _unary_emath_op(x, nx.log10, _real_lt_zero) def log2(x): return _unary_emath_op(x, nx.log2, _real_lt_zero) def arccos(x): return _unary_emath_op(x, nx.arccos, _real_abs_gt_1) def arcsin(x): return _unary_emath_op(x, nx.arcsin, _real_abs_gt_1) def arctanh(x): return _unary_emath_op(x, nx.arctanh, _real_abs_gt_1) def logn(n, x): return log(x) / log(n) # power