Remove adafactorbv numpy dep, hack fix for loading optimizer state w/ half prec momentum (need better one)

This commit is contained in:
Ross Wightman 2024-11-04 14:54:41 -08:00 committed by Ross Wightman
parent 19090ea966
commit 0b5ae49251

View File

@ -1,6 +1,5 @@
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from torch import Tensor
from torch.optim import Optimizer
@ -10,38 +9,36 @@ def _get_scalar_dtype():
"""Get the scalar dtype that the optimizer uses for state"""
return torch.float64
def _factored_dims(
shape: Tuple[int, ...],
factored: bool,
min_dim_size_to_factor: int
shape: Tuple[int, ...],
factored: bool,
min_dim_size_to_factor: int
) -> Optional[tuple[int, int]]:
"""Whether to use a factored second moment estimator.
"""Whether to use a factored second moment estimator.
This function returns a tuple with the two largest axes to reduce over.
If no two dimensions have size >= min_dim_size_to_factor, return None.
This function returns a tuple with the two largest axes to reduce over.
If no two dimensions have size >= min_dim_size_to_factor, return None.
Args:
shape: an input shape
factored: whether to use factored second-moment estimator for > 2d vars.
min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size.
Args:
shape: an input shape
factored: whether to use factored second-moment estimator for > 2d vars.
min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size.
Returns:
None or a tuple of ints
"""
if not factored or len(shape) < 2:
return None
sorted_dims = np.argsort(shape)
if shape[sorted_dims[-2]] < min_dim_size_to_factor:
return None
return int(sorted_dims[-2]), int(sorted_dims[-1])
Returns:
None or a tuple of ints
"""
if not factored or len(shape) < 2:
return None
sorted_dims = sorted(((x, i) for i, x in enumerate(shape)))
if shape[sorted_dims[-2][1]] < min_dim_size_to_factor:
return None
return int(sorted_dims[-2][1]), int(sorted_dims[-1][1])
class AdafactorBigVision(Optimizer):
"""
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
"""
def __init__(
@ -95,6 +92,12 @@ class AdafactorBigVision(Optimizer):
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
if 'exp_avg' in p_state and torch.is_tensor(p_state['exp_avg']):
# FIXME this is a bit of a hack, optimizer.load_state_dict appears to upcast
# the momentum to float32 (it's half precision in the state_dict), need to
# look into this further. Better to override _process_value_according_to_param_policy?
p_state['exp_avg'] = p_state['exp_avg'].to(dtype=self.defaults['momentum_dtype'])
@torch.no_grad()
def step(self, closure=None):
loss = None
@ -181,6 +184,7 @@ class AdafactorBigVision(Optimizer):
return loss
def _single_tensor_adafactor(
params: List[Tensor],
grads: List[Tensor],
@ -262,24 +266,25 @@ def _single_tensor_adafactor(
# Update parameters
param.add_(update, alpha=-1.0)
def _multi_tensor_adafactor(
params: List[Tensor],
grads: List[Tensor],
exp_avg_sq_rs: List[Optional[Tensor]],
exp_avg_sq_cs: List[Optional[Tensor]],
exp_avg_sqs: List[Optional[Tensor]],
exp_avgs: List[Optional[Tensor]],
state_steps: List[Tensor],
*,
beta2_decay: float,
beta2_cap: float,
min_dim_size_to_factor: int,
eps: float,
lr: float,
weight_decay: float,
momentum: Optional[float],
momentum_dtype: Union[str, torch.dtype],
clipping_threshold: Optional[float],
unscaled_wd: bool,
params: List[Tensor],
grads: List[Tensor],
exp_avg_sq_rs: List[Optional[Tensor]],
exp_avg_sq_cs: List[Optional[Tensor]],
exp_avg_sqs: List[Optional[Tensor]],
exp_avgs: List[Optional[Tensor]],
state_steps: List[Tensor],
*,
beta2_decay: float,
beta2_cap: float,
min_dim_size_to_factor: int,
eps: float,
lr: float,
weight_decay: float,
momentum: Optional[float],
momentum_dtype: Union[str, torch.dtype],
clipping_threshold: Optional[float],
unscaled_wd: bool,
):
assert False, 'multi-tensor fn (foreach=True) not implemented yet'