mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
commit
aff194f42c
@ -29,7 +29,7 @@ def load_checkpoint(model, checkpoint_path, use_ema=False):
|
|||||||
|
|
||||||
|
|
||||||
def resume_checkpoint(model, checkpoint_path):
|
def resume_checkpoint(model, checkpoint_path):
|
||||||
optimizer_state = None
|
other_state = {}
|
||||||
resume_epoch = None
|
resume_epoch = None
|
||||||
if os.path.isfile(checkpoint_path):
|
if os.path.isfile(checkpoint_path):
|
||||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||||
@ -40,7 +40,9 @@ def resume_checkpoint(model, checkpoint_path):
|
|||||||
new_state_dict[name] = v
|
new_state_dict[name] = v
|
||||||
model.load_state_dict(new_state_dict)
|
model.load_state_dict(new_state_dict)
|
||||||
if 'optimizer' in checkpoint:
|
if 'optimizer' in checkpoint:
|
||||||
optimizer_state = checkpoint['optimizer']
|
other_state['optimizer'] = checkpoint['optimizer']
|
||||||
|
if 'amp' in checkpoint:
|
||||||
|
other_state['amp'] = checkpoint['amp']
|
||||||
if 'epoch' in checkpoint:
|
if 'epoch' in checkpoint:
|
||||||
resume_epoch = checkpoint['epoch']
|
resume_epoch = checkpoint['epoch']
|
||||||
if 'version' in checkpoint and checkpoint['version'] > 1:
|
if 'version' in checkpoint and checkpoint['version'] > 1:
|
||||||
@ -49,7 +51,7 @@ def resume_checkpoint(model, checkpoint_path):
|
|||||||
else:
|
else:
|
||||||
model.load_state_dict(checkpoint)
|
model.load_state_dict(checkpoint)
|
||||||
logging.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
logging.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
||||||
return optimizer_state, resume_epoch
|
return other_state, resume_epoch
|
||||||
else:
|
else:
|
||||||
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
|
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
|
||||||
raise FileNotFoundError()
|
raise FileNotFoundError()
|
||||||
|
@ -3,5 +3,6 @@ from .rmsprop_tf import RMSpropTF
|
|||||||
from .adamw import AdamW
|
from .adamw import AdamW
|
||||||
from .radam import RAdam
|
from .radam import RAdam
|
||||||
from .novograd import NovoGrad
|
from .novograd import NovoGrad
|
||||||
|
from .nvnovograd import NvNovoGrad
|
||||||
from .lookahead import Lookahead
|
from .lookahead import Lookahead
|
||||||
from .optim_factory import create_optimizer
|
from .optim_factory import create_optimizer
|
||||||
|
@ -13,37 +13,40 @@ class Lookahead(Optimizer):
|
|||||||
raise ValueError(f'Invalid slow update rate: {alpha}')
|
raise ValueError(f'Invalid slow update rate: {alpha}')
|
||||||
if not 1 <= k:
|
if not 1 <= k:
|
||||||
raise ValueError(f'Invalid lookahead steps: {k}')
|
raise ValueError(f'Invalid lookahead steps: {k}')
|
||||||
self.alpha = alpha
|
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
|
||||||
self.k = k
|
|
||||||
self.base_optimizer = base_optimizer
|
self.base_optimizer = base_optimizer
|
||||||
self.param_groups = self.base_optimizer.param_groups
|
self.param_groups = self.base_optimizer.param_groups
|
||||||
self.defaults = base_optimizer.defaults
|
self.defaults = base_optimizer.defaults
|
||||||
|
self.defaults.update(defaults)
|
||||||
self.state = defaultdict(dict)
|
self.state = defaultdict(dict)
|
||||||
for group in self.param_groups:
|
# manually add our defaults to the param groups
|
||||||
group["step_counter"] = 0
|
for name, default in defaults.items():
|
||||||
|
for group in self.param_groups:
|
||||||
|
group.setdefault(name, default)
|
||||||
|
|
||||||
def update_slow_weights(self, group):
|
def update_slow(self, group):
|
||||||
for fast_p in group["params"]:
|
for fast_p in group["params"]:
|
||||||
if fast_p.grad is None:
|
if fast_p.grad is None:
|
||||||
continue
|
continue
|
||||||
param_state = self.state[fast_p]
|
param_state = self.state[fast_p]
|
||||||
if "slow_buffer" not in param_state:
|
if 'slow_buffer' not in param_state:
|
||||||
param_state["slow_buffer"] = torch.empty_like(fast_p.data)
|
param_state['slow_buffer'] = torch.empty_like(fast_p.data)
|
||||||
param_state["slow_buffer"].copy_(fast_p.data)
|
param_state['slow_buffer'].copy_(fast_p.data)
|
||||||
slow = param_state["slow_buffer"]
|
slow = param_state['slow_buffer']
|
||||||
slow.add_(self.alpha, fast_p.data - slow)
|
slow.add_(group['lookahead_alpha'], fast_p.data - slow)
|
||||||
fast_p.data.copy_(slow)
|
fast_p.data.copy_(slow)
|
||||||
|
|
||||||
def sync_lookahead(self):
|
def sync_lookahead(self):
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
self.update_slow_weights(group)
|
self.update_slow(group)
|
||||||
|
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
|
#assert id(self.param_groups) == id(self.base_optimizer.param_groups)
|
||||||
loss = self.base_optimizer.step(closure)
|
loss = self.base_optimizer.step(closure)
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
group['step_counter'] += 1
|
group['lookahead_step'] += 1
|
||||||
if group['step_counter'] % self.k == 0:
|
if group['lookahead_step'] % group['lookahead_k'] == 0:
|
||||||
self.update_slow_weights(group)
|
self.update_slow(group)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
@ -52,37 +55,36 @@ class Lookahead(Optimizer):
|
|||||||
(id(k) if isinstance(k, torch.Tensor) else k): v
|
(id(k) if isinstance(k, torch.Tensor) else k): v
|
||||||
for k, v in self.state.items()
|
for k, v in self.state.items()
|
||||||
}
|
}
|
||||||
fast_state = fast_state_dict["state"]
|
fast_state = fast_state_dict['state']
|
||||||
param_groups = fast_state_dict["param_groups"]
|
param_groups = fast_state_dict['param_groups']
|
||||||
return {
|
return {
|
||||||
"state": fast_state,
|
'state': fast_state,
|
||||||
"slow_state": slow_state,
|
'slow_state': slow_state,
|
||||||
"param_groups": param_groups,
|
'param_groups': param_groups,
|
||||||
}
|
}
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
if 'slow_state' not in state_dict:
|
|
||||||
print('Loading state_dict from optimizer without Lookahead applied')
|
|
||||||
state_dict['slow_state'] = defaultdict(dict)
|
|
||||||
slow_state_dict = {
|
|
||||||
"state": state_dict["slow_state"],
|
|
||||||
"param_groups": state_dict["param_groups"],
|
|
||||||
}
|
|
||||||
fast_state_dict = {
|
fast_state_dict = {
|
||||||
"state": state_dict["state"],
|
'state': state_dict['state'],
|
||||||
"param_groups": state_dict["param_groups"],
|
'param_groups': state_dict['param_groups'],
|
||||||
}
|
}
|
||||||
super(Lookahead, self).load_state_dict(slow_state_dict)
|
|
||||||
self.base_optimizer.load_state_dict(fast_state_dict)
|
self.base_optimizer.load_state_dict(fast_state_dict)
|
||||||
|
|
||||||
def add_param_group(self, param_group):
|
# We want to restore the slow state, but share param_groups reference
|
||||||
r"""Add a param group to the :class:`Optimizer` s `param_groups`.
|
# with base_optimizer. This is a bit redundant but least code
|
||||||
This can be useful when fine tuning a pre-trained network as frozen
|
slow_state_new = False
|
||||||
layers can be made trainable and added to the :class:`Optimizer` as
|
if 'slow_state' not in state_dict:
|
||||||
training progresses.
|
print('Loading state_dict from optimizer without Lookahead applied.')
|
||||||
Args:
|
state_dict['slow_state'] = defaultdict(dict)
|
||||||
param_group (dict): Specifies what Tensors should be optimized along
|
slow_state_new = True
|
||||||
with group specific optimization options.
|
slow_state_dict = {
|
||||||
"""
|
'state': state_dict['slow_state'],
|
||||||
param_group['step_counter'] = 0
|
'param_groups': state_dict['param_groups'], # this is pointless but saves code
|
||||||
self.base_optimizer.add_param_group(param_group)
|
}
|
||||||
|
super(Lookahead, self).load_state_dict(slow_state_dict)
|
||||||
|
self.param_groups = self.base_optimizer.param_groups # make both ref same container
|
||||||
|
if slow_state_new:
|
||||||
|
# reapply defaults to catch missing lookahead specific ones
|
||||||
|
for name, default in self.defaults.items():
|
||||||
|
for group in self.param_groups:
|
||||||
|
group.setdefault(name, default)
|
||||||
|
118
timm/optim/nvnovograd.py
Normal file
118
timm/optim/nvnovograd.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
""" Nvidia NovoGrad Optimizer.
|
||||||
|
Original impl by Nvidia from Jasper example:
|
||||||
|
- https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper
|
||||||
|
Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
|
||||||
|
- https://arxiv.org/abs/1905.11286
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.optim.optimizer import Optimizer
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class NvNovoGrad(Optimizer):
|
||||||
|
"""
|
||||||
|
Implements Novograd algorithm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params (iterable): iterable of parameters to optimize or dicts defining
|
||||||
|
parameter groups
|
||||||
|
lr (float, optional): learning rate (default: 1e-3)
|
||||||
|
betas (Tuple[float, float], optional): coefficients used for computing
|
||||||
|
running averages of gradient and its square (default: (0.95, 0.98))
|
||||||
|
eps (float, optional): term added to the denominator to improve
|
||||||
|
numerical stability (default: 1e-8)
|
||||||
|
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||||
|
grad_averaging: gradient averaging
|
||||||
|
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||||
|
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||||
|
(default: False)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8,
|
||||||
|
weight_decay=0, grad_averaging=False, amsgrad=False):
|
||||||
|
if not 0.0 <= lr:
|
||||||
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
|
if not 0.0 <= eps:
|
||||||
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||||
|
if not 0.0 <= betas[0] < 1.0:
|
||||||
|
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||||
|
if not 0.0 <= betas[1] < 1.0:
|
||||||
|
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||||
|
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
grad_averaging=grad_averaging,
|
||||||
|
amsgrad=amsgrad)
|
||||||
|
|
||||||
|
super(NvNovoGrad, self).__init__(params, defaults)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
super(NvNovoGrad, self).__setstate__(state)
|
||||||
|
for group in self.param_groups:
|
||||||
|
group.setdefault('amsgrad', False)
|
||||||
|
|
||||||
|
def step(self, closure=None):
|
||||||
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
closure (callable, optional): A closure that reevaluates the model
|
||||||
|
and returns the loss.
|
||||||
|
"""
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group['params']:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
grad = p.grad.data
|
||||||
|
if grad.is_sparse:
|
||||||
|
raise RuntimeError('Sparse gradients are not supported.')
|
||||||
|
amsgrad = group['amsgrad']
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
|
# State initialization
|
||||||
|
if len(state) == 0:
|
||||||
|
state['step'] = 0
|
||||||
|
# Exponential moving average of gradient values
|
||||||
|
state['exp_avg'] = torch.zeros_like(p.data)
|
||||||
|
# Exponential moving average of squared gradient values
|
||||||
|
state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
|
||||||
|
if amsgrad:
|
||||||
|
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||||
|
state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
|
||||||
|
|
||||||
|
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||||
|
if amsgrad:
|
||||||
|
max_exp_avg_sq = state['max_exp_avg_sq']
|
||||||
|
beta1, beta2 = group['betas']
|
||||||
|
|
||||||
|
state['step'] += 1
|
||||||
|
|
||||||
|
norm = torch.sum(torch.pow(grad, 2))
|
||||||
|
|
||||||
|
if exp_avg_sq == 0:
|
||||||
|
exp_avg_sq.copy_(norm)
|
||||||
|
else:
|
||||||
|
exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
|
||||||
|
|
||||||
|
if amsgrad:
|
||||||
|
# Maintains the maximum of all 2nd moment running avg. till now
|
||||||
|
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
||||||
|
# Use the max. for normalizing running avg. of gradient
|
||||||
|
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
|
||||||
|
else:
|
||||||
|
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||||
|
|
||||||
|
grad.div_(denom)
|
||||||
|
if group['weight_decay'] != 0:
|
||||||
|
grad.add_(group['weight_decay'], p.data)
|
||||||
|
if group['grad_averaging']:
|
||||||
|
grad.mul_(1 - beta1)
|
||||||
|
exp_avg.mul_(beta1).add_(grad)
|
||||||
|
|
||||||
|
p.data.add_(-group['lr'], exp_avg)
|
||||||
|
|
||||||
|
return loss
|
@ -1,5 +1,11 @@
|
|||||||
|
import torch
|
||||||
from torch import optim as optim
|
from torch import optim as optim
|
||||||
from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, Lookahead
|
from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead
|
||||||
|
try:
|
||||||
|
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
|
||||||
|
has_apex = True
|
||||||
|
except ImportError:
|
||||||
|
has_apex = False
|
||||||
|
|
||||||
|
|
||||||
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
|
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
|
||||||
@ -20,9 +26,10 @@ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
|
|||||||
def create_optimizer(args, model, filter_bias_and_bn=True):
|
def create_optimizer(args, model, filter_bias_and_bn=True):
|
||||||
opt_lower = args.opt.lower()
|
opt_lower = args.opt.lower()
|
||||||
weight_decay = args.weight_decay
|
weight_decay = args.weight_decay
|
||||||
if opt_lower == 'adamw' or opt_lower == 'radam':
|
if 'adamw' in opt_lower or 'radam' in opt_lower:
|
||||||
# compensate for the way current AdamW and RAdam optimizers
|
# Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay
|
||||||
# apply the weight-decay
|
# I don't believe they follow the paper or original Torch7 impl which schedules weight
|
||||||
|
# decay based on the ratio of current_lr/initial_lr
|
||||||
weight_decay /= args.lr
|
weight_decay /= args.lr
|
||||||
if weight_decay and filter_bias_and_bn:
|
if weight_decay and filter_bias_and_bn:
|
||||||
parameters = add_weight_decay(model, weight_decay)
|
parameters = add_weight_decay(model, weight_decay)
|
||||||
@ -30,12 +37,14 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
|
|||||||
else:
|
else:
|
||||||
parameters = model.parameters()
|
parameters = model.parameters()
|
||||||
|
|
||||||
|
if 'fused' in opt_lower:
|
||||||
|
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
|
||||||
|
|
||||||
opt_split = opt_lower.split('_')
|
opt_split = opt_lower.split('_')
|
||||||
opt_lower = opt_split[-1]
|
opt_lower = opt_split[-1]
|
||||||
if opt_lower == 'sgd':
|
if opt_lower == 'sgd':
|
||||||
optimizer = optim.SGD(
|
optimizer = optim.SGD(
|
||||||
parameters, lr=args.lr,
|
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
|
||||||
momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
|
|
||||||
elif opt_lower == 'adam':
|
elif opt_lower == 'adam':
|
||||||
optimizer = optim.Adam(
|
optimizer = optim.Adam(
|
||||||
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
||||||
@ -61,6 +70,22 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
|
|||||||
momentum=args.momentum, weight_decay=weight_decay)
|
momentum=args.momentum, weight_decay=weight_decay)
|
||||||
elif opt_lower == 'novograd':
|
elif opt_lower == 'novograd':
|
||||||
optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
||||||
|
elif opt_lower == 'nvnovograd':
|
||||||
|
optimizer = NvNovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
||||||
|
elif opt_lower == 'fusedsgd':
|
||||||
|
optimizer = FusedSGD(
|
||||||
|
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
|
||||||
|
elif opt_lower == 'fusedadam':
|
||||||
|
optimizer = FusedAdam(
|
||||||
|
parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps)
|
||||||
|
elif opt_lower == 'fusedadamw':
|
||||||
|
optimizer = FusedAdam(
|
||||||
|
parameters, lr=args.lr, adam_w_mode=True, weight_decay=weight_decay, eps=args.opt_eps)
|
||||||
|
elif opt_lower == 'fusedlamb':
|
||||||
|
optimizer = FusedLAMB(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
||||||
|
elif opt_lower == 'fusednovograd':
|
||||||
|
optimizer = FusedNovoGrad(
|
||||||
|
parameters, lr=args.lr, betas=(0.95, 0.98), weight_decay=weight_decay, eps=args.opt_eps)
|
||||||
else:
|
else:
|
||||||
assert False and "Invalid optimizer"
|
assert False and "Invalid optimizer"
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
@ -11,6 +11,12 @@ import operator
|
|||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
try:
|
||||||
|
from apex import amp
|
||||||
|
has_apex = True
|
||||||
|
except ImportError:
|
||||||
|
amp = None
|
||||||
|
has_apex = False
|
||||||
|
|
||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
|
|
||||||
@ -50,7 +56,7 @@ class CheckpointSaver:
|
|||||||
self.max_history = max_history
|
self.max_history = max_history
|
||||||
assert self.max_history >= 1
|
assert self.max_history >= 1
|
||||||
|
|
||||||
def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None):
|
def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False):
|
||||||
assert epoch >= 0
|
assert epoch >= 0
|
||||||
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
|
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
|
||||||
if (len(self.checkpoint_files) < self.max_history
|
if (len(self.checkpoint_files) < self.max_history
|
||||||
@ -59,7 +65,7 @@ class CheckpointSaver:
|
|||||||
self._cleanup_checkpoints(1)
|
self._cleanup_checkpoints(1)
|
||||||
filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
|
filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
|
||||||
save_path = os.path.join(self.checkpoint_dir, filename)
|
save_path = os.path.join(self.checkpoint_dir, filename)
|
||||||
self._save(save_path, model, optimizer, args, epoch, model_ema, metric)
|
self._save(save_path, model, optimizer, args, epoch, model_ema, metric, use_amp)
|
||||||
self.checkpoint_files.append((save_path, metric))
|
self.checkpoint_files.append((save_path, metric))
|
||||||
self.checkpoint_files = sorted(
|
self.checkpoint_files = sorted(
|
||||||
self.checkpoint_files, key=lambda x: x[1],
|
self.checkpoint_files, key=lambda x: x[1],
|
||||||
@ -77,7 +83,7 @@ class CheckpointSaver:
|
|||||||
|
|
||||||
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
|
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
|
||||||
|
|
||||||
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None):
|
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False):
|
||||||
save_state = {
|
save_state = {
|
||||||
'epoch': epoch,
|
'epoch': epoch,
|
||||||
'arch': args.model,
|
'arch': args.model,
|
||||||
@ -86,6 +92,8 @@ class CheckpointSaver:
|
|||||||
'args': args,
|
'args': args,
|
||||||
'version': 2, # version < 2 increments epoch before save
|
'version': 2, # version < 2 increments epoch before save
|
||||||
}
|
}
|
||||||
|
if use_amp and 'state_dict' in amp.__dict__:
|
||||||
|
save_state['amp'] = amp.state_dict()
|
||||||
if model_ema is not None:
|
if model_ema is not None:
|
||||||
save_state['state_dict_ema'] = get_state_dict(model_ema)
|
save_state['state_dict_ema'] = get_state_dict(model_ema)
|
||||||
if metric is not None:
|
if metric is not None:
|
||||||
@ -106,11 +114,11 @@ class CheckpointSaver:
|
|||||||
logging.error("Exception '{}' while deleting checkpoint".format(e))
|
logging.error("Exception '{}' while deleting checkpoint".format(e))
|
||||||
self.checkpoint_files = self.checkpoint_files[:delete_index]
|
self.checkpoint_files = self.checkpoint_files[:delete_index]
|
||||||
|
|
||||||
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0):
|
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, use_amp=False, batch_idx=0):
|
||||||
assert epoch >= 0
|
assert epoch >= 0
|
||||||
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
|
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
|
||||||
save_path = os.path.join(self.recovery_dir, filename)
|
save_path = os.path.join(self.recovery_dir, filename)
|
||||||
self._save(save_path, model, optimizer, args, epoch, model_ema)
|
self._save(save_path, model, optimizer, args, epoch, model_ema, use_amp=use_amp)
|
||||||
if os.path.exists(self.last_recovery_file):
|
if os.path.exists(self.last_recovery_file):
|
||||||
try:
|
try:
|
||||||
logging.debug("Cleaning recovery: {}".format(self.last_recovery_file))
|
logging.debug("Cleaning recovery: {}".format(self.last_recovery_file))
|
||||||
|
30
train.py
30
train.py
@ -38,6 +38,8 @@ parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH'
|
|||||||
help='Initialize model from this checkpoint (default: none)')
|
help='Initialize model from this checkpoint (default: none)')
|
||||||
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||||
help='Resume full model and optimizer state from checkpoint (default: none)')
|
help='Resume full model and optimizer state from checkpoint (default: none)')
|
||||||
|
parser.add_argument('--no-resume-opt', action='store_true', default=False,
|
||||||
|
help='prevent resume of optimizer state when resuming model')
|
||||||
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
|
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
|
||||||
help='number of label classes (default: 1000)')
|
help='number of label classes (default: 1000)')
|
||||||
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
|
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
|
||||||
@ -189,12 +191,6 @@ def main():
|
|||||||
|
|
||||||
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
||||||
|
|
||||||
# optionally resume from a checkpoint
|
|
||||||
optimizer_state = None
|
|
||||||
resume_epoch = None
|
|
||||||
if args.resume:
|
|
||||||
optimizer_state, resume_epoch = resume_checkpoint(model, args.resume)
|
|
||||||
|
|
||||||
if args.num_gpu > 1:
|
if args.num_gpu > 1:
|
||||||
if args.amp:
|
if args.amp:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
@ -205,8 +201,6 @@ def main():
|
|||||||
model.cuda()
|
model.cuda()
|
||||||
|
|
||||||
optimizer = create_optimizer(args, model)
|
optimizer = create_optimizer(args, model)
|
||||||
if optimizer_state is not None:
|
|
||||||
optimizer.load_state_dict(optimizer_state)
|
|
||||||
|
|
||||||
use_amp = False
|
use_amp = False
|
||||||
if has_apex and args.amp:
|
if has_apex and args.amp:
|
||||||
@ -216,6 +210,22 @@ def main():
|
|||||||
logging.info('NVIDIA APEX {}. AMP {}.'.format(
|
logging.info('NVIDIA APEX {}. AMP {}.'.format(
|
||||||
'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
|
'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
|
||||||
|
|
||||||
|
# optionally resume from a checkpoint
|
||||||
|
resume_state = {}
|
||||||
|
resume_epoch = None
|
||||||
|
if args.resume:
|
||||||
|
resume_state, resume_epoch = resume_checkpoint(model, args.resume)
|
||||||
|
if resume_state and not args.no_resume_opt:
|
||||||
|
if 'optimizer' in resume_state:
|
||||||
|
if args.local_rank == 0:
|
||||||
|
logging.info('Restoring Optimizer state from checkpoint')
|
||||||
|
optimizer.load_state_dict(resume_state['optimizer'])
|
||||||
|
if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
|
||||||
|
if args.local_rank == 0:
|
||||||
|
logging.info('Restoring NVIDIA AMP state from checkpoint')
|
||||||
|
amp.load_state_dict(resume_state['amp'])
|
||||||
|
resume_state = None
|
||||||
|
|
||||||
model_ema = None
|
model_ema = None
|
||||||
if args.model_ema:
|
if args.model_ema:
|
||||||
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
||||||
@ -363,7 +373,7 @@ def main():
|
|||||||
save_metric = eval_metrics[eval_metric]
|
save_metric = eval_metrics[eval_metric]
|
||||||
best_metric, best_epoch = saver.save_checkpoint(
|
best_metric, best_epoch = saver.save_checkpoint(
|
||||||
model, optimizer, args,
|
model, optimizer, args,
|
||||||
epoch=epoch, model_ema=model_ema, metric=save_metric)
|
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
@ -456,7 +466,7 @@ def train_epoch(
|
|||||||
if saver is not None and args.recovery_interval and (
|
if saver is not None and args.recovery_interval and (
|
||||||
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
||||||
saver.save_recovery(
|
saver.save_recovery(
|
||||||
model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx)
|
model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx)
|
||||||
|
|
||||||
if lr_scheduler is not None:
|
if lr_scheduler is not None:
|
||||||
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user