Remove adafactorbv numpy dep, hack fix for loading optimizer state w/ half prec momentum (need better one)

This commit is contained in:
Ross Wightman 2024-11-04 14:54:41 -08:00
parent 91f0ea3338
commit 548fdb5d71

View File

@ -1,6 +1,5 @@
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
@ -10,38 +9,36 @@ def _get_scalar_dtype():
"""Get the scalar dtype that the optimizer uses for state""" """Get the scalar dtype that the optimizer uses for state"""
return torch.float64 return torch.float64
def _factored_dims( def _factored_dims(
shape: Tuple[int, ...], shape: Tuple[int, ...],
factored: bool, factored: bool,
min_dim_size_to_factor: int min_dim_size_to_factor: int
) -> Optional[tuple[int, int]]: ) -> Optional[tuple[int, int]]:
"""Whether to use a factored second moment estimator. """Whether to use a factored second moment estimator.
This function returns a tuple with the two largest axes to reduce over. This function returns a tuple with the two largest axes to reduce over.
If no two dimensions have size >= min_dim_size_to_factor, return None. If no two dimensions have size >= min_dim_size_to_factor, return None.
Args: Args:
shape: an input shape shape: an input shape
factored: whether to use factored second-moment estimator for > 2d vars. factored: whether to use factored second-moment estimator for > 2d vars.
min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size. min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size.
Returns: Returns:
None or a tuple of ints None or a tuple of ints
""" """
if not factored or len(shape) < 2: if not factored or len(shape) < 2:
return None return None
sorted_dims = np.argsort(shape) sorted_dims = sorted(((x, i) for i, x in enumerate(shape)))
if shape[sorted_dims[-2]] < min_dim_size_to_factor: if shape[sorted_dims[-2][1]] < min_dim_size_to_factor:
return None return None
return int(sorted_dims[-2]), int(sorted_dims[-1]) return int(sorted_dims[-2][1]), int(sorted_dims[-1][1])
class AdafactorBigVision(Optimizer): class AdafactorBigVision(Optimizer):
""" """
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations. PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
""" """
def __init__( def __init__(
@ -95,6 +92,12 @@ class AdafactorBigVision(Optimizer):
if len(p_state) != 0 and not torch.is_tensor(p_state['step']): if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype()) p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
if 'exp_avg' in p_state and torch.is_tensor(p_state['exp_avg']):
# FIXME this is a bit of a hack, optimizer.load_state_dict appears to upcast
# the momentum to float32 (it's half precision in the state_dict), need to
# look into this further. Better to override _process_value_according_to_param_policy?
p_state['exp_avg'] = p_state['exp_avg'].to(dtype=self.defaults['momentum_dtype'])
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
loss = None loss = None
@ -181,6 +184,7 @@ class AdafactorBigVision(Optimizer):
return loss return loss
def _single_tensor_adafactor( def _single_tensor_adafactor(
params: List[Tensor], params: List[Tensor],
grads: List[Tensor], grads: List[Tensor],
@ -262,24 +266,25 @@ def _single_tensor_adafactor(
# Update parameters # Update parameters
param.add_(update, alpha=-1.0) param.add_(update, alpha=-1.0)
def _multi_tensor_adafactor( def _multi_tensor_adafactor(
params: List[Tensor], params: List[Tensor],
grads: List[Tensor], grads: List[Tensor],
exp_avg_sq_rs: List[Optional[Tensor]], exp_avg_sq_rs: List[Optional[Tensor]],
exp_avg_sq_cs: List[Optional[Tensor]], exp_avg_sq_cs: List[Optional[Tensor]],
exp_avg_sqs: List[Optional[Tensor]], exp_avg_sqs: List[Optional[Tensor]],
exp_avgs: List[Optional[Tensor]], exp_avgs: List[Optional[Tensor]],
state_steps: List[Tensor], state_steps: List[Tensor],
*, *,
beta2_decay: float, beta2_decay: float,
beta2_cap: float, beta2_cap: float,
min_dim_size_to_factor: int, min_dim_size_to_factor: int,
eps: float, eps: float,
lr: float, lr: float,
weight_decay: float, weight_decay: float,
momentum: Optional[float], momentum: Optional[float],
momentum_dtype: Union[str, torch.dtype], momentum_dtype: Union[str, torch.dtype],
clipping_threshold: Optional[float], clipping_threshold: Optional[float],
unscaled_wd: bool, unscaled_wd: bool,
): ):
assert False, 'multi-tensor fn (foreach=True) not implemented yet' assert False, 'multi-tensor fn (foreach=True) not implemented yet'