mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add gradient accumulation option to train.py
option: iters-to-accum(iterations to accmulate) Gradient accumulation improves training performance(samples/s). It can reduce the number of parameter sharing between each node. This option can be helpful when network is bottleneck. Signed-off-by: Taeksang Kim <voidbag@puzzle-ai.com>
This commit is contained in:
parent
7a13be67a5
commit
7f29a46d44
@ -17,12 +17,13 @@ from .clip_grad import dispatch_clip_grad
|
|||||||
class ApexScaler:
|
class ApexScaler:
|
||||||
state_dict_key = "amp"
|
state_dict_key = "amp"
|
||||||
|
|
||||||
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
|
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, need_step=True):
|
||||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||||
scaled_loss.backward(create_graph=create_graph)
|
scaled_loss.backward(create_graph=create_graph)
|
||||||
if clip_grad is not None:
|
if clip_grad is not None:
|
||||||
dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
|
dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
|
||||||
optimizer.step()
|
if need_step:
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
if 'state_dict' in amp.__dict__:
|
if 'state_dict' in amp.__dict__:
|
||||||
@ -39,14 +40,15 @@ class NativeScaler:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._scaler = torch.cuda.amp.GradScaler()
|
self._scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
|
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, need_step=True):
|
||||||
self._scaler.scale(loss).backward(create_graph=create_graph)
|
self._scaler.scale(loss).backward(create_graph=create_graph)
|
||||||
if clip_grad is not None:
|
if clip_grad is not None:
|
||||||
assert parameters is not None
|
assert parameters is not None
|
||||||
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
||||||
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
|
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
|
||||||
self._scaler.step(optimizer)
|
if need_step:
|
||||||
self._scaler.update()
|
self._scaler.step(optimizer)
|
||||||
|
self._scaler.update()
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return self._scaler.state_dict()
|
return self._scaler.state_dict()
|
||||||
|
176
train.py
176
train.py
@ -130,6 +130,8 @@ group.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
|||||||
help='Image resize interpolation type (overrides model)')
|
help='Image resize interpolation type (overrides model)')
|
||||||
group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
|
group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
|
||||||
help='Input batch size for training (default: 128)')
|
help='Input batch size for training (default: 128)')
|
||||||
|
group.add_argument('--iters-to-accum', type=int, default=1, metavar='N',
|
||||||
|
help='The number of iterations to accumulate gradients (default: 1)')
|
||||||
group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
|
group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
|
||||||
help='Validation batch size override (default: None)')
|
help='Validation batch size override (default: None)')
|
||||||
group.add_argument('--channels-last', action='store_true', default=False,
|
group.add_argument('--channels-last', action='store_true', default=False,
|
||||||
@ -399,6 +401,9 @@ def main():
|
|||||||
if args.amp_dtype == 'bfloat16':
|
if args.amp_dtype == 'bfloat16':
|
||||||
amp_dtype = torch.bfloat16
|
amp_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
# check if iters_to_accum is smaller than or equal to 0.
|
||||||
|
assert args.iters_to_accum > 0, 'The argument "iters-to-accum" must be greater than zero.'
|
||||||
|
|
||||||
utils.random_seed(args.seed, args.rank)
|
utils.random_seed(args.seed, args.rank)
|
||||||
|
|
||||||
if args.fuser:
|
if args.fuser:
|
||||||
@ -851,11 +856,23 @@ def train_one_epoch(
|
|||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
num_batches_per_epoch = len(loader)
|
num_batches_per_epoch = (len(loader) + args.iters_to_accum - 1) // args.iters_to_accum
|
||||||
last_idx = num_batches_per_epoch - 1
|
last_idx = len(loader) - 1
|
||||||
|
last_iters_to_accum = len(loader) % args.iters_to_accum
|
||||||
|
last_idx_to_accum = len(loader) - last_iters_to_accum
|
||||||
num_updates = epoch * num_batches_per_epoch
|
num_updates = epoch * num_batches_per_epoch
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
num_step_samples = 0
|
||||||
for batch_idx, (input, target) in enumerate(loader):
|
for batch_idx, (input, target) in enumerate(loader):
|
||||||
last_batch = batch_idx == last_idx
|
last_batch = batch_idx == last_idx
|
||||||
|
iters_to_accum = args.iters_to_accum
|
||||||
|
if batch_idx >= last_idx_to_accum:
|
||||||
|
iters_to_accum = last_iters_to_accum
|
||||||
|
need_step = False
|
||||||
|
if (batch_idx + 1) % args.iters_to_accum == 0 or last_batch:
|
||||||
|
need_step =True
|
||||||
|
|
||||||
data_time_m.update(time.time() - end)
|
data_time_m.update(time.time() - end)
|
||||||
if not args.prefetcher:
|
if not args.prefetcher:
|
||||||
input, target = input.to(device), target.to(device)
|
input, target = input.to(device), target.to(device)
|
||||||
@ -864,82 +881,101 @@ def train_one_epoch(
|
|||||||
if args.channels_last:
|
if args.channels_last:
|
||||||
input = input.contiguous(memory_format=torch.channels_last)
|
input = input.contiguous(memory_format=torch.channels_last)
|
||||||
|
|
||||||
with amp_autocast():
|
def _forward():
|
||||||
output = model(input)
|
with amp_autocast():
|
||||||
loss = loss_fn(output, target)
|
output = model(input)
|
||||||
|
loss = loss_fn(output, target)
|
||||||
|
loss /= iters_to_accum
|
||||||
|
return loss
|
||||||
|
|
||||||
|
if need_step is not True and hasattr(model, "no_sync"):
|
||||||
|
with model.no_sync():
|
||||||
|
loss = _forward()
|
||||||
|
else:
|
||||||
|
loss = _forward()
|
||||||
|
|
||||||
if not args.distributed:
|
if not args.distributed:
|
||||||
losses_m.update(loss.item(), input.size(0))
|
losses_m.update(loss.item() * iters_to_accum, input.size(0))
|
||||||
|
|
||||||
optimizer.zero_grad()
|
def _backward():
|
||||||
if loss_scaler is not None:
|
if loss_scaler is not None:
|
||||||
loss_scaler(
|
loss_scaler(
|
||||||
loss, optimizer,
|
loss, optimizer,
|
||||||
clip_grad=args.clip_grad,
|
clip_grad=args.clip_grad,
|
||||||
clip_mode=args.clip_mode,
|
clip_mode=args.clip_mode,
|
||||||
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
|
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
|
||||||
create_graph=second_order
|
create_graph=second_order,
|
||||||
)
|
need_step=need_step
|
||||||
else:
|
|
||||||
loss.backward(create_graph=second_order)
|
|
||||||
if args.clip_grad is not None:
|
|
||||||
utils.dispatch_clip_grad(
|
|
||||||
model_parameters(model, exclude_head='agc' in args.clip_mode),
|
|
||||||
value=args.clip_grad,
|
|
||||||
mode=args.clip_mode
|
|
||||||
)
|
)
|
||||||
optimizer.step()
|
else:
|
||||||
|
loss.backward(create_graph=second_order)
|
||||||
if model_ema is not None:
|
if args.clip_grad is not None:
|
||||||
model_ema.update(model)
|
utils.dispatch_clip_grad(
|
||||||
|
model_parameters(model, exclude_head='agc' in args.clip_mode),
|
||||||
torch.cuda.synchronize()
|
value=args.clip_grad,
|
||||||
|
mode=args.clip_mode
|
||||||
num_updates += 1
|
|
||||||
batch_time_m.update(time.time() - end)
|
|
||||||
if last_batch or batch_idx % args.log_interval == 0:
|
|
||||||
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
|
|
||||||
lr = sum(lrl) / len(lrl)
|
|
||||||
|
|
||||||
if args.distributed:
|
|
||||||
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
|
|
||||||
losses_m.update(reduced_loss.item(), input.size(0))
|
|
||||||
|
|
||||||
if utils.is_primary(args):
|
|
||||||
_logger.info(
|
|
||||||
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
|
|
||||||
'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
|
|
||||||
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
|
|
||||||
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
|
||||||
'LR: {lr:.3e} '
|
|
||||||
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
|
|
||||||
epoch,
|
|
||||||
batch_idx, len(loader),
|
|
||||||
100. * batch_idx / last_idx,
|
|
||||||
loss=losses_m,
|
|
||||||
batch_time=batch_time_m,
|
|
||||||
rate=input.size(0) * args.world_size / batch_time_m.val,
|
|
||||||
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
|
|
||||||
lr=lr,
|
|
||||||
data_time=data_time_m)
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.save_images and output_dir:
|
|
||||||
torchvision.utils.save_image(
|
|
||||||
input,
|
|
||||||
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
|
|
||||||
padding=0,
|
|
||||||
normalize=True
|
|
||||||
)
|
)
|
||||||
|
if need_step:
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
if saver is not None and args.recovery_interval and (
|
num_step_samples += input.size(0)
|
||||||
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
if need_step is not True and hasattr(model, "no_sync"):
|
||||||
saver.save_recovery(epoch, batch_idx=batch_idx)
|
with model.no_sync():
|
||||||
|
_backward()
|
||||||
|
else:
|
||||||
|
_backward()
|
||||||
|
if need_step:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
if model_ema is not None:
|
||||||
|
model_ema.update(model)
|
||||||
|
|
||||||
if lr_scheduler is not None:
|
torch.cuda.synchronize()
|
||||||
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
num_updates += 1
|
||||||
|
batch_time_m.update(time.time() - end)
|
||||||
|
|
||||||
end = time.time()
|
if (batch_idx // args.iters_to_accum) % args.log_interval == 0:
|
||||||
|
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
|
||||||
|
lr = sum(lrl) / len(lrl)
|
||||||
|
|
||||||
|
if args.distributed:
|
||||||
|
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
|
||||||
|
losses_m.update(reduced_loss.item() * iters_to_accum, input.size(0))
|
||||||
|
|
||||||
|
if utils.is_primary(args):
|
||||||
|
_logger.info(
|
||||||
|
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
|
||||||
|
'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
|
||||||
|
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
|
||||||
|
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
||||||
|
'LR: {lr:.3e} '
|
||||||
|
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
|
||||||
|
epoch,
|
||||||
|
batch_idx, len(loader),
|
||||||
|
100. * batch_idx / last_idx,
|
||||||
|
loss=losses_m,
|
||||||
|
batch_time=batch_time_m,
|
||||||
|
rate=num_step_samples * args.world_size / batch_time_m.val,
|
||||||
|
rate_avg=num_step_samples * args.world_size / batch_time_m.avg,
|
||||||
|
lr=lr,
|
||||||
|
data_time=data_time_m)
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.save_images and output_dir:
|
||||||
|
torchvision.utils.save_image(
|
||||||
|
input,
|
||||||
|
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
|
||||||
|
padding=0,
|
||||||
|
normalize=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if saver is not None and args.recovery_interval and (
|
||||||
|
(batch_idx // args.iters_to_accum + 1) % args.recovery_interval == 0):
|
||||||
|
saver.save_recovery(epoch, batch_idx=batch_idx)
|
||||||
|
|
||||||
|
if lr_scheduler is not None:
|
||||||
|
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
||||||
|
num_step_samples = 0
|
||||||
|
end = time.time()
|
||||||
# end for
|
# end for
|
||||||
|
|
||||||
if hasattr(optimizer, 'sync_lookahead'):
|
if hasattr(optimizer, 'sync_lookahead'):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user