Change start/end args

kron_flatten
Ross Wightman 2025-01-30 13:13:49 -08:00
parent 3be8b1abe4
commit 5940cc167f
1 changed files with 5 additions and 3 deletions

View File

@ -116,7 +116,8 @@ class Kron(torch.optim.Optimizer):
precond_dtype: Optional[torch.dtype] = None,
decoupled_decay: bool = False,
flatten: bool = False,
flatten_start_end: Tuple[int, int] = (2, -1),
flatten_start_dim: int = 2,
flatten_end_dim: int = -1,
deterministic: bool = False,
):
if not has_opt_einsum:
@ -144,7 +145,8 @@ class Kron(torch.optim.Optimizer):
precond_dtype=precond_dtype,
decoupled_decay=decoupled_decay,
flatten=flatten,
flatten_start_end=flatten_start_end,
flatten_start_dim=flatten_start_dim,
flatten_end_dim=flatten_end_dim,
)
super(Kron, self).__init__(params, defaults)
@ -235,7 +237,7 @@ class Kron(torch.optim.Optimizer):
flattened = False
if group['flatten']:
grad = safe_flatten(grad, *group["flatten_start_end"])
grad = safe_flatten(grad, group["flatten_start_dim"], group["flatten_end_dim"])
flattened = True
if len(state) == 0: