Change flattening behaviour in Kron

This commit is contained in:
Ross Wightman 2025-01-30 13:07:20 -08:00 committed by Ross Wightman
parent cdbafd9057
commit 31831f5948

View File

@ -94,7 +94,8 @@ class Kron(torch.optim.Optimizer):
mu_dtype: Dtype of the momentum accumulator. mu_dtype: Dtype of the momentum accumulator.
precond_dtype: Dtype of the preconditioner. precond_dtype: Dtype of the preconditioner.
decoupled_decay: AdamW style decoupled weight decay decoupled_decay: AdamW style decoupled weight decay
flatten_dim: Flatten dim >= 2 instead of relying on expressions flatten: Flatten dimensions instead of fully relying on expressions for higher rank params
flatten_start_end: Range of dimensions to flatten, defaults to (2, -1).
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
""" """
@ -114,7 +115,8 @@ class Kron(torch.optim.Optimizer):
mu_dtype: Optional[torch.dtype] = None, mu_dtype: Optional[torch.dtype] = None,
precond_dtype: Optional[torch.dtype] = None, precond_dtype: Optional[torch.dtype] = None,
decoupled_decay: bool = False, decoupled_decay: bool = False,
flatten_dim: bool = False, flatten: bool = False,
flatten_start_end: Tuple[int, int] = (2, -1),
deterministic: bool = False, deterministic: bool = False,
): ):
if not has_opt_einsum: if not has_opt_einsum:
@ -141,7 +143,8 @@ class Kron(torch.optim.Optimizer):
mu_dtype=mu_dtype, mu_dtype=mu_dtype,
precond_dtype=precond_dtype, precond_dtype=precond_dtype,
decoupled_decay=decoupled_decay, decoupled_decay=decoupled_decay,
flatten_dim=flatten_dim, flatten=flatten,
flatten_start_end=flatten_start_end,
) )
super(Kron, self).__init__(params, defaults) super(Kron, self).__init__(params, defaults)
@ -229,8 +232,11 @@ class Kron(torch.optim.Optimizer):
grad = p.grad grad = p.grad
state = self.state[p] state = self.state[p]
if group['flatten_dim']:
grad = grad.view(grad.size(0), -1) flattened = False
if group['flatten']:
grad = safe_flatten(grad, *group["flatten_start_end"])
flattened = True
if len(state) == 0: if len(state) == 0:
state["step"] = 0 state["step"] = 0
@ -341,7 +347,7 @@ class Kron(torch.optim.Optimizer):
# RMS of pre_grad should be 1.0, so let's cap at 1.1 # RMS of pre_grad should be 1.0, so let's cap at 1.1
pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0)) pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0))
if group['flatten_dim']: if flattened:
pre_grad = pre_grad.view(p.shape) pre_grad = pre_grad.view(p.shape)
# Apply weight decay # Apply weight decay
@ -361,6 +367,20 @@ class Kron(torch.optim.Optimizer):
return loss return loss
def safe_flatten(tensor, start_dim=0, end_dim=-1):
ndim = tensor.ndim
# Convert negative end_dim to positive and clip to end
end_dim = min(end_dim if end_dim >= 0 else ndim + end_dim, ndim - 1)
# If tensor has fewer dims than start_dim or start > end, return tensor as is
if ndim <= start_dim or start_dim > end_dim:
return tensor
# Now safe to flatten
return tensor.flatten(start_dim, end_dim)
def _init_Q_exprs( def _init_Q_exprs(
t, t,
scale, scale,