diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py index 9e7bddf3..de0b881c 100644 --- a/timm/utils/cuda.py +++ b/timm/utils/cuda.py @@ -17,12 +17,22 @@ from .clip_grad import dispatch_clip_grad class ApexScaler: state_dict_key = "amp" - def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): + def __call__( + self, + loss, + optimizer, + clip_grad=None, + clip_mode='norm', + parameters=None, + create_graph=False, + need_update=True, + ): with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward(create_graph=create_graph) - if clip_grad is not None: - dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) - optimizer.step() + if need_update: + if clip_grad is not None: + dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) + optimizer.step() def state_dict(self): if 'state_dict' in amp.__dict__: @@ -39,14 +49,24 @@ class NativeScaler: def __init__(self): self._scaler = torch.cuda.amp.GradScaler() - def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): + def __call__( + self, + loss, + optimizer, + clip_grad=None, + clip_mode='norm', + parameters=None, + create_graph=False, + need_update=True, + ): self._scaler.scale(loss).backward(create_graph=create_graph) - if clip_grad is not None: - assert parameters is not None - self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place - dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) - self._scaler.step(optimizer) - self._scaler.update() + if need_update: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) + self._scaler.step(optimizer) + self._scaler.update() def state_dict(self): return self._scaler.state_dict() diff --git a/train.py b/train.py index 83663cc6..ec344a64 100755 --- a/train.py +++ b/train.py @@ -136,6 +136,8 @@ group.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') group.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") +group.add_argument('--grad-accum-steps', type=int, default=1, metavar='N', + help='The number of steps to accumulate gradients (default: 1)') group.add_argument('--grad-checkpointing', action='store_true', default=False, help='Enable gradient checkpointing through model blocks/stages') group.add_argument('--fast-norm', default=False, action='store_true', @@ -146,13 +148,12 @@ group.add_argument('--head-init-scale', default=None, type=float, group.add_argument('--head-init-bias', default=None, type=float, help='Head initialization bias value') +# scripting / codegen scripting_group = group.add_mutually_exclusive_group() scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', help='torch.jit.script the full model') scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor', help="Enable compilation w/ specified backend (default: inductor).") -scripting_group.add_argument('--aot-autograd', default=False, action='store_true', - help="Enable AOT Autograd support.") # Optimizer parameters group = parser.add_argument_group('Optimizer parameters') @@ -334,6 +335,8 @@ group.add_argument('--amp-impl', default='native', type=str, help='AMP impl to use, "native" or "apex" (default: native)') group.add_argument('--no-ddp-bb', action='store_true', default=False, help='Force broadcast buffers for native DDP to off.') +group.add_argument('--synchronize-step', action='store_true', default=False, + help='torch.cuda.synchronize() end of each step') group.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') group.add_argument('--no-prefetcher', action='store_true', default=False, @@ -379,6 +382,7 @@ def main(): torch.backends.cudnn.benchmark = True args.prefetcher = not args.no_prefetcher + args.grad_accum_steps = max(1, args.grad_accum_steps) device = utils.init_distributed_device(args) if args.distributed: _logger.info( @@ -483,20 +487,13 @@ def main(): 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') if args.torchscript: + assert not args.torchcompile assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) - elif args.torchcompile: - # FIXME dynamo might need move below DDP wrapping? TBD - assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' - torch._dynamo.reset() - model = torch.compile(model, backend=args.torchcompile) - elif args.aot_autograd: - assert has_functorch, "functorch is needed for --aot-autograd" - model = memory_efficient_fusion(model) if not args.lr: - global_batch_size = args.batch_size * args.world_size + global_batch_size = args.batch_size * args.world_size * args.grad_accum_steps batch_ratio = global_batch_size / args.lr_base_size if not args.lr_base_scale: on = args.opt.lower() @@ -507,7 +504,7 @@ def main(): if utils.is_primary(args): _logger.info( f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) ' - f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.') + f'and effective global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.') optimizer = create_optimizer_v2( model, @@ -573,6 +570,11 @@ def main(): model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb) # NOTE: EMA model does not need to be wrapped by DDP + if args.torchcompile: + # torch compile should be done after DDP + assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' + model = torch.compile(model, backend=args.torchcompile) + # create the train and eval datasets if args.data and not args.data_dir: args.data_dir = args.data @@ -738,7 +740,7 @@ def main(): "Metrics not being logged to wandb, try `pip install wandb`") # setup learning rate schedule and starting epoch - updates_per_epoch = len(loader_train) + updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps lr_scheduler, num_epochs = create_scheduler_v2( optimizer, **scheduler_kwargs(args), @@ -852,7 +854,7 @@ def train_one_epoch( amp_autocast=suppress, loss_scaler=None, model_ema=None, - mixup_fn=None + mixup_fn=None, ): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: @@ -861,19 +863,30 @@ def train_one_epoch( mixup_fn.mixup_enabled = False second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order - batch_time_m = utils.AverageMeter() + has_no_sync = hasattr(model, "no_sync") + update_time_m = utils.AverageMeter() data_time_m = utils.AverageMeter() losses_m = utils.AverageMeter() model.train() - end = time.time() - num_batches_per_epoch = len(loader) - last_idx = num_batches_per_epoch - 1 - num_updates = epoch * num_batches_per_epoch + accum_steps = args.grad_accum_steps + last_accum_steps = len(loader) % accum_steps + updates_per_epoch = (len(loader) + accum_steps - 1) // accum_steps + num_updates = epoch * updates_per_epoch + last_batch_idx = len(loader) - 1 + last_batch_idx_to_accum = len(loader) - last_accum_steps + + data_start_time = update_start_time = time.time() + optimizer.zero_grad() + update_sample_count = 0 for batch_idx, (input, target) in enumerate(loader): - last_batch = batch_idx == last_idx - data_time_m.update(time.time() - end) + last_batch = batch_idx == last_batch_idx + need_update = last_batch or (batch_idx + 1) % accum_steps == 0 + update_idx = batch_idx // accum_steps + if batch_idx >= last_batch_idx_to_accum: + accum_steps = last_accum_steps + if not args.prefetcher: input, target = input.to(device), target.to(device) if mixup_fn is not None: @@ -881,64 +894,84 @@ def train_one_epoch( if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) - with amp_autocast(): - output = model(input) - loss = loss_fn(output, target) + # multiply by accum steps to get equivalent for full update + data_time_m.update(accum_steps * (time.time() - data_start_time)) + + def _forward(): + with amp_autocast(): + output = model(input) + loss = loss_fn(output, target) + if accum_steps > 1: + loss /= accum_steps + return loss + + def _backward(_loss): + if loss_scaler is not None: + loss_scaler( + _loss, + optimizer, + clip_grad=args.clip_grad, + clip_mode=args.clip_mode, + parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), + create_graph=second_order, + need_update=need_update, + ) + else: + _loss.backward(create_graph=second_order) + if need_update: + if args.clip_grad is not None: + utils.dispatch_clip_grad( + model_parameters(model, exclude_head='agc' in args.clip_mode), + value=args.clip_grad, + mode=args.clip_mode, + ) + optimizer.step() + + if has_no_sync and not need_update: + with model.no_sync(): + loss = _forward() + _backward(loss) + else: + loss = _forward() + _backward(loss) if not args.distributed: - losses_m.update(loss.item(), input.size(0)) + losses_m.update(loss.item() * accum_steps, input.size(0)) + update_sample_count += input.size(0) + if not need_update: + data_start_time = time.time() + continue + + num_updates += 1 optimizer.zero_grad() - if loss_scaler is not None: - loss_scaler( - loss, optimizer, - clip_grad=args.clip_grad, - clip_mode=args.clip_mode, - parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), - create_graph=second_order - ) - else: - loss.backward(create_graph=second_order) - if args.clip_grad is not None: - utils.dispatch_clip_grad( - model_parameters(model, exclude_head='agc' in args.clip_mode), - value=args.clip_grad, - mode=args.clip_mode - ) - optimizer.step() - if model_ema is not None: model_ema.update(model) - torch.cuda.synchronize() + if args.synchronize_step and device.type == 'cuda': + torch.cuda.synchronize() + time_now = time.time() + update_time_m.update(time.time() - update_start_time) + update_start_time = time_now - num_updates += 1 - batch_time_m.update(time.time() - end) - if last_batch or batch_idx % args.log_interval == 0: + if update_idx % args.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) if args.distributed: reduced_loss = utils.reduce_tensor(loss.data, args.world_size) - losses_m.update(reduced_loss.item(), input.size(0)) + losses_m.update(reduced_loss.item() * accum_steps, input.size(0)) + update_sample_count *= args.world_size if utils.is_primary(args): _logger.info( - 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' - 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) ' - 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' - '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' - 'LR: {lr:.3e} ' - 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( - epoch, - batch_idx, len(loader), - 100. * batch_idx / last_idx, - loss=losses_m, - batch_time=batch_time_m, - rate=input.size(0) * args.world_size / batch_time_m.val, - rate_avg=input.size(0) * args.world_size / batch_time_m.avg, - lr=lr, - data_time=data_time_m) + f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} ' + f'({100. * update_idx / (updates_per_epoch - 1):>3.0f}%)] ' + f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g}) ' + f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s ' + f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) ' + f'LR: {lr:.3e} ' + f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})' ) if args.save_images and output_dir: @@ -950,13 +983,14 @@ def train_one_epoch( ) if saver is not None and args.recovery_interval and ( - last_batch or (batch_idx + 1) % args.recovery_interval == 0): - saver.save_recovery(epoch, batch_idx=batch_idx) + (update_idx + 1) % args.recovery_interval == 0): + saver.save_recovery(epoch, batch_idx=update_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) - end = time.time() + update_sample_count = 0 + data_start_time = time.time() # end for if hasattr(optimizer, 'sync_lookahead'): @@ -1025,16 +1059,11 @@ def validate( if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0): log_name = 'Test' + log_suffix _logger.info( - '{0}: [{1:>4d}/{2}] ' - 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' - 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' - 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' - 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( - log_name, batch_idx, last_idx, - batch_time=batch_time_m, - loss=losses_m, - top1=top1_m, - top5=top5_m) + f'{log_name}: [{batch_idx:>4d}/{last_idx}] ' + f'Time: {batch_time_m.val:.3f} ({batch_time_m.avg:.3f}) ' + f'Loss: {losses_m.val:>7.3f} ({losses_m.avg:>6.3f}) ' + f'Acc@1: {top1_m.val:>7.3f} ({top1_m.avg:>7.3f}) ' + f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})' ) metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])