Unified output names (#283)

pull/368/head
tripleMu 2022-07-29 09:58:57 +08:00 committed by GitHub
parent 5f7d38b12a
commit 6bacefff5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 4 deletions

View File

@ -111,22 +111,22 @@ class RegisterNMS(object):
# NMS Outputs # NMS Outputs
output_num_detections = gs.Variable( output_num_detections = gs.Variable(
name="num_detections", name="num_dets",
dtype=np.int32, dtype=np.int32,
shape=[self.batch_size, 1], shape=[self.batch_size, 1],
) # A scalar indicating the number of valid detections per batch image. ) # A scalar indicating the number of valid detections per batch image.
output_boxes = gs.Variable( output_boxes = gs.Variable(
name="detection_boxes", name="det_boxes",
dtype=dtype_output, dtype=dtype_output,
shape=[self.batch_size, detections_per_img, 4], shape=[self.batch_size, detections_per_img, 4],
) )
output_scores = gs.Variable( output_scores = gs.Variable(
name="detection_scores", name="det_scores",
dtype=dtype_output, dtype=dtype_output,
shape=[self.batch_size, detections_per_img], shape=[self.batch_size, detections_per_img],
) )
output_labels = gs.Variable( output_labels = gs.Variable(
name="detection_classes", name="det_classes",
dtype=np.int32, dtype=np.int32,
shape=[self.batch_size, detections_per_img], shape=[self.batch_size, detections_per_img],
) )