mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
PyTorch Hub cv2 .save() .show() bug fix (#2831)
* PyTorch Hub cv2 .save() .show() bug fix cv2.rectangle() was failing on non-contiguous np array inputs. This checks for contiguous arrays and applies is necessary: ```python imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update ``` * Update plots.py ```python assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.' ``` * Update hubconf.py Expand CI tests to OpenCV image.
This commit is contained in:
parent
aff03be35a
commit
c15e25c40f
10
hubconf.py
10
hubconf.py
@ -124,13 +124,15 @@ if __name__ == '__main__':
|
||||
# model = custom(path_or_model='path/to/model.pt') # custom example
|
||||
|
||||
# Verify inference
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
imgs = [Image.open('data/images/bus.jpg'), # PIL
|
||||
'data/images/zidane.jpg', # filename
|
||||
'https://github.com/ultralytics/yolov5/raw/master/data/images/bus.jpg', # URI
|
||||
np.zeros((640, 480, 3))] # numpy
|
||||
imgs = ['data/images/zidane.jpg', # filename
|
||||
'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg', # URI
|
||||
cv2.imread('data/images/bus.jpg')[:, :, ::-1], # OpenCV
|
||||
Image.open('data/images/bus.jpg'), # PIL
|
||||
np.zeros((320, 640, 3))] # numpy
|
||||
|
||||
results = model(imgs) # batched inference
|
||||
results.print()
|
||||
|
@ -240,7 +240,7 @@ class autoShape(nn.Module):
|
||||
@torch.no_grad()
|
||||
def forward(self, imgs, size=640, augment=False, profile=False):
|
||||
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
|
||||
# filename: imgs = 'data/samples/zidane.jpg'
|
||||
# filename: imgs = 'data/images/zidane.jpg'
|
||||
# URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
|
||||
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
|
||||
# PIL: = Image.open('image.jpg') # HWC x(640,1280,3)
|
||||
@ -271,7 +271,7 @@ class autoShape(nn.Module):
|
||||
shape0.append(s) # image shape
|
||||
g = (size / max(s)) # gain
|
||||
shape1.append([y * g for y in s])
|
||||
imgs[i] = im # update
|
||||
imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
|
||||
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
|
||||
x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
|
||||
x = np.stack(x, 0) if n > 1 else x[0][None] # stack
|
||||
|
@ -54,32 +54,34 @@ def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
|
||||
return filtfilt(b, a, data) # forward-backward filter
|
||||
|
||||
|
||||
def plot_one_box(x, img, color=None, label=None, line_thickness=3):
|
||||
# Plots one bounding box on image img
|
||||
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
|
||||
def plot_one_box(x, im, color=None, label=None, line_thickness=3):
|
||||
# Plots one bounding box on image 'im' using OpenCV
|
||||
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
|
||||
tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1 # line/font thickness
|
||||
color = color or [random.randint(0, 255) for _ in range(3)]
|
||||
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
|
||||
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
|
||||
cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
|
||||
if label:
|
||||
tf = max(tl - 1, 1) # font thickness
|
||||
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
|
||||
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
|
||||
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
|
||||
cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
||||
cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled
|
||||
cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
||||
|
||||
|
||||
def plot_one_box_PIL(box, img, color=None, label=None, line_thickness=None):
|
||||
img = Image.fromarray(img)
|
||||
draw = ImageDraw.Draw(img)
|
||||
line_thickness = line_thickness or max(int(min(img.size) / 200), 2)
|
||||
def plot_one_box_PIL(box, im, color=None, label=None, line_thickness=None):
|
||||
# Plots one bounding box on image 'im' using PIL
|
||||
im = Image.fromarray(im)
|
||||
draw = ImageDraw.Draw(im)
|
||||
line_thickness = line_thickness or max(int(min(im.size) / 200), 2)
|
||||
draw.rectangle(box, width=line_thickness, outline=tuple(color)) # plot
|
||||
if label:
|
||||
fontsize = max(round(max(img.size) / 40), 12)
|
||||
fontsize = max(round(max(im.size) / 40), 12)
|
||||
font = ImageFont.truetype("Arial.ttf", fontsize)
|
||||
txt_width, txt_height = font.getsize(label)
|
||||
draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=tuple(color))
|
||||
draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font)
|
||||
return np.asarray(img)
|
||||
return np.asarray(im)
|
||||
|
||||
|
||||
def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
|
||||
|
Loading…
x
Reference in New Issue
Block a user