mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Change flattening behaviour in Kron
This commit is contained in:
parent
cdbafd9057
commit
31831f5948
@ -94,7 +94,8 @@ class Kron(torch.optim.Optimizer):
|
|||||||
mu_dtype: Dtype of the momentum accumulator.
|
mu_dtype: Dtype of the momentum accumulator.
|
||||||
precond_dtype: Dtype of the preconditioner.
|
precond_dtype: Dtype of the preconditioner.
|
||||||
decoupled_decay: AdamW style decoupled weight decay
|
decoupled_decay: AdamW style decoupled weight decay
|
||||||
flatten_dim: Flatten dim >= 2 instead of relying on expressions
|
flatten: Flatten dimensions instead of fully relying on expressions for higher rank params
|
||||||
|
flatten_start_end: Range of dimensions to flatten, defaults to (2, -1).
|
||||||
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
|
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -114,7 +115,8 @@ class Kron(torch.optim.Optimizer):
|
|||||||
mu_dtype: Optional[torch.dtype] = None,
|
mu_dtype: Optional[torch.dtype] = None,
|
||||||
precond_dtype: Optional[torch.dtype] = None,
|
precond_dtype: Optional[torch.dtype] = None,
|
||||||
decoupled_decay: bool = False,
|
decoupled_decay: bool = False,
|
||||||
flatten_dim: bool = False,
|
flatten: bool = False,
|
||||||
|
flatten_start_end: Tuple[int, int] = (2, -1),
|
||||||
deterministic: bool = False,
|
deterministic: bool = False,
|
||||||
):
|
):
|
||||||
if not has_opt_einsum:
|
if not has_opt_einsum:
|
||||||
@ -141,7 +143,8 @@ class Kron(torch.optim.Optimizer):
|
|||||||
mu_dtype=mu_dtype,
|
mu_dtype=mu_dtype,
|
||||||
precond_dtype=precond_dtype,
|
precond_dtype=precond_dtype,
|
||||||
decoupled_decay=decoupled_decay,
|
decoupled_decay=decoupled_decay,
|
||||||
flatten_dim=flatten_dim,
|
flatten=flatten,
|
||||||
|
flatten_start_end=flatten_start_end,
|
||||||
)
|
)
|
||||||
super(Kron, self).__init__(params, defaults)
|
super(Kron, self).__init__(params, defaults)
|
||||||
|
|
||||||
@ -229,8 +232,11 @@ class Kron(torch.optim.Optimizer):
|
|||||||
|
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
state = self.state[p]
|
state = self.state[p]
|
||||||
if group['flatten_dim']:
|
|
||||||
grad = grad.view(grad.size(0), -1)
|
flattened = False
|
||||||
|
if group['flatten']:
|
||||||
|
grad = safe_flatten(grad, *group["flatten_start_end"])
|
||||||
|
flattened = True
|
||||||
|
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
state["step"] = 0
|
state["step"] = 0
|
||||||
@ -341,7 +347,7 @@ class Kron(torch.optim.Optimizer):
|
|||||||
|
|
||||||
# RMS of pre_grad should be 1.0, so let's cap at 1.1
|
# RMS of pre_grad should be 1.0, so let's cap at 1.1
|
||||||
pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0))
|
pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0))
|
||||||
if group['flatten_dim']:
|
if flattened:
|
||||||
pre_grad = pre_grad.view(p.shape)
|
pre_grad = pre_grad.view(p.shape)
|
||||||
|
|
||||||
# Apply weight decay
|
# Apply weight decay
|
||||||
@ -361,6 +367,20 @@ class Kron(torch.optim.Optimizer):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def safe_flatten(tensor, start_dim=0, end_dim=-1):
|
||||||
|
ndim = tensor.ndim
|
||||||
|
|
||||||
|
# Convert negative end_dim to positive and clip to end
|
||||||
|
end_dim = min(end_dim if end_dim >= 0 else ndim + end_dim, ndim - 1)
|
||||||
|
|
||||||
|
# If tensor has fewer dims than start_dim or start > end, return tensor as is
|
||||||
|
if ndim <= start_dim or start_dim > end_dim:
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
# Now safe to flatten
|
||||||
|
return tensor.flatten(start_dim, end_dim)
|
||||||
|
|
||||||
|
|
||||||
def _init_Q_exprs(
|
def _init_Q_exprs(
|
||||||
t,
|
t,
|
||||||
scale,
|
scale,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user