mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove adafactorbv numpy dep, hack fix for loading optimizer state w/ half prec momentum (need better one)
This commit is contained in:
parent
19090ea966
commit
0b5ae49251
@ -1,6 +1,5 @@
|
|||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
@ -10,6 +9,7 @@ def _get_scalar_dtype():
|
|||||||
"""Get the scalar dtype that the optimizer uses for state"""
|
"""Get the scalar dtype that the optimizer uses for state"""
|
||||||
return torch.float64
|
return torch.float64
|
||||||
|
|
||||||
|
|
||||||
def _factored_dims(
|
def _factored_dims(
|
||||||
shape: Tuple[int, ...],
|
shape: Tuple[int, ...],
|
||||||
factored: bool,
|
factored: bool,
|
||||||
@ -30,18 +30,15 @@ def _factored_dims(
|
|||||||
"""
|
"""
|
||||||
if not factored or len(shape) < 2:
|
if not factored or len(shape) < 2:
|
||||||
return None
|
return None
|
||||||
sorted_dims = np.argsort(shape)
|
sorted_dims = sorted(((x, i) for i, x in enumerate(shape)))
|
||||||
if shape[sorted_dims[-2]] < min_dim_size_to_factor:
|
if shape[sorted_dims[-2][1]] < min_dim_size_to_factor:
|
||||||
return None
|
return None
|
||||||
return int(sorted_dims[-2]), int(sorted_dims[-1])
|
return int(sorted_dims[-2][1]), int(sorted_dims[-1][1])
|
||||||
|
|
||||||
|
|
||||||
class AdafactorBigVision(Optimizer):
|
class AdafactorBigVision(Optimizer):
|
||||||
"""
|
"""
|
||||||
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
|
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -95,6 +92,12 @@ class AdafactorBigVision(Optimizer):
|
|||||||
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
||||||
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
|
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
|
||||||
|
|
||||||
|
if 'exp_avg' in p_state and torch.is_tensor(p_state['exp_avg']):
|
||||||
|
# FIXME this is a bit of a hack, optimizer.load_state_dict appears to upcast
|
||||||
|
# the momentum to float32 (it's half precision in the state_dict), need to
|
||||||
|
# look into this further. Better to override _process_value_according_to_param_policy?
|
||||||
|
p_state['exp_avg'] = p_state['exp_avg'].to(dtype=self.defaults['momentum_dtype'])
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
loss = None
|
loss = None
|
||||||
@ -181,6 +184,7 @@ class AdafactorBigVision(Optimizer):
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def _single_tensor_adafactor(
|
def _single_tensor_adafactor(
|
||||||
params: List[Tensor],
|
params: List[Tensor],
|
||||||
grads: List[Tensor],
|
grads: List[Tensor],
|
||||||
@ -262,6 +266,7 @@ def _single_tensor_adafactor(
|
|||||||
# Update parameters
|
# Update parameters
|
||||||
param.add_(update, alpha=-1.0)
|
param.add_(update, alpha=-1.0)
|
||||||
|
|
||||||
|
|
||||||
def _multi_tensor_adafactor(
|
def _multi_tensor_adafactor(
|
||||||
params: List[Tensor],
|
params: List[Tensor],
|
||||||
grads: List[Tensor],
|
grads: List[Tensor],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user