Add autoShape() speed profiling (#2459)
* Add autoShape() speed profiling * Update common.py * Create README.md * Update hubconf.py * cleanuippull/2460/head
parent
747c2653ee
commit
569757ecc0
|
@ -108,11 +108,11 @@ To run **batched inference** with YOLOv5 and [PyTorch Hub](https://github.com/ul
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
|
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
|
||||||
|
|
||||||
# Images
|
# Images
|
||||||
dir = 'https://github.com/ultralytics/yolov5/raw/master/data/images/'
|
dir = 'https://github.com/ultralytics/yolov5/raw/master/data/images/'
|
||||||
imgs = [dir + f for f in ('zidane.jpg', 'bus.jpg')] # batched list of images
|
imgs = [dir + f for f in ('zidane.jpg', 'bus.jpg')] # batch of images
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
results = model(imgs)
|
results = model(imgs)
|
||||||
|
|
|
@ -51,7 +51,7 @@ def create(name, pretrained, channels, classes, autoshape):
|
||||||
raise Exception(s) from e
|
raise Exception(s) from e
|
||||||
|
|
||||||
|
|
||||||
def yolov5s(pretrained=False, channels=3, classes=80, autoshape=True):
|
def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True):
|
||||||
"""YOLOv5-small model from https://github.com/ultralytics/yolov5
|
"""YOLOv5-small model from https://github.com/ultralytics/yolov5
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
@ -65,7 +65,7 @@ def yolov5s(pretrained=False, channels=3, classes=80, autoshape=True):
|
||||||
return create('yolov5s', pretrained, channels, classes, autoshape)
|
return create('yolov5s', pretrained, channels, classes, autoshape)
|
||||||
|
|
||||||
|
|
||||||
def yolov5m(pretrained=False, channels=3, classes=80, autoshape=True):
|
def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True):
|
||||||
"""YOLOv5-medium model from https://github.com/ultralytics/yolov5
|
"""YOLOv5-medium model from https://github.com/ultralytics/yolov5
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
@ -79,7 +79,7 @@ def yolov5m(pretrained=False, channels=3, classes=80, autoshape=True):
|
||||||
return create('yolov5m', pretrained, channels, classes, autoshape)
|
return create('yolov5m', pretrained, channels, classes, autoshape)
|
||||||
|
|
||||||
|
|
||||||
def yolov5l(pretrained=False, channels=3, classes=80, autoshape=True):
|
def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True):
|
||||||
"""YOLOv5-large model from https://github.com/ultralytics/yolov5
|
"""YOLOv5-large model from https://github.com/ultralytics/yolov5
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
@ -93,7 +93,7 @@ def yolov5l(pretrained=False, channels=3, classes=80, autoshape=True):
|
||||||
return create('yolov5l', pretrained, channels, classes, autoshape)
|
return create('yolov5l', pretrained, channels, classes, autoshape)
|
||||||
|
|
||||||
|
|
||||||
def yolov5x(pretrained=False, channels=3, classes=80, autoshape=True):
|
def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True):
|
||||||
"""YOLOv5-xlarge model from https://github.com/ultralytics/yolov5
|
"""YOLOv5-xlarge model from https://github.com/ultralytics/yolov5
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|
|
@ -12,6 +12,7 @@ from PIL import Image
|
||||||
from utils.datasets import letterbox
|
from utils.datasets import letterbox
|
||||||
from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh
|
from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh
|
||||||
from utils.plots import color_list, plot_one_box
|
from utils.plots import color_list, plot_one_box
|
||||||
|
from utils.torch_utils import time_synchronized
|
||||||
|
|
||||||
|
|
||||||
def autopad(k, p=None): # kernel, padding
|
def autopad(k, p=None): # kernel, padding
|
||||||
|
@ -190,6 +191,7 @@ class autoShape(nn.Module):
|
||||||
# torch: = torch.zeros(16,3,720,1280) # BCHW
|
# torch: = torch.zeros(16,3,720,1280) # BCHW
|
||||||
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
||||||
|
|
||||||
|
t = [time_synchronized()]
|
||||||
p = next(self.model.parameters()) # for device and type
|
p = next(self.model.parameters()) # for device and type
|
||||||
if isinstance(imgs, torch.Tensor): # torch
|
if isinstance(imgs, torch.Tensor): # torch
|
||||||
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
|
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
|
||||||
|
@ -216,22 +218,25 @@ class autoShape(nn.Module):
|
||||||
x = np.stack(x, 0) if n > 1 else x[0][None] # stack
|
x = np.stack(x, 0) if n > 1 else x[0][None] # stack
|
||||||
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
|
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
|
||||||
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
|
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
|
||||||
|
t.append(time_synchronized())
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
y = self.model(x, augment, profile)[0] # forward
|
y = self.model(x, augment, profile)[0] # forward
|
||||||
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
|
t.append(time_synchronized())
|
||||||
|
|
||||||
# Post-process
|
# Post-process
|
||||||
|
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
scale_coords(shape1, y[i][:, :4], shape0[i])
|
scale_coords(shape1, y[i][:, :4], shape0[i])
|
||||||
|
t.append(time_synchronized())
|
||||||
|
|
||||||
return Detections(imgs, y, files, self.names)
|
return Detections(imgs, y, files, t, self.names, x.shape)
|
||||||
|
|
||||||
|
|
||||||
class Detections:
|
class Detections:
|
||||||
# detections class for YOLOv5 inference results
|
# detections class for YOLOv5 inference results
|
||||||
def __init__(self, imgs, pred, files, names=None):
|
def __init__(self, imgs, pred, files, times, names=None, shape=None):
|
||||||
super(Detections, self).__init__()
|
super(Detections, self).__init__()
|
||||||
d = pred[0].device # device
|
d = pred[0].device # device
|
||||||
gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
|
gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
|
||||||
|
@ -244,6 +249,8 @@ class Detections:
|
||||||
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
|
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
|
||||||
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
||||||
self.n = len(self.pred)
|
self.n = len(self.pred)
|
||||||
|
self.t = ((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
|
||||||
|
self.s = shape # inference BCHW shape
|
||||||
|
|
||||||
def display(self, pprint=False, show=False, save=False, render=False, save_dir=''):
|
def display(self, pprint=False, show=False, save=False, render=False, save_dir=''):
|
||||||
colors = color_list()
|
colors = color_list()
|
||||||
|
@ -271,6 +278,7 @@ class Detections:
|
||||||
|
|
||||||
def print(self):
|
def print(self):
|
||||||
self.display(pprint=True) # print results
|
self.display(pprint=True) # print results
|
||||||
|
print(f'Speed: %.1f/%.1f/%.1f ms pre-process/inference/NMS per image at shape {tuple(self.s)}' % tuple(self.t))
|
||||||
|
|
||||||
def show(self):
|
def show(self):
|
||||||
self.display(show=True) # show results
|
self.display(show=True) # show results
|
||||||
|
|
Loading…
Reference in New Issue