mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Improved Profile()
inference timing (#9024)
* Improved `Profile()` class * Update predict.py * Update val.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update val.py * Update AutoShape Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c0e7a776cd
commit
d40cd0d454
@ -22,8 +22,8 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
||||
from models.common import DetectMultiBackend
|
||||
from utils.augmentations import classify_transforms
|
||||
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages
|
||||
from utils.general import LOGGER, check_file, check_requirements, colorstr, increment_path, print_args
|
||||
from utils.torch_utils import select_device, smart_inference_mode, time_sync
|
||||
from utils.general import LOGGER, Profile, check_file, check_requirements, colorstr, increment_path, print_args
|
||||
from utils.torch_utils import select_device, smart_inference_mode
|
||||
|
||||
|
||||
@smart_inference_mode()
|
||||
@ -44,7 +44,7 @@ def run(
|
||||
if is_url and is_file:
|
||||
source = check_file(source) # download
|
||||
|
||||
seen, dt = 1, [0.0, 0.0, 0.0]
|
||||
dt = Profile(), Profile(), Profile()
|
||||
device = select_device(device)
|
||||
|
||||
# Directories
|
||||
@ -55,30 +55,27 @@ def run(
|
||||
model = DetectMultiBackend(weights, device=device, dnn=dnn, fp16=half)
|
||||
model.warmup(imgsz=(1, 3, imgsz, imgsz)) # warmup
|
||||
dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz))
|
||||
for path, im, im0s, vid_cap, s in dataset:
|
||||
for seen, (path, im, im0s, vid_cap, s) in enumerate(dataset):
|
||||
# Image
|
||||
t1 = time_sync()
|
||||
im = im.unsqueeze(0).to(device)
|
||||
im = im.half() if model.fp16 else im.float()
|
||||
t2 = time_sync()
|
||||
dt[0] += t2 - t1
|
||||
with dt[0]:
|
||||
im = im.unsqueeze(0).to(device)
|
||||
im = im.half() if model.fp16 else im.float()
|
||||
|
||||
# Inference
|
||||
results = model(im)
|
||||
t3 = time_sync()
|
||||
dt[1] += t3 - t2
|
||||
with dt[1]:
|
||||
results = model(im)
|
||||
|
||||
# Post-process
|
||||
p = F.softmax(results, dim=1) # probabilities
|
||||
i = p.argsort(1, descending=True)[:, :5].squeeze().tolist() # top 5 indices
|
||||
dt[2] += time_sync() - t3
|
||||
# if save:
|
||||
# imshow_cls(im, f=save_dir / Path(path).name, verbose=True)
|
||||
seen += 1
|
||||
LOGGER.info(f"{s}{imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i)}")
|
||||
with dt[2]:
|
||||
p = F.softmax(results, dim=1) # probabilities
|
||||
i = p.argsort(1, descending=True)[:, :5].squeeze().tolist() # top 5 indices
|
||||
# if save:
|
||||
# imshow_cls(im, f=save_dir / Path(path).name, verbose=True)
|
||||
LOGGER.info(
|
||||
f"{s}{imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i)}, {dt[1].dt * 1E3:.1f}ms")
|
||||
|
||||
# Print results
|
||||
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
|
||||
t = tuple(x.t / (seen + 1) * 1E3 for x in dt) # speeds per image
|
||||
shape = (1, 3, imgsz, imgsz)
|
||||
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
|
||||
|
@ -23,8 +23,8 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
||||
|
||||
from models.common import DetectMultiBackend
|
||||
from utils.dataloaders import create_classification_dataloader
|
||||
from utils.general import LOGGER, check_img_size, check_requirements, colorstr, increment_path, print_args
|
||||
from utils.torch_utils import select_device, smart_inference_mode, time_sync
|
||||
from utils.general import LOGGER, Profile, check_img_size, check_requirements, colorstr, increment_path, print_args
|
||||
from utils.torch_utils import select_device, smart_inference_mode
|
||||
|
||||
|
||||
@smart_inference_mode()
|
||||
@ -83,27 +83,24 @@ def run(
|
||||
workers=workers)
|
||||
|
||||
model.eval()
|
||||
pred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0]
|
||||
pred, targets, loss, dt = [], [], 0, (Profile(), Profile(), Profile())
|
||||
n = len(dataloader) # number of batches
|
||||
action = 'validating' if dataloader.dataset.root.stem == 'val' else 'testing'
|
||||
desc = f"{pbar.desc[:-36]}{action:>36}" if pbar else f"{action}"
|
||||
bar = tqdm(dataloader, desc, n, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}', position=0)
|
||||
with torch.cuda.amp.autocast(enabled=device.type != 'cpu'):
|
||||
for images, labels in bar:
|
||||
t1 = time_sync()
|
||||
images, labels = images.to(device, non_blocking=True), labels.to(device)
|
||||
t2 = time_sync()
|
||||
dt[0] += t2 - t1
|
||||
with dt[0]:
|
||||
images, labels = images.to(device, non_blocking=True), labels.to(device)
|
||||
|
||||
y = model(images)
|
||||
t3 = time_sync()
|
||||
dt[1] += t3 - t2
|
||||
with dt[1]:
|
||||
y = model(images)
|
||||
|
||||
pred.append(y.argsort(1, descending=True)[:, :5])
|
||||
targets.append(labels)
|
||||
if criterion:
|
||||
loss += criterion(y, labels)
|
||||
dt[2] += time_sync() - t3
|
||||
with dt[2]:
|
||||
pred.append(y.argsort(1, descending=True)[:, :5])
|
||||
targets.append(labels)
|
||||
if criterion:
|
||||
loss += criterion(y, labels)
|
||||
|
||||
loss /= n
|
||||
pred, targets = torch.cat(pred), torch.cat(targets)
|
||||
@ -122,7 +119,7 @@ def run(
|
||||
LOGGER.info(f"{c:>24}{aci.shape[0]:>12}{top1i:>12.3g}{top5i:>12.3g}")
|
||||
|
||||
# Print results
|
||||
t = tuple(x / len(dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
|
||||
t = tuple(x.t / len(dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
|
||||
shape = (1, 3, imgsz, imgsz)
|
||||
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
|
||||
|
35
detect.py
35
detect.py
@ -41,10 +41,10 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
||||
|
||||
from models.common import DetectMultiBackend
|
||||
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
|
||||
from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
|
||||
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
|
||||
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
|
||||
from utils.plots import Annotator, colors, save_one_box
|
||||
from utils.torch_utils import select_device, smart_inference_mode, time_sync
|
||||
from utils.torch_utils import select_device, smart_inference_mode
|
||||
|
||||
|
||||
@smart_inference_mode()
|
||||
@ -107,26 +107,23 @@ def run(
|
||||
|
||||
# Run inference
|
||||
model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup
|
||||
seen, windows, dt = 0, [], [0.0, 0.0, 0.0]
|
||||
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
|
||||
for path, im, im0s, vid_cap, s in dataset:
|
||||
t1 = time_sync()
|
||||
im = torch.from_numpy(im).to(device)
|
||||
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
|
||||
im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
t2 = time_sync()
|
||||
dt[0] += t2 - t1
|
||||
with dt[0]:
|
||||
im = torch.from_numpy(im).to(device)
|
||||
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
|
||||
im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
|
||||
# Inference
|
||||
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
|
||||
pred = model(im, augment=augment, visualize=visualize)
|
||||
t3 = time_sync()
|
||||
dt[1] += t3 - t2
|
||||
with dt[1]:
|
||||
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
|
||||
pred = model(im, augment=augment, visualize=visualize)
|
||||
|
||||
# NMS
|
||||
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
|
||||
dt[2] += time_sync() - t3
|
||||
with dt[2]:
|
||||
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
|
||||
|
||||
# Second-stage classifier (optional)
|
||||
# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
|
||||
@ -201,10 +198,10 @@ def run(
|
||||
vid_writer[i].write(im0)
|
||||
|
||||
# Print time (inference-only)
|
||||
LOGGER.info(f'{s}Done. ({t3 - t2:.3f}s)')
|
||||
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
|
||||
|
||||
# Print results
|
||||
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
|
||||
t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
|
||||
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
|
||||
if save_txt or save_img:
|
||||
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
|
||||
|
115
models/common.py
115
models/common.py
@ -21,10 +21,11 @@ from PIL import Image
|
||||
from torch.cuda import amp
|
||||
|
||||
from utils.dataloaders import exif_transpose, letterbox
|
||||
from utils.general import (LOGGER, ROOT, check_requirements, check_suffix, check_version, colorstr, increment_path,
|
||||
make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh, yaml_load)
|
||||
from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
|
||||
increment_path, make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh,
|
||||
yaml_load)
|
||||
from utils.plots import Annotator, colors, save_one_box
|
||||
from utils.torch_utils import copy_attr, smart_inference_mode, time_sync
|
||||
from utils.torch_utils import copy_attr, smart_inference_mode
|
||||
|
||||
|
||||
def autopad(k, p=None): # kernel, padding
|
||||
@ -587,9 +588,9 @@ class AutoShape(nn.Module):
|
||||
return self
|
||||
|
||||
@smart_inference_mode()
|
||||
def forward(self, imgs, size=640, augment=False, profile=False):
|
||||
def forward(self, ims, size=640, augment=False, profile=False):
|
||||
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
|
||||
# file: imgs = 'data/images/zidane.jpg' # str or PosixPath
|
||||
# file: ims = 'data/images/zidane.jpg' # str or PosixPath
|
||||
# URI: = 'https://ultralytics.com/images/zidane.jpg'
|
||||
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
|
||||
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
|
||||
@ -597,65 +598,65 @@ class AutoShape(nn.Module):
|
||||
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
|
||||
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
||||
|
||||
t = [time_sync()]
|
||||
p = next(self.model.parameters()) if self.pt else torch.zeros(1, device=self.model.device) # for device, type
|
||||
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
|
||||
if isinstance(imgs, torch.Tensor): # torch
|
||||
with amp.autocast(autocast):
|
||||
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
|
||||
dt = (Profile(), Profile(), Profile())
|
||||
with dt[0]:
|
||||
p = next(self.model.parameters()) if self.pt else torch.zeros(1, device=self.model.device) # param
|
||||
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
|
||||
if isinstance(ims, torch.Tensor): # torch
|
||||
with amp.autocast(autocast):
|
||||
return self.model(ims.to(p.device).type_as(p), augment, profile) # inference
|
||||
|
||||
# Pre-process
|
||||
n, imgs = (len(imgs), list(imgs)) if isinstance(imgs, (list, tuple)) else (1, [imgs]) # number, list of images
|
||||
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
|
||||
for i, im in enumerate(imgs):
|
||||
f = f'image{i}' # filename
|
||||
if isinstance(im, (str, Path)): # filename or uri
|
||||
im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
|
||||
im = np.asarray(exif_transpose(im))
|
||||
elif isinstance(im, Image.Image): # PIL Image
|
||||
im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
|
||||
files.append(Path(f).with_suffix('.jpg').name)
|
||||
if im.shape[0] < 5: # image in CHW
|
||||
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
|
||||
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
|
||||
s = im.shape[:2] # HWC
|
||||
shape0.append(s) # image shape
|
||||
g = (size / max(s)) # gain
|
||||
shape1.append([y * g for y in s])
|
||||
imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
|
||||
shape1 = [make_divisible(x, self.stride) if self.pt else size for x in np.array(shape1).max(0)] # inf shape
|
||||
x = [letterbox(im, shape1, auto=False)[0] for im in imgs] # pad
|
||||
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
|
||||
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
||||
t.append(time_sync())
|
||||
# Pre-process
|
||||
n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
|
||||
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
|
||||
for i, im in enumerate(ims):
|
||||
f = f'image{i}' # filename
|
||||
if isinstance(im, (str, Path)): # filename or uri
|
||||
im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
|
||||
im = np.asarray(exif_transpose(im))
|
||||
elif isinstance(im, Image.Image): # PIL Image
|
||||
im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
|
||||
files.append(Path(f).with_suffix('.jpg').name)
|
||||
if im.shape[0] < 5: # image in CHW
|
||||
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
|
||||
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
|
||||
s = im.shape[:2] # HWC
|
||||
shape0.append(s) # image shape
|
||||
g = (size / max(s)) # gain
|
||||
shape1.append([y * g for y in s])
|
||||
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
|
||||
shape1 = [make_divisible(x, self.stride) if self.pt else size for x in np.array(shape1).max(0)] # inf shape
|
||||
x = [letterbox(im, shape1, auto=False)[0] for im in ims] # pad
|
||||
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
|
||||
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
||||
|
||||
with amp.autocast(autocast):
|
||||
# Inference
|
||||
y = self.model(x, augment, profile) # forward
|
||||
t.append(time_sync())
|
||||
with dt[1]:
|
||||
y = self.model(x, augment, profile) # forward
|
||||
|
||||
# Post-process
|
||||
y = non_max_suppression(y if self.dmb else y[0],
|
||||
self.conf,
|
||||
self.iou,
|
||||
self.classes,
|
||||
self.agnostic,
|
||||
self.multi_label,
|
||||
max_det=self.max_det) # NMS
|
||||
for i in range(n):
|
||||
scale_coords(shape1, y[i][:, :4], shape0[i])
|
||||
with dt[2]:
|
||||
y = non_max_suppression(y if self.dmb else y[0],
|
||||
self.conf,
|
||||
self.iou,
|
||||
self.classes,
|
||||
self.agnostic,
|
||||
self.multi_label,
|
||||
max_det=self.max_det) # NMS
|
||||
for i in range(n):
|
||||
scale_coords(shape1, y[i][:, :4], shape0[i])
|
||||
|
||||
t.append(time_sync())
|
||||
return Detections(imgs, y, files, t, self.names, x.shape)
|
||||
return Detections(ims, y, files, dt, self.names, x.shape)
|
||||
|
||||
|
||||
class Detections:
|
||||
# YOLOv5 detections class for inference results
|
||||
def __init__(self, imgs, pred, files, times=(0, 0, 0, 0), names=None, shape=None):
|
||||
def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
|
||||
super().__init__()
|
||||
d = pred[0].device # device
|
||||
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs] # normalizations
|
||||
self.imgs = imgs # list of images as numpy arrays
|
||||
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
|
||||
self.ims = ims # list of images as numpy arrays
|
||||
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
|
||||
self.names = names # class names
|
||||
self.files = files # image filenames
|
||||
@ -665,12 +666,12 @@ class Detections:
|
||||
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
|
||||
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
||||
self.n = len(self.pred) # number of images (batch size)
|
||||
self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
|
||||
self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
|
||||
self.s = shape # inference BCHW shape
|
||||
|
||||
def display(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
|
||||
crops = []
|
||||
for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
|
||||
for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
|
||||
s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
|
||||
if pred.shape[0]:
|
||||
for c in pred[:, -1].unique():
|
||||
@ -705,7 +706,7 @@ class Detections:
|
||||
if i == self.n - 1:
|
||||
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
|
||||
if render:
|
||||
self.imgs[i] = np.asarray(im)
|
||||
self.ims[i] = np.asarray(im)
|
||||
if crop:
|
||||
if save:
|
||||
LOGGER.info(f'Saved results to {save_dir}\n')
|
||||
@ -728,7 +729,7 @@ class Detections:
|
||||
|
||||
def render(self, labels=True):
|
||||
self.display(render=True, labels=labels) # render results
|
||||
return self.imgs
|
||||
return self.ims
|
||||
|
||||
def pandas(self):
|
||||
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
|
||||
@ -743,9 +744,9 @@ class Detections:
|
||||
def tolist(self):
|
||||
# return a list of Detections objects, i.e. 'for result in results.tolist():'
|
||||
r = range(self.n) # iterable
|
||||
x = [Detections([self.imgs[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
|
||||
x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
|
||||
# for d in x:
|
||||
# for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
|
||||
# for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
|
||||
# setattr(d, k, getattr(d, k)[0]) # pop out of list
|
||||
return x
|
||||
|
||||
|
@ -141,16 +141,26 @@ CONFIG_DIR = user_config_dir() # Ultralytics settings dir
|
||||
|
||||
|
||||
class Profile(contextlib.ContextDecorator):
|
||||
# Usage: @Profile() decorator or 'with Profile():' context manager
|
||||
# YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
|
||||
def __init__(self, t=0.0):
|
||||
self.t = t
|
||||
self.cuda = torch.cuda.is_available()
|
||||
|
||||
def __enter__(self):
|
||||
self.start = time.time()
|
||||
self.start = self.time()
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
print(f'Profile results: {time.time() - self.start:.5f}s')
|
||||
self.dt = self.time() - self.start # delta-time
|
||||
self.t += self.dt # accumulate dt
|
||||
|
||||
def time(self):
|
||||
if self.cuda:
|
||||
torch.cuda.synchronize()
|
||||
return time.time()
|
||||
|
||||
|
||||
class Timeout(contextlib.ContextDecorator):
|
||||
# Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
|
||||
# YOLOv5 Timeout class. Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
|
||||
def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
|
||||
self.seconds = int(seconds)
|
||||
self.timeout_message = timeout_msg
|
||||
|
31
val.py
31
val.py
@ -37,7 +37,7 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
||||
from models.common import DetectMultiBackend
|
||||
from utils.callbacks import Callbacks
|
||||
from utils.dataloaders import create_dataloader
|
||||
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_yaml,
|
||||
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_yaml,
|
||||
coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
|
||||
scale_coords, xywh2xyxy, xyxy2xywh)
|
||||
from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
|
||||
@ -187,26 +187,24 @@ def run(
|
||||
names = dict(enumerate(names))
|
||||
class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
|
||||
s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
|
||||
dt, p, r, f1, mp, mr, map50, map = [0.0, 0.0, 0.0], 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
|
||||
dt, p, r, f1, mp, mr, map50, map = (Profile(), Profile(), Profile()), 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
|
||||
loss = torch.zeros(3, device=device)
|
||||
jdict, stats, ap, ap_class = [], [], [], []
|
||||
callbacks.run('on_val_start')
|
||||
pbar = tqdm(dataloader, desc=s, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
|
||||
for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
|
||||
callbacks.run('on_val_batch_start')
|
||||
t1 = time_sync()
|
||||
if cuda:
|
||||
im = im.to(device, non_blocking=True)
|
||||
targets = targets.to(device)
|
||||
im = im.half() if half else im.float() # uint8 to fp16/32
|
||||
im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
nb, _, height, width = im.shape # batch size, channels, height, width
|
||||
t2 = time_sync()
|
||||
dt[0] += t2 - t1
|
||||
with dt[0]:
|
||||
if cuda:
|
||||
im = im.to(device, non_blocking=True)
|
||||
targets = targets.to(device)
|
||||
im = im.half() if half else im.float() # uint8 to fp16/32
|
||||
im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
nb, _, height, width = im.shape # batch size, channels, height, width
|
||||
|
||||
# Inference
|
||||
out, train_out = model(im) if training else model(im, augment=augment, val=True) # inference, loss outputs
|
||||
dt[1] += time_sync() - t2
|
||||
with dt[1]:
|
||||
out, train_out = model(im) if training else model(im, augment=augment, val=True) # inference, loss outputs
|
||||
|
||||
# Loss
|
||||
if compute_loss:
|
||||
@ -215,9 +213,8 @@ def run(
|
||||
# NMS
|
||||
targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels
|
||||
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
|
||||
t3 = time_sync()
|
||||
out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
|
||||
dt[2] += time_sync() - t3
|
||||
with dt[2]:
|
||||
out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
|
||||
|
||||
# Metrics
|
||||
for si, pred in enumerate(out):
|
||||
@ -284,7 +281,7 @@ def run(
|
||||
LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
|
||||
|
||||
# Print speeds
|
||||
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
|
||||
t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
|
||||
if not training:
|
||||
shape = (batch_size, 3, imgsz, imgsz)
|
||||
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)
|
||||
|
Loading…
x
Reference in New Issue
Block a user