mirror of https://github.com/JosephKJ/OWOD.git
parent
ea9f7d15f8
commit
81516e5543
|
@ -28,6 +28,7 @@ OWOD:
|
|||
NUM_UNK_PER_IMAGE: 1
|
||||
ENABLE_UNCERTAINITY_AUTOLABEL_UNK: False
|
||||
ENABLE_CLUSTERING: True
|
||||
FEATURE_STORE_SAVE_PATH: 'feature_store'
|
||||
CLUSTERING:
|
||||
ITEMS_PER_CLASS: 20
|
||||
START_ITER: 1000
|
||||
|
|
|
@ -1,10 +1,18 @@
|
|||
_BASE_: "../../Base-RCNN-C4-OWOD.yaml"
|
||||
MODEL:
|
||||
# WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
||||
WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_ENABLE_CLUSTERING/model_final.pth"
|
||||
WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_std_frcnn/model_final.pth"
|
||||
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_THRESHOLD_AUTOLABEL_UNK/model_final.pth"
|
||||
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_ENABLE_CLUSTERING/model_final.pth"
|
||||
ROI_HEADS:
|
||||
POSITIVE_FRACTION: 0.25
|
||||
NMS_THRESH_TEST: 0.5
|
||||
SCORE_THRESH_TEST: 0.05
|
||||
TEST:
|
||||
DETECTIONS_PER_IMAGE: 50
|
||||
DATASETS:
|
||||
TRAIN: ('t1_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
|
||||
TEST: ('voc_coco_2007_test', ) # voc_coco_2007_test
|
||||
TEST: ('voc_coco_2007_test', 't1_voc_coco_2007_known_test') # voc_coco_2007_test
|
||||
SOLVER:
|
||||
STEPS: (12000, 16000)
|
||||
MAX_ITER: 18000
|
||||
|
@ -12,4 +20,8 @@ SOLVER:
|
|||
OUTPUT_DIR: "./output/t1_clustering_new_4"
|
||||
OWOD:
|
||||
PREV_INTRODUCED_CLS: 0
|
||||
CUR_INTRODUCED_CLS: 20
|
||||
CUR_INTRODUCED_CLS: 20
|
||||
|
||||
# POSITIVE_FRACTION: 0.25
|
||||
# NMS_THRESH_TEST: 0.5
|
||||
# SCORE_THRESH_TEST: 0.05
|
|
@ -2,14 +2,15 @@ _BASE_: "../../Base-RCNN-C4-OWOD.yaml"
|
|||
MODEL:
|
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
||||
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1/model_final.pth"
|
||||
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_std_frcnn/model_final.pth"
|
||||
DATASETS:
|
||||
TRAIN: ('t1_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
|
||||
TEST: (, ) # voc_coco_2007_test
|
||||
TEST: ('voc_coco_2007_test', ) # voc_coco_2007_test, t1_voc_coco_2007_test, t1_voc_coco_2007_val
|
||||
SOLVER:
|
||||
STEPS: (12000, 16000)
|
||||
MAX_ITER: 18000
|
||||
WARMUP_ITERS: 100
|
||||
OUTPUT_DIR: "./output/t1_expr"
|
||||
OUTPUT_DIR: "./output/t1_clustering_with_save"
|
||||
OWOD:
|
||||
PREV_INTRODUCED_CLS: 0
|
||||
CUR_INTRODUCED_CLS: 20
|
|
@ -1,7 +1,8 @@
|
|||
_BASE_: "../Base-RCNN-C4-OWOD.yaml"
|
||||
MODEL:
|
||||
# WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
||||
WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_ENABLE_CLUSTERING/model_final.pth"
|
||||
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_ENABLE_CLUSTERING/model_final.pth"
|
||||
WEIGHTS: "/home/fk1/workspace/OWOD/output/t2/model_final.pth"
|
||||
DATASETS:
|
||||
TRAIN: ('t2_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
|
||||
TEST: ('voc_coco_2007_test', )
|
||||
|
@ -12,4 +13,10 @@ SOLVER:
|
|||
OUTPUT_DIR: "./output/t2"
|
||||
OWOD:
|
||||
PREV_INTRODUCED_CLS: 20
|
||||
CUR_INTRODUCED_CLS: 20
|
||||
CUR_INTRODUCED_CLS: 20
|
||||
CLUSTERING:
|
||||
ITEMS_PER_CLASS: 20
|
||||
START_ITER: 20000
|
||||
UPDATE_MU_ITER: 3000
|
||||
MOMENTUM: 0.99
|
||||
Z_DIMENSION: 128
|
File diff suppressed because it is too large
Load Diff
|
@ -616,6 +616,7 @@ _C.OWOD.CUR_INTRODUCED_CLS = 20
|
|||
_C.OWOD.COMPUTE_ENERGY = False
|
||||
_C.OWOD.ENERGY_SAVE_PATH = ''
|
||||
_C.OWOD.SKIP_TRAINING_WHILE_EVAL = False
|
||||
_C.OWOD.FEATURE_STORE_SAVE_PATH = ''
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# Misc options
|
||||
|
|
|
@ -215,6 +215,7 @@ def register_all_pascal_voc(root):
|
|||
("voc_2012_train", "VOC2012", "train"),
|
||||
("voc_2012_val", "VOC2012", "val"),
|
||||
("t1_voc_coco_2007_train", "VOC2007", "t1_train"),
|
||||
("t1_voc_coco_2007_known_test", "VOC2007", "t1_known_test"),
|
||||
("voc_coco_2007_test", "VOC2007", "all_task_test"),
|
||||
("voc_coco_2007_val", "VOC2007", "all_task_val"),
|
||||
("t1_voc_coco_2007_ft", "VOC2007", "t1_ft"),
|
||||
|
|
|
@ -121,6 +121,9 @@ def default_setup(cfg, args):
|
|||
if cfg.OWOD.COMPUTE_ENERGY:
|
||||
PathManager.mkdirs(os.path.join(output_dir, cfg.OWOD.ENERGY_SAVE_PATH))
|
||||
|
||||
if cfg.OWOD.ENABLE_CLUSTERING:
|
||||
PathManager.mkdirs(os.path.join(output_dir, cfg.OWOD.FEATURE_STORE_SAVE_PATH))
|
||||
|
||||
rank = comm.get_rank()
|
||||
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
|
||||
logger = setup_logger(output_dir, distributed_rank=rank)
|
||||
|
|
|
@ -4,8 +4,10 @@
|
|||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import xml.etree.ElementTree as ET
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import OrderedDict, defaultdict
|
||||
from functools import lru_cache
|
||||
import torch
|
||||
|
@ -19,6 +21,7 @@ from detectron2.utils import comm
|
|||
|
||||
from .evaluator import DatasetEvaluator
|
||||
|
||||
np.set_printoptions(threshold=sys.maxsize)
|
||||
|
||||
class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
||||
"""
|
||||
|
@ -42,7 +45,8 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
|||
self._image_set_path = os.path.join(meta.dirname, "ImageSets", "Main", meta.split + ".txt")
|
||||
self._class_names = meta.thing_classes
|
||||
assert meta.year in [2007, 2012], meta.year
|
||||
self._is_2007 = meta.year == 2007
|
||||
self._is_2007 = False
|
||||
# self._is_2007 = meta.year == 2007
|
||||
self._cpu_device = torch.device("cpu")
|
||||
self._logger = logging.getLogger(__name__)
|
||||
if cfg is not None:
|
||||
|
@ -54,7 +58,7 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
|||
self.known_classes = self._class_names[:self.num_seen_classes]
|
||||
|
||||
param_save_location = os.path.join(cfg.OUTPUT_DIR,'energy_dist_' + str(self.num_seen_classes) + '.pkl')
|
||||
|
||||
self.energy_distribution_loaded = False
|
||||
if os.path.isfile(param_save_location) and os.access(param_save_location, os.R_OK):
|
||||
self._logger.info('Loading energy distribution from ' + param_save_location)
|
||||
params = torch.load(param_save_location)
|
||||
|
@ -103,7 +107,6 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
|||
cls[i] = -100
|
||||
else:
|
||||
if cls[i] != self.unknown_class_index:
|
||||
# cls[i] = -100
|
||||
cls[i] = self.unknown_class_index
|
||||
return cls
|
||||
|
||||
|
@ -127,6 +130,28 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
|||
f"{image_id} {score:.3f} {xmin:.1f} {ymin:.1f} {xmax:.1f} {ymax:.1f}"
|
||||
)
|
||||
|
||||
def compute_avg_precision_at_many_recall_level(self, precisions, recalls):
|
||||
precs = {}
|
||||
for r in range(1, 10):
|
||||
r = r/10
|
||||
p = self.compute_avg_precision_at_a_recall_level(precisions, recalls, recall_level=r)
|
||||
precs[r] = p
|
||||
return precs
|
||||
|
||||
def compute_avg_precision_at_a_recall_level(self, precisions, recalls, recall_level=0.5):
|
||||
precs = {}
|
||||
for iou, recall in recalls.items():
|
||||
prec = []
|
||||
for cls_id, rec in enumerate(recall):
|
||||
if cls_id in range(self.num_seen_classes) and len(rec)>0:
|
||||
p = precisions[iou][cls_id][min(range(len(rec)), key=lambda i: abs(rec[i] - recall_level))]
|
||||
prec.append(p)
|
||||
if len(prec) > 0:
|
||||
precs[iou] = np.mean(prec)
|
||||
else:
|
||||
precs[iou] = 0
|
||||
return precs
|
||||
|
||||
def evaluate(self):
|
||||
"""
|
||||
Returns:
|
||||
|
@ -152,8 +177,12 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
|||
res_file_template = os.path.join(dirname, "{}.txt")
|
||||
|
||||
aps = defaultdict(list) # iou -> ap per class
|
||||
# recs = defaultdict(list)
|
||||
# precs = defaultdict(list)
|
||||
recs = defaultdict(list)
|
||||
precs = defaultdict(list)
|
||||
all_recs = defaultdict(list)
|
||||
all_precs = defaultdict(list)
|
||||
unk_det_as_knowns = defaultdict(list)
|
||||
num_unks = defaultdict(list)
|
||||
|
||||
for cls_id, cls_name in enumerate(self._class_names):
|
||||
lines = predictions.get(cls_id, [""])
|
||||
|
@ -161,47 +190,75 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
|||
with open(res_file_template.format(cls_name), "w") as f:
|
||||
f.write("\n".join(lines))
|
||||
|
||||
for thresh in range(50, 100, 5):
|
||||
rec, prec, ap = voc_eval(
|
||||
res_file_template,
|
||||
self._anno_file_template,
|
||||
self._image_set_path,
|
||||
cls_name,
|
||||
ovthresh=thresh / 100.0,
|
||||
use_07_metric=self._is_2007,
|
||||
known_classes=self.known_classes
|
||||
)
|
||||
aps[thresh].append(ap * 100)
|
||||
# recs[thresh].append(rec * 100)
|
||||
# precs[thresh].append(prec * 100)
|
||||
# for thresh in range(50, 100, 5):
|
||||
thresh = 50
|
||||
rec, prec, ap, unk_det_as_known, num_unk = voc_eval(
|
||||
res_file_template,
|
||||
self._anno_file_template,
|
||||
self._image_set_path,
|
||||
cls_name,
|
||||
ovthresh=thresh / 100.0,
|
||||
use_07_metric=self._is_2007,
|
||||
known_classes=self.known_classes
|
||||
)
|
||||
aps[thresh].append(ap * 100)
|
||||
unk_det_as_knowns[thresh].append(unk_det_as_known)
|
||||
num_unks[thresh].append(num_unk)
|
||||
all_precs[thresh].append(prec)
|
||||
all_recs[thresh].append(rec)
|
||||
try:
|
||||
recs[thresh].append(rec[-1] * 100)
|
||||
precs[thresh].append(prec[-1] * 100)
|
||||
except:
|
||||
recs[thresh].append(0)
|
||||
precs[thresh].append(0)
|
||||
|
||||
avg_precision = self.compute_avg_precision_at_many_recall_level(all_precs, all_recs)
|
||||
self._logger.info('avg_precision: ' + str(avg_precision))
|
||||
|
||||
ret = OrderedDict()
|
||||
mAP = {iou: np.mean(x) for iou, x in aps.items()}
|
||||
ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]}
|
||||
ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50]}
|
||||
|
||||
total_num_unk_det_as_known = {iou: np.sum(x) for iou, x in unk_det_as_knowns.items()}
|
||||
total_num_unk = num_unks[50][0]
|
||||
self._logger.info('Absolute OSE (total_num_unk_det_as_known): ' + str(total_num_unk_det_as_known))
|
||||
self._logger.info('total_num_unk ' + str(total_num_unk))
|
||||
|
||||
# Extra logging of class-wise APs
|
||||
avg_precs = list(np.mean([x for _, x in aps.items()], axis=0))
|
||||
self._logger.info(self._class_names)
|
||||
self._logger.info("AP__: " + str(['%.1f' % x for x in avg_precs]))
|
||||
# self._logger.info("AP__: " + str(['%.1f' % x for x in avg_precs]))
|
||||
self._logger.info("AP50: " + str(['%.1f' % x for x in aps[50]]))
|
||||
self._logger.info("AP75: " + str(['%.1f' % x for x in aps[75]]))
|
||||
self._logger.info("Precisions50: " + str(['%.1f' % x for x in precs[50]]))
|
||||
self._logger.info("Recall50: " + str(['%.1f' % x for x in recs[50]]))
|
||||
# self._logger.info("AP75: " + str(['%.1f' % x for x in aps[75]]))
|
||||
|
||||
if self.prev_intro_cls > 0:
|
||||
self._logger.info("Prev class AP__: " + str(np.mean(avg_precs[:self.prev_intro_cls])))
|
||||
# self._logger.info("\nPrev class AP__: " + str(np.mean(avg_precs[:self.prev_intro_cls])))
|
||||
self._logger.info("Prev class AP50: " + str(np.mean(aps[50][:self.prev_intro_cls])))
|
||||
self._logger.info("Prev class AP75: " + str(np.mean(aps[75][:self.prev_intro_cls])))
|
||||
self._logger.info("Prev class Precisions50: " + str(np.mean(precs[50][:self.prev_intro_cls])))
|
||||
self._logger.info("Prev class Recall50: " + str(np.mean(recs[50][:self.prev_intro_cls])))
|
||||
|
||||
self._logger.info("Current class AP__: " + str(np.mean(avg_precs[self.prev_intro_cls:self.curr_intro_cls])))
|
||||
self._logger.info("Current class AP50: " + str(np.mean(aps[50][self.prev_intro_cls:self.curr_intro_cls])))
|
||||
self._logger.info("Current class AP75: " + str(np.mean(aps[75][self.prev_intro_cls:self.curr_intro_cls])))
|
||||
# self._logger.info("Prev class AP75: " + str(np.mean(aps[75][:self.prev_intro_cls])))
|
||||
|
||||
self._logger.info("Known AP__: " + str(np.mean(avg_precs[:self.prev_intro_cls + self.curr_intro_cls])))
|
||||
# self._logger.info("\nCurrent class AP__: " + str(np.mean(avg_precs[self.prev_intro_cls:self.curr_intro_cls])))
|
||||
self._logger.info("Current class AP50: " + str(np.mean(aps[50][self.prev_intro_cls:self.prev_intro_cls + self.curr_intro_cls])))
|
||||
self._logger.info("Current class Precisions50: " + str(np.mean(precs[50][self.prev_intro_cls:self.prev_intro_cls + self.curr_intro_cls])))
|
||||
self._logger.info("Current class Recall50: " + str(np.mean(recs[50][self.prev_intro_cls:self.prev_intro_cls + self.curr_intro_cls])))
|
||||
# self._logger.info("Current class AP75: " + str(np.mean(aps[75][self.prev_intro_cls:self.curr_intro_cls])))
|
||||
|
||||
# self._logger.info("\nKnown AP__: " + str(np.mean(avg_precs[:self.prev_intro_cls + self.curr_intro_cls])))
|
||||
self._logger.info("Known AP50: " + str(np.mean(aps[50][:self.prev_intro_cls + self.curr_intro_cls])))
|
||||
self._logger.info("Known AP75: " + str(np.mean(aps[75][:self.prev_intro_cls + self.curr_intro_cls])))
|
||||
self._logger.info("Known Precisions50: " + str(np.mean(precs[50][:self.prev_intro_cls + self.curr_intro_cls])))
|
||||
self._logger.info("Known Recall50: " + str(np.mean(recs[50][:self.prev_intro_cls + self.curr_intro_cls])))
|
||||
# self._logger.info("Known AP75: " + str(np.mean(aps[75][:self.prev_intro_cls + self.curr_intro_cls])))
|
||||
|
||||
self._logger.info("Unknown AP__: " + str(avg_precs[-1]))
|
||||
# self._logger.info("\nUnknown AP__: " + str(avg_precs[-1]))
|
||||
self._logger.info("Unknown AP50: " + str(aps[50][-1]))
|
||||
self._logger.info("Unknown AP75: " + str(aps[75][-1]))
|
||||
self._logger.info("Unknown Precisions50: " + str(precs[50][-1]))
|
||||
self._logger.info("Unknown Recall50: " + str(recs[50][-1]))
|
||||
# self._logger.info("Unknown AP75: " + str(aps[75][-1]))
|
||||
|
||||
# self._logger.info("R__: " + str(['%.1f' % x for x in list(np.mean([x for _, x in recs.items()], axis=0))]))
|
||||
# self._logger.info("R50: " + str(['%.1f' % x for x in recs[50]]))
|
||||
|
@ -362,6 +419,10 @@ def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_me
|
|||
nd = len(image_ids)
|
||||
tp = np.zeros(nd)
|
||||
fp = np.zeros(nd)
|
||||
|
||||
# if 'unknown' not in classname:
|
||||
# return tp, fp, 0
|
||||
|
||||
for d in range(nd):
|
||||
R = class_recs[image_ids[d]]
|
||||
bb = BB[d, :].astype(float)
|
||||
|
@ -407,6 +468,93 @@ def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_me
|
|||
# avoid divide by zero in case the first detection matches a difficult
|
||||
# ground truth
|
||||
prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
|
||||
# plot_pr_curve(prec, rec, classname+'.png')
|
||||
ap = voc_ap(rec, prec, use_07_metric)
|
||||
|
||||
return rec, prec, ap
|
||||
# print('tp: ' + str(tp[-1]))
|
||||
# print('fp: ' + str(fp[-1]))
|
||||
# print('tp: ')
|
||||
# print(tp)
|
||||
# print('fp: ')
|
||||
# print(fp)
|
||||
|
||||
|
||||
'''
|
||||
Computing Open-Set Error (OSE)
|
||||
==============================
|
||||
|
||||
OSE = # of unknown objects classified as known objects of class 'classname'
|
||||
--------------------------------------------------------------------
|
||||
# of unknown objects
|
||||
'''
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Finding GT of unknown objects
|
||||
unknown_class_recs = {}
|
||||
n_unk = 0
|
||||
for imagename in imagenames:
|
||||
R = [obj for obj in recs[imagename] if obj["name"] == 'unknown']
|
||||
bbox = np.array([x["bbox"] for x in R])
|
||||
difficult = np.array([x["difficult"] for x in R]).astype(np.bool)
|
||||
det = [False] * len(R)
|
||||
n_unk = n_unk + sum(~difficult)
|
||||
unknown_class_recs[imagename] = {"bbox": bbox, "difficult": difficult, "det": det}
|
||||
|
||||
if classname == 'unknown':
|
||||
return rec, prec, ap, 0, n_unk
|
||||
|
||||
# Go down each detection and see if it has an overlap with an unknown object.
|
||||
# If so, it is an unknown object that was classified as known.
|
||||
is_unk = np.zeros(nd)
|
||||
for d in range(nd):
|
||||
R = unknown_class_recs[image_ids[d]]
|
||||
bb = BB[d, :].astype(float)
|
||||
ovmax = -np.inf
|
||||
BBGT = R["bbox"].astype(float)
|
||||
|
||||
if BBGT.size > 0:
|
||||
# compute overlaps
|
||||
# intersection
|
||||
ixmin = np.maximum(BBGT[:, 0], bb[0])
|
||||
iymin = np.maximum(BBGT[:, 1], bb[1])
|
||||
ixmax = np.minimum(BBGT[:, 2], bb[2])
|
||||
iymax = np.minimum(BBGT[:, 3], bb[3])
|
||||
iw = np.maximum(ixmax - ixmin + 1.0, 0.0)
|
||||
ih = np.maximum(iymax - iymin + 1.0, 0.0)
|
||||
inters = iw * ih
|
||||
|
||||
# union
|
||||
uni = (
|
||||
(bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0)
|
||||
+ (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0)
|
||||
- inters
|
||||
)
|
||||
|
||||
overlaps = inters / uni
|
||||
ovmax = np.max(overlaps)
|
||||
jmax = np.argmax(overlaps)
|
||||
|
||||
if ovmax > ovthresh:
|
||||
is_unk[d] = 1.0
|
||||
|
||||
is_unk = np.sum(is_unk)
|
||||
# OSE = is_unk / n_unk
|
||||
# logger.info('Number of unknowns detected knowns (for class '+ classname + ') is ' + str(is_unk))
|
||||
# logger.info("Num of unknown instances: " + str(n_unk))
|
||||
# logger.info('OSE: ' + str(OSE))
|
||||
|
||||
return rec, prec, ap, is_unk, n_unk
|
||||
|
||||
|
||||
def plot_pr_curve(precision, recall, filename, base_path='/home/fk1/workspace/OWOD/output/plots/'):
|
||||
fig, ax = plt.subplots()
|
||||
ax.step(recall, precision, color='r', alpha=0.99, where='post')
|
||||
ax.fill_between(recall, precision, alpha=0.2, color='b', step='post')
|
||||
plt.xlabel('Recall')
|
||||
plt.ylabel('Precision')
|
||||
plt.ylim([0.0, 1.05])
|
||||
plt.xlim([0.0, 1.0])
|
||||
plt.savefig(base_path + filename)
|
||||
|
||||
# print(precision)
|
||||
# print(recall)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import logging
|
||||
from typing import Dict, Union
|
||||
import torch
|
||||
import os
|
||||
import math
|
||||
import shortuuid
|
||||
from fvcore.nn import giou_loss, smooth_l1_loss
|
||||
|
@ -414,6 +415,9 @@ class FastRCNNOutputLayers(nn.Module):
|
|||
enable_clustering,
|
||||
prev_intro_cls,
|
||||
curr_intro_cls,
|
||||
max_iterations,
|
||||
output_dir,
|
||||
feat_store_path,
|
||||
num_classes: int,
|
||||
test_score_thresh: float = 0.0,
|
||||
test_nms_thresh: float = 0.5,
|
||||
|
@ -472,9 +476,6 @@ class FastRCNNOutputLayers(nn.Module):
|
|||
self.clustering_update_mu_iter = clustering_update_mu_iter
|
||||
self.clustering_momentum = clustering_momentum
|
||||
|
||||
self.feature_store = Store(num_classes + 1, clustering_items_per_class)
|
||||
self.means = [None for _ in range(num_classes + 1)]
|
||||
|
||||
self.hingeloss = nn.HingeEmbeddingLoss(2)
|
||||
self.enable_clustering = enable_clustering
|
||||
|
||||
|
@ -484,6 +485,22 @@ class FastRCNNOutputLayers(nn.Module):
|
|||
self.invalid_class_range = list(range(self.seen_classes, self.num_classes-1))
|
||||
logging.getLogger(__name__).info("Invalid class range: " + str(self.invalid_class_range))
|
||||
|
||||
self.max_iterations = max_iterations
|
||||
self.feature_store_is_stored = False
|
||||
self.output_dir = output_dir
|
||||
self.feat_store_path = feat_store_path
|
||||
self.feature_store_save_loc = os.path.join(self.output_dir, self.feat_store_path, 'feat.pt')
|
||||
|
||||
if os.path.isfile(self.feature_store_save_loc):
|
||||
self.feature_store = torch.load(self.feature_store_save_loc)
|
||||
logging.getLogger(__name__).info('Loaded feature store from ' + self.feature_store_save_loc)
|
||||
else:
|
||||
logging.getLogger(__name__).info('Feature store not found in ' +
|
||||
self.feature_store_save_loc + '. Creating new feature store.')
|
||||
self.feature_store = Store(num_classes + 1, clustering_items_per_class)
|
||||
self.means = [None for _ in range(num_classes + 1)]
|
||||
|
||||
|
||||
# self.ae_model = AE(input_size, clustering_z_dimension)
|
||||
# self.ae_model.apply(Xavier)
|
||||
|
||||
|
@ -508,7 +525,10 @@ class FastRCNNOutputLayers(nn.Module):
|
|||
"clustering_z_dimension": cfg.OWOD.CLUSTERING.Z_DIMENSION,
|
||||
"enable_clustering" : cfg.OWOD.ENABLE_CLUSTERING,
|
||||
"prev_intro_cls" : cfg.OWOD.PREV_INTRODUCED_CLS,
|
||||
"curr_intro_cls" : cfg.OWOD.CUR_INTRODUCED_CLS
|
||||
"curr_intro_cls" : cfg.OWOD.CUR_INTRODUCED_CLS,
|
||||
"max_iterations" : cfg.SOLVER.MAX_ITER,
|
||||
"output_dir" : cfg.OUTPUT_DIR,
|
||||
"feat_store_path" : cfg.OWOD.FEATURE_STORE_SAVE_PATH,
|
||||
# fmt: on
|
||||
}
|
||||
|
||||
|
@ -536,75 +556,16 @@ class FastRCNNOutputLayers(nn.Module):
|
|||
gt_classes = torch.cat([p.gt_classes for p in proposals])
|
||||
self.feature_store.add(features, gt_classes)
|
||||
|
||||
storage = get_event_storage()
|
||||
|
||||
if storage.iter == self.max_iterations and self.feature_store_is_stored is False:
|
||||
logging.getLogger(__name__).info('Saving image store at iteration ' + str(storage.iter) + ' to ' + self.feature_store_save_loc)
|
||||
torch.save(self.feature_store, self.feature_store_save_loc)
|
||||
self.feature_store_is_stored = True
|
||||
|
||||
# self.feature_store.add(F.normalize(features, dim=0), gt_classes)
|
||||
# self.feature_store.add(self.ae_model.encoder(features), gt_classes)
|
||||
|
||||
# def clstr_loss(self, input_features, proposals):
|
||||
# """
|
||||
# Get the foreground input_features, generate distributions for the class,
|
||||
# get probability of each feature from each distribution;
|
||||
# Compute loss: if belonging to a class -> likelihood should be higher
|
||||
# else -> lower
|
||||
# :param input_features:
|
||||
# :param proposals:
|
||||
# :return:
|
||||
# """
|
||||
# loss = 0
|
||||
# gt_classes = torch.cat([p.gt_classes for p in proposals])
|
||||
# mask = gt_classes != self.num_classes
|
||||
# fg_features = input_features[mask]
|
||||
# classes = gt_classes[mask]
|
||||
# # fg_features = self.ae_model.encoder(fg_features)
|
||||
#
|
||||
# # Distribution per class
|
||||
# log_prob = [None for _ in range(self.num_classes + 1)]
|
||||
# # https://github.com/pytorch/pytorch/issues/23780
|
||||
# for cls_index, mu in enumerate(self.means):
|
||||
# if mu is not None:
|
||||
# dist = Normal(loc=mu.cuda(), scale=torch.ones_like(mu.cuda()))
|
||||
# log_prob[cls_index] = dist.log_prob(fg_features).mean(dim=1)
|
||||
# # log_prob[cls_index] = torch.distributions.multivariate_normal. \
|
||||
# # MultivariateNormal(mu.cuda(), torch.eye(len(mu)).cuda()).log_prob(fg_features)
|
||||
# # MultivariateNormal(mu, torch.eye(len(mu))).log_prob(fg_features.cpu())
|
||||
# # MultivariateNormal(mu[:2], torch.eye(len(mu[:2]))).log_prob(fg_features[:,:2].cpu())
|
||||
# else:
|
||||
# log_prob[cls_index] = torch.zeros((len(fg_features))).cuda()
|
||||
#
|
||||
# log_prob = torch.stack(log_prob).T # num_of_fg_proposals x num_of_classes
|
||||
# for i, p in enumerate(log_prob):
|
||||
# weight = torch.ones_like(p) * -1
|
||||
# weight[classes[i]] = 1
|
||||
# p = p * weight
|
||||
# loss += p.mean()
|
||||
# return loss
|
||||
|
||||
# def clstr_loss_l2(self, input_features, proposals):
|
||||
# """
|
||||
# Get the foreground input_features, generate distributions for the class,
|
||||
# get probability of each feature from each distribution;
|
||||
# Compute loss: if belonging to a class -> likelihood should be higher
|
||||
# else -> lower
|
||||
# :param input_features:
|
||||
# :param proposals:
|
||||
# :return:
|
||||
# """
|
||||
# loss = 0
|
||||
# gt_classes = torch.cat([p.gt_classes for p in proposals])
|
||||
# mask = gt_classes != self.num_classes
|
||||
# fg_features = input_features[mask]
|
||||
# classes = gt_classes[mask]
|
||||
# fg_features = self.ae_model.encoder(fg_features)
|
||||
#
|
||||
# for index, feature in enumerate(fg_features):
|
||||
# for cls_index, mu in enumerate(self.means):
|
||||
# if mu is not None and feature is not None:
|
||||
# mu = mu.cuda()
|
||||
# if classes[index] == cls_index:
|
||||
# loss -= F.mse_loss(feature, mu)
|
||||
# else:
|
||||
# loss += F.mse_loss(feature, mu)
|
||||
#
|
||||
# return loss
|
||||
|
||||
def clstr_loss_l2_cdist(self, input_features, proposals):
|
||||
"""
|
||||
|
@ -822,3 +783,70 @@ class FastRCNNOutputLayers(nn.Module):
|
|||
num_inst_per_image = [len(p) for p in proposals]
|
||||
probs = F.softmax(scores, dim=-1)
|
||||
return probs.split(num_inst_per_image, dim=0)
|
||||
|
||||
# def clstr_loss(self, input_features, proposals):
|
||||
# """
|
||||
# Get the foreground input_features, generate distributions for the class,
|
||||
# get probability of each feature from each distribution;
|
||||
# Compute loss: if belonging to a class -> likelihood should be higher
|
||||
# else -> lower
|
||||
# :param input_features:
|
||||
# :param proposals:
|
||||
# :return:
|
||||
# """
|
||||
# loss = 0
|
||||
# gt_classes = torch.cat([p.gt_classes for p in proposals])
|
||||
# mask = gt_classes != self.num_classes
|
||||
# fg_features = input_features[mask]
|
||||
# classes = gt_classes[mask]
|
||||
# # fg_features = self.ae_model.encoder(fg_features)
|
||||
#
|
||||
# # Distribution per class
|
||||
# log_prob = [None for _ in range(self.num_classes + 1)]
|
||||
# # https://github.com/pytorch/pytorch/issues/23780
|
||||
# for cls_index, mu in enumerate(self.means):
|
||||
# if mu is not None:
|
||||
# dist = Normal(loc=mu.cuda(), scale=torch.ones_like(mu.cuda()))
|
||||
# log_prob[cls_index] = dist.log_prob(fg_features).mean(dim=1)
|
||||
# # log_prob[cls_index] = torch.distributions.multivariate_normal. \
|
||||
# # MultivariateNormal(mu.cuda(), torch.eye(len(mu)).cuda()).log_prob(fg_features)
|
||||
# # MultivariateNormal(mu, torch.eye(len(mu))).log_prob(fg_features.cpu())
|
||||
# # MultivariateNormal(mu[:2], torch.eye(len(mu[:2]))).log_prob(fg_features[:,:2].cpu())
|
||||
# else:
|
||||
# log_prob[cls_index] = torch.zeros((len(fg_features))).cuda()
|
||||
#
|
||||
# log_prob = torch.stack(log_prob).T # num_of_fg_proposals x num_of_classes
|
||||
# for i, p in enumerate(log_prob):
|
||||
# weight = torch.ones_like(p) * -1
|
||||
# weight[classes[i]] = 1
|
||||
# p = p * weight
|
||||
# loss += p.mean()
|
||||
# return loss
|
||||
|
||||
# def clstr_loss_l2(self, input_features, proposals):
|
||||
# """
|
||||
# Get the foreground input_features, generate distributions for the class,
|
||||
# get probability of each feature from each distribution;
|
||||
# Compute loss: if belonging to a class -> likelihood should be higher
|
||||
# else -> lower
|
||||
# :param input_features:
|
||||
# :param proposals:
|
||||
# :return:
|
||||
# """
|
||||
# loss = 0
|
||||
# gt_classes = torch.cat([p.gt_classes for p in proposals])
|
||||
# mask = gt_classes != self.num_classes
|
||||
# fg_features = input_features[mask]
|
||||
# classes = gt_classes[mask]
|
||||
# fg_features = self.ae_model.encoder(fg_features)
|
||||
#
|
||||
# for index, feature in enumerate(fg_features):
|
||||
# for cls_index, mu in enumerate(self.means):
|
||||
# if mu is not None and feature is not None:
|
||||
# mu = mu.cuda()
|
||||
# if classes[index] == cls_index:
|
||||
# loss -= F.mse_loss(feature, mu)
|
||||
# else:
|
||||
# loss += F.mse_loss(feature, mu)
|
||||
#
|
||||
# return loss
|
Loading…
Reference in New Issue