diff --git a/test.py b/test.py
index ecd45f5f4..9f484c809 100644
--- a/test.py
+++ b/test.py
@@ -272,7 +272,6 @@ def test(data,
     if not training:
         s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
         print(f"Results saved to {save_dir}{s}")
-    model.float()  # for training
     maps = np.zeros(nc) + map
     for i, c in enumerate(ap_class):
         maps[c] = ap[i]
diff --git a/train.py b/train.py
index 8533667fe..7aa57fa99 100644
--- a/train.py
+++ b/train.py
@@ -31,7 +31,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima
 from utils.google_utils import attempt_download
 from utils.loss import ComputeLoss
 from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
-from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first
+from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
 
 logger = logging.getLogger(__name__)
 
@@ -136,6 +136,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
                                id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
     loggers = {'wandb': wandb}  # loggers dict
 
+    # EMA
+    ema = ModelEMA(model) if rank in [-1, 0] else None
+
     # Resume
     start_epoch, best_fitness = 0, 0.0
     if pretrained:
@@ -144,6 +147,11 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
             optimizer.load_state_dict(ckpt['optimizer'])
             best_fitness = ckpt['best_fitness']
 
+        # EMA
+        if ema and ckpt.get('ema'):
+            ema.ema.load_state_dict(ckpt['ema'][0].float().state_dict())
+            ema.updates = ckpt['ema'][1]
+
         # Results
         if ckpt.get('training_results') is not None:
             results_file.write_text(ckpt['training_results'])  # write results.txt
@@ -173,9 +181,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
         model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
         logger.info('Using SyncBatchNorm()')
 
-    # EMA
-    ema = ModelEMA(model) if rank in [-1, 0] else None
-
     # DDP mode
     if cuda and rank != -1:
         model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
@@ -191,7 +196,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
 
     # Process 0
     if rank in [-1, 0]:
-        ema.updates = start_epoch * nb // accumulate  # set EMA updates
         testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt,  # testloader
                                        hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
                                        world_size=opt.world_size, workers=opt.workers,
@@ -335,8 +339,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
         # DDP process 0 or single-GPU
         if rank in [-1, 0]:
             # mAP
-            if ema:
-                ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
+            ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
             final_epoch = epoch + 1 == epochs
             if not opt.notest or final_epoch:  # Calculate mAP
                 results, maps, times = test.test(opt.data,
@@ -378,8 +381,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
                 ckpt = {'epoch': epoch,
                         'best_fitness': best_fitness,
                         'training_results': results_file.read_text(),
-                        'model': ema.ema,
-                        'optimizer': None if final_epoch else optimizer.state_dict(),
+                        'model': (model.module if is_parallel(model) else model).half(),
+                        'ema': (ema.ema.half(), ema.updates),
+                        'optimizer': optimizer.state_dict(),
                         'wandb_id': wandb_run.id if wandb else None}
 
                 # Save last, best and delete
@@ -387,6 +391,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
                 if best_fitness == fi:
                     torch.save(ckpt, best)
                 del ckpt
+
+            model.float(), ema.ema.float()
+
         # end epoch ----------------------------------------------------------------------------------------------------
     # end training
 
diff --git a/utils/general.py b/utils/general.py
index 3b5f4629b..e5bbc50c6 100755
--- a/utils/general.py
+++ b/utils/general.py
@@ -484,8 +484,8 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
 def strip_optimizer(f='weights/best.pt', s=''):  # from utils.general import *; strip_optimizer()
     # Strip optimizer from 'f' to finalize training, optionally save as 's'
     x = torch.load(f, map_location=torch.device('cpu'))
-    for key in 'optimizer', 'training_results', 'wandb_id':
-        x[key] = None
+    for k in 'optimizer', 'training_results', 'wandb_id', 'ema':  # keys
+        x[k] = None
     x['epoch'] = -1
     x['model'].half()  # to FP16
     for p in x['model'].parameters():