More additions to Kron

This commit is contained in:
Ross Wightman 2025-01-27 13:09:09 -08:00 committed by Ross Wightman
parent 5f10450235
commit 9ab5464e4d

View File

@ -15,15 +15,6 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
try:
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
import opt_einsum
opt_einsum.enabled = True
opt_einsum.strategy = "auto-hq"
import torch.backends.opt_einsum
has_opt_einsum = True
except ImportError:
has_opt_einsum = False
try: try:
torch._dynamo.config.cache_size_limit = 1_000_000 torch._dynamo.config.cache_size_limit = 1_000_000
@ -67,19 +58,20 @@ class Kron(torch.optim.Optimizer):
params: Iterable of parameters to optimize or dicts defining parameter groups. params: Iterable of parameters to optimize or dicts defining parameter groups.
lr: Learning rate. lr: Learning rate.
momentum: Momentum parameter. momentum: Momentum parameter.
weight_decay: Weight decay (L2 penalty). weight_decay: Weight decay.
preconditioner_update_probability: Probability of updating the preconditioner. preconditioner_update_probability: Probability of updating the preconditioner.
If None, defaults to a schedule that anneals from 1.0 to 0.03 by 4000 steps. If None, defaults to a schedule that anneals from 1.0 to 0.03 by 4000 steps.
max_size_triangular: Max size for dim's preconditioner to be triangular. max_size_triangular: Max size for dim's preconditioner to be triangular.
min_ndim_triangular: Minimum number of dimensions a layer needs to have triangular preconditioners. min_ndim_triangular: Minimum number of dimensions a layer needs to have triangular preconditioners.
memory_save_mode: 'one_diag', or 'all_diag', None is default memory_save_mode: 'one_diag', 'smart_one_diag', or 'all_diag', None is default
to set all preconditioners to be triangular, 'one_diag' sets the largest to set all preconditioners to be triangular, 'one_diag' sets the largest
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners to be diagonal. or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners to be diagonal.
momentum_into_precond_update: whether to send momentum into preconditioner momentum_into_precond_update: whether to send momentum into preconditioner
update instead of raw gradients. update instead of raw gradients.
mu_dtype: Dtype of the momentum accumulator. mu_dtype: Dtype of the momentum accumulator.
precond_dtype: Dtype of the preconditioner. precond_dtype: Dtype of the preconditioner.
decoupled_decay: AdamW style decoupled-decay. decoupled_decay: AdamW style decoupled weight decay
flatten_dim: Flatten dim >= 2 instead of relying on expressions
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
""" """
@ -97,10 +89,18 @@ class Kron(torch.optim.Optimizer):
mu_dtype: Optional[torch.dtype] = None, mu_dtype: Optional[torch.dtype] = None,
precond_dtype: Optional[torch.dtype] = None, precond_dtype: Optional[torch.dtype] = None,
decoupled_decay: bool = False, decoupled_decay: bool = False,
flatten_dim: bool = False,
deterministic: bool = False, deterministic: bool = False,
): ):
if not has_opt_einsum: try:
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
import opt_einsum
opt_einsum.enabled = True
opt_einsum.strategy = "auto-hq"
import torch.backends.opt_einsum
except ImportError:
warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer." ) warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer." )
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= momentum < 1.0: if not 0.0 <= momentum < 1.0:
@ -122,10 +122,11 @@ class Kron(torch.optim.Optimizer):
mu_dtype=mu_dtype, mu_dtype=mu_dtype,
precond_dtype=precond_dtype, precond_dtype=precond_dtype,
decoupled_decay=decoupled_decay, decoupled_decay=decoupled_decay,
flatten_dim=flatten_dim,
) )
super(Kron, self).__init__(params, defaults) super(Kron, self).__init__(params, defaults)
self._param_exprs = {} self._param_exprs = {} # cache for einsum expr
self._tiny = torch.finfo(torch.bfloat16).tiny self._tiny = torch.finfo(torch.bfloat16).tiny
self.rng = random.Random(1337) self.rng = random.Random(1337)
if deterministic: if deterministic:
@ -165,20 +166,21 @@ class Kron(torch.optim.Optimizer):
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# Extract and remove the RNG state from the state dict # Extract and remove the RNG state from the state dict
rng_state = state_dict.pop('rng_state', None) rng_states = {}
torch_rng_state = state_dict.pop('torch_rng_state', None) if 'rng_state' in state_dict:
rng_states['rng_state'] = state_dict.pop('rng_state')
if 'torch_rng_state' in state_dict:
rng_states['torch_rng_state'] = state_dict.pop('torch_rng_state')
# Load the optimizer state # Load the optimizer state
super().load_state_dict(state_dict) super().load_state_dict(state_dict)
state_dict.update(rng_states) # add back
# Restore the RNG state if it exists # Restore the RNG state if it exists
if rng_state is not None: if 'rng_state' in rng_states:
self.rng.setstate(rng_state) self.rng.setstate(rng_states['rng_state'])
state_dict['rng_state'] = rng_state # put it back if caller still using state_dict if 'torch_rng_state' in rng_states:
if torch_rng_state is not None: self.torch_rng.set_state(rng_states['torch_rng_state'])
if self.torch_rng is not None:
self.torch_rng.set_state(torch_rng_state)
state_dict['torch_rng_state'] = torch_rng_state # put it back if caller still using state_dict
def __setstate__(self, state): def __setstate__(self, state):
super().__setstate__(state) super().__setstate__(state)
@ -208,13 +210,16 @@ class Kron(torch.optim.Optimizer):
grad = p.grad grad = p.grad
state = self.state[p] state = self.state[p]
if group['flatten_dim']:
grad = grad.view(grad.size(0), -1)
if len(state) == 0: if len(state) == 0:
state["step"] = 0 state["step"] = 0
state["update_counter"] = 0 state["update_counter"] = 0
state["momentum_buffer"] = torch.zeros_like(p, dtype=mu_dtype or p.dtype) state["momentum_buffer"] = torch.zeros_like(grad, dtype=mu_dtype or grad.dtype)
# init Q and einsum expressions on first step
state["Q"], exprs = _init_Q_exprs( state["Q"], exprs = _init_Q_exprs(
p, grad,
group["precond_init_scale"], group["precond_init_scale"],
group["max_size_triangular"], group["max_size_triangular"],
group["min_ndim_triangular"], group["min_ndim_triangular"],
@ -234,8 +239,9 @@ class Kron(torch.optim.Optimizer):
total_precond_size += precond_size total_precond_size += precond_size
total_precond_mb += precond_mb total_precond_mb += precond_mb
elif p not in self._param_exprs: elif p not in self._param_exprs:
# init only the einsum expressions, called after state load, Q are loaded from state_dict
exprs = _init_Q_exprs( exprs = _init_Q_exprs(
p, grad,
group["precond_init_scale"], group["precond_init_scale"],
group["max_size_triangular"], group["max_size_triangular"],
group["min_ndim_triangular"], group["min_ndim_triangular"],
@ -245,6 +251,7 @@ class Kron(torch.optim.Optimizer):
) )
self._param_exprs[p] = exprs self._param_exprs[p] = exprs
else: else:
# retrieve cached expressions
exprs = self._param_exprs[p] exprs = self._param_exprs[p]
# update preconditioners all together deterministically # update preconditioners all together deterministically
@ -315,6 +322,8 @@ class Kron(torch.optim.Optimizer):
# RMS of pre_grad should be 1.0, so let's cap at 1.1 # RMS of pre_grad should be 1.0, so let's cap at 1.1
pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0)) pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0))
if group['flatten_dim']:
pre_grad = pre_grad.view(p.shape)
# Apply weight decay # Apply weight decay
if group["weight_decay"] != 0: if group["weight_decay"] != 0:
@ -369,9 +378,10 @@ def _init_Q_exprs(
dim_diag = [False for _ in shape] dim_diag = [False for _ in shape]
dim_diag[rev_sorted_dims[0]] = True dim_diag[rev_sorted_dims[0]] = True
elif memory_save_mode == "smart_one_diag": elif memory_save_mode == "smart_one_diag":
dim_diag = [False for _ in shape] # addition proposed by Lucas Nestler
rev_sorted_dims = np.argsort(shape)[::-1] rev_sorted_dims = np.argsort(shape)[::-1]
sorted_shape = sorted(shape) sorted_shape = sorted(shape)
dim_diag = [False for _ in shape]
if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]: if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
dim_diag[rev_sorted_dims[0]] = True dim_diag[rev_sorted_dims[0]] = True
elif memory_save_mode == "all_diag": elif memory_save_mode == "all_diag":