EMA FP32 assert classification bug fix (#9016)

* Return EMA float on classification val

* verbose val fix

* EMA check
This commit is contained in:
Glenn Jocher 2022-08-18 14:06:15 +02:00 committed by GitHub
parent 529aafd737
commit 20049be2e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 14 additions and 11 deletions

View File

@ -116,7 +116,7 @@ def run(
if verbose: # all classes if verbose: # all classes
LOGGER.info(f"{'Class':>24}{'Images':>12}{'top1_acc':>12}{'top5_acc':>12}") LOGGER.info(f"{'Class':>24}{'Images':>12}{'top1_acc':>12}{'top5_acc':>12}")
LOGGER.info(f"{'all':>24}{targets.shape[0]:>12}{top1:>12.3g}{top5:>12.3g}") LOGGER.info(f"{'all':>24}{targets.shape[0]:>12}{top1:>12.3g}{top5:>12.3g}")
for i, c in enumerate(model.names): for i, c in model.names.items():
aci = acc[targets == i] aci = acc[targets == i]
top1i, top5i = aci.mean(0).tolist() top1i, top5i = aci.mean(0).tolist()
LOGGER.info(f"{c:>24}{aci.shape[0]:>12}{top1i:>12.3g}{top5i:>12.3g}") LOGGER.info(f"{c:>24}{aci.shape[0]:>12}{top1i:>12.3g}{top5i:>12.3g}")
@ -127,6 +127,7 @@ def run(
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t) LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
model.float() # for training
return top1, top5, loss return top1, top5, loss

View File

@ -599,7 +599,7 @@ def parse_opt():
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold') parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
parser.add_argument('--include', parser.add_argument('--include',
nargs='+', nargs='+',
default=['torchscript', 'onnx'], default=['torchscript'],
help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs') help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
opt = parser.parse_args() opt = parser.parse_args()
print_args(vars(opt)) print_args(vars(opt))

View File

@ -8,7 +8,6 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from models.common import Conv
from utils.downloads import attempt_download from utils.downloads import attempt_download
@ -79,11 +78,16 @@ def attempt_load(weights, device=None, inplace=True, fuse=True):
for w in weights if isinstance(weights, list) else [weights]: for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w), map_location='cpu') # load ckpt = torch.load(attempt_download(w), map_location='cpu') # load
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
# Model compatibility updates
if not hasattr(ckpt, 'stride'): if not hasattr(ckpt, 'stride'):
ckpt.stride = torch.tensor([32.]) # compatibility update for ResNet etc. ckpt.stride = torch.tensor([32.])
if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
# Compatibility updates # Module compatibility updates
for m in model.modules(): for m in model.modules():
t = type(m) t = type(m)
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model): if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):

View File

@ -107,8 +107,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
data_dict = data_dict or check_dataset(data) # check if None data_dict = data_dict or check_dataset(data) # check if None
train_path, val_path = data_dict['train'], data_dict['val'] train_path, val_path = data_dict['train'], data_dict['val']
nc = 1 if single_cls else int(data_dict['nc']) # number of classes nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
# Model # Model

View File

@ -408,8 +408,6 @@ class ModelEMA:
def __init__(self, model, decay=0.9999, tau=2000, updates=0): def __init__(self, model, decay=0.9999, tau=2000, updates=0):
# Create EMA # Create EMA
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
# if next(model.parameters()).device.type != 'cpu':
# self.ema.half() # FP16 EMA
self.updates = updates # number of EMA updates self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs) self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
for p in self.ema.parameters(): for p in self.ema.parameters():
@ -423,9 +421,10 @@ class ModelEMA:
msd = de_parallel(model).state_dict() # model state_dict msd = de_parallel(model).state_dict() # model state_dict
for k, v in self.ema.state_dict().items(): for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point: if v.dtype.is_floating_point: # true for FP16 and FP32
v *= d v *= d
v += (1 - d) * msd[k].detach() v += (1 - d) * msd[k]
assert v.dtype == msd[k].dtype == torch.float32, f'EMA {v.dtype} and model {msd[k]} must be updated in FP32'
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes # Update EMA attributes