Improved Profile() inference timing (#9024)

* Improved `Profile()` class

* Update predict.py

* Update val.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update val.py

* Update AutoShape

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2022-08-18 19:55:38 +02:00 committed by GitHub
parent c0e7a776cd
commit d40cd0d454
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 132 additions and 133 deletions

View File

@ -22,8 +22,8 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.common import DetectMultiBackend from models.common import DetectMultiBackend
from utils.augmentations import classify_transforms from utils.augmentations import classify_transforms
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages
from utils.general import LOGGER, check_file, check_requirements, colorstr, increment_path, print_args from utils.general import LOGGER, Profile, check_file, check_requirements, colorstr, increment_path, print_args
from utils.torch_utils import select_device, smart_inference_mode, time_sync from utils.torch_utils import select_device, smart_inference_mode
@smart_inference_mode() @smart_inference_mode()
@ -44,7 +44,7 @@ def run(
if is_url and is_file: if is_url and is_file:
source = check_file(source) # download source = check_file(source) # download
seen, dt = 1, [0.0, 0.0, 0.0] dt = Profile(), Profile(), Profile()
device = select_device(device) device = select_device(device)
# Directories # Directories
@ -55,30 +55,27 @@ def run(
model = DetectMultiBackend(weights, device=device, dnn=dnn, fp16=half) model = DetectMultiBackend(weights, device=device, dnn=dnn, fp16=half)
model.warmup(imgsz=(1, 3, imgsz, imgsz)) # warmup model.warmup(imgsz=(1, 3, imgsz, imgsz)) # warmup
dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz)) dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz))
for path, im, im0s, vid_cap, s in dataset: for seen, (path, im, im0s, vid_cap, s) in enumerate(dataset):
# Image # Image
t1 = time_sync() with dt[0]:
im = im.unsqueeze(0).to(device) im = im.unsqueeze(0).to(device)
im = im.half() if model.fp16 else im.float() im = im.half() if model.fp16 else im.float()
t2 = time_sync()
dt[0] += t2 - t1
# Inference # Inference
results = model(im) with dt[1]:
t3 = time_sync() results = model(im)
dt[1] += t3 - t2
# Post-process # Post-process
p = F.softmax(results, dim=1) # probabilities with dt[2]:
i = p.argsort(1, descending=True)[:, :5].squeeze().tolist() # top 5 indices p = F.softmax(results, dim=1) # probabilities
dt[2] += time_sync() - t3 i = p.argsort(1, descending=True)[:, :5].squeeze().tolist() # top 5 indices
# if save: # if save:
# imshow_cls(im, f=save_dir / Path(path).name, verbose=True) # imshow_cls(im, f=save_dir / Path(path).name, verbose=True)
seen += 1 LOGGER.info(
LOGGER.info(f"{s}{imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i)}") f"{s}{imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i)}, {dt[1].dt * 1E3:.1f}ms")
# Print results # Print results
t = tuple(x / seen * 1E3 for x in dt) # speeds per image t = tuple(x.t / (seen + 1) * 1E3 for x in dt) # speeds per image
shape = (1, 3, imgsz, imgsz) shape = (1, 3, imgsz, imgsz)
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t) LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")

View File

@ -23,8 +23,8 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.common import DetectMultiBackend from models.common import DetectMultiBackend
from utils.dataloaders import create_classification_dataloader from utils.dataloaders import create_classification_dataloader
from utils.general import LOGGER, check_img_size, check_requirements, colorstr, increment_path, print_args from utils.general import LOGGER, Profile, check_img_size, check_requirements, colorstr, increment_path, print_args
from utils.torch_utils import select_device, smart_inference_mode, time_sync from utils.torch_utils import select_device, smart_inference_mode
@smart_inference_mode() @smart_inference_mode()
@ -83,27 +83,24 @@ def run(
workers=workers) workers=workers)
model.eval() model.eval()
pred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0] pred, targets, loss, dt = [], [], 0, (Profile(), Profile(), Profile())
n = len(dataloader) # number of batches n = len(dataloader) # number of batches
action = 'validating' if dataloader.dataset.root.stem == 'val' else 'testing' action = 'validating' if dataloader.dataset.root.stem == 'val' else 'testing'
desc = f"{pbar.desc[:-36]}{action:>36}" if pbar else f"{action}" desc = f"{pbar.desc[:-36]}{action:>36}" if pbar else f"{action}"
bar = tqdm(dataloader, desc, n, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}', position=0) bar = tqdm(dataloader, desc, n, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}', position=0)
with torch.cuda.amp.autocast(enabled=device.type != 'cpu'): with torch.cuda.amp.autocast(enabled=device.type != 'cpu'):
for images, labels in bar: for images, labels in bar:
t1 = time_sync() with dt[0]:
images, labels = images.to(device, non_blocking=True), labels.to(device) images, labels = images.to(device, non_blocking=True), labels.to(device)
t2 = time_sync()
dt[0] += t2 - t1
y = model(images) with dt[1]:
t3 = time_sync() y = model(images)
dt[1] += t3 - t2
pred.append(y.argsort(1, descending=True)[:, :5]) with dt[2]:
targets.append(labels) pred.append(y.argsort(1, descending=True)[:, :5])
if criterion: targets.append(labels)
loss += criterion(y, labels) if criterion:
dt[2] += time_sync() - t3 loss += criterion(y, labels)
loss /= n loss /= n
pred, targets = torch.cat(pred), torch.cat(targets) pred, targets = torch.cat(pred), torch.cat(targets)
@ -122,7 +119,7 @@ def run(
LOGGER.info(f"{c:>24}{aci.shape[0]:>12}{top1i:>12.3g}{top5i:>12.3g}") LOGGER.info(f"{c:>24}{aci.shape[0]:>12}{top1i:>12.3g}{top5i:>12.3g}")
# Print results # Print results
t = tuple(x / len(dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image t = tuple(x.t / len(dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
shape = (1, 3, imgsz, imgsz) shape = (1, 3, imgsz, imgsz)
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t) LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")

View File

@ -41,10 +41,10 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.common import DetectMultiBackend from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2, from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh) increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device, smart_inference_mode, time_sync from utils.torch_utils import select_device, smart_inference_mode
@smart_inference_mode() @smart_inference_mode()
@ -107,26 +107,23 @@ def run(
# Run inference # Run inference
model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup
seen, windows, dt = 0, [], [0.0, 0.0, 0.0] seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
for path, im, im0s, vid_cap, s in dataset: for path, im, im0s, vid_cap, s in dataset:
t1 = time_sync() with dt[0]:
im = torch.from_numpy(im).to(device) im = torch.from_numpy(im).to(device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0 im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3: if len(im.shape) == 3:
im = im[None] # expand for batch dim im = im[None] # expand for batch dim
t2 = time_sync()
dt[0] += t2 - t1
# Inference # Inference
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False with dt[1]:
pred = model(im, augment=augment, visualize=visualize) visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
t3 = time_sync() pred = model(im, augment=augment, visualize=visualize)
dt[1] += t3 - t2
# NMS # NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) with dt[2]:
dt[2] += time_sync() - t3 pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
# Second-stage classifier (optional) # Second-stage classifier (optional)
# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s) # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
@ -201,10 +198,10 @@ def run(
vid_writer[i].write(im0) vid_writer[i].write(im0)
# Print time (inference-only) # Print time (inference-only)
LOGGER.info(f'{s}Done. ({t3 - t2:.3f}s)') LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
# Print results # Print results
t = tuple(x / seen * 1E3 for x in dt) # speeds per image t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t) LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
if save_txt or save_img: if save_txt or save_img:
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''

View File

@ -21,10 +21,11 @@ from PIL import Image
from torch.cuda import amp from torch.cuda import amp
from utils.dataloaders import exif_transpose, letterbox from utils.dataloaders import exif_transpose, letterbox
from utils.general import (LOGGER, ROOT, check_requirements, check_suffix, check_version, colorstr, increment_path, from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh, yaml_load) increment_path, make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh,
yaml_load)
from utils.plots import Annotator, colors, save_one_box from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import copy_attr, smart_inference_mode, time_sync from utils.torch_utils import copy_attr, smart_inference_mode
def autopad(k, p=None): # kernel, padding def autopad(k, p=None): # kernel, padding
@ -587,9 +588,9 @@ class AutoShape(nn.Module):
return self return self
@smart_inference_mode() @smart_inference_mode()
def forward(self, imgs, size=640, augment=False, profile=False): def forward(self, ims, size=640, augment=False, profile=False):
# Inference from various sources. For height=640, width=1280, RGB images example inputs are: # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
# file: imgs = 'data/images/zidane.jpg' # str or PosixPath # file: ims = 'data/images/zidane.jpg' # str or PosixPath
# URI: = 'https://ultralytics.com/images/zidane.jpg' # URI: = 'https://ultralytics.com/images/zidane.jpg'
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3) # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3) # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
@ -597,65 +598,65 @@ class AutoShape(nn.Module):
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values) # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
t = [time_sync()] dt = (Profile(), Profile(), Profile())
p = next(self.model.parameters()) if self.pt else torch.zeros(1, device=self.model.device) # for device, type with dt[0]:
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference p = next(self.model.parameters()) if self.pt else torch.zeros(1, device=self.model.device) # param
if isinstance(imgs, torch.Tensor): # torch autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
with amp.autocast(autocast): if isinstance(ims, torch.Tensor): # torch
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference with amp.autocast(autocast):
return self.model(ims.to(p.device).type_as(p), augment, profile) # inference
# Pre-process # Pre-process
n, imgs = (len(imgs), list(imgs)) if isinstance(imgs, (list, tuple)) else (1, [imgs]) # number, list of images n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
shape0, shape1, files = [], [], [] # image and inference shapes, filenames shape0, shape1, files = [], [], [] # image and inference shapes, filenames
for i, im in enumerate(imgs): for i, im in enumerate(ims):
f = f'image{i}' # filename f = f'image{i}' # filename
if isinstance(im, (str, Path)): # filename or uri if isinstance(im, (str, Path)): # filename or uri
im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
im = np.asarray(exif_transpose(im)) im = np.asarray(exif_transpose(im))
elif isinstance(im, Image.Image): # PIL Image elif isinstance(im, Image.Image): # PIL Image
im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
files.append(Path(f).with_suffix('.jpg').name) files.append(Path(f).with_suffix('.jpg').name)
if im.shape[0] < 5: # image in CHW if im.shape[0] < 5: # image in CHW
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1) im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
s = im.shape[:2] # HWC s = im.shape[:2] # HWC
shape0.append(s) # image shape shape0.append(s) # image shape
g = (size / max(s)) # gain g = (size / max(s)) # gain
shape1.append([y * g for y in s]) shape1.append([y * g for y in s])
imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
shape1 = [make_divisible(x, self.stride) if self.pt else size for x in np.array(shape1).max(0)] # inf shape shape1 = [make_divisible(x, self.stride) if self.pt else size for x in np.array(shape1).max(0)] # inf shape
x = [letterbox(im, shape1, auto=False)[0] for im in imgs] # pad x = [letterbox(im, shape1, auto=False)[0] for im in ims] # pad
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32 x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
t.append(time_sync())
with amp.autocast(autocast): with amp.autocast(autocast):
# Inference # Inference
y = self.model(x, augment, profile) # forward with dt[1]:
t.append(time_sync()) y = self.model(x, augment, profile) # forward
# Post-process # Post-process
y = non_max_suppression(y if self.dmb else y[0], with dt[2]:
self.conf, y = non_max_suppression(y if self.dmb else y[0],
self.iou, self.conf,
self.classes, self.iou,
self.agnostic, self.classes,
self.multi_label, self.agnostic,
max_det=self.max_det) # NMS self.multi_label,
for i in range(n): max_det=self.max_det) # NMS
scale_coords(shape1, y[i][:, :4], shape0[i]) for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i])
t.append(time_sync()) return Detections(ims, y, files, dt, self.names, x.shape)
return Detections(imgs, y, files, t, self.names, x.shape)
class Detections: class Detections:
# YOLOv5 detections class for inference results # YOLOv5 detections class for inference results
def __init__(self, imgs, pred, files, times=(0, 0, 0, 0), names=None, shape=None): def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
super().__init__() super().__init__()
d = pred[0].device # device d = pred[0].device # device
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs] # normalizations gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
self.imgs = imgs # list of images as numpy arrays self.ims = ims # list of images as numpy arrays
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls) self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
self.names = names # class names self.names = names # class names
self.files = files # image filenames self.files = files # image filenames
@ -665,12 +666,12 @@ class Detections:
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
self.n = len(self.pred) # number of images (batch size) self.n = len(self.pred) # number of images (batch size)
self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms) self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
self.s = shape # inference BCHW shape self.s = shape # inference BCHW shape
def display(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')): def display(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
crops = [] crops = []
for i, (im, pred) in enumerate(zip(self.imgs, self.pred)): for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
if pred.shape[0]: if pred.shape[0]:
for c in pred[:, -1].unique(): for c in pred[:, -1].unique():
@ -705,7 +706,7 @@ class Detections:
if i == self.n - 1: if i == self.n - 1:
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}") LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
if render: if render:
self.imgs[i] = np.asarray(im) self.ims[i] = np.asarray(im)
if crop: if crop:
if save: if save:
LOGGER.info(f'Saved results to {save_dir}\n') LOGGER.info(f'Saved results to {save_dir}\n')
@ -728,7 +729,7 @@ class Detections:
def render(self, labels=True): def render(self, labels=True):
self.display(render=True, labels=labels) # render results self.display(render=True, labels=labels) # render results
return self.imgs return self.ims
def pandas(self): def pandas(self):
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0]) # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
@ -743,9 +744,9 @@ class Detections:
def tolist(self): def tolist(self):
# return a list of Detections objects, i.e. 'for result in results.tolist():' # return a list of Detections objects, i.e. 'for result in results.tolist():'
r = range(self.n) # iterable r = range(self.n) # iterable
x = [Detections([self.imgs[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r] x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
# for d in x: # for d in x:
# for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']: # for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
# setattr(d, k, getattr(d, k)[0]) # pop out of list # setattr(d, k, getattr(d, k)[0]) # pop out of list
return x return x

View File

@ -141,16 +141,26 @@ CONFIG_DIR = user_config_dir() # Ultralytics settings dir
class Profile(contextlib.ContextDecorator): class Profile(contextlib.ContextDecorator):
# Usage: @Profile() decorator or 'with Profile():' context manager # YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
def __init__(self, t=0.0):
self.t = t
self.cuda = torch.cuda.is_available()
def __enter__(self): def __enter__(self):
self.start = time.time() self.start = self.time()
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
print(f'Profile results: {time.time() - self.start:.5f}s') self.dt = self.time() - self.start # delta-time
self.t += self.dt # accumulate dt
def time(self):
if self.cuda:
torch.cuda.synchronize()
return time.time()
class Timeout(contextlib.ContextDecorator): class Timeout(contextlib.ContextDecorator):
# Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager # YOLOv5 Timeout class. Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True): def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
self.seconds = int(seconds) self.seconds = int(seconds)
self.timeout_message = timeout_msg self.timeout_message = timeout_msg

31
val.py
View File

@ -37,7 +37,7 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.common import DetectMultiBackend from models.common import DetectMultiBackend
from utils.callbacks import Callbacks from utils.callbacks import Callbacks
from utils.dataloaders import create_dataloader from utils.dataloaders import create_dataloader
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_yaml, from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_yaml,
coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args, coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
scale_coords, xywh2xyxy, xyxy2xywh) scale_coords, xywh2xyxy, xyxy2xywh)
from utils.metrics import ConfusionMatrix, ap_per_class, box_iou from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
@ -187,26 +187,24 @@ def run(
names = dict(enumerate(names)) names = dict(enumerate(names))
class_map = coco80_to_coco91_class() if is_coco else list(range(1000)) class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95') s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
dt, p, r, f1, mp, mr, map50, map = [0.0, 0.0, 0.0], 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 dt, p, r, f1, mp, mr, map50, map = (Profile(), Profile(), Profile()), 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
loss = torch.zeros(3, device=device) loss = torch.zeros(3, device=device)
jdict, stats, ap, ap_class = [], [], [], [] jdict, stats, ap, ap_class = [], [], [], []
callbacks.run('on_val_start') callbacks.run('on_val_start')
pbar = tqdm(dataloader, desc=s, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar pbar = tqdm(dataloader, desc=s, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
for batch_i, (im, targets, paths, shapes) in enumerate(pbar): for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
callbacks.run('on_val_batch_start') callbacks.run('on_val_batch_start')
t1 = time_sync() with dt[0]:
if cuda: if cuda:
im = im.to(device, non_blocking=True) im = im.to(device, non_blocking=True)
targets = targets.to(device) targets = targets.to(device)
im = im.half() if half else im.float() # uint8 to fp16/32 im = im.half() if half else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0 im /= 255 # 0 - 255 to 0.0 - 1.0
nb, _, height, width = im.shape # batch size, channels, height, width nb, _, height, width = im.shape # batch size, channels, height, width
t2 = time_sync()
dt[0] += t2 - t1
# Inference # Inference
out, train_out = model(im) if training else model(im, augment=augment, val=True) # inference, loss outputs with dt[1]:
dt[1] += time_sync() - t2 out, train_out = model(im) if training else model(im, augment=augment, val=True) # inference, loss outputs
# Loss # Loss
if compute_loss: if compute_loss:
@ -215,9 +213,8 @@ def run(
# NMS # NMS
targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
t3 = time_sync() with dt[2]:
out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls) out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
dt[2] += time_sync() - t3
# Metrics # Metrics
for si, pred in enumerate(out): for si, pred in enumerate(out):
@ -284,7 +281,7 @@ def run(
LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i])) LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
# Print speeds # Print speeds
t = tuple(x / seen * 1E3 for x in dt) # speeds per image t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
if not training: if not training:
shape = (batch_size, 3, imgsz, imgsz) shape = (batch_size, 3, imgsz, imgsz)
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t) LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)