From 5940cc167f7426054a7f69c4358c7f2ac2655d5d Mon Sep 17 00:00:00 2001 From: Ross Wightman <rwightman@gmail.com> Date: Thu, 30 Jan 2025 13:13:49 -0800 Subject: [PATCH] Change start/end args --- timm/optim/kron.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/timm/optim/kron.py b/timm/optim/kron.py index 9f4e4965..25c1b047 100644 --- a/timm/optim/kron.py +++ b/timm/optim/kron.py @@ -116,7 +116,8 @@ class Kron(torch.optim.Optimizer): precond_dtype: Optional[torch.dtype] = None, decoupled_decay: bool = False, flatten: bool = False, - flatten_start_end: Tuple[int, int] = (2, -1), + flatten_start_dim: int = 2, + flatten_end_dim: int = -1, deterministic: bool = False, ): if not has_opt_einsum: @@ -144,7 +145,8 @@ class Kron(torch.optim.Optimizer): precond_dtype=precond_dtype, decoupled_decay=decoupled_decay, flatten=flatten, - flatten_start_end=flatten_start_end, + flatten_start_dim=flatten_start_dim, + flatten_end_dim=flatten_end_dim, ) super(Kron, self).__init__(params, defaults) @@ -235,7 +237,7 @@ class Kron(torch.optim.Optimizer): flattened = False if group['flatten']: - grad = safe_flatten(grad, *group["flatten_start_end"]) + grad = safe_flatten(grad, group["flatten_start_dim"], group["flatten_end_dim"]) flattened = True if len(state) == 0: