Update Adopt to include clipping for stability, separate wd so no param decay if update not taken on first step
parent
444c506ce3
commit
e5aea357b1
|
@ -13,7 +13,7 @@ Modified for reduced dependencies on PyTorch internals from original at: https:/
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import cast, List, Optional, Tuple, Union
|
from typing import cast, Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
@ -64,6 +64,7 @@ class Adopt(Optimizer):
|
||||||
lr: Union[float, Tensor] = 1e-3,
|
lr: Union[float, Tensor] = 1e-3,
|
||||||
betas: Tuple[float, float] = (0.9, 0.9999),
|
betas: Tuple[float, float] = (0.9, 0.9999),
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
|
clip_exp: Optional[float] = 0.333,
|
||||||
weight_decay: float = 0.0,
|
weight_decay: float = 0.0,
|
||||||
decoupled: bool = False,
|
decoupled: bool = False,
|
||||||
*,
|
*,
|
||||||
|
@ -95,6 +96,7 @@ class Adopt(Optimizer):
|
||||||
betas=betas,
|
betas=betas,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
|
clip_exp=clip_exp,
|
||||||
decoupled=decoupled,
|
decoupled=decoupled,
|
||||||
maximize=maximize,
|
maximize=maximize,
|
||||||
foreach=foreach,
|
foreach=foreach,
|
||||||
|
@ -111,6 +113,7 @@ class Adopt(Optimizer):
|
||||||
group.setdefault("foreach", None)
|
group.setdefault("foreach", None)
|
||||||
group.setdefault("capturable", False)
|
group.setdefault("capturable", False)
|
||||||
group.setdefault("differentiable", False)
|
group.setdefault("differentiable", False)
|
||||||
|
group.setdefault("clip_exp", None)
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
p_state = self.state.get(p, [])
|
p_state = self.state.get(p, [])
|
||||||
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
|
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
|
||||||
|
@ -141,9 +144,7 @@ class Adopt(Optimizer):
|
||||||
has_complex |= torch.is_complex(p)
|
has_complex |= torch.is_complex(p)
|
||||||
params_with_grad.append(p)
|
params_with_grad.append(p)
|
||||||
if p.grad.is_sparse:
|
if p.grad.is_sparse:
|
||||||
raise RuntimeError(
|
raise RuntimeError("ADOPT does not support sparse gradients")
|
||||||
"ADOPT does not support sparse gradients"
|
|
||||||
)
|
|
||||||
grads.append(p.grad)
|
grads.append(p.grad)
|
||||||
|
|
||||||
state = self.state[p]
|
state = self.state[p]
|
||||||
|
@ -153,36 +154,24 @@ class Adopt(Optimizer):
|
||||||
# Deliberately host `step` on CPU if both capturable and fused are off.
|
# Deliberately host `step` on CPU if both capturable and fused are off.
|
||||||
# This is because kernel launches are costly on CUDA and XLA.
|
# This is because kernel launches are costly on CUDA and XLA.
|
||||||
state["step"] = (
|
state["step"] = (
|
||||||
torch.zeros(
|
torch.zeros((), dtype=_get_scalar_dtype(), device=p.grad.device)
|
||||||
(),
|
|
||||||
dtype=_get_scalar_dtype(),
|
|
||||||
device=p.grad.device,
|
|
||||||
)
|
|
||||||
if group["capturable"]
|
if group["capturable"]
|
||||||
else torch.tensor(0.0, dtype=_get_scalar_dtype())
|
else torch.tensor(0.0, dtype=_get_scalar_dtype())
|
||||||
)
|
)
|
||||||
# Exponential moving average of gradient values
|
# Exponential moving average of gradient values
|
||||||
state["exp_avg"] = torch.zeros_like(
|
state["exp_avg"] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
|
||||||
p.grad, memory_format=torch.preserve_format
|
|
||||||
)
|
|
||||||
# Exponential moving average of squared gradient values
|
# Exponential moving average of squared gradient values
|
||||||
state["exp_avg_sq"] = torch.zeros_like(
|
state["exp_avg_sq"] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
|
||||||
p.grad, memory_format=torch.preserve_format
|
|
||||||
)
|
|
||||||
|
|
||||||
exp_avgs.append(state["exp_avg"])
|
exp_avgs.append(state["exp_avg"])
|
||||||
exp_avg_sqs.append(state["exp_avg_sq"])
|
exp_avg_sqs.append(state["exp_avg_sq"])
|
||||||
|
|
||||||
if group["differentiable"] and state["step"].requires_grad:
|
if group["differentiable"] and state["step"].requires_grad:
|
||||||
raise RuntimeError(
|
raise RuntimeError("`requires_grad` is not supported for `step` in differentiable mode")
|
||||||
"`requires_grad` is not supported for `step` in differentiable mode"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Foreach without capturable does not support a tensor lr
|
# Foreach without capturable does not support a tensor lr
|
||||||
if group["foreach"] and torch.is_tensor(group["lr"]) and not group["capturable"]:
|
if group["foreach"] and torch.is_tensor(group["lr"]) and not group["capturable"]:
|
||||||
raise RuntimeError(
|
raise RuntimeError("lr as a Tensor is not supported for capturable=False and foreach=True")
|
||||||
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
|
||||||
)
|
|
||||||
|
|
||||||
state_steps.append(state["step"])
|
state_steps.append(state["step"])
|
||||||
return has_complex
|
return has_complex
|
||||||
|
@ -231,6 +220,7 @@ class Adopt(Optimizer):
|
||||||
beta2=beta2,
|
beta2=beta2,
|
||||||
lr=group["lr"],
|
lr=group["lr"],
|
||||||
weight_decay=group["weight_decay"],
|
weight_decay=group["weight_decay"],
|
||||||
|
clip_exp=group["clip_exp"],
|
||||||
decoupled=group["decoupled"],
|
decoupled=group["decoupled"],
|
||||||
eps=group["eps"],
|
eps=group["eps"],
|
||||||
maximize=group["maximize"],
|
maximize=group["maximize"],
|
||||||
|
@ -258,6 +248,7 @@ def _single_tensor_adopt(
|
||||||
beta2: float,
|
beta2: float,
|
||||||
lr: Union[float, Tensor],
|
lr: Union[float, Tensor],
|
||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
|
clip_exp: Optional[float],
|
||||||
decoupled: bool,
|
decoupled: bool,
|
||||||
eps: float,
|
eps: float,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
|
@ -282,20 +273,12 @@ def _single_tensor_adopt(
|
||||||
if capturable and not _is_compiling():
|
if capturable and not _is_compiling():
|
||||||
from torch.optim.optimizer import _get_capturable_supported_devices
|
from torch.optim.optimizer import _get_capturable_supported_devices
|
||||||
capturable_supported_devices = _get_capturable_supported_devices()
|
capturable_supported_devices = _get_capturable_supported_devices()
|
||||||
assert (
|
assert param.device.type == step_t.device.type and param.device.type in capturable_supported_devices,\
|
||||||
param.device.type == step_t.device.type
|
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||||
and param.device.type in capturable_supported_devices
|
|
||||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
|
||||||
|
|
||||||
# update step
|
# update step
|
||||||
step_t += 1
|
step_t += 1
|
||||||
|
|
||||||
if weight_decay != 0:
|
|
||||||
if decoupled:
|
|
||||||
param.add_(param, alpha=-lr * weight_decay)
|
|
||||||
else:
|
|
||||||
grad = grad.add(param, alpha=weight_decay)
|
|
||||||
|
|
||||||
if torch.is_complex(param):
|
if torch.is_complex(param):
|
||||||
grad = torch.view_as_real(grad)
|
grad = torch.view_as_real(grad)
|
||||||
if exp_avg is not None:
|
if exp_avg is not None:
|
||||||
|
@ -304,17 +287,25 @@ def _single_tensor_adopt(
|
||||||
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
||||||
param = torch.view_as_real(param)
|
param = torch.view_as_real(param)
|
||||||
|
|
||||||
|
if weight_decay != 0 and not decoupled:
|
||||||
|
grad = grad.add(param, alpha=weight_decay)
|
||||||
|
|
||||||
step = step_t if capturable or differentiable else _get_value(step_t)
|
step = step_t if capturable or differentiable else _get_value(step_t)
|
||||||
if step == 1:
|
if step == 1:
|
||||||
exp_avg_sq.addcmul_(grad, grad.conj())
|
exp_avg_sq.addcmul_(grad, grad.conj())
|
||||||
continue
|
continue
|
||||||
|
|
||||||
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
|
if weight_decay != 0 and decoupled:
|
||||||
if step == 2:
|
param.add_(param, alpha=-lr * weight_decay)
|
||||||
exp_avg.addcdiv_(grad, denom)
|
|
||||||
else:
|
|
||||||
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
|
|
||||||
|
|
||||||
|
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
|
||||||
|
normed_grad = grad.div(denom)
|
||||||
|
|
||||||
|
if clip_exp is not None:
|
||||||
|
clip_val = (step - 1) ** clip_exp
|
||||||
|
normed_grad.clamp_(-clip_val, clip_val)
|
||||||
|
|
||||||
|
exp_avg.lerp_(normed_grad, 1 - beta1)
|
||||||
param.add_(exp_avg, alpha=-lr)
|
param.add_(exp_avg, alpha=-lr)
|
||||||
|
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
||||||
|
@ -334,6 +325,7 @@ def _multi_tensor_adopt(
|
||||||
beta2: float,
|
beta2: float,
|
||||||
lr: Union[float, Tensor],
|
lr: Union[float, Tensor],
|
||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
|
clip_exp: Optional[float],
|
||||||
decoupled: bool,
|
decoupled: bool,
|
||||||
eps: float,
|
eps: float,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
|
@ -355,8 +347,7 @@ def _multi_tensor_adopt(
|
||||||
supports_xla=False
|
supports_xla=False
|
||||||
)
|
)
|
||||||
assert all(
|
assert all(
|
||||||
p.device.type == step.device.type
|
p.device.type == step.device.type and p.device.type in capturable_supported_devices
|
||||||
and p.device.type in capturable_supported_devices
|
|
||||||
for p, step in zip(params, state_steps)
|
for p, step in zip(params, state_steps)
|
||||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||||
|
|
||||||
|
@ -382,9 +373,7 @@ def _multi_tensor_adopt(
|
||||||
|
|
||||||
# Handle complex parameters
|
# Handle complex parameters
|
||||||
if has_complex:
|
if has_complex:
|
||||||
_view_as_real(
|
_view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs)
|
||||||
device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
|
|
||||||
)
|
|
||||||
|
|
||||||
if maximize:
|
if maximize:
|
||||||
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
|
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
|
||||||
|
@ -394,44 +383,38 @@ def _multi_tensor_adopt(
|
||||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||||
if not _is_compiling() and device_state_steps[0].is_cpu:
|
if not _is_compiling() and device_state_steps[0].is_cpu:
|
||||||
torch._foreach_add_(
|
torch._foreach_add_(device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0)
|
||||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
torch._foreach_add_(device_state_steps, 1)
|
torch._foreach_add_(device_state_steps, 1)
|
||||||
|
|
||||||
if weight_decay != 0:
|
if weight_decay != 0 and not decoupled:
|
||||||
if decoupled:
|
|
||||||
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
|
|
||||||
else:
|
|
||||||
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||||
if maximize:
|
if maximize:
|
||||||
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
||||||
else:
|
else:
|
||||||
device_grads = torch._foreach_add( # type: ignore[assignment]
|
device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay)
|
||||||
device_grads, device_params, alpha=weight_decay
|
|
||||||
)
|
|
||||||
|
|
||||||
if device_state_steps[0] == 1:
|
if device_state_steps[0] == 1:
|
||||||
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
|
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if weight_decay != 0 and decoupled:
|
||||||
|
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
|
||||||
|
|
||||||
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||||
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
|
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
|
||||||
|
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
|
||||||
|
|
||||||
if device_state_steps[0] == 2:
|
if clip_exp is not None:
|
||||||
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
|
clip_val = (device_state_steps[0] - 1) ** clip_exp
|
||||||
else:
|
torch._foreach_maximum_(normed_grad, -clip_val)
|
||||||
torch._foreach_mul_(device_exp_avgs, beta1)
|
torch._foreach_minimum_(normed_grad, clip_val)
|
||||||
torch._foreach_addcdiv_(
|
|
||||||
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
|
|
||||||
)
|
|
||||||
|
|
||||||
|
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
|
||||||
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
||||||
|
|
||||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||||
torch._foreach_addcmul_(
|
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2)
|
||||||
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
#@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt) # FIXME internal context mgr, can't use
|
#@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt) # FIXME internal context mgr, can't use
|
||||||
|
@ -454,6 +437,7 @@ def adopt(
|
||||||
beta2: float,
|
beta2: float,
|
||||||
lr: Union[float, Tensor],
|
lr: Union[float, Tensor],
|
||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
|
clip_exp: Optional[float],
|
||||||
decoupled: bool,
|
decoupled: bool,
|
||||||
eps: float,
|
eps: float,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
|
@ -490,6 +474,7 @@ def adopt(
|
||||||
beta2=beta2,
|
beta2=beta2,
|
||||||
lr=lr,
|
lr=lr,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
|
clip_exp=clip_exp,
|
||||||
decoupled=decoupled,
|
decoupled=decoupled,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
maximize=maximize,
|
maximize=maximize,
|
||||||
|
|
Loading…
Reference in New Issue