Remove adafactorbv numpy dep, hack fix for loading optimizer state w/ half prec momentum (need better one)

This commit is contained in:
Ross Wightman 2024-11-04 14:54:41 -08:00 committed by Ross Wightman
parent 19090ea966
commit 0b5ae49251

View File

@ -1,6 +1,5 @@
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from torch import Tensor
from torch.optim import Optimizer
@ -10,6 +9,7 @@ def _get_scalar_dtype():
"""Get the scalar dtype that the optimizer uses for state"""
return torch.float64
def _factored_dims(
shape: Tuple[int, ...],
factored: bool,
@ -30,18 +30,15 @@ def _factored_dims(
"""
if not factored or len(shape) < 2:
return None
sorted_dims = np.argsort(shape)
if shape[sorted_dims[-2]] < min_dim_size_to_factor:
sorted_dims = sorted(((x, i) for i, x in enumerate(shape)))
if shape[sorted_dims[-2][1]] < min_dim_size_to_factor:
return None
return int(sorted_dims[-2]), int(sorted_dims[-1])
return int(sorted_dims[-2][1]), int(sorted_dims[-1][1])
class AdafactorBigVision(Optimizer):
"""
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
"""
def __init__(
@ -95,6 +92,12 @@ class AdafactorBigVision(Optimizer):
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
if 'exp_avg' in p_state and torch.is_tensor(p_state['exp_avg']):
# FIXME this is a bit of a hack, optimizer.load_state_dict appears to upcast
# the momentum to float32 (it's half precision in the state_dict), need to
# look into this further. Better to override _process_value_according_to_param_policy?
p_state['exp_avg'] = p_state['exp_avg'].to(dtype=self.defaults['momentum_dtype'])
@torch.no_grad()
def step(self, closure=None):
loss = None
@ -181,6 +184,7 @@ class AdafactorBigVision(Optimizer):
return loss
def _single_tensor_adafactor(
params: List[Tensor],
grads: List[Tensor],
@ -262,6 +266,7 @@ def _single_tensor_adafactor(
# Update parameters
param.add_(update, alpha=-1.0)
def _multi_tensor_adafactor(
params: List[Tensor],
grads: List[Tensor],