Merge pull request #1784 from huggingface/wip-voidbag-accumulate-grad

Accumulate gradients (adding to #1659)
This commit is contained in:
Ross Wightman 2023-04-20 08:15:28 -07:00 committed by GitHub
commit 2aabaef039
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 136 additions and 87 deletions

View File

@ -17,12 +17,22 @@ from .clip_grad import dispatch_clip_grad
class ApexScaler:
state_dict_key = "amp"
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
def __call__(
self,
loss,
optimizer,
clip_grad=None,
clip_mode='norm',
parameters=None,
create_graph=False,
need_update=True,
):
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward(create_graph=create_graph)
if clip_grad is not None:
dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
optimizer.step()
if need_update:
if clip_grad is not None:
dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
optimizer.step()
def state_dict(self):
if 'state_dict' in amp.__dict__:
@ -39,14 +49,24 @@ class NativeScaler:
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
def __call__(
self,
loss,
optimizer,
clip_grad=None,
clip_mode='norm',
parameters=None,
create_graph=False,
need_update=True,
):
self._scaler.scale(loss).backward(create_graph=create_graph)
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
self._scaler.step(optimizer)
self._scaler.update()
if need_update:
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
self._scaler.step(optimizer)
self._scaler.update()
def state_dict(self):
return self._scaler.state_dict()

181
train.py
View File

@ -136,6 +136,8 @@ group.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
group.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
group.add_argument('--grad-accum-steps', type=int, default=1, metavar='N',
help='The number of steps to accumulate gradients (default: 1)')
group.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages')
group.add_argument('--fast-norm', default=False, action='store_true',
@ -146,13 +148,12 @@ group.add_argument('--head-init-scale', default=None, type=float,
group.add_argument('--head-init-bias', default=None, type=float,
help='Head initialization bias value')
# scripting / codegen
scripting_group = group.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
help='torch.jit.script the full model')
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
help="Enable compilation w/ specified backend (default: inductor).")
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support.")
# Optimizer parameters
group = parser.add_argument_group('Optimizer parameters')
@ -334,6 +335,8 @@ group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
group.add_argument('--synchronize-step', action='store_true', default=False,
help='torch.cuda.synchronize() end of each step')
group.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
group.add_argument('--no-prefetcher', action='store_true', default=False,
@ -379,6 +382,7 @@ def main():
torch.backends.cudnn.benchmark = True
args.prefetcher = not args.no_prefetcher
args.grad_accum_steps = max(1, args.grad_accum_steps)
device = utils.init_distributed_device(args)
if args.distributed:
_logger.info(
@ -483,20 +487,13 @@ def main():
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
if args.torchscript:
assert not args.torchcompile
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model)
elif args.torchcompile:
# FIXME dynamo might need move below DDP wrapping? TBD
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
torch._dynamo.reset()
model = torch.compile(model, backend=args.torchcompile)
elif args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model)
if not args.lr:
global_batch_size = args.batch_size * args.world_size
global_batch_size = args.batch_size * args.world_size * args.grad_accum_steps
batch_ratio = global_batch_size / args.lr_base_size
if not args.lr_base_scale:
on = args.opt.lower()
@ -507,7 +504,7 @@ def main():
if utils.is_primary(args):
_logger.info(
f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
f'and effective global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
optimizer = create_optimizer_v2(
model,
@ -573,6 +570,11 @@ def main():
model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
# NOTE: EMA model does not need to be wrapped by DDP
if args.torchcompile:
# torch compile should be done after DDP
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
model = torch.compile(model, backend=args.torchcompile)
# create the train and eval datasets
if args.data and not args.data_dir:
args.data_dir = args.data
@ -738,7 +740,7 @@ def main():
"Metrics not being logged to wandb, try `pip install wandb`")
# setup learning rate schedule and starting epoch
updates_per_epoch = len(loader_train)
updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
lr_scheduler, num_epochs = create_scheduler_v2(
optimizer,
**scheduler_kwargs(args),
@ -852,7 +854,7 @@ def train_one_epoch(
amp_autocast=suppress,
loss_scaler=None,
model_ema=None,
mixup_fn=None
mixup_fn=None,
):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
@ -861,19 +863,30 @@ def train_one_epoch(
mixup_fn.mixup_enabled = False
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
batch_time_m = utils.AverageMeter()
has_no_sync = hasattr(model, "no_sync")
update_time_m = utils.AverageMeter()
data_time_m = utils.AverageMeter()
losses_m = utils.AverageMeter()
model.train()
end = time.time()
num_batches_per_epoch = len(loader)
last_idx = num_batches_per_epoch - 1
num_updates = epoch * num_batches_per_epoch
accum_steps = args.grad_accum_steps
last_accum_steps = len(loader) % accum_steps
updates_per_epoch = (len(loader) + accum_steps - 1) // accum_steps
num_updates = epoch * updates_per_epoch
last_batch_idx = len(loader) - 1
last_batch_idx_to_accum = len(loader) - last_accum_steps
data_start_time = update_start_time = time.time()
optimizer.zero_grad()
update_sample_count = 0
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
last_batch = batch_idx == last_batch_idx
need_update = last_batch or (batch_idx + 1) % accum_steps == 0
update_idx = batch_idx // accum_steps
if batch_idx >= last_batch_idx_to_accum:
accum_steps = last_accum_steps
if not args.prefetcher:
input, target = input.to(device), target.to(device)
if mixup_fn is not None:
@ -881,64 +894,84 @@ def train_one_epoch(
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(input)
loss = loss_fn(output, target)
# multiply by accum steps to get equivalent for full update
data_time_m.update(accum_steps * (time.time() - data_start_time))
def _forward():
with amp_autocast():
output = model(input)
loss = loss_fn(output, target)
if accum_steps > 1:
loss /= accum_steps
return loss
def _backward(_loss):
if loss_scaler is not None:
loss_scaler(
_loss,
optimizer,
clip_grad=args.clip_grad,
clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order,
need_update=need_update,
)
else:
_loss.backward(create_graph=second_order)
if need_update:
if args.clip_grad is not None:
utils.dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad,
mode=args.clip_mode,
)
optimizer.step()
if has_no_sync and not need_update:
with model.no_sync():
loss = _forward()
_backward(loss)
else:
loss = _forward()
_backward(loss)
if not args.distributed:
losses_m.update(loss.item(), input.size(0))
losses_m.update(loss.item() * accum_steps, input.size(0))
update_sample_count += input.size(0)
if not need_update:
data_start_time = time.time()
continue
num_updates += 1
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer,
clip_grad=args.clip_grad,
clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order
)
else:
loss.backward(create_graph=second_order)
if args.clip_grad is not None:
utils.dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad,
mode=args.clip_mode
)
optimizer.step()
if model_ema is not None:
model_ema.update(model)
torch.cuda.synchronize()
if args.synchronize_step and device.type == 'cuda':
torch.cuda.synchronize()
time_now = time.time()
update_time_m.update(time.time() - update_start_time)
update_start_time = time_now
num_updates += 1
batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
if update_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item(), input.size(0))
losses_m.update(reduced_loss.item() * accum_steps, input.size(0))
update_sample_count *= args.world_size
if utils.is_primary(args):
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
batch_time=batch_time_m,
rate=input.size(0) * args.world_size / batch_time_m.val,
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m)
f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
f'({100. * update_idx / (updates_per_epoch - 1):>3.0f}%)] '
f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g}) '
f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s '
f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) '
f'LR: {lr:.3e} '
f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})'
)
if args.save_images and output_dir:
@ -950,13 +983,14 @@ def train_one_epoch(
)
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
(update_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=update_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
end = time.time()
update_sample_count = 0
data_start_time = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
@ -1025,16 +1059,11 @@ def validate(
if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
_logger.info(
'{0}: [{1:>4d}/{2}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
log_name, batch_idx, last_idx,
batch_time=batch_time_m,
loss=losses_m,
top1=top1_m,
top5=top5_m)
f'{log_name}: [{batch_idx:>4d}/{last_idx}] '
f'Time: {batch_time_m.val:.3f} ({batch_time_m.avg:.3f}) '
f'Loss: {losses_m.val:>7.3f} ({losses_m.avg:>6.3f}) '
f'Acc@1: {top1_m.val:>7.3f} ({top1_m.avg:>7.3f}) '
f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})'
)
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])