mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add exponential moving average for model weights + few other additions and cleanup
* ModelEma class added to track an EMA set of weights for the model being trained * EMA handling added to train, validation and clean_checkpoint scripts * Add multi checkpoint or multi-model validation support to validate.py * Add syncbn option (APEX) to train script for experimentation * Cleanup interface of CheckpointSaver while adding ema functionality
This commit is contained in:
parent
ff99625603
commit
9bcd65181b
@ -9,6 +9,8 @@ parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('--output', default='./cleaned.pth', type=str, metavar='PATH',
|
||||
help='output path')
|
||||
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
||||
help='use ema version of weights if present')
|
||||
|
||||
|
||||
def main():
|
||||
@ -24,8 +26,13 @@ def main():
|
||||
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
||||
|
||||
new_state_dict = OrderedDict()
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
if isinstance(checkpoint, dict):
|
||||
state_dict_key = 'state_dict_ema' if args.use_ema else 'state_dict'
|
||||
if state_dict_key in checkpoint:
|
||||
state_dict = checkpoint[state_dict_key]
|
||||
else:
|
||||
print("Error: No state_dict found in checkpoint {}.".format(args.checkpoint))
|
||||
exit(1)
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
for k, v in state_dict.items():
|
||||
|
@ -4,22 +4,24 @@ import os
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path):
|
||||
def load_checkpoint(model, checkpoint_path, use_ema=False):
|
||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||
print("=> Loading checkpoint '{}'".format(checkpoint_path))
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
state_dict_key = ''
|
||||
if isinstance(checkpoint, dict):
|
||||
state_dict_key = 'state_dict'
|
||||
if use_ema and 'state_dict_ema' in checkpoint:
|
||||
state_dict_key = 'state_dict_ema'
|
||||
if state_dict_key and state_dict_key in checkpoint:
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in checkpoint['state_dict'].items():
|
||||
if k.startswith('module'):
|
||||
name = k[7:] # remove `module.`
|
||||
else:
|
||||
name = k
|
||||
for k, v in checkpoint[state_dict_key].items():
|
||||
# strip `module.` prefix
|
||||
name = k[7:] if k.startswith('module') else k
|
||||
new_state_dict[name] = v
|
||||
model.load_state_dict(new_state_dict)
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
|
||||
print("=> Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path))
|
||||
else:
|
||||
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
|
||||
raise FileNotFoundError()
|
||||
@ -28,27 +30,24 @@ def load_checkpoint(model, checkpoint_path):
|
||||
def resume_checkpoint(model, checkpoint_path, start_epoch=None):
|
||||
optimizer_state = None
|
||||
if os.path.isfile(checkpoint_path):
|
||||
print("=> loading checkpoint '{}'".format(checkpoint_path))
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in checkpoint['state_dict'].items():
|
||||
if k.startswith('module'):
|
||||
name = k[7:] # remove `module.`
|
||||
else:
|
||||
name = k
|
||||
name = k[7:] if k.startswith('module') else k
|
||||
new_state_dict[name] = v
|
||||
model.load_state_dict(new_state_dict)
|
||||
if 'optimizer' in checkpoint:
|
||||
optimizer_state = checkpoint['optimizer']
|
||||
print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
||||
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
|
||||
print("=> Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
start_epoch = 0 if start_epoch is None else start_epoch
|
||||
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
|
||||
return optimizer_state, start_epoch
|
||||
else:
|
||||
print("=> No checkpoint found at '{}'".format(checkpoint_path))
|
||||
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
|
||||
raise FileNotFoundError()
|
||||
|
||||
|
||||
|
@ -89,7 +89,7 @@ class RMSpropTF(Optimizer):
|
||||
state['step'] += 1
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
if group['decoupled_decay']:
|
||||
if 'decoupled_decay' in group and group['decoupled_decay']:
|
||||
p.data.add_(-group['weight_decay'], p.data)
|
||||
else:
|
||||
grad = grad.add(group['weight_decay'], p.data)
|
||||
@ -109,7 +109,7 @@ class RMSpropTF(Optimizer):
|
||||
if group['momentum'] > 0:
|
||||
buf = state['momentum_buffer']
|
||||
# Tensorflow accumulates the LR scaling in the momentum buffer
|
||||
if group['lr_in_momentum']:
|
||||
if 'lr_in_momentum' in group and group['lr_in_momentum']:
|
||||
buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg)
|
||||
p.data.add_(-buf)
|
||||
else:
|
||||
|
104
train.py
104
train.py
@ -6,12 +6,13 @@ from datetime import datetime
|
||||
try:
|
||||
from apex import amp
|
||||
from apex.parallel import DistributedDataParallel as DDP
|
||||
from apex.parallel import convert_syncbn_model
|
||||
has_apex = True
|
||||
except ImportError:
|
||||
has_apex = False
|
||||
|
||||
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
|
||||
from models import create_model, resume_checkpoint
|
||||
from models import create_model, resume_checkpoint, load_checkpoint
|
||||
from utils import *
|
||||
from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
||||
from optim import create_optimizer
|
||||
@ -91,11 +92,17 @@ parser.add_argument('--bn-momentum', type=float, default=None,
|
||||
help='BatchNorm momentum override (if not None)')
|
||||
parser.add_argument('--bn-eps', type=float, default=None,
|
||||
help='BatchNorm epsilon override (if not None)')
|
||||
parser.add_argument('--model-ema', action='store_true', default=False,
|
||||
help='Enable tracking moving average of model weights')
|
||||
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
|
||||
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
|
||||
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
|
||||
help='decay factor for model weights moving average (default: 0.9998)')
|
||||
parser.add_argument('--seed', type=int, default=42, metavar='S',
|
||||
help='random seed (default: 42)')
|
||||
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
|
||||
help='how many batches to wait before logging training status')
|
||||
parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N',
|
||||
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
|
||||
help='how many batches to wait before writing recovery checkpoint')
|
||||
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
|
||||
help='how many training processes to use (default: 1)')
|
||||
@ -109,6 +116,8 @@ parser.add_argument('--save-images', action='store_true', default=False,
|
||||
help='save images of input bathes every log interval for debugging')
|
||||
parser.add_argument('--amp', action='store_true', default=False,
|
||||
help='use NVIDIA amp for mixed precision training')
|
||||
parser.add_argument('--sync-bn', action='store_true',
|
||||
help='enabling apex sync BN.')
|
||||
parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
||||
help='disable fast prefetcher')
|
||||
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
||||
@ -131,31 +140,28 @@ def main():
|
||||
|
||||
args.device = 'cuda:0'
|
||||
args.world_size = 1
|
||||
r = -1
|
||||
args.rank = 0 # global rank
|
||||
if args.distributed:
|
||||
args.num_gpu = 1
|
||||
args.device = 'cuda:%d' % args.local_rank
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl',
|
||||
init_method='env://')
|
||||
torch.distributed.init_process_group(
|
||||
backend='nccl', init_method='env://')
|
||||
args.world_size = torch.distributed.get_world_size()
|
||||
r = torch.distributed.get_rank()
|
||||
args.rank = torch.distributed.get_rank()
|
||||
assert args.rank >= 0
|
||||
|
||||
if args.distributed:
|
||||
print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
|
||||
% (r, args.world_size))
|
||||
% (args.rank, args.world_size))
|
||||
else:
|
||||
print('Training with a single process on %d GPUs.' % args.num_gpu)
|
||||
|
||||
# FIXME seed handling for multi-process distributed?
|
||||
torch.manual_seed(args.seed)
|
||||
torch.manual_seed(args.seed + args.rank)
|
||||
|
||||
output_dir = ''
|
||||
if args.local_rank == 0:
|
||||
if args.output:
|
||||
output_base = args.output
|
||||
else:
|
||||
output_base = './output'
|
||||
output_base = args.output if args.output else './output'
|
||||
exp_name = '-'.join([
|
||||
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
||||
args.model,
|
||||
@ -191,6 +197,8 @@ def main():
|
||||
args.amp = False
|
||||
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
||||
else:
|
||||
if args.distributed and args.sync_bn and has_apex:
|
||||
model = convert_syncbn_model(model)
|
||||
model.cuda()
|
||||
|
||||
optimizer = create_optimizer(args, model)
|
||||
@ -205,8 +213,20 @@ def main():
|
||||
use_amp = False
|
||||
print('AMP disabled')
|
||||
|
||||
model_ema = None
|
||||
if args.model_ema:
|
||||
model_ema = ModelEma(
|
||||
model,
|
||||
decay=args.model_ema_decay,
|
||||
device='cpu' if args.model_ema_force_cpu else '',
|
||||
resume=args.resume)
|
||||
|
||||
if args.distributed:
|
||||
model = DDP(model, delay_allreduce=True)
|
||||
if model_ema is not None and not args.model_ema_force_cpu:
|
||||
# must also distribute EMA model to allow validation
|
||||
model_ema.ema = DDP(model_ema.ema, delay_allreduce=True)
|
||||
model_ema.ema_has_module = True
|
||||
|
||||
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
|
||||
if start_epoch > 0:
|
||||
@ -273,6 +293,7 @@ def main():
|
||||
eval_metric = args.eval_metric
|
||||
saver = None
|
||||
if output_dir:
|
||||
# only set if process is rank 0
|
||||
decreasing = True if eval_metric == 'loss' else False
|
||||
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
|
||||
best_metric = None
|
||||
@ -284,10 +305,15 @@ def main():
|
||||
|
||||
train_metrics = train_epoch(
|
||||
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
||||
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp)
|
||||
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
||||
use_amp=use_amp, model_ema=model_ema)
|
||||
|
||||
eval_metrics = validate(
|
||||
model, loader_eval, validate_loss_fn, args)
|
||||
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
|
||||
|
||||
if model_ema is not None and not args.model_ema_force_cpu:
|
||||
ema_eval_metrics = validate(
|
||||
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
|
||||
eval_metrics = ema_eval_metrics
|
||||
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step(epoch, eval_metrics[eval_metric])
|
||||
@ -298,15 +324,12 @@ def main():
|
||||
|
||||
if saver is not None:
|
||||
# save proper checkpoint with eval metric
|
||||
best_metric, best_epoch = saver.save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'arch': args.model,
|
||||
'state_dict': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'args': args,
|
||||
},
|
||||
save_metric = eval_metrics[eval_metric]
|
||||
best_metric, best_epoch = saver.save_checkpoint(
|
||||
model, optimizer, args,
|
||||
epoch=epoch + 1,
|
||||
metric=eval_metrics[eval_metric])
|
||||
model_ema=model_ema,
|
||||
metric=save_metric)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
@ -316,7 +339,7 @@ def main():
|
||||
|
||||
def train_epoch(
|
||||
epoch, model, loader, optimizer, loss_fn, args,
|
||||
lr_scheduler=None, saver=None, output_dir='', use_amp=False):
|
||||
lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None):
|
||||
|
||||
if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
|
||||
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
||||
@ -359,6 +382,8 @@ def train_epoch(
|
||||
optimizer.step()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
if model_ema is not None:
|
||||
model_ema.update(model)
|
||||
num_updates += 1
|
||||
|
||||
batch_time_m.update(time.time() - end)
|
||||
@ -394,18 +419,11 @@ def train_epoch(
|
||||
padding=0,
|
||||
normalize=True)
|
||||
|
||||
if args.local_rank == 0 and (
|
||||
saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
||||
if saver is not None and args.recovery_interval and (
|
||||
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
||||
save_epoch = epoch + 1 if last_batch else epoch
|
||||
saver.save_recovery({
|
||||
'epoch': save_epoch,
|
||||
'arch': args.model,
|
||||
'state_dict': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'args': args,
|
||||
},
|
||||
epoch=save_epoch,
|
||||
batch_idx=batch_idx)
|
||||
saver.save_recovery(
|
||||
model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx)
|
||||
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
||||
@ -415,7 +433,7 @@ def train_epoch(
|
||||
return OrderedDict([('loss', losses_m.avg)])
|
||||
|
||||
|
||||
def validate(model, loader, loss_fn, args):
|
||||
def validate(model, loader, loss_fn, args, log_suffix=''):
|
||||
batch_time_m = AverageMeter()
|
||||
losses_m = AverageMeter()
|
||||
prec1_m = AverageMeter()
|
||||
@ -461,12 +479,13 @@ def validate(model, loader, loss_fn, args):
|
||||
batch_time_m.update(time.time() - end)
|
||||
end = time.time()
|
||||
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
|
||||
print('Test: [{0}/{1}]\t'
|
||||
log_name = 'Test' + log_suffix
|
||||
print('{0}: [{1}/{2}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
||||
'Loss {loss.val:.4f} ({loss.avg:.4f}) '
|
||||
'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) '
|
||||
'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format(
|
||||
batch_idx, last_idx,
|
||||
log_name, batch_idx, last_idx,
|
||||
batch_time=batch_time_m, loss=losses_m,
|
||||
top1=prec1_m, top5=prec5_m))
|
||||
|
||||
@ -475,12 +494,5 @@ def validate(model, loader, loss_fn, args):
|
||||
return metrics
|
||||
|
||||
|
||||
def reduce_tensor(tensor, n):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||
rt /= n
|
||||
return rt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
120
utils.py
120
utils.py
@ -1,6 +1,9 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import glob
|
||||
import csv
|
||||
@ -8,6 +11,15 @@ import operator
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
from torch import distributed as dist
|
||||
|
||||
|
||||
def get_state_dict(model):
|
||||
if isinstance(model, ModelEma):
|
||||
return get_state_dict(model.ema)
|
||||
else:
|
||||
return model.module.state_dict() if getattr(model, 'module') else model.state_dict()
|
||||
|
||||
|
||||
class CheckpointSaver:
|
||||
def __init__(
|
||||
@ -39,17 +51,16 @@ class CheckpointSaver:
|
||||
self.max_history = max_history
|
||||
assert self.max_history >= 1
|
||||
|
||||
def save_checkpoint(self, state, epoch, metric=None):
|
||||
def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None):
|
||||
assert epoch >= 0
|
||||
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
|
||||
if len(self.checkpoint_files) < self.max_history or self.cmp(metric, worst_file[1]):
|
||||
if (len(self.checkpoint_files) < self.max_history
|
||||
or metric is None or self.cmp(metric, worst_file[1])):
|
||||
if len(self.checkpoint_files) >= self.max_history:
|
||||
self._cleanup_checkpoints(1)
|
||||
|
||||
filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
|
||||
save_path = os.path.join(self.checkpoint_dir, filename)
|
||||
if metric is not None:
|
||||
state['metric'] = metric
|
||||
torch.save(state, save_path)
|
||||
self._save(save_path, model, optimizer, args, epoch, model_ema, metric)
|
||||
self.checkpoint_files.append((save_path, metric))
|
||||
self.checkpoint_files = sorted(
|
||||
self.checkpoint_files, key=lambda x: x[1],
|
||||
@ -67,6 +78,20 @@ class CheckpointSaver:
|
||||
|
||||
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
|
||||
|
||||
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None):
|
||||
save_state = {
|
||||
'epoch': epoch,
|
||||
'arch': args.model,
|
||||
'state_dict': get_state_dict(model),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'args': args
|
||||
}
|
||||
if model_ema is not None:
|
||||
save_state['state_dict_ema'] = get_state_dict(model_ema)
|
||||
if metric is not None:
|
||||
save_state['metric'] = metric
|
||||
torch.save(save_state, save_path)
|
||||
|
||||
def _cleanup_checkpoints(self, trim=0):
|
||||
trim = min(len(self.checkpoint_files), trim)
|
||||
delete_index = self.max_history - trim
|
||||
@ -82,10 +107,11 @@ class CheckpointSaver:
|
||||
print('Exception (%s) while deleting checkpoint' % str(e))
|
||||
self.checkpoint_files = self.checkpoint_files[:delete_index]
|
||||
|
||||
def save_recovery(self, state, epoch, batch_idx):
|
||||
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0):
|
||||
assert epoch >= 0
|
||||
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
|
||||
save_path = os.path.join(self.recovery_dir, filename)
|
||||
torch.save(state, save_path)
|
||||
self._save(save_path, model, optimizer, args, epoch, model_ema)
|
||||
if os.path.exists(self.last_recovery_file):
|
||||
try:
|
||||
if self.verbose:
|
||||
@ -165,3 +191,81 @@ def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=Fa
|
||||
if write_header: # first iteration (epoch == 1 can't be used)
|
||||
dw.writeheader()
|
||||
dw.writerow(rowd)
|
||||
|
||||
|
||||
def natural_key(string_):
|
||||
"""See http://www.codinghorror.com/blog/archives/001018.html"""
|
||||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||
|
||||
|
||||
def reduce_tensor(tensor, n):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||
rt /= n
|
||||
return rt
|
||||
|
||||
|
||||
class ModelEma:
|
||||
""" Model Exponential Moving Average
|
||||
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
||||
|
||||
This is intended to allow functionality like
|
||||
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
|
||||
A smoothed version of the weights is necessary for some training schemes to perform well.
|
||||
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
|
||||
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
|
||||
smoothing of weights to match results. Pay attention to the decay constant you are using
|
||||
relative to your update count per epoch.
|
||||
|
||||
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
|
||||
disable validation of the EMA weights. Validation will have to be done manually in a separate
|
||||
process, or after the training stops converging.
|
||||
|
||||
This class is sensitive where it is initialized in the sequence of model init,
|
||||
GPU assignment and distributed training wrappers.
|
||||
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
||||
"""
|
||||
def __init__(self, model, decay=0.9999, device='', resume=''):
|
||||
# make a copy of the model for accumulating moving average of weights
|
||||
self.ema = deepcopy(model)
|
||||
self.ema.eval()
|
||||
self.decay = decay
|
||||
self.device = device # perform ema on different device from model if set
|
||||
if device:
|
||||
self.ema.to(device=device)
|
||||
self.ema_has_module = hasattr(self.ema, 'module')
|
||||
if resume:
|
||||
self._load_checkpoint(resume)
|
||||
for p in self.ema.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
def _load_checkpoint(self, checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
assert isinstance(checkpoint, dict)
|
||||
if 'state_dict_ema' in checkpoint:
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in checkpoint['state_dict_ema'].items():
|
||||
# ema model may have been wrapped by DataParallel, and need module prefix
|
||||
if self.ema_has_module:
|
||||
name = 'module.' + k if not k.startswith('module') else k
|
||||
else:
|
||||
name = k
|
||||
new_state_dict[name] = v
|
||||
self.ema.load_state_dict(new_state_dict)
|
||||
print("=> loaded state_dict_ema")
|
||||
else:
|
||||
print("=> failed to find state_dict_ema, starting from loaded model weights)")
|
||||
|
||||
def update(self, model):
|
||||
# correct a mismatch in state dict keys
|
||||
needs_module = hasattr(model, 'module') and not self.ema_has_module
|
||||
with torch.no_grad():
|
||||
msd = model.state_dict()
|
||||
for k, ema_v in self.ema.state_dict().items():
|
||||
if needs_module:
|
||||
k = 'module.' + k
|
||||
model_v = msd[k].detach()
|
||||
if self.device:
|
||||
model_v = model_v.to(device=self.device)
|
||||
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
|
||||
|
72
validate.py
72
validate.py
@ -4,14 +4,17 @@ from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import csv
|
||||
import glob
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
from collections import OrderedDict
|
||||
|
||||
from models import create_model, apply_test_time_pool
|
||||
from models import create_model, apply_test_time_pool, load_checkpoint
|
||||
from data import Dataset, create_loader, resolve_data_config
|
||||
from utils import accuracy, AverageMeter
|
||||
from utils import accuracy, AverageMeter, natural_key
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
@ -46,21 +49,26 @@ parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
|
||||
help='disable test time pool')
|
||||
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
|
||||
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
|
||||
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
||||
help='use ema version of weights if present')
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
def validate(args):
|
||||
|
||||
# create model
|
||||
model = create_model(
|
||||
args.model,
|
||||
num_classes=args.num_classes,
|
||||
in_chans=3,
|
||||
pretrained=args.pretrained,
|
||||
checkpoint_path=args.checkpoint)
|
||||
pretrained=args.pretrained)
|
||||
|
||||
print('Model %s created, param count: %d' %
|
||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||
if args.checkpoint and not args.pretrained:
|
||||
load_checkpoint(model, args.checkpoint, args.use_ema)
|
||||
else:
|
||||
args.pretrained = True # might as well try to validate something...
|
||||
|
||||
param_count = sum([m.numel() for m in model.parameters()])
|
||||
print('Model %s created, param count: %d' % (args.model, param_count))
|
||||
|
||||
data_config = resolve_data_config(model, args)
|
||||
model, test_time_pool = apply_test_time_pool(model, data_config, args)
|
||||
@ -120,8 +128,52 @@ def main():
|
||||
rate_avg=input.size(0) / batch_time.avg,
|
||||
loss=losses, top1=top1, top5=top5))
|
||||
|
||||
print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
|
||||
top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
|
||||
results = OrderedDict(
|
||||
top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3),
|
||||
top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3),
|
||||
param_count=round(param_count / 1e6, 2))
|
||||
|
||||
print(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
|
||||
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
if args.model == 'all':
|
||||
# validate all models in a list of names with pretrained checkpoints
|
||||
args.pretrained = True
|
||||
# FIXME just an example list, need to add model name collections for
|
||||
# batch testing of various pretrained combinations by arg string
|
||||
models = ['tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3']
|
||||
model_cfgs = [(n, '') for n in models]
|
||||
elif os.path.isdir(args.checkpoint):
|
||||
# validate all checkpoints in a path with same model
|
||||
checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
|
||||
checkpoints += glob.glob(args.checkpoint + '/*.pth')
|
||||
model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
|
||||
else:
|
||||
model_cfgs = []
|
||||
|
||||
if len(model_cfgs):
|
||||
header_written = False
|
||||
with open('./results-all.csv', mode='w') as cf:
|
||||
for m, c in model_cfgs:
|
||||
args.model = m
|
||||
args.checkpoint = c
|
||||
result = OrderedDict(model=args.model)
|
||||
result.update(validate(args))
|
||||
if args.checkpoint:
|
||||
result['checkpoint'] = args.checkpoint
|
||||
dw = csv.DictWriter(cf, fieldnames=result.keys())
|
||||
if not header_written:
|
||||
dw.writeheader()
|
||||
header_written = True
|
||||
dw.writerow(result)
|
||||
cf.flush()
|
||||
else:
|
||||
validate(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user