mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update adafactor comments / attrib
This commit is contained in:
parent
d73e8e7531
commit
e6d72ed1b7
@ -2,8 +2,9 @@
|
||||
|
||||
Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
|
||||
|
||||
Original header/copyright below.
|
||||
Modified by Ross Wightman to fix some issues with factorization dims for non nn.Linear layers
|
||||
|
||||
Original header/copyright below.
|
||||
"""
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
@ -96,7 +97,7 @@ class Adafactor(torch.optim.Optimizer):
|
||||
# nD convs in torch are ND + 2 dim weights with leading in/out chs
|
||||
factored = 0, 1
|
||||
elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor:
|
||||
# if the criteria above didn't match, check trailing dims
|
||||
# if the criteria above didn't match, test trailing dims for eligibility
|
||||
factored = ndim - 2, ndim - 1
|
||||
|
||||
return factored, use_first_moment
|
||||
|
@ -1,3 +1,12 @@
|
||||
""" Adafactor (Big Vision variant) for PyTorch
|
||||
|
||||
Adapted from the implementation in big vision: https://github.com/google-research/big_vision
|
||||
|
||||
Described in 'Scaling Vision Transformers': https://arxiv.org/abs/2106.04560
|
||||
|
||||
Adaptation and PyTorch modifications by Ross Wightman
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -39,6 +48,8 @@ def _factored_dims(
|
||||
class AdafactorBigVision(Optimizer):
|
||||
"""
|
||||
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
|
||||
|
||||
Adapted from https://github.com/google-research/big_vision by Ross Wightman
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -292,4 +303,5 @@ def _multi_tensor_adafactor(
|
||||
clipping_threshold: Optional[float],
|
||||
unscaled_wd: bool,
|
||||
):
|
||||
# FIXME TODO
|
||||
assert False, 'multi-tensor fn (foreach=True) not implemented yet'
|
||||
|
Loading…
x
Reference in New Issue
Block a user