mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Cleanup Apex vs native AMP scaler state save/load. Cleanup CheckpointSaver a bit.
This commit is contained in:
parent
80c9d9cc72
commit
9c297ec67d
@ -48,30 +48,41 @@ def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
|
|||||||
model.load_state_dict(state_dict, strict=strict)
|
model.load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
|
|
||||||
def resume_checkpoint(model, checkpoint_path):
|
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
|
||||||
other_state = {}
|
|
||||||
resume_epoch = None
|
resume_epoch = None
|
||||||
if os.path.isfile(checkpoint_path):
|
if os.path.isfile(checkpoint_path):
|
||||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||||
|
if log_info:
|
||||||
|
_logger.info('Restoring model state from checkpoint...')
|
||||||
new_state_dict = OrderedDict()
|
new_state_dict = OrderedDict()
|
||||||
for k, v in checkpoint['state_dict'].items():
|
for k, v in checkpoint['state_dict'].items():
|
||||||
name = k[7:] if k.startswith('module') else k
|
name = k[7:] if k.startswith('module') else k
|
||||||
new_state_dict[name] = v
|
new_state_dict[name] = v
|
||||||
model.load_state_dict(new_state_dict)
|
model.load_state_dict(new_state_dict)
|
||||||
if 'optimizer' in checkpoint:
|
|
||||||
other_state['optimizer'] = checkpoint['optimizer']
|
if optimizer is not None and 'optimizer' in checkpoint:
|
||||||
if 'amp' in checkpoint:
|
if log_info:
|
||||||
other_state['amp'] = checkpoint['amp']
|
_logger.info('Restoring optimizer state from checkpoint...')
|
||||||
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
|
|
||||||
|
if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
|
||||||
|
if log_info:
|
||||||
|
_logger.info('Restoring AMP loss scaler state from checkpoint...')
|
||||||
|
loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
|
||||||
|
|
||||||
if 'epoch' in checkpoint:
|
if 'epoch' in checkpoint:
|
||||||
resume_epoch = checkpoint['epoch']
|
resume_epoch = checkpoint['epoch']
|
||||||
if 'version' in checkpoint and checkpoint['version'] > 1:
|
if 'version' in checkpoint and checkpoint['version'] > 1:
|
||||||
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
|
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
|
||||||
|
|
||||||
|
if log_info:
|
||||||
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(checkpoint)
|
model.load_state_dict(checkpoint)
|
||||||
|
if log_info:
|
||||||
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
||||||
return other_state, resume_epoch
|
return resume_epoch
|
||||||
else:
|
else:
|
||||||
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
||||||
raise FileNotFoundError()
|
raise FileNotFoundError()
|
||||||
|
@ -37,20 +37,67 @@ def unwrap_model(model):
|
|||||||
return model.module if hasattr(model, 'module') else model
|
return model.module if hasattr(model, 'module') else model
|
||||||
|
|
||||||
|
|
||||||
def get_state_dict(model):
|
def get_state_dict(model, unwrap_fn=unwrap_model):
|
||||||
return unwrap_model(model).state_dict()
|
return unwrap_fn(model).state_dict()
|
||||||
|
|
||||||
|
|
||||||
|
class ApexScaler:
|
||||||
|
state_dict_key = "amp"
|
||||||
|
|
||||||
|
def __call__(self, loss, optimizer):
|
||||||
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||||
|
scaled_loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
if 'state_dict' in amp.__dict__:
|
||||||
|
return amp.state_dict()
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
if 'load_state_dict' in amp.__dict__:
|
||||||
|
amp.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class NativeScaler:
|
||||||
|
state_dict_key = "amp_scaler"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
|
def __call__(self, loss, optimizer):
|
||||||
|
self._scaler.scale(loss).backward()
|
||||||
|
self._scaler.step(optimizer)
|
||||||
|
self._scaler.update()
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return self._scaler.state_dict()
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
self._scaler.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
class CheckpointSaver:
|
class CheckpointSaver:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
args=None,
|
||||||
|
model_ema=None,
|
||||||
|
amp_scaler=None,
|
||||||
checkpoint_prefix='checkpoint',
|
checkpoint_prefix='checkpoint',
|
||||||
recovery_prefix='recovery',
|
recovery_prefix='recovery',
|
||||||
checkpoint_dir='',
|
checkpoint_dir='',
|
||||||
recovery_dir='',
|
recovery_dir='',
|
||||||
decreasing=False,
|
decreasing=False,
|
||||||
max_history=10,
|
max_history=10,
|
||||||
save_amp=False):
|
unwrap_fn=unwrap_model):
|
||||||
|
|
||||||
|
# objects to save state_dicts of
|
||||||
|
self.model = model
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.args = args
|
||||||
|
self.model_ema = model_ema
|
||||||
|
self.amp_scaler = amp_scaler
|
||||||
|
|
||||||
# state
|
# state
|
||||||
self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
|
self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
|
||||||
@ -68,14 +115,14 @@ class CheckpointSaver:
|
|||||||
self.decreasing = decreasing # a lower metric is better if True
|
self.decreasing = decreasing # a lower metric is better if True
|
||||||
self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
|
self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
|
||||||
self.max_history = max_history
|
self.max_history = max_history
|
||||||
self.save_apex_amp = save_amp # save APEX amp state
|
self.unwrap_fn = unwrap_fn
|
||||||
assert self.max_history >= 1
|
assert self.max_history >= 1
|
||||||
|
|
||||||
def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None):
|
def save_checkpoint(self, epoch, metric=None):
|
||||||
assert epoch >= 0
|
assert epoch >= 0
|
||||||
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
|
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
|
||||||
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
|
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
|
||||||
self._save(tmp_save_path, model, optimizer, args, epoch, model_ema, metric)
|
self._save(tmp_save_path, epoch, metric)
|
||||||
if os.path.exists(last_save_path):
|
if os.path.exists(last_save_path):
|
||||||
os.unlink(last_save_path) # required for Windows support.
|
os.unlink(last_save_path) # required for Windows support.
|
||||||
os.rename(tmp_save_path, last_save_path)
|
os.rename(tmp_save_path, last_save_path)
|
||||||
@ -107,19 +154,21 @@ class CheckpointSaver:
|
|||||||
|
|
||||||
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
|
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
|
||||||
|
|
||||||
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None):
|
def _save(self, save_path, epoch, metric=None):
|
||||||
save_state = {
|
save_state = {
|
||||||
'epoch': epoch,
|
'epoch': epoch,
|
||||||
'arch': args.model,
|
'arch': type(self.model).__name__.lower(),
|
||||||
'state_dict': get_state_dict(model),
|
'state_dict': get_state_dict(self.model, self.unwrap_fn),
|
||||||
'optimizer': optimizer.state_dict(),
|
'optimizer': self.optimizer.state_dict(),
|
||||||
'args': args,
|
|
||||||
'version': 2, # version < 2 increments epoch before save
|
'version': 2, # version < 2 increments epoch before save
|
||||||
}
|
}
|
||||||
if self.save_apex_amp and 'state_dict' in amp.__dict__:
|
if self.args is not None:
|
||||||
save_state['amp'] = amp.state_dict()
|
save_state['arch'] = self.args.model
|
||||||
if model_ema is not None:
|
save_state['args'] = self.args
|
||||||
save_state['state_dict_ema'] = get_state_dict(model_ema)
|
if self.amp_scaler is not None:
|
||||||
|
save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict()
|
||||||
|
if self.model_ema is not None:
|
||||||
|
save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn)
|
||||||
if metric is not None:
|
if metric is not None:
|
||||||
save_state['metric'] = metric
|
save_state['metric'] = metric
|
||||||
torch.save(save_state, save_path)
|
torch.save(save_state, save_path)
|
||||||
@ -138,11 +187,11 @@ class CheckpointSaver:
|
|||||||
_logger.error("Exception '{}' while deleting checkpoint".format(e))
|
_logger.error("Exception '{}' while deleting checkpoint".format(e))
|
||||||
self.checkpoint_files = self.checkpoint_files[:delete_index]
|
self.checkpoint_files = self.checkpoint_files[:delete_index]
|
||||||
|
|
||||||
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0):
|
def save_recovery(self, epoch, batch_idx=0):
|
||||||
assert epoch >= 0
|
assert epoch >= 0
|
||||||
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
|
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
|
||||||
save_path = os.path.join(self.recovery_dir, filename)
|
save_path = os.path.join(self.recovery_dir, filename)
|
||||||
self._save(save_path, model, optimizer, args, epoch, model_ema)
|
self._save(save_path, epoch)
|
||||||
if os.path.exists(self.last_recovery_file):
|
if os.path.exists(self.last_recovery_file):
|
||||||
try:
|
try:
|
||||||
_logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
|
_logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
|
||||||
@ -336,3 +385,16 @@ def add_bool_arg(parser, name, default=False, help=''):
|
|||||||
group.add_argument('--' + name, dest=dest_name, action='store_true', help=help)
|
group.add_argument('--' + name, dest=dest_name, action='store_true', help=help)
|
||||||
group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help)
|
group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help)
|
||||||
parser.set_defaults(**{dest_name: default})
|
parser.set_defaults(**{dest_name: default})
|
||||||
|
|
||||||
|
|
||||||
|
def set_jit_legacy():
|
||||||
|
""" Set JIT executor to legacy w/ support for op fusion
|
||||||
|
This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes
|
||||||
|
in the JIT exectutor. These API are not supported so could change.
|
||||||
|
"""
|
||||||
|
#
|
||||||
|
assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!"
|
||||||
|
torch._C._jit_set_profiling_executor(False)
|
||||||
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||||
|
#torch._C._jit_set_texpr_fuser_enabled(True)
|
||||||
|
46
train.py
46
train.py
@ -20,7 +20,6 @@ import yaml
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision.utils
|
import torchvision.utils
|
||||||
from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
||||||
@ -31,6 +30,7 @@ from timm.utils import *
|
|||||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
||||||
from timm.optim import create_optimizer
|
from timm.optim import create_optimizer
|
||||||
from timm.scheduler import create_scheduler
|
from timm.scheduler import create_scheduler
|
||||||
|
from timm.utils import ApexScaler, NativeScaler
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
@ -264,23 +264,6 @@ def _parse_args():
|
|||||||
return args, args_text
|
return args, args_text
|
||||||
|
|
||||||
|
|
||||||
class ApexScaler:
|
|
||||||
def __call__(self, loss, optimizer):
|
|
||||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
|
||||||
scaled_loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
|
|
||||||
class NativeScaler:
|
|
||||||
def __init__(self):
|
|
||||||
self._scaler = torch.cuda.amp.GradScaler()
|
|
||||||
|
|
||||||
def __call__(self, loss, optimizer):
|
|
||||||
self._scaler.scale(loss).backward()
|
|
||||||
self._scaler.step(optimizer)
|
|
||||||
self._scaler.update()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
setup_default_logging()
|
setup_default_logging()
|
||||||
args, args_text = _parse_args()
|
args, args_text = _parse_args()
|
||||||
@ -389,20 +372,13 @@ def main():
|
|||||||
_logger.info('AMP not enabled. Training in float32.')
|
_logger.info('AMP not enabled. Training in float32.')
|
||||||
|
|
||||||
# optionally resume from a checkpoint
|
# optionally resume from a checkpoint
|
||||||
resume_state = {}
|
|
||||||
resume_epoch = None
|
resume_epoch = None
|
||||||
if args.resume:
|
if args.resume:
|
||||||
resume_state, resume_epoch = resume_checkpoint(model, args.resume)
|
resume_epoch = resume_checkpoint(
|
||||||
if resume_state and not args.no_resume_opt:
|
model, args.resume,
|
||||||
if 'optimizer' in resume_state:
|
optimizer=None if args.no_resume_opt else optimizer,
|
||||||
if args.local_rank == 0:
|
loss_scaler=None if args.no_resume_opt else loss_scaler,
|
||||||
_logger.info('Restoring optimizer state from checkpoint')
|
log_info=args.local_rank == 0)
|
||||||
optimizer.load_state_dict(resume_state['optimizer'])
|
|
||||||
if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
|
|
||||||
if args.local_rank == 0:
|
|
||||||
_logger.info('Restoring NVIDIA AMP state from checkpoint')
|
|
||||||
amp.load_state_dict(resume_state['amp'])
|
|
||||||
del resume_state
|
|
||||||
|
|
||||||
model_ema = None
|
model_ema = None
|
||||||
if args.model_ema:
|
if args.model_ema:
|
||||||
@ -555,7 +531,9 @@ def main():
|
|||||||
])
|
])
|
||||||
output_dir = get_outdir(output_base, 'train', exp_name)
|
output_dir = get_outdir(output_base, 'train', exp_name)
|
||||||
decreasing = True if eval_metric == 'loss' else False
|
decreasing = True if eval_metric == 'loss' else False
|
||||||
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing, save_amp=use_amp == 'apex')
|
saver = CheckpointSaver(
|
||||||
|
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
|
||||||
|
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing)
|
||||||
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
||||||
f.write(args_text)
|
f.write(args_text)
|
||||||
|
|
||||||
@ -594,8 +572,7 @@ def main():
|
|||||||
if saver is not None:
|
if saver is not None:
|
||||||
# save proper checkpoint with eval metric
|
# save proper checkpoint with eval metric
|
||||||
save_metric = eval_metrics[eval_metric]
|
save_metric = eval_metrics[eval_metric]
|
||||||
best_metric, best_epoch = saver.save_checkpoint(
|
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
|
||||||
model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
@ -688,8 +665,7 @@ def train_epoch(
|
|||||||
|
|
||||||
if saver is not None and args.recovery_interval and (
|
if saver is not None and args.recovery_interval and (
|
||||||
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
||||||
|
saver.save_recovery(epoch, batch_idx=batch_idx)
|
||||||
saver.save_recovery(model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx)
|
|
||||||
|
|
||||||
if lr_scheduler is not None:
|
if lr_scheduler is not None:
|
||||||
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
||||||
|
15
validate.py
15
validate.py
@ -21,7 +21,7 @@ from contextlib import suppress
|
|||||||
|
|
||||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
||||||
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet
|
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet
|
||||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
|
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
|
||||||
|
|
||||||
has_apex = False
|
has_apex = False
|
||||||
try:
|
try:
|
||||||
@ -102,19 +102,6 @@ parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
|
|||||||
help='Valid label indices txt file for validation of partial label space')
|
help='Valid label indices txt file for validation of partial label space')
|
||||||
|
|
||||||
|
|
||||||
def set_jit_legacy():
|
|
||||||
""" Set JIT executor to legacy w/ support for op fusion
|
|
||||||
This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes
|
|
||||||
in the JIT exectutor. These API are not supported so could change.
|
|
||||||
"""
|
|
||||||
#
|
|
||||||
assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!"
|
|
||||||
torch._C._jit_set_profiling_executor(False)
|
|
||||||
torch._C._jit_set_profiling_mode(False)
|
|
||||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
|
||||||
#torch._C._jit_set_texpr_fuser_enabled(True)
|
|
||||||
|
|
||||||
|
|
||||||
def validate(args):
|
def validate(args):
|
||||||
# might as well try to validate something
|
# might as well try to validate something
|
||||||
args.pretrained = args.pretrained or not args.checkpoint
|
args.pretrained = args.pretrained or not args.checkpoint
|
||||||
|
Loading…
x
Reference in New Issue
Block a user