New `smart_inference_mode()` conditional decorator (#8957)
New smart_inference_mode()pull/8958/head
parent
6aed0a7c00
commit
dc38cd03f4
|
@ -44,10 +44,10 @@ from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
|
|||
from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
|
||||
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
|
||||
from utils.plots import Annotator, colors, save_one_box
|
||||
from utils.torch_utils import select_device, time_sync
|
||||
from utils.torch_utils import select_device, smart_inference_mode, time_sync
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@smart_inference_mode()
|
||||
def run(
|
||||
weights=ROOT / 'yolov5s.pt', # model.pt path(s)
|
||||
source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam
|
||||
|
|
|
@ -69,7 +69,7 @@ from models.yolo import Detect
|
|||
from utils.dataloaders import LoadImages
|
||||
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, check_yaml,
|
||||
colorstr, file_size, print_args, url2file)
|
||||
from utils.torch_utils import select_device
|
||||
from utils.torch_utils import select_device, smart_inference_mode
|
||||
|
||||
|
||||
def export_formats():
|
||||
|
@ -455,7 +455,7 @@ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
|
|||
LOGGER.info(f'\n{prefix} export failure: {e}')
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@smart_inference_mode()
|
||||
def run(
|
||||
data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
|
||||
weights=ROOT / 'yolov5s.pt', # weights path
|
||||
|
|
|
@ -25,7 +25,7 @@ from utils.dataloaders import exif_transpose, letterbox
|
|||
from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
|
||||
make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
|
||||
from utils.plots import Annotator, colors, save_one_box
|
||||
from utils.torch_utils import copy_attr, time_sync
|
||||
from utils.torch_utils import copy_attr, smart_inference_mode, time_sync
|
||||
|
||||
|
||||
def autopad(k, p=None): # kernel, padding
|
||||
|
@ -578,7 +578,7 @@ class AutoShape(nn.Module):
|
|||
m.anchor_grid = list(map(fn, m.anchor_grid))
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
@smart_inference_mode()
|
||||
def forward(self, imgs, size=640, augment=False, profile=False):
|
||||
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
|
||||
# file: imgs = 'data/images/zidane.jpg' # str or PosixPath
|
||||
|
|
|
@ -76,12 +76,12 @@ class Detect(nn.Module):
|
|||
|
||||
return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
|
||||
|
||||
def _make_grid(self, nx=20, ny=20, i=0):
|
||||
def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, '1.10.0')):
|
||||
d = self.anchors[i].device
|
||||
t = self.anchors[i].dtype
|
||||
shape = 1, self.na, ny, nx, 2 # grid shape
|
||||
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
|
||||
if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
|
||||
if torch_1_10: # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
|
||||
yv, xv = torch.meshgrid(y, x, indexing='ij')
|
||||
else:
|
||||
yv, xv = torch.meshgrid(y, x)
|
||||
|
|
|
@ -34,6 +34,14 @@ except ImportError:
|
|||
warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
|
||||
|
||||
|
||||
def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
|
||||
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
|
||||
def decorate(fn):
|
||||
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def smart_DDP(model):
|
||||
# Model DDP creation with checks
|
||||
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
|
||||
|
@ -364,9 +372,9 @@ class ModelEMA:
|
|||
for p in self.ema.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
@smart_inference_mode()
|
||||
def update(self, model):
|
||||
# Update EMA parameters
|
||||
with torch.no_grad():
|
||||
self.updates += 1
|
||||
d = self.decay(self.updates)
|
||||
|
||||
|
|
4
val.py
4
val.py
|
@ -42,7 +42,7 @@ from utils.general import (LOGGER, check_dataset, check_img_size, check_requirem
|
|||
scale_coords, xywh2xyxy, xyxy2xywh)
|
||||
from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
|
||||
from utils.plots import output_to_target, plot_images, plot_val_study
|
||||
from utils.torch_utils import select_device, time_sync
|
||||
from utils.torch_utils import select_device, smart_inference_mode, time_sync
|
||||
|
||||
|
||||
def save_one_txt(predn, save_conf, shape, file):
|
||||
|
@ -93,7 +93,7 @@ def process_batch(detections, labels, iouv):
|
|||
return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@smart_inference_mode()
|
||||
def run(
|
||||
data,
|
||||
weights=None, # model.pt path(s)
|
||||
|
|
Loading…
Reference in New Issue