Improve row/col dim var name

This commit is contained in:
Ross Wightman 2024-11-08 08:36:11 -08:00 committed by Ross Wightman
parent e7b0480381
commit 36a45e5d94

View File

@ -139,11 +139,11 @@ class AdafactorBigVision(Optimizer):
)
if factored_dims is not None:
d1, d0 = factored_dims
dc, dr = factored_dims
row_shape = list(p.grad.shape)
row_shape[d0] = 1
row_shape[dr] = 1
col_shape = list(p.grad.shape)
col_shape[d1] = 1
col_shape[dc] = 1
state['exp_avg_sq_r'] = p.grad.new_zeros(row_shape)
state['exp_avg_sq_c'] = p.grad.new_zeros(col_shape)
else:
@ -226,12 +226,12 @@ def _single_tensor_adafactor(
# NOTE application of eps (epsilon1) mirrors the optax/big vision/t5x approach
if exp_avg_sq is None:
# factorized second moment
d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
exp_avg_sq_r.lerp_(grad_sqr.mean(dim=d0, keepdim=True), one_minus_beta2_t)
exp_avg_sq_c.lerp_(grad_sqr.mean(dim=d1, keepdim=True), one_minus_beta2_t)
dc, dr = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
exp_avg_sq_r.lerp_(grad_sqr.mean(dim=dr, keepdim=True), one_minus_beta2_t)
exp_avg_sq_c.lerp_(grad_sqr.mean(dim=dc, keepdim=True), one_minus_beta2_t)
reduced_d1 = d1 - 1 if d1 > d0 else d1
row_col_mean = exp_avg_sq_r.mean(dim=reduced_d1, keepdim=True)
reduce_dc = dc - 1 if dc > dr else dc
row_col_mean = exp_avg_sq_r.mean(dim=reduce_dc, keepdim=True)
row_factor = (exp_avg_sq_r / row_col_mean).rsqrt()
col_factor = exp_avg_sq_c.rsqrt()