diff --git a/utils/add_nms.py b/utils/add_nms.py index 927355e..0a1f797 100644 --- a/utils/add_nms.py +++ b/utils/add_nms.py @@ -111,22 +111,22 @@ class RegisterNMS(object): # NMS Outputs output_num_detections = gs.Variable( - name="num_detections", + name="num_dets", dtype=np.int32, shape=[self.batch_size, 1], ) # A scalar indicating the number of valid detections per batch image. output_boxes = gs.Variable( - name="detection_boxes", + name="det_boxes", dtype=dtype_output, shape=[self.batch_size, detections_per_img, 4], ) output_scores = gs.Variable( - name="detection_scores", + name="det_scores", dtype=dtype_output, shape=[self.batch_size, detections_per_img], ) output_labels = gs.Variable( - name="detection_classes", + name="det_classes", dtype=np.int32, shape=[self.batch_size, detections_per_img], )