mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove adafactorbv numpy dep, hack fix for loading optimizer state w/ half prec momentum (need better one)
This commit is contained in:
parent
91f0ea3338
commit
548fdb5d71
@ -1,6 +1,5 @@
|
|||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
@ -10,38 +9,36 @@ def _get_scalar_dtype():
|
|||||||
"""Get the scalar dtype that the optimizer uses for state"""
|
"""Get the scalar dtype that the optimizer uses for state"""
|
||||||
return torch.float64
|
return torch.float64
|
||||||
|
|
||||||
|
|
||||||
def _factored_dims(
|
def _factored_dims(
|
||||||
shape: Tuple[int, ...],
|
shape: Tuple[int, ...],
|
||||||
factored: bool,
|
factored: bool,
|
||||||
min_dim_size_to_factor: int
|
min_dim_size_to_factor: int
|
||||||
) -> Optional[tuple[int, int]]:
|
) -> Optional[tuple[int, int]]:
|
||||||
"""Whether to use a factored second moment estimator.
|
"""Whether to use a factored second moment estimator.
|
||||||
|
|
||||||
This function returns a tuple with the two largest axes to reduce over.
|
This function returns a tuple with the two largest axes to reduce over.
|
||||||
If no two dimensions have size >= min_dim_size_to_factor, return None.
|
If no two dimensions have size >= min_dim_size_to_factor, return None.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape: an input shape
|
shape: an input shape
|
||||||
factored: whether to use factored second-moment estimator for > 2d vars.
|
factored: whether to use factored second-moment estimator for > 2d vars.
|
||||||
min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size.
|
min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None or a tuple of ints
|
None or a tuple of ints
|
||||||
"""
|
"""
|
||||||
if not factored or len(shape) < 2:
|
if not factored or len(shape) < 2:
|
||||||
return None
|
return None
|
||||||
sorted_dims = np.argsort(shape)
|
sorted_dims = sorted(((x, i) for i, x in enumerate(shape)))
|
||||||
if shape[sorted_dims[-2]] < min_dim_size_to_factor:
|
if shape[sorted_dims[-2][1]] < min_dim_size_to_factor:
|
||||||
return None
|
return None
|
||||||
return int(sorted_dims[-2]), int(sorted_dims[-1])
|
return int(sorted_dims[-2][1]), int(sorted_dims[-1][1])
|
||||||
|
|
||||||
|
|
||||||
class AdafactorBigVision(Optimizer):
|
class AdafactorBigVision(Optimizer):
|
||||||
"""
|
"""
|
||||||
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
|
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -95,6 +92,12 @@ class AdafactorBigVision(Optimizer):
|
|||||||
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
||||||
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
|
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
|
||||||
|
|
||||||
|
if 'exp_avg' in p_state and torch.is_tensor(p_state['exp_avg']):
|
||||||
|
# FIXME this is a bit of a hack, optimizer.load_state_dict appears to upcast
|
||||||
|
# the momentum to float32 (it's half precision in the state_dict), need to
|
||||||
|
# look into this further. Better to override _process_value_according_to_param_policy?
|
||||||
|
p_state['exp_avg'] = p_state['exp_avg'].to(dtype=self.defaults['momentum_dtype'])
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
loss = None
|
loss = None
|
||||||
@ -181,6 +184,7 @@ class AdafactorBigVision(Optimizer):
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def _single_tensor_adafactor(
|
def _single_tensor_adafactor(
|
||||||
params: List[Tensor],
|
params: List[Tensor],
|
||||||
grads: List[Tensor],
|
grads: List[Tensor],
|
||||||
@ -262,24 +266,25 @@ def _single_tensor_adafactor(
|
|||||||
# Update parameters
|
# Update parameters
|
||||||
param.add_(update, alpha=-1.0)
|
param.add_(update, alpha=-1.0)
|
||||||
|
|
||||||
|
|
||||||
def _multi_tensor_adafactor(
|
def _multi_tensor_adafactor(
|
||||||
params: List[Tensor],
|
params: List[Tensor],
|
||||||
grads: List[Tensor],
|
grads: List[Tensor],
|
||||||
exp_avg_sq_rs: List[Optional[Tensor]],
|
exp_avg_sq_rs: List[Optional[Tensor]],
|
||||||
exp_avg_sq_cs: List[Optional[Tensor]],
|
exp_avg_sq_cs: List[Optional[Tensor]],
|
||||||
exp_avg_sqs: List[Optional[Tensor]],
|
exp_avg_sqs: List[Optional[Tensor]],
|
||||||
exp_avgs: List[Optional[Tensor]],
|
exp_avgs: List[Optional[Tensor]],
|
||||||
state_steps: List[Tensor],
|
state_steps: List[Tensor],
|
||||||
*,
|
*,
|
||||||
beta2_decay: float,
|
beta2_decay: float,
|
||||||
beta2_cap: float,
|
beta2_cap: float,
|
||||||
min_dim_size_to_factor: int,
|
min_dim_size_to_factor: int,
|
||||||
eps: float,
|
eps: float,
|
||||||
lr: float,
|
lr: float,
|
||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
momentum: Optional[float],
|
momentum: Optional[float],
|
||||||
momentum_dtype: Union[str, torch.dtype],
|
momentum_dtype: Union[str, torch.dtype],
|
||||||
clipping_threshold: Optional[float],
|
clipping_threshold: Optional[float],
|
||||||
unscaled_wd: bool,
|
unscaled_wd: bool,
|
||||||
):
|
):
|
||||||
assert False, 'multi-tensor fn (foreach=True) not implemented yet'
|
assert False, 'multi-tensor fn (foreach=True) not implemented yet'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user