499 lines
19 KiB
Python
Raw Normal View History

2025-01-24 14:16:21 -08:00
""" PyTorch Implementation of the Kron PSGD optimizer
FIXME attribution
* https://github.com/evanatyourservice/kron_torch (direct source)
* https://github.com/lixilinx/psgd_torch (original)
* https://github.com/ClashLuke/HeavyBall (added improvements)
"""
import logging
2025-01-24 14:16:21 -08:00
import string
import random
import warnings
from typing import Any, Callable, Dict, Optional, Tuple, Union
2025-01-24 14:16:21 -08:00
import numpy as np
import torch
2025-01-24 14:16:21 -08:00
try:
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
import opt_einsum
opt_einsum.enabled = True
opt_einsum.strategy = "auto-hq"
import torch.backends.opt_einsum
has_opt_einsum = True
2025-01-24 14:16:21 -08:00
except ImportError:
has_opt_einsum = False
2025-01-24 14:16:21 -08:00
try:
torch._dynamo.config.cache_size_limit = 1_000_000
has_dynamo = True
except AttributeError:
has_dynamo = False
_logger = logging.getLogger(__name__)
2025-01-24 14:16:21 -08:00
def precond_update_prob_schedule(
n: float,
max_prob: float = 1.0,
min_prob: float = 0.03,
decay: float = 0.001,
flat_start: float = 500,
) -> torch.Tensor:
2025-01-24 14:16:21 -08:00
"""Anneal preconditioner update probability during beginning of training.
PSGD benefits from more preconditioner updates at the beginning of training,
but once the preconditioner is learned the update probability can drop low.
This schedule is an exponential anneal with a flat start. Default settings keep
update probability at 1.0 for 200 steps then exponentially anneal down to
`min_prob` by 4000 steps. Default settings work very well for most models and
training regimes.
"""
"""Exponential anneal with flat start."""
n = torch.tensor(n, dtype=torch.float32)
prob = max_prob * torch.exp(-decay * (n - flat_start))
prob.clamp_(min=min_prob, max=max_prob)
return prob
class Kron(torch.optim.Optimizer):
"""Implements PSGD Kron from https://github.com/lixilinx/psgd_torch.
Args:
params: Iterable of parameters to optimize or dicts defining parameter groups.
lr: Learning rate.
momentum: Momentum parameter.
weight_decay: Weight decay (L2 penalty).
preconditioner_update_probability: Probability of updating the preconditioner.
If None, defaults to a schedule that anneals from 1.0 to 0.03 by 4000 steps.
max_size_triangular: Max size for dim's preconditioner to be triangular.
min_ndim_triangular: Minimum number of dimensions a layer needs to have triangular preconditioners.
memory_save_mode: 'one_diag', or 'all_diag', None is default
2025-01-24 14:16:21 -08:00
to set all preconditioners to be triangular, 'one_diag' sets the largest
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners to be diagonal.
momentum_into_precond_update: whether to send momentum into preconditioner
2025-01-24 14:16:21 -08:00
update instead of raw gradients.
mu_dtype: Dtype of the momentum accumulator.
precond_dtype: Dtype of the preconditioner.
decoupled_decay: AdamW style decoupled-decay.
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
2025-01-24 14:16:21 -08:00
"""
def __init__(
self,
params,
lr: float = 0.001,
momentum: float = 0.9,
weight_decay: float = 0.0,
preconditioner_update_probability: Optional[Union[Callable, float]] = None,
max_size_triangular: int = 2048,
min_ndim_triangular: int = 2,
memory_save_mode: Optional[str] = None,
momentum_into_precond_update: bool = True,
mu_dtype: Optional[torch.dtype] = None,
precond_dtype: Optional[torch.dtype] = None,
decoupled_decay: bool = False,
deterministic: bool = False,
2025-01-24 14:16:21 -08:00
):
if not has_opt_einsum:
warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer." )
2025-01-24 14:16:21 -08:00
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= momentum < 1.0:
raise ValueError(f"Invalid beta parameter: {momentum}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
preconditioner_update_probability=preconditioner_update_probability,
max_size_triangular=max_size_triangular,
min_ndim_triangular=min_ndim_triangular,
memory_save_mode=memory_save_mode,
momentum_into_precond_update=momentum_into_precond_update,
precond_lr=0.1, # precond lr hardcoded to 0.1
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
mu_dtype=mu_dtype,
precond_dtype=precond_dtype,
decoupled_decay=decoupled_decay,
2025-01-24 14:16:21 -08:00
)
super(Kron, self).__init__(params, defaults)
self._param_exprs = {}
2025-01-24 14:16:21 -08:00
self._tiny = torch.finfo(torch.bfloat16).tiny
self.rng = random.Random(1337)
if deterministic:
# Use a Generator to try to be more deterministic across resume (save/load)
self.torch_rng = torch.Generator().manual_seed(1337)
else:
self.torch_rng = None
2025-01-24 14:16:21 -08:00
# make compile optional (for bwd compat)
if has_dynamo:
self._calc_A_and_conjB = torch.compile(_calc_A_and_conjB, fullgraph=True, dynamic=False)
self._q_terms = torch.compile(_q_terms, fullgraph=True, dynamic=False)
self._precond_grad = torch.compile(_precond_grad, fullgraph=True, dynamic=False)
self._balance_Q = torch.compile(_balance_Q, fullgraph=True, dynamic=False)
else:
self._calc_A_and_conjB = _calc_A_and_conjB
self._q_terms = _q_terms
self._precond_grad = _precond_grad
self._balance_Q = _balance_Q
def __getstate__(self):
_dict = super().__getstate__()
_dict["rng"] = self.rng
_dict["torch_rng"] = self.torch_rng
return _dict
def state_dict(self) -> Dict[str, Any]:
# Get the optimizer's state dict
optimizer_state = super().state_dict()
# Add the generator state
optimizer_state['rng_state'] = self.rng.getstate()
if self.torch_rng is not None:
optimizer_state['torch_rng_state'] = self.torch_rng.get_state()
return optimizer_state
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# Extract and remove the RNG state from the state dict
rng_state = state_dict.pop('rng_state', None)
torch_rng_state = state_dict.pop('torch_rng_state', None)
# Load the optimizer state
super().load_state_dict(state_dict)
# Restore the RNG state if it exists
if rng_state is not None:
self.rng.setstate(rng_state)
state_dict['rng_state'] = rng_state # put it back if caller still using state_dict
if torch_rng_state is not None:
if self.torch_rng is not None:
self.torch_rng.set_state(torch_rng_state)
state_dict['torch_rng_state'] = torch_rng_state # put it back if caller still using state_dict
def __setstate__(self, state):
super().__setstate__(state)
self._param_exprs = {}
2025-01-24 14:16:21 -08:00
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
total_momentum_size = 0
total_momentum_mb = 0
total_precond_size = 0
total_precond_mb = 0
for group in self.param_groups:
mu_dtype = group.get("mu_dtype")
precond_dtype = group.get("precond_dtype", torch.float32)
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
update_prob = group.get("preconditioner_update_probability", None)
2025-01-24 14:16:21 -08:00
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
if len(state) == 0:
state["step"] = 0
state["update_counter"] = 0
2025-01-24 14:16:21 -08:00
state["momentum_buffer"] = torch.zeros_like(p, dtype=mu_dtype or p.dtype)
state["Q"], exprs = _init_Q_exprs(
2025-01-24 14:16:21 -08:00
p,
group["precond_init_scale"],
group["max_size_triangular"],
group["min_ndim_triangular"],
group["memory_save_mode"],
dtype=precond_dtype,
)
self._param_exprs[p] = exprs
2025-01-24 14:16:21 -08:00
# Accumulate sizes for log
2025-01-24 14:16:21 -08:00
momentum_size = state["momentum_buffer"].numel()
momentum_mb = momentum_size * state["momentum_buffer"].element_size() / 2**20
total_momentum_size += momentum_size
total_momentum_mb += momentum_mb
precond_size = sum(q.numel() for q in state["Q"])
precond_mb = sum(q.numel() * q.element_size() for q in state["Q"]) / 2**20
total_precond_size += precond_size
total_precond_mb += precond_mb
elif p not in self._param_exprs:
exprs = _init_Q_exprs(
p,
group["precond_init_scale"],
group["max_size_triangular"],
group["min_ndim_triangular"],
group["memory_save_mode"],
dtype=precond_dtype,
init_q=False,
)
self._param_exprs[p] = exprs
else:
exprs = self._param_exprs[p]
# update preconditioners all together deterministically
if update_prob is None:
update_prob = precond_update_prob_schedule
if callable(update_prob):
update_prob = update_prob(state["step"])
state["update_counter"] += 1
do_update = state["update_counter"] >= 1 / update_prob
if do_update:
state["update_counter"] = 0
2025-01-24 14:16:21 -08:00
state["step"] += 1
# Update momentum buffer
beta = group["momentum"]
bias_correction = 1 - beta ** state["step"]
momentum_buffer = state["momentum_buffer"]
momentum_buffer.mul_(group["momentum"]).add_(grad, alpha=1 - group["momentum"])
2025-01-24 14:16:21 -08:00
# Restore momentum dtype
if mu_dtype is not None:
momentum_buffer.copy_(momentum_buffer.to(dtype=mu_dtype))
debiased_momentum = (momentum_buffer / bias_correction).to(dtype=precond_dtype)
2025-01-24 14:16:21 -08:00
# Balance preconditioners roughly every 100 updates
balance = self.rng.random() < 0.01 and do_update
2025-01-24 14:16:21 -08:00
if grad.dim() > 1 and balance:
self._balance_Q(state["Q"])
# Update preconditioner
if do_update:
exprA, exprGs, _ = exprs
2025-01-24 14:16:21 -08:00
Q = state["Q"]
if self.torch_rng is None:
V = torch.randn_like(debiased_momentum, dtype=precond_dtype)
else:
# Restoring generator state to device is messy. For now,
# we keep RNG on CPU, but this slows the optimizer down quite a bit.
# FIXME Need a better approach
V = torch.randn(
debiased_momentum.shape, generator=self.torch_rng, dtype=precond_dtype, device='cpu')
V = V.to(debiased_momentum.device)
2025-01-24 14:16:21 -08:00
G = debiased_momentum if momentum_into_precond_update else grad
A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)
terms = self._q_terms(exprGs, A, conjB)
for q, (term1, term2) in zip(Q, terms):
tmp = term1 - term2
tmp *= group["precond_lr"]
if q.dim() < 2:
tmp *= q
tmp /= (term1 + term2).norm(float("inf")) + self._tiny
else:
tmp = torch.triu(tmp)
tmp /= _norm_lower_bound(term1 + term2) + self._tiny
tmp @= q
q.sub_(tmp)
2025-01-24 14:16:21 -08:00
# Precondition gradients
pre_grad = self._precond_grad(
state["Q"],
exprs,
2025-01-24 14:16:21 -08:00
debiased_momentum,
).to(dtype=p.dtype)
2025-01-24 14:16:21 -08:00
# RMS of pre_grad should be 1.0, so let's cap at 1.1
pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0))
2025-01-24 14:16:21 -08:00
# Apply weight decay
if group["weight_decay"] != 0:
if group["decoupled_decay"]:
p.mul_(1. - group["lr"] * group["weight_decay"])
else:
pre_grad.add_(p, alpha=group["weight_decay"])
# Update parameters
2025-01-24 14:16:21 -08:00
p.add_(pre_grad, alpha=-group["lr"])
if total_momentum_size > 0:
_logger.info(f"PSGD Momentum buffer size: {total_momentum_size} elements, {total_momentum_mb:.2f} MB")
_logger.info(f"PSGD Preconditioners size: {total_precond_size} elements, {total_precond_mb:.2f} MB")
2025-01-24 14:16:21 -08:00
return loss
def _init_Q_exprs(
t,
scale,
max_size,
min_ndim_triangular,
memory_save_mode,
dtype=None,
init_q=True,
):
2025-01-24 14:16:21 -08:00
"""For a scalar or tensor t, we initialize its preconditioner Q and
reusable einsum expressions for updating Q and preconditioning gradient.
"""
letters = string.ascii_lowercase + string.ascii_uppercase
dtype = dtype if dtype is not None else t.dtype
shape = t.shape
Q = []
2025-01-24 14:16:21 -08:00
if len(shape) == 0: # scalar
if init_q:
Q.append(scale * torch.ones_like(t, dtype=dtype))
2025-01-24 14:16:21 -08:00
exprA = ",->"
exprGs = [",->"]
exprP = ",,->"
else: # tensor
if len(shape) > 13:
raise ValueError(f"Got tensor with dim {len(t.shape)}; Einstein runs out of letters!")
scale = scale ** (1 / len(shape))
if memory_save_mode is None:
dim_diag = [False for _ in shape]
elif memory_save_mode == "one_diag":
rev_sorted_dims = np.argsort(shape)[::-1]
dim_diag = [False for _ in shape]
dim_diag[rev_sorted_dims[0]] = True
elif memory_save_mode == "smart_one_diag":
dim_diag = [False for _ in shape]
rev_sorted_dims = np.argsort(shape)[::-1]
sorted_shape = sorted(shape)
if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
dim_diag[rev_sorted_dims[0]] = True
2025-01-24 14:16:21 -08:00
elif memory_save_mode == "all_diag":
dim_diag = [True for _ in shape]
else:
raise ValueError(
f"Invalid memory_save_mode: {memory_save_mode}, must be one of [None, 'one_diag', 'all_diag']")
piece1A, piece2A, piece3A = ([], "", "")
exprGs = []
piece1P, piece2P, piece3P, piece4P = ([], [], "", "")
for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
if (
size == 1
or size > max_size
or len(shape) < min_ndim_triangular
or dim_d
):
# use diagonal matrix as preconditioner for this dim
if init_q:
Q.append(scale * torch.ones(size, dtype=dtype, device=t.device))
2025-01-24 14:16:21 -08:00
piece1A.append(letters[i])
piece2A = piece2A + letters[i]
piece3A = piece3A + letters[i]
piece1 = "".join([letters[i + 13] if j == i else letters[j] for j in range(len(shape))])
subscripts = piece1 + "," + piece1 + "->" + letters[i + 13]
exprGs.append(subscripts)
piece1P.append(letters[i + 13])
piece2P.append(letters[i + 13])
piece3P = piece3P + letters[i + 13]
piece4P = piece4P + letters[i + 13]
else:
# use triangular matrix as preconditioner for this dim
if init_q:
Q.append(scale * torch.eye(size, dtype=dtype, device=t.device))
2025-01-24 14:16:21 -08:00
piece1A.append(letters[i] + letters[i + 13])
piece2A = piece2A + letters[i + 13]
piece3A = piece3A + letters[i]
piece1 = "".join([letters[i + 13] if j == i else letters[j] for j in range(len(shape))])
piece2 = "".join([letters[i + 26] if j == i else letters[j] for j in range(len(shape))])
subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26]
exprGs.append(subscripts)
a, b, c = (letters[i], letters[i + 13], letters[i + 26])
piece1P.append(a + b)
piece2P.append(a + c)
piece3P = piece3P + c
piece4P = piece4P + b
exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A
exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P
exprGs = tuple(exprGs)
if init_q:
return [Q, (exprA, exprGs, exprP)]
else:
return exprA, exprGs, exprP
2025-01-24 14:16:21 -08:00
def _lb(A, max_abs):
A = A / max_abs
aa = torch.real(A * A.conj())
value0, i = torch.max(torch.sum(aa, dim=0), 0)
value1, j = torch.max(torch.sum(aa, dim=1), 0)
if value0 > value1:
x = A[:, i].conj() @ A
return max_abs * torch.linalg.vector_norm((x / torch.linalg.vector_norm(x)) @ A.H)
else:
x = A @ A[j].conj()
return max_abs * torch.linalg.vector_norm(A.H @ (x / torch.linalg.vector_norm(x)))
def _norm_lower_bound(A):
"""Cheap lower bound for the spectral norm of A."""
max_abs = A.norm(float("inf"))
return torch.where(max_abs > 0, _lb(A, max_abs), max_abs)
def _solve_triangular_right(X, A):
"""X @ inv(A)"""
orig_dtype = X.dtype
X = X.to(dtype=torch.float32)
A = A.to(dtype=torch.float32)
2025-01-24 14:16:21 -08:00
out = torch.linalg.solve_triangular(A, X.reshape(-1, X.size(-1)), upper=True, left=False).reshape_as(X)
return out.to(dtype=orig_dtype)
2025-01-24 14:16:21 -08:00
def _balance_Q(Q_in):
norms = torch.stack([q.norm(float("inf")) for q in Q_in])
geometric_mean = norms.prod() ** (1 / len(Q_in))
norms = geometric_mean / norms
for i, q in enumerate(Q_in):
q.mul_(norms[i])
def _precond_grad(Q, exprs, G):
"""Precondition gradient G with preconditioner Q."""
return torch.einsum(exprs[-1], *[q.conj() for q in Q], *Q, G)
def _calc_A_and_conjB(exprA, G, Q, V):
A = torch.einsum(exprA, *Q, G)
order = G.dim()
p = tuple(range(order))
conjB = torch.permute(V.conj(), p[1:] + p[:1])
for i, q in enumerate(Q):
conjB = conjB / q if q.dim() < 2 else _solve_triangular_right(conjB, q)
if i < order - 1:
conjB = torch.transpose(conjB, i, order - 1)
return A, conjB
def _q_terms(exprGs, A, conjB):
terms = []
for exprG in exprGs:
term1 = torch.einsum(exprG, A, A.conj())
term2 = torch.einsum(exprG, conjB.conj(), conjB)
terms.append((term1, term2))
return terms