Merge branch 'patch-1' of github.com:ClashLuke/pytorch-image-models into ClashLuke-patch-1

so150m2
Ross Wightman 2025-01-31 12:43:28 -08:00
commit 875c19d0c9
1 changed files with 6 additions and 22 deletions

View File

@ -157,11 +157,7 @@ class Kron(torch.optim.Optimizer):
self._param_exprs = {} # cache for einsum expr
self._tiny = torch.finfo(torch.bfloat16).tiny
self.rng = random.Random(1337)
if deterministic:
# Use a Generator to try to be more deterministic across resume (save/load)
self.torch_rng = torch.Generator().manual_seed(1337)
else:
self.torch_rng = None
self.deterministic = deterministic
# make compile optional (for bwd compat)
if has_dynamo:
@ -178,7 +174,6 @@ class Kron(torch.optim.Optimizer):
def __getstate__(self):
_dict = super().__getstate__()
_dict["rng"] = self.rng
_dict["torch_rng"] = self.torch_rng
return _dict
def state_dict(self) -> Dict[str, Any]:
@ -187,9 +182,6 @@ class Kron(torch.optim.Optimizer):
# Add the generator state
optimizer_state['rng_state'] = self.rng.getstate()
if self.torch_rng is not None:
optimizer_state['torch_rng_state'] = self.torch_rng.get_state()
return optimizer_state
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
@ -197,9 +189,7 @@ class Kron(torch.optim.Optimizer):
rng_states = {}
if 'rng_state' in state_dict:
rng_states['rng_state'] = state_dict.pop('rng_state')
if 'torch_rng_state' in state_dict:
rng_states['torch_rng_state'] = state_dict.pop('torch_rng_state')
# Load the optimizer state
super().load_state_dict(state_dict)
state_dict.update(rng_states) # add back
@ -207,8 +197,6 @@ class Kron(torch.optim.Optimizer):
# Restore the RNG state if it exists
if 'rng_state' in rng_states:
self.rng.setstate(rng_states['rng_state'])
if 'torch_rng_state' in rng_states:
self.torch_rng.set_state(rng_states['torch_rng_state'])
def __setstate__(self, state):
super().__setstate__(state)
@ -317,15 +305,11 @@ class Kron(torch.optim.Optimizer):
if do_update:
exprA, exprGs, _ = exprs
Q = state["Q"]
if self.torch_rng is None:
V = torch.randn_like(debiased_momentum, dtype=precond_dtype)
if self.deterministic:
torch_rng = torch.Generator(device=V.device).manual_seed(self.rng.randint(0, 2 ** 31))
else:
# Restoring generator state to device is messy. For now,
# we keep RNG on CPU, but this slows the optimizer down quite a bit.
# FIXME Need a better approach
V = torch.randn(
debiased_momentum.shape, generator=self.torch_rng, dtype=precond_dtype, device='cpu')
V = V.to(debiased_momentum.device)
torch_rng = None
V = torch.randn(debiased_momentum.shape, generator=torch_rng, dtype=precond_dtype, device=debiased_momentum.device)
G = debiased_momentum if momentum_into_precond_update else grad
A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)