mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
forward() pass through for ema model, flag for ema warmup, comment about warmup
This commit is contained in:
parent
5e4a4b2adc
commit
bee0471f91
@ -126,6 +126,9 @@ class ModelEmaV2(nn.Module):
|
||||
def set(self, model):
|
||||
self._update(model, update_fn=lambda e, m: m)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
|
||||
class ModelEmaV3(nn.Module):
|
||||
""" Model Exponential Moving Average V3
|
||||
@ -133,6 +136,13 @@ class ModelEmaV3(nn.Module):
|
||||
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
||||
V3 of this module leverages for_each and in-place operations for faster performance.
|
||||
|
||||
Decay warmup based on code by @crowsonkb, her comments:
|
||||
If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
|
||||
good values for models you plan to train for a million or more steps (reaches decay
|
||||
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
|
||||
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
||||
215.4k steps).
|
||||
|
||||
This is intended to allow functionality like
|
||||
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
|
||||
@ -195,25 +205,12 @@ class ModelEmaV3(nn.Module):
|
||||
@torch.no_grad()
|
||||
def update(self, model, step: Optional[int] = None):
|
||||
decay = self.get_decay(step)
|
||||
|
||||
if self.exclude_buffers:
|
||||
# interpolate parameters
|
||||
ema_params = tuple(self.module.parameters())
|
||||
model_params = tuple(model.parameters())
|
||||
if self.foreach:
|
||||
if hasattr(torch, '_foreach_lerp_'):
|
||||
torch._foreach_lerp_(ema_params, model_params, weight=1. - decay)
|
||||
self.apply_update_no_buffers_(model, decay)
|
||||
else:
|
||||
torch._foreach_mul_(ema_params, scalar=decay)
|
||||
torch._foreach_add_(ema_params, model_params, alpha=1 - decay)
|
||||
else:
|
||||
for ema_p, model_p in zip(ema_params, model_params):
|
||||
ema_p.lerp_(model_p, weight=1. - decay)
|
||||
self.apply_update_(model, decay)
|
||||
|
||||
# copy buffers instead of EMA
|
||||
for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
|
||||
ema_b.copy_(model_b.to(device=self.device))
|
||||
else:
|
||||
def apply_update_(self, model, decay: float):
|
||||
# interpolate parameters and buffers
|
||||
if self.foreach:
|
||||
ema_lerp_values = []
|
||||
@ -237,7 +234,27 @@ class ModelEmaV3(nn.Module):
|
||||
else:
|
||||
ema_v.copy_(model_v)
|
||||
|
||||
def apply_update_no_buffers_(self, model, decay: float):
|
||||
# interpolate parameters, copy buffers
|
||||
ema_params = tuple(self.module.parameters())
|
||||
model_params = tuple(model.parameters())
|
||||
if self.foreach:
|
||||
if hasattr(torch, '_foreach_lerp_'):
|
||||
torch._foreach_lerp_(ema_params, model_params, weight=1. - decay)
|
||||
else:
|
||||
torch._foreach_mul_(ema_params, scalar=decay)
|
||||
torch._foreach_add_(ema_params, model_params, alpha=1 - decay)
|
||||
else:
|
||||
for ema_p, model_p in zip(ema_params, model_params):
|
||||
ema_p.lerp_(model_p, weight=1. - decay)
|
||||
|
||||
for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
|
||||
ema_b.copy_(model_b.to(device=self.device))
|
||||
|
||||
@torch.no_grad()
|
||||
def set(self, model):
|
||||
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
|
||||
ema_v.copy_(model_v.to(device=self.device))
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
16
train.py
16
train.py
@ -349,11 +349,13 @@ group.add_argument('--split-bn', action='store_true',
|
||||
# Model Exponential Moving Average
|
||||
group = parser.add_argument_group('Model exponential moving average parameters')
|
||||
group.add_argument('--model-ema', action='store_true', default=False,
|
||||
help='Enable tracking moving average of model weights')
|
||||
help='Enable tracking moving average of model weights.')
|
||||
group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
|
||||
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
|
||||
group.add_argument('--model-ema-decay', type=float, default=0.9998,
|
||||
help='decay factor for model weights moving average (default: 0.9998)')
|
||||
help='Decay factor for model weights moving average (default: 0.9998)')
|
||||
group.add_argument('--model-ema-warmup', action='store_true',
|
||||
help='Enable warmup for model EMA decay.')
|
||||
|
||||
# Misc
|
||||
group = parser.add_argument_group('Miscellaneous parameters')
|
||||
@ -601,11 +603,13 @@ def main():
|
||||
model_ema = utils.ModelEmaV3(
|
||||
model,
|
||||
decay=args.model_ema_decay,
|
||||
use_warmup=True,
|
||||
use_warmup=args.model_ema_warmup,
|
||||
device='cpu' if args.model_ema_force_cpu else None,
|
||||
)
|
||||
if args.resume:
|
||||
load_checkpoint(model_ema.module, args.resume, use_ema=True)
|
||||
if args.torchcompile:
|
||||
model_ema = torch.compile(model_ema, backend=args.torchcompile)
|
||||
|
||||
# setup distributed training
|
||||
if args.distributed:
|
||||
@ -885,7 +889,7 @@ def main():
|
||||
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
||||
|
||||
ema_eval_metrics = validate(
|
||||
model_ema.module,
|
||||
model_ema,
|
||||
loader_eval,
|
||||
validate_loss_fn,
|
||||
args,
|
||||
@ -1002,7 +1006,7 @@ def train_one_epoch(
|
||||
|
||||
if num_updates / num_updates_total > 0.25:
|
||||
with torch.no_grad():
|
||||
output_mesa = model_ema.module(input)
|
||||
output_mesa = model_ema(input)
|
||||
|
||||
# loss_mesa = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||
# output,
|
||||
@ -1018,7 +1022,7 @@ def train_one_epoch(
|
||||
(output_mesa / 5).log_softmax(-1).detach(),
|
||||
log_target=True,
|
||||
reduction='none').sum(-1).mean()
|
||||
loss += 5 * loss_mesa
|
||||
loss += 10 * loss_mesa
|
||||
|
||||
if accum_steps > 1:
|
||||
loss /= accum_steps
|
||||
|
Loading…
x
Reference in New Issue
Block a user