diff --git a/timm/optim/kron.py b/timm/optim/kron.py index 533354ec..ad09ed9c 100644 --- a/timm/optim/kron.py +++ b/timm/optim/kron.py @@ -157,11 +157,7 @@ class Kron(torch.optim.Optimizer): self._param_exprs = {} # cache for einsum expr self._tiny = torch.finfo(torch.bfloat16).tiny self.rng = random.Random(1337) - if deterministic: - # Use a Generator to try to be more deterministic across resume (save/load) - self.torch_rng = torch.Generator().manual_seed(1337) - else: - self.torch_rng = None + self.deterministic = deterministic # make compile optional (for bwd compat) if has_dynamo: @@ -178,7 +174,6 @@ class Kron(torch.optim.Optimizer): def __getstate__(self): _dict = super().__getstate__() _dict["rng"] = self.rng - _dict["torch_rng"] = self.torch_rng return _dict def state_dict(self) -> Dict[str, Any]: @@ -187,9 +182,6 @@ class Kron(torch.optim.Optimizer): # Add the generator state optimizer_state['rng_state'] = self.rng.getstate() - if self.torch_rng is not None: - optimizer_state['torch_rng_state'] = self.torch_rng.get_state() - return optimizer_state def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -197,9 +189,7 @@ class Kron(torch.optim.Optimizer): rng_states = {} if 'rng_state' in state_dict: rng_states['rng_state'] = state_dict.pop('rng_state') - if 'torch_rng_state' in state_dict: - rng_states['torch_rng_state'] = state_dict.pop('torch_rng_state') - + # Load the optimizer state super().load_state_dict(state_dict) state_dict.update(rng_states) # add back @@ -207,8 +197,6 @@ class Kron(torch.optim.Optimizer): # Restore the RNG state if it exists if 'rng_state' in rng_states: self.rng.setstate(rng_states['rng_state']) - if 'torch_rng_state' in rng_states: - self.torch_rng.set_state(rng_states['torch_rng_state']) def __setstate__(self, state): super().__setstate__(state) @@ -317,15 +305,12 @@ class Kron(torch.optim.Optimizer): if do_update: exprA, exprGs, _ = exprs Q = state["Q"] - if self.torch_rng is None: - V = torch.randn_like(debiased_momentum, dtype=precond_dtype) - else: - # Restoring generator state to device is messy. For now, - # we keep RNG on CPU, but this slows the optimizer down quite a bit. - # FIXME Need a better approach - V = torch.randn( - debiased_momentum.shape, generator=self.torch_rng, dtype=precond_dtype, device='cpu') + if self.deterministic is None: + torch_rng = torch.Generator(device=V.device).manual_seed(self.rng.randint(0, 2 ** 31)) + V = torch.randn(debiased_momentum.shape, generator=torch_rng, dtype=precond_dtype, device='cpu') V = V.to(debiased_momentum.device) + else: + V = torch.randn_like(debiased_momentum, dtype=precond_dtype) G = debiased_momentum if momentum_into_precond_update else grad A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)