EMA FP32 assert classification bug fix (#9016)
* Return EMA float on classification val * verbose val fix * EMA checkpull/9009/head^2
parent
529aafd737
commit
20049be2e7
|
@ -116,7 +116,7 @@ def run(
|
|||
if verbose: # all classes
|
||||
LOGGER.info(f"{'Class':>24}{'Images':>12}{'top1_acc':>12}{'top5_acc':>12}")
|
||||
LOGGER.info(f"{'all':>24}{targets.shape[0]:>12}{top1:>12.3g}{top5:>12.3g}")
|
||||
for i, c in enumerate(model.names):
|
||||
for i, c in model.names.items():
|
||||
aci = acc[targets == i]
|
||||
top1i, top5i = aci.mean(0).tolist()
|
||||
LOGGER.info(f"{c:>24}{aci.shape[0]:>12}{top1i:>12.3g}{top5i:>12.3g}")
|
||||
|
@ -127,6 +127,7 @@ def run(
|
|||
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
|
||||
|
||||
model.float() # for training
|
||||
return top1, top5, loss
|
||||
|
||||
|
||||
|
|
|
@ -599,7 +599,7 @@ def parse_opt():
|
|||
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
|
||||
parser.add_argument('--include',
|
||||
nargs='+',
|
||||
default=['torchscript', 'onnx'],
|
||||
default=['torchscript'],
|
||||
help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
|
||||
opt = parser.parse_args()
|
||||
print_args(vars(opt))
|
||||
|
|
|
@ -8,7 +8,6 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from models.common import Conv
|
||||
from utils.downloads import attempt_download
|
||||
|
||||
|
||||
|
@ -79,11 +78,16 @@ def attempt_load(weights, device=None, inplace=True, fuse=True):
|
|||
for w in weights if isinstance(weights, list) else [weights]:
|
||||
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
||||
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||
|
||||
# Model compatibility updates
|
||||
if not hasattr(ckpt, 'stride'):
|
||||
ckpt.stride = torch.tensor([32.]) # compatibility update for ResNet etc.
|
||||
ckpt.stride = torch.tensor([32.])
|
||||
if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
|
||||
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
|
||||
|
||||
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
||||
|
||||
# Compatibility updates
|
||||
# Module compatibility updates
|
||||
for m in model.modules():
|
||||
t = type(m)
|
||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
|
||||
|
|
3
train.py
3
train.py
|
@ -107,8 +107,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|||
data_dict = data_dict or check_dataset(data) # check if None
|
||||
train_path, val_path = data_dict['train'], data_dict['val']
|
||||
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
|
||||
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
||||
assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
|
||||
names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
||||
is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
|
||||
|
||||
# Model
|
||||
|
|
|
@ -408,8 +408,6 @@ class ModelEMA:
|
|||
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||||
# Create EMA
|
||||
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
||||
# if next(model.parameters()).device.type != 'cpu':
|
||||
# self.ema.half() # FP16 EMA
|
||||
self.updates = updates # number of EMA updates
|
||||
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
||||
for p in self.ema.parameters():
|
||||
|
@ -423,9 +421,10 @@ class ModelEMA:
|
|||
|
||||
msd = de_parallel(model).state_dict() # model state_dict
|
||||
for k, v in self.ema.state_dict().items():
|
||||
if v.dtype.is_floating_point:
|
||||
if v.dtype.is_floating_point: # true for FP16 and FP32
|
||||
v *= d
|
||||
v += (1 - d) * msd[k].detach()
|
||||
v += (1 - d) * msd[k]
|
||||
assert v.dtype == msd[k].dtype == torch.float32, f'EMA {v.dtype} and model {msd[k]} must be updated in FP32'
|
||||
|
||||
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
||||
# Update EMA attributes
|
||||
|
|
Loading…
Reference in New Issue