From b3a83b81d6137d32ffc51417ae5a6851e1bf1f77 Mon Sep 17 00:00:00 2001
From: Ross Wightman <rwightman@gmail.com>
Date: Mon, 27 Jan 2025 16:00:58 -0800
Subject: [PATCH] Prep Kron for merge, add detail to attributions note, README.

---
 README.md          |  6 ++++++
 timm/optim/kron.py | 33 +++++++++++++++++++++++++--------
 2 files changed, 31 insertions(+), 8 deletions(-)

diff --git a/README.md b/README.md
index 84e3cdfe..125428d2 100644
--- a/README.md
+++ b/README.md
@@ -12,6 +12,11 @@
 
 ## What's New
 
+## Jan 27, 2025
+* Add Kron Optimizer (PSGD w/ Kronecker-factored preconditioner) 
+  * Code from https://github.com/evanatyourservice/kron_torch
+  * See also https://sites.google.com/site/lixilinx/home/psgd
+
 ## Jan 19, 2025
 * Fix loading of LeViT safetensor weights, remove conversion code which should have been deactivated
 * Add 'SO150M' ViT weights trained with SBB recipes, decent results, but not optimal shape for ImageNet-12k/1k pretrain/ft
@@ -461,6 +466,7 @@ Included optimizers available via `timm.optim.create_optimizer_v2` factory metho
 * `adamp` and `sgdp` by [Naver ClovAI](https://github.com/clovaai) - https://arxiv.org/abs/2006.08217
 * `adan` an implementation of Adan adapted from https://github.com/sail-sg/Adan - https://arxiv.org/abs/2208.06677
 * `adopt` ADOPT adapted from https://github.com/iShohei220/adopt - https://arxiv.org/abs/2411.02853
+* `kron` PSGD w/ Kronecker-factored preconditioner from https://github.com/evanatyourservice/kron_torch - https://sites.google.com/site/lixilinx/home/psgd
 * `lamb` an implementation of Lamb and LambC (w/ trust-clipping) cleaned up and modified to support use with XLA - https://arxiv.org/abs/1904.00962
 * `laprop` optimizer from https://github.com/Z-T-WANG/LaProp-Optimizer - https://arxiv.org/abs/2002.04839
 * `lars` an implementation of LARS and LARC (w/ trust-clipping) - https://arxiv.org/abs/1708.03888
diff --git a/timm/optim/kron.py b/timm/optim/kron.py
index 7f1fcd47..e01c9885 100644
--- a/timm/optim/kron.py
+++ b/timm/optim/kron.py
@@ -1,9 +1,22 @@
-""" PyTorch Implementation of the Kron PSGD optimizer
+""" PyTorch Implementation of the Kron (PSGD) optimizer
 
-FIXME attribution
-* https://github.com/evanatyourservice/kron_torch (direct source)
-* https://github.com/lixilinx/psgd_torch (original)
-* https://github.com/ClashLuke/HeavyBall (added improvements)
+This is a PSGD optimizer using a Kronecker-factored preconditioner.
+
+This impl was adapted from https://github.com/evanatyourservice/kron_torch
+by Evan Walters, licensed CC-BY-4.0.
+
+Contributions to above also made by
+* Lucas Nestler, added to his https://github.com/ClashLuke/HeavyBall implementation.
+* Omead Pooladzandi https://github.com/opooladz
+
+The above work drew from https://github.com/lixilinx/psgd_torch by Xi-Lin Li
+
+This `timm` impl
+* works with a wider variety of torch versions
+* fixes some checkpoint save/restore (resume issues)
+* adds decoupled weight-decay option
+* has some refactoring, cleanup of args, default/group items
+* warning about not having opt_einsum (unusable without)
 
 """
 import logging
@@ -30,6 +43,8 @@ try:
 except AttributeError:
     has_dynamo = False
 
+from ._types import ParamsT
+
 _logger = logging.getLogger(__name__)
 
 
@@ -85,7 +100,7 @@ class Kron(torch.optim.Optimizer):
 
     def __init__(
         self,
-        params,
+        params: ParamsT,
         lr: float = 0.001,
         momentum: float = 0.9,
         weight_decay: float = 0.0,
@@ -94,6 +109,8 @@ class Kron(torch.optim.Optimizer):
         min_ndim_triangular: int = 2,
         memory_save_mode: Optional[str] = None,
         momentum_into_precond_update: bool = True,
+        precond_lr: float = 0.1,
+        precond_init_scale: float = 1.0,
         mu_dtype: Optional[torch.dtype] = None,
         precond_dtype: Optional[torch.dtype] = None,
         decoupled_decay: bool = False,
@@ -119,8 +136,8 @@ class Kron(torch.optim.Optimizer):
             min_ndim_triangular=min_ndim_triangular,
             memory_save_mode=memory_save_mode,
             momentum_into_precond_update=momentum_into_precond_update,
-            precond_lr=0.1,  # precond lr hardcoded to 0.1
-            precond_init_scale=1.0,  # precond init scale hardcoded to 1.0
+            precond_lr=precond_lr,
+            precond_init_scale=precond_init_scale,
             mu_dtype=mu_dtype,
             precond_dtype=precond_dtype,
             decoupled_decay=decoupled_decay,