From e025328f96c078b5f6e78a73f7a66a73e4ee584a Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 31 Jan 2025 17:26:14 +0100 Subject: [PATCH] simplify RNG --- timm/optim/kron.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/timm/optim/kron.py b/timm/optim/kron.py index ad09ed9c..322bc391 100644 --- a/timm/optim/kron.py +++ b/timm/optim/kron.py @@ -305,12 +305,11 @@ class Kron(torch.optim.Optimizer): if do_update: exprA, exprGs, _ = exprs Q = state["Q"] - if self.deterministic is None: + if self.deterministic: torch_rng = torch.Generator(device=V.device).manual_seed(self.rng.randint(0, 2 ** 31)) - V = torch.randn(debiased_momentum.shape, generator=torch_rng, dtype=precond_dtype, device='cpu') - V = V.to(debiased_momentum.device) else: - V = torch.randn_like(debiased_momentum, dtype=precond_dtype) + torch_rng = None + V = torch.randn(debiased_momentum.shape, generator=torch_rng, dtype=precond_dtype, device=debiased_momentum.device) G = debiased_momentum if momentum_into_precond_update else grad A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)