Update Adopt to include clipping for stability, separate wd so no param decay if update not taken on first step
parent
444c506ce3
commit
e5aea357b1
|
@ -13,7 +13,7 @@ Modified for reduced dependencies on PyTorch internals from original at: https:/
|
|||
|
||||
"""
|
||||
|
||||
from typing import cast, List, Optional, Tuple, Union
|
||||
from typing import cast, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -64,6 +64,7 @@ class Adopt(Optimizer):
|
|||
lr: Union[float, Tensor] = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.9999),
|
||||
eps: float = 1e-6,
|
||||
clip_exp: Optional[float] = 0.333,
|
||||
weight_decay: float = 0.0,
|
||||
decoupled: bool = False,
|
||||
*,
|
||||
|
@ -95,6 +96,7 @@ class Adopt(Optimizer):
|
|||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
clip_exp=clip_exp,
|
||||
decoupled=decoupled,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
|
@ -111,6 +113,7 @@ class Adopt(Optimizer):
|
|||
group.setdefault("foreach", None)
|
||||
group.setdefault("capturable", False)
|
||||
group.setdefault("differentiable", False)
|
||||
group.setdefault("clip_exp", None)
|
||||
for p in group["params"]:
|
||||
p_state = self.state.get(p, [])
|
||||
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
|
||||
|
@ -141,9 +144,7 @@ class Adopt(Optimizer):
|
|||
has_complex |= torch.is_complex(p)
|
||||
params_with_grad.append(p)
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"ADOPT does not support sparse gradients"
|
||||
)
|
||||
raise RuntimeError("ADOPT does not support sparse gradients")
|
||||
grads.append(p.grad)
|
||||
|
||||
state = self.state[p]
|
||||
|
@ -153,36 +154,24 @@ class Adopt(Optimizer):
|
|||
# Deliberately host `step` on CPU if both capturable and fused are off.
|
||||
# This is because kernel launches are costly on CUDA and XLA.
|
||||
state["step"] = (
|
||||
torch.zeros(
|
||||
(),
|
||||
dtype=_get_scalar_dtype(),
|
||||
device=p.grad.device,
|
||||
)
|
||||
torch.zeros((), dtype=_get_scalar_dtype(), device=p.grad.device)
|
||||
if group["capturable"]
|
||||
else torch.tensor(0.0, dtype=_get_scalar_dtype())
|
||||
)
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(
|
||||
p.grad, memory_format=torch.preserve_format
|
||||
)
|
||||
state["exp_avg"] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p.grad, memory_format=torch.preserve_format
|
||||
)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
|
||||
|
||||
exp_avgs.append(state["exp_avg"])
|
||||
exp_avg_sqs.append(state["exp_avg_sq"])
|
||||
|
||||
if group["differentiable"] and state["step"].requires_grad:
|
||||
raise RuntimeError(
|
||||
"`requires_grad` is not supported for `step` in differentiable mode"
|
||||
)
|
||||
raise RuntimeError("`requires_grad` is not supported for `step` in differentiable mode")
|
||||
|
||||
# Foreach without capturable does not support a tensor lr
|
||||
if group["foreach"] and torch.is_tensor(group["lr"]) and not group["capturable"]:
|
||||
raise RuntimeError(
|
||||
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||
)
|
||||
raise RuntimeError("lr as a Tensor is not supported for capturable=False and foreach=True")
|
||||
|
||||
state_steps.append(state["step"])
|
||||
return has_complex
|
||||
|
@ -231,6 +220,7 @@ class Adopt(Optimizer):
|
|||
beta2=beta2,
|
||||
lr=group["lr"],
|
||||
weight_decay=group["weight_decay"],
|
||||
clip_exp=group["clip_exp"],
|
||||
decoupled=group["decoupled"],
|
||||
eps=group["eps"],
|
||||
maximize=group["maximize"],
|
||||
|
@ -258,6 +248,7 @@ def _single_tensor_adopt(
|
|||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
weight_decay: float,
|
||||
clip_exp: Optional[float],
|
||||
decoupled: bool,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
|
@ -282,20 +273,12 @@ def _single_tensor_adopt(
|
|||
if capturable and not _is_compiling():
|
||||
from torch.optim.optimizer import _get_capturable_supported_devices
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
and param.device.type in capturable_supported_devices
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
assert param.device.type == step_t.device.type and param.device.type in capturable_supported_devices,\
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
||||
# update step
|
||||
step_t += 1
|
||||
|
||||
if weight_decay != 0:
|
||||
if decoupled:
|
||||
param.add_(param, alpha=-lr * weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
||||
if torch.is_complex(param):
|
||||
grad = torch.view_as_real(grad)
|
||||
if exp_avg is not None:
|
||||
|
@ -304,17 +287,25 @@ def _single_tensor_adopt(
|
|||
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
||||
param = torch.view_as_real(param)
|
||||
|
||||
if weight_decay != 0 and not decoupled:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
||||
step = step_t if capturable or differentiable else _get_value(step_t)
|
||||
if step == 1:
|
||||
exp_avg_sq.addcmul_(grad, grad.conj())
|
||||
continue
|
||||
|
||||
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
|
||||
if step == 2:
|
||||
exp_avg.addcdiv_(grad, denom)
|
||||
else:
|
||||
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
|
||||
if weight_decay != 0 and decoupled:
|
||||
param.add_(param, alpha=-lr * weight_decay)
|
||||
|
||||
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
|
||||
normed_grad = grad.div(denom)
|
||||
|
||||
if clip_exp is not None:
|
||||
clip_val = (step - 1) ** clip_exp
|
||||
normed_grad.clamp_(-clip_val, clip_val)
|
||||
|
||||
exp_avg.lerp_(normed_grad, 1 - beta1)
|
||||
param.add_(exp_avg, alpha=-lr)
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
||||
|
@ -334,6 +325,7 @@ def _multi_tensor_adopt(
|
|||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
weight_decay: float,
|
||||
clip_exp: Optional[float],
|
||||
decoupled: bool,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
|
@ -355,8 +347,7 @@ def _multi_tensor_adopt(
|
|||
supports_xla=False
|
||||
)
|
||||
assert all(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
p.device.type == step.device.type and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
||||
|
@ -382,9 +373,7 @@ def _multi_tensor_adopt(
|
|||
|
||||
# Handle complex parameters
|
||||
if has_complex:
|
||||
_view_as_real(
|
||||
device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
|
||||
)
|
||||
_view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs)
|
||||
|
||||
if maximize:
|
||||
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
|
||||
|
@ -394,44 +383,38 @@ def _multi_tensor_adopt(
|
|||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not _is_compiling() and device_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
torch._foreach_add_(device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0)
|
||||
else:
|
||||
torch._foreach_add_(device_state_steps, 1)
|
||||
|
||||
if weight_decay != 0:
|
||||
if decoupled:
|
||||
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
|
||||
if weight_decay != 0 and not decoupled:
|
||||
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||
if maximize:
|
||||
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
||||
else:
|
||||
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||
if maximize:
|
||||
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
||||
else:
|
||||
device_grads = torch._foreach_add( # type: ignore[assignment]
|
||||
device_grads, device_params, alpha=weight_decay
|
||||
)
|
||||
device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay)
|
||||
|
||||
if device_state_steps[0] == 1:
|
||||
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
|
||||
continue
|
||||
|
||||
if weight_decay != 0 and decoupled:
|
||||
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
|
||||
|
||||
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
|
||||
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
|
||||
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
|
||||
|
||||
if device_state_steps[0] == 2:
|
||||
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
|
||||
else:
|
||||
torch._foreach_mul_(device_exp_avgs, beta1)
|
||||
torch._foreach_addcdiv_(
|
||||
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
|
||||
)
|
||||
if clip_exp is not None:
|
||||
clip_val = (device_state_steps[0] - 1) ** clip_exp
|
||||
torch._foreach_maximum_(normed_grad, -clip_val)
|
||||
torch._foreach_minimum_(normed_grad, clip_val)
|
||||
|
||||
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
|
||||
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
||||
|
||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||
torch._foreach_addcmul_(
|
||||
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
|
||||
)
|
||||
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2)
|
||||
|
||||
|
||||
#@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt) # FIXME internal context mgr, can't use
|
||||
|
@ -454,6 +437,7 @@ def adopt(
|
|||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
weight_decay: float,
|
||||
clip_exp: Optional[float],
|
||||
decoupled: bool,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
|
@ -490,6 +474,7 @@ def adopt(
|
|||
beta2=beta2,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
clip_exp=clip_exp,
|
||||
decoupled=decoupled,
|
||||
eps=eps,
|
||||
maximize=maximize,
|
||||
|
|
Loading…
Reference in New Issue