simplify RNG

pull/2433/head
Lucas Nestler 2025-01-31 17:26:14 +01:00 committed by GitHub
parent 6367267298
commit e025328f96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 4 deletions

View File

@ -305,12 +305,11 @@ class Kron(torch.optim.Optimizer):
if do_update:
exprA, exprGs, _ = exprs
Q = state["Q"]
if self.deterministic is None:
if self.deterministic:
torch_rng = torch.Generator(device=V.device).manual_seed(self.rng.randint(0, 2 ** 31))
V = torch.randn(debiased_momentum.shape, generator=torch_rng, dtype=precond_dtype, device='cpu')
V = V.to(debiased_momentum.device)
else:
V = torch.randn_like(debiased_momentum, dtype=precond_dtype)
torch_rng = None
V = torch.randn(debiased_momentum.shape, generator=torch_rng, dtype=precond_dtype, device=debiased_momentum.device)
G = debiased_momentum if momentum_into_precond_update else grad
A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)