mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Improve row/col dim var name
This commit is contained in:
parent
e7b0480381
commit
36a45e5d94
@ -139,11 +139,11 @@ class AdafactorBigVision(Optimizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if factored_dims is not None:
|
if factored_dims is not None:
|
||||||
d1, d0 = factored_dims
|
dc, dr = factored_dims
|
||||||
row_shape = list(p.grad.shape)
|
row_shape = list(p.grad.shape)
|
||||||
row_shape[d0] = 1
|
row_shape[dr] = 1
|
||||||
col_shape = list(p.grad.shape)
|
col_shape = list(p.grad.shape)
|
||||||
col_shape[d1] = 1
|
col_shape[dc] = 1
|
||||||
state['exp_avg_sq_r'] = p.grad.new_zeros(row_shape)
|
state['exp_avg_sq_r'] = p.grad.new_zeros(row_shape)
|
||||||
state['exp_avg_sq_c'] = p.grad.new_zeros(col_shape)
|
state['exp_avg_sq_c'] = p.grad.new_zeros(col_shape)
|
||||||
else:
|
else:
|
||||||
@ -226,12 +226,12 @@ def _single_tensor_adafactor(
|
|||||||
# NOTE application of eps (epsilon1) mirrors the optax/big vision/t5x approach
|
# NOTE application of eps (epsilon1) mirrors the optax/big vision/t5x approach
|
||||||
if exp_avg_sq is None:
|
if exp_avg_sq is None:
|
||||||
# factorized second moment
|
# factorized second moment
|
||||||
d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
|
dc, dr = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
|
||||||
exp_avg_sq_r.lerp_(grad_sqr.mean(dim=d0, keepdim=True), one_minus_beta2_t)
|
exp_avg_sq_r.lerp_(grad_sqr.mean(dim=dr, keepdim=True), one_minus_beta2_t)
|
||||||
exp_avg_sq_c.lerp_(grad_sqr.mean(dim=d1, keepdim=True), one_minus_beta2_t)
|
exp_avg_sq_c.lerp_(grad_sqr.mean(dim=dc, keepdim=True), one_minus_beta2_t)
|
||||||
|
|
||||||
reduced_d1 = d1 - 1 if d1 > d0 else d1
|
reduce_dc = dc - 1 if dc > dr else dc
|
||||||
row_col_mean = exp_avg_sq_r.mean(dim=reduced_d1, keepdim=True)
|
row_col_mean = exp_avg_sq_r.mean(dim=reduce_dc, keepdim=True)
|
||||||
row_factor = (exp_avg_sq_r / row_col_mean).rsqrt()
|
row_factor = (exp_avg_sq_r / row_col_mean).rsqrt()
|
||||||
col_factor = exp_avg_sq_c.rsqrt()
|
col_factor = exp_avg_sq_c.rsqrt()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user