# Copyright (C) 2022-2025 Exaloop Inc. from .seed import SeedSequence from ..util import zext, itrunc u128 = UInt[128] PHILOX_BUFFER_SIZE: Static[int] = 4 philox4x64_rounds: Static[int] = 10 def mulhilo64(a: u64, b: u64): product = zext(a, u128) * zext(b, u128) return itrunc(product >> u128(64), u64), itrunc(product, u64) def _philox4x64bumpkey(key: Tuple[u64, u64]): v0 = key[0] v1 = key[1] v0 += (u64(0x9E3779B9) << u64(32)) | u64(0x7F4A7C15) v1 += (u64(0xBB67AE85) << u64(32)) | u64(0x84CAA73B) return (v0, v1) def _philox4x64round(ctr: Tuple[u64, u64, u64, u64], key: Tuple[u64, u64]): c0 = (u64(0xD2E7470E) << u64(32)) | u64(0xE14C6C93) c1 = (u64(0xCA5A8263) << u64(32)) | u64(0x95121157) hi0, lo0 = mulhilo64(c0, ctr[0]) hi1, lo1 = mulhilo64(c1, ctr[2]) return (hi1 ^ ctr[1] ^ key[0], lo1, hi0 ^ ctr[3] ^ key[1], lo0) def philox4x64_R(R: int, ctr: Tuple[u64, u64, u64, u64], key: Tuple[u64, u64]): if R > 0: ctr = _philox4x64round(ctr, key) if R > 1: key = _philox4x64bumpkey(key) ctr = _philox4x64round(ctr, key) if R > 2: key = _philox4x64bumpkey(key) ctr = _philox4x64round(ctr, key) if R > 3: key = _philox4x64bumpkey(key) ctr = _philox4x64round(ctr, key) if R > 4: key = _philox4x64bumpkey(key) ctr = _philox4x64round(ctr, key) if R > 5: key = _philox4x64bumpkey(key) ctr = _philox4x64round(ctr, key) if R > 6: key = _philox4x64bumpkey(key) ctr = _philox4x64round(ctr, key) if R > 7: key = _philox4x64bumpkey(key) ctr = _philox4x64round(ctr, key) if R > 8: key = _philox4x64bumpkey(key) ctr = _philox4x64round(ctr, key) if R > 9: key = _philox4x64bumpkey(key) ctr = _philox4x64round(ctr, key) if R > 10: key = _philox4x64bumpkey(key) ctr = _philox4x64round(ctr, key) if R > 11: key = _philox4x64bumpkey(key) ctr = _philox4x64round(ctr, key) if R > 12: key = _philox4x64bumpkey(key) ctr = _philox4x64round(ctr, key) if R > 13: key = _philox4x64bumpkey(key) ctr = _philox4x64round(ctr, key) if R > 14: key = _philox4x64bumpkey(key) ctr = _philox4x64round(ctr, key) if R > 15: key = _philox4x64bumpkey(key) ctr = _philox4x64round(ctr, key) return ctr class Philox: ctr: Tuple[u64, u64, u64, u64] key: Tuple[u64, u64] buffer_pos: int buffer: Tuple[u64, u64, u64, u64] seed: SeedSequence def __init__(self, ctr: Tuple[u64, u64, u64, u64], key: Tuple[u64, u64]): self.ctr = ctr self.key = key self.buffer_pos = PHILOX_BUFFER_SIZE self.buffer = (u64(0), u64(0), u64(0), u64(0)) def __init__(self, seed): if not isinstance(seed, SeedSequence): self.__init__(SeedSequence(seed)) else: ctr = (u64(0), u64(0), u64(0), u64(0)) key = seed.generate_state(2, u64) key = (key[0], key[1]) self.__init__(ctr, key) self.seed = seed def __get_state__(self): return (self.ctr, self.key, self.buffer_pos, self.buffer) def __set_state__(self, state): ctr, key, buffer_pos, buffer = state self.ctr = ctr self.key = key self.buffer_pos = buffer_pos self.buffer = buffer def next64(self): out = u64(0) if self.buffer_pos < PHILOX_BUFFER_SIZE: buf = self.buffer out = Ptr[u64](__ptr__(buf).as_byte())[self.buffer_pos] self.buffer_pos += 1 return out v0, v1, v2, v3 = self.ctr v0 += u64(1) if not v0: v1 += u64(1) if not v1: v2 += u64(1) if not v2: v3 += u64(1) self.ctr = (v0, v1, v2, v3) ct = philox4x64_R(philox4x64_rounds, self.ctr, self.key) self.buffer = ct self.buffer_pos = 1 return self.buffer[0] def jump(self): v0, v1, v2, v3 = self.ctr v2 += u64(1) if not v2: v3 += 1 self.ctr = (v0, v1, v2, v3) def advance(self, step: Tuple[u64, u64, u64, u64]): v0, v1, v2, v3 = self.ctr carry = False v0 += u64(1) carry = not v0 v_orig = v0 v0 += step[0] if v0 < v_orig and not carry: carry = True if carry: v1 += u64(1) carry = not v1 v_orig = v1 v1 += step[1] if v1 < v_orig and not carry: carry = True if carry: v2 += u64(1) carry = not v2 v_orig = v2 v2 += step[2] if v2 < v_orig and not carry: carry = True if carry: v3 += u64(1) carry = not v3 v3 += step[3] self.ctr = (v0, v1, v2, v3) def jump_inplace(self, jumps: int): self.advance((u64(0), u64(jumps), u64(0), u64(0)))