mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #1784 from huggingface/wip-voidbag-accumulate-grad
Accumulate gradients (adding to #1659)
This commit is contained in:
commit
2aabaef039
@ -17,12 +17,22 @@ from .clip_grad import dispatch_clip_grad
|
||||
class ApexScaler:
|
||||
state_dict_key = "amp"
|
||||
|
||||
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
|
||||
def __call__(
|
||||
self,
|
||||
loss,
|
||||
optimizer,
|
||||
clip_grad=None,
|
||||
clip_mode='norm',
|
||||
parameters=None,
|
||||
create_graph=False,
|
||||
need_update=True,
|
||||
):
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward(create_graph=create_graph)
|
||||
if clip_grad is not None:
|
||||
dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
|
||||
optimizer.step()
|
||||
if need_update:
|
||||
if clip_grad is not None:
|
||||
dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
|
||||
optimizer.step()
|
||||
|
||||
def state_dict(self):
|
||||
if 'state_dict' in amp.__dict__:
|
||||
@ -39,14 +49,24 @@ class NativeScaler:
|
||||
def __init__(self):
|
||||
self._scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
|
||||
def __call__(
|
||||
self,
|
||||
loss,
|
||||
optimizer,
|
||||
clip_grad=None,
|
||||
clip_mode='norm',
|
||||
parameters=None,
|
||||
create_graph=False,
|
||||
need_update=True,
|
||||
):
|
||||
self._scaler.scale(loss).backward(create_graph=create_graph)
|
||||
if clip_grad is not None:
|
||||
assert parameters is not None
|
||||
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
||||
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
|
||||
self._scaler.step(optimizer)
|
||||
self._scaler.update()
|
||||
if need_update:
|
||||
if clip_grad is not None:
|
||||
assert parameters is not None
|
||||
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
||||
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
|
||||
self._scaler.step(optimizer)
|
||||
self._scaler.update()
|
||||
|
||||
def state_dict(self):
|
||||
return self._scaler.state_dict()
|
||||
|
181
train.py
181
train.py
@ -136,6 +136,8 @@ group.add_argument('--channels-last', action='store_true', default=False,
|
||||
help='Use channels_last memory layout')
|
||||
group.add_argument('--fuser', default='', type=str,
|
||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||
group.add_argument('--grad-accum-steps', type=int, default=1, metavar='N',
|
||||
help='The number of steps to accumulate gradients (default: 1)')
|
||||
group.add_argument('--grad-checkpointing', action='store_true', default=False,
|
||||
help='Enable gradient checkpointing through model blocks/stages')
|
||||
group.add_argument('--fast-norm', default=False, action='store_true',
|
||||
@ -146,13 +148,12 @@ group.add_argument('--head-init-scale', default=None, type=float,
|
||||
group.add_argument('--head-init-bias', default=None, type=float,
|
||||
help='Head initialization bias value')
|
||||
|
||||
# scripting / codegen
|
||||
scripting_group = group.add_mutually_exclusive_group()
|
||||
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='torch.jit.script the full model')
|
||||
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
|
||||
help="Enable compilation w/ specified backend (default: inductor).")
|
||||
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
|
||||
help="Enable AOT Autograd support.")
|
||||
|
||||
# Optimizer parameters
|
||||
group = parser.add_argument_group('Optimizer parameters')
|
||||
@ -334,6 +335,8 @@ group.add_argument('--amp-impl', default='native', type=str,
|
||||
help='AMP impl to use, "native" or "apex" (default: native)')
|
||||
group.add_argument('--no-ddp-bb', action='store_true', default=False,
|
||||
help='Force broadcast buffers for native DDP to off.')
|
||||
group.add_argument('--synchronize-step', action='store_true', default=False,
|
||||
help='torch.cuda.synchronize() end of each step')
|
||||
group.add_argument('--pin-mem', action='store_true', default=False,
|
||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
||||
group.add_argument('--no-prefetcher', action='store_true', default=False,
|
||||
@ -379,6 +382,7 @@ def main():
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
args.prefetcher = not args.no_prefetcher
|
||||
args.grad_accum_steps = max(1, args.grad_accum_steps)
|
||||
device = utils.init_distributed_device(args)
|
||||
if args.distributed:
|
||||
_logger.info(
|
||||
@ -483,20 +487,13 @@ def main():
|
||||
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
|
||||
|
||||
if args.torchscript:
|
||||
assert not args.torchcompile
|
||||
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
|
||||
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
|
||||
model = torch.jit.script(model)
|
||||
elif args.torchcompile:
|
||||
# FIXME dynamo might need move below DDP wrapping? TBD
|
||||
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
|
||||
torch._dynamo.reset()
|
||||
model = torch.compile(model, backend=args.torchcompile)
|
||||
elif args.aot_autograd:
|
||||
assert has_functorch, "functorch is needed for --aot-autograd"
|
||||
model = memory_efficient_fusion(model)
|
||||
|
||||
if not args.lr:
|
||||
global_batch_size = args.batch_size * args.world_size
|
||||
global_batch_size = args.batch_size * args.world_size * args.grad_accum_steps
|
||||
batch_ratio = global_batch_size / args.lr_base_size
|
||||
if not args.lr_base_scale:
|
||||
on = args.opt.lower()
|
||||
@ -507,7 +504,7 @@ def main():
|
||||
if utils.is_primary(args):
|
||||
_logger.info(
|
||||
f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
|
||||
f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
|
||||
f'and effective global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
|
||||
|
||||
optimizer = create_optimizer_v2(
|
||||
model,
|
||||
@ -573,6 +570,11 @@ def main():
|
||||
model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
|
||||
# NOTE: EMA model does not need to be wrapped by DDP
|
||||
|
||||
if args.torchcompile:
|
||||
# torch compile should be done after DDP
|
||||
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
|
||||
model = torch.compile(model, backend=args.torchcompile)
|
||||
|
||||
# create the train and eval datasets
|
||||
if args.data and not args.data_dir:
|
||||
args.data_dir = args.data
|
||||
@ -738,7 +740,7 @@ def main():
|
||||
"Metrics not being logged to wandb, try `pip install wandb`")
|
||||
|
||||
# setup learning rate schedule and starting epoch
|
||||
updates_per_epoch = len(loader_train)
|
||||
updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
|
||||
lr_scheduler, num_epochs = create_scheduler_v2(
|
||||
optimizer,
|
||||
**scheduler_kwargs(args),
|
||||
@ -852,7 +854,7 @@ def train_one_epoch(
|
||||
amp_autocast=suppress,
|
||||
loss_scaler=None,
|
||||
model_ema=None,
|
||||
mixup_fn=None
|
||||
mixup_fn=None,
|
||||
):
|
||||
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
||||
if args.prefetcher and loader.mixup_enabled:
|
||||
@ -861,19 +863,30 @@ def train_one_epoch(
|
||||
mixup_fn.mixup_enabled = False
|
||||
|
||||
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
|
||||
batch_time_m = utils.AverageMeter()
|
||||
has_no_sync = hasattr(model, "no_sync")
|
||||
update_time_m = utils.AverageMeter()
|
||||
data_time_m = utils.AverageMeter()
|
||||
losses_m = utils.AverageMeter()
|
||||
|
||||
model.train()
|
||||
|
||||
end = time.time()
|
||||
num_batches_per_epoch = len(loader)
|
||||
last_idx = num_batches_per_epoch - 1
|
||||
num_updates = epoch * num_batches_per_epoch
|
||||
accum_steps = args.grad_accum_steps
|
||||
last_accum_steps = len(loader) % accum_steps
|
||||
updates_per_epoch = (len(loader) + accum_steps - 1) // accum_steps
|
||||
num_updates = epoch * updates_per_epoch
|
||||
last_batch_idx = len(loader) - 1
|
||||
last_batch_idx_to_accum = len(loader) - last_accum_steps
|
||||
|
||||
data_start_time = update_start_time = time.time()
|
||||
optimizer.zero_grad()
|
||||
update_sample_count = 0
|
||||
for batch_idx, (input, target) in enumerate(loader):
|
||||
last_batch = batch_idx == last_idx
|
||||
data_time_m.update(time.time() - end)
|
||||
last_batch = batch_idx == last_batch_idx
|
||||
need_update = last_batch or (batch_idx + 1) % accum_steps == 0
|
||||
update_idx = batch_idx // accum_steps
|
||||
if batch_idx >= last_batch_idx_to_accum:
|
||||
accum_steps = last_accum_steps
|
||||
|
||||
if not args.prefetcher:
|
||||
input, target = input.to(device), target.to(device)
|
||||
if mixup_fn is not None:
|
||||
@ -881,64 +894,84 @@ def train_one_epoch(
|
||||
if args.channels_last:
|
||||
input = input.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
with amp_autocast():
|
||||
output = model(input)
|
||||
loss = loss_fn(output, target)
|
||||
# multiply by accum steps to get equivalent for full update
|
||||
data_time_m.update(accum_steps * (time.time() - data_start_time))
|
||||
|
||||
def _forward():
|
||||
with amp_autocast():
|
||||
output = model(input)
|
||||
loss = loss_fn(output, target)
|
||||
if accum_steps > 1:
|
||||
loss /= accum_steps
|
||||
return loss
|
||||
|
||||
def _backward(_loss):
|
||||
if loss_scaler is not None:
|
||||
loss_scaler(
|
||||
_loss,
|
||||
optimizer,
|
||||
clip_grad=args.clip_grad,
|
||||
clip_mode=args.clip_mode,
|
||||
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
|
||||
create_graph=second_order,
|
||||
need_update=need_update,
|
||||
)
|
||||
else:
|
||||
_loss.backward(create_graph=second_order)
|
||||
if need_update:
|
||||
if args.clip_grad is not None:
|
||||
utils.dispatch_clip_grad(
|
||||
model_parameters(model, exclude_head='agc' in args.clip_mode),
|
||||
value=args.clip_grad,
|
||||
mode=args.clip_mode,
|
||||
)
|
||||
optimizer.step()
|
||||
|
||||
if has_no_sync and not need_update:
|
||||
with model.no_sync():
|
||||
loss = _forward()
|
||||
_backward(loss)
|
||||
else:
|
||||
loss = _forward()
|
||||
_backward(loss)
|
||||
|
||||
if not args.distributed:
|
||||
losses_m.update(loss.item(), input.size(0))
|
||||
losses_m.update(loss.item() * accum_steps, input.size(0))
|
||||
update_sample_count += input.size(0)
|
||||
|
||||
if not need_update:
|
||||
data_start_time = time.time()
|
||||
continue
|
||||
|
||||
num_updates += 1
|
||||
optimizer.zero_grad()
|
||||
if loss_scaler is not None:
|
||||
loss_scaler(
|
||||
loss, optimizer,
|
||||
clip_grad=args.clip_grad,
|
||||
clip_mode=args.clip_mode,
|
||||
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
|
||||
create_graph=second_order
|
||||
)
|
||||
else:
|
||||
loss.backward(create_graph=second_order)
|
||||
if args.clip_grad is not None:
|
||||
utils.dispatch_clip_grad(
|
||||
model_parameters(model, exclude_head='agc' in args.clip_mode),
|
||||
value=args.clip_grad,
|
||||
mode=args.clip_mode
|
||||
)
|
||||
optimizer.step()
|
||||
|
||||
if model_ema is not None:
|
||||
model_ema.update(model)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
if args.synchronize_step and device.type == 'cuda':
|
||||
torch.cuda.synchronize()
|
||||
time_now = time.time()
|
||||
update_time_m.update(time.time() - update_start_time)
|
||||
update_start_time = time_now
|
||||
|
||||
num_updates += 1
|
||||
batch_time_m.update(time.time() - end)
|
||||
if last_batch or batch_idx % args.log_interval == 0:
|
||||
if update_idx % args.log_interval == 0:
|
||||
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
|
||||
lr = sum(lrl) / len(lrl)
|
||||
|
||||
if args.distributed:
|
||||
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
|
||||
losses_m.update(reduced_loss.item(), input.size(0))
|
||||
losses_m.update(reduced_loss.item() * accum_steps, input.size(0))
|
||||
update_sample_count *= args.world_size
|
||||
|
||||
if utils.is_primary(args):
|
||||
_logger.info(
|
||||
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
|
||||
'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
|
||||
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
|
||||
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
||||
'LR: {lr:.3e} '
|
||||
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
|
||||
epoch,
|
||||
batch_idx, len(loader),
|
||||
100. * batch_idx / last_idx,
|
||||
loss=losses_m,
|
||||
batch_time=batch_time_m,
|
||||
rate=input.size(0) * args.world_size / batch_time_m.val,
|
||||
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
|
||||
lr=lr,
|
||||
data_time=data_time_m)
|
||||
f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
|
||||
f'({100. * update_idx / (updates_per_epoch - 1):>3.0f}%)] '
|
||||
f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g}) '
|
||||
f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s '
|
||||
f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) '
|
||||
f'LR: {lr:.3e} '
|
||||
f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})'
|
||||
)
|
||||
|
||||
if args.save_images and output_dir:
|
||||
@ -950,13 +983,14 @@ def train_one_epoch(
|
||||
)
|
||||
|
||||
if saver is not None and args.recovery_interval and (
|
||||
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
||||
saver.save_recovery(epoch, batch_idx=batch_idx)
|
||||
(update_idx + 1) % args.recovery_interval == 0):
|
||||
saver.save_recovery(epoch, batch_idx=update_idx)
|
||||
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
||||
|
||||
end = time.time()
|
||||
update_sample_count = 0
|
||||
data_start_time = time.time()
|
||||
# end for
|
||||
|
||||
if hasattr(optimizer, 'sync_lookahead'):
|
||||
@ -1025,16 +1059,11 @@ def validate(
|
||||
if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
|
||||
log_name = 'Test' + log_suffix
|
||||
_logger.info(
|
||||
'{0}: [{1:>4d}/{2}] '
|
||||
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
||||
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
||||
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
|
||||
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
|
||||
log_name, batch_idx, last_idx,
|
||||
batch_time=batch_time_m,
|
||||
loss=losses_m,
|
||||
top1=top1_m,
|
||||
top5=top5_m)
|
||||
f'{log_name}: [{batch_idx:>4d}/{last_idx}] '
|
||||
f'Time: {batch_time_m.val:.3f} ({batch_time_m.avg:.3f}) '
|
||||
f'Loss: {losses_m.val:>7.3f} ({losses_m.avg:>6.3f}) '
|
||||
f'Acc@1: {top1_m.val:>7.3f} ({top1_m.avg:>7.3f}) '
|
||||
f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})'
|
||||
)
|
||||
|
||||
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
|
||||
|
Loading…
x
Reference in New Issue
Block a user