New `smart_inference_mode()` conditional decorator (#8957)
New smart_inference_mode()pull/8958/head
parent
6aed0a7c00
commit
dc38cd03f4
|
@ -44,10 +44,10 @@ from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
|
||||||
from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
|
from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
|
||||||
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
|
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
|
||||||
from utils.plots import Annotator, colors, save_one_box
|
from utils.plots import Annotator, colors, save_one_box
|
||||||
from utils.torch_utils import select_device, time_sync
|
from utils.torch_utils import select_device, smart_inference_mode, time_sync
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@smart_inference_mode()
|
||||||
def run(
|
def run(
|
||||||
weights=ROOT / 'yolov5s.pt', # model.pt path(s)
|
weights=ROOT / 'yolov5s.pt', # model.pt path(s)
|
||||||
source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam
|
source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam
|
||||||
|
|
|
@ -69,7 +69,7 @@ from models.yolo import Detect
|
||||||
from utils.dataloaders import LoadImages
|
from utils.dataloaders import LoadImages
|
||||||
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, check_yaml,
|
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, check_yaml,
|
||||||
colorstr, file_size, print_args, url2file)
|
colorstr, file_size, print_args, url2file)
|
||||||
from utils.torch_utils import select_device
|
from utils.torch_utils import select_device, smart_inference_mode
|
||||||
|
|
||||||
|
|
||||||
def export_formats():
|
def export_formats():
|
||||||
|
@ -455,7 +455,7 @@ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
|
||||||
LOGGER.info(f'\n{prefix} export failure: {e}')
|
LOGGER.info(f'\n{prefix} export failure: {e}')
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@smart_inference_mode()
|
||||||
def run(
|
def run(
|
||||||
data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
|
data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
|
||||||
weights=ROOT / 'yolov5s.pt', # weights path
|
weights=ROOT / 'yolov5s.pt', # weights path
|
||||||
|
|
|
@ -25,7 +25,7 @@ from utils.dataloaders import exif_transpose, letterbox
|
||||||
from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
|
from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
|
||||||
make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
|
make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
|
||||||
from utils.plots import Annotator, colors, save_one_box
|
from utils.plots import Annotator, colors, save_one_box
|
||||||
from utils.torch_utils import copy_attr, time_sync
|
from utils.torch_utils import copy_attr, smart_inference_mode, time_sync
|
||||||
|
|
||||||
|
|
||||||
def autopad(k, p=None): # kernel, padding
|
def autopad(k, p=None): # kernel, padding
|
||||||
|
@ -578,7 +578,7 @@ class AutoShape(nn.Module):
|
||||||
m.anchor_grid = list(map(fn, m.anchor_grid))
|
m.anchor_grid = list(map(fn, m.anchor_grid))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@torch.no_grad()
|
@smart_inference_mode()
|
||||||
def forward(self, imgs, size=640, augment=False, profile=False):
|
def forward(self, imgs, size=640, augment=False, profile=False):
|
||||||
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
|
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
|
||||||
# file: imgs = 'data/images/zidane.jpg' # str or PosixPath
|
# file: imgs = 'data/images/zidane.jpg' # str or PosixPath
|
||||||
|
|
|
@ -76,12 +76,12 @@ class Detect(nn.Module):
|
||||||
|
|
||||||
return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
|
return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
|
||||||
|
|
||||||
def _make_grid(self, nx=20, ny=20, i=0):
|
def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, '1.10.0')):
|
||||||
d = self.anchors[i].device
|
d = self.anchors[i].device
|
||||||
t = self.anchors[i].dtype
|
t = self.anchors[i].dtype
|
||||||
shape = 1, self.na, ny, nx, 2 # grid shape
|
shape = 1, self.na, ny, nx, 2 # grid shape
|
||||||
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
|
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
|
||||||
if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
|
if torch_1_10: # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
|
||||||
yv, xv = torch.meshgrid(y, x, indexing='ij')
|
yv, xv = torch.meshgrid(y, x, indexing='ij')
|
||||||
else:
|
else:
|
||||||
yv, xv = torch.meshgrid(y, x)
|
yv, xv = torch.meshgrid(y, x)
|
||||||
|
|
|
@ -34,6 +34,14 @@ except ImportError:
|
||||||
warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
|
warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
|
||||||
|
|
||||||
|
|
||||||
|
def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
|
||||||
|
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
|
||||||
|
def decorate(fn):
|
||||||
|
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
|
||||||
|
|
||||||
|
return decorate
|
||||||
|
|
||||||
|
|
||||||
def smart_DDP(model):
|
def smart_DDP(model):
|
||||||
# Model DDP creation with checks
|
# Model DDP creation with checks
|
||||||
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
|
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
|
||||||
|
@ -364,9 +372,9 @@ class ModelEMA:
|
||||||
for p in self.ema.parameters():
|
for p in self.ema.parameters():
|
||||||
p.requires_grad_(False)
|
p.requires_grad_(False)
|
||||||
|
|
||||||
|
@smart_inference_mode()
|
||||||
def update(self, model):
|
def update(self, model):
|
||||||
# Update EMA parameters
|
# Update EMA parameters
|
||||||
with torch.no_grad():
|
|
||||||
self.updates += 1
|
self.updates += 1
|
||||||
d = self.decay(self.updates)
|
d = self.decay(self.updates)
|
||||||
|
|
||||||
|
|
4
val.py
4
val.py
|
@ -42,7 +42,7 @@ from utils.general import (LOGGER, check_dataset, check_img_size, check_requirem
|
||||||
scale_coords, xywh2xyxy, xyxy2xywh)
|
scale_coords, xywh2xyxy, xyxy2xywh)
|
||||||
from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
|
from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
|
||||||
from utils.plots import output_to_target, plot_images, plot_val_study
|
from utils.plots import output_to_target, plot_images, plot_val_study
|
||||||
from utils.torch_utils import select_device, time_sync
|
from utils.torch_utils import select_device, smart_inference_mode, time_sync
|
||||||
|
|
||||||
|
|
||||||
def save_one_txt(predn, save_conf, shape, file):
|
def save_one_txt(predn, save_conf, shape, file):
|
||||||
|
@ -93,7 +93,7 @@ def process_batch(detections, labels, iouv):
|
||||||
return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
|
return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@smart_inference_mode()
|
||||||
def run(
|
def run(
|
||||||
data,
|
data,
|
||||||
weights=None, # model.pt path(s)
|
weights=None, # model.pt path(s)
|
||||||
|
|
Loading…
Reference in New Issue