diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index 968f4f58..3e491675 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -126,6 +126,9 @@ class ModelEmaV2(nn.Module): def set(self, model): self._update(model, update_fn=lambda e, m: m) + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + class ModelEmaV3(nn.Module): """ Model Exponential Moving Average V3 @@ -133,6 +136,13 @@ class ModelEmaV3(nn.Module): Keep a moving average of everything in the model state_dict (parameters and buffers). V3 of this module leverages for_each and in-place operations for faster performance. + Decay warmup based on code by @crowsonkb, her comments: + If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are + good values for models you plan to train for a million or more steps (reaches decay + factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models + you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at + 215.4k steps). + This is intended to allow functionality like https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage @@ -195,49 +205,56 @@ class ModelEmaV3(nn.Module): @torch.no_grad() def update(self, model, step: Optional[int] = None): decay = self.get_decay(step) - if self.exclude_buffers: - # interpolate parameters - ema_params = tuple(self.module.parameters()) - model_params = tuple(model.parameters()) - if self.foreach: - if hasattr(torch, '_foreach_lerp_'): - torch._foreach_lerp_(ema_params, model_params, weight=1. - decay) - else: - torch._foreach_mul_(ema_params, scalar=decay) - torch._foreach_add_(ema_params, model_params, alpha=1 - decay) - else: - for ema_p, model_p in zip(ema_params, model_params): - ema_p.lerp_(model_p, weight=1. - decay) - - # copy buffers instead of EMA - for ema_b, model_b in zip(self.module.buffers(), model.buffers()): - ema_b.copy_(model_b.to(device=self.device)) + self.apply_update_no_buffers_(model, decay) else: - # interpolate parameters and buffers - if self.foreach: - ema_lerp_values = [] - model_lerp_values = [] - for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): - if ema_v.is_floating_point(): - ema_lerp_values.append(ema_v) - model_lerp_values.append(model_v) - else: - ema_v.copy_(model_v) + self.apply_update_(model, decay) - if hasattr(torch, '_foreach_lerp_'): - torch._foreach_lerp_(ema_lerp_values, model_lerp_values, weight=1. - decay) + def apply_update_(self, model, decay: float): + # interpolate parameters and buffers + if self.foreach: + ema_lerp_values = [] + model_lerp_values = [] + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if ema_v.is_floating_point(): + ema_lerp_values.append(ema_v) + model_lerp_values.append(model_v) else: - torch._foreach_mul_(ema_lerp_values, scalar=decay) - torch._foreach_add_(ema_lerp_values, model_lerp_values, alpha=1. - decay) + ema_v.copy_(model_v) + + if hasattr(torch, '_foreach_lerp_'): + torch._foreach_lerp_(ema_lerp_values, model_lerp_values, weight=1. - decay) else: - for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): - if ema_v.is_floating_point(): - ema_v.lerp_(model_v, weight=1. - decay) - else: - ema_v.copy_(model_v) + torch._foreach_mul_(ema_lerp_values, scalar=decay) + torch._foreach_add_(ema_lerp_values, model_lerp_values, alpha=1. - decay) + else: + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if ema_v.is_floating_point(): + ema_v.lerp_(model_v, weight=1. - decay) + else: + ema_v.copy_(model_v) + + def apply_update_no_buffers_(self, model, decay: float): + # interpolate parameters, copy buffers + ema_params = tuple(self.module.parameters()) + model_params = tuple(model.parameters()) + if self.foreach: + if hasattr(torch, '_foreach_lerp_'): + torch._foreach_lerp_(ema_params, model_params, weight=1. - decay) + else: + torch._foreach_mul_(ema_params, scalar=decay) + torch._foreach_add_(ema_params, model_params, alpha=1 - decay) + else: + for ema_p, model_p in zip(ema_params, model_params): + ema_p.lerp_(model_p, weight=1. - decay) + + for ema_b, model_b in zip(self.module.buffers(), model.buffers()): + ema_b.copy_(model_b.to(device=self.device)) @torch.no_grad() def set(self, model): for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): ema_v.copy_(model_v.to(device=self.device)) + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) \ No newline at end of file diff --git a/train.py b/train.py index a773d855..5e027229 100755 --- a/train.py +++ b/train.py @@ -349,11 +349,13 @@ group.add_argument('--split-bn', action='store_true', # Model Exponential Moving Average group = parser.add_argument_group('Model exponential moving average parameters') group.add_argument('--model-ema', action='store_true', default=False, - help='Enable tracking moving average of model weights') + help='Enable tracking moving average of model weights.') group.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') group.add_argument('--model-ema-decay', type=float, default=0.9998, - help='decay factor for model weights moving average (default: 0.9998)') + help='Decay factor for model weights moving average (default: 0.9998)') +group.add_argument('--model-ema-warmup', action='store_true', + help='Enable warmup for model EMA decay.') # Misc group = parser.add_argument_group('Miscellaneous parameters') @@ -601,11 +603,13 @@ def main(): model_ema = utils.ModelEmaV3( model, decay=args.model_ema_decay, - use_warmup=True, + use_warmup=args.model_ema_warmup, device='cpu' if args.model_ema_force_cpu else None, ) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) + if args.torchcompile: + model_ema = torch.compile(model_ema, backend=args.torchcompile) # setup distributed training if args.distributed: @@ -885,7 +889,7 @@ def main(): utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( - model_ema.module, + model_ema, loader_eval, validate_loss_fn, args, @@ -1002,7 +1006,7 @@ def train_one_epoch( if num_updates / num_updates_total > 0.25: with torch.no_grad(): - output_mesa = model_ema.module(input) + output_mesa = model_ema(input) # loss_mesa = torch.nn.functional.binary_cross_entropy_with_logits( # output, @@ -1018,7 +1022,7 @@ def train_one_epoch( (output_mesa / 5).log_softmax(-1).detach(), log_target=True, reduction='none').sum(-1).mean() - loss += 5 * loss_mesa + loss += 10 * loss_mesa if accum_steps > 1: loss /= accum_steps