mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Improve row/col dim var name
This commit is contained in:
parent
e7b0480381
commit
36a45e5d94
@ -139,11 +139,11 @@ class AdafactorBigVision(Optimizer):
|
||||
)
|
||||
|
||||
if factored_dims is not None:
|
||||
d1, d0 = factored_dims
|
||||
dc, dr = factored_dims
|
||||
row_shape = list(p.grad.shape)
|
||||
row_shape[d0] = 1
|
||||
row_shape[dr] = 1
|
||||
col_shape = list(p.grad.shape)
|
||||
col_shape[d1] = 1
|
||||
col_shape[dc] = 1
|
||||
state['exp_avg_sq_r'] = p.grad.new_zeros(row_shape)
|
||||
state['exp_avg_sq_c'] = p.grad.new_zeros(col_shape)
|
||||
else:
|
||||
@ -226,12 +226,12 @@ def _single_tensor_adafactor(
|
||||
# NOTE application of eps (epsilon1) mirrors the optax/big vision/t5x approach
|
||||
if exp_avg_sq is None:
|
||||
# factorized second moment
|
||||
d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
|
||||
exp_avg_sq_r.lerp_(grad_sqr.mean(dim=d0, keepdim=True), one_minus_beta2_t)
|
||||
exp_avg_sq_c.lerp_(grad_sqr.mean(dim=d1, keepdim=True), one_minus_beta2_t)
|
||||
dc, dr = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
|
||||
exp_avg_sq_r.lerp_(grad_sqr.mean(dim=dr, keepdim=True), one_minus_beta2_t)
|
||||
exp_avg_sq_c.lerp_(grad_sqr.mean(dim=dc, keepdim=True), one_minus_beta2_t)
|
||||
|
||||
reduced_d1 = d1 - 1 if d1 > d0 else d1
|
||||
row_col_mean = exp_avg_sq_r.mean(dim=reduced_d1, keepdim=True)
|
||||
reduce_dc = dc - 1 if dc > dr else dc
|
||||
row_col_mean = exp_avg_sq_r.mean(dim=reduce_dc, keepdim=True)
|
||||
row_factor = (exp_avg_sq_r / row_col_mean).rsqrt()
|
||||
col_factor = exp_avg_sq_c.rsqrt()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user