Prep Kron for merge, add detail to attributions note, README.
parent
67ef6f0a92
commit
b3a83b81d6
|
@ -12,6 +12,11 @@
|
|||
|
||||
## What's New
|
||||
|
||||
## Jan 27, 2025
|
||||
* Add Kron Optimizer (PSGD w/ Kronecker-factored preconditioner)
|
||||
* Code from https://github.com/evanatyourservice/kron_torch
|
||||
* See also https://sites.google.com/site/lixilinx/home/psgd
|
||||
|
||||
## Jan 19, 2025
|
||||
* Fix loading of LeViT safetensor weights, remove conversion code which should have been deactivated
|
||||
* Add 'SO150M' ViT weights trained with SBB recipes, decent results, but not optimal shape for ImageNet-12k/1k pretrain/ft
|
||||
|
@ -461,6 +466,7 @@ Included optimizers available via `timm.optim.create_optimizer_v2` factory metho
|
|||
* `adamp` and `sgdp` by [Naver ClovAI](https://github.com/clovaai) - https://arxiv.org/abs/2006.08217
|
||||
* `adan` an implementation of Adan adapted from https://github.com/sail-sg/Adan - https://arxiv.org/abs/2208.06677
|
||||
* `adopt` ADOPT adapted from https://github.com/iShohei220/adopt - https://arxiv.org/abs/2411.02853
|
||||
* `kron` PSGD w/ Kronecker-factored preconditioner from https://github.com/evanatyourservice/kron_torch - https://sites.google.com/site/lixilinx/home/psgd
|
||||
* `lamb` an implementation of Lamb and LambC (w/ trust-clipping) cleaned up and modified to support use with XLA - https://arxiv.org/abs/1904.00962
|
||||
* `laprop` optimizer from https://github.com/Z-T-WANG/LaProp-Optimizer - https://arxiv.org/abs/2002.04839
|
||||
* `lars` an implementation of LARS and LARC (w/ trust-clipping) - https://arxiv.org/abs/1708.03888
|
||||
|
|
|
@ -1,9 +1,22 @@
|
|||
""" PyTorch Implementation of the Kron PSGD optimizer
|
||||
""" PyTorch Implementation of the Kron (PSGD) optimizer
|
||||
|
||||
FIXME attribution
|
||||
* https://github.com/evanatyourservice/kron_torch (direct source)
|
||||
* https://github.com/lixilinx/psgd_torch (original)
|
||||
* https://github.com/ClashLuke/HeavyBall (added improvements)
|
||||
This is a PSGD optimizer using a Kronecker-factored preconditioner.
|
||||
|
||||
This impl was adapted from https://github.com/evanatyourservice/kron_torch
|
||||
by Evan Walters, licensed CC-BY-4.0.
|
||||
|
||||
Contributions to above also made by
|
||||
* Lucas Nestler, added to his https://github.com/ClashLuke/HeavyBall implementation.
|
||||
* Omead Pooladzandi https://github.com/opooladz
|
||||
|
||||
The above work drew from https://github.com/lixilinx/psgd_torch by Xi-Lin Li
|
||||
|
||||
This `timm` impl
|
||||
* works with a wider variety of torch versions
|
||||
* fixes some checkpoint save/restore (resume issues)
|
||||
* adds decoupled weight-decay option
|
||||
* has some refactoring, cleanup of args, default/group items
|
||||
* warning about not having opt_einsum (unusable without)
|
||||
|
||||
"""
|
||||
import logging
|
||||
|
@ -30,6 +43,8 @@ try:
|
|||
except AttributeError:
|
||||
has_dynamo = False
|
||||
|
||||
from ._types import ParamsT
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -85,7 +100,7 @@ class Kron(torch.optim.Optimizer):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
params: ParamsT,
|
||||
lr: float = 0.001,
|
||||
momentum: float = 0.9,
|
||||
weight_decay: float = 0.0,
|
||||
|
@ -94,6 +109,8 @@ class Kron(torch.optim.Optimizer):
|
|||
min_ndim_triangular: int = 2,
|
||||
memory_save_mode: Optional[str] = None,
|
||||
momentum_into_precond_update: bool = True,
|
||||
precond_lr: float = 0.1,
|
||||
precond_init_scale: float = 1.0,
|
||||
mu_dtype: Optional[torch.dtype] = None,
|
||||
precond_dtype: Optional[torch.dtype] = None,
|
||||
decoupled_decay: bool = False,
|
||||
|
@ -119,8 +136,8 @@ class Kron(torch.optim.Optimizer):
|
|||
min_ndim_triangular=min_ndim_triangular,
|
||||
memory_save_mode=memory_save_mode,
|
||||
momentum_into_precond_update=momentum_into_precond_update,
|
||||
precond_lr=0.1, # precond lr hardcoded to 0.1
|
||||
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
||||
precond_lr=precond_lr,
|
||||
precond_init_scale=precond_init_scale,
|
||||
mu_dtype=mu_dtype,
|
||||
precond_dtype=precond_dtype,
|
||||
decoupled_decay=decoupled_decay,
|
||||
|
|
Loading…
Reference in New Issue