From 6367267298142f372552ac8dcffb3795727c8054 Mon Sep 17 00:00:00 2001
From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com>
Date: Fri, 31 Jan 2025 17:23:53 +0100
Subject: [PATCH 1/2] unify RNG

---
 timm/optim/kron.py | 29 +++++++----------------------
 1 file changed, 7 insertions(+), 22 deletions(-)

diff --git a/timm/optim/kron.py b/timm/optim/kron.py
index 533354ec..ad09ed9c 100644
--- a/timm/optim/kron.py
+++ b/timm/optim/kron.py
@@ -157,11 +157,7 @@ class Kron(torch.optim.Optimizer):
         self._param_exprs = {}  # cache for einsum expr
         self._tiny = torch.finfo(torch.bfloat16).tiny
         self.rng = random.Random(1337)
-        if deterministic:
-            # Use a Generator to try to be more deterministic across resume (save/load)
-            self.torch_rng = torch.Generator().manual_seed(1337)
-        else:
-            self.torch_rng = None
+        self.deterministic = deterministic
 
         # make compile optional (for bwd compat)
         if has_dynamo:
@@ -178,7 +174,6 @@ class Kron(torch.optim.Optimizer):
     def __getstate__(self):
         _dict = super().__getstate__()
         _dict["rng"] = self.rng
-        _dict["torch_rng"] = self.torch_rng
         return _dict
 
     def state_dict(self) -> Dict[str, Any]:
@@ -187,9 +182,6 @@ class Kron(torch.optim.Optimizer):
 
         # Add the generator state
         optimizer_state['rng_state'] = self.rng.getstate()
-        if self.torch_rng is not None:
-            optimizer_state['torch_rng_state'] = self.torch_rng.get_state()
-
         return optimizer_state
 
     def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
@@ -197,9 +189,7 @@ class Kron(torch.optim.Optimizer):
         rng_states = {}
         if 'rng_state' in state_dict:
             rng_states['rng_state'] = state_dict.pop('rng_state')
-        if 'torch_rng_state' in state_dict:
-            rng_states['torch_rng_state'] = state_dict.pop('torch_rng_state')
-
+            
         # Load the optimizer state
         super().load_state_dict(state_dict)
         state_dict.update(rng_states)  # add back
@@ -207,8 +197,6 @@ class Kron(torch.optim.Optimizer):
         # Restore the RNG state if it exists
         if 'rng_state' in rng_states:
             self.rng.setstate(rng_states['rng_state'])
-        if 'torch_rng_state' in rng_states:
-            self.torch_rng.set_state(rng_states['torch_rng_state'])
 
     def __setstate__(self, state):
         super().__setstate__(state)
@@ -317,15 +305,12 @@ class Kron(torch.optim.Optimizer):
                 if do_update:
                     exprA, exprGs, _ = exprs
                     Q = state["Q"]
-                    if self.torch_rng is None:
-                        V = torch.randn_like(debiased_momentum, dtype=precond_dtype)
-                    else:
-                        # Restoring generator state to device is messy. For now,
-                        # we keep RNG on CPU, but this slows the optimizer down quite a bit.
-                        # FIXME Need a better approach
-                        V = torch.randn(
-                            debiased_momentum.shape, generator=self.torch_rng, dtype=precond_dtype, device='cpu')
+                    if self.deterministic is None:
+                        torch_rng = torch.Generator(device=V.device).manual_seed(self.rng.randint(0, 2 ** 31))
+                        V = torch.randn(debiased_momentum.shape, generator=torch_rng, dtype=precond_dtype, device='cpu')
                         V = V.to(debiased_momentum.device)
+                    else:
+                        V = torch.randn_like(debiased_momentum, dtype=precond_dtype)
                     G = debiased_momentum if momentum_into_precond_update else grad
 
                     A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)

From e025328f96c078b5f6e78a73f7a66a73e4ee584a Mon Sep 17 00:00:00 2001
From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com>
Date: Fri, 31 Jan 2025 17:26:14 +0100
Subject: [PATCH 2/2] simplify RNG

---
 timm/optim/kron.py | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/timm/optim/kron.py b/timm/optim/kron.py
index ad09ed9c..322bc391 100644
--- a/timm/optim/kron.py
+++ b/timm/optim/kron.py
@@ -305,12 +305,11 @@ class Kron(torch.optim.Optimizer):
                 if do_update:
                     exprA, exprGs, _ = exprs
                     Q = state["Q"]
-                    if self.deterministic is None:
+                    if self.deterministic:
                         torch_rng = torch.Generator(device=V.device).manual_seed(self.rng.randint(0, 2 ** 31))
-                        V = torch.randn(debiased_momentum.shape, generator=torch_rng, dtype=precond_dtype, device='cpu')
-                        V = V.to(debiased_momentum.device)
                     else:
-                        V = torch.randn_like(debiased_momentum, dtype=precond_dtype)
+                        torch_rng = None
+                    V = torch.randn(debiased_momentum.shape, generator=torch_rng, dtype=precond_dtype, device=debiased_momentum.device)
                     G = debiased_momentum if momentum_into_precond_update else grad
 
                     A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)