mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove adafactorbv numpy dep, hack fix for loading optimizer state w/ half prec momentum (need better one)
This commit is contained in:
parent
19090ea966
commit
0b5ae49251
@ -1,6 +1,5 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
@ -10,38 +9,36 @@ def _get_scalar_dtype():
|
||||
"""Get the scalar dtype that the optimizer uses for state"""
|
||||
return torch.float64
|
||||
|
||||
|
||||
def _factored_dims(
|
||||
shape: Tuple[int, ...],
|
||||
factored: bool,
|
||||
min_dim_size_to_factor: int
|
||||
shape: Tuple[int, ...],
|
||||
factored: bool,
|
||||
min_dim_size_to_factor: int
|
||||
) -> Optional[tuple[int, int]]:
|
||||
"""Whether to use a factored second moment estimator.
|
||||
"""Whether to use a factored second moment estimator.
|
||||
|
||||
This function returns a tuple with the two largest axes to reduce over.
|
||||
If no two dimensions have size >= min_dim_size_to_factor, return None.
|
||||
This function returns a tuple with the two largest axes to reduce over.
|
||||
If no two dimensions have size >= min_dim_size_to_factor, return None.
|
||||
|
||||
Args:
|
||||
shape: an input shape
|
||||
factored: whether to use factored second-moment estimator for > 2d vars.
|
||||
min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size.
|
||||
Args:
|
||||
shape: an input shape
|
||||
factored: whether to use factored second-moment estimator for > 2d vars.
|
||||
min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size.
|
||||
|
||||
Returns:
|
||||
None or a tuple of ints
|
||||
"""
|
||||
if not factored or len(shape) < 2:
|
||||
return None
|
||||
sorted_dims = np.argsort(shape)
|
||||
if shape[sorted_dims[-2]] < min_dim_size_to_factor:
|
||||
return None
|
||||
return int(sorted_dims[-2]), int(sorted_dims[-1])
|
||||
Returns:
|
||||
None or a tuple of ints
|
||||
"""
|
||||
if not factored or len(shape) < 2:
|
||||
return None
|
||||
sorted_dims = sorted(((x, i) for i, x in enumerate(shape)))
|
||||
if shape[sorted_dims[-2][1]] < min_dim_size_to_factor:
|
||||
return None
|
||||
return int(sorted_dims[-2][1]), int(sorted_dims[-1][1])
|
||||
|
||||
|
||||
class AdafactorBigVision(Optimizer):
|
||||
"""
|
||||
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
|
||||
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -95,6 +92,12 @@ class AdafactorBigVision(Optimizer):
|
||||
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
||||
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
|
||||
|
||||
if 'exp_avg' in p_state and torch.is_tensor(p_state['exp_avg']):
|
||||
# FIXME this is a bit of a hack, optimizer.load_state_dict appears to upcast
|
||||
# the momentum to float32 (it's half precision in the state_dict), need to
|
||||
# look into this further. Better to override _process_value_according_to_param_policy?
|
||||
p_state['exp_avg'] = p_state['exp_avg'].to(dtype=self.defaults['momentum_dtype'])
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
@ -181,6 +184,7 @@ class AdafactorBigVision(Optimizer):
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def _single_tensor_adafactor(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
@ -262,24 +266,25 @@ def _single_tensor_adafactor(
|
||||
# Update parameters
|
||||
param.add_(update, alpha=-1.0)
|
||||
|
||||
|
||||
def _multi_tensor_adafactor(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avg_sq_rs: List[Optional[Tensor]],
|
||||
exp_avg_sq_cs: List[Optional[Tensor]],
|
||||
exp_avg_sqs: List[Optional[Tensor]],
|
||||
exp_avgs: List[Optional[Tensor]],
|
||||
state_steps: List[Tensor],
|
||||
*,
|
||||
beta2_decay: float,
|
||||
beta2_cap: float,
|
||||
min_dim_size_to_factor: int,
|
||||
eps: float,
|
||||
lr: float,
|
||||
weight_decay: float,
|
||||
momentum: Optional[float],
|
||||
momentum_dtype: Union[str, torch.dtype],
|
||||
clipping_threshold: Optional[float],
|
||||
unscaled_wd: bool,
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avg_sq_rs: List[Optional[Tensor]],
|
||||
exp_avg_sq_cs: List[Optional[Tensor]],
|
||||
exp_avg_sqs: List[Optional[Tensor]],
|
||||
exp_avgs: List[Optional[Tensor]],
|
||||
state_steps: List[Tensor],
|
||||
*,
|
||||
beta2_decay: float,
|
||||
beta2_cap: float,
|
||||
min_dim_size_to_factor: int,
|
||||
eps: float,
|
||||
lr: float,
|
||||
weight_decay: float,
|
||||
momentum: Optional[float],
|
||||
momentum_dtype: Union[str, torch.dtype],
|
||||
clipping_threshold: Optional[float],
|
||||
unscaled_wd: bool,
|
||||
):
|
||||
assert False, 'multi-tensor fn (foreach=True) not implemented yet'
|
||||
|
Loading…
x
Reference in New Issue
Block a user