Streaming --save-txt bug fix (#1672)

* Streaming --save-txt bug fix

* cleanup
pull/1677/head
Glenn Jocher 2020-12-11 15:45:32 -08:00 committed by GitHub
parent bc52ea2d5f
commit 54043a9fa4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 11 deletions

View File

@ -81,12 +81,12 @@ def detect(save_img=False):
# Process detections
for i, det in enumerate(pred): # detections per image
if webcam: # batch_size >= 1
p, s, im0 = Path(path[i]), '%g: ' % i, im0s[i].copy()
p, s, im0, frame = Path(path[i]), '%g: ' % i, im0s[i].copy(), dataset.count
else:
p, s, im0 = Path(path), '', im0s
p, s, im0, frame = Path(path), '', im0s, getattr(dataset, 'frame', 0)
save_path = str(save_dir / p.name)
txt_path = str(save_dir / 'labels' / p.stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')
s += '%gx%g ' % img.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
if len(det):
@ -96,7 +96,7 @@ def detect(save_img=False):
# Print results
for c in det[:, -1].unique():
n = (det[:, -1] == c).sum() # detections per class
s += '%g %ss, ' % (n, names[int(c)]) # add to string
s += f'{n} {names[int(c)]}s, ' # add to string
# Write results
for *xyxy, conf, cls in reversed(det):
@ -107,11 +107,11 @@ def detect(save_img=False):
f.write(('%g ' * len(line)).rstrip() % line + '\n')
if save_img or view_img: # Add bbox to image
label = '%s %.2f' % (names[int(cls)], conf)
label = f'{names[int(cls)]} {conf:.2f}'
plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
# Print time (inference + NMS)
print('%sDone. (%.3fs)' % (s, t2 - t1))
print(f'{s}Done. ({t2 - t1:.3f}s)')
# Stream results
if view_img:
@ -121,9 +121,9 @@ def detect(save_img=False):
# Save results (image with detections)
if save_img:
if dataset.mode == 'images':
if dataset.mode == 'image':
cv2.imwrite(save_path, im0)
else:
else: # 'video'
if vid_path != save_path: # new video
vid_path = save_path
if isinstance(vid_writer, cv2.VideoWriter):
@ -140,7 +140,7 @@ def detect(save_img=False):
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
print(f"Results saved to {save_dir}{s}")
print('Done. (%.3fs)' % (time.time() - t0))
print(f'Done. ({time.time() - t0:.3f}s)')
if __name__ == '__main__':

View File

@ -138,7 +138,7 @@ class LoadImages: # for inference
self.files = images + videos
self.nf = ni + nv # number of files
self.video_flag = [False] * ni + [True] * nv
self.mode = 'images'
self.mode = 'image'
if any(videos):
self.new_video(videos[0]) # new video
else:
@ -256,7 +256,7 @@ class LoadWebcam: # for inference
class LoadStreams: # multiple IP or RTSP cameras
def __init__(self, sources='streams.txt', img_size=640):
self.mode = 'images'
self.mode = 'stream'
self.img_size = img_size
if os.path.isfile(sources):