mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
AutoShape explicit arguments fix (#9443)
* AutoShape explicit arguments fix Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update common.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
03f2ca8eff
commit
2ac4b634c7
@ -633,7 +633,7 @@ class AutoShape(nn.Module):
|
||||
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
|
||||
if isinstance(ims, torch.Tensor): # torch
|
||||
with amp.autocast(autocast):
|
||||
return self.model(ims.to(p.device).type_as(p), augment, profile) # inference
|
||||
return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
|
||||
|
||||
# Pre-process
|
||||
n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
|
||||
@ -662,7 +662,7 @@ class AutoShape(nn.Module):
|
||||
with amp.autocast(autocast):
|
||||
# Inference
|
||||
with dt[1]:
|
||||
y = self.model(x, augment, profile) # forward
|
||||
y = self.model(x, augment=augment) # forward
|
||||
|
||||
# Post-process
|
||||
with dt[2]:
|
||||
@ -696,7 +696,7 @@ class Detections:
|
||||
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
||||
self.n = len(self.pred) # number of images (batch size)
|
||||
self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
|
||||
self.s = shape # inference BCHW shape
|
||||
self.s = tuple(shape) # inference BCHW shape
|
||||
|
||||
def display(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
|
||||
crops = []
|
||||
@ -726,7 +726,7 @@ class Detections:
|
||||
|
||||
im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
|
||||
if pprint:
|
||||
print(s.rstrip(', '))
|
||||
LOGGER.info(s.rstrip(', '))
|
||||
if show:
|
||||
im.show(self.files[i]) # show
|
||||
if save:
|
||||
@ -743,7 +743,7 @@ class Detections:
|
||||
|
||||
def print(self):
|
||||
self.display(pprint=True) # print results
|
||||
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' % self.t)
|
||||
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t)
|
||||
|
||||
def show(self, labels=True):
|
||||
self.display(show=True, labels=labels) # show results
|
||||
|
Loading…
x
Reference in New Issue
Block a user