diff --git a/timm/optim/adopt.py b/timm/optim/adopt.py index 61320207..5cc7c18b 100644 --- a/timm/optim/adopt.py +++ b/timm/optim/adopt.py @@ -13,7 +13,7 @@ Modified for reduced dependencies on PyTorch internals from original at: https:/ """ -from typing import cast, List, Optional, Tuple, Union +from typing import cast, Callable, List, Optional, Tuple, Union import torch from torch import Tensor @@ -64,6 +64,7 @@ class Adopt(Optimizer): lr: Union[float, Tensor] = 1e-3, betas: Tuple[float, float] = (0.9, 0.9999), eps: float = 1e-6, + clip_exp: Optional[float] = 0.333, weight_decay: float = 0.0, decoupled: bool = False, *, @@ -95,6 +96,7 @@ class Adopt(Optimizer): betas=betas, eps=eps, weight_decay=weight_decay, + clip_exp=clip_exp, decoupled=decoupled, maximize=maximize, foreach=foreach, @@ -111,6 +113,7 @@ class Adopt(Optimizer): group.setdefault("foreach", None) group.setdefault("capturable", False) group.setdefault("differentiable", False) + group.setdefault("clip_exp", None) for p in group["params"]: p_state = self.state.get(p, []) if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): @@ -141,9 +144,7 @@ class Adopt(Optimizer): has_complex |= torch.is_complex(p) params_with_grad.append(p) if p.grad.is_sparse: - raise RuntimeError( - "ADOPT does not support sparse gradients" - ) + raise RuntimeError("ADOPT does not support sparse gradients") grads.append(p.grad) state = self.state[p] @@ -153,36 +154,24 @@ class Adopt(Optimizer): # Deliberately host `step` on CPU if both capturable and fused are off. # This is because kernel launches are costly on CUDA and XLA. state["step"] = ( - torch.zeros( - (), - dtype=_get_scalar_dtype(), - device=p.grad.device, - ) + torch.zeros((), dtype=_get_scalar_dtype(), device=p.grad.device) if group["capturable"] else torch.tensor(0.0, dtype=_get_scalar_dtype()) ) # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like( - p.grad, memory_format=torch.preserve_format - ) + state["exp_avg"] = torch.zeros_like(p.grad, memory_format=torch.preserve_format) # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like( - p.grad, memory_format=torch.preserve_format - ) + state["exp_avg_sq"] = torch.zeros_like(p.grad, memory_format=torch.preserve_format) exp_avgs.append(state["exp_avg"]) exp_avg_sqs.append(state["exp_avg_sq"]) if group["differentiable"] and state["step"].requires_grad: - raise RuntimeError( - "`requires_grad` is not supported for `step` in differentiable mode" - ) + raise RuntimeError("`requires_grad` is not supported for `step` in differentiable mode") # Foreach without capturable does not support a tensor lr if group["foreach"] and torch.is_tensor(group["lr"]) and not group["capturable"]: - raise RuntimeError( - "lr as a Tensor is not supported for capturable=False and foreach=True" - ) + raise RuntimeError("lr as a Tensor is not supported for capturable=False and foreach=True") state_steps.append(state["step"]) return has_complex @@ -231,6 +220,7 @@ class Adopt(Optimizer): beta2=beta2, lr=group["lr"], weight_decay=group["weight_decay"], + clip_exp=group["clip_exp"], decoupled=group["decoupled"], eps=group["eps"], maximize=group["maximize"], @@ -258,6 +248,7 @@ def _single_tensor_adopt( beta2: float, lr: Union[float, Tensor], weight_decay: float, + clip_exp: Optional[float], decoupled: bool, eps: float, maximize: bool, @@ -282,20 +273,12 @@ def _single_tensor_adopt( if capturable and not _is_compiling(): from torch.optim.optimizer import _get_capturable_supported_devices capturable_supported_devices = _get_capturable_supported_devices() - assert ( - param.device.type == step_t.device.type - and param.device.type in capturable_supported_devices - ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + assert param.device.type == step_t.device.type and param.device.type in capturable_supported_devices,\ + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." # update step step_t += 1 - if weight_decay != 0: - if decoupled: - param.add_(param, alpha=-lr * weight_decay) - else: - grad = grad.add(param, alpha=weight_decay) - if torch.is_complex(param): grad = torch.view_as_real(grad) if exp_avg is not None: @@ -304,17 +287,25 @@ def _single_tensor_adopt( exp_avg_sq = torch.view_as_real(exp_avg_sq) param = torch.view_as_real(param) + if weight_decay != 0 and not decoupled: + grad = grad.add(param, alpha=weight_decay) + step = step_t if capturable or differentiable else _get_value(step_t) if step == 1: exp_avg_sq.addcmul_(grad, grad.conj()) continue - denom = torch.clamp(exp_avg_sq.sqrt(), eps) - if step == 2: - exp_avg.addcdiv_(grad, denom) - else: - exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1) + if weight_decay != 0 and decoupled: + param.add_(param, alpha=-lr * weight_decay) + denom = torch.clamp(exp_avg_sq.sqrt(), eps) + normed_grad = grad.div(denom) + + if clip_exp is not None: + clip_val = (step - 1) ** clip_exp + normed_grad.clamp_(-clip_val, clip_val) + + exp_avg.lerp_(normed_grad, 1 - beta1) param.add_(exp_avg, alpha=-lr) exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) @@ -334,6 +325,7 @@ def _multi_tensor_adopt( beta2: float, lr: Union[float, Tensor], weight_decay: float, + clip_exp: Optional[float], decoupled: bool, eps: float, maximize: bool, @@ -355,8 +347,7 @@ def _multi_tensor_adopt( supports_xla=False ) assert all( - p.device.type == step.device.type - and p.device.type in capturable_supported_devices + p.device.type == step.device.type and p.device.type in capturable_supported_devices for p, step in zip(params, state_steps) ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." @@ -382,9 +373,7 @@ def _multi_tensor_adopt( # Handle complex parameters if has_complex: - _view_as_real( - device_params, device_grads, device_exp_avgs, device_exp_avg_sqs - ) + _view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs) if maximize: device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] @@ -394,44 +383,38 @@ def _multi_tensor_adopt( # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just # wrapped it once now. The alpha is required to assure we go to the right overload. if not _is_compiling() and device_state_steps[0].is_cpu: - torch._foreach_add_( - device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 - ) + torch._foreach_add_(device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0) else: torch._foreach_add_(device_state_steps, 1) - if weight_decay != 0: - if decoupled: - torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay) + if weight_decay != 0 and not decoupled: + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) else: - # Re-use the intermediate memory (device_grads) already allocated for maximize - if maximize: - torch._foreach_add_(device_grads, device_params, alpha=weight_decay) - else: - device_grads = torch._foreach_add( # type: ignore[assignment] - device_grads, device_params, alpha=weight_decay - ) + device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay) if device_state_steps[0] == 1: torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads) continue + if weight_decay != 0 and decoupled: + torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay) + exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) - exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps) + torch._foreach_maximum_(exp_avg_sq_sqrt, eps) + normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt) - if device_state_steps[0] == 2: - torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt) - else: - torch._foreach_mul_(device_exp_avgs, beta1) - torch._foreach_addcdiv_( - device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1 - ) + if clip_exp is not None: + clip_val = (device_state_steps[0] - 1) ** clip_exp + torch._foreach_maximum_(normed_grad, -clip_val) + torch._foreach_minimum_(normed_grad, clip_val) + torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1) torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr) + torch._foreach_mul_(device_exp_avg_sqs, beta2) - torch._foreach_addcmul_( - device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2 - ) + torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2) #@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt) # FIXME internal context mgr, can't use @@ -454,6 +437,7 @@ def adopt( beta2: float, lr: Union[float, Tensor], weight_decay: float, + clip_exp: Optional[float], decoupled: bool, eps: float, maximize: bool, @@ -490,6 +474,7 @@ def adopt( beta2=beta2, lr=lr, weight_decay=weight_decay, + clip_exp=clip_exp, decoupled=decoupled, eps=eps, maximize=maximize,