PyTorch Hub and autoShape update (#1415)
* PyTorch Hub and autoShape update * comment x for imgs * reduce commentpull/1420/head
parent
92c9b72832
commit
f5429260ca
|
@ -89,7 +89,7 @@ def detect(save_img=False):
|
||||||
txt_path = str(save_dir / 'labels' / p.stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
|
txt_path = str(save_dir / 'labels' / p.stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
|
||||||
s += '%gx%g ' % img.shape[2:] # print string
|
s += '%gx%g ' % img.shape[2:] # print string
|
||||||
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
|
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
|
||||||
if det is not None and len(det):
|
if len(det):
|
||||||
# Rescale boxes from img_size to im0 size
|
# Rescale boxes from img_size to im0 size
|
||||||
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
|
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
|
||||||
|
|
||||||
|
|
16
hubconf.py
16
hubconf.py
|
@ -5,15 +5,16 @@ Usage:
|
||||||
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, channels=3, classes=80)
|
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, channels=3, classes=80)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dependencies = ['torch', 'yaml']
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from models.yolo import Model
|
from models.yolo import Model
|
||||||
from utils.general import set_logging
|
from utils.general import set_logging
|
||||||
from utils.google_utils import attempt_download
|
from utils.google_utils import attempt_download
|
||||||
|
|
||||||
|
dependencies = ['torch', 'yaml', 'pillow']
|
||||||
set_logging()
|
set_logging()
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,7 +42,7 @@ def create(name, pretrained, channels, classes):
|
||||||
model.load_state_dict(state_dict, strict=False) # load
|
model.load_state_dict(state_dict, strict=False) # load
|
||||||
if len(ckpt['model'].names) == classes:
|
if len(ckpt['model'].names) == classes:
|
||||||
model.names = ckpt['model'].names # set class names attribute
|
model.names = ckpt['model'].names # set class names attribute
|
||||||
# model = model.autoshape() # for autoshaping of PIL/cv2/np inputs and NMS
|
# model = model.autoshape() # for PIL/cv2/np inputs and NMS
|
||||||
return model
|
return model
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -108,11 +109,10 @@ def yolov5x(pretrained=False, channels=3, classes=80):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example
|
model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example
|
||||||
model = model.fuse().eval().autoshape() # for autoshaping of PIL/cv2/np inputs and NMS
|
model = model.fuse().autoshape() # for PIL/cv2/np inputs and NMS
|
||||||
|
|
||||||
# Verify inference
|
# Verify inference
|
||||||
from PIL import Image
|
imgs = [Image.open(x) for x in Path('data/images').glob('*.jpg')]
|
||||||
|
results = model(imgs)
|
||||||
img = Image.open('data/images/zidane.jpg')
|
results.show()
|
||||||
y = model(img)
|
results.print()
|
||||||
print(y[0].shape)
|
|
||||||
|
|
|
@ -5,9 +5,11 @@ import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
from utils.datasets import letterbox
|
from utils.datasets import letterbox
|
||||||
from utils.general import non_max_suppression, make_divisible, scale_coords
|
from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh
|
||||||
|
from utils.plots import color_list
|
||||||
|
|
||||||
|
|
||||||
def autopad(k, p=None): # kernel, padding
|
def autopad(k, p=None): # kernel, padding
|
||||||
|
@ -125,47 +127,94 @@ class autoShape(nn.Module):
|
||||||
|
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super(autoShape, self).__init__()
|
super(autoShape, self).__init__()
|
||||||
self.model = model
|
self.model = model.eval()
|
||||||
|
|
||||||
def forward(self, x, size=640, augment=False, profile=False):
|
def forward(self, imgs, size=640, augment=False, profile=False):
|
||||||
# supports inference from various sources. For height=720, width=1280, RGB images example inputs are:
|
# supports inference from various sources. For height=720, width=1280, RGB images example inputs are:
|
||||||
# opencv: x = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
|
# opencv: imgs = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
|
||||||
# PIL: x = Image.open('image.jpg') # HWC x(720,1280,3)
|
# PIL: imgs = Image.open('image.jpg') # HWC x(720,1280,3)
|
||||||
# numpy: x = np.zeros((720,1280,3)) # HWC
|
# numpy: imgs = np.zeros((720,1280,3)) # HWC
|
||||||
# torch: x = torch.zeros(16,3,720,1280) # BCHW
|
# torch: imgs = torch.zeros(16,3,720,1280) # BCHW
|
||||||
# multiple: x = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
# multiple: imgs = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
||||||
|
|
||||||
p = next(self.model.parameters()) # for device and type
|
p = next(self.model.parameters()) # for device and type
|
||||||
if isinstance(x, torch.Tensor): # torch
|
if isinstance(imgs, torch.Tensor): # torch
|
||||||
return self.model(x.to(p.device).type_as(p), augment, profile) # inference
|
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
|
||||||
|
|
||||||
# Pre-process
|
# Pre-process
|
||||||
if not isinstance(x, list):
|
if not isinstance(imgs, list):
|
||||||
x = [x]
|
imgs = [imgs]
|
||||||
shape0, shape1 = [], [] # image and inference shapes
|
shape0, shape1 = [], [] # image and inference shapes
|
||||||
batch = range(len(x)) # batch size
|
batch = range(len(imgs)) # batch size
|
||||||
for i in batch:
|
for i in batch:
|
||||||
x[i] = np.array(x[i]) # to numpy
|
imgs[i] = np.array(imgs[i]) # to numpy
|
||||||
x[i] = x[i][:, :, :3] if x[i].ndim == 3 else np.tile(x[i][:, :, None], 3) # enforce 3ch input
|
imgs[i] = imgs[i][:, :, :3] if imgs[i].ndim == 3 else np.tile(imgs[i][:, :, None], 3) # enforce 3ch input
|
||||||
s = x[i].shape[:2] # HWC
|
s = imgs[i].shape[:2] # HWC
|
||||||
shape0.append(s) # image shape
|
shape0.append(s) # image shape
|
||||||
g = (size / max(s)) # gain
|
g = (size / max(s)) # gain
|
||||||
shape1.append([y * g for y in s])
|
shape1.append([y * g for y in s])
|
||||||
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
|
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
|
||||||
x = [letterbox(x[i], new_shape=shape1, auto=False)[0] for i in batch] # pad
|
x = [letterbox(imgs[i], new_shape=shape1, auto=False)[0] for i in batch] # pad
|
||||||
x = np.stack(x, 0) if batch[-1] else x[0][None] # stack
|
x = np.stack(x, 0) if batch[-1] else x[0][None] # stack
|
||||||
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
|
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
|
||||||
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
|
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
x = self.model(x, augment, profile) # forward
|
with torch.no_grad():
|
||||||
x = non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
|
y = self.model(x, augment, profile)[0] # forward
|
||||||
|
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
|
||||||
|
|
||||||
# Post-process
|
# Post-process
|
||||||
for i in batch:
|
for i in batch:
|
||||||
if x[i] is not None:
|
if y[i] is not None:
|
||||||
x[i][:, :4] = scale_coords(shape1, x[i][:, :4], shape0[i])
|
y[i][:, :4] = scale_coords(shape1, y[i][:, :4], shape0[i])
|
||||||
return x
|
|
||||||
|
return Detections(imgs, y, self.names)
|
||||||
|
|
||||||
|
|
||||||
|
class Detections:
|
||||||
|
# detections class for YOLOv5 inference results
|
||||||
|
def __init__(self, imgs, pred, names=None):
|
||||||
|
super(Detections, self).__init__()
|
||||||
|
self.imgs = imgs # list of images as numpy arrays
|
||||||
|
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
|
||||||
|
self.names = names # class names
|
||||||
|
self.xyxy = pred # xyxy pixels
|
||||||
|
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
|
||||||
|
gn = [torch.Tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.]) for im in imgs] # normalization gains
|
||||||
|
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
|
||||||
|
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
||||||
|
|
||||||
|
def display(self, pprint=False, show=False, save=False):
|
||||||
|
colors = color_list()
|
||||||
|
for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
|
||||||
|
str = f'Image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
|
||||||
|
if pred is not None:
|
||||||
|
for c in pred[:, -1].unique():
|
||||||
|
n = (pred[:, -1] == c).sum() # detections per class
|
||||||
|
str += f'{n} {self.names[int(c)]}s, ' # add to string
|
||||||
|
if show or save:
|
||||||
|
img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
|
||||||
|
for *box, conf, cls in pred: # xyxy, confidence, class
|
||||||
|
# str += '%s %.2f, ' % (names[int(cls)], conf) # label
|
||||||
|
ImageDraw.Draw(img).rectangle(box, width=4, outline=colors[int(cls) % 10]) # plot
|
||||||
|
if save:
|
||||||
|
f = f'results{i}.jpg'
|
||||||
|
str += f"saved to '{f}'"
|
||||||
|
img.save(f) # save
|
||||||
|
if show:
|
||||||
|
img.show(f'Image {i}') # show
|
||||||
|
if pprint:
|
||||||
|
print(str)
|
||||||
|
|
||||||
|
def print(self):
|
||||||
|
self.display(pprint=True) # print results
|
||||||
|
|
||||||
|
def show(self):
|
||||||
|
self.display(show=True) # show results
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
self.display(save=True) # save results
|
||||||
|
|
||||||
|
|
||||||
class Flatten(nn.Module):
|
class Flatten(nn.Module):
|
||||||
|
|
2
test.py
2
test.py
|
@ -126,7 +126,7 @@ def test(data,
|
||||||
tcls = labels[:, 0].tolist() if nl else [] # target class
|
tcls = labels[:, 0].tolist() if nl else [] # target class
|
||||||
seen += 1
|
seen += 1
|
||||||
|
|
||||||
if pred is None:
|
if len(pred) == 0:
|
||||||
if nl:
|
if nl:
|
||||||
stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
|
stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -142,7 +142,7 @@ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
|
||||||
|
|
||||||
def xyxy2xywh(x):
|
def xyxy2xywh(x):
|
||||||
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
|
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
|
||||||
y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||||
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
|
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
|
||||||
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
|
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
|
||||||
y[:, 2] = x[:, 2] - x[:, 0] # width
|
y[:, 2] = x[:, 2] - x[:, 0] # width
|
||||||
|
@ -152,7 +152,7 @@ def xyxy2xywh(x):
|
||||||
|
|
||||||
def xywh2xyxy(x):
|
def xywh2xyxy(x):
|
||||||
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
||||||
y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||||
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
||||||
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
||||||
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
||||||
|
@ -280,7 +280,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False,
|
||||||
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
|
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
|
||||||
|
|
||||||
t = time.time()
|
t = time.time()
|
||||||
output = [None] * prediction.shape[0]
|
output = [torch.zeros(0, 6)] * prediction.shape[0]
|
||||||
for xi, x in enumerate(prediction): # image index, image inference
|
for xi, x in enumerate(prediction): # image index, image inference
|
||||||
# Apply constraints
|
# Apply constraints
|
||||||
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
||||||
|
|
Loading…
Reference in New Issue