""" PyTorch Implementation of the Kron PSGD optimizer FIXME attribution * https://github.com/evanatyourservice/kron_torch (direct source) * https://github.com/lixilinx/psgd_torch (original) * https://github.com/ClashLuke/HeavyBall (added improvements) """ import logging import string import random import warnings from typing import Any, Callable, Dict, Optional, Tuple, Union import numpy as np import torch try: # NOTE opt_einsum needed to avoid blowing up memory with einsum ops import opt_einsum opt_einsum.enabled = True opt_einsum.strategy = "auto-hq" import torch.backends.opt_einsum has_opt_einsum = True except ImportError: has_opt_einsum = False try: torch._dynamo.config.cache_size_limit = 1_000_000 has_dynamo = True except AttributeError: has_dynamo = False _logger = logging.getLogger(__name__) def precond_update_prob_schedule( n: float, max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.001, flat_start: float = 500, ) -> torch.Tensor: """Anneal preconditioner update probability during beginning of training. PSGD benefits from more preconditioner updates at the beginning of training, but once the preconditioner is learned the update probability can drop low. This schedule is an exponential anneal with a flat start. Default settings keep update probability at 1.0 for 200 steps then exponentially anneal down to `min_prob` by 4000 steps. Default settings work very well for most models and training regimes. """ """Exponential anneal with flat start.""" n = torch.tensor(n, dtype=torch.float32) prob = max_prob * torch.exp(-decay * (n - flat_start)) prob.clamp_(min=min_prob, max=max_prob) return prob class Kron(torch.optim.Optimizer): """Implements PSGD Kron from https://github.com/lixilinx/psgd_torch. Args: params: Iterable of parameters to optimize or dicts defining parameter groups. lr: Learning rate. momentum: Momentum parameter. weight_decay: Weight decay (L2 penalty). preconditioner_update_probability: Probability of updating the preconditioner. If None, defaults to a schedule that anneals from 1.0 to 0.03 by 4000 steps. max_size_triangular: Max size for dim's preconditioner to be triangular. min_ndim_triangular: Minimum number of dimensions a layer needs to have triangular preconditioners. memory_save_mode: 'one_diag', or 'all_diag', None is default to set all preconditioners to be triangular, 'one_diag' sets the largest or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners to be diagonal. momentum_into_precond_update: whether to send momentum into preconditioner update instead of raw gradients. mu_dtype: Dtype of the momentum accumulator. precond_dtype: Dtype of the preconditioner. decoupled_decay: AdamW style decoupled-decay. deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work """ def __init__( self, params, lr: float = 0.001, momentum: float = 0.9, weight_decay: float = 0.0, preconditioner_update_probability: Optional[Union[Callable, float]] = None, max_size_triangular: int = 2048, min_ndim_triangular: int = 2, memory_save_mode: Optional[str] = None, momentum_into_precond_update: bool = True, mu_dtype: Optional[torch.dtype] = None, precond_dtype: Optional[torch.dtype] = None, decoupled_decay: bool = False, deterministic: bool = False, ): if not has_opt_einsum: warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer." ) if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= momentum < 1.0: raise ValueError(f"Invalid beta parameter: {momentum}") if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict( lr=lr, momentum=momentum, weight_decay=weight_decay, preconditioner_update_probability=preconditioner_update_probability, max_size_triangular=max_size_triangular, min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode, momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1, # precond lr hardcoded to 0.1 precond_init_scale=1.0, # precond init scale hardcoded to 1.0 mu_dtype=mu_dtype, precond_dtype=precond_dtype, decoupled_decay=decoupled_decay, ) super(Kron, self).__init__(params, defaults) self._param_exprs = {} self._tiny = torch.finfo(torch.bfloat16).tiny self.rng = random.Random(1337) if deterministic: # Use a Generator to try to be more deterministic across resume (save/load) self.torch_rng = torch.Generator().manual_seed(1337) else: self.torch_rng = None # make compile optional (for bwd compat) if has_dynamo: self._calc_A_and_conjB = torch.compile(_calc_A_and_conjB, fullgraph=True, dynamic=False) self._q_terms = torch.compile(_q_terms, fullgraph=True, dynamic=False) self._precond_grad = torch.compile(_precond_grad, fullgraph=True, dynamic=False) self._balance_Q = torch.compile(_balance_Q, fullgraph=True, dynamic=False) else: self._calc_A_and_conjB = _calc_A_and_conjB self._q_terms = _q_terms self._precond_grad = _precond_grad self._balance_Q = _balance_Q def __getstate__(self): _dict = super().__getstate__() _dict["rng"] = self.rng _dict["torch_rng"] = self.torch_rng return _dict def state_dict(self) -> Dict[str, Any]: # Get the optimizer's state dict optimizer_state = super().state_dict() # Add the generator state optimizer_state['rng_state'] = self.rng.getstate() if self.torch_rng is not None: optimizer_state['torch_rng_state'] = self.torch_rng.get_state() return optimizer_state def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # Extract and remove the RNG state from the state dict rng_state = state_dict.pop('rng_state', None) torch_rng_state = state_dict.pop('torch_rng_state', None) # Load the optimizer state super().load_state_dict(state_dict) # Restore the RNG state if it exists if rng_state is not None: self.rng.setstate(rng_state) state_dict['rng_state'] = rng_state # put it back if caller still using state_dict if torch_rng_state is not None: if self.torch_rng is not None: self.torch_rng.set_state(torch_rng_state) state_dict['torch_rng_state'] = torch_rng_state # put it back if caller still using state_dict def __setstate__(self, state): super().__setstate__(state) self._param_exprs = {} @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() total_momentum_size = 0 total_momentum_mb = 0 total_precond_size = 0 total_precond_mb = 0 for group in self.param_groups: mu_dtype = group.get("mu_dtype") precond_dtype = group.get("precond_dtype", torch.float32) momentum_into_precond_update = group.get("momentum_into_precond_update", True) update_prob = group.get("preconditioner_update_probability", None) for p in group["params"]: if p.grad is None: continue grad = p.grad state = self.state[p] if len(state) == 0: state["step"] = 0 state["update_counter"] = 0 state["momentum_buffer"] = torch.zeros_like(p, dtype=mu_dtype or p.dtype) state["Q"], exprs = _init_Q_exprs( p, group["precond_init_scale"], group["max_size_triangular"], group["min_ndim_triangular"], group["memory_save_mode"], dtype=precond_dtype, ) self._param_exprs[p] = exprs # Accumulate sizes for log momentum_size = state["momentum_buffer"].numel() momentum_mb = momentum_size * state["momentum_buffer"].element_size() / 2**20 total_momentum_size += momentum_size total_momentum_mb += momentum_mb precond_size = sum(q.numel() for q in state["Q"]) precond_mb = sum(q.numel() * q.element_size() for q in state["Q"]) / 2**20 total_precond_size += precond_size total_precond_mb += precond_mb elif p not in self._param_exprs: exprs = _init_Q_exprs( p, group["precond_init_scale"], group["max_size_triangular"], group["min_ndim_triangular"], group["memory_save_mode"], dtype=precond_dtype, init_q=False, ) self._param_exprs[p] = exprs else: exprs = self._param_exprs[p] # update preconditioners all together deterministically if update_prob is None: update_prob = precond_update_prob_schedule if callable(update_prob): update_prob = update_prob(state["step"]) state["update_counter"] += 1 do_update = state["update_counter"] >= 1 / update_prob if do_update: state["update_counter"] = 0 state["step"] += 1 # Update momentum buffer beta = group["momentum"] bias_correction = 1 - beta ** state["step"] momentum_buffer = state["momentum_buffer"] momentum_buffer.mul_(group["momentum"]).add_(grad, alpha=1 - group["momentum"]) # Restore momentum dtype if mu_dtype is not None: momentum_buffer.copy_(momentum_buffer.to(dtype=mu_dtype)) debiased_momentum = (momentum_buffer / bias_correction).to(dtype=precond_dtype) # Balance preconditioners roughly every 100 updates balance = self.rng.random() < 0.01 and do_update if grad.dim() > 1 and balance: self._balance_Q(state["Q"]) # Update preconditioner if do_update: exprA, exprGs, _ = exprs Q = state["Q"] if self.torch_rng is None: V = torch.randn_like(debiased_momentum, dtype=precond_dtype) else: # Restoring generator state to device is messy. For now, # we keep RNG on CPU, but this slows the optimizer down quite a bit. # FIXME Need a better approach V = torch.randn( debiased_momentum.shape, generator=self.torch_rng, dtype=precond_dtype, device='cpu') V = V.to(debiased_momentum.device) G = debiased_momentum if momentum_into_precond_update else grad A, conjB = self._calc_A_and_conjB(exprA, G, Q, V) terms = self._q_terms(exprGs, A, conjB) for q, (term1, term2) in zip(Q, terms): tmp = term1 - term2 tmp *= group["precond_lr"] if q.dim() < 2: tmp *= q tmp /= (term1 + term2).norm(float("inf")) + self._tiny else: tmp = torch.triu(tmp) tmp /= _norm_lower_bound(term1 + term2) + self._tiny tmp @= q q.sub_(tmp) # Precondition gradients pre_grad = self._precond_grad( state["Q"], exprs, debiased_momentum, ).to(dtype=p.dtype) # RMS of pre_grad should be 1.0, so let's cap at 1.1 pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0)) # Apply weight decay if group["weight_decay"] != 0: if group["decoupled_decay"]: p.mul_(1. - group["lr"] * group["weight_decay"]) else: pre_grad.add_(p, alpha=group["weight_decay"]) # Update parameters p.add_(pre_grad, alpha=-group["lr"]) if total_momentum_size > 0: _logger.info(f"PSGD Momentum buffer size: {total_momentum_size} elements, {total_momentum_mb:.2f} MB") _logger.info(f"PSGD Preconditioners size: {total_precond_size} elements, {total_precond_mb:.2f} MB") return loss def _init_Q_exprs( t, scale, max_size, min_ndim_triangular, memory_save_mode, dtype=None, init_q=True, ): """For a scalar or tensor t, we initialize its preconditioner Q and reusable einsum expressions for updating Q and preconditioning gradient. """ letters = string.ascii_lowercase + string.ascii_uppercase dtype = dtype if dtype is not None else t.dtype shape = t.shape Q = [] if len(shape) == 0: # scalar if init_q: Q.append(scale * torch.ones_like(t, dtype=dtype)) exprA = ",->" exprGs = [",->"] exprP = ",,->" else: # tensor if len(shape) > 13: raise ValueError(f"Got tensor with dim {len(t.shape)}; Einstein runs out of letters!") scale = scale ** (1 / len(shape)) if memory_save_mode is None: dim_diag = [False for _ in shape] elif memory_save_mode == "one_diag": rev_sorted_dims = np.argsort(shape)[::-1] dim_diag = [False for _ in shape] dim_diag[rev_sorted_dims[0]] = True elif memory_save_mode == "smart_one_diag": dim_diag = [False for _ in shape] rev_sorted_dims = np.argsort(shape)[::-1] sorted_shape = sorted(shape) if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]: dim_diag[rev_sorted_dims[0]] = True elif memory_save_mode == "all_diag": dim_diag = [True for _ in shape] else: raise ValueError( f"Invalid memory_save_mode: {memory_save_mode}, must be one of [None, 'one_diag', 'all_diag']") piece1A, piece2A, piece3A = ([], "", "") exprGs = [] piece1P, piece2P, piece3P, piece4P = ([], [], "", "") for i, (size, dim_d) in enumerate(zip(shape, dim_diag)): if ( size == 1 or size > max_size or len(shape) < min_ndim_triangular or dim_d ): # use diagonal matrix as preconditioner for this dim if init_q: Q.append(scale * torch.ones(size, dtype=dtype, device=t.device)) piece1A.append(letters[i]) piece2A = piece2A + letters[i] piece3A = piece3A + letters[i] piece1 = "".join([letters[i + 13] if j == i else letters[j] for j in range(len(shape))]) subscripts = piece1 + "," + piece1 + "->" + letters[i + 13] exprGs.append(subscripts) piece1P.append(letters[i + 13]) piece2P.append(letters[i + 13]) piece3P = piece3P + letters[i + 13] piece4P = piece4P + letters[i + 13] else: # use triangular matrix as preconditioner for this dim if init_q: Q.append(scale * torch.eye(size, dtype=dtype, device=t.device)) piece1A.append(letters[i] + letters[i + 13]) piece2A = piece2A + letters[i + 13] piece3A = piece3A + letters[i] piece1 = "".join([letters[i + 13] if j == i else letters[j] for j in range(len(shape))]) piece2 = "".join([letters[i + 26] if j == i else letters[j] for j in range(len(shape))]) subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26] exprGs.append(subscripts) a, b, c = (letters[i], letters[i + 13], letters[i + 26]) piece1P.append(a + b) piece2P.append(a + c) piece3P = piece3P + c piece4P = piece4P + b exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P exprGs = tuple(exprGs) if init_q: return [Q, (exprA, exprGs, exprP)] else: return exprA, exprGs, exprP def _lb(A, max_abs): A = A / max_abs aa = torch.real(A * A.conj()) value0, i = torch.max(torch.sum(aa, dim=0), 0) value1, j = torch.max(torch.sum(aa, dim=1), 0) if value0 > value1: x = A[:, i].conj() @ A return max_abs * torch.linalg.vector_norm((x / torch.linalg.vector_norm(x)) @ A.H) else: x = A @ A[j].conj() return max_abs * torch.linalg.vector_norm(A.H @ (x / torch.linalg.vector_norm(x))) def _norm_lower_bound(A): """Cheap lower bound for the spectral norm of A.""" max_abs = A.norm(float("inf")) return torch.where(max_abs > 0, _lb(A, max_abs), max_abs) def _solve_triangular_right(X, A): """X @ inv(A)""" orig_dtype = X.dtype X = X.to(dtype=torch.float32) A = A.to(dtype=torch.float32) out = torch.linalg.solve_triangular(A, X.reshape(-1, X.size(-1)), upper=True, left=False).reshape_as(X) return out.to(dtype=orig_dtype) def _balance_Q(Q_in): norms = torch.stack([q.norm(float("inf")) for q in Q_in]) geometric_mean = norms.prod() ** (1 / len(Q_in)) norms = geometric_mean / norms for i, q in enumerate(Q_in): q.mul_(norms[i]) def _precond_grad(Q, exprs, G): """Precondition gradient G with preconditioner Q.""" return torch.einsum(exprs[-1], *[q.conj() for q in Q], *Q, G) def _calc_A_and_conjB(exprA, G, Q, V): A = torch.einsum(exprA, *Q, G) order = G.dim() p = tuple(range(order)) conjB = torch.permute(V.conj(), p[1:] + p[:1]) for i, q in enumerate(Q): conjB = conjB / q if q.dim() < 2 else _solve_triangular_right(conjB, q) if i < order - 1: conjB = torch.transpose(conjB, i, order - 1) return A, conjB def _q_terms(exprGs, A, conjB): terms = [] for exprG in exprGs: term1 = torch.einsum(exprG, A, A.conj()) term2 = torch.einsum(exprG, conjB.conj(), conjB) terms.append((term1, term2)) return terms