Merge branch 'ClashLuke-patch-1'
commit
a49b020eff
|
@ -157,11 +157,7 @@ class Kron(torch.optim.Optimizer):
|
||||||
self._param_exprs = {} # cache for einsum expr
|
self._param_exprs = {} # cache for einsum expr
|
||||||
self._tiny = torch.finfo(torch.bfloat16).tiny
|
self._tiny = torch.finfo(torch.bfloat16).tiny
|
||||||
self.rng = random.Random(1337)
|
self.rng = random.Random(1337)
|
||||||
if deterministic:
|
self.deterministic = deterministic
|
||||||
# Use a Generator to try to be more deterministic across resume (save/load)
|
|
||||||
self.torch_rng = torch.Generator().manual_seed(1337)
|
|
||||||
else:
|
|
||||||
self.torch_rng = None
|
|
||||||
|
|
||||||
# make compile optional (for bwd compat)
|
# make compile optional (for bwd compat)
|
||||||
if has_dynamo:
|
if has_dynamo:
|
||||||
|
@ -178,7 +174,6 @@ class Kron(torch.optim.Optimizer):
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
_dict = super().__getstate__()
|
_dict = super().__getstate__()
|
||||||
_dict["rng"] = self.rng
|
_dict["rng"] = self.rng
|
||||||
_dict["torch_rng"] = self.torch_rng
|
|
||||||
return _dict
|
return _dict
|
||||||
|
|
||||||
def state_dict(self) -> Dict[str, Any]:
|
def state_dict(self) -> Dict[str, Any]:
|
||||||
|
@ -187,9 +182,6 @@ class Kron(torch.optim.Optimizer):
|
||||||
|
|
||||||
# Add the generator state
|
# Add the generator state
|
||||||
optimizer_state['rng_state'] = self.rng.getstate()
|
optimizer_state['rng_state'] = self.rng.getstate()
|
||||||
if self.torch_rng is not None:
|
|
||||||
optimizer_state['torch_rng_state'] = self.torch_rng.get_state()
|
|
||||||
|
|
||||||
return optimizer_state
|
return optimizer_state
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||||
|
@ -197,9 +189,7 @@ class Kron(torch.optim.Optimizer):
|
||||||
rng_states = {}
|
rng_states = {}
|
||||||
if 'rng_state' in state_dict:
|
if 'rng_state' in state_dict:
|
||||||
rng_states['rng_state'] = state_dict.pop('rng_state')
|
rng_states['rng_state'] = state_dict.pop('rng_state')
|
||||||
if 'torch_rng_state' in state_dict:
|
|
||||||
rng_states['torch_rng_state'] = state_dict.pop('torch_rng_state')
|
|
||||||
|
|
||||||
# Load the optimizer state
|
# Load the optimizer state
|
||||||
super().load_state_dict(state_dict)
|
super().load_state_dict(state_dict)
|
||||||
state_dict.update(rng_states) # add back
|
state_dict.update(rng_states) # add back
|
||||||
|
@ -207,8 +197,6 @@ class Kron(torch.optim.Optimizer):
|
||||||
# Restore the RNG state if it exists
|
# Restore the RNG state if it exists
|
||||||
if 'rng_state' in rng_states:
|
if 'rng_state' in rng_states:
|
||||||
self.rng.setstate(rng_states['rng_state'])
|
self.rng.setstate(rng_states['rng_state'])
|
||||||
if 'torch_rng_state' in rng_states:
|
|
||||||
self.torch_rng.set_state(rng_states['torch_rng_state'])
|
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
super().__setstate__(state)
|
super().__setstate__(state)
|
||||||
|
@ -317,15 +305,17 @@ class Kron(torch.optim.Optimizer):
|
||||||
if do_update:
|
if do_update:
|
||||||
exprA, exprGs, _ = exprs
|
exprA, exprGs, _ = exprs
|
||||||
Q = state["Q"]
|
Q = state["Q"]
|
||||||
if self.torch_rng is None:
|
if self.deterministic:
|
||||||
V = torch.randn_like(debiased_momentum, dtype=precond_dtype)
|
torch_rng = torch.Generator(device=debiased_momentum.device)
|
||||||
|
torch_rng.manual_seed(self.rng.randint(0, 2 ** 31))
|
||||||
else:
|
else:
|
||||||
# Restoring generator state to device is messy. For now,
|
torch_rng = None
|
||||||
# we keep RNG on CPU, but this slows the optimizer down quite a bit.
|
V = torch.randn(
|
||||||
# FIXME Need a better approach
|
debiased_momentum.shape,
|
||||||
V = torch.randn(
|
generator=torch_rng,
|
||||||
debiased_momentum.shape, generator=self.torch_rng, dtype=precond_dtype, device='cpu')
|
dtype=precond_dtype,
|
||||||
V = V.to(debiased_momentum.device)
|
device=debiased_momentum.device,
|
||||||
|
)
|
||||||
G = debiased_momentum if momentum_into_precond_update else grad
|
G = debiased_momentum if momentum_into_precond_update else grad
|
||||||
|
|
||||||
A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)
|
A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)
|
||||||
|
|
Loading…
Reference in New Issue