1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00
codon/stdlib/numpy/random/philox.codon
A. R. Shajii b8c1eeed36
2025 updates (#619)
* 2025 updates

* Update ci.yml
2025-01-29 15:41:43 -05:00

199 lines
5.1 KiB
Python

# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
from .seed import SeedSequence
from ..util import zext, itrunc
u128 = UInt[128]
PHILOX_BUFFER_SIZE: Static[int] = 4
philox4x64_rounds: Static[int] = 10
def mulhilo64(a: u64, b: u64):
product = zext(a, u128) * zext(b, u128)
return itrunc(product >> u128(64), u64), itrunc(product, u64)
def _philox4x64bumpkey(key: Tuple[u64, u64]):
v0 = key[0]
v1 = key[1]
v0 += (u64(0x9E3779B9) << u64(32)) | u64(0x7F4A7C15)
v1 += (u64(0xBB67AE85) << u64(32)) | u64(0x84CAA73B)
return (v0, v1)
def _philox4x64round(ctr: Tuple[u64, u64, u64, u64],
key: Tuple[u64, u64]):
c0 = (u64(0xD2E7470E) << u64(32)) | u64(0xE14C6C93)
c1 = (u64(0xCA5A8263) << u64(32)) | u64(0x95121157)
hi0, lo0 = mulhilo64(c0, ctr[0])
hi1, lo1 = mulhilo64(c1, ctr[2])
return (hi1 ^ ctr[1] ^ key[0], lo1, hi0 ^ ctr[3] ^ key[1], lo0)
def philox4x64_R(R: int,
ctr: Tuple[u64, u64, u64, u64],
key: Tuple[u64, u64]):
if R > 0:
ctr = _philox4x64round(ctr, key)
if R > 1:
key = _philox4x64bumpkey(key)
ctr = _philox4x64round(ctr, key)
if R > 2:
key = _philox4x64bumpkey(key)
ctr = _philox4x64round(ctr, key)
if R > 3:
key = _philox4x64bumpkey(key)
ctr = _philox4x64round(ctr, key)
if R > 4:
key = _philox4x64bumpkey(key)
ctr = _philox4x64round(ctr, key)
if R > 5:
key = _philox4x64bumpkey(key)
ctr = _philox4x64round(ctr, key)
if R > 6:
key = _philox4x64bumpkey(key)
ctr = _philox4x64round(ctr, key)
if R > 7:
key = _philox4x64bumpkey(key)
ctr = _philox4x64round(ctr, key)
if R > 8:
key = _philox4x64bumpkey(key)
ctr = _philox4x64round(ctr, key)
if R > 9:
key = _philox4x64bumpkey(key)
ctr = _philox4x64round(ctr, key)
if R > 10:
key = _philox4x64bumpkey(key)
ctr = _philox4x64round(ctr, key)
if R > 11:
key = _philox4x64bumpkey(key)
ctr = _philox4x64round(ctr, key)
if R > 12:
key = _philox4x64bumpkey(key)
ctr = _philox4x64round(ctr, key)
if R > 13:
key = _philox4x64bumpkey(key)
ctr = _philox4x64round(ctr, key)
if R > 14:
key = _philox4x64bumpkey(key)
ctr = _philox4x64round(ctr, key)
if R > 15:
key = _philox4x64bumpkey(key)
ctr = _philox4x64round(ctr, key)
return ctr
class Philox:
ctr: Tuple[u64, u64, u64, u64]
key: Tuple[u64, u64]
buffer_pos: int
buffer: Tuple[u64, u64, u64, u64]
seed: SeedSequence
def __init__(self,
ctr: Tuple[u64, u64, u64, u64],
key: Tuple[u64, u64]):
self.ctr = ctr
self.key = key
self.buffer_pos = PHILOX_BUFFER_SIZE
self.buffer = (u64(0), u64(0), u64(0), u64(0))
def __init__(self, seed):
if not isinstance(seed, SeedSequence):
self.__init__(SeedSequence(seed))
else:
ctr = (u64(0), u64(0), u64(0), u64(0))
key = seed.generate_state(2, u64)
key = (key[0], key[1])
self.__init__(ctr, key)
self.seed = seed
def __get_state__(self):
return (self.ctr, self.key, self.buffer_pos, self.buffer)
def __set_state__(self, state):
ctr, key, buffer_pos, buffer = state
self.ctr = ctr
self.key = key
self.buffer_pos = buffer_pos
self.buffer = buffer
def next64(self):
out = u64(0)
if self.buffer_pos < PHILOX_BUFFER_SIZE:
buf = self.buffer
out = Ptr[u64](__ptr__(buf).as_byte())[self.buffer_pos]
self.buffer_pos += 1
return out
v0, v1, v2, v3 = self.ctr
v0 += u64(1)
if not v0:
v1 += u64(1)
if not v1:
v2 += u64(1)
if not v2:
v3 += u64(1)
self.ctr = (v0, v1, v2, v3)
ct = philox4x64_R(philox4x64_rounds, self.ctr, self.key)
self.buffer = ct
self.buffer_pos = 1
return self.buffer[0]
def jump(self):
v0, v1, v2, v3 = self.ctr
v2 += u64(1)
if not v2:
v3 += 1
self.ctr = (v0, v1, v2, v3)
def advance(self, step: Tuple[u64, u64, u64, u64]):
v0, v1, v2, v3 = self.ctr
carry = False
v0 += u64(1)
carry = not v0
v_orig = v0
v0 += step[0]
if v0 < v_orig and not carry:
carry = True
if carry:
v1 += u64(1)
carry = not v1
v_orig = v1
v1 += step[1]
if v1 < v_orig and not carry:
carry = True
if carry:
v2 += u64(1)
carry = not v2
v_orig = v2
v2 += step[2]
if v2 < v_orig and not carry:
carry = True
if carry:
v3 += u64(1)
carry = not v3
v3 += step[3]
self.ctr = (v0, v1, v2, v3)
def jump_inplace(self, jumps: int):
self.advance((u64(0), u64(jumps), u64(0), u64(0)))