Change start/end args
parent
3be8b1abe4
commit
5940cc167f
|
@ -116,7 +116,8 @@ class Kron(torch.optim.Optimizer):
|
|||
precond_dtype: Optional[torch.dtype] = None,
|
||||
decoupled_decay: bool = False,
|
||||
flatten: bool = False,
|
||||
flatten_start_end: Tuple[int, int] = (2, -1),
|
||||
flatten_start_dim: int = 2,
|
||||
flatten_end_dim: int = -1,
|
||||
deterministic: bool = False,
|
||||
):
|
||||
if not has_opt_einsum:
|
||||
|
@ -144,7 +145,8 @@ class Kron(torch.optim.Optimizer):
|
|||
precond_dtype=precond_dtype,
|
||||
decoupled_decay=decoupled_decay,
|
||||
flatten=flatten,
|
||||
flatten_start_end=flatten_start_end,
|
||||
flatten_start_dim=flatten_start_dim,
|
||||
flatten_end_dim=flatten_end_dim,
|
||||
)
|
||||
super(Kron, self).__init__(params, defaults)
|
||||
|
||||
|
@ -235,7 +237,7 @@ class Kron(torch.optim.Optimizer):
|
|||
|
||||
flattened = False
|
||||
if group['flatten']:
|
||||
grad = safe_flatten(grad, *group["flatten_start_end"])
|
||||
grad = safe_flatten(grad, group["flatten_start_dim"], group["flatten_end_dim"])
|
||||
flattened = True
|
||||
|
||||
if len(state) == 0:
|
||||
|
|
Loading…
Reference in New Issue