mirror of https://github.com/JosephKJ/OWOD.git
Updates
parent
ae9cec1c8a
commit
857ac8c9ef
|
@ -110,15 +110,6 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
|||
cls[i] = self.unknown_class_index
|
||||
return cls
|
||||
|
||||
def update_labels_based_on_softmax(self, logits, classes, thresold=0.9):
|
||||
cls = classes
|
||||
if len(logits) <= 0:
|
||||
return cls
|
||||
scores = torch.max(torch.nn.functional.softmax(logits[:, :self.num_seen_classes], dim=1), dim=1)[0]
|
||||
for i, s in enumerate(scores):
|
||||
if s < thresold:
|
||||
cls[i] = self.unknown_class_index
|
||||
return cls
|
||||
|
||||
def process(self, inputs, outputs):
|
||||
for input, output in zip(inputs, outputs):
|
||||
|
@ -128,8 +119,7 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
|||
scores = instances.scores.tolist()
|
||||
classes = instances.pred_classes.tolist()
|
||||
logits = instances.logits
|
||||
# classes = self.update_label_based_on_energy(logits, classes)
|
||||
classes = self.update_labels_based_on_softmax(logits, classes)
|
||||
classes = self.update_label_based_on_energy(logits, classes)
|
||||
for box, score, cls in zip(boxes, scores, classes):
|
||||
if cls == -100:
|
||||
continue
|
||||
|
|
Loading…
Reference in New Issue