New `DetectMultiBackend()` class (#5549)

* New `DetectMultiBackend()` class

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* pb to pt fix

* Cleanup

* explicit apply_classifier path

* Cleanup2

* Cleanup3

* Cleanup4

* Cleanup5

* Cleanup6

* val.py MultiBackend inference

* warmup fix

* to device fix

* pt fix

* device fix

* Val cleanup

* COCO128 URL to assets

* half fix

* detect fix

* detect fix 2

* remove half from DetectMultiBackend

* training half handling

* training half handling 2

* training half handling 3

* Cleanup

* Fix CI error

* Add torchscript _extra_files

* Add TorchScript

* Add CoreML

* CoreML cleanup

* New `DetectMultiBackend()` class

* pb to pt fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleanup

* explicit apply_classifier path

* Cleanup2

* Cleanup3

* Cleanup4

* Cleanup5

* Cleanup6

* val.py MultiBackend inference

* warmup fix

* to device fix

* pt fix

* device fix

* Val cleanup

* COCO128 URL to assets

* half fix

* detect fix

* detect fix 2

* remove half from DetectMultiBackend

* training half handling

* training half handling 2

* training half handling 3

* Cleanup

* Fix CI error

* Add torchscript _extra_files

* Add TorchScript

* Add CoreML

* CoreML cleanup

* revert default to pt

* Add Usage examples

* Cleanup val

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5592/head
Glenn Jocher 2021-11-09 16:45:02 +01:00 committed by GitHub
parent 79bca2bf64
commit 3883261143
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 201 additions and 166 deletions

View File

@ -27,4 +27,4 @@ names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 't
# Download script/URL (optional) # Download script/URL (optional)
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip download: https://ultralytics.com/assets/coco128.zip

133
detect.py
View File

@ -14,12 +14,10 @@ Usage:
import argparse import argparse
import os import os
import platform
import sys import sys
from pathlib import Path from pathlib import Path
import cv2 import cv2
import numpy as np
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
@ -29,13 +27,12 @@ if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.experimental import attempt_load from models.common import DetectMultiBackend
from utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams from utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
from utils.general import (LOGGER, apply_classifier, check_file, check_img_size, check_imshow, check_requirements, from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr,
check_suffix, colorstr, increment_path, non_max_suppression, print_args, scale_coords, increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import load_classifier, select_device, time_sync from utils.torch_utils import select_device, time_sync
@torch.no_grad() @torch.no_grad()
@ -77,120 +74,45 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
# Initialize
device = select_device(device)
half &= device.type != 'cpu' # half precision only supported on CUDA
# Load model # Load model
w = str(weights[0] if isinstance(weights, list) else weights) device = select_device(device)
classify, suffix, suffixes = False, Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', ''] model = DetectMultiBackend(weights, device=device, dnn=dnn)
check_suffix(w, suffixes) # check weights have acceptable suffix stride, names, pt, jit, onnx = model.stride, model.names, model.pt, model.jit, model.onnx
pt, onnx, tflite, pb, saved_model = (suffix == x for x in suffixes) # backend booleans
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
if pt:
model = torch.jit.load(w) if 'torchscript' in w else attempt_load(weights, map_location=device)
stride = int(model.stride.max()) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
if half:
model.half() # to FP16
if classify: # second-stage classifier
modelc = load_classifier(name='resnet50', n=2) # initialize
modelc.load_state_dict(torch.load('resnet50.pt', map_location=device)['model']).to(device).eval()
elif onnx:
if dnn:
check_requirements(('opencv-python>=4.5.4',))
net = cv2.dnn.readNetFromONNX(w)
else:
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
else: # TensorFlow models
import tensorflow as tf
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped import
return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
tf.nest.map_structure(x.graph.as_graph_element, outputs))
graph_def = tf.Graph().as_graph_def()
graph_def.ParseFromString(open(w, 'rb').read())
frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
elif saved_model:
model = tf.keras.models.load_model(w)
elif tflite:
if "edgetpu" in w: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
import tflite_runtime.interpreter as tflri
delegate = {'Linux': 'libedgetpu.so.1', # install libedgetpu https://coral.ai/software/#edgetpu-runtime
'Darwin': 'libedgetpu.1.dylib',
'Windows': 'edgetpu.dll'}[platform.system()]
interpreter = tflri.Interpreter(model_path=w, experimental_delegates=[tflri.load_delegate(delegate)])
else:
interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
int8 = input_details[0]['dtype'] == np.uint8 # is TFLite quantized uint8 model
imgsz = check_img_size(imgsz, s=stride) # check image size imgsz = check_img_size(imgsz, s=stride) # check image size
# Half
half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
if pt:
model.model.half() if half else model.model.float()
# Dataloader # Dataloader
if webcam: if webcam:
view_img = check_imshow() view_img = check_imshow()
cudnn.benchmark = True # set True to speed up constant image size inference cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt) dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt and not jit)
bs = len(dataset) # batch_size bs = len(dataset) # batch_size
else: else:
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt) dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt and not jit)
bs = 1 # batch_size bs = 1 # batch_size
vid_path, vid_writer = [None] * bs, [None] * bs vid_path, vid_writer = [None] * bs, [None] * bs
# Run inference # Run inference
if pt and device.type != 'cpu': if pt and device.type != 'cpu':
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters()))) # run once model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.model.parameters()))) # warmup
dt, seen = [0.0, 0.0, 0.0], 0 dt, seen = [0.0, 0.0, 0.0], 0
for path, img, im0s, vid_cap, s in dataset: for path, im, im0s, vid_cap, s in dataset:
t1 = time_sync() t1 = time_sync()
if onnx: im = torch.from_numpy(im).to(device)
img = img.astype('float32') im = im.half() if half else im.float() # uint8 to fp16/32
else: im /= 255 # 0 - 255 to 0.0 - 1.0
img = torch.from_numpy(img).to(device) if len(im.shape) == 3:
img = img.half() if half else img.float() # uint8 to fp16/32 im = im[None] # expand for batch dim
img /= 255 # 0 - 255 to 0.0 - 1.0
if len(img.shape) == 3:
img = img[None] # expand for batch dim
t2 = time_sync() t2 = time_sync()
dt[0] += t2 - t1 dt[0] += t2 - t1
# Inference # Inference
if pt:
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
pred = model(img, augment=augment, visualize=visualize)[0] pred = model(im, augment=augment, visualize=visualize)
elif onnx:
if dnn:
net.setInput(img)
pred = torch.tensor(net.forward())
else:
pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
else: # tensorflow model (tflite, pb, saved_model)
imn = img.permute(0, 2, 3, 1).cpu().numpy() # image in numpy
if pb:
pred = frozen_func(x=tf.constant(imn)).numpy()
elif saved_model:
pred = model(imn, training=False).numpy()
elif tflite:
if int8:
scale, zero_point = input_details[0]['quantization']
imn = (imn / scale + zero_point).astype(np.uint8) # de-scale
interpreter.set_tensor(input_details[0]['index'], imn)
interpreter.invoke()
pred = interpreter.get_tensor(output_details[0]['index'])
if int8:
scale, zero_point = output_details[0]['quantization']
pred = (pred.astype(np.float32) - zero_point) * scale # re-scale
pred[..., 0] *= imgsz[1] # x
pred[..., 1] *= imgsz[0] # y
pred[..., 2] *= imgsz[1] # w
pred[..., 3] *= imgsz[0] # h
pred = torch.tensor(pred)
t3 = time_sync() t3 = time_sync()
dt[1] += t3 - t2 dt[1] += t3 - t2
@ -199,8 +121,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
dt[2] += time_sync() - t3 dt[2] += time_sync() - t3
# Second-stage classifier (optional) # Second-stage classifier (optional)
if classify: # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
pred = apply_classifier(pred, modelc, img, im0s)
# Process predictions # Process predictions
for i, det in enumerate(pred): # per image for i, det in enumerate(pred): # per image
@ -212,15 +133,15 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0) p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
p = Path(p) # to Path p = Path(p) # to Path
save_path = str(save_dir / p.name) # img.jpg save_path = str(save_dir / p.name) # im.jpg
txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
s += '%gx%g ' % img.shape[2:] # print string s += '%gx%g ' % im.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
imc = im0.copy() if save_crop else im0 # for save_crop imc = im0.copy() if save_crop else im0 # for save_crop
annotator = Annotator(im0, line_width=line_thickness, example=str(names)) annotator = Annotator(im0, line_width=line_thickness, example=str(names))
if len(det): if len(det):
# Rescale boxes from img_size to im0 size # Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
# Print results # Print results
for c in det[:, -1].unique(): for c in det[:, -1].unique():

View File

@ -21,6 +21,7 @@ TensorFlow.js:
""" """
import argparse import argparse
import json
import os import os
import subprocess import subprocess
import sys import sys
@ -54,7 +55,9 @@ def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:'
f = file.with_suffix('.torchscript.pt') f = file.with_suffix('.torchscript.pt')
ts = torch.jit.trace(model, im, strict=False) ts = torch.jit.trace(model, im, strict=False)
(optimize_for_mobile(ts) if optimize else ts).save(f) d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
(optimize_for_mobile(ts) if optimize else ts).save(f, _extra_files=extra_files)
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e: except Exception as e:

View File

@ -3,11 +3,14 @@
Common modules Common modules
""" """
import json
import math import math
import platform
import warnings import warnings
from copy import copy from copy import copy
from pathlib import Path from pathlib import Path
import cv2
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import requests import requests
@ -17,7 +20,8 @@ from PIL import Image
from torch.cuda import amp from torch.cuda import amp
from utils.datasets import exif_transpose, letterbox from utils.datasets import exif_transpose, letterbox
from utils.general import LOGGER, colorstr, increment_path, make_divisible, non_max_suppression, scale_coords, xyxy2xywh from utils.general import (LOGGER, check_requirements, check_suffix, colorstr, increment_path, make_divisible,
non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import time_sync from utils.torch_utils import time_sync
@ -269,6 +273,128 @@ class Concat(nn.Module):
return torch.cat(x, self.d) return torch.cat(x, self.d)
class DetectMultiBackend(nn.Module):
# YOLOv5 MultiBackend class for python inference on various backends
def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
# Usage:
# PyTorch: weights = *.pt
# TorchScript: *.torchscript.pt
# CoreML: *.mlmodel
# TensorFlow: *_saved_model
# TensorFlow: *.pb
# TensorFlow Lite: *.tflite
# ONNX Runtime: *.onnx
# OpenCV DNN: *.onnx with dnn=True
super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '', '.mlmodel']
check_suffix(w, suffixes) # check weights have acceptable suffix
pt, onnx, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
jit = pt and 'torchscript' in w.lower()
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
if jit: # TorchScript
LOGGER.info(f'Loading {w} for TorchScript inference...')
extra_files = {'config.txt': ''} # model metadata
model = torch.jit.load(w, _extra_files=extra_files)
if extra_files['config.txt']:
d = json.loads(extra_files['config.txt']) # extra_files dict
stride, names = int(d['stride']), d['names']
elif pt: # PyTorch
from models.experimental import attempt_load # scoped to avoid circular import
model = torch.jit.load(w) if 'torchscript' in w else attempt_load(weights, map_location=device)
stride = int(model.stride.max()) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
elif coreml: # CoreML *.mlmodel
import coremltools as ct
model = ct.models.MLModel(w)
elif dnn: # ONNX OpenCV DNN
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
check_requirements(('opencv-python>=4.5.4',))
net = cv2.dnn.readNetFromONNX(w)
elif onnx: # ONNX Runtime
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
else: # TensorFlow model (TFLite, pb, saved_model)
import tensorflow as tf
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
tf.nest.map_structure(x.graph.as_graph_element, outputs))
LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...')
graph_def = tf.Graph().as_graph_def()
graph_def.ParseFromString(open(w, 'rb').read())
frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
elif saved_model:
LOGGER.info(f'Loading {w} for TensorFlow saved_model inference...')
model = tf.keras.models.load_model(w)
elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
if 'edgetpu' in w.lower():
LOGGER.info(f'Loading {w} for TensorFlow Edge TPU inference...')
import tflite_runtime.interpreter as tfli
delegate = {'Linux': 'libedgetpu.so.1', # install https://coral.ai/software/#edgetpu-runtime
'Darwin': 'libedgetpu.1.dylib',
'Windows': 'edgetpu.dll'}[platform.system()]
interpreter = tfli.Interpreter(model_path=w, experimental_delegates=[tfli.load_delegate(delegate)])
else:
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
self.__dict__.update(locals()) # assign all variables to self
def forward(self, im, augment=False, visualize=False, val=False):
# YOLOv5 MultiBackend inference
b, ch, h, w = im.shape # batch, channel, height, width
if self.pt: # PyTorch
y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize)
return y if val else y[0]
elif self.coreml: # CoreML *.mlmodel
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
im = Image.fromarray((im[0] * 255).astype('uint8'))
# im = im.resize((192, 320), Image.ANTIALIAS)
y = self.model.predict({'image': im}) # coordinates are xywh normalized
box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
elif self.onnx: # ONNX
im = im.cpu().numpy() # torch to numpy
if self.dnn: # ONNX OpenCV DNN
self.net.setInput(im)
y = self.net.forward()
else: # ONNX Runtime
y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
else: # TensorFlow model (TFLite, pb, saved_model)
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.pb:
y = self.frozen_func(x=self.tf.constant(im)).numpy()
elif self.saved_model:
y = self.model(im, training=False).numpy()
elif self.tflite:
input, output = self.input_details[0], self.output_details[0]
int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
if int8:
scale, zero_point = input['quantization']
im = (im / scale + zero_point).astype(np.uint8) # de-scale
self.interpreter.set_tensor(input['index'], im)
self.interpreter.invoke()
y = self.interpreter.get_tensor(output['index'])
if int8:
scale, zero_point = output['quantization']
y = (y.astype(np.float32) - zero_point) * scale # re-scale
y[..., 0] *= w # x
y[..., 1] *= h # y
y[..., 2] *= w # w
y[..., 3] *= h # h
y = torch.tensor(y)
return (y, []) if val else y
class AutoShape(nn.Module): class AutoShape(nn.Module):
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
conf = 0.25 # NMS confidence threshold conf = 0.25 # NMS confidence threshold

View File

@ -785,7 +785,8 @@ def print_mutation(results, hyp, save_dir, bucket):
def apply_classifier(x, model, img, im0): def apply_classifier(x, model, img, im0):
# Apply a second stage classifier to yolo outputs # Apply a second stage classifier to YOLO outputs
# Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
im0 = [im0] if isinstance(im0, np.ndarray) else im0 im0 = [im0] if isinstance(im0, np.ndarray) else im0
for i, d in enumerate(x): # per image for i, d in enumerate(x): # per image
if d is not None and len(d): if d is not None and len(d):

View File

@ -17,7 +17,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision
from utils.general import LOGGER from utils.general import LOGGER
@ -235,25 +234,6 @@ def model_info(model, verbose=False, img_size=640):
LOGGER.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}") LOGGER.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
def load_classifier(name='resnet101', n=2):
# Loads a pretrained model reshaped to n-class output
model = torchvision.models.__dict__[name](pretrained=True)
# ResNet model properties
# input_size = [3, 224, 224]
# input_space = 'RGB'
# input_range = [0, 1]
# mean = [0.485, 0.456, 0.406]
# std = [0.229, 0.224, 0.225]
# Reshape output to n classes
filters = model.fc.weight.shape[1]
model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
model.fc.out_features = n
return model
def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416) def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
# scales img(bs,3,y,x) by ratio constrained to gs-multiple # scales img(bs,3,y,x) by ratio constrained to gs-multiple
if ratio == 1.0: if ratio == 1.0:

72
val.py
View File

@ -23,10 +23,10 @@ if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.experimental import attempt_load from models.common import DetectMultiBackend
from utils.callbacks import Callbacks from utils.callbacks import Callbacks
from utils.datasets import create_dataloader from utils.datasets import create_dataloader
from utils.general import (LOGGER, box_iou, check_dataset, check_img_size, check_requirements, check_suffix, check_yaml, from utils.general import (LOGGER, box_iou, check_dataset, check_img_size, check_requirements, check_yaml,
coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args, coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
scale_coords, xywh2xyxy, xyxy2xywh) scale_coords, xywh2xyxy, xyxy2xywh)
from utils.metrics import ConfusionMatrix, ap_per_class from utils.metrics import ConfusionMatrix, ap_per_class
@ -100,6 +100,7 @@ def run(data,
name='exp', # save to project/name name='exp', # save to project/name
exist_ok=False, # existing project/name ok, do not increment exist_ok=False, # existing project/name ok, do not increment
half=True, # use FP16 half-precision inference half=True, # use FP16 half-precision inference
dnn=False, # use OpenCV DNN for ONNX inference
model=None, model=None,
dataloader=None, dataloader=None,
save_dir=Path(''), save_dir=Path(''),
@ -110,8 +111,10 @@ def run(data,
# Initialize/load model and set device # Initialize/load model and set device
training = model is not None training = model is not None
if training: # called by train.py if training: # called by train.py
device = next(model.parameters()).device # get model device device, pt = next(model.parameters()).device, True # get model device, PyTorch model
half &= device.type != 'cpu' # half precision only supported on CUDA
model.half() if half else model.float()
else: # called directly else: # called directly
device = select_device(device, batch_size=batch_size) device = select_device(device, batch_size=batch_size)
@ -120,22 +123,21 @@ def run(data,
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
# Load model # Load model
check_suffix(weights, '.pt') model = DetectMultiBackend(weights, device=device, dnn=dnn)
model = attempt_load(weights, map_location=device) # load FP32 model stride, pt = model.stride, model.pt
gs = max(int(model.stride.max()), 32) # grid size (max stride) imgsz = check_img_size(imgsz, s=stride) # check image size
imgsz = check_img_size(imgsz, s=gs) # check image size half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
if pt:
# Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99 model.model.half() if half else model.model.float()
# if device.type != 'cpu' and torch.cuda.device_count() > 1: else:
# model = nn.DataParallel(model) half = False
batch_size = 1 # export.py models default to batch-size 1
device = torch.device('cpu')
LOGGER.info(f'Forcing --batch-size 1 square inference shape(1,3,{imgsz},{imgsz}) for non-PyTorch backends')
# Data # Data
data = check_dataset(data) # check data = check_dataset(data) # check
# Half
half &= device.type != 'cpu' # half precision only supported on CUDA
model.half() if half else model.float()
# Configure # Configure
model.eval() model.eval()
is_coco = isinstance(data.get('val'), str) and data['val'].endswith('coco/val2017.txt') # COCO dataset is_coco = isinstance(data.get('val'), str) and data['val'].endswith('coco/val2017.txt') # COCO dataset
@ -145,11 +147,11 @@ def run(data,
# Dataloader # Dataloader
if not training: if not training:
if device.type != 'cpu': if pt and device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.model.parameters()))) # warmup
pad = 0.0 if task == 'speed' else 0.5 pad = 0.0 if task == 'speed' else 0.5
task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
dataloader = create_dataloader(data[task], imgsz, batch_size, gs, single_cls, pad=pad, rect=True, dataloader = create_dataloader(data[task], imgsz, batch_size, stride, single_cls, pad=pad, rect=pt,
prefix=colorstr(f'{task}: '))[0] prefix=colorstr(f'{task}: '))[0]
seen = 0 seen = 0
@ -160,32 +162,33 @@ def run(data,
dt, p, r, f1, mp, mr, map50, map = [0.0, 0.0, 0.0], 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 dt, p, r, f1, mp, mr, map50, map = [0.0, 0.0, 0.0], 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
loss = torch.zeros(3, device=device) loss = torch.zeros(3, device=device)
jdict, stats, ap, ap_class = [], [], [], [] jdict, stats, ap, ap_class = [], [], [], []
for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)): for batch_i, (im, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
t1 = time_sync() t1 = time_sync()
img = img.to(device, non_blocking=True) if pt:
img = img.half() if half else img.float() # uint8 to fp16/32 im = im.to(device, non_blocking=True)
img /= 255 # 0 - 255 to 0.0 - 1.0
targets = targets.to(device) targets = targets.to(device)
nb, _, height, width = img.shape # batch size, channels, height, width im = im.half() if half else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
nb, _, height, width = im.shape # batch size, channels, height, width
t2 = time_sync() t2 = time_sync()
dt[0] += t2 - t1 dt[0] += t2 - t1
# Run model # Inference
out, train_out = model(img, augment=augment) # inference and training outputs out, train_out = model(im) if training else model(im, augment=augment, val=True) # inference, loss outputs
dt[1] += time_sync() - t2 dt[1] += time_sync() - t2
# Compute loss # Loss
if compute_loss: if compute_loss:
loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls
# Run NMS # NMS
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
t3 = time_sync() t3 = time_sync()
out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls) out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
dt[2] += time_sync() - t3 dt[2] += time_sync() - t3
# Statistics per image # Metrics
for si, pred in enumerate(out): for si, pred in enumerate(out):
labels = targets[targets[:, 0] == si, 1:] labels = targets[targets[:, 0] == si, 1:]
nl = len(labels) nl = len(labels)
@ -202,12 +205,12 @@ def run(data,
if single_cls: if single_cls:
pred[:, 5] = 0 pred[:, 5] = 0
predn = pred.clone() predn = pred.clone()
scale_coords(img[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred scale_coords(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred
# Evaluate # Evaluate
if nl: if nl:
tbox = xywh2xyxy(labels[:, 1:5]) # target boxes tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
scale_coords(img[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels scale_coords(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
correct = process_batch(predn, labelsn, iouv) correct = process_batch(predn, labelsn, iouv)
if plots: if plots:
@ -221,16 +224,16 @@ def run(data,
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt')) save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
if save_json: if save_json:
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
callbacks.run('on_val_image_end', pred, predn, path, names, img[si]) callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
# Plot images # Plot images
if plots and batch_i < 3: if plots and batch_i < 3:
f = save_dir / f'val_batch{batch_i}_labels.jpg' # labels f = save_dir / f'val_batch{batch_i}_labels.jpg' # labels
Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start() Thread(target=plot_images, args=(im, targets, paths, f, names), daemon=True).start()
f = save_dir / f'val_batch{batch_i}_pred.jpg' # predictions f = save_dir / f'val_batch{batch_i}_pred.jpg' # predictions
Thread(target=plot_images, args=(img, output_to_target(out), paths, f, names), daemon=True).start() Thread(target=plot_images, args=(im, output_to_target(out), paths, f, names), daemon=True).start()
# Compute statistics # Compute metrics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
if len(stats) and stats[0].any(): if len(stats) and stats[0].any():
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names) p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
@ -318,6 +321,7 @@ def parse_opt():
parser.add_argument('--name', default='exp', help='save to project/name') parser.add_argument('--name', default='exp', help='save to project/name')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
opt = parser.parse_args() opt = parser.parse_args()
opt.data = check_yaml(opt.data) # check YAML opt.data = check_yaml(opt.data) # check YAML
opt.save_json |= opt.data.endswith('coco.yaml') opt.save_json |= opt.data.endswith('coco.yaml')