diff --git a/timm/optim/adafactor.py b/timm/optim/adafactor.py index 06057433..37871af1 100644 --- a/timm/optim/adafactor.py +++ b/timm/optim/adafactor.py @@ -38,16 +38,38 @@ class Adafactor(torch.optim.Optimizer): whether warm-up initialization is being used (default: False) """ - def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0, - decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False): + def __init__( + self, + params, + lr=None, + eps=1e-30, + eps_scale=1e-3, + clip_threshold=1.0, + decay_rate=-0.8, + betas=None, + weight_decay=0.0, + scale_parameter=True, + warmup_init=False, + min_dim_size_to_factor=32, + ): relative_step = not lr if warmup_init and not relative_step: raise ValueError('warmup_init requires relative_step=True') beta1 = None if betas is None else betas[0] # make it compat with standard betas arg - defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate, - beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter, - relative_step=relative_step, warmup_init=warmup_init) + defaults = dict( + lr=lr, + eps=eps, + eps_scale=eps_scale, + clip_threshold=clip_threshold, + decay_rate=decay_rate, + beta1=beta1, + weight_decay=weight_decay, + scale_parameter=scale_parameter, + relative_step=relative_step, + warmup_init=warmup_init, + min_dim_size_to_factor=min_dim_size_to_factor, + ) super(Adafactor, self).__init__(params, defaults) @staticmethod @@ -62,20 +84,34 @@ class Adafactor(torch.optim.Optimizer): return param_group['lr'] @staticmethod - def _get_options(param_group, param_shape): - factored = len(param_shape) >= 2 + def _get_options(param_group, param_shape, min_size_to_factor=32): use_first_moment = param_group['beta1'] is not None + factored = None + ndim = len(param_shape) + # Use a simple heuristic to pick factorization row & col, note other PyTorch impl tend to + # always use -2, -1 BUT this will not pick correct dims for convolutions. This is a simple + # approach that should work in most cases, compare to the slightly more involved approach + # in AdafactorBigVision that sorts dims by size, please report if wrong dims chosen. + if ndim > 2 and param_shape[0] > min_size_to_factor and param_shape[1] > min_size_to_factor: + # nD convs in torch are ND + 2 dim weights with leading in/out chs + factored = 0, 1 + elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor: + # if the criteria above didn't match, check trailing dims + factored = ndim - 2, ndim - 1 + return factored, use_first_moment @staticmethod def _rms(tensor): return tensor.norm(2) / (tensor.numel() ** 0.5) - def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): - r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) - c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col, dim_col, dim_row): + # from our dim heuristic, always dim_col < dim_row, so col reduction dim for factored row = dim_col + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=dim_col, keepdim=True)).rsqrt_().unsqueeze(dim_row) + c_factor = exp_avg_sq_col.unsqueeze(dim_col).rsqrt() return torch.mul(r_factor, c_factor) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -99,7 +135,11 @@ class Adafactor(torch.optim.Optimizer): state = self.state[p] - factored, use_first_moment = self._get_options(group, grad.shape) + factored_dims, use_first_moment = self._get_options( + group, + grad.shape, + min_size_to_factor=group['min_dim_size_to_factor'], + ) # State Initialization if len(state) == 0: state['step'] = 0 @@ -107,9 +147,12 @@ class Adafactor(torch.optim.Optimizer): if use_first_moment: # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(grad) - if factored: - state['exp_avg_sq_row'] = torch.zeros(grad.shape[:-1]).to(grad) - state['exp_avg_sq_col'] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad) + if factored_dims is not None: + dim_col, dim_row = factored_dims + def _remove_dim(shape, dim): + return shape[:dim] + shape[dim + 1:] + state['exp_avg_sq_row'] = torch.zeros(_remove_dim(grad.shape, dim_row)).to(grad) + state['exp_avg_sq_col'] = torch.zeros(_remove_dim(grad.shape, dim_col)).to(grad) else: state['exp_avg_sq'] = torch.zeros_like(grad) @@ -117,7 +160,7 @@ class Adafactor(torch.optim.Optimizer): else: if use_first_moment: state['exp_avg'] = state['exp_avg'].to(grad) - if factored: + if factored_dims is not None: state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad) state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad) else: @@ -133,15 +176,16 @@ class Adafactor(torch.optim.Optimizer): beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) update = grad ** 2 + group['eps'] - if factored: + if factored_dims is not None: + dim_col, dim_row = factored_dims exp_avg_sq_row = state['exp_avg_sq_row'] exp_avg_sq_col = state['exp_avg_sq_col'] - exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) - exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t) + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=dim_row), alpha=1.0 - beta2t) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=dim_col), alpha=1.0 - beta2t) # Approximation of exponential moving average of square of gradient - update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col, dim_col, dim_row) update.mul_(grad) else: exp_avg_sq = state['exp_avg_sq']