diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py index d2eb67ba..5d30be41 100644 --- a/timm/optim/__init__.py +++ b/timm/optim/__init__.py @@ -4,6 +4,7 @@ from .adahessian import Adahessian from .adamp import AdamP from .adamw import AdamW from .adan import Adan +from .adafactor_bv import AdafactorBigVision from .lamb import Lamb from .lars import Lars from .lookahead import Lookahead diff --git a/timm/optim/adafactor_bv.py b/timm/optim/adafactor_bv.py new file mode 100644 index 00000000..d320f63d --- /dev/null +++ b/timm/optim/adafactor_bv.py @@ -0,0 +1,288 @@ +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import Tensor +from torch.optim import Optimizer + + +def _get_scalar_dtype(): + """Get the scalar dtype that the optimizer uses for state""" + return torch.float64 + +def _factored_dims( + shape: Tuple[int, ...], + factored: bool, + min_dim_size_to_factor: int +) -> Optional[tuple[int, int]]: + """Whether to use a factored second moment estimator. + + This function returns a tuple with the two largest axes to reduce over. + If no two dimensions have size >= min_dim_size_to_factor, return None. + + Args: + shape: an input shape + factored: whether to use factored second-moment estimator for > 2d vars. + min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size. + + Returns: + None or a tuple of ints + """ + if not factored or len(shape) < 2: + return None + sorted_dims = np.argsort(shape) + if shape[sorted_dims[-2]] < min_dim_size_to_factor: + return None + return int(sorted_dims[-2]), int(sorted_dims[-1]) + + +class AdafactorBigVision(Optimizer): + """ + PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations. + + + + """ + + def __init__( + self, + params, + lr: float = 1.0, + min_dim_size_to_factor: int = 32, + decay_rate: float = 0.8, + decay_offset: int = 0, + beta2_cap: float = 0.999, + momentum: Optional[float] = 0.9, + momentum_dtype: Union[str, torch.dtype] = torch.bfloat16, + eps: float = 1e-30, + weight_decay: float = 0.0, + clipping_threshold: Optional[float] = None, + unscaled_wd: bool = False, + *, + foreach: Optional[bool] = False, + ): + if isinstance(momentum_dtype, str): + if momentum_dtype == 'float16': + momentum_dtype = torch.float16 + elif momentum_dtype == 'bfloat16': + momentum_dtype = torch.bfloat16 + else: + assert momentum_dtype == 'float32', f'{momentum_dtype} dtype not supported' + momentum_dtype = torch.float32 + + defaults = dict( + lr=lr, + min_dim_size_to_factor=min_dim_size_to_factor, + decay_rate=decay_rate, + decay_offset=decay_offset, + beta2_cap=beta2_cap, + momentum=momentum, + momentum_dtype=momentum_dtype, + eps=eps, + weight_decay=weight_decay, + clipping_threshold=clipping_threshold, + unscaled_wd=unscaled_wd, + foreach=foreach, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('foreach', None) + for p in group['params']: + p_state = self.state.get(p, {}) + if len(p_state) != 0 and not torch.is_tensor(p_state['step']): + p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype()) + + def _get_beta2(self, step: Tensor, decay_rate: float, beta2_cap: float) -> float: + """Computes beta2 according to the step schedule""" + t = float(step + 1) + return min(beta2_cap, 1.0 - t ** (-decay_rate)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avg_sq_rs = [] + exp_avg_sq_cs = [] + exp_avg_sqs = [] + state_steps = [] + exp_avgs = [] # For momentum + + for p in group['params']: + if p.grad is None: + continue + + if p.grad.is_sparse: + raise RuntimeError("Sparse gradients not supported") + + params_with_grad.append(p) + grads.append(p.grad) + + state = self.state[p] + + if len(state) == 0: + # NOTE step on CPU, probably need some more though to make capturable + state['step'] = torch.tensor(0.0, dtype=_get_scalar_dtype()) + + shape = p.grad.shape + factored_dims = _factored_dims( + shape, + factored=True, + min_dim_size_to_factor=self.defaults['min_dim_size_to_factor'] + ) + + if factored_dims is not None: + d1, d0 = factored_dims + row_shape = list(p.grad.shape) + row_shape[d0] = 1 + col_shape = list(p.grad.shape) + col_shape[d1] = 1 + state['exp_avg_sq_r'] = p.grad.new_zeros(row_shape) + state['exp_avg_sq_c'] = p.grad.new_zeros(col_shape) + else: + state['exp_avg_sq'] = torch.zeros_like(p.grad, memory_format=torch.preserve_format) + + if self.defaults['momentum'] is not None: + state['exp_avg'] = torch.zeros_like(p.grad, dtype=torch.bfloat16) + + state_steps.append(state['step']) + exp_avg_sq_rs.append(state.get('exp_avg_sq_r', None)) + exp_avg_sq_cs.append(state.get('exp_avg_sq_c', None)) + exp_avg_sqs.append(state.get('exp_avg_sq', None)) + exp_avgs.append(state.get('exp_avg', None)) + + if group['foreach']: + func = _multi_tensor_adafactor + else: + func = _single_tensor_adafactor + + func( + params=params_with_grad, + grads=grads, + exp_avg_sq_rs=exp_avg_sq_rs, + exp_avg_sq_cs=exp_avg_sq_cs, + exp_avg_sqs=exp_avg_sqs, + exp_avgs=exp_avgs, + state_steps=state_steps, + beta2_decay=group['decay_rate'], + beta2_cap=group['beta2_cap'], + min_dim_size_to_factor=group['min_dim_size_to_factor'], + eps=group['eps'], + lr=group['lr'], + weight_decay=group['weight_decay'], + momentum=group['momentum'], + momentum_dtype=group['momentum_dtype'], + clipping_threshold=group['clipping_threshold'], + unscaled_wd=group['unscaled_wd'], + ) + + return loss + +def _single_tensor_adafactor( + params: List[Tensor], + grads: List[Tensor], + exp_avg_sq_rs: List[Optional[Tensor]], + exp_avg_sq_cs: List[Optional[Tensor]], + exp_avg_sqs: List[Optional[Tensor]], + exp_avgs: List[Optional[Tensor]], + state_steps: List[Tensor], + *, + beta2_decay: float, + beta2_cap: float, + min_dim_size_to_factor: int, + eps: float, + lr: float, + weight_decay: float, + momentum: Optional[float], + momentum_dtype: Union[str, torch.dtype], + clipping_threshold: Optional[float], + unscaled_wd: bool, +): + for i, param in enumerate(params): + grad = grads[i] + exp_avg_sq_r = exp_avg_sq_rs[i] + exp_avg_sq_c = exp_avg_sq_cs[i] + exp_avg_sq = exp_avg_sqs[i] + exp_avg = exp_avgs[i] + step_t = state_steps[i] + + # Update step + step_t += 1 + beta2_t = min(beta2_cap, 1.0 - float(step_t) ** (-beta2_decay)) + one_minus_beta2_t = 1 - beta2_t + + if exp_avg_sq is None: + d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor) + grad_sqr = torch.square(grad) + eps + exp_avg_sq_r.lerp_(grad_sqr.mean(dim=d0, keepdim=True), one_minus_beta2_t) + exp_avg_sq_c.lerp_(grad_sqr.mean(dim=d1, keepdim=True), one_minus_beta2_t) + + reduced_d1 = d1 - 1 if d1 > d0 else d1 + row_col_mean = exp_avg_sq_r.mean(dim=reduced_d1, keepdim=True) + row_factor = (exp_avg_sq_r / row_col_mean).rsqrt() + col_factor = exp_avg_sq_c.rsqrt() + + update = grad * row_factor * col_factor + else: + # Handle non-factored + exp_avg_sq.mul_(beta2_t).addcmul_(grad, grad, value=one_minus_beta2_t) + update = grad * exp_avg_sq.add(eps).rsqrt_() + + # Clip by RMS value + if clipping_threshold is not None: + denom = (update.norm(2) / ((update.numel() ** 0.5) / clipping_threshold)).clamp_(max=1.0) + update.div_(denom) + + # Apply momentum (in different dtype) + if momentum is not None and exp_avg is not None: + if momentum_dtype != grad.dtype: + exp_avg.lerp_(update.to(momentum_dtype), 1 - momentum) # ema + update = exp_avg.to(grad.dtype) + else: + exp_avg.lerp_(update, 1 - momentum) # ema + update = exp_avg.clone() + + # Scale by learning rate + update.mul_(lr) + + # Perform weight decay + if weight_decay != 0: + if unscaled_wd: + # match big vision impl, 'fully decoupled' decay w/o LR scaling + param.mul_(1. - weight_decay) + else: + # match typical pytorch behaviour for decoupled decay, eg adamw where wd is scaled by LR + param.mul_(1. - lr * weight_decay) + + # Update parameters + param.add_(update, alpha=-1.0) + +def _multi_tensor_adafactor( + params: List[Tensor], + grads: List[Tensor], + exp_avg_sq_rs: List[Optional[Tensor]], + exp_avg_sq_cs: List[Optional[Tensor]], + exp_avg_sqs: List[Optional[Tensor]], + exp_avgs: List[Optional[Tensor]], + state_steps: List[Tensor], + *, + beta2_decay: float, + beta2_cap: float, + min_dim_size_to_factor: int, + eps: float, + lr: float, + weight_decay: float, + momentum: Optional[float], + momentum_dtype: Union[str, torch.dtype], + clipping_threshold: Optional[float], + unscaled_wd: bool, +): + assert False, 'multi-tensor fn (foreach=True) not implemented yet' diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index 8187b55a..63a0b1b1 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -10,6 +10,7 @@ import torch.nn as nn import torch.optim as optim from timm.models import group_parameters +from . import AdafactorBigVision from .adabelief import AdaBelief from .adafactor import Adafactor @@ -356,6 +357,8 @@ def create_optimizer_v2( elif opt_lower == 'lion': opt_args.pop('eps', None) optimizer = Lion(parameters, **opt_args) + elif opt_lower == 'adafactorbv': + optimizer = AdafactorBigVision(parameters, **opt_args) # second order elif opt_lower == 'adahessian':