pytorch-image-models/timm/optim/kron.py

548 lines
21 KiB
Python

""" PyTorch Implementation of the Kron (PSGD) optimizer
This is a PSGD optimizer using a Kronecker-factored preconditioner.
This impl was adapted from https://github.com/evanatyourservice/kron_torch
by Evan Walters, licensed CC-BY-4.0.
Contributions to above also made by
* Lucas Nestler, added to his https://github.com/ClashLuke/HeavyBall implementation.
* Omead Pooladzandi https://github.com/opooladz
The above work drew from https://github.com/lixilinx/psgd_torch by Xi-Lin Li
This `timm` impl
* works with a wider variety of torch versions
* fixes some checkpoint save/restore (resume issues)
* adds decoupled weight-decay option
* has some refactoring, cleanup of args, default/group items
* warning about not having opt_einsum (unusable without)
"""
import logging
import string
import random
import warnings
from typing import Any, Callable, Dict, Optional, Tuple, Union
import numpy as np
import torch
try:
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
import opt_einsum
import torch.backends.opt_einsum
torch.backends.opt_einsum.enabled = True
torch.backends.opt_einsum.strategy = "auto-hq"
has_opt_einsum = True
except ImportError:
has_opt_einsum = False
try:
torch._dynamo.config.cache_size_limit = 1_000_000
has_dynamo = True
except AttributeError:
has_dynamo = False
from ._types import ParamsT
_logger = logging.getLogger(__name__)
def precond_update_prob_schedule(
n: float,
max_prob: float = 1.0,
min_prob: float = 0.03,
decay: float = 0.001,
flat_start: float = 500,
) -> torch.Tensor:
"""Anneal preconditioner update probability during beginning of training.
PSGD benefits from more preconditioner updates at the beginning of training,
but once the preconditioner is learned the update probability can drop low.
This schedule is an exponential anneal with a flat start. Default settings keep
update probability at 1.0 for 200 steps then exponentially anneal down to
`min_prob` by 4000 steps. Default settings work very well for most models and
training regimes.
"""
"""Exponential anneal with flat start."""
n = torch.tensor(n, dtype=torch.float32)
prob = max_prob * torch.exp(-decay * (n - flat_start))
prob.clamp_(min=min_prob, max=max_prob)
return prob
class Kron(torch.optim.Optimizer):
"""Implements PSGD Kron from https://github.com/lixilinx/psgd_torch.
Args:
params: Iterable of parameters to optimize or dicts defining parameter groups.
lr: Learning rate.
momentum: Momentum parameter.
weight_decay: Weight decay.
preconditioner_update_probability: Probability of updating the preconditioner.
If None, defaults to a schedule that anneals from 1.0 to 0.03 by 4000 steps.
max_size_triangular: Max size for dim's preconditioner to be triangular.
min_ndim_triangular: Minimum number of dimensions a layer needs to have triangular preconditioners.
memory_save_mode: 'one_diag', 'smart_one_diag', or 'all_diag', None is default
to set all preconditioners to be triangular, 'one_diag' sets the largest
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners to be diagonal.
momentum_into_precond_update: whether to send momentum into preconditioner
update instead of raw gradients.
mu_dtype: Dtype of the momentum accumulator.
precond_dtype: Dtype of the preconditioner.
decoupled_decay: AdamW style decoupled weight decay
flatten: Flatten dimensions instead of fully relying on expressions for higher rank params
flatten_start_end: Range of dimensions to flatten, defaults to (2, -1).
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
"""
def __init__(
self,
params: ParamsT,
lr: float = 0.001,
momentum: float = 0.9,
weight_decay: float = 0.0,
preconditioner_update_probability: Optional[Union[Callable, float]] = None,
max_size_triangular: int = 2048,
min_ndim_triangular: int = 2,
memory_save_mode: Optional[str] = None,
momentum_into_precond_update: bool = True,
precond_lr: float = 0.1,
precond_init_scale: float = 1.0,
mu_dtype: Optional[torch.dtype] = None,
precond_dtype: Optional[torch.dtype] = None,
decoupled_decay: bool = False,
flatten: bool = False,
flatten_start_end: Tuple[int, int] = (2, -1),
deterministic: bool = False,
):
if not has_opt_einsum:
warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer.")
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= momentum < 1.0:
raise ValueError(f"Invalid beta parameter: {momentum}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
preconditioner_update_probability=preconditioner_update_probability,
max_size_triangular=max_size_triangular,
min_ndim_triangular=min_ndim_triangular,
memory_save_mode=memory_save_mode,
momentum_into_precond_update=momentum_into_precond_update,
precond_lr=precond_lr,
precond_init_scale=precond_init_scale,
mu_dtype=mu_dtype,
precond_dtype=precond_dtype,
decoupled_decay=decoupled_decay,
flatten=flatten,
flatten_start_end=flatten_start_end,
)
super(Kron, self).__init__(params, defaults)
self._param_exprs = {} # cache for einsum expr
self._tiny = torch.finfo(torch.bfloat16).tiny
self.rng = random.Random(1337)
if deterministic:
# Use a Generator to try to be more deterministic across resume (save/load)
self.torch_rng = torch.Generator().manual_seed(1337)
else:
self.torch_rng = None
# make compile optional (for bwd compat)
if has_dynamo:
self._calc_A_and_conjB = torch.compile(_calc_A_and_conjB, fullgraph=True, dynamic=False)
self._q_terms = torch.compile(_q_terms, fullgraph=True, dynamic=False)
self._precond_grad = torch.compile(_precond_grad, fullgraph=True, dynamic=False)
self._balance_Q = torch.compile(_balance_Q, fullgraph=True, dynamic=False)
else:
self._calc_A_and_conjB = _calc_A_and_conjB
self._q_terms = _q_terms
self._precond_grad = _precond_grad
self._balance_Q = _balance_Q
def __getstate__(self):
_dict = super().__getstate__()
_dict["rng"] = self.rng
_dict["torch_rng"] = self.torch_rng
return _dict
def state_dict(self) -> Dict[str, Any]:
# Get the optimizer's state dict
optimizer_state = super().state_dict()
# Add the generator state
optimizer_state['rng_state'] = self.rng.getstate()
if self.torch_rng is not None:
optimizer_state['torch_rng_state'] = self.torch_rng.get_state()
return optimizer_state
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# Extract and remove the RNG state from the state dict
rng_states = {}
if 'rng_state' in state_dict:
rng_states['rng_state'] = state_dict.pop('rng_state')
if 'torch_rng_state' in state_dict:
rng_states['torch_rng_state'] = state_dict.pop('torch_rng_state')
# Load the optimizer state
super().load_state_dict(state_dict)
state_dict.update(rng_states) # add back
# Restore the RNG state if it exists
if 'rng_state' in rng_states:
self.rng.setstate(rng_states['rng_state'])
if 'torch_rng_state' in rng_states:
self.torch_rng.set_state(rng_states['torch_rng_state'])
def __setstate__(self, state):
super().__setstate__(state)
self._param_exprs = {}
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
total_momentum_size = 0
total_momentum_mb = 0
total_precond_size = 0
total_precond_mb = 0
for group in self.param_groups:
mu_dtype = group.get("mu_dtype")
precond_dtype = group.get("precond_dtype", torch.float32)
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
update_prob = group.get("preconditioner_update_probability", None)
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
flattened = False
if group['flatten']:
grad = safe_flatten(grad, *group["flatten_start_end"])
flattened = True
if len(state) == 0:
state["step"] = 0
state["update_counter"] = 0
state["momentum_buffer"] = torch.zeros_like(grad, dtype=mu_dtype or grad.dtype)
# init Q and einsum expressions on first step
state["Q"], exprs = _init_Q_exprs(
grad,
group["precond_init_scale"],
group["max_size_triangular"],
group["min_ndim_triangular"],
group["memory_save_mode"],
dtype=precond_dtype,
)
self._param_exprs[p] = exprs
# Accumulate sizes for log
momentum_size = state["momentum_buffer"].numel()
momentum_mb = momentum_size * state["momentum_buffer"].element_size() / 2**20
total_momentum_size += momentum_size
total_momentum_mb += momentum_mb
precond_size = sum(q.numel() for q in state["Q"])
precond_mb = sum(q.numel() * q.element_size() for q in state["Q"]) / 2**20
total_precond_size += precond_size
total_precond_mb += precond_mb
elif p not in self._param_exprs:
# init only the einsum expressions, called after state load, Q are loaded from state_dict
exprs = _init_Q_exprs(
grad,
group["precond_init_scale"],
group["max_size_triangular"],
group["min_ndim_triangular"],
group["memory_save_mode"],
dtype=precond_dtype,
init_q=False,
)
self._param_exprs[p] = exprs
else:
# retrieve cached expressions
exprs = self._param_exprs[p]
# update preconditioners all together deterministically
if update_prob is None:
update_prob = precond_update_prob_schedule
if callable(update_prob):
update_prob = update_prob(state["step"])
state["update_counter"] += 1
do_update = state["update_counter"] >= 1 / update_prob
if do_update:
state["update_counter"] = 0
state["step"] += 1
# Update momentum buffer
beta = group["momentum"]
bias_correction = 1 - beta ** state["step"]
momentum_buffer = state["momentum_buffer"]
momentum_buffer.mul_(group["momentum"]).add_(grad, alpha=1 - group["momentum"])
# Restore momentum dtype
if mu_dtype is not None:
momentum_buffer.copy_(momentum_buffer.to(dtype=mu_dtype))
debiased_momentum = (momentum_buffer / bias_correction).to(dtype=precond_dtype)
# Balance preconditioners roughly every 100 updates
balance = self.rng.random() < 0.01 and do_update
if grad.dim() > 1 and balance:
self._balance_Q(state["Q"])
# Update preconditioner
if do_update:
exprA, exprGs, _ = exprs
Q = state["Q"]
if self.torch_rng is None:
V = torch.randn_like(debiased_momentum, dtype=precond_dtype)
else:
# Restoring generator state to device is messy. For now,
# we keep RNG on CPU, but this slows the optimizer down quite a bit.
# FIXME Need a better approach
V = torch.randn(
debiased_momentum.shape, generator=self.torch_rng, dtype=precond_dtype, device='cpu')
V = V.to(debiased_momentum.device)
G = debiased_momentum if momentum_into_precond_update else grad
A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)
terms = self._q_terms(exprGs, A, conjB)
for q, (term1, term2) in zip(Q, terms):
tmp = term1 - term2
tmp *= group["precond_lr"]
if q.dim() < 2:
tmp *= q
tmp /= (term1 + term2).norm(float("inf")) + self._tiny
else:
tmp = torch.triu(tmp)
tmp /= _norm_lower_bound(term1 + term2) + self._tiny
tmp @= q
q.sub_(tmp)
# Precondition gradients
pre_grad = self._precond_grad(
state["Q"],
exprs,
debiased_momentum,
).to(dtype=p.dtype)
# RMS of pre_grad should be 1.0, so let's cap at 1.1
pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0))
if flattened:
pre_grad = pre_grad.view(p.shape)
# Apply weight decay
if group["weight_decay"] != 0:
if group["decoupled_decay"]:
p.mul_(1. - group["lr"] * group["weight_decay"])
else:
pre_grad.add_(p, alpha=group["weight_decay"])
# Update parameters
p.add_(pre_grad, alpha=-group["lr"])
if total_momentum_size > 0:
_logger.info(f"PSGD Momentum buffer size: {total_momentum_size} elements, {total_momentum_mb:.2f} MB")
_logger.info(f"PSGD Preconditioners size: {total_precond_size} elements, {total_precond_mb:.2f} MB")
return loss
def safe_flatten(tensor, start_dim=0, end_dim=-1):
ndim = tensor.ndim
# Convert negative end_dim to positive and clip to end
end_dim = min(end_dim if end_dim >= 0 else ndim + end_dim, ndim - 1)
# If tensor has fewer dims than start_dim or start > end, return tensor as is
if ndim <= start_dim or start_dim > end_dim:
return tensor
# Now safe to flatten
return tensor.flatten(start_dim, end_dim)
def _init_Q_exprs(
t,
scale,
max_size,
min_ndim_triangular,
memory_save_mode,
dtype=None,
init_q=True,
):
"""For a scalar or tensor t, we initialize its preconditioner Q and
reusable einsum expressions for updating Q and preconditioning gradient.
"""
letters = string.ascii_lowercase + string.ascii_uppercase
dtype = dtype if dtype is not None else t.dtype
shape = t.shape
Q = []
if len(shape) == 0: # scalar
if init_q:
Q.append(scale * torch.ones_like(t, dtype=dtype))
exprA = ",->"
exprGs = [",->"]
exprP = ",,->"
else: # tensor
if len(shape) > 13:
raise ValueError(f"Got tensor with dim {len(t.shape)}; Einstein runs out of letters!")
scale = scale ** (1 / len(shape))
if memory_save_mode is None:
dim_diag = [False for _ in shape]
elif memory_save_mode == "one_diag":
rev_sorted_dims = np.argsort(shape)[::-1]
dim_diag = [False for _ in shape]
dim_diag[rev_sorted_dims[0]] = True
elif memory_save_mode == "smart_one_diag":
# addition proposed by Lucas Nestler
rev_sorted_dims = np.argsort(shape)[::-1]
sorted_shape = sorted(shape)
dim_diag = [False for _ in shape]
if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
dim_diag[rev_sorted_dims[0]] = True
elif memory_save_mode == "all_diag":
dim_diag = [True for _ in shape]
else:
raise ValueError(
f"Invalid memory_save_mode: {memory_save_mode}, must be one of [None, 'one_diag', 'all_diag']")
piece1A, piece2A, piece3A = ([], "", "")
exprGs = []
piece1P, piece2P, piece3P, piece4P = ([], [], "", "")
for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
if (
size == 1
or size > max_size
or len(shape) < min_ndim_triangular
or dim_d
):
# use diagonal matrix as preconditioner for this dim
if init_q:
Q.append(scale * torch.ones(size, dtype=dtype, device=t.device))
piece1A.append(letters[i])
piece2A = piece2A + letters[i]
piece3A = piece3A + letters[i]
piece1 = "".join([letters[i + 13] if j == i else letters[j] for j in range(len(shape))])
subscripts = piece1 + "," + piece1 + "->" + letters[i + 13]
exprGs.append(subscripts)
piece1P.append(letters[i + 13])
piece2P.append(letters[i + 13])
piece3P = piece3P + letters[i + 13]
piece4P = piece4P + letters[i + 13]
else:
# use triangular matrix as preconditioner for this dim
if init_q:
Q.append(scale * torch.eye(size, dtype=dtype, device=t.device))
piece1A.append(letters[i] + letters[i + 13])
piece2A = piece2A + letters[i + 13]
piece3A = piece3A + letters[i]
piece1 = "".join([letters[i + 13] if j == i else letters[j] for j in range(len(shape))])
piece2 = "".join([letters[i + 26] if j == i else letters[j] for j in range(len(shape))])
subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26]
exprGs.append(subscripts)
a, b, c = (letters[i], letters[i + 13], letters[i + 26])
piece1P.append(a + b)
piece2P.append(a + c)
piece3P = piece3P + c
piece4P = piece4P + b
exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A
exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P
exprGs = tuple(exprGs)
if init_q:
return [Q, (exprA, exprGs, exprP)]
else:
return exprA, exprGs, exprP
def _lb(A, max_abs):
A = A / max_abs
aa = torch.real(A * A.conj())
value0, i = torch.max(torch.sum(aa, dim=0), 0)
value1, j = torch.max(torch.sum(aa, dim=1), 0)
if value0 > value1:
x = A[:, i].conj() @ A
return max_abs * torch.linalg.vector_norm((x / torch.linalg.vector_norm(x)) @ A.H)
else:
x = A @ A[j].conj()
return max_abs * torch.linalg.vector_norm(A.H @ (x / torch.linalg.vector_norm(x)))
def _norm_lower_bound(A):
"""Cheap lower bound for the spectral norm of A."""
max_abs = A.norm(float("inf"))
return torch.where(max_abs > 0, _lb(A, max_abs), max_abs)
def _solve_triangular_right(X, A):
"""X @ inv(A)"""
orig_dtype = X.dtype
X = X.to(dtype=torch.float32)
A = A.to(dtype=torch.float32)
out = torch.linalg.solve_triangular(A, X.reshape(-1, X.size(-1)), upper=True, left=False).reshape_as(X)
return out.to(dtype=orig_dtype)
def _balance_Q(Q_in):
norms = torch.stack([q.norm(float("inf")) for q in Q_in])
geometric_mean = norms.prod() ** (1 / len(Q_in))
norms = geometric_mean / norms
for i, q in enumerate(Q_in):
q.mul_(norms[i])
def _precond_grad(Q, exprs, G):
"""Precondition gradient G with preconditioner Q."""
return torch.einsum(exprs[-1], *[q.conj() for q in Q], *Q, G)
def _calc_A_and_conjB(exprA, G, Q, V):
A = torch.einsum(exprA, *Q, G)
order = G.dim()
p = tuple(range(order))
conjB = torch.permute(V.conj(), p[1:] + p[:1])
for i, q in enumerate(Q):
conjB = conjB / q if q.dim() < 2 else _solve_triangular_right(conjB, q)
if i < order - 1:
conjB = torch.transpose(conjB, i, order - 1)
return A, conjB
def _q_terms(exprGs, A, conjB):
terms = []
for exprG in exprGs:
term1 = torch.einsum(exprG, A, A.conj())
term2 = torch.einsum(exprG, conjB.conj(), conjB)
terms.append((term1, term2))
return terms