# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>

from ..routines import *

DEFAULT_POOL_SIZE = 4
INIT_A            = u32(0x43b0d7e5)
MULT_A            = u32(0x931e8875)
INIT_B            = u32(0x8b51f9dd)
MULT_B            = u32(0x58f38ded)
MIX_MULT_L        = u32(0xca01f9dd)
MIX_MULT_R        = u32(0x4973f715)
XSHIFT            = u32(4 * 8 // 2)
MASK32            = u32(0xFFFFFFFF)

def _int_to_uint32_array(n: int):
    if n < 0:
        raise ValueError("expected non-negative integer")
    elif n < int(MASK32):
        return array((u32(n),))
    else:
        return array((u32(n & 0xFFFFFFFF), u32(n >> 32)))

def _coerce_to_uint32_array(x):
    if isinstance(x, ndarray):
        if type(x.data[0]) is u32:
            return x.copy()

    if isinstance(x, str):
        if x.startswith('0x'):
            return _int_to_uint32_array(int(x, base=16))
        elif x.isdigit():
            return _int_to_uint32_array(int(x))
        else:
            raise ValueError("unrecognized seed string")

    if isinstance(x, int):
        return _int_to_uint32_array(x)

    if isinstance(x, float):
        compile_error("seed must be integer")

    if len(x) == 0:
        return empty(0, dtype=u32)
    subseqs = [_coerce_to_uint32_array(v) for v in x]
    return concatenate(subseqs)

def hashmix(value: u32, hash_const: u32):
    value ^= hash_const
    hash_const *= MULT_A
    value *= hash_const
    value ^= value >> XSHIFT
    return value, hash_const

def mix(x: u32, y: u32):
    result = (MIX_MULT_L * x - MIX_MULT_R * y)
    result ^= result >> XSHIFT
    return result

# TODO: use NumPy's default seeding scheme,
#       which falls back to Python's os.urandom()
def random_seed_time_pid():
    now = _C.seq_time() * 1000
    key = empty((5,), dtype=u32)
    key[0] = u32(now & 0xFFFFFFFF)
    key[1] = u32(now >> 32)
    key[2] = u32(_C.seq_pid())
    now = _C.seq_time_monotonic()
    key[3] = u32(now & 0xFFFFFFFF)
    key[4] = u32(now >> 32)
    return key

class SeedSequence:
    entropy: ndarray[u32, 1]
    _spawn_key: Optional[List[int]]
    pool_size: int
    n_children_spawned: int
    pool: ndarray[u32, 1]

    @property
    def spawn_key(self) -> List[int]:
        if self._spawn_key is not None:
            return self._spawn_key.copy()
        else:
            return []

    def __init__(self,
                 entropy = None,
                 spawn_key: Optional[List[int]] = None,
                 pool_size: int = DEFAULT_POOL_SIZE,
                 n_children_spawned: int = 0):
        if pool_size < DEFAULT_POOL_SIZE:
            raise ValueError(f"The size of the entropy pool should be at least {DEFAULT_POOL_SIZE}")

        if entropy is None:
            run_entropy = random_seed_time_pid()
        else:
            run_entropy = _coerce_to_uint32_array(entropy)

        spawn_entropy = _coerce_to_uint32_array(spawn_key) if spawn_key is not None else empty(0, dtype=u32)

        # assemble entropy
        if len(spawn_entropy) > 0 and len(run_entropy) < pool_size:
            diff = pool_size - len(run_entropy)
            run_entropy = concatenate((run_entropy, zeros(diff, dtype=u32)))
        entropy_array = concatenate((run_entropy, spawn_entropy))

        self.entropy = entropy_array  # standard NumPy stores original entropy, but much easier to just store array
        self._spawn_key = spawn_key
        self.pool_size = pool_size
        self.n_children_spawned = n_children_spawned
        self.pool = zeros(pool_size, dtype=u32)

        # mix entropy
        mixer = self.pool
        hash_const = INIT_A
        for i in range(len(mixer)):
            if i < len(entropy_array):
                mixer[i], hash_const = hashmix(entropy_array[i], hash_const)
            else:
                mixer[i], hash_const = hashmix(u32(0), hash_const)

        for i_src in range(len(mixer)):
            for i_dst in range(len(mixer)):
                if i_src != i_dst:
                    value, hash_const = hashmix(mixer[i_src], hash_const)
                    mixer[i_dst] = mix(mixer[i_dst], value)

        for i_src in range(len(mixer), len(entropy_array)):
            for i_dst in range(len(mixer)):
                value, hash_const = hashmix(entropy_array[i_src], hash_const)
                mixer[i_dst] = mix(mixer[i_dst], value)

    def generate_state(self, n_words: int, dtype: type = u32):
        hash_const = INIT_B
        if dtype is u64:
            n_words *= 2
        elif dtype is not u32:
            compile_error("only support uint32 or uint64")

        state = zeros(n_words, dtype=u32)
        pool = self.pool
        npool = len(self.pool)

        for i_dst in range(n_words):
            data_val = pool[i_dst % npool]
            data_val ^= hash_const
            hash_const *= MULT_B
            data_val *= hash_const
            data_val ^= data_val >> XSHIFT
            state[i_dst] = data_val

        if dtype is u64:
            return ndarray((n_words // 2,), Ptr[u64](state.data.as_byte()))
        else:
            return state

    def spawn(self, n_children: int):
        seqs = []
        sk: List[int] = self._spawn_key if self._spawn_key is not None else []
        for i in range(self.n_children_spawned,
                       self.n_children_spawned + n_children):
            seqs.append(SeedSequence(
                self.entropy,
                spawn_key=(sk + [i]),
                pool_size=self.pool_size
            ))
        return seqs