Update Adopt to include clipping for stability, separate wd so no param decay if update not taken on first step

cautious_optim
Ross Wightman 2024-11-26 10:42:01 -08:00 committed by Ross Wightman
parent 444c506ce3
commit e5aea357b1
1 changed files with 50 additions and 65 deletions

View File

@ -13,7 +13,7 @@ Modified for reduced dependencies on PyTorch internals from original at: https:/
""" """
from typing import cast, List, Optional, Tuple, Union from typing import cast, Callable, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
@ -64,6 +64,7 @@ class Adopt(Optimizer):
lr: Union[float, Tensor] = 1e-3, lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.9999), betas: Tuple[float, float] = (0.9, 0.9999),
eps: float = 1e-6, eps: float = 1e-6,
clip_exp: Optional[float] = 0.333,
weight_decay: float = 0.0, weight_decay: float = 0.0,
decoupled: bool = False, decoupled: bool = False,
*, *,
@ -95,6 +96,7 @@ class Adopt(Optimizer):
betas=betas, betas=betas,
eps=eps, eps=eps,
weight_decay=weight_decay, weight_decay=weight_decay,
clip_exp=clip_exp,
decoupled=decoupled, decoupled=decoupled,
maximize=maximize, maximize=maximize,
foreach=foreach, foreach=foreach,
@ -111,6 +113,7 @@ class Adopt(Optimizer):
group.setdefault("foreach", None) group.setdefault("foreach", None)
group.setdefault("capturable", False) group.setdefault("capturable", False)
group.setdefault("differentiable", False) group.setdefault("differentiable", False)
group.setdefault("clip_exp", None)
for p in group["params"]: for p in group["params"]:
p_state = self.state.get(p, []) p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
@ -141,9 +144,7 @@ class Adopt(Optimizer):
has_complex |= torch.is_complex(p) has_complex |= torch.is_complex(p)
params_with_grad.append(p) params_with_grad.append(p)
if p.grad.is_sparse: if p.grad.is_sparse:
raise RuntimeError( raise RuntimeError("ADOPT does not support sparse gradients")
"ADOPT does not support sparse gradients"
)
grads.append(p.grad) grads.append(p.grad)
state = self.state[p] state = self.state[p]
@ -153,36 +154,24 @@ class Adopt(Optimizer):
# Deliberately host `step` on CPU if both capturable and fused are off. # Deliberately host `step` on CPU if both capturable and fused are off.
# This is because kernel launches are costly on CUDA and XLA. # This is because kernel launches are costly on CUDA and XLA.
state["step"] = ( state["step"] = (
torch.zeros( torch.zeros((), dtype=_get_scalar_dtype(), device=p.grad.device)
(),
dtype=_get_scalar_dtype(),
device=p.grad.device,
)
if group["capturable"] if group["capturable"]
else torch.tensor(0.0, dtype=_get_scalar_dtype()) else torch.tensor(0.0, dtype=_get_scalar_dtype())
) )
# Exponential moving average of gradient values # Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like( state["exp_avg"] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
p.grad, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like( state["exp_avg_sq"] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
p.grad, memory_format=torch.preserve_format
)
exp_avgs.append(state["exp_avg"]) exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"]) exp_avg_sqs.append(state["exp_avg_sq"])
if group["differentiable"] and state["step"].requires_grad: if group["differentiable"] and state["step"].requires_grad:
raise RuntimeError( raise RuntimeError("`requires_grad` is not supported for `step` in differentiable mode")
"`requires_grad` is not supported for `step` in differentiable mode"
)
# Foreach without capturable does not support a tensor lr # Foreach without capturable does not support a tensor lr
if group["foreach"] and torch.is_tensor(group["lr"]) and not group["capturable"]: if group["foreach"] and torch.is_tensor(group["lr"]) and not group["capturable"]:
raise RuntimeError( raise RuntimeError("lr as a Tensor is not supported for capturable=False and foreach=True")
"lr as a Tensor is not supported for capturable=False and foreach=True"
)
state_steps.append(state["step"]) state_steps.append(state["step"])
return has_complex return has_complex
@ -231,6 +220,7 @@ class Adopt(Optimizer):
beta2=beta2, beta2=beta2,
lr=group["lr"], lr=group["lr"],
weight_decay=group["weight_decay"], weight_decay=group["weight_decay"],
clip_exp=group["clip_exp"],
decoupled=group["decoupled"], decoupled=group["decoupled"],
eps=group["eps"], eps=group["eps"],
maximize=group["maximize"], maximize=group["maximize"],
@ -258,6 +248,7 @@ def _single_tensor_adopt(
beta2: float, beta2: float,
lr: Union[float, Tensor], lr: Union[float, Tensor],
weight_decay: float, weight_decay: float,
clip_exp: Optional[float],
decoupled: bool, decoupled: bool,
eps: float, eps: float,
maximize: bool, maximize: bool,
@ -282,20 +273,12 @@ def _single_tensor_adopt(
if capturable and not _is_compiling(): if capturable and not _is_compiling():
from torch.optim.optimizer import _get_capturable_supported_devices from torch.optim.optimizer import _get_capturable_supported_devices
capturable_supported_devices = _get_capturable_supported_devices() capturable_supported_devices = _get_capturable_supported_devices()
assert ( assert param.device.type == step_t.device.type and param.device.type in capturable_supported_devices,\
param.device.type == step_t.device.type f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
# update step # update step
step_t += 1 step_t += 1
if weight_decay != 0:
if decoupled:
param.add_(param, alpha=-lr * weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
if torch.is_complex(param): if torch.is_complex(param):
grad = torch.view_as_real(grad) grad = torch.view_as_real(grad)
if exp_avg is not None: if exp_avg is not None:
@ -304,17 +287,25 @@ def _single_tensor_adopt(
exp_avg_sq = torch.view_as_real(exp_avg_sq) exp_avg_sq = torch.view_as_real(exp_avg_sq)
param = torch.view_as_real(param) param = torch.view_as_real(param)
if weight_decay != 0 and not decoupled:
grad = grad.add(param, alpha=weight_decay)
step = step_t if capturable or differentiable else _get_value(step_t) step = step_t if capturable or differentiable else _get_value(step_t)
if step == 1: if step == 1:
exp_avg_sq.addcmul_(grad, grad.conj()) exp_avg_sq.addcmul_(grad, grad.conj())
continue continue
denom = torch.clamp(exp_avg_sq.sqrt(), eps) if weight_decay != 0 and decoupled:
if step == 2: param.add_(param, alpha=-lr * weight_decay)
exp_avg.addcdiv_(grad, denom)
else:
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
normed_grad = grad.div(denom)
if clip_exp is not None:
clip_val = (step - 1) ** clip_exp
normed_grad.clamp_(-clip_val, clip_val)
exp_avg.lerp_(normed_grad, 1 - beta1)
param.add_(exp_avg, alpha=-lr) param.add_(exp_avg, alpha=-lr)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
@ -334,6 +325,7 @@ def _multi_tensor_adopt(
beta2: float, beta2: float,
lr: Union[float, Tensor], lr: Union[float, Tensor],
weight_decay: float, weight_decay: float,
clip_exp: Optional[float],
decoupled: bool, decoupled: bool,
eps: float, eps: float,
maximize: bool, maximize: bool,
@ -355,8 +347,7 @@ def _multi_tensor_adopt(
supports_xla=False supports_xla=False
) )
assert all( assert all(
p.device.type == step.device.type p.device.type == step.device.type and p.device.type in capturable_supported_devices
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps) for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
@ -382,9 +373,7 @@ def _multi_tensor_adopt(
# Handle complex parameters # Handle complex parameters
if has_complex: if has_complex:
_view_as_real( _view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs)
device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
)
if maximize: if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
@ -394,44 +383,38 @@ def _multi_tensor_adopt(
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload. # wrapped it once now. The alpha is required to assure we go to the right overload.
if not _is_compiling() and device_state_steps[0].is_cpu: if not _is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_( torch._foreach_add_(device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0)
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else: else:
torch._foreach_add_(device_state_steps, 1) torch._foreach_add_(device_state_steps, 1)
if weight_decay != 0: if weight_decay != 0 and not decoupled:
if decoupled:
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
else:
# Re-use the intermediate memory (device_grads) already allocated for maximize # Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize: if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay) torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else: else:
device_grads = torch._foreach_add( # type: ignore[assignment] device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay)
device_grads, device_params, alpha=weight_decay
)
if device_state_steps[0] == 1: if device_state_steps[0] == 1:
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads) torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
continue continue
if weight_decay != 0 and decoupled:
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps) torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
if device_state_steps[0] == 2: if clip_exp is not None:
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt) clip_val = (device_state_steps[0] - 1) ** clip_exp
else: torch._foreach_maximum_(normed_grad, -clip_val)
torch._foreach_mul_(device_exp_avgs, beta1) torch._foreach_minimum_(normed_grad, clip_val)
torch._foreach_addcdiv_(
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
)
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr) torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
torch._foreach_mul_(device_exp_avg_sqs, beta2) torch._foreach_mul_(device_exp_avg_sqs, beta2)
torch._foreach_addcmul_( torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2)
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
)
#@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt) # FIXME internal context mgr, can't use #@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt) # FIXME internal context mgr, can't use
@ -454,6 +437,7 @@ def adopt(
beta2: float, beta2: float,
lr: Union[float, Tensor], lr: Union[float, Tensor],
weight_decay: float, weight_decay: float,
clip_exp: Optional[float],
decoupled: bool, decoupled: bool,
eps: float, eps: float,
maximize: bool, maximize: bool,
@ -490,6 +474,7 @@ def adopt(
beta2=beta2, beta2=beta2,
lr=lr, lr=lr,
weight_decay=weight_decay, weight_decay=weight_decay,
clip_exp=clip_exp,
decoupled=decoupled, decoupled=decoupled,
eps=eps, eps=eps,
maximize=maximize, maximize=maximize,