diff --git a/timm/optim/adafactor.py b/timm/optim/adafactor.py index 37871af1..4cefad18 100644 --- a/timm/optim/adafactor.py +++ b/timm/optim/adafactor.py @@ -2,8 +2,9 @@ Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py -Original header/copyright below. +Modified by Ross Wightman to fix some issues with factorization dims for non nn.Linear layers +Original header/copyright below. """ # Copyright (c) Facebook, Inc. and its affiliates. # @@ -96,7 +97,7 @@ class Adafactor(torch.optim.Optimizer): # nD convs in torch are ND + 2 dim weights with leading in/out chs factored = 0, 1 elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor: - # if the criteria above didn't match, check trailing dims + # if the criteria above didn't match, test trailing dims for eligibility factored = ndim - 2, ndim - 1 return factored, use_first_moment diff --git a/timm/optim/adafactor_bv.py b/timm/optim/adafactor_bv.py index 465f8f0e..3bb6e959 100644 --- a/timm/optim/adafactor_bv.py +++ b/timm/optim/adafactor_bv.py @@ -1,3 +1,12 @@ +""" Adafactor (Big Vision variant) for PyTorch + +Adapted from the implementation in big vision: https://github.com/google-research/big_vision + +Described in 'Scaling Vision Transformers': https://arxiv.org/abs/2106.04560 + +Adaptation and PyTorch modifications by Ross Wightman +""" + from typing import List, Optional, Tuple, Union import torch @@ -39,6 +48,8 @@ def _factored_dims( class AdafactorBigVision(Optimizer): """ PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations. + + Adapted from https://github.com/google-research/big_vision by Ross Wightman """ def __init__( @@ -292,4 +303,5 @@ def _multi_tensor_adafactor( clipping_threshold: Optional[float], unscaled_wd: bool, ): + # FIXME TODO assert False, 'multi-tensor fn (foreach=True) not implemented yet'