Update Adopt to include clipping for stability, separate wd so no param decay if update not taken on first step

cautious_optim
Ross Wightman 2024-11-26 10:42:01 -08:00 committed by Ross Wightman
parent 444c506ce3
commit e5aea357b1
1 changed files with 50 additions and 65 deletions

View File

@ -13,7 +13,7 @@ Modified for reduced dependencies on PyTorch internals from original at: https:/
"""
from typing import cast, List, Optional, Tuple, Union
from typing import cast, Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor
@ -64,6 +64,7 @@ class Adopt(Optimizer):
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.9999),
eps: float = 1e-6,
clip_exp: Optional[float] = 0.333,
weight_decay: float = 0.0,
decoupled: bool = False,
*,
@ -95,6 +96,7 @@ class Adopt(Optimizer):
betas=betas,
eps=eps,
weight_decay=weight_decay,
clip_exp=clip_exp,
decoupled=decoupled,
maximize=maximize,
foreach=foreach,
@ -111,6 +113,7 @@ class Adopt(Optimizer):
group.setdefault("foreach", None)
group.setdefault("capturable", False)
group.setdefault("differentiable", False)
group.setdefault("clip_exp", None)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
@ -141,9 +144,7 @@ class Adopt(Optimizer):
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError(
"ADOPT does not support sparse gradients"
)
raise RuntimeError("ADOPT does not support sparse gradients")
grads.append(p.grad)
state = self.state[p]
@ -153,36 +154,24 @@ class Adopt(Optimizer):
# Deliberately host `step` on CPU if both capturable and fused are off.
# This is because kernel launches are costly on CUDA and XLA.
state["step"] = (
torch.zeros(
(),
dtype=_get_scalar_dtype(),
device=p.grad.device,
)
torch.zeros((), dtype=_get_scalar_dtype(), device=p.grad.device)
if group["capturable"]
else torch.tensor(0.0, dtype=_get_scalar_dtype())
)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p.grad, memory_format=torch.preserve_format
)
state["exp_avg"] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p.grad, memory_format=torch.preserve_format
)
state["exp_avg_sq"] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
if group["differentiable"] and state["step"].requires_grad:
raise RuntimeError(
"`requires_grad` is not supported for `step` in differentiable mode"
)
raise RuntimeError("`requires_grad` is not supported for `step` in differentiable mode")
# Foreach without capturable does not support a tensor lr
if group["foreach"] and torch.is_tensor(group["lr"]) and not group["capturable"]:
raise RuntimeError(
"lr as a Tensor is not supported for capturable=False and foreach=True"
)
raise RuntimeError("lr as a Tensor is not supported for capturable=False and foreach=True")
state_steps.append(state["step"])
return has_complex
@ -231,6 +220,7 @@ class Adopt(Optimizer):
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
clip_exp=group["clip_exp"],
decoupled=group["decoupled"],
eps=group["eps"],
maximize=group["maximize"],
@ -258,6 +248,7 @@ def _single_tensor_adopt(
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
clip_exp: Optional[float],
decoupled: bool,
eps: float,
maximize: bool,
@ -282,20 +273,12 @@ def _single_tensor_adopt(
if capturable and not _is_compiling():
from torch.optim.optimizer import _get_capturable_supported_devices
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == step_t.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
assert param.device.type == step_t.device.type and param.device.type in capturable_supported_devices,\
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
# update step
step_t += 1
if weight_decay != 0:
if decoupled:
param.add_(param, alpha=-lr * weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
if torch.is_complex(param):
grad = torch.view_as_real(grad)
if exp_avg is not None:
@ -304,17 +287,25 @@ def _single_tensor_adopt(
exp_avg_sq = torch.view_as_real(exp_avg_sq)
param = torch.view_as_real(param)
if weight_decay != 0 and not decoupled:
grad = grad.add(param, alpha=weight_decay)
step = step_t if capturable or differentiable else _get_value(step_t)
if step == 1:
exp_avg_sq.addcmul_(grad, grad.conj())
continue
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
if step == 2:
exp_avg.addcdiv_(grad, denom)
else:
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
if weight_decay != 0 and decoupled:
param.add_(param, alpha=-lr * weight_decay)
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
normed_grad = grad.div(denom)
if clip_exp is not None:
clip_val = (step - 1) ** clip_exp
normed_grad.clamp_(-clip_val, clip_val)
exp_avg.lerp_(normed_grad, 1 - beta1)
param.add_(exp_avg, alpha=-lr)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
@ -334,6 +325,7 @@ def _multi_tensor_adopt(
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
clip_exp: Optional[float],
decoupled: bool,
eps: float,
maximize: bool,
@ -355,8 +347,7 @@ def _multi_tensor_adopt(
supports_xla=False
)
assert all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
p.device.type == step.device.type and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
@ -382,9 +373,7 @@ def _multi_tensor_adopt(
# Handle complex parameters
if has_complex:
_view_as_real(
device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
)
_view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs)
if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
@ -394,44 +383,38 @@ def _multi_tensor_adopt(
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not _is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
torch._foreach_add_(device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0)
else:
torch._foreach_add_(device_state_steps, 1)
if weight_decay != 0:
if decoupled:
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
if weight_decay != 0 and not decoupled:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay)
if device_state_steps[0] == 1:
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
continue
if weight_decay != 0 and decoupled:
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
if device_state_steps[0] == 2:
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
else:
torch._foreach_mul_(device_exp_avgs, beta1)
torch._foreach_addcdiv_(
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
)
if clip_exp is not None:
clip_val = (device_state_steps[0] - 1) ** clip_exp
torch._foreach_maximum_(normed_grad, -clip_val)
torch._foreach_minimum_(normed_grad, clip_val)
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
torch._foreach_mul_(device_exp_avg_sqs, beta2)
torch._foreach_addcmul_(
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
)
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2)
#@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt) # FIXME internal context mgr, can't use
@ -454,6 +437,7 @@ def adopt(
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
clip_exp: Optional[float],
decoupled: bool,
eps: float,
maximize: bool,
@ -490,6 +474,7 @@ def adopt(
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
clip_exp=clip_exp,
decoupled=decoupled,
eps=eps,
maximize=maximize,