1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00
codon/stdlib/numpy/random/mt19937.codon

262 lines
17 KiB
Python
Raw Normal View History

# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
from .seed import SeedSequence
N: Static[int] = 624
M: Static[int] = 397
MATRIX_A = u32(0x9908b0df) # constant vector a
UPPER_MASK = u32(0x80000000) # most significant w-r bits
LOWER_MASK = u32(0x7fffffff) # least significant r bits
# parameters for computing jump
W_SIZE: Static[int] = 32 # size of unsigned long
MEXP : Static[int] = 19937
P_SIZE: Static[int] = ((MEXP // W_SIZE) + 1)
LSB : Static[int] = 0x00000001
QQ : Static[int] = 7
LL : Static[int] = 128 # LL = 2^(QQ)
def get_coef(pf: Ptr[u64], deg: int):
return bool(pf[deg >> 5] & u64(LSB << (deg & 0x1f)))
@tuple
class MT19937:
data: Ptr[u32]
seed: SeedSequence
def __new__(seed, legacy: Static[int] = False):
if not isinstance(seed, SeedSequence):
return MT19937(SeedSequence(seed))
else:
g = MT19937(Ptr[u32](N + 1), seed)
if legacy:
if not isinstance(seed, int):
compile_error("'seed' must be an int for legacy seeding")
if seed < 0 or seed > 0xffffffff:
raise ValueError("Seed must be between 0 and 2**32 - 1")
g.seed_legacy(seed)
else:
g.data[0] = u32(N - 1)
val = seed.generate_state(N, u32)
key = g.state
key[0] = u32(0x80000000)
for i in range(1, N):
key[i] = val[i]
return g
def seed_legacy(self, seed: int):
seed = u32(seed)
seed &= u32(0xffffffff)
for pos in range(N):
self.state[pos] = u32(seed)
seed = (u32(1812433253) * (seed ^ (seed >> u32(30))) + u32(pos + 1)) & u32(0xffffffff)
self.data[0] = u32(N)
@property
def pos(self):
return int(self.data[0])
@property
def state(self):
return self.data + 1
def __get_state__(self):
from internal.gc import sizeof
p = Ptr[u32](N + 1)
str.memcpy(p.as_byte(), self.data.as_byte(), (N + 1) * sizeof(u32))
return (p,)
def __set_state__(self, state):
from internal.gc import sizeof
p = state[0]
str.memcpy(self.data.as_byte(), p.as_byte(), (N + 1) * sizeof(u32))
def copy_state(self, state: MT19937):
str.memcpy(self.data.as_byte(), state.data.as_byte(), (N + 1) * sizeof(u32))
def genrand_int32(self):
mag01t = (u32(0), MATRIX_A)
mag01 = Ptr[u32](__ptr__(mag01t).as_byte())
mt = self.state
if self.pos >= N:
kk = 0
while kk < int(N - M):
y = (mt[kk] & UPPER_MASK) | (mt[kk+1] & LOWER_MASK)
mt[kk] = mt[kk + M] ^ (y >> u32(1)) ^ mag01[int(y & u32(1))]
kk += 1
while kk < int(N - 1):
y = (mt[kk] & UPPER_MASK) | (mt[kk+1] & LOWER_MASK)
mt[kk] = mt[kk+(M-N)] ^ (y >> u32(1)) ^ mag01[int(y & u32(1))]
kk += 1
y = (mt[N-1] & UPPER_MASK) | (mt[0] & LOWER_MASK)
mt[N-1] = mt[M-1] ^ (y >> u32(1)) ^ mag01[int(y & u32(1))]
self.data[0] = u32(0)
i = self.pos
y = mt[i]
self.data[0] = u32(i + 1)
y ^= (y >> u32(11))
y ^= (y << u32(7)) & u32(0x9d2c5680)
y ^= (y << u32(15)) & u32(0xefc60000)
y ^= (y >> u32(18))
return y
def genrand_res53(self):
a = self.genrand_int32() >> u32(5)
b = self.genrand_int32() >> u32(6)
return (int(a) * 67108864.0 + int(b)) * (1.0 / 9007199254740992.0)
def random(self):
return self.genrand_res53()
def init_u32(self, s: u32):
mt = self.state
mt[0] = s
for mti in range(1, N):
mt[mti] = (u32(1812433253) * (mt[mti-1] ^ (mt[mti-1] >> u32(30))) + u32(mti))
self.data[0] = u32(N)
def init_array(self, init_key: Ptr[u32], key_length: int):
mt = self.state
self.init_u32(u32(19650218))
i = 1
j = 0
k = N if N > key_length else key_length
while k:
mt[i] = (mt[i] ^ ((mt[i-1] ^ (mt[i-1] >> u32(30))) * u32(1664525))) + init_key[j] + u32(j)
i += 1
j += 1
if i >= N:
mt[0] = mt[N - 1]
i = 1
if j >= key_length:
j = 0
k -= 1
k = N - 1
while k:
mt[i] = (mt[i] ^ ((mt[i-1] ^ (mt[i-1] >> u32(30))) * u32(1566083941))) - u32(i)
i += 1
if i >= N:
mt[0] = mt[N - 1]
i = 1
k -= 1
mt[0] = u32(0x80000000)
def seed_cpython(self, s: int):
init_key = (u32(s & ((1 << 32) - 1)), u32(s >> 32))
self.init_array(Ptr[u32](__ptr__(init_key).as_byte()), 2 if init_key[1] else 1)
# jump
def gen_next(self):
mag01t = (u32(0), MATRIX_A)
mag01 = Ptr[u32](__ptr__(mag01t).as_byte())
mt = self.state
kk = self.pos
if kk < N - M:
y = (mt[kk] & UPPER_MASK) | (mt[kk+1] & LOWER_MASK)
mt[kk] = mt[kk + M] ^ (y >> u32(1)) ^ mag01[int(y & u32(1))]
self.data[0] += u32(1)
kk += 1
elif kk < N - 1:
y = (mt[kk] & UPPER_MASK) | (mt[kk+1] & LOWER_MASK)
mt[kk] = mt[kk+(M-N)] ^ (y >> u32(1)) ^ mag01[int(y & u32(1))]
self.data[0] += u32(1)
kk += 1
elif kk == N - 1:
y = (mt[N-1] & UPPER_MASK) | (mt[0] & LOWER_MASK)
mt[N-1] = mt[M-1] ^ (y >> u32(1)) ^ mag01[int(y & u32(1))]
self.data[0] = u32(0)
def add_state(self, state2: MT19937):
pt1 = self.pos
pt2 = state2.pos
mt1 = self.state
mt2 = state2.state
i = 0
if pt2 - pt1 >= 0:
while i < N - pt2:
mt1[i + pt1] ^= mt2[i + pt2]
i += 1
while i < N - pt1:
mt1[i + pt1] ^= mt2[i + (pt2 - N)]
i += 1
while i < N:
mt1[i + (pt1 - N)] ^= mt2[i + (pt2 - N)]
i += 1
else:
while i < N - pt1:
mt1[i + pt1] ^= mt2[i + pt2]
i += 1
while i < N - pt2:
mt1[i + (pt1 - N)] ^= mt2[i + pt2]
i += 1
while i < N:
mt1[i + (pt1 - N)] ^= mt2[i + (pt2 - N)]
i += 1
def horner1(self, pf: Ptr[u64]):
from internal.gc import sizeof
i = MEXP - 1
temp_data = __array__[u32](N + 1)
str.memset(temp_data.ptr.as_byte(), byte(0), (N + 1) * sizeof(u32))
temp = MT19937(temp_data.ptr)
while not get_coef(pf, i):
i -= 1
if i > 0:
temp.copy_state(self)
temp.gen_next()
i -= 1
while i > 0:
if get_coef(pf, i):
temp.add_state(self)
temp.gen_next()
i -= 1
if get_coef(pf, 0):
temp.add_state(self)
elif i == 0:
temp.copy_state(self)
self.copy_state(temp)
def jump_state(self):
if self.pos >= N:
self.data[0] = u32(0)
self.horner1(MT19937.poly_coef())
def jump_inplace(self, jumps: int):
for _ in range(jumps):
self.jump_state()
def next32(self):
return self.genrand_int32()
def next_double(self):
return self.genrand_res53()
# TODO: somehow add wrapping to this...
@pure
@llvm
def poly_coef() -> Ptr[u64]:
@pf = private unnamed_addr constant [624 x i64] [i64 1927166307, i64 3044056772, i64 2284297142, i64 2820929765, i64 651705945, i64 69149273, i64 3892165397, i64 2337412983, i64 1219880790, i64 3207074517, i64 3836784057, i64 189286826, i64 1049791363, i64 3916249550, i64 2942382547, i64 166392552, i64 861176918, i64 3246476411, i64 2302311555, i64 4273801148, i64 29196903, i64 1363664063, i64 3802562022, i64 2600400244, i64 3090369801, i64 4040416970, i64 1432485208, i64 3632558139, i64 4015816763, i64 3013316418, i64 551532385, i64 3592224467, i64 3479125595, i64 1195467127, i64 2391032553, i64 2393493419, i64 1482493632, i64 1625159565, i64 748389672, i64 4042774030, i64 2998615036, i64 3393119101, i64 2177492569, i64 2265897321, i64 2507383006, i64 3461498961, i64 2003319700, i64 1942857197, i64 1455226044, i64 4097545580, i64 529653268, i64 3204756480, i64 2486748289, i64 495294513, i64 3396001954, i64 2643963605, i64 2655404568, i64 3881604377, i64 624710790, i64 3443737948, i64 1941294296, i64 2139259604, i64 3368734020, i64 422436761, i64 3602810182, i64 1384691081, i64 3035786407, i64 2551797119, i64 537227499, i64 65486120, i64 642436100, i64 2023822537, i64 2515598203, i64 1122953367, i64 2882306242, i64 1743213032, i64 321965189, i64 336496623, i64 2436602518, i64 3556266590, i64 1055117829, i64 463541647, i64 743234441, i64 527083645, i64 2606668346, i64 2274046499, i64 2761475053, i64 2760669048, i64 2538258534, i64 487125077, i64 3365962306, i64 3604906217, i64 2714700608, i64 680709708, i64 2217161159, i64 1614899374, i64 3710119533, i64 3201300658, i64 3752620679, i64 2755041105, i64 3129723037, i64 1247297753, i64 2812642690, i64 4114340845, i64 3485092247, i64 2752814364, i64 3586551747, i64 4073138437, i64 3462966585, i64 2924318358, i64 4061374901, i64 3314086806, i64 2640385723, i64 744590670, i64 3007586513, i64 3959120371, i64 997207767, i64 3420235506, i64 2092400998, i64 3190305685, i64 60965738, i64 549507222, i64 3784354415, i64 3209279509, i64 1238863299, i64 2605037827, i64 178570440, i64 1743491299, i64 4079686640, i64 2136795825, i64 3435430548, i64 1679732443, i64 1835708342, i64 2159367000, i64 1924487218, i64 4059723674, i64 996192116, i64 2308091645, i64 1336281586, i64 674600050, i64 1642572529, i64 1383973289, i64 2202960007, i64 3165481279, i64 3385474038, i64 2501318550, i64 2671842890, i64 3084085109, i64 3475033915, i64 1551329147, i64 4101397249, i64 1205851807, i64 3641536021, i64 3607635071, i64 1609126163, i64 2910426664, i64 3324508658, i64 4244311266, i64 254034382, i64 1258304384, i64 1914048768, i64 1358592011, i64 527610138, i64 3072108727, i64 4289413885, i64 1417001678, i64 2445445945, i64 896462712, i64 339855811, i64 3699378285, i64 2529457297, i64 3049459401, i64 2723472429, i64 2838633181, i64 2520397330, i64 3272339035, i64 1667003847, i64 3742634787, i64 942706520, i64 2301027215, i64 1907791250, i64 2306299096, i64 1021173342, i64 1539334516, i64 2907834628, i64 3199959207, i64 1556251860, i64 3642580275, i64 2355865416, i64 285806145, i64 867932457, i64 1177354172, i64 3291107470, i64 4022765061, i64 1613380116, i64 588147929, i64 650574324, i64 1236855601, i64 1371354511, i64 2085218212, i64 1203081931, i64 420526905, i64 1022192219, i64 2903287064, i64 2470845899, i64 3649873273, i64 2502333582, i64 3972385637, i64 4246356763, i64 199084157, i64 1567178788, i64 2107121836, i64 4293612856, i64 1902910177, i64 332397359, i64 83422598, i64 3614961721, i64 456321943, i64 2277615967, i64 2302518510, i64 3258315116, i64 2521897172, i64 3900282042, i64 4186973154, i64 3146532165, i64 2299685029, i64 3889120948, i64 1293301857, i64 187455105, i64 3395849230, i64 913321567, i64 3093513909, i64 1440944571, i64 1923481911, i64 338680924, i64 1204882963, i64 2739724491, i64 2886241328, i64 2408907774, i64 1299817192, i64 2474012871, i64 45400213, i64 553186784, i64 134558656, i64 2180943666, i64 2870807589, i64 76511085, i64 3053566760, i64 2516601415, i64 4172865902, i64 1751297915, i64 1251975234, i64 2964780642, i64 1412975316, i64 2739978478, i64 2171013719,
ret ptr @pf