mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update adafactor comments / attrib
This commit is contained in:
parent
d73e8e7531
commit
e6d72ed1b7
@ -2,8 +2,9 @@
|
|||||||
|
|
||||||
Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
|
Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
|
||||||
|
|
||||||
Original header/copyright below.
|
Modified by Ross Wightman to fix some issues with factorization dims for non nn.Linear layers
|
||||||
|
|
||||||
|
Original header/copyright below.
|
||||||
"""
|
"""
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
#
|
#
|
||||||
@ -96,7 +97,7 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
# nD convs in torch are ND + 2 dim weights with leading in/out chs
|
# nD convs in torch are ND + 2 dim weights with leading in/out chs
|
||||||
factored = 0, 1
|
factored = 0, 1
|
||||||
elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor:
|
elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor:
|
||||||
# if the criteria above didn't match, check trailing dims
|
# if the criteria above didn't match, test trailing dims for eligibility
|
||||||
factored = ndim - 2, ndim - 1
|
factored = ndim - 2, ndim - 1
|
||||||
|
|
||||||
return factored, use_first_moment
|
return factored, use_first_moment
|
||||||
|
@ -1,3 +1,12 @@
|
|||||||
|
""" Adafactor (Big Vision variant) for PyTorch
|
||||||
|
|
||||||
|
Adapted from the implementation in big vision: https://github.com/google-research/big_vision
|
||||||
|
|
||||||
|
Described in 'Scaling Vision Transformers': https://arxiv.org/abs/2106.04560
|
||||||
|
|
||||||
|
Adaptation and PyTorch modifications by Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -39,6 +48,8 @@ def _factored_dims(
|
|||||||
class AdafactorBigVision(Optimizer):
|
class AdafactorBigVision(Optimizer):
|
||||||
"""
|
"""
|
||||||
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
|
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
|
||||||
|
|
||||||
|
Adapted from https://github.com/google-research/big_vision by Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -292,4 +303,5 @@ def _multi_tensor_adafactor(
|
|||||||
clipping_threshold: Optional[float],
|
clipping_threshold: Optional[float],
|
||||||
unscaled_wd: bool,
|
unscaled_wd: bool,
|
||||||
):
|
):
|
||||||
|
# FIXME TODO
|
||||||
assert False, 'multi-tensor fn (foreach=True) not implemented yet'
|
assert False, 'multi-tensor fn (foreach=True) not implemented yet'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user