EMA FP32 assert classification bug fix (#9016)

* Return EMA float on classification val

* verbose val fix

* EMA check
pull/9009/head^2
Glenn Jocher 2022-08-18 14:06:15 +02:00 committed by GitHub
parent 529aafd737
commit 20049be2e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 14 additions and 11 deletions

View File

@ -116,7 +116,7 @@ def run(
if verbose: # all classes
LOGGER.info(f"{'Class':>24}{'Images':>12}{'top1_acc':>12}{'top5_acc':>12}")
LOGGER.info(f"{'all':>24}{targets.shape[0]:>12}{top1:>12.3g}{top5:>12.3g}")
for i, c in enumerate(model.names):
for i, c in model.names.items():
aci = acc[targets == i]
top1i, top5i = aci.mean(0).tolist()
LOGGER.info(f"{c:>24}{aci.shape[0]:>12}{top1i:>12.3g}{top5i:>12.3g}")
@ -127,6 +127,7 @@ def run(
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
model.float() # for training
return top1, top5, loss

View File

@ -599,7 +599,7 @@ def parse_opt():
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
parser.add_argument('--include',
nargs='+',
default=['torchscript', 'onnx'],
default=['torchscript'],
help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
opt = parser.parse_args()
print_args(vars(opt))

View File

@ -8,7 +8,6 @@ import numpy as np
import torch
import torch.nn as nn
from models.common import Conv
from utils.downloads import attempt_download
@ -79,11 +78,16 @@ def attempt_load(weights, device=None, inplace=True, fuse=True):
for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
# Model compatibility updates
if not hasattr(ckpt, 'stride'):
ckpt.stride = torch.tensor([32.]) # compatibility update for ResNet etc.
ckpt.stride = torch.tensor([32.])
if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
# Compatibility updates
# Module compatibility updates
for m in model.modules():
t = type(m)
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):

View File

@ -107,8 +107,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
data_dict = data_dict or check_dataset(data) # check if None
train_path, val_path = data_dict['train'], data_dict['val']
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
# Model

View File

@ -408,8 +408,6 @@ class ModelEMA:
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
# Create EMA
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
# if next(model.parameters()).device.type != 'cpu':
# self.ema.half() # FP16 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
for p in self.ema.parameters():
@ -423,9 +421,10 @@ class ModelEMA:
msd = de_parallel(model).state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point:
if v.dtype.is_floating_point: # true for FP16 and FP32
v *= d
v += (1 - d) * msd[k].detach()
v += (1 - d) * msd[k]
assert v.dtype == msd[k].dtype == torch.float32, f'EMA {v.dtype} and model {msd[k]} must be updated in FP32'
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes