mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add back old ModelEma and rename new one to ModelEmaV2 to avoid compat breaks in dependant code. Shuffle train script, add a few comments, remove DataParallel support, support experimental torchscript training.
This commit is contained in:
parent
9214ca0716
commit
27bbc70d71
@ -6,5 +6,5 @@ from .log import setup_default_logging, FormatterNoInfo
|
||||
from .metrics import AverageMeter, accuracy
|
||||
from .misc import natural_key, add_bool_arg
|
||||
from .model import unwrap_model, get_state_dict
|
||||
from .model_ema import ModelEma
|
||||
from .model_ema import ModelEma, ModelEmaV2
|
||||
from .summary import update_summary, get_outdir
|
||||
|
@ -6,7 +6,10 @@ from .model_ema import ModelEma
|
||||
|
||||
|
||||
def unwrap_model(model):
|
||||
return model.module if hasattr(model, 'module') else model
|
||||
if isinstance(model, ModelEma):
|
||||
return unwrap_model(model.ema)
|
||||
else:
|
||||
return model.module if hasattr(model, 'module') else model
|
||||
|
||||
|
||||
def get_state_dict(model, unwrap_fn=unwrap_model):
|
||||
|
@ -2,15 +2,89 @@
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelEma:
|
||||
""" Model Exponential Moving Average (DEPRECATED)
|
||||
|
||||
class ModelEma(nn.Module):
|
||||
""" Model Exponential Moving Average
|
||||
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
||||
This version is deprecated, it does not work with scripted models. Will be removed eventually.
|
||||
|
||||
This is intended to allow functionality like
|
||||
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
|
||||
A smoothed version of the weights is necessary for some training schemes to perform well.
|
||||
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
|
||||
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
|
||||
smoothing of weights to match results. Pay attention to the decay constant you are using
|
||||
relative to your update count per epoch.
|
||||
|
||||
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
|
||||
disable validation of the EMA weights. Validation will have to be done manually in a separate
|
||||
process, or after the training stops converging.
|
||||
|
||||
This class is sensitive where it is initialized in the sequence of model init,
|
||||
GPU assignment and distributed training wrappers.
|
||||
"""
|
||||
def __init__(self, model, decay=0.9999, device='', resume=''):
|
||||
# make a copy of the model for accumulating moving average of weights
|
||||
self.ema = deepcopy(model)
|
||||
self.ema.eval()
|
||||
self.decay = decay
|
||||
self.device = device # perform ema on different device from model if set
|
||||
if device:
|
||||
self.ema.to(device=device)
|
||||
self.ema_has_module = hasattr(self.ema, 'module')
|
||||
if resume:
|
||||
self._load_checkpoint(resume)
|
||||
for p in self.ema.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
def _load_checkpoint(self, checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
assert isinstance(checkpoint, dict)
|
||||
if 'state_dict_ema' in checkpoint:
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in checkpoint['state_dict_ema'].items():
|
||||
# ema model may have been wrapped by DataParallel, and need module prefix
|
||||
if self.ema_has_module:
|
||||
name = 'module.' + k if not k.startswith('module') else k
|
||||
else:
|
||||
name = k
|
||||
new_state_dict[name] = v
|
||||
self.ema.load_state_dict(new_state_dict)
|
||||
_logger.info("Loaded state_dict_ema")
|
||||
else:
|
||||
_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
|
||||
|
||||
def update(self, model):
|
||||
# correct a mismatch in state dict keys
|
||||
needs_module = hasattr(model, 'module') and not self.ema_has_module
|
||||
with torch.no_grad():
|
||||
msd = model.state_dict()
|
||||
for k, ema_v in self.ema.state_dict().items():
|
||||
if needs_module:
|
||||
k = 'module.' + k
|
||||
model_v = msd[k].detach()
|
||||
if self.device:
|
||||
model_v = model_v.to(device=self.device)
|
||||
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
|
||||
|
||||
|
||||
class ModelEmaV2(nn.Module):
|
||||
""" Model Exponential Moving Average V2
|
||||
|
||||
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
||||
V2 of this module is simpler, it does not match params/buffers based on name but simply
|
||||
iterates in order. It works with torchscript (JIT of full model).
|
||||
|
||||
This is intended to allow functionality like
|
||||
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
@ -27,22 +101,20 @@ class ModelEma(nn.Module):
|
||||
|
||||
This class is sensitive where it is initialized in the sequence of model init,
|
||||
GPU assignment and distributed training wrappers.
|
||||
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
||||
"""
|
||||
def __init__(self, model, decay=0.9999, device=None):
|
||||
super(ModelEma, self).__init__()
|
||||
super(ModelEmaV2, self).__init__()
|
||||
# make a copy of the model for accumulating moving average of weights
|
||||
self.module = deepcopy(model)
|
||||
self.module.eval()
|
||||
self.decay = decay
|
||||
self.device = device # perform ema on different device from model if set
|
||||
if device is not None:
|
||||
if self.device is not None:
|
||||
self.module.to(device=device)
|
||||
|
||||
def update(self, model):
|
||||
with torch.no_grad():
|
||||
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
|
||||
assert ema_v.shape == model_v.shape
|
||||
if self.device:
|
||||
if self.device is not None:
|
||||
model_v = model_v.to(device=self.device)
|
||||
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
|
||||
|
136
train.py
136
train.py
@ -29,7 +29,7 @@ import torchvision.utils
|
||||
from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
||||
|
||||
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
||||
from timm.models import create_model, resume_checkpoint, convert_splitbn_model
|
||||
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model
|
||||
from timm.utils import *
|
||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
||||
from timm.optim import create_optimizer
|
||||
@ -230,8 +230,6 @@ parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
|
||||
help='how many batches to wait before writing recovery checkpoint')
|
||||
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
|
||||
help='how many training processes to use (default: 1)')
|
||||
parser.add_argument('--num-gpu', type=int, default=1,
|
||||
help='Number of GPUS to use')
|
||||
parser.add_argument('--save-images', action='store_true', default=False,
|
||||
help='save images of input bathes every log interval for debugging')
|
||||
parser.add_argument('--amp', action='store_true', default=False,
|
||||
@ -255,6 +253,8 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
|
||||
parser.add_argument("--local_rank", default=0, type=int)
|
||||
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
|
||||
help='use the multi-epochs-loader to save time at the beginning of every epoch')
|
||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='convert model torchscript for inference')
|
||||
|
||||
|
||||
def _parse_args():
|
||||
@ -282,28 +282,36 @@ def main():
|
||||
args.distributed = False
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
||||
if args.distributed and args.num_gpu > 1:
|
||||
_logger.warning(
|
||||
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
|
||||
args.num_gpu = 1
|
||||
|
||||
args.device = 'cuda:0'
|
||||
args.world_size = 1
|
||||
args.rank = 0 # global rank
|
||||
if args.distributed:
|
||||
args.num_gpu = 1
|
||||
args.device = 'cuda:%d' % args.local_rank
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||
args.world_size = torch.distributed.get_world_size()
|
||||
args.rank = torch.distributed.get_rank()
|
||||
assert args.rank >= 0
|
||||
|
||||
if args.distributed:
|
||||
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
|
||||
% (args.rank, args.world_size))
|
||||
else:
|
||||
_logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
|
||||
_logger.info('Training with a single process on 1 GPUs.')
|
||||
assert args.rank >= 0
|
||||
|
||||
# resolve AMP arguments based on PyTorch / Apex availability
|
||||
use_amp = None
|
||||
if args.amp:
|
||||
# for backwards compat, `--amp` arg tries apex before native amp
|
||||
if has_apex:
|
||||
args.apex_amp = True
|
||||
elif has_native_amp:
|
||||
args.native_amp = True
|
||||
if args.apex_amp and has_apex:
|
||||
use_amp = 'apex'
|
||||
elif args.native_amp and has_native_amp:
|
||||
use_amp = 'native'
|
||||
elif args.apex_amp or args.native_amp:
|
||||
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
|
||||
"Install NVIDA apex or upgrade to PyTorch 1.6")
|
||||
|
||||
torch.manual_seed(args.seed + args.rank)
|
||||
|
||||
@ -327,44 +335,44 @@ def main():
|
||||
|
||||
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
||||
|
||||
# setup augmentation batch splits for contrastive loss or split bn
|
||||
num_aug_splits = 0
|
||||
if args.aug_splits > 0:
|
||||
assert args.aug_splits > 1, 'A split of 1 makes no sense'
|
||||
num_aug_splits = args.aug_splits
|
||||
|
||||
# enable split bn (separate bn stats per batch-portion)
|
||||
if args.split_bn:
|
||||
assert num_aug_splits > 1 or args.resplit
|
||||
model = convert_splitbn_model(model, max(num_aug_splits, 2))
|
||||
|
||||
use_amp = None
|
||||
if args.amp:
|
||||
# for backwards compat, `--amp` arg tries apex before native amp
|
||||
if has_apex:
|
||||
args.apex_amp = True
|
||||
elif has_native_amp:
|
||||
args.native_amp = True
|
||||
if args.apex_amp and has_apex:
|
||||
use_amp = 'apex'
|
||||
elif args.native_amp and has_native_amp:
|
||||
use_amp = 'native'
|
||||
elif args.apex_amp or args.native_amp:
|
||||
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
|
||||
"Install NVIDA apex or upgrade to PyTorch 1.6")
|
||||
# move model to GPU, enable channels last layout if set
|
||||
model.cuda()
|
||||
if args.channels_last:
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
|
||||
if args.num_gpu > 1:
|
||||
if use_amp == 'apex':
|
||||
_logger.warning(
|
||||
'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
|
||||
use_amp = None
|
||||
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
||||
assert not args.channels_last, "Channels last not supported with DP, use DDP."
|
||||
else:
|
||||
model.cuda()
|
||||
if args.channels_last:
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
# setup synchronized BatchNorm for distributed training
|
||||
if args.distributed and args.sync_bn:
|
||||
assert not args.split_bn
|
||||
if has_apex and use_amp != 'native':
|
||||
# Apex SyncBN preferred unless native amp is activated
|
||||
model = convert_syncbn_model(model)
|
||||
else:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
if args.local_rank == 0:
|
||||
_logger.info(
|
||||
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
|
||||
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
|
||||
|
||||
if args.torchscript:
|
||||
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
|
||||
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
|
||||
# FIXME I ran into a bug w/ AMP + torchscript + Linear layers
|
||||
model = torch.jit.script(model)
|
||||
|
||||
optimizer = create_optimizer(args, model)
|
||||
|
||||
# setup automatic mixed-precision (AMP) loss scaling and op casting
|
||||
amp_autocast = suppress # do nothing
|
||||
loss_scaler = None
|
||||
if use_amp == 'apex':
|
||||
@ -390,30 +398,17 @@ def main():
|
||||
loss_scaler=None if args.no_resume_opt else loss_scaler,
|
||||
log_info=args.local_rank == 0)
|
||||
|
||||
# setup exponential moving average of model weights, SWA could be used here too
|
||||
model_ema = None
|
||||
if args.model_ema:
|
||||
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
||||
model_ema = ModelEma(
|
||||
model,
|
||||
decay=args.model_ema_decay,
|
||||
device='cpu' if args.model_ema_force_cpu else '',
|
||||
resume=args.resume)
|
||||
model_ema = ModelEmaV2(
|
||||
model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
|
||||
if args.resume:
|
||||
load_checkpoint(model_ema.module, args.resume, use_ema=True)
|
||||
|
||||
# setup distributed training
|
||||
if args.distributed:
|
||||
if args.sync_bn:
|
||||
assert not args.split_bn
|
||||
try:
|
||||
if has_apex and use_amp != 'native':
|
||||
# Apex SyncBN preferred unless native amp is activated
|
||||
model = convert_syncbn_model(model)
|
||||
else:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
if args.local_rank == 0:
|
||||
_logger.info(
|
||||
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
|
||||
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
|
||||
except Exception as e:
|
||||
_logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
|
||||
if has_apex and use_amp != 'native':
|
||||
# Apex DDP preferred unless native amp is activated
|
||||
if args.local_rank == 0:
|
||||
@ -425,6 +420,7 @@ def main():
|
||||
model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
|
||||
# NOTE: EMA model does not need to be wrapped by DDP
|
||||
|
||||
# setup learning rate schedule and starting epoch
|
||||
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
|
||||
start_epoch = 0
|
||||
if args.start_epoch is not None:
|
||||
@ -438,12 +434,22 @@ def main():
|
||||
if args.local_rank == 0:
|
||||
_logger.info('Scheduled epochs: {}'.format(num_epochs))
|
||||
|
||||
# create the train and eval datasets
|
||||
train_dir = os.path.join(args.data, 'train')
|
||||
if not os.path.exists(train_dir):
|
||||
_logger.error('Training folder does not exist at: {}'.format(train_dir))
|
||||
exit(1)
|
||||
dataset_train = Dataset(train_dir)
|
||||
|
||||
eval_dir = os.path.join(args.data, 'val')
|
||||
if not os.path.isdir(eval_dir):
|
||||
eval_dir = os.path.join(args.data, 'validation')
|
||||
if not os.path.isdir(eval_dir):
|
||||
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
|
||||
exit(1)
|
||||
dataset_eval = Dataset(eval_dir)
|
||||
|
||||
# setup mixup / cutmix
|
||||
collate_fn = None
|
||||
mixup_fn = None
|
||||
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
||||
@ -458,9 +464,11 @@ def main():
|
||||
else:
|
||||
mixup_fn = Mixup(**mixup_args)
|
||||
|
||||
# wrap dataset in AugMix helper
|
||||
if num_aug_splits > 1:
|
||||
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
|
||||
|
||||
# create data loaders w/ augmentation pipeiine
|
||||
train_interpolation = args.train_interpolation
|
||||
if args.no_aug or not train_interpolation:
|
||||
train_interpolation = data_config['interpolation']
|
||||
@ -492,14 +500,6 @@ def main():
|
||||
use_multi_epochs_loader=args.use_multi_epochs_loader
|
||||
)
|
||||
|
||||
eval_dir = os.path.join(args.data, 'val')
|
||||
if not os.path.isdir(eval_dir):
|
||||
eval_dir = os.path.join(args.data, 'validation')
|
||||
if not os.path.isdir(eval_dir):
|
||||
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
|
||||
exit(1)
|
||||
dataset_eval = Dataset(eval_dir)
|
||||
|
||||
loader_eval = create_loader(
|
||||
dataset_eval,
|
||||
input_size=data_config['input_size'],
|
||||
@ -515,6 +515,7 @@ def main():
|
||||
pin_memory=args.pin_mem,
|
||||
)
|
||||
|
||||
# setup loss function
|
||||
if args.jsd:
|
||||
assert num_aug_splits > 1 # JSD only valid with aug splits set
|
||||
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
|
||||
@ -527,6 +528,7 @@ def main():
|
||||
train_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||
|
||||
# setup checkpoint saver and eval metric tracking
|
||||
eval_metric = args.eval_metric
|
||||
best_metric = None
|
||||
best_epoch = None
|
||||
@ -638,11 +640,11 @@ def train_epoch(
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
|
||||
optimizer.step()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
if model_ema is not None:
|
||||
model_ema.update(model)
|
||||
num_updates += 1
|
||||
|
||||
torch.cuda.synchronize()
|
||||
num_updates += 1
|
||||
batch_time_m.update(time.time() - end)
|
||||
if last_batch or batch_idx % args.log_interval == 0:
|
||||
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
|
||||
|
Loading…
x
Reference in New Issue
Block a user