mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Simplifying EMA...
This commit is contained in:
parent
80cd31f21f
commit
9214ca0716
@ -6,10 +6,7 @@ from .model_ema import ModelEma
|
|||||||
|
|
||||||
|
|
||||||
def unwrap_model(model):
|
def unwrap_model(model):
|
||||||
if isinstance(model, ModelEma):
|
return model.module if hasattr(model, 'module') else model
|
||||||
return unwrap_model(model.ema)
|
|
||||||
else:
|
|
||||||
return model.module if hasattr(model, 'module') else model
|
|
||||||
|
|
||||||
|
|
||||||
def get_state_dict(model, unwrap_fn=unwrap_model):
|
def get_state_dict(model, unwrap_fn=unwrap_model):
|
||||||
|
@ -2,16 +2,13 @@
|
|||||||
|
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
import logging
|
|
||||||
from collections import OrderedDict
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
_logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelEma:
|
class ModelEma(nn.Module):
|
||||||
""" Model Exponential Moving Average
|
""" Model Exponential Moving Average
|
||||||
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
||||||
|
|
||||||
@ -32,46 +29,20 @@ class ModelEma:
|
|||||||
GPU assignment and distributed training wrappers.
|
GPU assignment and distributed training wrappers.
|
||||||
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
||||||
"""
|
"""
|
||||||
def __init__(self, model, decay=0.9999, device='', resume=''):
|
def __init__(self, model, decay=0.9999, device=None):
|
||||||
|
super(ModelEma, self).__init__()
|
||||||
# make a copy of the model for accumulating moving average of weights
|
# make a copy of the model for accumulating moving average of weights
|
||||||
self.ema = deepcopy(model)
|
self.module = deepcopy(model)
|
||||||
self.ema.eval()
|
self.module.eval()
|
||||||
self.decay = decay
|
self.decay = decay
|
||||||
self.device = device # perform ema on different device from model if set
|
self.device = device # perform ema on different device from model if set
|
||||||
if device:
|
if device is not None:
|
||||||
self.ema.to(device=device)
|
self.module.to(device=device)
|
||||||
self.ema_has_module = hasattr(self.ema, 'module')
|
|
||||||
if resume:
|
|
||||||
self._load_checkpoint(resume)
|
|
||||||
for p in self.ema.parameters():
|
|
||||||
p.requires_grad_(False)
|
|
||||||
|
|
||||||
def _load_checkpoint(self, checkpoint_path):
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
||||||
assert isinstance(checkpoint, dict)
|
|
||||||
if 'state_dict_ema' in checkpoint:
|
|
||||||
new_state_dict = OrderedDict()
|
|
||||||
for k, v in checkpoint['state_dict_ema'].items():
|
|
||||||
# ema model may have been wrapped by DataParallel, and need module prefix
|
|
||||||
if self.ema_has_module:
|
|
||||||
name = 'module.' + k if not k.startswith('module') else k
|
|
||||||
else:
|
|
||||||
name = k
|
|
||||||
new_state_dict[name] = v
|
|
||||||
self.ema.load_state_dict(new_state_dict)
|
|
||||||
_logger.info("Loaded state_dict_ema")
|
|
||||||
else:
|
|
||||||
_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
|
|
||||||
|
|
||||||
def update(self, model):
|
def update(self, model):
|
||||||
# correct a mismatch in state dict keys
|
|
||||||
needs_module = hasattr(model, 'module') and not self.ema_has_module
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
msd = model.state_dict()
|
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
|
||||||
for k, ema_v in self.ema.state_dict().items():
|
assert ema_v.shape == model_v.shape
|
||||||
if needs_module:
|
|
||||||
k = 'module.' + k
|
|
||||||
model_v = msd[k].detach()
|
|
||||||
if self.device:
|
if self.device:
|
||||||
model_v = model_v.to(device=self.device)
|
model_v = model_v.to(device=self.device)
|
||||||
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
|
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
|
||||||
|
2
train.py
2
train.py
@ -568,7 +568,7 @@ def main():
|
|||||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||||
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
||||||
ema_eval_metrics = validate(
|
ema_eval_metrics = validate(
|
||||||
model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
|
model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
|
||||||
eval_metrics = ema_eval_metrics
|
eval_metrics = ema_eval_metrics
|
||||||
|
|
||||||
if lr_scheduler is not None:
|
if lr_scheduler is not None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user