mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add back old ModelEma and rename new one to ModelEmaV2 to avoid compat breaks in dependant code. Shuffle train script, add a few comments, remove DataParallel support, support experimental torchscript training.
This commit is contained in:
parent
9214ca0716
commit
27bbc70d71
@ -6,5 +6,5 @@ from .log import setup_default_logging, FormatterNoInfo
|
|||||||
from .metrics import AverageMeter, accuracy
|
from .metrics import AverageMeter, accuracy
|
||||||
from .misc import natural_key, add_bool_arg
|
from .misc import natural_key, add_bool_arg
|
||||||
from .model import unwrap_model, get_state_dict
|
from .model import unwrap_model, get_state_dict
|
||||||
from .model_ema import ModelEma
|
from .model_ema import ModelEma, ModelEmaV2
|
||||||
from .summary import update_summary, get_outdir
|
from .summary import update_summary, get_outdir
|
||||||
|
@ -6,7 +6,10 @@ from .model_ema import ModelEma
|
|||||||
|
|
||||||
|
|
||||||
def unwrap_model(model):
|
def unwrap_model(model):
|
||||||
return model.module if hasattr(model, 'module') else model
|
if isinstance(model, ModelEma):
|
||||||
|
return unwrap_model(model.ema)
|
||||||
|
else:
|
||||||
|
return model.module if hasattr(model, 'module') else model
|
||||||
|
|
||||||
|
|
||||||
def get_state_dict(model, unwrap_fn=unwrap_model):
|
def get_state_dict(model, unwrap_fn=unwrap_model):
|
||||||
|
@ -2,15 +2,89 @@
|
|||||||
|
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelEma:
|
||||||
|
""" Model Exponential Moving Average (DEPRECATED)
|
||||||
|
|
||||||
class ModelEma(nn.Module):
|
|
||||||
""" Model Exponential Moving Average
|
|
||||||
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
||||||
|
This version is deprecated, it does not work with scripted models. Will be removed eventually.
|
||||||
|
|
||||||
|
This is intended to allow functionality like
|
||||||
|
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||||
|
|
||||||
|
A smoothed version of the weights is necessary for some training schemes to perform well.
|
||||||
|
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
|
||||||
|
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
|
||||||
|
smoothing of weights to match results. Pay attention to the decay constant you are using
|
||||||
|
relative to your update count per epoch.
|
||||||
|
|
||||||
|
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
|
||||||
|
disable validation of the EMA weights. Validation will have to be done manually in a separate
|
||||||
|
process, or after the training stops converging.
|
||||||
|
|
||||||
|
This class is sensitive where it is initialized in the sequence of model init,
|
||||||
|
GPU assignment and distributed training wrappers.
|
||||||
|
"""
|
||||||
|
def __init__(self, model, decay=0.9999, device='', resume=''):
|
||||||
|
# make a copy of the model for accumulating moving average of weights
|
||||||
|
self.ema = deepcopy(model)
|
||||||
|
self.ema.eval()
|
||||||
|
self.decay = decay
|
||||||
|
self.device = device # perform ema on different device from model if set
|
||||||
|
if device:
|
||||||
|
self.ema.to(device=device)
|
||||||
|
self.ema_has_module = hasattr(self.ema, 'module')
|
||||||
|
if resume:
|
||||||
|
self._load_checkpoint(resume)
|
||||||
|
for p in self.ema.parameters():
|
||||||
|
p.requires_grad_(False)
|
||||||
|
|
||||||
|
def _load_checkpoint(self, checkpoint_path):
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||||
|
assert isinstance(checkpoint, dict)
|
||||||
|
if 'state_dict_ema' in checkpoint:
|
||||||
|
new_state_dict = OrderedDict()
|
||||||
|
for k, v in checkpoint['state_dict_ema'].items():
|
||||||
|
# ema model may have been wrapped by DataParallel, and need module prefix
|
||||||
|
if self.ema_has_module:
|
||||||
|
name = 'module.' + k if not k.startswith('module') else k
|
||||||
|
else:
|
||||||
|
name = k
|
||||||
|
new_state_dict[name] = v
|
||||||
|
self.ema.load_state_dict(new_state_dict)
|
||||||
|
_logger.info("Loaded state_dict_ema")
|
||||||
|
else:
|
||||||
|
_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
|
||||||
|
|
||||||
|
def update(self, model):
|
||||||
|
# correct a mismatch in state dict keys
|
||||||
|
needs_module = hasattr(model, 'module') and not self.ema_has_module
|
||||||
|
with torch.no_grad():
|
||||||
|
msd = model.state_dict()
|
||||||
|
for k, ema_v in self.ema.state_dict().items():
|
||||||
|
if needs_module:
|
||||||
|
k = 'module.' + k
|
||||||
|
model_v = msd[k].detach()
|
||||||
|
if self.device:
|
||||||
|
model_v = model_v.to(device=self.device)
|
||||||
|
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelEmaV2(nn.Module):
|
||||||
|
""" Model Exponential Moving Average V2
|
||||||
|
|
||||||
|
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
||||||
|
V2 of this module is simpler, it does not match params/buffers based on name but simply
|
||||||
|
iterates in order. It works with torchscript (JIT of full model).
|
||||||
|
|
||||||
This is intended to allow functionality like
|
This is intended to allow functionality like
|
||||||
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||||
@ -27,22 +101,20 @@ class ModelEma(nn.Module):
|
|||||||
|
|
||||||
This class is sensitive where it is initialized in the sequence of model init,
|
This class is sensitive where it is initialized in the sequence of model init,
|
||||||
GPU assignment and distributed training wrappers.
|
GPU assignment and distributed training wrappers.
|
||||||
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, model, decay=0.9999, device=None):
|
def __init__(self, model, decay=0.9999, device=None):
|
||||||
super(ModelEma, self).__init__()
|
super(ModelEmaV2, self).__init__()
|
||||||
# make a copy of the model for accumulating moving average of weights
|
# make a copy of the model for accumulating moving average of weights
|
||||||
self.module = deepcopy(model)
|
self.module = deepcopy(model)
|
||||||
self.module.eval()
|
self.module.eval()
|
||||||
self.decay = decay
|
self.decay = decay
|
||||||
self.device = device # perform ema on different device from model if set
|
self.device = device # perform ema on different device from model if set
|
||||||
if device is not None:
|
if self.device is not None:
|
||||||
self.module.to(device=device)
|
self.module.to(device=device)
|
||||||
|
|
||||||
def update(self, model):
|
def update(self, model):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
|
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
|
||||||
assert ema_v.shape == model_v.shape
|
if self.device is not None:
|
||||||
if self.device:
|
|
||||||
model_v = model_v.to(device=self.device)
|
model_v = model_v.to(device=self.device)
|
||||||
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
|
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
|
||||||
|
136
train.py
136
train.py
@ -29,7 +29,7 @@ import torchvision.utils
|
|||||||
from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
||||||
|
|
||||||
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
||||||
from timm.models import create_model, resume_checkpoint, convert_splitbn_model
|
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model
|
||||||
from timm.utils import *
|
from timm.utils import *
|
||||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
||||||
from timm.optim import create_optimizer
|
from timm.optim import create_optimizer
|
||||||
@ -230,8 +230,6 @@ parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
|
|||||||
help='how many batches to wait before writing recovery checkpoint')
|
help='how many batches to wait before writing recovery checkpoint')
|
||||||
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
|
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
|
||||||
help='how many training processes to use (default: 1)')
|
help='how many training processes to use (default: 1)')
|
||||||
parser.add_argument('--num-gpu', type=int, default=1,
|
|
||||||
help='Number of GPUS to use')
|
|
||||||
parser.add_argument('--save-images', action='store_true', default=False,
|
parser.add_argument('--save-images', action='store_true', default=False,
|
||||||
help='save images of input bathes every log interval for debugging')
|
help='save images of input bathes every log interval for debugging')
|
||||||
parser.add_argument('--amp', action='store_true', default=False,
|
parser.add_argument('--amp', action='store_true', default=False,
|
||||||
@ -255,6 +253,8 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
|
|||||||
parser.add_argument("--local_rank", default=0, type=int)
|
parser.add_argument("--local_rank", default=0, type=int)
|
||||||
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
|
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
|
||||||
help='use the multi-epochs-loader to save time at the beginning of every epoch')
|
help='use the multi-epochs-loader to save time at the beginning of every epoch')
|
||||||
|
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||||
|
help='convert model torchscript for inference')
|
||||||
|
|
||||||
|
|
||||||
def _parse_args():
|
def _parse_args():
|
||||||
@ -282,28 +282,36 @@ def main():
|
|||||||
args.distributed = False
|
args.distributed = False
|
||||||
if 'WORLD_SIZE' in os.environ:
|
if 'WORLD_SIZE' in os.environ:
|
||||||
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
||||||
if args.distributed and args.num_gpu > 1:
|
|
||||||
_logger.warning(
|
|
||||||
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
|
|
||||||
args.num_gpu = 1
|
|
||||||
|
|
||||||
args.device = 'cuda:0'
|
args.device = 'cuda:0'
|
||||||
args.world_size = 1
|
args.world_size = 1
|
||||||
args.rank = 0 # global rank
|
args.rank = 0 # global rank
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
args.num_gpu = 1
|
|
||||||
args.device = 'cuda:%d' % args.local_rank
|
args.device = 'cuda:%d' % args.local_rank
|
||||||
torch.cuda.set_device(args.local_rank)
|
torch.cuda.set_device(args.local_rank)
|
||||||
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||||
args.world_size = torch.distributed.get_world_size()
|
args.world_size = torch.distributed.get_world_size()
|
||||||
args.rank = torch.distributed.get_rank()
|
args.rank = torch.distributed.get_rank()
|
||||||
assert args.rank >= 0
|
|
||||||
|
|
||||||
if args.distributed:
|
|
||||||
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
|
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
|
||||||
% (args.rank, args.world_size))
|
% (args.rank, args.world_size))
|
||||||
else:
|
else:
|
||||||
_logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
|
_logger.info('Training with a single process on 1 GPUs.')
|
||||||
|
assert args.rank >= 0
|
||||||
|
|
||||||
|
# resolve AMP arguments based on PyTorch / Apex availability
|
||||||
|
use_amp = None
|
||||||
|
if args.amp:
|
||||||
|
# for backwards compat, `--amp` arg tries apex before native amp
|
||||||
|
if has_apex:
|
||||||
|
args.apex_amp = True
|
||||||
|
elif has_native_amp:
|
||||||
|
args.native_amp = True
|
||||||
|
if args.apex_amp and has_apex:
|
||||||
|
use_amp = 'apex'
|
||||||
|
elif args.native_amp and has_native_amp:
|
||||||
|
use_amp = 'native'
|
||||||
|
elif args.apex_amp or args.native_amp:
|
||||||
|
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
|
||||||
|
"Install NVIDA apex or upgrade to PyTorch 1.6")
|
||||||
|
|
||||||
torch.manual_seed(args.seed + args.rank)
|
torch.manual_seed(args.seed + args.rank)
|
||||||
|
|
||||||
@ -327,44 +335,44 @@ def main():
|
|||||||
|
|
||||||
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
||||||
|
|
||||||
|
# setup augmentation batch splits for contrastive loss or split bn
|
||||||
num_aug_splits = 0
|
num_aug_splits = 0
|
||||||
if args.aug_splits > 0:
|
if args.aug_splits > 0:
|
||||||
assert args.aug_splits > 1, 'A split of 1 makes no sense'
|
assert args.aug_splits > 1, 'A split of 1 makes no sense'
|
||||||
num_aug_splits = args.aug_splits
|
num_aug_splits = args.aug_splits
|
||||||
|
|
||||||
|
# enable split bn (separate bn stats per batch-portion)
|
||||||
if args.split_bn:
|
if args.split_bn:
|
||||||
assert num_aug_splits > 1 or args.resplit
|
assert num_aug_splits > 1 or args.resplit
|
||||||
model = convert_splitbn_model(model, max(num_aug_splits, 2))
|
model = convert_splitbn_model(model, max(num_aug_splits, 2))
|
||||||
|
|
||||||
use_amp = None
|
# move model to GPU, enable channels last layout if set
|
||||||
if args.amp:
|
model.cuda()
|
||||||
# for backwards compat, `--amp` arg tries apex before native amp
|
if args.channels_last:
|
||||||
if has_apex:
|
model = model.to(memory_format=torch.channels_last)
|
||||||
args.apex_amp = True
|
|
||||||
elif has_native_amp:
|
|
||||||
args.native_amp = True
|
|
||||||
if args.apex_amp and has_apex:
|
|
||||||
use_amp = 'apex'
|
|
||||||
elif args.native_amp and has_native_amp:
|
|
||||||
use_amp = 'native'
|
|
||||||
elif args.apex_amp or args.native_amp:
|
|
||||||
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
|
|
||||||
"Install NVIDA apex or upgrade to PyTorch 1.6")
|
|
||||||
|
|
||||||
if args.num_gpu > 1:
|
# setup synchronized BatchNorm for distributed training
|
||||||
if use_amp == 'apex':
|
if args.distributed and args.sync_bn:
|
||||||
_logger.warning(
|
assert not args.split_bn
|
||||||
'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
|
if has_apex and use_amp != 'native':
|
||||||
use_amp = None
|
# Apex SyncBN preferred unless native amp is activated
|
||||||
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
model = convert_syncbn_model(model)
|
||||||
assert not args.channels_last, "Channels last not supported with DP, use DDP."
|
else:
|
||||||
else:
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||||
model.cuda()
|
if args.local_rank == 0:
|
||||||
if args.channels_last:
|
_logger.info(
|
||||||
model = model.to(memory_format=torch.channels_last)
|
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
|
||||||
|
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
|
||||||
|
|
||||||
|
if args.torchscript:
|
||||||
|
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
|
||||||
|
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
|
||||||
|
# FIXME I ran into a bug w/ AMP + torchscript + Linear layers
|
||||||
|
model = torch.jit.script(model)
|
||||||
|
|
||||||
optimizer = create_optimizer(args, model)
|
optimizer = create_optimizer(args, model)
|
||||||
|
|
||||||
|
# setup automatic mixed-precision (AMP) loss scaling and op casting
|
||||||
amp_autocast = suppress # do nothing
|
amp_autocast = suppress # do nothing
|
||||||
loss_scaler = None
|
loss_scaler = None
|
||||||
if use_amp == 'apex':
|
if use_amp == 'apex':
|
||||||
@ -390,30 +398,17 @@ def main():
|
|||||||
loss_scaler=None if args.no_resume_opt else loss_scaler,
|
loss_scaler=None if args.no_resume_opt else loss_scaler,
|
||||||
log_info=args.local_rank == 0)
|
log_info=args.local_rank == 0)
|
||||||
|
|
||||||
|
# setup exponential moving average of model weights, SWA could be used here too
|
||||||
model_ema = None
|
model_ema = None
|
||||||
if args.model_ema:
|
if args.model_ema:
|
||||||
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
||||||
model_ema = ModelEma(
|
model_ema = ModelEmaV2(
|
||||||
model,
|
model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
|
||||||
decay=args.model_ema_decay,
|
if args.resume:
|
||||||
device='cpu' if args.model_ema_force_cpu else '',
|
load_checkpoint(model_ema.module, args.resume, use_ema=True)
|
||||||
resume=args.resume)
|
|
||||||
|
|
||||||
|
# setup distributed training
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
if args.sync_bn:
|
|
||||||
assert not args.split_bn
|
|
||||||
try:
|
|
||||||
if has_apex and use_amp != 'native':
|
|
||||||
# Apex SyncBN preferred unless native amp is activated
|
|
||||||
model = convert_syncbn_model(model)
|
|
||||||
else:
|
|
||||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
||||||
if args.local_rank == 0:
|
|
||||||
_logger.info(
|
|
||||||
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
|
|
||||||
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
|
|
||||||
except Exception as e:
|
|
||||||
_logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
|
|
||||||
if has_apex and use_amp != 'native':
|
if has_apex and use_amp != 'native':
|
||||||
# Apex DDP preferred unless native amp is activated
|
# Apex DDP preferred unless native amp is activated
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
@ -425,6 +420,7 @@ def main():
|
|||||||
model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
|
model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
|
||||||
# NOTE: EMA model does not need to be wrapped by DDP
|
# NOTE: EMA model does not need to be wrapped by DDP
|
||||||
|
|
||||||
|
# setup learning rate schedule and starting epoch
|
||||||
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
|
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
if args.start_epoch is not None:
|
if args.start_epoch is not None:
|
||||||
@ -438,12 +434,22 @@ def main():
|
|||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
_logger.info('Scheduled epochs: {}'.format(num_epochs))
|
_logger.info('Scheduled epochs: {}'.format(num_epochs))
|
||||||
|
|
||||||
|
# create the train and eval datasets
|
||||||
train_dir = os.path.join(args.data, 'train')
|
train_dir = os.path.join(args.data, 'train')
|
||||||
if not os.path.exists(train_dir):
|
if not os.path.exists(train_dir):
|
||||||
_logger.error('Training folder does not exist at: {}'.format(train_dir))
|
_logger.error('Training folder does not exist at: {}'.format(train_dir))
|
||||||
exit(1)
|
exit(1)
|
||||||
dataset_train = Dataset(train_dir)
|
dataset_train = Dataset(train_dir)
|
||||||
|
|
||||||
|
eval_dir = os.path.join(args.data, 'val')
|
||||||
|
if not os.path.isdir(eval_dir):
|
||||||
|
eval_dir = os.path.join(args.data, 'validation')
|
||||||
|
if not os.path.isdir(eval_dir):
|
||||||
|
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
|
||||||
|
exit(1)
|
||||||
|
dataset_eval = Dataset(eval_dir)
|
||||||
|
|
||||||
|
# setup mixup / cutmix
|
||||||
collate_fn = None
|
collate_fn = None
|
||||||
mixup_fn = None
|
mixup_fn = None
|
||||||
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
||||||
@ -458,9 +464,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
mixup_fn = Mixup(**mixup_args)
|
mixup_fn = Mixup(**mixup_args)
|
||||||
|
|
||||||
|
# wrap dataset in AugMix helper
|
||||||
if num_aug_splits > 1:
|
if num_aug_splits > 1:
|
||||||
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
|
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
|
||||||
|
|
||||||
|
# create data loaders w/ augmentation pipeiine
|
||||||
train_interpolation = args.train_interpolation
|
train_interpolation = args.train_interpolation
|
||||||
if args.no_aug or not train_interpolation:
|
if args.no_aug or not train_interpolation:
|
||||||
train_interpolation = data_config['interpolation']
|
train_interpolation = data_config['interpolation']
|
||||||
@ -492,14 +500,6 @@ def main():
|
|||||||
use_multi_epochs_loader=args.use_multi_epochs_loader
|
use_multi_epochs_loader=args.use_multi_epochs_loader
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_dir = os.path.join(args.data, 'val')
|
|
||||||
if not os.path.isdir(eval_dir):
|
|
||||||
eval_dir = os.path.join(args.data, 'validation')
|
|
||||||
if not os.path.isdir(eval_dir):
|
|
||||||
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
|
|
||||||
exit(1)
|
|
||||||
dataset_eval = Dataset(eval_dir)
|
|
||||||
|
|
||||||
loader_eval = create_loader(
|
loader_eval = create_loader(
|
||||||
dataset_eval,
|
dataset_eval,
|
||||||
input_size=data_config['input_size'],
|
input_size=data_config['input_size'],
|
||||||
@ -515,6 +515,7 @@ def main():
|
|||||||
pin_memory=args.pin_mem,
|
pin_memory=args.pin_mem,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# setup loss function
|
||||||
if args.jsd:
|
if args.jsd:
|
||||||
assert num_aug_splits > 1 # JSD only valid with aug splits set
|
assert num_aug_splits > 1 # JSD only valid with aug splits set
|
||||||
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
|
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
|
||||||
@ -527,6 +528,7 @@ def main():
|
|||||||
train_loss_fn = nn.CrossEntropyLoss().cuda()
|
train_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||||
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||||
|
|
||||||
|
# setup checkpoint saver and eval metric tracking
|
||||||
eval_metric = args.eval_metric
|
eval_metric = args.eval_metric
|
||||||
best_metric = None
|
best_metric = None
|
||||||
best_epoch = None
|
best_epoch = None
|
||||||
@ -638,11 +640,11 @@ def train_epoch(
|
|||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
if model_ema is not None:
|
if model_ema is not None:
|
||||||
model_ema.update(model)
|
model_ema.update(model)
|
||||||
num_updates += 1
|
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
num_updates += 1
|
||||||
batch_time_m.update(time.time() - end)
|
batch_time_m.update(time.time() - end)
|
||||||
if last_batch or batch_idx % args.log_interval == 0:
|
if last_batch or batch_idx % args.log_interval == 0:
|
||||||
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
|
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user