mirror of
https://github.com/exaloop/codon.git
synced 2025-06-03 15:03:52 +08:00
199 lines
5.1 KiB
Python
199 lines
5.1 KiB
Python
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
|
|
|
|
from .seed import SeedSequence
|
|
from ..util import zext, itrunc
|
|
|
|
u128 = UInt[128]
|
|
PHILOX_BUFFER_SIZE: Static[int] = 4
|
|
philox4x64_rounds: Static[int] = 10
|
|
|
|
def mulhilo64(a: u64, b: u64):
|
|
product = zext(a, u128) * zext(b, u128)
|
|
return itrunc(product >> u128(64), u64), itrunc(product, u64)
|
|
|
|
def _philox4x64bumpkey(key: Tuple[u64, u64]):
|
|
v0 = key[0]
|
|
v1 = key[1]
|
|
v0 += (u64(0x9E3779B9) << u64(32)) | u64(0x7F4A7C15)
|
|
v1 += (u64(0xBB67AE85) << u64(32)) | u64(0x84CAA73B)
|
|
return (v0, v1)
|
|
|
|
def _philox4x64round(ctr: Tuple[u64, u64, u64, u64],
|
|
key: Tuple[u64, u64]):
|
|
c0 = (u64(0xD2E7470E) << u64(32)) | u64(0xE14C6C93)
|
|
c1 = (u64(0xCA5A8263) << u64(32)) | u64(0x95121157)
|
|
hi0, lo0 = mulhilo64(c0, ctr[0])
|
|
hi1, lo1 = mulhilo64(c1, ctr[2])
|
|
return (hi1 ^ ctr[1] ^ key[0], lo1, hi0 ^ ctr[3] ^ key[1], lo0)
|
|
|
|
def philox4x64_R(R: int,
|
|
ctr: Tuple[u64, u64, u64, u64],
|
|
key: Tuple[u64, u64]):
|
|
if R > 0:
|
|
ctr = _philox4x64round(ctr, key)
|
|
|
|
if R > 1:
|
|
key = _philox4x64bumpkey(key)
|
|
ctr = _philox4x64round(ctr, key)
|
|
|
|
if R > 2:
|
|
key = _philox4x64bumpkey(key)
|
|
ctr = _philox4x64round(ctr, key)
|
|
|
|
if R > 3:
|
|
key = _philox4x64bumpkey(key)
|
|
ctr = _philox4x64round(ctr, key)
|
|
|
|
if R > 4:
|
|
key = _philox4x64bumpkey(key)
|
|
ctr = _philox4x64round(ctr, key)
|
|
|
|
if R > 5:
|
|
key = _philox4x64bumpkey(key)
|
|
ctr = _philox4x64round(ctr, key)
|
|
|
|
if R > 6:
|
|
key = _philox4x64bumpkey(key)
|
|
ctr = _philox4x64round(ctr, key)
|
|
|
|
if R > 7:
|
|
key = _philox4x64bumpkey(key)
|
|
ctr = _philox4x64round(ctr, key)
|
|
|
|
if R > 8:
|
|
key = _philox4x64bumpkey(key)
|
|
ctr = _philox4x64round(ctr, key)
|
|
|
|
if R > 9:
|
|
key = _philox4x64bumpkey(key)
|
|
ctr = _philox4x64round(ctr, key)
|
|
|
|
if R > 10:
|
|
key = _philox4x64bumpkey(key)
|
|
ctr = _philox4x64round(ctr, key)
|
|
|
|
if R > 11:
|
|
key = _philox4x64bumpkey(key)
|
|
ctr = _philox4x64round(ctr, key)
|
|
|
|
if R > 12:
|
|
key = _philox4x64bumpkey(key)
|
|
ctr = _philox4x64round(ctr, key)
|
|
|
|
if R > 13:
|
|
key = _philox4x64bumpkey(key)
|
|
ctr = _philox4x64round(ctr, key)
|
|
|
|
if R > 14:
|
|
key = _philox4x64bumpkey(key)
|
|
ctr = _philox4x64round(ctr, key)
|
|
|
|
if R > 15:
|
|
key = _philox4x64bumpkey(key)
|
|
ctr = _philox4x64round(ctr, key)
|
|
|
|
return ctr
|
|
|
|
class Philox:
|
|
ctr: Tuple[u64, u64, u64, u64]
|
|
key: Tuple[u64, u64]
|
|
buffer_pos: int
|
|
buffer: Tuple[u64, u64, u64, u64]
|
|
seed: SeedSequence
|
|
|
|
def __init__(self,
|
|
ctr: Tuple[u64, u64, u64, u64],
|
|
key: Tuple[u64, u64]):
|
|
self.ctr = ctr
|
|
self.key = key
|
|
self.buffer_pos = PHILOX_BUFFER_SIZE
|
|
self.buffer = (u64(0), u64(0), u64(0), u64(0))
|
|
|
|
def __init__(self, seed):
|
|
if not isinstance(seed, SeedSequence):
|
|
self.__init__(SeedSequence(seed))
|
|
else:
|
|
ctr = (u64(0), u64(0), u64(0), u64(0))
|
|
key = seed.generate_state(2, u64)
|
|
key = (key[0], key[1])
|
|
self.__init__(ctr, key)
|
|
self.seed = seed
|
|
|
|
def __get_state__(self):
|
|
return (self.ctr, self.key, self.buffer_pos, self.buffer)
|
|
|
|
def __set_state__(self, state):
|
|
ctr, key, buffer_pos, buffer = state
|
|
self.ctr = ctr
|
|
self.key = key
|
|
self.buffer_pos = buffer_pos
|
|
self.buffer = buffer
|
|
|
|
def next64(self):
|
|
out = u64(0)
|
|
|
|
if self.buffer_pos < PHILOX_BUFFER_SIZE:
|
|
buf = self.buffer
|
|
out = Ptr[u64](__ptr__(buf).as_byte())[self.buffer_pos]
|
|
self.buffer_pos += 1
|
|
return out
|
|
|
|
v0, v1, v2, v3 = self.ctr
|
|
v0 += u64(1)
|
|
if not v0:
|
|
v1 += u64(1)
|
|
if not v1:
|
|
v2 += u64(1)
|
|
if not v2:
|
|
v3 += u64(1)
|
|
|
|
self.ctr = (v0, v1, v2, v3)
|
|
ct = philox4x64_R(philox4x64_rounds, self.ctr, self.key)
|
|
self.buffer = ct
|
|
self.buffer_pos = 1
|
|
return self.buffer[0]
|
|
|
|
def jump(self):
|
|
v0, v1, v2, v3 = self.ctr
|
|
v2 += u64(1)
|
|
if not v2:
|
|
v3 += 1
|
|
self.ctr = (v0, v1, v2, v3)
|
|
|
|
def advance(self, step: Tuple[u64, u64, u64, u64]):
|
|
v0, v1, v2, v3 = self.ctr
|
|
carry = False
|
|
|
|
v0 += u64(1)
|
|
carry = not v0
|
|
v_orig = v0
|
|
v0 += step[0]
|
|
if v0 < v_orig and not carry:
|
|
carry = True
|
|
|
|
if carry:
|
|
v1 += u64(1)
|
|
carry = not v1
|
|
v_orig = v1
|
|
v1 += step[1]
|
|
if v1 < v_orig and not carry:
|
|
carry = True
|
|
|
|
if carry:
|
|
v2 += u64(1)
|
|
carry = not v2
|
|
v_orig = v2
|
|
v2 += step[2]
|
|
if v2 < v_orig and not carry:
|
|
carry = True
|
|
|
|
if carry:
|
|
v3 += u64(1)
|
|
carry = not v3
|
|
v3 += step[3]
|
|
|
|
self.ctr = (v0, v1, v2, v3)
|
|
|
|
def jump_inplace(self, jumps: int):
|
|
self.advance((u64(0), u64(jumps), u64(0), u64(0)))
|