Change start/end args

This commit is contained in:
Ross Wightman 2025-01-30 13:13:49 -08:00
parent 3be8b1abe4
commit 5940cc167f

View File

@ -116,7 +116,8 @@ class Kron(torch.optim.Optimizer):
precond_dtype: Optional[torch.dtype] = None, precond_dtype: Optional[torch.dtype] = None,
decoupled_decay: bool = False, decoupled_decay: bool = False,
flatten: bool = False, flatten: bool = False,
flatten_start_end: Tuple[int, int] = (2, -1), flatten_start_dim: int = 2,
flatten_end_dim: int = -1,
deterministic: bool = False, deterministic: bool = False,
): ):
if not has_opt_einsum: if not has_opt_einsum:
@ -144,7 +145,8 @@ class Kron(torch.optim.Optimizer):
precond_dtype=precond_dtype, precond_dtype=precond_dtype,
decoupled_decay=decoupled_decay, decoupled_decay=decoupled_decay,
flatten=flatten, flatten=flatten,
flatten_start_end=flatten_start_end, flatten_start_dim=flatten_start_dim,
flatten_end_dim=flatten_end_dim,
) )
super(Kron, self).__init__(params, defaults) super(Kron, self).__init__(params, defaults)
@ -235,7 +237,7 @@ class Kron(torch.optim.Optimizer):
flattened = False flattened = False
if group['flatten']: if group['flatten']:
grad = safe_flatten(grad, *group["flatten_start_end"]) grad = safe_flatten(grad, group["flatten_start_dim"], group["flatten_end_dim"])
flattened = True flattened = True
if len(state) == 0: if len(state) == 0: