mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Change start/end args
This commit is contained in:
parent
31831f5948
commit
510bbd5389
@ -116,7 +116,8 @@ class Kron(torch.optim.Optimizer):
|
|||||||
precond_dtype: Optional[torch.dtype] = None,
|
precond_dtype: Optional[torch.dtype] = None,
|
||||||
decoupled_decay: bool = False,
|
decoupled_decay: bool = False,
|
||||||
flatten: bool = False,
|
flatten: bool = False,
|
||||||
flatten_start_end: Tuple[int, int] = (2, -1),
|
flatten_start_dim: int = 2,
|
||||||
|
flatten_end_dim: int = -1,
|
||||||
deterministic: bool = False,
|
deterministic: bool = False,
|
||||||
):
|
):
|
||||||
if not has_opt_einsum:
|
if not has_opt_einsum:
|
||||||
@ -144,7 +145,8 @@ class Kron(torch.optim.Optimizer):
|
|||||||
precond_dtype=precond_dtype,
|
precond_dtype=precond_dtype,
|
||||||
decoupled_decay=decoupled_decay,
|
decoupled_decay=decoupled_decay,
|
||||||
flatten=flatten,
|
flatten=flatten,
|
||||||
flatten_start_end=flatten_start_end,
|
flatten_start_dim=flatten_start_dim,
|
||||||
|
flatten_end_dim=flatten_end_dim,
|
||||||
)
|
)
|
||||||
super(Kron, self).__init__(params, defaults)
|
super(Kron, self).__init__(params, defaults)
|
||||||
|
|
||||||
@ -235,7 +237,7 @@ class Kron(torch.optim.Optimizer):
|
|||||||
|
|
||||||
flattened = False
|
flattened = False
|
||||||
if group['flatten']:
|
if group['flatten']:
|
||||||
grad = safe_flatten(grad, *group["flatten_start_end"])
|
grad = safe_flatten(grad, group["flatten_start_dim"], group["flatten_end_dim"])
|
||||||
flattened = True
|
flattened = True
|
||||||
|
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user