mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Some Lookahead cleanup and fixes
This commit is contained in:
parent
fac58f609a
commit
ba3c97c3ad
@ -13,37 +13,40 @@ class Lookahead(Optimizer):
|
|||||||
raise ValueError(f'Invalid slow update rate: {alpha}')
|
raise ValueError(f'Invalid slow update rate: {alpha}')
|
||||||
if not 1 <= k:
|
if not 1 <= k:
|
||||||
raise ValueError(f'Invalid lookahead steps: {k}')
|
raise ValueError(f'Invalid lookahead steps: {k}')
|
||||||
self.alpha = alpha
|
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
|
||||||
self.k = k
|
|
||||||
self.base_optimizer = base_optimizer
|
self.base_optimizer = base_optimizer
|
||||||
self.param_groups = self.base_optimizer.param_groups
|
self.param_groups = self.base_optimizer.param_groups
|
||||||
self.defaults = base_optimizer.defaults
|
self.defaults = base_optimizer.defaults
|
||||||
|
self.defaults.update(defaults)
|
||||||
self.state = defaultdict(dict)
|
self.state = defaultdict(dict)
|
||||||
|
# manually add our defaults to the param groups
|
||||||
|
for name, default in defaults.items():
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
group["step_counter"] = 0
|
group.setdefault(name, default)
|
||||||
|
|
||||||
def update_slow_weights(self, group):
|
def update_slow(self, group):
|
||||||
for fast_p in group["params"]:
|
for fast_p in group["params"]:
|
||||||
if fast_p.grad is None:
|
if fast_p.grad is None:
|
||||||
continue
|
continue
|
||||||
param_state = self.state[fast_p]
|
param_state = self.state[fast_p]
|
||||||
if "slow_buffer" not in param_state:
|
if 'slow_buffer' not in param_state:
|
||||||
param_state["slow_buffer"] = torch.empty_like(fast_p.data)
|
param_state['slow_buffer'] = torch.empty_like(fast_p.data)
|
||||||
param_state["slow_buffer"].copy_(fast_p.data)
|
param_state['slow_buffer'].copy_(fast_p.data)
|
||||||
slow = param_state["slow_buffer"]
|
slow = param_state['slow_buffer']
|
||||||
slow.add_(self.alpha, fast_p.data - slow)
|
slow.add_(group['lookahead_alpha'], fast_p.data - slow)
|
||||||
fast_p.data.copy_(slow)
|
fast_p.data.copy_(slow)
|
||||||
|
|
||||||
def sync_lookahead(self):
|
def sync_lookahead(self):
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
self.update_slow_weights(group)
|
self.update_slow(group)
|
||||||
|
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
|
#assert id(self.param_groups) == id(self.base_optimizer.param_groups)
|
||||||
loss = self.base_optimizer.step(closure)
|
loss = self.base_optimizer.step(closure)
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
group['step_counter'] += 1
|
group['lookahead_step'] += 1
|
||||||
if group['step_counter'] % self.k == 0:
|
if group['lookahead_step'] % group['lookahead_k'] == 0:
|
||||||
self.update_slow_weights(group)
|
self.update_slow(group)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
@ -52,37 +55,36 @@ class Lookahead(Optimizer):
|
|||||||
(id(k) if isinstance(k, torch.Tensor) else k): v
|
(id(k) if isinstance(k, torch.Tensor) else k): v
|
||||||
for k, v in self.state.items()
|
for k, v in self.state.items()
|
||||||
}
|
}
|
||||||
fast_state = fast_state_dict["state"]
|
fast_state = fast_state_dict['state']
|
||||||
param_groups = fast_state_dict["param_groups"]
|
param_groups = fast_state_dict['param_groups']
|
||||||
return {
|
return {
|
||||||
"state": fast_state,
|
'state': fast_state,
|
||||||
"slow_state": slow_state,
|
'slow_state': slow_state,
|
||||||
"param_groups": param_groups,
|
'param_groups': param_groups,
|
||||||
}
|
}
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
if 'slow_state' not in state_dict:
|
|
||||||
print('Loading state_dict from optimizer without Lookahead applied')
|
|
||||||
state_dict['slow_state'] = defaultdict(dict)
|
|
||||||
slow_state_dict = {
|
|
||||||
"state": state_dict["slow_state"],
|
|
||||||
"param_groups": state_dict["param_groups"],
|
|
||||||
}
|
|
||||||
fast_state_dict = {
|
fast_state_dict = {
|
||||||
"state": state_dict["state"],
|
'state': state_dict['state'],
|
||||||
"param_groups": state_dict["param_groups"],
|
'param_groups': state_dict['param_groups'],
|
||||||
}
|
}
|
||||||
super(Lookahead, self).load_state_dict(slow_state_dict)
|
|
||||||
self.base_optimizer.load_state_dict(fast_state_dict)
|
self.base_optimizer.load_state_dict(fast_state_dict)
|
||||||
|
|
||||||
def add_param_group(self, param_group):
|
# We want to restore the slow state, but share param_groups reference
|
||||||
r"""Add a param group to the :class:`Optimizer` s `param_groups`.
|
# with base_optimizer. This is a bit redundant but least code
|
||||||
This can be useful when fine tuning a pre-trained network as frozen
|
slow_state_new = False
|
||||||
layers can be made trainable and added to the :class:`Optimizer` as
|
if 'slow_state' not in state_dict:
|
||||||
training progresses.
|
print('Loading state_dict from optimizer without Lookahead applied.')
|
||||||
Args:
|
state_dict['slow_state'] = defaultdict(dict)
|
||||||
param_group (dict): Specifies what Tensors should be optimized along
|
slow_state_new = True
|
||||||
with group specific optimization options.
|
slow_state_dict = {
|
||||||
"""
|
'state': state_dict['slow_state'],
|
||||||
param_group['step_counter'] = 0
|
'param_groups': state_dict['param_groups'], # this is pointless but saves code
|
||||||
self.base_optimizer.add_param_group(param_group)
|
}
|
||||||
|
super(Lookahead, self).load_state_dict(slow_state_dict)
|
||||||
|
self.param_groups = self.base_optimizer.param_groups # make both ref same container
|
||||||
|
if slow_state_new:
|
||||||
|
# reapply defaults to catch missing lookahead specific ones
|
||||||
|
for name, default in self.defaults.items():
|
||||||
|
for group in self.param_groups:
|
||||||
|
group.setdefault(name, default)
|
||||||
|
0
timm/optim/nvnovograd.py
Normal file
0
timm/optim/nvnovograd.py
Normal file
Loading…
x
Reference in New Issue
Block a user