mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove adafactorbv numpy dep, hack fix for loading optimizer state w/ half prec momentum (need better one)
This commit is contained in:
parent
19090ea966
commit
0b5ae49251
@ -1,6 +1,5 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
@ -10,6 +9,7 @@ def _get_scalar_dtype():
|
||||
"""Get the scalar dtype that the optimizer uses for state"""
|
||||
return torch.float64
|
||||
|
||||
|
||||
def _factored_dims(
|
||||
shape: Tuple[int, ...],
|
||||
factored: bool,
|
||||
@ -30,18 +30,15 @@ def _factored_dims(
|
||||
"""
|
||||
if not factored or len(shape) < 2:
|
||||
return None
|
||||
sorted_dims = np.argsort(shape)
|
||||
if shape[sorted_dims[-2]] < min_dim_size_to_factor:
|
||||
sorted_dims = sorted(((x, i) for i, x in enumerate(shape)))
|
||||
if shape[sorted_dims[-2][1]] < min_dim_size_to_factor:
|
||||
return None
|
||||
return int(sorted_dims[-2]), int(sorted_dims[-1])
|
||||
return int(sorted_dims[-2][1]), int(sorted_dims[-1][1])
|
||||
|
||||
|
||||
class AdafactorBigVision(Optimizer):
|
||||
"""
|
||||
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
|
||||
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -95,6 +92,12 @@ class AdafactorBigVision(Optimizer):
|
||||
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
||||
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
|
||||
|
||||
if 'exp_avg' in p_state and torch.is_tensor(p_state['exp_avg']):
|
||||
# FIXME this is a bit of a hack, optimizer.load_state_dict appears to upcast
|
||||
# the momentum to float32 (it's half precision in the state_dict), need to
|
||||
# look into this further. Better to override _process_value_according_to_param_policy?
|
||||
p_state['exp_avg'] = p_state['exp_avg'].to(dtype=self.defaults['momentum_dtype'])
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
@ -181,6 +184,7 @@ class AdafactorBigVision(Optimizer):
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def _single_tensor_adafactor(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
@ -262,6 +266,7 @@ def _single_tensor_adafactor(
|
||||
# Update parameters
|
||||
param.add_(update, alpha=-1.0)
|
||||
|
||||
|
||||
def _multi_tensor_adafactor(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
|
Loading…
x
Reference in New Issue
Block a user