mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
simplify RNG
This commit is contained in:
parent
6367267298
commit
e025328f96
@ -305,12 +305,11 @@ class Kron(torch.optim.Optimizer):
|
|||||||
if do_update:
|
if do_update:
|
||||||
exprA, exprGs, _ = exprs
|
exprA, exprGs, _ = exprs
|
||||||
Q = state["Q"]
|
Q = state["Q"]
|
||||||
if self.deterministic is None:
|
if self.deterministic:
|
||||||
torch_rng = torch.Generator(device=V.device).manual_seed(self.rng.randint(0, 2 ** 31))
|
torch_rng = torch.Generator(device=V.device).manual_seed(self.rng.randint(0, 2 ** 31))
|
||||||
V = torch.randn(debiased_momentum.shape, generator=torch_rng, dtype=precond_dtype, device='cpu')
|
|
||||||
V = V.to(debiased_momentum.device)
|
|
||||||
else:
|
else:
|
||||||
V = torch.randn_like(debiased_momentum, dtype=precond_dtype)
|
torch_rng = None
|
||||||
|
V = torch.randn(debiased_momentum.shape, generator=torch_rng, dtype=precond_dtype, device=debiased_momentum.device)
|
||||||
G = debiased_momentum if momentum_into_precond_update else grad
|
G = debiased_momentum if momentum_into_precond_update else grad
|
||||||
|
|
||||||
A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)
|
A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user