Add exponential moving average for model weights + few other additions and cleanup

* ModelEma class added to track an EMA set of weights for the model being trained
* EMA handling added to train, validation and clean_checkpoint scripts
* Add multi checkpoint or multi-model validation support to validate.py
* Add syncbn option (APEX) to train script for experimentation
* Cleanup interface of CheckpointSaver while adding ema functionality
This commit is contained in:
Ross Wightman 2019-06-07 15:39:36 -07:00
parent ff99625603
commit 9bcd65181b
6 changed files with 258 additions and 84 deletions

View File

@ -9,6 +9,8 @@ parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)') help='path to latest checkpoint (default: none)')
parser.add_argument('--output', default='./cleaned.pth', type=str, metavar='PATH', parser.add_argument('--output', default='./cleaned.pth', type=str, metavar='PATH',
help='output path') help='output path')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
def main(): def main():
@ -24,8 +26,13 @@ def main():
checkpoint = torch.load(args.checkpoint, map_location='cpu') checkpoint = torch.load(args.checkpoint, map_location='cpu')
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if isinstance(checkpoint, dict):
state_dict = checkpoint['state_dict'] state_dict_key = 'state_dict_ema' if args.use_ema else 'state_dict'
if state_dict_key in checkpoint:
state_dict = checkpoint[state_dict_key]
else:
print("Error: No state_dict found in checkpoint {}.".format(args.checkpoint))
exit(1)
else: else:
state_dict = checkpoint state_dict = checkpoint
for k, v in state_dict.items(): for k, v in state_dict.items():

View File

@ -4,22 +4,24 @@ import os
from collections import OrderedDict from collections import OrderedDict
def load_checkpoint(model, checkpoint_path): def load_checkpoint(model, checkpoint_path, use_ema=False):
if checkpoint_path and os.path.isfile(checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path):
print("=> Loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict_key = ''
if isinstance(checkpoint, dict):
state_dict_key = 'state_dict'
if use_ema and 'state_dict_ema' in checkpoint:
state_dict_key = 'state_dict_ema'
if state_dict_key and state_dict_key in checkpoint:
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items(): for k, v in checkpoint[state_dict_key].items():
if k.startswith('module'): # strip `module.` prefix
name = k[7:] # remove `module.` name = k[7:] if k.startswith('module') else k
else:
name = k
new_state_dict[name] = v new_state_dict[name] = v
model.load_state_dict(new_state_dict) model.load_state_dict(new_state_dict)
else: else:
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
print("=> Loaded checkpoint '{}'".format(checkpoint_path)) print("=> Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path))
else: else:
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError() raise FileNotFoundError()
@ -28,27 +30,24 @@ def load_checkpoint(model, checkpoint_path):
def resume_checkpoint(model, checkpoint_path, start_epoch=None): def resume_checkpoint(model, checkpoint_path, start_epoch=None):
optimizer_state = None optimizer_state = None
if os.path.isfile(checkpoint_path): if os.path.isfile(checkpoint_path):
print("=> loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items(): for k, v in checkpoint['state_dict'].items():
if k.startswith('module'): name = k[7:] if k.startswith('module') else k
name = k[7:] # remove `module.`
else:
name = k
new_state_dict[name] = v new_state_dict[name] = v
model.load_state_dict(new_state_dict) model.load_state_dict(new_state_dict)
if 'optimizer' in checkpoint: if 'optimizer' in checkpoint:
optimizer_state = checkpoint['optimizer'] optimizer_state = checkpoint['optimizer']
print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
print("=> Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else: else:
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
start_epoch = 0 if start_epoch is None else start_epoch start_epoch = 0 if start_epoch is None else start_epoch
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
return optimizer_state, start_epoch return optimizer_state, start_epoch
else: else:
print("=> No checkpoint found at '{}'".format(checkpoint_path)) print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError() raise FileNotFoundError()

View File

@ -89,7 +89,7 @@ class RMSpropTF(Optimizer):
state['step'] += 1 state['step'] += 1
if group['weight_decay'] != 0: if group['weight_decay'] != 0:
if group['decoupled_decay']: if 'decoupled_decay' in group and group['decoupled_decay']:
p.data.add_(-group['weight_decay'], p.data) p.data.add_(-group['weight_decay'], p.data)
else: else:
grad = grad.add(group['weight_decay'], p.data) grad = grad.add(group['weight_decay'], p.data)
@ -109,7 +109,7 @@ class RMSpropTF(Optimizer):
if group['momentum'] > 0: if group['momentum'] > 0:
buf = state['momentum_buffer'] buf = state['momentum_buffer']
# Tensorflow accumulates the LR scaling in the momentum buffer # Tensorflow accumulates the LR scaling in the momentum buffer
if group['lr_in_momentum']: if 'lr_in_momentum' in group and group['lr_in_momentum']:
buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg) buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg)
p.data.add_(-buf) p.data.add_(-buf)
else: else:

104
train.py
View File

@ -6,12 +6,13 @@ from datetime import datetime
try: try:
from apex import amp from apex import amp
from apex.parallel import DistributedDataParallel as DDP from apex.parallel import DistributedDataParallel as DDP
from apex.parallel import convert_syncbn_model
has_apex = True has_apex = True
except ImportError: except ImportError:
has_apex = False has_apex = False
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
from models import create_model, resume_checkpoint from models import create_model, resume_checkpoint, load_checkpoint
from utils import * from utils import *
from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from optim import create_optimizer from optim import create_optimizer
@ -91,11 +92,17 @@ parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)') help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None, parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)') help='BatchNorm epsilon override (if not None)')
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
help='decay factor for model weights moving average (default: 0.9998)')
parser.add_argument('--seed', type=int, default=42, metavar='S', parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)') help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N', parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status') help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N', parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint') help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 1)') help='how many training processes to use (default: 1)')
@ -109,6 +116,8 @@ parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging') help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False, parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA amp for mixed precision training') help='use NVIDIA amp for mixed precision training')
parser.add_argument('--sync-bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--no-prefetcher', action='store_true', default=False, parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher') help='disable fast prefetcher')
parser.add_argument('--output', default='', type=str, metavar='PATH', parser.add_argument('--output', default='', type=str, metavar='PATH',
@ -131,31 +140,28 @@ def main():
args.device = 'cuda:0' args.device = 'cuda:0'
args.world_size = 1 args.world_size = 1
r = -1 args.rank = 0 # global rank
if args.distributed: if args.distributed:
args.num_gpu = 1 args.num_gpu = 1
args.device = 'cuda:%d' % args.local_rank args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank) torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', torch.distributed.init_process_group(
init_method='env://') backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size() args.world_size = torch.distributed.get_world_size()
r = torch.distributed.get_rank() args.rank = torch.distributed.get_rank()
assert args.rank >= 0
if args.distributed: if args.distributed:
print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (r, args.world_size)) % (args.rank, args.world_size))
else: else:
print('Training with a single process on %d GPUs.' % args.num_gpu) print('Training with a single process on %d GPUs.' % args.num_gpu)
# FIXME seed handling for multi-process distributed? torch.manual_seed(args.seed + args.rank)
torch.manual_seed(args.seed)
output_dir = '' output_dir = ''
if args.local_rank == 0: if args.local_rank == 0:
if args.output: output_base = args.output if args.output else './output'
output_base = args.output
else:
output_base = './output'
exp_name = '-'.join([ exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"), datetime.now().strftime("%Y%m%d-%H%M%S"),
args.model, args.model,
@ -191,6 +197,8 @@ def main():
args.amp = False args.amp = False
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else: else:
if args.distributed and args.sync_bn and has_apex:
model = convert_syncbn_model(model)
model.cuda() model.cuda()
optimizer = create_optimizer(args, model) optimizer = create_optimizer(args, model)
@ -205,8 +213,20 @@ def main():
use_amp = False use_amp = False
print('AMP disabled') print('AMP disabled')
model_ema = None
if args.model_ema:
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume=args.resume)
if args.distributed: if args.distributed:
model = DDP(model, delay_allreduce=True) model = DDP(model, delay_allreduce=True)
if model_ema is not None and not args.model_ema_force_cpu:
# must also distribute EMA model to allow validation
model_ema.ema = DDP(model_ema.ema, delay_allreduce=True)
model_ema.ema_has_module = True
lr_scheduler, num_epochs = create_scheduler(args, optimizer) lr_scheduler, num_epochs = create_scheduler(args, optimizer)
if start_epoch > 0: if start_epoch > 0:
@ -273,6 +293,7 @@ def main():
eval_metric = args.eval_metric eval_metric = args.eval_metric
saver = None saver = None
if output_dir: if output_dir:
# only set if process is rank 0
decreasing = True if eval_metric == 'loss' else False decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
best_metric = None best_metric = None
@ -284,10 +305,15 @@ def main():
train_metrics = train_epoch( train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args, epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp) lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
use_amp=use_amp, model_ema=model_ema)
eval_metrics = validate( eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
model, loader_eval, validate_loss_fn, args)
if model_ema is not None and not args.model_ema_force_cpu:
ema_eval_metrics = validate(
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
eval_metrics = ema_eval_metrics
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step(epoch, eval_metrics[eval_metric]) lr_scheduler.step(epoch, eval_metrics[eval_metric])
@ -298,15 +324,12 @@ def main():
if saver is not None: if saver is not None:
# save proper checkpoint with eval metric # save proper checkpoint with eval metric
best_metric, best_epoch = saver.save_checkpoint({ save_metric = eval_metrics[eval_metric]
'epoch': epoch + 1, best_metric, best_epoch = saver.save_checkpoint(
'arch': args.model, model, optimizer, args,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'args': args,
},
epoch=epoch + 1, epoch=epoch + 1,
metric=eval_metrics[eval_metric]) model_ema=model_ema,
metric=save_metric)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
@ -316,7 +339,7 @@ def main():
def train_epoch( def train_epoch(
epoch, model, loader, optimizer, loss_fn, args, epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', use_amp=False): lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None):
if args.prefetcher and args.mixup > 0 and loader.mixup_enabled: if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
@ -359,6 +382,8 @@ def train_epoch(
optimizer.step() optimizer.step()
torch.cuda.synchronize() torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
num_updates += 1 num_updates += 1
batch_time_m.update(time.time() - end) batch_time_m.update(time.time() - end)
@ -394,18 +419,11 @@ def train_epoch(
padding=0, padding=0,
normalize=True) normalize=True)
if args.local_rank == 0 and ( if saver is not None and args.recovery_interval and (
saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0): last_batch or (batch_idx + 1) % args.recovery_interval == 0):
save_epoch = epoch + 1 if last_batch else epoch save_epoch = epoch + 1 if last_batch else epoch
saver.save_recovery({ saver.save_recovery(
'epoch': save_epoch, model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx)
'arch': args.model,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'args': args,
},
epoch=save_epoch,
batch_idx=batch_idx)
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
@ -415,7 +433,7 @@ def train_epoch(
return OrderedDict([('loss', losses_m.avg)]) return OrderedDict([('loss', losses_m.avg)])
def validate(model, loader, loss_fn, args): def validate(model, loader, loss_fn, args, log_suffix=''):
batch_time_m = AverageMeter() batch_time_m = AverageMeter()
losses_m = AverageMeter() losses_m = AverageMeter()
prec1_m = AverageMeter() prec1_m = AverageMeter()
@ -461,12 +479,13 @@ def validate(model, loader, loss_fn, args):
batch_time_m.update(time.time() - end) batch_time_m.update(time.time() - end)
end = time.time() end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
print('Test: [{0}/{1}]\t' log_name = 'Test' + log_suffix
print('{0}: [{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss {loss.val:.4f} ({loss.avg:.4f}) ' 'Loss {loss.val:.4f} ({loss.avg:.4f}) '
'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) ' 'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) '
'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format( 'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format(
batch_idx, last_idx, log_name, batch_idx, last_idx,
batch_time=batch_time_m, loss=losses_m, batch_time=batch_time_m, loss=losses_m,
top1=prec1_m, top5=prec5_m)) top1=prec1_m, top5=prec5_m))
@ -475,12 +494,5 @@ def validate(model, loader, loss_fn, args):
return metrics return metrics
def reduce_tensor(tensor, n):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= n
return rt
if __name__ == '__main__': if __name__ == '__main__':
main() main()

120
utils.py
View File

@ -1,6 +1,9 @@
from copy import deepcopy
import torch import torch
import math import math
import os import os
import re
import shutil import shutil
import glob import glob
import csv import csv
@ -8,6 +11,15 @@ import operator
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
from torch import distributed as dist
def get_state_dict(model):
if isinstance(model, ModelEma):
return get_state_dict(model.ema)
else:
return model.module.state_dict() if getattr(model, 'module') else model.state_dict()
class CheckpointSaver: class CheckpointSaver:
def __init__( def __init__(
@ -39,17 +51,16 @@ class CheckpointSaver:
self.max_history = max_history self.max_history = max_history
assert self.max_history >= 1 assert self.max_history >= 1
def save_checkpoint(self, state, epoch, metric=None): def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None):
assert epoch >= 0
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
if len(self.checkpoint_files) < self.max_history or self.cmp(metric, worst_file[1]): if (len(self.checkpoint_files) < self.max_history
or metric is None or self.cmp(metric, worst_file[1])):
if len(self.checkpoint_files) >= self.max_history: if len(self.checkpoint_files) >= self.max_history:
self._cleanup_checkpoints(1) self._cleanup_checkpoints(1)
filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
save_path = os.path.join(self.checkpoint_dir, filename) save_path = os.path.join(self.checkpoint_dir, filename)
if metric is not None: self._save(save_path, model, optimizer, args, epoch, model_ema, metric)
state['metric'] = metric
torch.save(state, save_path)
self.checkpoint_files.append((save_path, metric)) self.checkpoint_files.append((save_path, metric))
self.checkpoint_files = sorted( self.checkpoint_files = sorted(
self.checkpoint_files, key=lambda x: x[1], self.checkpoint_files, key=lambda x: x[1],
@ -67,6 +78,20 @@ class CheckpointSaver:
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None):
save_state = {
'epoch': epoch,
'arch': args.model,
'state_dict': get_state_dict(model),
'optimizer': optimizer.state_dict(),
'args': args
}
if model_ema is not None:
save_state['state_dict_ema'] = get_state_dict(model_ema)
if metric is not None:
save_state['metric'] = metric
torch.save(save_state, save_path)
def _cleanup_checkpoints(self, trim=0): def _cleanup_checkpoints(self, trim=0):
trim = min(len(self.checkpoint_files), trim) trim = min(len(self.checkpoint_files), trim)
delete_index = self.max_history - trim delete_index = self.max_history - trim
@ -82,10 +107,11 @@ class CheckpointSaver:
print('Exception (%s) while deleting checkpoint' % str(e)) print('Exception (%s) while deleting checkpoint' % str(e))
self.checkpoint_files = self.checkpoint_files[:delete_index] self.checkpoint_files = self.checkpoint_files[:delete_index]
def save_recovery(self, state, epoch, batch_idx): def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0):
assert epoch >= 0
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
save_path = os.path.join(self.recovery_dir, filename) save_path = os.path.join(self.recovery_dir, filename)
torch.save(state, save_path) self._save(save_path, model, optimizer, args, epoch, model_ema)
if os.path.exists(self.last_recovery_file): if os.path.exists(self.last_recovery_file):
try: try:
if self.verbose: if self.verbose:
@ -165,3 +191,81 @@ def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=Fa
if write_header: # first iteration (epoch == 1 can't be used) if write_header: # first iteration (epoch == 1 can't be used)
dw.writeheader() dw.writeheader()
dw.writerow(rowd) dw.writerow(rowd)
def natural_key(string_):
"""See http://www.codinghorror.com/blog/archives/001018.html"""
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def reduce_tensor(tensor, n):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= n
return rt
class ModelEma:
""" Model Exponential Moving Average
Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
smoothing of weights to match results. Pay attention to the decay constant you are using
relative to your update count per epoch.
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
disable validation of the EMA weights. Validation will have to be done manually in a separate
process, or after the training stops converging.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
"""
def __init__(self, model, decay=0.9999, device='', resume=''):
# make a copy of the model for accumulating moving average of weights
self.ema = deepcopy(model)
self.ema.eval()
self.decay = decay
self.device = device # perform ema on different device from model if set
if device:
self.ema.to(device=device)
self.ema_has_module = hasattr(self.ema, 'module')
if resume:
self._load_checkpoint(resume)
for p in self.ema.parameters():
p.requires_grad_(False)
def _load_checkpoint(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path)
assert isinstance(checkpoint, dict)
if 'state_dict_ema' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict_ema'].items():
# ema model may have been wrapped by DataParallel, and need module prefix
if self.ema_has_module:
name = 'module.' + k if not k.startswith('module') else k
else:
name = k
new_state_dict[name] = v
self.ema.load_state_dict(new_state_dict)
print("=> loaded state_dict_ema")
else:
print("=> failed to find state_dict_ema, starting from loaded model weights)")
def update(self, model):
# correct a mismatch in state dict keys
needs_module = hasattr(model, 'module') and not self.ema_has_module
with torch.no_grad():
msd = model.state_dict()
for k, ema_v in self.ema.state_dict().items():
if needs_module:
k = 'module.' + k
model_v = msd[k].detach()
if self.device:
model_v = model_v.to(device=self.device)
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

View File

@ -4,14 +4,17 @@ from __future__ import print_function
import argparse import argparse
import os import os
import csv
import glob
import time import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.parallel import torch.nn.parallel
from collections import OrderedDict
from models import create_model, apply_test_time_pool from models import create_model, apply_test_time_pool, load_checkpoint
from data import Dataset, create_loader, resolve_data_config from data import Dataset, create_loader, resolve_data_config
from utils import accuracy, AverageMeter from utils import accuracy, AverageMeter, natural_key
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@ -46,21 +49,26 @@ parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
help='disable test time pool') help='disable test time pool')
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
help='Use Tensorflow preprocessing pipeline (require CPU TF installed') help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
def main(): def validate(args):
args = parser.parse_args()
# create model # create model
model = create_model( model = create_model(
args.model, args.model,
num_classes=args.num_classes, num_classes=args.num_classes,
in_chans=3, in_chans=3,
pretrained=args.pretrained, pretrained=args.pretrained)
checkpoint_path=args.checkpoint)
print('Model %s created, param count: %d' % if args.checkpoint and not args.pretrained:
(args.model, sum([m.numel() for m in model.parameters()]))) load_checkpoint(model, args.checkpoint, args.use_ema)
else:
args.pretrained = True # might as well try to validate something...
param_count = sum([m.numel() for m in model.parameters()])
print('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(model, args) data_config = resolve_data_config(model, args)
model, test_time_pool = apply_test_time_pool(model, data_config, args) model, test_time_pool = apply_test_time_pool(model, data_config, args)
@ -120,8 +128,52 @@ def main():
rate_avg=input.size(0) / batch_time.avg, rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5)) loss=losses, top1=top1, top5=top5))
print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( results = OrderedDict(
top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3),
top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3),
param_count=round(param_count / 1e6, 2))
print(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
return results
def main():
args = parser.parse_args()
if args.model == 'all':
# validate all models in a list of names with pretrained checkpoints
args.pretrained = True
# FIXME just an example list, need to add model name collections for
# batch testing of various pretrained combinations by arg string
models = ['tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3']
model_cfgs = [(n, '') for n in models]
elif os.path.isdir(args.checkpoint):
# validate all checkpoints in a path with same model
checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
checkpoints += glob.glob(args.checkpoint + '/*.pth')
model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
else:
model_cfgs = []
if len(model_cfgs):
header_written = False
with open('./results-all.csv', mode='w') as cf:
for m, c in model_cfgs:
args.model = m
args.checkpoint = c
result = OrderedDict(model=args.model)
result.update(validate(args))
if args.checkpoint:
result['checkpoint'] = args.checkpoint
dw = csv.DictWriter(cf, fieldnames=result.keys())
if not header_written:
dw.writeheader()
header_written = True
dw.writerow(result)
cf.flush()
else:
validate(args)
if __name__ == '__main__': if __name__ == '__main__':