Some Lookahead cleanup and fixes

This commit is contained in:
Ross Wightman 2019-08-29 15:14:35 -07:00
parent fac58f609a
commit ba3c97c3ad
2 changed files with 42 additions and 40 deletions

View File

@ -13,37 +13,40 @@ class Lookahead(Optimizer):
raise ValueError(f'Invalid slow update rate: {alpha}') raise ValueError(f'Invalid slow update rate: {alpha}')
if not 1 <= k: if not 1 <= k:
raise ValueError(f'Invalid lookahead steps: {k}') raise ValueError(f'Invalid lookahead steps: {k}')
self.alpha = alpha defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
self.k = k
self.base_optimizer = base_optimizer self.base_optimizer = base_optimizer
self.param_groups = self.base_optimizer.param_groups self.param_groups = self.base_optimizer.param_groups
self.defaults = base_optimizer.defaults self.defaults = base_optimizer.defaults
self.defaults.update(defaults)
self.state = defaultdict(dict) self.state = defaultdict(dict)
# manually add our defaults to the param groups
for name, default in defaults.items():
for group in self.param_groups: for group in self.param_groups:
group["step_counter"] = 0 group.setdefault(name, default)
def update_slow_weights(self, group): def update_slow(self, group):
for fast_p in group["params"]: for fast_p in group["params"]:
if fast_p.grad is None: if fast_p.grad is None:
continue continue
param_state = self.state[fast_p] param_state = self.state[fast_p]
if "slow_buffer" not in param_state: if 'slow_buffer' not in param_state:
param_state["slow_buffer"] = torch.empty_like(fast_p.data) param_state['slow_buffer'] = torch.empty_like(fast_p.data)
param_state["slow_buffer"].copy_(fast_p.data) param_state['slow_buffer'].copy_(fast_p.data)
slow = param_state["slow_buffer"] slow = param_state['slow_buffer']
slow.add_(self.alpha, fast_p.data - slow) slow.add_(group['lookahead_alpha'], fast_p.data - slow)
fast_p.data.copy_(slow) fast_p.data.copy_(slow)
def sync_lookahead(self): def sync_lookahead(self):
for group in self.param_groups: for group in self.param_groups:
self.update_slow_weights(group) self.update_slow(group)
def step(self, closure=None): def step(self, closure=None):
#assert id(self.param_groups) == id(self.base_optimizer.param_groups)
loss = self.base_optimizer.step(closure) loss = self.base_optimizer.step(closure)
for group in self.param_groups: for group in self.param_groups:
group['step_counter'] += 1 group['lookahead_step'] += 1
if group['step_counter'] % self.k == 0: if group['lookahead_step'] % group['lookahead_k'] == 0:
self.update_slow_weights(group) self.update_slow(group)
return loss return loss
def state_dict(self): def state_dict(self):
@ -52,37 +55,36 @@ class Lookahead(Optimizer):
(id(k) if isinstance(k, torch.Tensor) else k): v (id(k) if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items() for k, v in self.state.items()
} }
fast_state = fast_state_dict["state"] fast_state = fast_state_dict['state']
param_groups = fast_state_dict["param_groups"] param_groups = fast_state_dict['param_groups']
return { return {
"state": fast_state, 'state': fast_state,
"slow_state": slow_state, 'slow_state': slow_state,
"param_groups": param_groups, 'param_groups': param_groups,
} }
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
if 'slow_state' not in state_dict:
print('Loading state_dict from optimizer without Lookahead applied')
state_dict['slow_state'] = defaultdict(dict)
slow_state_dict = {
"state": state_dict["slow_state"],
"param_groups": state_dict["param_groups"],
}
fast_state_dict = { fast_state_dict = {
"state": state_dict["state"], 'state': state_dict['state'],
"param_groups": state_dict["param_groups"], 'param_groups': state_dict['param_groups'],
} }
super(Lookahead, self).load_state_dict(slow_state_dict)
self.base_optimizer.load_state_dict(fast_state_dict) self.base_optimizer.load_state_dict(fast_state_dict)
def add_param_group(self, param_group): # We want to restore the slow state, but share param_groups reference
r"""Add a param group to the :class:`Optimizer` s `param_groups`. # with base_optimizer. This is a bit redundant but least code
This can be useful when fine tuning a pre-trained network as frozen slow_state_new = False
layers can be made trainable and added to the :class:`Optimizer` as if 'slow_state' not in state_dict:
training progresses. print('Loading state_dict from optimizer without Lookahead applied.')
Args: state_dict['slow_state'] = defaultdict(dict)
param_group (dict): Specifies what Tensors should be optimized along slow_state_new = True
with group specific optimization options. slow_state_dict = {
""" 'state': state_dict['slow_state'],
param_group['step_counter'] = 0 'param_groups': state_dict['param_groups'], # this is pointless but saves code
self.base_optimizer.add_param_group(param_group) }
super(Lookahead, self).load_state_dict(slow_state_dict)
self.param_groups = self.base_optimizer.param_groups # make both ref same container
if slow_state_new:
# reapply defaults to catch missing lookahead specific ones
for name, default in self.defaults.items():
for group in self.param_groups:
group.setdefault(name, default)

0
timm/optim/nvnovograd.py Normal file
View File