New `smart_inference_mode()` conditional decorator (#8957)

New smart_inference_mode()
pull/8958/head
Glenn Jocher 2022-08-13 20:38:51 +02:00 committed by GitHub
parent 6aed0a7c00
commit dc38cd03f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 26 additions and 18 deletions

View File

@ -44,10 +44,10 @@ from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device, time_sync
from utils.torch_utils import select_device, smart_inference_mode, time_sync
@torch.no_grad()
@smart_inference_mode()
def run(
weights=ROOT / 'yolov5s.pt', # model.pt path(s)
source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam

View File

@ -69,7 +69,7 @@ from models.yolo import Detect
from utils.dataloaders import LoadImages
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, check_yaml,
colorstr, file_size, print_args, url2file)
from utils.torch_utils import select_device
from utils.torch_utils import select_device, smart_inference_mode
def export_formats():
@ -455,7 +455,7 @@ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
LOGGER.info(f'\n{prefix} export failure: {e}')
@torch.no_grad()
@smart_inference_mode()
def run(
data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
weights=ROOT / 'yolov5s.pt', # weights path

View File

@ -25,7 +25,7 @@ from utils.dataloaders import exif_transpose, letterbox
from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import copy_attr, time_sync
from utils.torch_utils import copy_attr, smart_inference_mode, time_sync
def autopad(k, p=None): # kernel, padding
@ -578,7 +578,7 @@ class AutoShape(nn.Module):
m.anchor_grid = list(map(fn, m.anchor_grid))
return self
@torch.no_grad()
@smart_inference_mode()
def forward(self, imgs, size=640, augment=False, profile=False):
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
# file: imgs = 'data/images/zidane.jpg' # str or PosixPath

View File

@ -76,12 +76,12 @@ class Detect(nn.Module):
return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
def _make_grid(self, nx=20, ny=20, i=0):
def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, '1.10.0')):
d = self.anchors[i].device
t = self.anchors[i].dtype
shape = 1, self.na, ny, nx, 2 # grid shape
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
if torch_1_10: # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
yv, xv = torch.meshgrid(y, x, indexing='ij')
else:
yv, xv = torch.meshgrid(y, x)

View File

@ -34,6 +34,14 @@ except ImportError:
warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
def decorate(fn):
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
return decorate
def smart_DDP(model):
# Model DDP creation with checks
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
@ -364,9 +372,9 @@ class ModelEMA:
for p in self.ema.parameters():
p.requires_grad_(False)
@smart_inference_mode()
def update(self, model):
# Update EMA parameters
with torch.no_grad():
self.updates += 1
d = self.decay(self.updates)

4
val.py
View File

@ -42,7 +42,7 @@ from utils.general import (LOGGER, check_dataset, check_img_size, check_requirem
scale_coords, xywh2xyxy, xyxy2xywh)
from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
from utils.plots import output_to_target, plot_images, plot_val_study
from utils.torch_utils import select_device, time_sync
from utils.torch_utils import select_device, smart_inference_mode, time_sync
def save_one_txt(predn, save_conf, shape, file):
@ -93,7 +93,7 @@ def process_batch(detections, labels, iouv):
return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
@torch.no_grad()
@smart_inference_mode()
def run(
data,
weights=None, # model.pt path(s)