PyTorch Hub and autoShape update (#1415)

* PyTorch Hub and autoShape update

* comment x for imgs

* reduce comment
pull/1420/head
Glenn Jocher 2020-11-16 23:09:55 +01:00 committed by GitHub
parent 92c9b72832
commit f5429260ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 84 additions and 35 deletions

View File

@ -89,7 +89,7 @@ def detect(save_img=False):
txt_path = str(save_dir / 'labels' / p.stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '') txt_path = str(save_dir / 'labels' / p.stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
s += '%gx%g ' % img.shape[2:] # print string s += '%gx%g ' % img.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
if det is not None and len(det): if len(det):
# Rescale boxes from img_size to im0 size # Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

View File

@ -5,15 +5,16 @@ Usage:
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, channels=3, classes=80) model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, channels=3, classes=80)
""" """
dependencies = ['torch', 'yaml']
from pathlib import Path from pathlib import Path
import torch import torch
from PIL import Image
from models.yolo import Model from models.yolo import Model
from utils.general import set_logging from utils.general import set_logging
from utils.google_utils import attempt_download from utils.google_utils import attempt_download
dependencies = ['torch', 'yaml', 'pillow']
set_logging() set_logging()
@ -41,7 +42,7 @@ def create(name, pretrained, channels, classes):
model.load_state_dict(state_dict, strict=False) # load model.load_state_dict(state_dict, strict=False) # load
if len(ckpt['model'].names) == classes: if len(ckpt['model'].names) == classes:
model.names = ckpt['model'].names # set class names attribute model.names = ckpt['model'].names # set class names attribute
# model = model.autoshape() # for autoshaping of PIL/cv2/np inputs and NMS # model = model.autoshape() # for PIL/cv2/np inputs and NMS
return model return model
except Exception as e: except Exception as e:
@ -108,11 +109,10 @@ def yolov5x(pretrained=False, channels=3, classes=80):
if __name__ == '__main__': if __name__ == '__main__':
model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example
model = model.fuse().eval().autoshape() # for autoshaping of PIL/cv2/np inputs and NMS model = model.fuse().autoshape() # for PIL/cv2/np inputs and NMS
# Verify inference # Verify inference
from PIL import Image imgs = [Image.open(x) for x in Path('data/images').glob('*.jpg')]
results = model(imgs)
img = Image.open('data/images/zidane.jpg') results.show()
y = model(img) results.print()
print(y[0].shape)

View File

@ -5,9 +5,11 @@ import math
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image, ImageDraw
from utils.datasets import letterbox from utils.datasets import letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh
from utils.plots import color_list
def autopad(k, p=None): # kernel, padding def autopad(k, p=None): # kernel, padding
@ -125,47 +127,94 @@ class autoShape(nn.Module):
def __init__(self, model): def __init__(self, model):
super(autoShape, self).__init__() super(autoShape, self).__init__()
self.model = model self.model = model.eval()
def forward(self, x, size=640, augment=False, profile=False): def forward(self, imgs, size=640, augment=False, profile=False):
# supports inference from various sources. For height=720, width=1280, RGB images example inputs are: # supports inference from various sources. For height=720, width=1280, RGB images example inputs are:
# opencv: x = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) # opencv: imgs = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
# PIL: x = Image.open('image.jpg') # HWC x(720,1280,3) # PIL: imgs = Image.open('image.jpg') # HWC x(720,1280,3)
# numpy: x = np.zeros((720,1280,3)) # HWC # numpy: imgs = np.zeros((720,1280,3)) # HWC
# torch: x = torch.zeros(16,3,720,1280) # BCHW # torch: imgs = torch.zeros(16,3,720,1280) # BCHW
# multiple: x = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images # multiple: imgs = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
p = next(self.model.parameters()) # for device and type p = next(self.model.parameters()) # for device and type
if isinstance(x, torch.Tensor): # torch if isinstance(imgs, torch.Tensor): # torch
return self.model(x.to(p.device).type_as(p), augment, profile) # inference return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
# Pre-process # Pre-process
if not isinstance(x, list): if not isinstance(imgs, list):
x = [x] imgs = [imgs]
shape0, shape1 = [], [] # image and inference shapes shape0, shape1 = [], [] # image and inference shapes
batch = range(len(x)) # batch size batch = range(len(imgs)) # batch size
for i in batch: for i in batch:
x[i] = np.array(x[i]) # to numpy imgs[i] = np.array(imgs[i]) # to numpy
x[i] = x[i][:, :, :3] if x[i].ndim == 3 else np.tile(x[i][:, :, None], 3) # enforce 3ch input imgs[i] = imgs[i][:, :, :3] if imgs[i].ndim == 3 else np.tile(imgs[i][:, :, None], 3) # enforce 3ch input
s = x[i].shape[:2] # HWC s = imgs[i].shape[:2] # HWC
shape0.append(s) # image shape shape0.append(s) # image shape
g = (size / max(s)) # gain g = (size / max(s)) # gain
shape1.append([y * g for y in s]) shape1.append([y * g for y in s])
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
x = [letterbox(x[i], new_shape=shape1, auto=False)[0] for i in batch] # pad x = [letterbox(imgs[i], new_shape=shape1, auto=False)[0] for i in batch] # pad
x = np.stack(x, 0) if batch[-1] else x[0][None] # stack x = np.stack(x, 0) if batch[-1] else x[0][None] # stack
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32 x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
# Inference # Inference
x = self.model(x, augment, profile) # forward with torch.no_grad():
x = non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS y = self.model(x, augment, profile)[0] # forward
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
# Post-process # Post-process
for i in batch: for i in batch:
if x[i] is not None: if y[i] is not None:
x[i][:, :4] = scale_coords(shape1, x[i][:, :4], shape0[i]) y[i][:, :4] = scale_coords(shape1, y[i][:, :4], shape0[i])
return x
return Detections(imgs, y, self.names)
class Detections:
# detections class for YOLOv5 inference results
def __init__(self, imgs, pred, names=None):
super(Detections, self).__init__()
self.imgs = imgs # list of images as numpy arrays
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
self.names = names # class names
self.xyxy = pred # xyxy pixels
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
gn = [torch.Tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.]) for im in imgs] # normalization gains
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
def display(self, pprint=False, show=False, save=False):
colors = color_list()
for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
str = f'Image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
if pred is not None:
for c in pred[:, -1].unique():
n = (pred[:, -1] == c).sum() # detections per class
str += f'{n} {self.names[int(c)]}s, ' # add to string
if show or save:
img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
for *box, conf, cls in pred: # xyxy, confidence, class
# str += '%s %.2f, ' % (names[int(cls)], conf) # label
ImageDraw.Draw(img).rectangle(box, width=4, outline=colors[int(cls) % 10]) # plot
if save:
f = f'results{i}.jpg'
str += f"saved to '{f}'"
img.save(f) # save
if show:
img.show(f'Image {i}') # show
if pprint:
print(str)
def print(self):
self.display(pprint=True) # print results
def show(self):
self.display(show=True) # show results
def save(self):
self.display(save=True) # save results
class Flatten(nn.Module): class Flatten(nn.Module):

View File

@ -126,7 +126,7 @@ def test(data,
tcls = labels[:, 0].tolist() if nl else [] # target class tcls = labels[:, 0].tolist() if nl else [] # target class
seen += 1 seen += 1
if pred is None: if len(pred) == 0:
if nl: if nl:
stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls)) stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
continue continue

View File

@ -142,7 +142,7 @@ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
def xyxy2xywh(x): def xyxy2xywh(x):
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x) y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
y[:, 2] = x[:, 2] - x[:, 0] # width y[:, 2] = x[:, 2] - x[:, 0] # width
@ -152,7 +152,7 @@ def xyxy2xywh(x):
def xywh2xyxy(x): def xywh2xyxy(x):
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x) y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
@ -280,7 +280,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False,
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
t = time.time() t = time.time()
output = [None] * prediction.shape[0] output = [torch.zeros(0, 6)] * prediction.shape[0]
for xi, x in enumerate(prediction): # image index, image inference for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints # Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height