`torch.empty()` for speed improvements (#9025)

`torch.empty()` for speed improvement

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/9026/head
Glenn Jocher 2022-08-18 20:12:33 +02:00 committed by GitHub
parent d40cd0d454
commit 61adf017f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 8 additions and 8 deletions

View File

@ -531,7 +531,7 @@ class DetectMultiBackend(nn.Module):
# Warmup model by running inference once
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb
if any(warmup_types) and self.device.type != 'cpu':
im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
for _ in range(2 if self.jit else 1): #
self.forward(im) # warmup
@ -600,7 +600,7 @@ class AutoShape(nn.Module):
dt = (Profile(), Profile(), Profile())
with dt[0]:
p = next(self.model.parameters()) if self.pt else torch.zeros(1, device=self.model.device) # param
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
if isinstance(ims, torch.Tensor): # torch
with amp.autocast(autocast):

View File

@ -46,8 +46,8 @@ class Detect(nn.Module):
self.no = nc + 5 # number of outputs per anchor
self.nl = len(anchors) # number of detection layers
self.na = len(anchors[0]) // 2 # number of anchors
self.grid = [torch.zeros(1)] * self.nl # init grid
self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
self.grid = [torch.empty(1)] * self.nl # init grid
self.anchor_grid = [torch.empty(1)] * self.nl # init anchor grid
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
self.inplace = inplace # use inplace ops (e.g. slice assignment)
@ -175,7 +175,7 @@ class DetectionModel(BaseModel):
if isinstance(m, Detect):
s = 256 # 2x min stride
m.inplace = self.inplace
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.empty(1, ch, s, s))]) # forward
check_anchor_order(m) # must be in pixel-space (not grid-space)
m.anchors /= m.stride.view(-1, 1, 1)
self.stride = m.stride

View File

@ -47,7 +47,7 @@ def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
# Profile batch sizes
batch_sizes = [1, 2, 4, 8, 16]
try:
img = [torch.zeros(b, 3, imgsz, imgsz) for b in batch_sizes]
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
results = profile(img, model, n=3, device=device)
except Exception as e:
LOGGER.warning(f'{prefix}{e}')

View File

@ -300,7 +300,7 @@ def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
try:
p = next(model.parameters()) # for device, type
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz # expand
im = torch.zeros((1, 3, *imgsz)).to(p.device).type_as(p) # input image
im = torch.empty((1, 3, *imgsz)).to(p.device).type_as(p) # input image
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
tb.add_graph(torch.jit.trace(de_parallel(model), im, strict=False), [])

View File

@ -282,7 +282,7 @@ def model_info(model, verbose=False, imgsz=640):
try: # FLOPs
p = next(model.parameters())
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs' # 640x640 GFLOPs