mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
with this update one can tune the kind of logs generated by timm but training and inference traces are unchanged
318 lines
12 KiB
Python
318 lines
12 KiB
Python
from copy import deepcopy
|
|
|
|
import torch
|
|
import math
|
|
import os
|
|
import re
|
|
import shutil
|
|
import glob
|
|
import csv
|
|
import operator
|
|
import logging
|
|
import numpy as np
|
|
from collections import OrderedDict
|
|
try:
|
|
from apex import amp
|
|
has_apex = True
|
|
except ImportError:
|
|
amp = None
|
|
has_apex = False
|
|
|
|
from torch import distributed as dist
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def unwrap_model(model):
|
|
if isinstance(model, ModelEma):
|
|
return unwrap_model(model.ema)
|
|
else:
|
|
return model.module if hasattr(model, 'module') else model
|
|
|
|
|
|
def get_state_dict(model):
|
|
return unwrap_model(model).state_dict()
|
|
|
|
|
|
class CheckpointSaver:
|
|
def __init__(
|
|
self,
|
|
checkpoint_prefix='checkpoint',
|
|
recovery_prefix='recovery',
|
|
checkpoint_dir='',
|
|
recovery_dir='',
|
|
decreasing=False,
|
|
max_history=10):
|
|
|
|
# state
|
|
self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
|
|
self.best_epoch = None
|
|
self.best_metric = None
|
|
self.curr_recovery_file = ''
|
|
self.last_recovery_file = ''
|
|
|
|
# config
|
|
self.checkpoint_dir = checkpoint_dir
|
|
self.recovery_dir = recovery_dir
|
|
self.save_prefix = checkpoint_prefix
|
|
self.recovery_prefix = recovery_prefix
|
|
self.extension = '.pth.tar'
|
|
self.decreasing = decreasing # a lower metric is better if True
|
|
self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
|
|
self.max_history = max_history
|
|
assert self.max_history >= 1
|
|
|
|
def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False):
|
|
assert epoch >= 0
|
|
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
|
|
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
|
|
self._save(tmp_save_path, model, optimizer, args, epoch, model_ema, metric, use_amp)
|
|
if os.path.exists(last_save_path):
|
|
os.unlink(last_save_path) # required for Windows support.
|
|
os.rename(tmp_save_path, last_save_path)
|
|
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
|
|
if (len(self.checkpoint_files) < self.max_history
|
|
or metric is None or self.cmp(metric, worst_file[1])):
|
|
if len(self.checkpoint_files) >= self.max_history:
|
|
self._cleanup_checkpoints(1)
|
|
filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
|
|
save_path = os.path.join(self.checkpoint_dir, filename)
|
|
os.link(last_save_path, save_path)
|
|
self.checkpoint_files.append((save_path, metric))
|
|
self.checkpoint_files = sorted(
|
|
self.checkpoint_files, key=lambda x: x[1],
|
|
reverse=not self.decreasing) # sort in descending order if a lower metric is not better
|
|
|
|
checkpoints_str = "Current checkpoints:\n"
|
|
for c in self.checkpoint_files:
|
|
checkpoints_str += ' {}\n'.format(c)
|
|
logger.info(checkpoints_str)
|
|
|
|
if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
|
|
self.best_epoch = epoch
|
|
self.best_metric = metric
|
|
best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension)
|
|
if os.path.exists(best_save_path):
|
|
os.unlink(best_save_path)
|
|
os.link(last_save_path, best_save_path)
|
|
|
|
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
|
|
|
|
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False):
|
|
save_state = {
|
|
'epoch': epoch,
|
|
'arch': args.model,
|
|
'state_dict': get_state_dict(model),
|
|
'optimizer': optimizer.state_dict(),
|
|
'args': args,
|
|
'version': 2, # version < 2 increments epoch before save
|
|
}
|
|
if use_amp and 'state_dict' in amp.__dict__:
|
|
save_state['amp'] = amp.state_dict()
|
|
if model_ema is not None:
|
|
save_state['state_dict_ema'] = get_state_dict(model_ema)
|
|
if metric is not None:
|
|
save_state['metric'] = metric
|
|
torch.save(save_state, save_path)
|
|
|
|
def _cleanup_checkpoints(self, trim=0):
|
|
trim = min(len(self.checkpoint_files), trim)
|
|
delete_index = self.max_history - trim
|
|
if delete_index <= 0 or len(self.checkpoint_files) <= delete_index:
|
|
return
|
|
to_delete = self.checkpoint_files[delete_index:]
|
|
for d in to_delete:
|
|
try:
|
|
logger.debug("Cleaning checkpoint: {}".format(d))
|
|
os.remove(d[0])
|
|
except Exception as e:
|
|
logger.error("Exception '{}' while deleting checkpoint".format(e))
|
|
self.checkpoint_files = self.checkpoint_files[:delete_index]
|
|
|
|
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, use_amp=False, batch_idx=0):
|
|
assert epoch >= 0
|
|
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
|
|
save_path = os.path.join(self.recovery_dir, filename)
|
|
self._save(save_path, model, optimizer, args, epoch, model_ema, use_amp=use_amp)
|
|
if os.path.exists(self.last_recovery_file):
|
|
try:
|
|
logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
|
|
os.remove(self.last_recovery_file)
|
|
except Exception as e:
|
|
logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file))
|
|
self.last_recovery_file = self.curr_recovery_file
|
|
self.curr_recovery_file = save_path
|
|
|
|
def find_recovery(self):
|
|
recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix)
|
|
files = glob.glob(recovery_path + '*' + self.extension)
|
|
files = sorted(files)
|
|
if len(files):
|
|
return files[0]
|
|
else:
|
|
return ''
|
|
|
|
|
|
class AverageMeter:
|
|
"""Computes and stores the average and current value"""
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.val = 0
|
|
self.avg = 0
|
|
self.sum = 0
|
|
self.count = 0
|
|
|
|
def update(self, val, n=1):
|
|
self.val = val
|
|
self.sum += val * n
|
|
self.count += n
|
|
self.avg = self.sum / self.count
|
|
|
|
|
|
def accuracy(output, target, topk=(1,)):
|
|
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
|
maxk = max(topk)
|
|
batch_size = target.size(0)
|
|
_, pred = output.topk(maxk, 1, True, True)
|
|
pred = pred.t()
|
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
return [correct[:k].view(-1).float().sum(0) * 100. / batch_size for k in topk]
|
|
|
|
|
|
def get_outdir(path, *paths, inc=False):
|
|
outdir = os.path.join(path, *paths)
|
|
if not os.path.exists(outdir):
|
|
os.makedirs(outdir)
|
|
elif inc:
|
|
count = 1
|
|
outdir_inc = outdir + '-' + str(count)
|
|
while os.path.exists(outdir_inc):
|
|
count = count + 1
|
|
outdir_inc = outdir + '-' + str(count)
|
|
assert count < 100
|
|
outdir = outdir_inc
|
|
os.makedirs(outdir)
|
|
return outdir
|
|
|
|
|
|
def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False):
|
|
rowd = OrderedDict(epoch=epoch)
|
|
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
|
|
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
|
|
with open(filename, mode='a') as cf:
|
|
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
|
|
if write_header: # first iteration (epoch == 1 can't be used)
|
|
dw.writeheader()
|
|
dw.writerow(rowd)
|
|
|
|
|
|
def natural_key(string_):
|
|
"""See http://www.codinghorror.com/blog/archives/001018.html"""
|
|
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
|
|
|
|
|
def reduce_tensor(tensor, n):
|
|
rt = tensor.clone()
|
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
|
rt /= n
|
|
return rt
|
|
|
|
|
|
def distribute_bn(model, world_size, reduce=False):
|
|
# ensure every node has the same running bn stats
|
|
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
|
|
if ('running_mean' in bn_name) or ('running_var' in bn_name):
|
|
if reduce:
|
|
# average bn stats across whole group
|
|
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
|
|
bn_buf /= float(world_size)
|
|
else:
|
|
# broadcast bn stats from rank 0 to whole group
|
|
torch.distributed.broadcast(bn_buf, 0)
|
|
|
|
|
|
class ModelEma:
|
|
""" Model Exponential Moving Average
|
|
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
|
|
|
This is intended to allow functionality like
|
|
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
|
|
|
A smoothed version of the weights is necessary for some training schemes to perform well.
|
|
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
|
|
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
|
|
smoothing of weights to match results. Pay attention to the decay constant you are using
|
|
relative to your update count per epoch.
|
|
|
|
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
|
|
disable validation of the EMA weights. Validation will have to be done manually in a separate
|
|
process, or after the training stops converging.
|
|
|
|
This class is sensitive where it is initialized in the sequence of model init,
|
|
GPU assignment and distributed training wrappers.
|
|
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
|
"""
|
|
def __init__(self, model, decay=0.9999, device='', resume=''):
|
|
# make a copy of the model for accumulating moving average of weights
|
|
self.ema = deepcopy(model)
|
|
self.ema.eval()
|
|
self.decay = decay
|
|
self.device = device # perform ema on different device from model if set
|
|
if device:
|
|
self.ema.to(device=device)
|
|
self.ema_has_module = hasattr(self.ema, 'module')
|
|
if resume:
|
|
self._load_checkpoint(resume)
|
|
for p in self.ema.parameters():
|
|
p.requires_grad_(False)
|
|
|
|
def _load_checkpoint(self, checkpoint_path):
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
assert isinstance(checkpoint, dict)
|
|
if 'state_dict_ema' in checkpoint:
|
|
new_state_dict = OrderedDict()
|
|
for k, v in checkpoint['state_dict_ema'].items():
|
|
# ema model may have been wrapped by DataParallel, and need module prefix
|
|
if self.ema_has_module:
|
|
name = 'module.' + k if not k.startswith('module') else k
|
|
else:
|
|
name = k
|
|
new_state_dict[name] = v
|
|
self.ema.load_state_dict(new_state_dict)
|
|
logger.info("Loaded state_dict_ema")
|
|
else:
|
|
logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
|
|
|
|
def update(self, model):
|
|
# correct a mismatch in state dict keys
|
|
needs_module = hasattr(model, 'module') and not self.ema_has_module
|
|
with torch.no_grad():
|
|
msd = model.state_dict()
|
|
for k, ema_v in self.ema.state_dict().items():
|
|
if needs_module:
|
|
k = 'module.' + k
|
|
model_v = msd[k].detach()
|
|
if self.device:
|
|
model_v = model_v.to(device=self.device)
|
|
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
|
|
|
|
|
|
class FormatterNoInfo(logging.Formatter):
|
|
def __init__(self, fmt='%(levelname)s: %(message)s'):
|
|
logging.Formatter.__init__(self, fmt)
|
|
|
|
def format(self, record):
|
|
if record.levelno == logging.INFO:
|
|
return str(record.getMessage())
|
|
return logging.Formatter.format(self, record)
|
|
|
|
|
|
def setup_default_logging(default_level=logging.INFO):
|
|
console_handler = logging.StreamHandler()
|
|
console_handler.setFormatter(FormatterNoInfo())
|
|
logging.root.addHandler(console_handler)
|
|
logging.root.setLevel(default_level)
|