Merge branch 'ClashLuke-patch-1'
commit
a49b020eff
|
@ -157,11 +157,7 @@ class Kron(torch.optim.Optimizer):
|
|||
self._param_exprs = {} # cache for einsum expr
|
||||
self._tiny = torch.finfo(torch.bfloat16).tiny
|
||||
self.rng = random.Random(1337)
|
||||
if deterministic:
|
||||
# Use a Generator to try to be more deterministic across resume (save/load)
|
||||
self.torch_rng = torch.Generator().manual_seed(1337)
|
||||
else:
|
||||
self.torch_rng = None
|
||||
self.deterministic = deterministic
|
||||
|
||||
# make compile optional (for bwd compat)
|
||||
if has_dynamo:
|
||||
|
@ -178,7 +174,6 @@ class Kron(torch.optim.Optimizer):
|
|||
def __getstate__(self):
|
||||
_dict = super().__getstate__()
|
||||
_dict["rng"] = self.rng
|
||||
_dict["torch_rng"] = self.torch_rng
|
||||
return _dict
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
|
@ -187,9 +182,6 @@ class Kron(torch.optim.Optimizer):
|
|||
|
||||
# Add the generator state
|
||||
optimizer_state['rng_state'] = self.rng.getstate()
|
||||
if self.torch_rng is not None:
|
||||
optimizer_state['torch_rng_state'] = self.torch_rng.get_state()
|
||||
|
||||
return optimizer_state
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
|
@ -197,8 +189,6 @@ class Kron(torch.optim.Optimizer):
|
|||
rng_states = {}
|
||||
if 'rng_state' in state_dict:
|
||||
rng_states['rng_state'] = state_dict.pop('rng_state')
|
||||
if 'torch_rng_state' in state_dict:
|
||||
rng_states['torch_rng_state'] = state_dict.pop('torch_rng_state')
|
||||
|
||||
# Load the optimizer state
|
||||
super().load_state_dict(state_dict)
|
||||
|
@ -207,8 +197,6 @@ class Kron(torch.optim.Optimizer):
|
|||
# Restore the RNG state if it exists
|
||||
if 'rng_state' in rng_states:
|
||||
self.rng.setstate(rng_states['rng_state'])
|
||||
if 'torch_rng_state' in rng_states:
|
||||
self.torch_rng.set_state(rng_states['torch_rng_state'])
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
|
@ -317,15 +305,17 @@ class Kron(torch.optim.Optimizer):
|
|||
if do_update:
|
||||
exprA, exprGs, _ = exprs
|
||||
Q = state["Q"]
|
||||
if self.torch_rng is None:
|
||||
V = torch.randn_like(debiased_momentum, dtype=precond_dtype)
|
||||
if self.deterministic:
|
||||
torch_rng = torch.Generator(device=debiased_momentum.device)
|
||||
torch_rng.manual_seed(self.rng.randint(0, 2 ** 31))
|
||||
else:
|
||||
# Restoring generator state to device is messy. For now,
|
||||
# we keep RNG on CPU, but this slows the optimizer down quite a bit.
|
||||
# FIXME Need a better approach
|
||||
torch_rng = None
|
||||
V = torch.randn(
|
||||
debiased_momentum.shape, generator=self.torch_rng, dtype=precond_dtype, device='cpu')
|
||||
V = V.to(debiased_momentum.device)
|
||||
debiased_momentum.shape,
|
||||
generator=torch_rng,
|
||||
dtype=precond_dtype,
|
||||
device=debiased_momentum.device,
|
||||
)
|
||||
G = debiased_momentum if momentum_into_precond_update else grad
|
||||
|
||||
A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)
|
||||
|
|
Loading…
Reference in New Issue