Prep Kron for merge, add detail to attributions note, README.

update_test_workflow
Ross Wightman 2025-01-27 16:00:58 -08:00 committed by Ross Wightman
parent 67ef6f0a92
commit b3a83b81d6
2 changed files with 31 additions and 8 deletions

View File

@ -12,6 +12,11 @@
## What's New
## Jan 27, 2025
* Add Kron Optimizer (PSGD w/ Kronecker-factored preconditioner)
* Code from https://github.com/evanatyourservice/kron_torch
* See also https://sites.google.com/site/lixilinx/home/psgd
## Jan 19, 2025
* Fix loading of LeViT safetensor weights, remove conversion code which should have been deactivated
* Add 'SO150M' ViT weights trained with SBB recipes, decent results, but not optimal shape for ImageNet-12k/1k pretrain/ft
@ -461,6 +466,7 @@ Included optimizers available via `timm.optim.create_optimizer_v2` factory metho
* `adamp` and `sgdp` by [Naver ClovAI](https://github.com/clovaai) - https://arxiv.org/abs/2006.08217
* `adan` an implementation of Adan adapted from https://github.com/sail-sg/Adan - https://arxiv.org/abs/2208.06677
* `adopt` ADOPT adapted from https://github.com/iShohei220/adopt - https://arxiv.org/abs/2411.02853
* `kron` PSGD w/ Kronecker-factored preconditioner from https://github.com/evanatyourservice/kron_torch - https://sites.google.com/site/lixilinx/home/psgd
* `lamb` an implementation of Lamb and LambC (w/ trust-clipping) cleaned up and modified to support use with XLA - https://arxiv.org/abs/1904.00962
* `laprop` optimizer from https://github.com/Z-T-WANG/LaProp-Optimizer - https://arxiv.org/abs/2002.04839
* `lars` an implementation of LARS and LARC (w/ trust-clipping) - https://arxiv.org/abs/1708.03888

View File

@ -1,9 +1,22 @@
""" PyTorch Implementation of the Kron PSGD optimizer
""" PyTorch Implementation of the Kron (PSGD) optimizer
FIXME attribution
* https://github.com/evanatyourservice/kron_torch (direct source)
* https://github.com/lixilinx/psgd_torch (original)
* https://github.com/ClashLuke/HeavyBall (added improvements)
This is a PSGD optimizer using a Kronecker-factored preconditioner.
This impl was adapted from https://github.com/evanatyourservice/kron_torch
by Evan Walters, licensed CC-BY-4.0.
Contributions to above also made by
* Lucas Nestler, added to his https://github.com/ClashLuke/HeavyBall implementation.
* Omead Pooladzandi https://github.com/opooladz
The above work drew from https://github.com/lixilinx/psgd_torch by Xi-Lin Li
This `timm` impl
* works with a wider variety of torch versions
* fixes some checkpoint save/restore (resume issues)
* adds decoupled weight-decay option
* has some refactoring, cleanup of args, default/group items
* warning about not having opt_einsum (unusable without)
"""
import logging
@ -30,6 +43,8 @@ try:
except AttributeError:
has_dynamo = False
from ._types import ParamsT
_logger = logging.getLogger(__name__)
@ -85,7 +100,7 @@ class Kron(torch.optim.Optimizer):
def __init__(
self,
params,
params: ParamsT,
lr: float = 0.001,
momentum: float = 0.9,
weight_decay: float = 0.0,
@ -94,6 +109,8 @@ class Kron(torch.optim.Optimizer):
min_ndim_triangular: int = 2,
memory_save_mode: Optional[str] = None,
momentum_into_precond_update: bool = True,
precond_lr: float = 0.1,
precond_init_scale: float = 1.0,
mu_dtype: Optional[torch.dtype] = None,
precond_dtype: Optional[torch.dtype] = None,
decoupled_decay: bool = False,
@ -119,8 +136,8 @@ class Kron(torch.optim.Optimizer):
min_ndim_triangular=min_ndim_triangular,
memory_save_mode=memory_save_mode,
momentum_into_precond_update=momentum_into_precond_update,
precond_lr=0.1, # precond lr hardcoded to 0.1
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
precond_lr=precond_lr,
precond_init_scale=precond_init_scale,
mu_dtype=mu_dtype,
precond_dtype=precond_dtype,
decoupled_decay=decoupled_decay,