forward() pass through for ema model, flag for ema warmup, comment about warmup

This commit is contained in:
Ross Wightman 2024-02-03 16:24:45 -08:00
parent 5e4a4b2adc
commit bee0471f91
2 changed files with 63 additions and 42 deletions

View File

@ -126,6 +126,9 @@ class ModelEmaV2(nn.Module):
def set(self, model):
self._update(model, update_fn=lambda e, m: m)
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
class ModelEmaV3(nn.Module):
""" Model Exponential Moving Average V3
@ -133,6 +136,13 @@ class ModelEmaV3(nn.Module):
Keep a moving average of everything in the model state_dict (parameters and buffers).
V3 of this module leverages for_each and in-place operations for faster performance.
Decay warmup based on code by @crowsonkb, her comments:
If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
@ -195,25 +205,12 @@ class ModelEmaV3(nn.Module):
@torch.no_grad()
def update(self, model, step: Optional[int] = None):
decay = self.get_decay(step)
if self.exclude_buffers:
# interpolate parameters
ema_params = tuple(self.module.parameters())
model_params = tuple(model.parameters())
if self.foreach:
if hasattr(torch, '_foreach_lerp_'):
torch._foreach_lerp_(ema_params, model_params, weight=1. - decay)
self.apply_update_no_buffers_(model, decay)
else:
torch._foreach_mul_(ema_params, scalar=decay)
torch._foreach_add_(ema_params, model_params, alpha=1 - decay)
else:
for ema_p, model_p in zip(ema_params, model_params):
ema_p.lerp_(model_p, weight=1. - decay)
self.apply_update_(model, decay)
# copy buffers instead of EMA
for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
ema_b.copy_(model_b.to(device=self.device))
else:
def apply_update_(self, model, decay: float):
# interpolate parameters and buffers
if self.foreach:
ema_lerp_values = []
@ -237,7 +234,27 @@ class ModelEmaV3(nn.Module):
else:
ema_v.copy_(model_v)
def apply_update_no_buffers_(self, model, decay: float):
# interpolate parameters, copy buffers
ema_params = tuple(self.module.parameters())
model_params = tuple(model.parameters())
if self.foreach:
if hasattr(torch, '_foreach_lerp_'):
torch._foreach_lerp_(ema_params, model_params, weight=1. - decay)
else:
torch._foreach_mul_(ema_params, scalar=decay)
torch._foreach_add_(ema_params, model_params, alpha=1 - decay)
else:
for ema_p, model_p in zip(ema_params, model_params):
ema_p.lerp_(model_p, weight=1. - decay)
for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
ema_b.copy_(model_b.to(device=self.device))
@torch.no_grad()
def set(self, model):
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
ema_v.copy_(model_v.to(device=self.device))
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)

View File

@ -349,11 +349,13 @@ group.add_argument('--split-bn', action='store_true',
# Model Exponential Moving Average
group = parser.add_argument_group('Model exponential moving average parameters')
group.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
help='Enable tracking moving average of model weights.')
group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
group.add_argument('--model-ema-decay', type=float, default=0.9998,
help='decay factor for model weights moving average (default: 0.9998)')
help='Decay factor for model weights moving average (default: 0.9998)')
group.add_argument('--model-ema-warmup', action='store_true',
help='Enable warmup for model EMA decay.')
# Misc
group = parser.add_argument_group('Miscellaneous parameters')
@ -601,11 +603,13 @@ def main():
model_ema = utils.ModelEmaV3(
model,
decay=args.model_ema_decay,
use_warmup=True,
use_warmup=args.model_ema_warmup,
device='cpu' if args.model_ema_force_cpu else None,
)
if args.resume:
load_checkpoint(model_ema.module, args.resume, use_ema=True)
if args.torchcompile:
model_ema = torch.compile(model_ema, backend=args.torchcompile)
# setup distributed training
if args.distributed:
@ -885,7 +889,7 @@ def main():
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
model_ema.module,
model_ema,
loader_eval,
validate_loss_fn,
args,
@ -1002,7 +1006,7 @@ def train_one_epoch(
if num_updates / num_updates_total > 0.25:
with torch.no_grad():
output_mesa = model_ema.module(input)
output_mesa = model_ema(input)
# loss_mesa = torch.nn.functional.binary_cross_entropy_with_logits(
# output,
@ -1018,7 +1022,7 @@ def train_one_epoch(
(output_mesa / 5).log_softmax(-1).detach(),
log_target=True,
reduction='none').sum(-1).mean()
loss += 5 * loss_mesa
loss += 10 * loss_mesa
if accum_steps > 1:
loss /= accum_steps