mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
EMA FP32 assert classification bug fix (#9016)
* Return EMA float on classification val * verbose val fix * EMA check
This commit is contained in:
parent
529aafd737
commit
20049be2e7
@ -116,7 +116,7 @@ def run(
|
|||||||
if verbose: # all classes
|
if verbose: # all classes
|
||||||
LOGGER.info(f"{'Class':>24}{'Images':>12}{'top1_acc':>12}{'top5_acc':>12}")
|
LOGGER.info(f"{'Class':>24}{'Images':>12}{'top1_acc':>12}{'top5_acc':>12}")
|
||||||
LOGGER.info(f"{'all':>24}{targets.shape[0]:>12}{top1:>12.3g}{top5:>12.3g}")
|
LOGGER.info(f"{'all':>24}{targets.shape[0]:>12}{top1:>12.3g}{top5:>12.3g}")
|
||||||
for i, c in enumerate(model.names):
|
for i, c in model.names.items():
|
||||||
aci = acc[targets == i]
|
aci = acc[targets == i]
|
||||||
top1i, top5i = aci.mean(0).tolist()
|
top1i, top5i = aci.mean(0).tolist()
|
||||||
LOGGER.info(f"{c:>24}{aci.shape[0]:>12}{top1i:>12.3g}{top5i:>12.3g}")
|
LOGGER.info(f"{c:>24}{aci.shape[0]:>12}{top1i:>12.3g}{top5i:>12.3g}")
|
||||||
@ -127,6 +127,7 @@ def run(
|
|||||||
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
|
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
|
||||||
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
|
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
|
||||||
|
|
||||||
|
model.float() # for training
|
||||||
return top1, top5, loss
|
return top1, top5, loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -599,7 +599,7 @@ def parse_opt():
|
|||||||
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
|
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
|
||||||
parser.add_argument('--include',
|
parser.add_argument('--include',
|
||||||
nargs='+',
|
nargs='+',
|
||||||
default=['torchscript', 'onnx'],
|
default=['torchscript'],
|
||||||
help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
|
help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
print_args(vars(opt))
|
print_args(vars(opt))
|
||||||
|
@ -8,7 +8,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from models.common import Conv
|
|
||||||
from utils.downloads import attempt_download
|
from utils.downloads import attempt_download
|
||||||
|
|
||||||
|
|
||||||
@ -79,11 +78,16 @@ def attempt_load(weights, device=None, inplace=True, fuse=True):
|
|||||||
for w in weights if isinstance(weights, list) else [weights]:
|
for w in weights if isinstance(weights, list) else [weights]:
|
||||||
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
||||||
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||||
|
|
||||||
|
# Model compatibility updates
|
||||||
if not hasattr(ckpt, 'stride'):
|
if not hasattr(ckpt, 'stride'):
|
||||||
ckpt.stride = torch.tensor([32.]) # compatibility update for ResNet etc.
|
ckpt.stride = torch.tensor([32.])
|
||||||
|
if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
|
||||||
|
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
|
||||||
|
|
||||||
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
||||||
|
|
||||||
# Compatibility updates
|
# Module compatibility updates
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
t = type(m)
|
t = type(m)
|
||||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
|
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
|
||||||
|
3
train.py
3
train.py
@ -107,8 +107,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|||||||
data_dict = data_dict or check_dataset(data) # check if None
|
data_dict = data_dict or check_dataset(data) # check if None
|
||||||
train_path, val_path = data_dict['train'], data_dict['val']
|
train_path, val_path = data_dict['train'], data_dict['val']
|
||||||
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
|
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
|
||||||
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
||||||
assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
|
|
||||||
is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
|
is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
|
@ -408,8 +408,6 @@ class ModelEMA:
|
|||||||
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||||||
# Create EMA
|
# Create EMA
|
||||||
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
||||||
# if next(model.parameters()).device.type != 'cpu':
|
|
||||||
# self.ema.half() # FP16 EMA
|
|
||||||
self.updates = updates # number of EMA updates
|
self.updates = updates # number of EMA updates
|
||||||
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
||||||
for p in self.ema.parameters():
|
for p in self.ema.parameters():
|
||||||
@ -423,9 +421,10 @@ class ModelEMA:
|
|||||||
|
|
||||||
msd = de_parallel(model).state_dict() # model state_dict
|
msd = de_parallel(model).state_dict() # model state_dict
|
||||||
for k, v in self.ema.state_dict().items():
|
for k, v in self.ema.state_dict().items():
|
||||||
if v.dtype.is_floating_point:
|
if v.dtype.is_floating_point: # true for FP16 and FP32
|
||||||
v *= d
|
v *= d
|
||||||
v += (1 - d) * msd[k].detach()
|
v += (1 - d) * msd[k]
|
||||||
|
assert v.dtype == msd[k].dtype == torch.float32, f'EMA {v.dtype} and model {msd[k]} must be updated in FP32'
|
||||||
|
|
||||||
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
||||||
# Update EMA attributes
|
# Update EMA attributes
|
||||||
|
Loading…
x
Reference in New Issue
Block a user