More additions to Kron
parent
5f10450235
commit
9ab5464e4d
timm/optim
|
@ -15,15 +15,6 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
|
||||
import opt_einsum
|
||||
opt_einsum.enabled = True
|
||||
opt_einsum.strategy = "auto-hq"
|
||||
import torch.backends.opt_einsum
|
||||
has_opt_einsum = True
|
||||
except ImportError:
|
||||
has_opt_einsum = False
|
||||
|
||||
try:
|
||||
torch._dynamo.config.cache_size_limit = 1_000_000
|
||||
|
@ -67,19 +58,20 @@ class Kron(torch.optim.Optimizer):
|
|||
params: Iterable of parameters to optimize or dicts defining parameter groups.
|
||||
lr: Learning rate.
|
||||
momentum: Momentum parameter.
|
||||
weight_decay: Weight decay (L2 penalty).
|
||||
weight_decay: Weight decay.
|
||||
preconditioner_update_probability: Probability of updating the preconditioner.
|
||||
If None, defaults to a schedule that anneals from 1.0 to 0.03 by 4000 steps.
|
||||
max_size_triangular: Max size for dim's preconditioner to be triangular.
|
||||
min_ndim_triangular: Minimum number of dimensions a layer needs to have triangular preconditioners.
|
||||
memory_save_mode: 'one_diag', or 'all_diag', None is default
|
||||
memory_save_mode: 'one_diag', 'smart_one_diag', or 'all_diag', None is default
|
||||
to set all preconditioners to be triangular, 'one_diag' sets the largest
|
||||
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners to be diagonal.
|
||||
momentum_into_precond_update: whether to send momentum into preconditioner
|
||||
update instead of raw gradients.
|
||||
mu_dtype: Dtype of the momentum accumulator.
|
||||
precond_dtype: Dtype of the preconditioner.
|
||||
decoupled_decay: AdamW style decoupled-decay.
|
||||
decoupled_decay: AdamW style decoupled weight decay
|
||||
flatten_dim: Flatten dim >= 2 instead of relying on expressions
|
||||
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
|
||||
"""
|
||||
|
||||
|
@ -97,10 +89,18 @@ class Kron(torch.optim.Optimizer):
|
|||
mu_dtype: Optional[torch.dtype] = None,
|
||||
precond_dtype: Optional[torch.dtype] = None,
|
||||
decoupled_decay: bool = False,
|
||||
flatten_dim: bool = False,
|
||||
deterministic: bool = False,
|
||||
):
|
||||
if not has_opt_einsum:
|
||||
try:
|
||||
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
|
||||
import opt_einsum
|
||||
opt_einsum.enabled = True
|
||||
opt_einsum.strategy = "auto-hq"
|
||||
import torch.backends.opt_einsum
|
||||
except ImportError:
|
||||
warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer." )
|
||||
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= momentum < 1.0:
|
||||
|
@ -122,10 +122,11 @@ class Kron(torch.optim.Optimizer):
|
|||
mu_dtype=mu_dtype,
|
||||
precond_dtype=precond_dtype,
|
||||
decoupled_decay=decoupled_decay,
|
||||
flatten_dim=flatten_dim,
|
||||
)
|
||||
super(Kron, self).__init__(params, defaults)
|
||||
|
||||
self._param_exprs = {}
|
||||
self._param_exprs = {} # cache for einsum expr
|
||||
self._tiny = torch.finfo(torch.bfloat16).tiny
|
||||
self.rng = random.Random(1337)
|
||||
if deterministic:
|
||||
|
@ -165,20 +166,21 @@ class Kron(torch.optim.Optimizer):
|
|||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
# Extract and remove the RNG state from the state dict
|
||||
rng_state = state_dict.pop('rng_state', None)
|
||||
torch_rng_state = state_dict.pop('torch_rng_state', None)
|
||||
rng_states = {}
|
||||
if 'rng_state' in state_dict:
|
||||
rng_states['rng_state'] = state_dict.pop('rng_state')
|
||||
if 'torch_rng_state' in state_dict:
|
||||
rng_states['torch_rng_state'] = state_dict.pop('torch_rng_state')
|
||||
|
||||
# Load the optimizer state
|
||||
super().load_state_dict(state_dict)
|
||||
state_dict.update(rng_states) # add back
|
||||
|
||||
# Restore the RNG state if it exists
|
||||
if rng_state is not None:
|
||||
self.rng.setstate(rng_state)
|
||||
state_dict['rng_state'] = rng_state # put it back if caller still using state_dict
|
||||
if torch_rng_state is not None:
|
||||
if self.torch_rng is not None:
|
||||
self.torch_rng.set_state(torch_rng_state)
|
||||
state_dict['torch_rng_state'] = torch_rng_state # put it back if caller still using state_dict
|
||||
if 'rng_state' in rng_states:
|
||||
self.rng.setstate(rng_states['rng_state'])
|
||||
if 'torch_rng_state' in rng_states:
|
||||
self.torch_rng.set_state(rng_states['torch_rng_state'])
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
|
@ -208,13 +210,16 @@ class Kron(torch.optim.Optimizer):
|
|||
|
||||
grad = p.grad
|
||||
state = self.state[p]
|
||||
if group['flatten_dim']:
|
||||
grad = grad.view(grad.size(0), -1)
|
||||
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
state["update_counter"] = 0
|
||||
state["momentum_buffer"] = torch.zeros_like(p, dtype=mu_dtype or p.dtype)
|
||||
state["momentum_buffer"] = torch.zeros_like(grad, dtype=mu_dtype or grad.dtype)
|
||||
# init Q and einsum expressions on first step
|
||||
state["Q"], exprs = _init_Q_exprs(
|
||||
p,
|
||||
grad,
|
||||
group["precond_init_scale"],
|
||||
group["max_size_triangular"],
|
||||
group["min_ndim_triangular"],
|
||||
|
@ -234,8 +239,9 @@ class Kron(torch.optim.Optimizer):
|
|||
total_precond_size += precond_size
|
||||
total_precond_mb += precond_mb
|
||||
elif p not in self._param_exprs:
|
||||
# init only the einsum expressions, called after state load, Q are loaded from state_dict
|
||||
exprs = _init_Q_exprs(
|
||||
p,
|
||||
grad,
|
||||
group["precond_init_scale"],
|
||||
group["max_size_triangular"],
|
||||
group["min_ndim_triangular"],
|
||||
|
@ -245,6 +251,7 @@ class Kron(torch.optim.Optimizer):
|
|||
)
|
||||
self._param_exprs[p] = exprs
|
||||
else:
|
||||
# retrieve cached expressions
|
||||
exprs = self._param_exprs[p]
|
||||
|
||||
# update preconditioners all together deterministically
|
||||
|
@ -315,6 +322,8 @@ class Kron(torch.optim.Optimizer):
|
|||
|
||||
# RMS of pre_grad should be 1.0, so let's cap at 1.1
|
||||
pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0))
|
||||
if group['flatten_dim']:
|
||||
pre_grad = pre_grad.view(p.shape)
|
||||
|
||||
# Apply weight decay
|
||||
if group["weight_decay"] != 0:
|
||||
|
@ -369,9 +378,10 @@ def _init_Q_exprs(
|
|||
dim_diag = [False for _ in shape]
|
||||
dim_diag[rev_sorted_dims[0]] = True
|
||||
elif memory_save_mode == "smart_one_diag":
|
||||
dim_diag = [False for _ in shape]
|
||||
# addition proposed by Lucas Nestler
|
||||
rev_sorted_dims = np.argsort(shape)[::-1]
|
||||
sorted_shape = sorted(shape)
|
||||
dim_diag = [False for _ in shape]
|
||||
if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
|
||||
dim_diag[rev_sorted_dims[0]] = True
|
||||
elif memory_save_mode == "all_diag":
|
||||
|
|
Loading…
Reference in New Issue