diff --git a/timm/optim/adafactor_bv.py b/timm/optim/adafactor_bv.py index 58b18032..ea8a4afa 100644 --- a/timm/optim/adafactor_bv.py +++ b/timm/optim/adafactor_bv.py @@ -1,6 +1,5 @@ from typing import List, Optional, Tuple, Union -import numpy as np import torch from torch import Tensor from torch.optim import Optimizer @@ -10,38 +9,36 @@ def _get_scalar_dtype(): """Get the scalar dtype that the optimizer uses for state""" return torch.float64 + def _factored_dims( - shape: Tuple[int, ...], - factored: bool, - min_dim_size_to_factor: int + shape: Tuple[int, ...], + factored: bool, + min_dim_size_to_factor: int ) -> Optional[tuple[int, int]]: - """Whether to use a factored second moment estimator. + """Whether to use a factored second moment estimator. - This function returns a tuple with the two largest axes to reduce over. - If no two dimensions have size >= min_dim_size_to_factor, return None. + This function returns a tuple with the two largest axes to reduce over. + If no two dimensions have size >= min_dim_size_to_factor, return None. - Args: - shape: an input shape - factored: whether to use factored second-moment estimator for > 2d vars. - min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size. + Args: + shape: an input shape + factored: whether to use factored second-moment estimator for > 2d vars. + min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size. - Returns: - None or a tuple of ints - """ - if not factored or len(shape) < 2: - return None - sorted_dims = np.argsort(shape) - if shape[sorted_dims[-2]] < min_dim_size_to_factor: - return None - return int(sorted_dims[-2]), int(sorted_dims[-1]) + Returns: + None or a tuple of ints + """ + if not factored or len(shape) < 2: + return None + sorted_dims = sorted(((x, i) for i, x in enumerate(shape))) + if shape[sorted_dims[-2][1]] < min_dim_size_to_factor: + return None + return int(sorted_dims[-2][1]), int(sorted_dims[-1][1]) class AdafactorBigVision(Optimizer): """ PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations. - - - """ def __init__( @@ -95,6 +92,12 @@ class AdafactorBigVision(Optimizer): if len(p_state) != 0 and not torch.is_tensor(p_state['step']): p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype()) + if 'exp_avg' in p_state and torch.is_tensor(p_state['exp_avg']): + # FIXME this is a bit of a hack, optimizer.load_state_dict appears to upcast + # the momentum to float32 (it's half precision in the state_dict), need to + # look into this further. Better to override _process_value_according_to_param_policy? + p_state['exp_avg'] = p_state['exp_avg'].to(dtype=self.defaults['momentum_dtype']) + @torch.no_grad() def step(self, closure=None): loss = None @@ -181,6 +184,7 @@ class AdafactorBigVision(Optimizer): return loss + def _single_tensor_adafactor( params: List[Tensor], grads: List[Tensor], @@ -262,24 +266,25 @@ def _single_tensor_adafactor( # Update parameters param.add_(update, alpha=-1.0) + def _multi_tensor_adafactor( - params: List[Tensor], - grads: List[Tensor], - exp_avg_sq_rs: List[Optional[Tensor]], - exp_avg_sq_cs: List[Optional[Tensor]], - exp_avg_sqs: List[Optional[Tensor]], - exp_avgs: List[Optional[Tensor]], - state_steps: List[Tensor], - *, - beta2_decay: float, - beta2_cap: float, - min_dim_size_to_factor: int, - eps: float, - lr: float, - weight_decay: float, - momentum: Optional[float], - momentum_dtype: Union[str, torch.dtype], - clipping_threshold: Optional[float], - unscaled_wd: bool, + params: List[Tensor], + grads: List[Tensor], + exp_avg_sq_rs: List[Optional[Tensor]], + exp_avg_sq_cs: List[Optional[Tensor]], + exp_avg_sqs: List[Optional[Tensor]], + exp_avgs: List[Optional[Tensor]], + state_steps: List[Tensor], + *, + beta2_decay: float, + beta2_cap: float, + min_dim_size_to_factor: int, + eps: float, + lr: float, + weight_decay: float, + momentum: Optional[float], + momentum_dtype: Union[str, torch.dtype], + clipping_threshold: Optional[float], + unscaled_wd: bool, ): assert False, 'multi-tensor fn (foreach=True) not implemented yet'