Remove adafactorbv numpy dep, hack fix for loading optimizer state w/ half prec momentum (need better one)

This commit is contained in:
Ross Wightman 2024-11-04 14:54:41 -08:00 committed by Ross Wightman
parent 19090ea966
commit 0b5ae49251

View File

@ -1,6 +1,5 @@
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
@ -10,6 +9,7 @@ def _get_scalar_dtype():
"""Get the scalar dtype that the optimizer uses for state""" """Get the scalar dtype that the optimizer uses for state"""
return torch.float64 return torch.float64
def _factored_dims( def _factored_dims(
shape: Tuple[int, ...], shape: Tuple[int, ...],
factored: bool, factored: bool,
@ -30,18 +30,15 @@ def _factored_dims(
""" """
if not factored or len(shape) < 2: if not factored or len(shape) < 2:
return None return None
sorted_dims = np.argsort(shape) sorted_dims = sorted(((x, i) for i, x in enumerate(shape)))
if shape[sorted_dims[-2]] < min_dim_size_to_factor: if shape[sorted_dims[-2][1]] < min_dim_size_to_factor:
return None return None
return int(sorted_dims[-2]), int(sorted_dims[-1]) return int(sorted_dims[-2][1]), int(sorted_dims[-1][1])
class AdafactorBigVision(Optimizer): class AdafactorBigVision(Optimizer):
""" """
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations. PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
""" """
def __init__( def __init__(
@ -95,6 +92,12 @@ class AdafactorBigVision(Optimizer):
if len(p_state) != 0 and not torch.is_tensor(p_state['step']): if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype()) p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
if 'exp_avg' in p_state and torch.is_tensor(p_state['exp_avg']):
# FIXME this is a bit of a hack, optimizer.load_state_dict appears to upcast
# the momentum to float32 (it's half precision in the state_dict), need to
# look into this further. Better to override _process_value_according_to_param_policy?
p_state['exp_avg'] = p_state['exp_avg'].to(dtype=self.defaults['momentum_dtype'])
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
loss = None loss = None
@ -181,6 +184,7 @@ class AdafactorBigVision(Optimizer):
return loss return loss
def _single_tensor_adafactor( def _single_tensor_adafactor(
params: List[Tensor], params: List[Tensor],
grads: List[Tensor], grads: List[Tensor],
@ -262,6 +266,7 @@ def _single_tensor_adafactor(
# Update parameters # Update parameters
param.add_(update, alpha=-1.0) param.add_(update, alpha=-1.0)
def _multi_tensor_adafactor( def _multi_tensor_adafactor(
params: List[Tensor], params: List[Tensor],
grads: List[Tensor], grads: List[Tensor],