Update adafactor comments / attrib

This commit is contained in:
Ross Wightman 2024-11-12 09:30:26 -08:00
parent d73e8e7531
commit e6d72ed1b7
2 changed files with 15 additions and 2 deletions

View File

@ -2,8 +2,9 @@
Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
Original header/copyright below.
Modified by Ross Wightman to fix some issues with factorization dims for non nn.Linear layers
Original header/copyright below.
"""
# Copyright (c) Facebook, Inc. and its affiliates.
#
@ -96,7 +97,7 @@ class Adafactor(torch.optim.Optimizer):
# nD convs in torch are ND + 2 dim weights with leading in/out chs
factored = 0, 1
elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor:
# if the criteria above didn't match, check trailing dims
# if the criteria above didn't match, test trailing dims for eligibility
factored = ndim - 2, ndim - 1
return factored, use_first_moment

View File

@ -1,3 +1,12 @@
""" Adafactor (Big Vision variant) for PyTorch
Adapted from the implementation in big vision: https://github.com/google-research/big_vision
Described in 'Scaling Vision Transformers': https://arxiv.org/abs/2106.04560
Adaptation and PyTorch modifications by Ross Wightman
"""
from typing import List, Optional, Tuple, Union
import torch
@ -39,6 +48,8 @@ def _factored_dims(
class AdafactorBigVision(Optimizer):
"""
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
Adapted from https://github.com/google-research/big_vision by Ross Wightman
"""
def __init__(
@ -292,4 +303,5 @@ def _multi_tensor_adafactor(
clipping_threshold: Optional[float],
unscaled_wd: bool,
):
# FIXME TODO
assert False, 'multi-tensor fn (foreach=True) not implemented yet'