simplify RNG
parent
6367267298
commit
e025328f96
|
@ -305,12 +305,11 @@ class Kron(torch.optim.Optimizer):
|
|||
if do_update:
|
||||
exprA, exprGs, _ = exprs
|
||||
Q = state["Q"]
|
||||
if self.deterministic is None:
|
||||
if self.deterministic:
|
||||
torch_rng = torch.Generator(device=V.device).manual_seed(self.rng.randint(0, 2 ** 31))
|
||||
V = torch.randn(debiased_momentum.shape, generator=torch_rng, dtype=precond_dtype, device='cpu')
|
||||
V = V.to(debiased_momentum.device)
|
||||
else:
|
||||
V = torch.randn_like(debiased_momentum, dtype=precond_dtype)
|
||||
torch_rng = None
|
||||
V = torch.randn(debiased_momentum.shape, generator=torch_rng, dtype=precond_dtype, device=debiased_momentum.device)
|
||||
G = debiased_momentum if momentum_into_precond_update else grad
|
||||
|
||||
A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)
|
||||
|
|
Loading…
Reference in New Issue