Change start/end args

This commit is contained in:
Ross Wightman 2025-01-30 13:13:49 -08:00 committed by Ross Wightman
parent 31831f5948
commit 510bbd5389

View File

@ -116,7 +116,8 @@ class Kron(torch.optim.Optimizer):
precond_dtype: Optional[torch.dtype] = None,
decoupled_decay: bool = False,
flatten: bool = False,
flatten_start_end: Tuple[int, int] = (2, -1),
flatten_start_dim: int = 2,
flatten_end_dim: int = -1,
deterministic: bool = False,
):
if not has_opt_einsum:
@ -144,7 +145,8 @@ class Kron(torch.optim.Optimizer):
precond_dtype=precond_dtype,
decoupled_decay=decoupled_decay,
flatten=flatten,
flatten_start_end=flatten_start_end,
flatten_start_dim=flatten_start_dim,
flatten_end_dim=flatten_end_dim,
)
super(Kron, self).__init__(params, defaults)
@ -235,7 +237,7 @@ class Kron(torch.optim.Optimizer):
flattened = False
if group['flatten']:
grad = safe_flatten(grad, *group["flatten_start_end"])
grad = safe_flatten(grad, group["flatten_start_dim"], group["flatten_end_dim"])
flattened = True
if len(state) == 0: