`torch.empty()` for speed improvements (#9025)
`torch.empty()` for speed improvement Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/9026/head
parent
d40cd0d454
commit
61adf017f2
|
@ -531,7 +531,7 @@ class DetectMultiBackend(nn.Module):
|
|||
# Warmup model by running inference once
|
||||
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb
|
||||
if any(warmup_types) and self.device.type != 'cpu':
|
||||
im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
||||
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
||||
for _ in range(2 if self.jit else 1): #
|
||||
self.forward(im) # warmup
|
||||
|
||||
|
@ -600,7 +600,7 @@ class AutoShape(nn.Module):
|
|||
|
||||
dt = (Profile(), Profile(), Profile())
|
||||
with dt[0]:
|
||||
p = next(self.model.parameters()) if self.pt else torch.zeros(1, device=self.model.device) # param
|
||||
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
|
||||
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
|
||||
if isinstance(ims, torch.Tensor): # torch
|
||||
with amp.autocast(autocast):
|
||||
|
|
|
@ -46,8 +46,8 @@ class Detect(nn.Module):
|
|||
self.no = nc + 5 # number of outputs per anchor
|
||||
self.nl = len(anchors) # number of detection layers
|
||||
self.na = len(anchors[0]) // 2 # number of anchors
|
||||
self.grid = [torch.zeros(1)] * self.nl # init grid
|
||||
self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
|
||||
self.grid = [torch.empty(1)] * self.nl # init grid
|
||||
self.anchor_grid = [torch.empty(1)] * self.nl # init anchor grid
|
||||
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
|
||||
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
||||
self.inplace = inplace # use inplace ops (e.g. slice assignment)
|
||||
|
@ -175,7 +175,7 @@ class DetectionModel(BaseModel):
|
|||
if isinstance(m, Detect):
|
||||
s = 256 # 2x min stride
|
||||
m.inplace = self.inplace
|
||||
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
|
||||
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.empty(1, ch, s, s))]) # forward
|
||||
check_anchor_order(m) # must be in pixel-space (not grid-space)
|
||||
m.anchors /= m.stride.view(-1, 1, 1)
|
||||
self.stride = m.stride
|
||||
|
|
|
@ -47,7 +47,7 @@ def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
|
|||
# Profile batch sizes
|
||||
batch_sizes = [1, 2, 4, 8, 16]
|
||||
try:
|
||||
img = [torch.zeros(b, 3, imgsz, imgsz) for b in batch_sizes]
|
||||
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
|
||||
results = profile(img, model, n=3, device=device)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'{prefix}{e}')
|
||||
|
|
|
@ -300,7 +300,7 @@ def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
|
|||
try:
|
||||
p = next(model.parameters()) # for device, type
|
||||
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz # expand
|
||||
im = torch.zeros((1, 3, *imgsz)).to(p.device).type_as(p) # input image
|
||||
im = torch.empty((1, 3, *imgsz)).to(p.device).type_as(p) # input image
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore') # suppress jit trace warning
|
||||
tb.add_graph(torch.jit.trace(de_parallel(model), im, strict=False), [])
|
||||
|
|
|
@ -282,7 +282,7 @@ def model_info(model, verbose=False, imgsz=640):
|
|||
try: # FLOPs
|
||||
p = next(model.parameters())
|
||||
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
|
||||
im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||||
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||||
flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
|
||||
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
||||
fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs' # 640x640 GFLOPs
|
||||
|
|
Loading…
Reference in New Issue