Add gradient accumulation option to train.py
option: iters-to-accum(iterations to accmulate) Gradient accumulation improves training performance(samples/s). It can reduce the number of parameter sharing between each node. This option can be helpful when network is bottleneck. Signed-off-by: Taeksang Kim <voidbag@puzzle-ai.com>pull/1659/head
parent
7a13be67a5
commit
7f29a46d44
|
@ -17,12 +17,13 @@ from .clip_grad import dispatch_clip_grad
|
|||
class ApexScaler:
|
||||
state_dict_key = "amp"
|
||||
|
||||
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
|
||||
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, need_step=True):
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward(create_graph=create_graph)
|
||||
if clip_grad is not None:
|
||||
dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
|
||||
optimizer.step()
|
||||
if need_step:
|
||||
optimizer.step()
|
||||
|
||||
def state_dict(self):
|
||||
if 'state_dict' in amp.__dict__:
|
||||
|
@ -39,14 +40,15 @@ class NativeScaler:
|
|||
def __init__(self):
|
||||
self._scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
|
||||
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, need_step=True):
|
||||
self._scaler.scale(loss).backward(create_graph=create_graph)
|
||||
if clip_grad is not None:
|
||||
assert parameters is not None
|
||||
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
||||
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
|
||||
self._scaler.step(optimizer)
|
||||
self._scaler.update()
|
||||
if need_step:
|
||||
self._scaler.step(optimizer)
|
||||
self._scaler.update()
|
||||
|
||||
def state_dict(self):
|
||||
return self._scaler.state_dict()
|
||||
|
|
176
train.py
176
train.py
|
@ -130,6 +130,8 @@ group.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
|||
help='Image resize interpolation type (overrides model)')
|
||||
group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
|
||||
help='Input batch size for training (default: 128)')
|
||||
group.add_argument('--iters-to-accum', type=int, default=1, metavar='N',
|
||||
help='The number of iterations to accumulate gradients (default: 1)')
|
||||
group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
|
||||
help='Validation batch size override (default: None)')
|
||||
group.add_argument('--channels-last', action='store_true', default=False,
|
||||
|
@ -399,6 +401,9 @@ def main():
|
|||
if args.amp_dtype == 'bfloat16':
|
||||
amp_dtype = torch.bfloat16
|
||||
|
||||
# check if iters_to_accum is smaller than or equal to 0.
|
||||
assert args.iters_to_accum > 0, 'The argument "iters-to-accum" must be greater than zero.'
|
||||
|
||||
utils.random_seed(args.seed, args.rank)
|
||||
|
||||
if args.fuser:
|
||||
|
@ -851,11 +856,23 @@ def train_one_epoch(
|
|||
model.train()
|
||||
|
||||
end = time.time()
|
||||
num_batches_per_epoch = len(loader)
|
||||
last_idx = num_batches_per_epoch - 1
|
||||
num_batches_per_epoch = (len(loader) + args.iters_to_accum - 1) // args.iters_to_accum
|
||||
last_idx = len(loader) - 1
|
||||
last_iters_to_accum = len(loader) % args.iters_to_accum
|
||||
last_idx_to_accum = len(loader) - last_iters_to_accum
|
||||
num_updates = epoch * num_batches_per_epoch
|
||||
|
||||
optimizer.zero_grad()
|
||||
num_step_samples = 0
|
||||
for batch_idx, (input, target) in enumerate(loader):
|
||||
last_batch = batch_idx == last_idx
|
||||
iters_to_accum = args.iters_to_accum
|
||||
if batch_idx >= last_idx_to_accum:
|
||||
iters_to_accum = last_iters_to_accum
|
||||
need_step = False
|
||||
if (batch_idx + 1) % args.iters_to_accum == 0 or last_batch:
|
||||
need_step =True
|
||||
|
||||
data_time_m.update(time.time() - end)
|
||||
if not args.prefetcher:
|
||||
input, target = input.to(device), target.to(device)
|
||||
|
@ -864,82 +881,101 @@ def train_one_epoch(
|
|||
if args.channels_last:
|
||||
input = input.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
with amp_autocast():
|
||||
output = model(input)
|
||||
loss = loss_fn(output, target)
|
||||
def _forward():
|
||||
with amp_autocast():
|
||||
output = model(input)
|
||||
loss = loss_fn(output, target)
|
||||
loss /= iters_to_accum
|
||||
return loss
|
||||
|
||||
if need_step is not True and hasattr(model, "no_sync"):
|
||||
with model.no_sync():
|
||||
loss = _forward()
|
||||
else:
|
||||
loss = _forward()
|
||||
|
||||
if not args.distributed:
|
||||
losses_m.update(loss.item(), input.size(0))
|
||||
losses_m.update(loss.item() * iters_to_accum, input.size(0))
|
||||
|
||||
optimizer.zero_grad()
|
||||
if loss_scaler is not None:
|
||||
loss_scaler(
|
||||
loss, optimizer,
|
||||
clip_grad=args.clip_grad,
|
||||
clip_mode=args.clip_mode,
|
||||
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
|
||||
create_graph=second_order
|
||||
)
|
||||
else:
|
||||
loss.backward(create_graph=second_order)
|
||||
if args.clip_grad is not None:
|
||||
utils.dispatch_clip_grad(
|
||||
model_parameters(model, exclude_head='agc' in args.clip_mode),
|
||||
value=args.clip_grad,
|
||||
mode=args.clip_mode
|
||||
def _backward():
|
||||
if loss_scaler is not None:
|
||||
loss_scaler(
|
||||
loss, optimizer,
|
||||
clip_grad=args.clip_grad,
|
||||
clip_mode=args.clip_mode,
|
||||
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
|
||||
create_graph=second_order,
|
||||
need_step=need_step
|
||||
)
|
||||
optimizer.step()
|
||||
|
||||
if model_ema is not None:
|
||||
model_ema.update(model)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
num_updates += 1
|
||||
batch_time_m.update(time.time() - end)
|
||||
if last_batch or batch_idx % args.log_interval == 0:
|
||||
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
|
||||
lr = sum(lrl) / len(lrl)
|
||||
|
||||
if args.distributed:
|
||||
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
|
||||
losses_m.update(reduced_loss.item(), input.size(0))
|
||||
|
||||
if utils.is_primary(args):
|
||||
_logger.info(
|
||||
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
|
||||
'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
|
||||
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
|
||||
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
||||
'LR: {lr:.3e} '
|
||||
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
|
||||
epoch,
|
||||
batch_idx, len(loader),
|
||||
100. * batch_idx / last_idx,
|
||||
loss=losses_m,
|
||||
batch_time=batch_time_m,
|
||||
rate=input.size(0) * args.world_size / batch_time_m.val,
|
||||
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
|
||||
lr=lr,
|
||||
data_time=data_time_m)
|
||||
)
|
||||
|
||||
if args.save_images and output_dir:
|
||||
torchvision.utils.save_image(
|
||||
input,
|
||||
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
|
||||
padding=0,
|
||||
normalize=True
|
||||
else:
|
||||
loss.backward(create_graph=second_order)
|
||||
if args.clip_grad is not None:
|
||||
utils.dispatch_clip_grad(
|
||||
model_parameters(model, exclude_head='agc' in args.clip_mode),
|
||||
value=args.clip_grad,
|
||||
mode=args.clip_mode
|
||||
)
|
||||
if need_step:
|
||||
optimizer.step()
|
||||
|
||||
if saver is not None and args.recovery_interval and (
|
||||
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
||||
saver.save_recovery(epoch, batch_idx=batch_idx)
|
||||
num_step_samples += input.size(0)
|
||||
if need_step is not True and hasattr(model, "no_sync"):
|
||||
with model.no_sync():
|
||||
_backward()
|
||||
else:
|
||||
_backward()
|
||||
if need_step:
|
||||
optimizer.zero_grad()
|
||||
if model_ema is not None:
|
||||
model_ema.update(model)
|
||||
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
||||
torch.cuda.synchronize()
|
||||
num_updates += 1
|
||||
batch_time_m.update(time.time() - end)
|
||||
|
||||
end = time.time()
|
||||
if (batch_idx // args.iters_to_accum) % args.log_interval == 0:
|
||||
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
|
||||
lr = sum(lrl) / len(lrl)
|
||||
|
||||
if args.distributed:
|
||||
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
|
||||
losses_m.update(reduced_loss.item() * iters_to_accum, input.size(0))
|
||||
|
||||
if utils.is_primary(args):
|
||||
_logger.info(
|
||||
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
|
||||
'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
|
||||
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
|
||||
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
||||
'LR: {lr:.3e} '
|
||||
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
|
||||
epoch,
|
||||
batch_idx, len(loader),
|
||||
100. * batch_idx / last_idx,
|
||||
loss=losses_m,
|
||||
batch_time=batch_time_m,
|
||||
rate=num_step_samples * args.world_size / batch_time_m.val,
|
||||
rate_avg=num_step_samples * args.world_size / batch_time_m.avg,
|
||||
lr=lr,
|
||||
data_time=data_time_m)
|
||||
)
|
||||
|
||||
if args.save_images and output_dir:
|
||||
torchvision.utils.save_image(
|
||||
input,
|
||||
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
|
||||
padding=0,
|
||||
normalize=True
|
||||
)
|
||||
|
||||
if saver is not None and args.recovery_interval and (
|
||||
(batch_idx // args.iters_to_accum + 1) % args.recovery_interval == 0):
|
||||
saver.save_recovery(epoch, batch_idx=batch_idx)
|
||||
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
||||
num_step_samples = 0
|
||||
end = time.time()
|
||||
# end for
|
||||
|
||||
if hasattr(optimizer, 'sync_lookahead'):
|
||||
|
|
Loading…
Reference in New Issue