From b3a83b81d6137d32ffc51417ae5a6851e1bf1f77 Mon Sep 17 00:00:00 2001 From: Ross Wightman <rwightman@gmail.com> Date: Mon, 27 Jan 2025 16:00:58 -0800 Subject: [PATCH] Prep Kron for merge, add detail to attributions note, README. --- README.md | 6 ++++++ timm/optim/kron.py | 33 +++++++++++++++++++++++++-------- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 84e3cdfe..125428d2 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,11 @@ ## What's New +## Jan 27, 2025 +* Add Kron Optimizer (PSGD w/ Kronecker-factored preconditioner) + * Code from https://github.com/evanatyourservice/kron_torch + * See also https://sites.google.com/site/lixilinx/home/psgd + ## Jan 19, 2025 * Fix loading of LeViT safetensor weights, remove conversion code which should have been deactivated * Add 'SO150M' ViT weights trained with SBB recipes, decent results, but not optimal shape for ImageNet-12k/1k pretrain/ft @@ -461,6 +466,7 @@ Included optimizers available via `timm.optim.create_optimizer_v2` factory metho * `adamp` and `sgdp` by [Naver ClovAI](https://github.com/clovaai) - https://arxiv.org/abs/2006.08217 * `adan` an implementation of Adan adapted from https://github.com/sail-sg/Adan - https://arxiv.org/abs/2208.06677 * `adopt` ADOPT adapted from https://github.com/iShohei220/adopt - https://arxiv.org/abs/2411.02853 +* `kron` PSGD w/ Kronecker-factored preconditioner from https://github.com/evanatyourservice/kron_torch - https://sites.google.com/site/lixilinx/home/psgd * `lamb` an implementation of Lamb and LambC (w/ trust-clipping) cleaned up and modified to support use with XLA - https://arxiv.org/abs/1904.00962 * `laprop` optimizer from https://github.com/Z-T-WANG/LaProp-Optimizer - https://arxiv.org/abs/2002.04839 * `lars` an implementation of LARS and LARC (w/ trust-clipping) - https://arxiv.org/abs/1708.03888 diff --git a/timm/optim/kron.py b/timm/optim/kron.py index 7f1fcd47..e01c9885 100644 --- a/timm/optim/kron.py +++ b/timm/optim/kron.py @@ -1,9 +1,22 @@ -""" PyTorch Implementation of the Kron PSGD optimizer +""" PyTorch Implementation of the Kron (PSGD) optimizer -FIXME attribution -* https://github.com/evanatyourservice/kron_torch (direct source) -* https://github.com/lixilinx/psgd_torch (original) -* https://github.com/ClashLuke/HeavyBall (added improvements) +This is a PSGD optimizer using a Kronecker-factored preconditioner. + +This impl was adapted from https://github.com/evanatyourservice/kron_torch +by Evan Walters, licensed CC-BY-4.0. + +Contributions to above also made by +* Lucas Nestler, added to his https://github.com/ClashLuke/HeavyBall implementation. +* Omead Pooladzandi https://github.com/opooladz + +The above work drew from https://github.com/lixilinx/psgd_torch by Xi-Lin Li + +This `timm` impl +* works with a wider variety of torch versions +* fixes some checkpoint save/restore (resume issues) +* adds decoupled weight-decay option +* has some refactoring, cleanup of args, default/group items +* warning about not having opt_einsum (unusable without) """ import logging @@ -30,6 +43,8 @@ try: except AttributeError: has_dynamo = False +from ._types import ParamsT + _logger = logging.getLogger(__name__) @@ -85,7 +100,7 @@ class Kron(torch.optim.Optimizer): def __init__( self, - params, + params: ParamsT, lr: float = 0.001, momentum: float = 0.9, weight_decay: float = 0.0, @@ -94,6 +109,8 @@ class Kron(torch.optim.Optimizer): min_ndim_triangular: int = 2, memory_save_mode: Optional[str] = None, momentum_into_precond_update: bool = True, + precond_lr: float = 0.1, + precond_init_scale: float = 1.0, mu_dtype: Optional[torch.dtype] = None, precond_dtype: Optional[torch.dtype] = None, decoupled_decay: bool = False, @@ -119,8 +136,8 @@ class Kron(torch.optim.Optimizer): min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode, momentum_into_precond_update=momentum_into_precond_update, - precond_lr=0.1, # precond lr hardcoded to 0.1 - precond_init_scale=1.0, # precond init scale hardcoded to 1.0 + precond_lr=precond_lr, + precond_init_scale=precond_init_scale, mu_dtype=mu_dtype, precond_dtype=precond_dtype, decoupled_decay=decoupled_decay,