[Classify]: Allow inference on dirs and videos (#9003)
* allow image dirs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update predict.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataloaders.py * Update predict.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update predict.py * Update predict.py Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/9005/head
parent
e83b422a69
commit
64e0757edf
|
@ -1,6 +1,6 @@
|
|||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
||||
"""
|
||||
Run classification inference on images
|
||||
Run classification inference on file/dir/URL/glob
|
||||
|
||||
Usage:
|
||||
$ python classify/predict.py --weights yolov5s-cls.pt --source data/images/bus.jpg
|
||||
|
@ -11,7 +11,6 @@ import os
|
|||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import torch.nn.functional as F
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
|
@ -20,27 +19,31 @@ if str(ROOT) not in sys.path:
|
|||
sys.path.append(str(ROOT)) # add ROOT to PATH
|
||||
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
||||
|
||||
from classify.train import imshow_cls
|
||||
from models.common import DetectMultiBackend
|
||||
from utils.augmentations import classify_transforms
|
||||
from utils.general import LOGGER, check_requirements, colorstr, increment_path, print_args
|
||||
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages
|
||||
from utils.general import LOGGER, check_file, check_requirements, colorstr, increment_path, print_args
|
||||
from utils.torch_utils import select_device, smart_inference_mode, time_sync
|
||||
|
||||
|
||||
@smart_inference_mode()
|
||||
def run(
|
||||
weights=ROOT / 'yolov5s-cls.pt', # model.pt path(s)
|
||||
source=ROOT / 'data/images/bus.jpg', # file/dir/URL/glob, 0 for webcam
|
||||
source=ROOT / 'data/images', # file/dir/URL/glob
|
||||
imgsz=224, # inference size
|
||||
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
||||
half=False, # use FP16 half-precision inference
|
||||
dnn=False, # use OpenCV DNN for ONNX inference
|
||||
show=True,
|
||||
project=ROOT / 'runs/predict-cls', # save to project/name
|
||||
name='exp', # save to project/name
|
||||
exist_ok=False, # existing project/name ok, do not increment
|
||||
):
|
||||
file = str(source)
|
||||
source = str(source)
|
||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
|
||||
if is_url and is_file:
|
||||
source = check_file(source) # download
|
||||
|
||||
seen, dt = 1, [0.0, 0.0, 0.0]
|
||||
device = select_device(device)
|
||||
|
||||
|
@ -48,37 +51,36 @@ def run(
|
|||
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
|
||||
save_dir.mkdir(parents=True, exist_ok=True) # make dir
|
||||
|
||||
# Transforms
|
||||
transforms = classify_transforms(imgsz)
|
||||
|
||||
# Load model
|
||||
model = DetectMultiBackend(weights, device=device, dnn=dnn, fp16=half)
|
||||
model.warmup(imgsz=(1, 3, imgsz, imgsz)) # warmup
|
||||
dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz))
|
||||
for path, im, im0s, vid_cap, s in dataset:
|
||||
# Image
|
||||
t1 = time_sync()
|
||||
im = im.unsqueeze(0).to(device)
|
||||
im = im.half() if model.fp16 else im.float()
|
||||
t2 = time_sync()
|
||||
dt[0] += t2 - t1
|
||||
|
||||
# Image
|
||||
t1 = time_sync()
|
||||
im = cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB)
|
||||
im = transforms(im).unsqueeze(0).to(device)
|
||||
im = im.half() if model.fp16 else im.float()
|
||||
t2 = time_sync()
|
||||
dt[0] += t2 - t1
|
||||
# Inference
|
||||
results = model(im)
|
||||
t3 = time_sync()
|
||||
dt[1] += t3 - t2
|
||||
|
||||
# Inference
|
||||
results = model(im)
|
||||
t3 = time_sync()
|
||||
dt[1] += t3 - t2
|
||||
|
||||
p = F.softmax(results, dim=1) # probabilities
|
||||
i = p.argsort(1, descending=True)[:, :5].squeeze() # top 5 indices
|
||||
dt[2] += time_sync() - t3
|
||||
LOGGER.info(f"image 1/1 {file}: {imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i.tolist())}")
|
||||
# Post-process
|
||||
p = F.softmax(results, dim=1) # probabilities
|
||||
i = p.argsort(1, descending=True)[:, :5].squeeze().tolist() # top 5 indices
|
||||
dt[2] += time_sync() - t3
|
||||
# if save:
|
||||
# imshow_cls(im, f=save_dir / Path(path).name, verbose=True)
|
||||
seen += 1
|
||||
LOGGER.info(f"{s}{imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i)}")
|
||||
|
||||
# Print results
|
||||
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
|
||||
shape = (1, 3, imgsz, imgsz)
|
||||
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
|
||||
if show:
|
||||
imshow_cls(im, f=save_dir / Path(file).name, verbose=True)
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
|
||||
return p
|
||||
|
||||
|
@ -86,7 +88,7 @@ def run(
|
|||
def parse_opt():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-cls.pt', help='model path(s)')
|
||||
parser.add_argument('--source', type=str, default=ROOT / 'data/images/bus.jpg', help='file')
|
||||
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob')
|
||||
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=224, help='train, val image size (pixels)')
|
||||
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
|
||||
|
|
|
@ -186,7 +186,7 @@ class _RepeatSampler:
|
|||
|
||||
class LoadImages:
|
||||
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
|
||||
def __init__(self, path, img_size=640, stride=32, auto=True):
|
||||
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None):
|
||||
files = []
|
||||
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
||||
p = str(Path(p).resolve())
|
||||
|
@ -210,6 +210,7 @@ class LoadImages:
|
|||
self.video_flag = [False] * ni + [True] * nv
|
||||
self.mode = 'image'
|
||||
self.auto = auto
|
||||
self.transforms = transforms # optional
|
||||
if any(videos):
|
||||
self.new_video(videos[0]) # new video
|
||||
else:
|
||||
|
@ -229,7 +230,7 @@ class LoadImages:
|
|||
if self.video_flag[self.count]:
|
||||
# Read video
|
||||
self.mode = 'video'
|
||||
ret_val, img0 = self.cap.read()
|
||||
ret_val, im0 = self.cap.read()
|
||||
while not ret_val:
|
||||
self.count += 1
|
||||
self.cap.release()
|
||||
|
@ -237,7 +238,7 @@ class LoadImages:
|
|||
raise StopIteration
|
||||
path = self.files[self.count]
|
||||
self.new_video(path)
|
||||
ret_val, img0 = self.cap.read()
|
||||
ret_val, im0 = self.cap.read()
|
||||
|
||||
self.frame += 1
|
||||
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
|
||||
|
@ -245,18 +246,18 @@ class LoadImages:
|
|||
else:
|
||||
# Read image
|
||||
self.count += 1
|
||||
img0 = cv2.imread(path) # BGR
|
||||
assert img0 is not None, f'Image Not Found {path}'
|
||||
im0 = cv2.imread(path) # BGR
|
||||
assert im0 is not None, f'Image Not Found {path}'
|
||||
s = f'image {self.count}/{self.nf} {path}: '
|
||||
|
||||
# Padded resize
|
||||
img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]
|
||||
if self.transforms:
|
||||
im = self.transforms(cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)) # classify transforms
|
||||
else:
|
||||
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
|
||||
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
|
||||
# Convert
|
||||
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
||||
img = np.ascontiguousarray(img)
|
||||
|
||||
return path, img, img0, self.cap, s
|
||||
return path, im, im0, self.cap, s
|
||||
|
||||
def new_video(self, path):
|
||||
self.frame = 0
|
||||
|
|
Loading…
Reference in New Issue