`Detect.inplace=False` for multithread-safe inference (#8801)

Detect.inplace=False for safe multithread inference
pull/8804/head
Glenn Jocher 2022-07-30 22:19:40 +02:00 committed by GitHub
parent 7921351b4e
commit 1e89807d9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 1 deletions

View File

@ -55,6 +55,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
if len(ckpt['model'].names) == classes:
model.names = ckpt['model'].names # set class names attribute
if autoshape:
model.model.model[-1].inplace = False # Detect.inplace=False for safe multithread inference
model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS
if not verbose:
LOGGER.setLevel(logging.INFO) # reset to default

View File

@ -50,7 +50,7 @@ class Detect(nn.Module):
self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
self.inplace = inplace # use in-place ops (e.g. slice assignment)
self.inplace = inplace # use inplace ops (e.g. slice assignment)
def forward(self, x):
z = [] # inference output