From 5940cc167f7426054a7f69c4358c7f2ac2655d5d Mon Sep 17 00:00:00 2001
From: Ross Wightman <rwightman@gmail.com>
Date: Thu, 30 Jan 2025 13:13:49 -0800
Subject: [PATCH] Change start/end args

---
 timm/optim/kron.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/timm/optim/kron.py b/timm/optim/kron.py
index 9f4e4965..25c1b047 100644
--- a/timm/optim/kron.py
+++ b/timm/optim/kron.py
@@ -116,7 +116,8 @@ class Kron(torch.optim.Optimizer):
         precond_dtype: Optional[torch.dtype] = None,
         decoupled_decay: bool = False,
         flatten: bool = False,
-        flatten_start_end: Tuple[int, int] = (2, -1),
+        flatten_start_dim: int = 2,
+        flatten_end_dim: int = -1,
         deterministic: bool = False,
     ):
         if not has_opt_einsum:
@@ -144,7 +145,8 @@ class Kron(torch.optim.Optimizer):
             precond_dtype=precond_dtype,
             decoupled_decay=decoupled_decay,
             flatten=flatten,
-            flatten_start_end=flatten_start_end,
+            flatten_start_dim=flatten_start_dim,
+            flatten_end_dim=flatten_end_dim,
         )
         super(Kron, self).__init__(params, defaults)
 
@@ -235,7 +237,7 @@ class Kron(torch.optim.Optimizer):
 
                 flattened = False
                 if group['flatten']:
-                    grad = safe_flatten(grad, *group["flatten_start_end"])
+                    grad = safe_flatten(grad, group["flatten_start_dim"], group["flatten_end_dim"])
                     flattened = True
 
                 if len(state) == 0: