mirror of https://github.com/JosephKJ/OWOD.git
Updates
parent
2d33beb021
commit
0f842492a2
|
@ -1,4 +1,4 @@
|
|||
### Towards Open World Object Detection
|
||||
## Towards Open World Object Detection
|
||||
#### CVPR 2021
|
||||
|
||||
Humans have a natural instinct to identify unknown object instances in their environments. The intrinsic curiosity about these unknown instances aids in learning about them, when the corresponding knowledge is eventually available. This motivates us to propose a novel computer vision problem called: Open World Object Detection, where a model is tasked to:
|
||||
|
|
|
@ -35,7 +35,7 @@ def compute_prob(x, distribution):
|
|||
def update_label_based_on_energy(logits, classes, unk_dist, known_dist):
|
||||
unknown_class_index = 80
|
||||
cls = classes
|
||||
lse = torch.logsumexp(logits[:, :10], dim=1)
|
||||
lse = torch.logsumexp(logits[:, :5], dim=1)
|
||||
for i, energy in enumerate(lse):
|
||||
p_unk = compute_prob(energy, unk_dist)
|
||||
p_known = compute_prob(energy, known_dist)
|
||||
|
@ -47,12 +47,15 @@ def update_label_based_on_energy(logits, classes, unk_dist, known_dist):
|
|||
return cls
|
||||
|
||||
# Get image
|
||||
fnum = '451953'
|
||||
fnum = '348006'
|
||||
file_name = '000000' + fnum
|
||||
im = cv2.imread("/home/fk1/workspace/OWOD/datasets/VOC2007/JPEGImages/" + file_name + ".jpg")
|
||||
model = '/home/fk1/workspace/OWOD/output/t1_THRESHOLD_AUTOLABEL_UNK/model_final.pth'
|
||||
# model = '/home/fk1/workspace/OWOD/output/old/t1_20_class/model_0009999.pth'
|
||||
# model = '/home/fk1/workspace/OWOD/output/t1_THRESHOLD_AUTOLABEL_UNK/model_final.pth'
|
||||
# model = '/home/fk1/workspace/OWOD/output/t1_clustering_with_save/model_final.pth'
|
||||
# model = '/home/fk1/workspace/OWOD/output/old/t1_20_class/model_final.pth'
|
||||
# model = '/home/fk1/workspace/OWOD/output/t2_ft/model_final.pth'
|
||||
# model = '/home/fk1/workspace/OWOD/output/t3_ft/model_final.pth'
|
||||
model = '/home/fk1/workspace/OWOD/output/t4_ft/model_final.pth'
|
||||
cfg_file = '/home/fk1/workspace/OWOD/configs/OWOD/t1/t1_test.yaml'
|
||||
|
||||
|
||||
|
@ -60,7 +63,13 @@ cfg_file = '/home/fk1/workspace/OWOD/configs/OWOD/t1/t1_test.yaml'
|
|||
cfg = get_cfg()
|
||||
cfg.merge_from_file(cfg_file)
|
||||
cfg.MODEL.WEIGHTS = model
|
||||
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.27
|
||||
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.61
|
||||
# cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.8
|
||||
cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.4
|
||||
|
||||
# POSITIVE_FRACTION: 0.25
|
||||
# NMS_THRESH_TEST: 0.5
|
||||
# SCORE_THRESH_TEST: 0.05
|
||||
# cfg.MODEL.ROI_HEADS.NUM_CLASSES = 21
|
||||
|
||||
predictor = DefaultPredictor(cfg)
|
||||
|
|
Loading…
Reference in New Issue