mirror of https://github.com/JosephKJ/OWOD.git
Weibull distribution
parent
9321996ee8
commit
71480aa878
|
@ -24,10 +24,10 @@ SOLVER:
|
||||||
MAX_ITER: 90000
|
MAX_ITER: 90000
|
||||||
VERSION: 2
|
VERSION: 2
|
||||||
OWOD:
|
OWOD:
|
||||||
ENABLE_THRESHOLD_AUTOLABEL_UNK: False
|
ENABLE_THRESHOLD_AUTOLABEL_UNK: True
|
||||||
NUM_UNK_PER_IMAGE: 1
|
NUM_UNK_PER_IMAGE: 1
|
||||||
ENABLE_UNCERTAINITY_AUTOLABEL_UNK: False
|
ENABLE_UNCERTAINITY_AUTOLABEL_UNK: False
|
||||||
ENABLE_CLUSTERING: False
|
ENABLE_CLUSTERING: True
|
||||||
CLUSTERING:
|
CLUSTERING:
|
||||||
ITEMS_PER_CLASS: 20
|
ITEMS_PER_CLASS: 20
|
||||||
START_ITER: 1000
|
START_ITER: 1000
|
||||||
|
|
|
@ -2,7 +2,7 @@ _BASE_: "../Base-RCNN-C4-OWOD.yaml"
|
||||||
MODEL:
|
MODEL:
|
||||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
||||||
# WEIGHTS: "/home/fk1/workspace/OWOD/output/expr_training_with_unk_with_clustering_Z_DIMENSION_256/model_final.pth"
|
# WEIGHTS: "/home/fk1/workspace/OWOD/output/expr_training_with_unk_with_clustering_Z_DIMENSION_256/model_final.pth"
|
||||||
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1/model_final.pth"
|
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_std_frcnn/model_final.pth"
|
||||||
DATASETS:
|
DATASETS:
|
||||||
TRAIN: ('t1_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
|
TRAIN: ('t1_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
|
||||||
TEST: ('voc_coco_2007_test', ) # t1_voc_coco_2007_test, t1_voc_coco_2007_val
|
TEST: ('voc_coco_2007_test', ) # t1_voc_coco_2007_test, t1_voc_coco_2007_val
|
||||||
|
@ -10,7 +10,7 @@ SOLVER:
|
||||||
STEPS: (12000, 16000)
|
STEPS: (12000, 16000)
|
||||||
MAX_ITER: 18000
|
MAX_ITER: 18000
|
||||||
WARMUP_ITERS: 100
|
WARMUP_ITERS: 100
|
||||||
OUTPUT_DIR: "./output/t1_std_frcnn"
|
OUTPUT_DIR: "./output/t1_ENABLE_CLUSTERING_margin_5"
|
||||||
OWOD:
|
OWOD:
|
||||||
PREV_INTRODUCED_CLS: 0
|
PREV_INTRODUCED_CLS: 0
|
||||||
CUR_INTRODUCED_CLS: 20
|
CUR_INTRODUCED_CLS: 20
|
|
@ -0,0 +1,14 @@
|
||||||
|
_BASE_: "../../Base-RCNN-C4-OWOD.yaml"
|
||||||
|
MODEL:
|
||||||
|
WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_expr/model_final.pth"
|
||||||
|
DATASETS:
|
||||||
|
TRAIN: ('t1_voc_coco_2007_ft', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
|
||||||
|
TEST: (, ) # voc_coco_2007_test
|
||||||
|
SOLVER:
|
||||||
|
STEPS: (12000, 16000)
|
||||||
|
MAX_ITER: 18000
|
||||||
|
WARMUP_ITERS: 100
|
||||||
|
OUTPUT_DIR: "./output/t1_expr"
|
||||||
|
OWOD:
|
||||||
|
PREV_INTRODUCED_CLS: 0
|
||||||
|
CUR_INTRODUCED_CLS: 20
|
|
@ -0,0 +1,15 @@
|
||||||
|
_BASE_: "../../Base-RCNN-C4-OWOD.yaml"
|
||||||
|
MODEL:
|
||||||
|
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
||||||
|
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1/model_final.pth"
|
||||||
|
DATASETS:
|
||||||
|
TRAIN: ('t1_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
|
||||||
|
TEST: ('voc_coco_2007_test', ) # voc_coco_2007_test
|
||||||
|
SOLVER:
|
||||||
|
STEPS: (12000, 16000)
|
||||||
|
MAX_ITER: 18000
|
||||||
|
WARMUP_ITERS: 100
|
||||||
|
OUTPUT_DIR: "./output/t1_expr"
|
||||||
|
OWOD:
|
||||||
|
PREV_INTRODUCED_CLS: 0
|
||||||
|
CUR_INTRODUCED_CLS: 20
|
|
@ -0,0 +1,15 @@
|
||||||
|
_BASE_: "../../Base-RCNN-C4-OWOD.yaml"
|
||||||
|
MODEL:
|
||||||
|
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
||||||
|
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1/model_final.pth"
|
||||||
|
DATASETS:
|
||||||
|
TRAIN: ('t1_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
|
||||||
|
TEST: (, ) # voc_coco_2007_test
|
||||||
|
SOLVER:
|
||||||
|
STEPS: (12000, 16000)
|
||||||
|
MAX_ITER: 18000
|
||||||
|
WARMUP_ITERS: 100
|
||||||
|
OUTPUT_DIR: "./output/t1_expr"
|
||||||
|
OWOD:
|
||||||
|
PREV_INTRODUCED_CLS: 0
|
||||||
|
CUR_INTRODUCED_CLS: 20
|
|
@ -0,0 +1,23 @@
|
||||||
|
_BASE_: "../../Base-RCNN-C4-OWOD.yaml"
|
||||||
|
MODEL:
|
||||||
|
# WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
||||||
|
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_THRESHOLD_AUTOLABEL_UNK/model_final.pth"
|
||||||
|
WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_ENABLE_CLUSTERING/model_final.pth"
|
||||||
|
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_std_frcnn/model_final.pth"
|
||||||
|
DATASETS:
|
||||||
|
TRAIN: ('voc_coco_2007_val', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
|
||||||
|
TEST: ('voc_coco_2007_val', ) # voc_coco_2007_test
|
||||||
|
SOLVER:
|
||||||
|
STEPS: (12000, 16000)
|
||||||
|
MAX_ITER: 500
|
||||||
|
WARMUP_ITERS: 0
|
||||||
|
OUTPUT_DIR: "./output/t1_clustering_new_3"
|
||||||
|
OWOD:
|
||||||
|
PREV_INTRODUCED_CLS: 0
|
||||||
|
CUR_INTRODUCED_CLS: 20
|
||||||
|
COMPUTE_ENERGY: True
|
||||||
|
ENERGY_SAVE_PATH: 'energy'
|
||||||
|
SKIP_TRAINING_WHILE_EVAL: False
|
||||||
|
|
||||||
|
|
||||||
|
#OUTPUT_DIR: "./output/t1_std_frcnn_energy"
|
|
@ -1,9 +1,10 @@
|
||||||
_BASE_: "../Base-RCNN-C4-OWOD.yaml"
|
_BASE_: "../Base-RCNN-C4-OWOD.yaml"
|
||||||
MODEL:
|
MODEL:
|
||||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
# WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
||||||
|
WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_ENABLE_CLUSTERING/model_final.pth"
|
||||||
DATASETS:
|
DATASETS:
|
||||||
TRAIN: ('t2_train',)
|
TRAIN: ('t2_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
|
||||||
TEST: ('voc_2007_test', 't2_test', 't3_test_unk', 't4_test_unk')
|
TEST: ('voc_coco_2007_test', )
|
||||||
SOLVER:
|
SOLVER:
|
||||||
STEPS: (22000, 26000)
|
STEPS: (22000, 26000)
|
||||||
MAX_ITER: 28000
|
MAX_ITER: 28000
|
||||||
|
|
|
@ -613,6 +613,9 @@ _C.OWOD.CLUSTERING.Z_DIMENSION = 64
|
||||||
|
|
||||||
_C.OWOD.PREV_INTRODUCED_CLS = 0
|
_C.OWOD.PREV_INTRODUCED_CLS = 0
|
||||||
_C.OWOD.CUR_INTRODUCED_CLS = 20
|
_C.OWOD.CUR_INTRODUCED_CLS = 20
|
||||||
|
_C.OWOD.COMPUTE_ENERGY = False
|
||||||
|
_C.OWOD.ENERGY_SAVE_PATH = ''
|
||||||
|
_C.OWOD.SKIP_TRAINING_WHILE_EVAL = False
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------------- #
|
||||||
# Misc options
|
# Misc options
|
||||||
|
|
|
@ -263,7 +263,7 @@ def remove_prev_class_and_unk_instances(cfg, dataset_dicts):
|
||||||
# For training data.
|
# For training data.
|
||||||
prev_intro_cls = cfg.OWOD.PREV_INTRODUCED_CLS
|
prev_intro_cls = cfg.OWOD.PREV_INTRODUCED_CLS
|
||||||
curr_intro_cls = cfg.OWOD.CUR_INTRODUCED_CLS
|
curr_intro_cls = cfg.OWOD.CUR_INTRODUCED_CLS
|
||||||
valid_classes = range(prev_intro_cls, curr_intro_cls)
|
valid_classes = range(prev_intro_cls, prev_intro_cls + curr_intro_cls)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.info("Valid classes: " + str(valid_classes))
|
logger.info("Valid classes: " + str(valid_classes))
|
||||||
|
|
|
@ -118,6 +118,9 @@ def default_setup(cfg, args):
|
||||||
if comm.is_main_process() and output_dir:
|
if comm.is_main_process() and output_dir:
|
||||||
PathManager.mkdirs(output_dir)
|
PathManager.mkdirs(output_dir)
|
||||||
|
|
||||||
|
if cfg.OWOD.COMPUTE_ENERGY:
|
||||||
|
PathManager.mkdirs(os.path.join(output_dir, cfg.OWOD.ENERGY_SAVE_PATH))
|
||||||
|
|
||||||
rank = comm.get_rank()
|
rank = comm.get_rank()
|
||||||
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
|
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
|
||||||
logger = setup_logger(output_dir, distributed_rank=rank)
|
logger = setup_logger(output_dir, distributed_rank=rank)
|
||||||
|
|
|
@ -7,6 +7,9 @@ import numpy as np
|
||||||
import time
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
import torch
|
import torch
|
||||||
|
import os
|
||||||
|
from matplotlib import pyplot
|
||||||
|
from reliability.Fitters import Fit_Weibull_3P
|
||||||
|
|
||||||
import detectron2.utils.comm as comm
|
import detectron2.utils.comm as comm
|
||||||
from detectron2.utils.events import EventStorage
|
from detectron2.utils.events import EventStorage
|
||||||
|
@ -138,6 +141,8 @@ class TrainerBase:
|
||||||
try:
|
try:
|
||||||
self.before_train()
|
self.before_train()
|
||||||
for self.iter in range(start_iter, max_iter):
|
for self.iter in range(start_iter, max_iter):
|
||||||
|
if self.cfg.OWOD.SKIP_TRAINING_WHILE_EVAL:
|
||||||
|
continue
|
||||||
self.before_step()
|
self.before_step()
|
||||||
self.run_step()
|
self.run_step()
|
||||||
self.after_step()
|
self.after_step()
|
||||||
|
@ -152,13 +157,80 @@ class TrainerBase:
|
||||||
self.after_train()
|
self.after_train()
|
||||||
|
|
||||||
def before_train(self):
|
def before_train(self):
|
||||||
|
if self.cfg.OWOD.SKIP_TRAINING_WHILE_EVAL:
|
||||||
|
logger.info('Skipping training as cfg.OWOD.SKIP_TRAINING_WHILE_EVAL flag is set.')
|
||||||
for h in self._hooks:
|
for h in self._hooks:
|
||||||
h.before_train()
|
h.before_train()
|
||||||
|
|
||||||
def after_train(self):
|
def after_train(self):
|
||||||
self.storage.iter = self.iter
|
self.storage.iter = self.iter
|
||||||
for h in self._hooks:
|
if self.cfg.OWOD.COMPUTE_ENERGY:
|
||||||
h.after_train()
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info("Going to analyse the energy files...")
|
||||||
|
|
||||||
|
self.analyse_energy()
|
||||||
|
|
||||||
|
for h in self._hooks:
|
||||||
|
if 'EvalHook' not in str(type(h)):
|
||||||
|
h.after_train()
|
||||||
|
else:
|
||||||
|
for h in self._hooks:
|
||||||
|
h.after_train()
|
||||||
|
|
||||||
|
def analyse_energy(self):
|
||||||
|
files = os.listdir(os.path.join(self.cfg.OUTPUT_DIR, self.cfg.OWOD.ENERGY_SAVE_PATH))
|
||||||
|
unk = []
|
||||||
|
known = []
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
for id, file in enumerate(files):
|
||||||
|
path = os.path.join(self.cfg.OUTPUT_DIR, self.cfg.OWOD.ENERGY_SAVE_PATH, file)
|
||||||
|
try:
|
||||||
|
logits, classes = torch.load(path)
|
||||||
|
except:
|
||||||
|
logger.info('Not able to load ' + path + ". Continuing...")
|
||||||
|
continue
|
||||||
|
num_seen_classes = self.cfg.OWOD.PREV_INTRODUCED_CLS + self.cfg.OWOD.CUR_INTRODUCED_CLS
|
||||||
|
lse = torch.logsumexp(logits[:, :num_seen_classes], dim=1)
|
||||||
|
# lse = torch.logsumexp(logits[:, :-2], dim=1)
|
||||||
|
|
||||||
|
for i, cls in enumerate(classes):
|
||||||
|
if cls == self.cfg.MODEL.ROI_HEADS.NUM_CLASSES:
|
||||||
|
continue
|
||||||
|
if cls == self.cfg.MODEL.ROI_HEADS.NUM_CLASSES-1:
|
||||||
|
unk.append(lse[i].detach().cpu().tolist())
|
||||||
|
else:
|
||||||
|
known.append(lse[i].detach().cpu().tolist())
|
||||||
|
|
||||||
|
if id % 100 == 0:
|
||||||
|
logger.info("Analysing " + str(id) + " / " + str(len(files)))
|
||||||
|
if id == 10:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info('len(unk): ' + str(len(unk)))
|
||||||
|
logger.info('len(known): '+ str(len(known)))
|
||||||
|
|
||||||
|
logger.info('Fitting Weibull distribution...')
|
||||||
|
wb_dist_param = []
|
||||||
|
wb_unk = Fit_Weibull_3P(failures=unk, show_probability_plot=False, print_results=False)
|
||||||
|
wb_dist_param.append({"scale_unk": wb_unk.alpha, "shape_unk": wb_unk.beta, "shift_unk": wb_unk.gamma})
|
||||||
|
|
||||||
|
wb_known = Fit_Weibull_3P(failures=known, show_probability_plot=False, print_results=False)
|
||||||
|
wb_dist_param.append(
|
||||||
|
{"scale_known": wb_known.alpha, "shape_known": wb_known.beta, "shift_known": wb_known.gamma})
|
||||||
|
|
||||||
|
param_save_location = os.path.join(self.cfg.OUTPUT_DIR,
|
||||||
|
'energy_dist_' + str(self.cfg.OWOD.PREV_INTRODUCED_CLS
|
||||||
|
+ self.cfg.OWOD.CUR_INTRODUCED_CLS) + '.pkl')
|
||||||
|
logger.info('Pickling the parameters to ' + param_save_location)
|
||||||
|
torch.save(wb_dist_param, param_save_location)
|
||||||
|
|
||||||
|
logger.info('Plotting the computed energy values...')
|
||||||
|
bins = np.linspace(2, 15, 500)
|
||||||
|
pyplot.hist(known, bins, alpha=0.5, label='known')
|
||||||
|
pyplot.hist(unk, bins, alpha=0.5, label='unk')
|
||||||
|
pyplot.legend(loc='upper right')
|
||||||
|
pyplot.savefig(os.path.join(self.cfg.OUTPUT_DIR, 'energy.png'))
|
||||||
|
|
||||||
def before_step(self):
|
def before_step(self):
|
||||||
# Maintain the invariant that storage.iter == trainer.iter
|
# Maintain the invariant that storage.iter == trainer.iter
|
||||||
|
|
|
@ -46,6 +46,7 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
||||||
self.prev_intro_cls = cfg.OWOD.PREV_INTRODUCED_CLS
|
self.prev_intro_cls = cfg.OWOD.PREV_INTRODUCED_CLS
|
||||||
self.curr_intro_cls = cfg.OWOD.CUR_INTRODUCED_CLS
|
self.curr_intro_cls = cfg.OWOD.CUR_INTRODUCED_CLS
|
||||||
self.total_num_class = cfg.MODEL.ROI_HEADS.NUM_CLASSES
|
self.total_num_class = cfg.MODEL.ROI_HEADS.NUM_CLASSES
|
||||||
|
self.known_classes = self._class_names[:self.prev_intro_cls + self.curr_intro_cls]
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self._predictions = defaultdict(list) # class name -> list of prediction strings
|
self._predictions = defaultdict(list) # class name -> list of prediction strings
|
||||||
|
@ -57,6 +58,7 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
||||||
boxes = instances.pred_boxes.tensor.numpy()
|
boxes = instances.pred_boxes.tensor.numpy()
|
||||||
scores = instances.scores.tolist()
|
scores = instances.scores.tolist()
|
||||||
classes = instances.pred_classes.tolist()
|
classes = instances.pred_classes.tolist()
|
||||||
|
logits = instances.logits
|
||||||
for box, score, cls in zip(boxes, scores, classes):
|
for box, score, cls in zip(boxes, scores, classes):
|
||||||
xmin, ymin, xmax, ymax = box
|
xmin, ymin, xmax, ymax = box
|
||||||
# The inverse of data loading logic in `datasets/pascal_voc.py`
|
# The inverse of data loading logic in `datasets/pascal_voc.py`
|
||||||
|
@ -108,6 +110,7 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
||||||
cls_name,
|
cls_name,
|
||||||
ovthresh=thresh / 100.0,
|
ovthresh=thresh / 100.0,
|
||||||
use_07_metric=self._is_2007,
|
use_07_metric=self._is_2007,
|
||||||
|
known_classes=self.known_classes
|
||||||
)
|
)
|
||||||
aps[thresh].append(ap * 100)
|
aps[thresh].append(ap * 100)
|
||||||
# recs[thresh].append(rec * 100)
|
# recs[thresh].append(rec * 100)
|
||||||
|
@ -166,14 +169,27 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def parse_rec(filename):
|
def parse_rec(filename, known_classes):
|
||||||
"""Parse a PASCAL VOC xml file."""
|
"""Parse a PASCAL VOC xml file."""
|
||||||
|
VOC_CLASS_NAMES_COCOFIED = [
|
||||||
|
"airplane", "dining table", "motorcycle",
|
||||||
|
"potted plant", "couch", "tv"
|
||||||
|
]
|
||||||
|
BASE_VOC_CLASS_NAMES = [
|
||||||
|
"aeroplane", "diningtable", "motorbike",
|
||||||
|
"pottedplant", "sofa", "tvmonitor"
|
||||||
|
]
|
||||||
with PathManager.open(filename) as f:
|
with PathManager.open(filename) as f:
|
||||||
tree = ET.parse(f)
|
tree = ET.parse(f)
|
||||||
objects = []
|
objects = []
|
||||||
for obj in tree.findall("object"):
|
for obj in tree.findall("object"):
|
||||||
obj_struct = {}
|
obj_struct = {}
|
||||||
obj_struct["name"] = obj.find("name").text
|
cls_name = obj.find("name").text
|
||||||
|
if cls_name in VOC_CLASS_NAMES_COCOFIED:
|
||||||
|
cls_name = BASE_VOC_CLASS_NAMES[VOC_CLASS_NAMES_COCOFIED.index(cls_name)]
|
||||||
|
if cls_name not in known_classes:
|
||||||
|
cls_name = 'unknown'
|
||||||
|
obj_struct["name"] = cls_name
|
||||||
# obj_struct["pose"] = obj.find("pose").text
|
# obj_struct["pose"] = obj.find("pose").text
|
||||||
# obj_struct["truncated"] = int(obj.find("truncated").text)
|
# obj_struct["truncated"] = int(obj.find("truncated").text)
|
||||||
obj_struct["difficult"] = int(obj.find("difficult").text)
|
obj_struct["difficult"] = int(obj.find("difficult").text)
|
||||||
|
@ -221,7 +237,7 @@ def voc_ap(rec, prec, use_07_metric=False):
|
||||||
return ap
|
return ap
|
||||||
|
|
||||||
|
|
||||||
def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_metric=False):
|
def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_metric=False, known_classes=None):
|
||||||
"""rec, prec, ap = voc_eval(detpath,
|
"""rec, prec, ap = voc_eval(detpath,
|
||||||
annopath,
|
annopath,
|
||||||
imagesetfile,
|
imagesetfile,
|
||||||
|
@ -254,7 +270,7 @@ def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_me
|
||||||
# load annots
|
# load annots
|
||||||
recs = {}
|
recs = {}
|
||||||
for imagename in imagenames:
|
for imagename in imagenames:
|
||||||
recs[imagename] = parse_rec(annopath.format(imagename))
|
recs[imagename] = parse_rec(annopath.format(imagename), tuple(known_classes))
|
||||||
|
|
||||||
# extract gt objects for this class
|
# extract gt objects for this class
|
||||||
class_recs = {}
|
class_recs = {}
|
||||||
|
|
|
@ -47,7 +47,7 @@ Naming convention:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def fast_rcnn_inference(boxes, scores, image_shapes, score_thresh, nms_thresh, topk_per_image):
|
def fast_rcnn_inference(boxes, scores, image_shapes, predictions, score_thresh, nms_thresh, topk_per_image):
|
||||||
"""
|
"""
|
||||||
Call `fast_rcnn_inference_single_image` for all images.
|
Call `fast_rcnn_inference_single_image` for all images.
|
||||||
|
|
||||||
|
@ -75,15 +75,15 @@ def fast_rcnn_inference(boxes, scores, image_shapes, score_thresh, nms_thresh, t
|
||||||
"""
|
"""
|
||||||
result_per_image = [
|
result_per_image = [
|
||||||
fast_rcnn_inference_single_image(
|
fast_rcnn_inference_single_image(
|
||||||
boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image
|
boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image, prediction
|
||||||
)
|
)
|
||||||
for scores_per_image, boxes_per_image, image_shape in zip(scores, boxes, image_shapes)
|
for scores_per_image, boxes_per_image, image_shape, prediction in zip(scores, boxes, image_shapes, predictions)
|
||||||
]
|
]
|
||||||
return [x[0] for x in result_per_image], [x[1] for x in result_per_image]
|
return [x[0] for x in result_per_image], [x[1] for x in result_per_image]
|
||||||
|
|
||||||
|
|
||||||
def fast_rcnn_inference_single_image(
|
def fast_rcnn_inference_single_image(
|
||||||
boxes, scores, image_shape, score_thresh, nms_thresh, topk_per_image
|
boxes, scores, image_shape, score_thresh, nms_thresh, topk_per_image, prediction
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Single-image inference. Return bounding-box detection results by thresholding
|
Single-image inference. Return bounding-box detection results by thresholding
|
||||||
|
@ -96,12 +96,15 @@ def fast_rcnn_inference_single_image(
|
||||||
Returns:
|
Returns:
|
||||||
Same as `fast_rcnn_inference`, but for only one image.
|
Same as `fast_rcnn_inference`, but for only one image.
|
||||||
"""
|
"""
|
||||||
|
logits = prediction
|
||||||
valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1)
|
valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1)
|
||||||
if not valid_mask.all():
|
if not valid_mask.all():
|
||||||
boxes = boxes[valid_mask]
|
boxes = boxes[valid_mask]
|
||||||
scores = scores[valid_mask]
|
scores = scores[valid_mask]
|
||||||
|
logits = logits[valid_mask]
|
||||||
|
|
||||||
scores = scores[:, :-1]
|
scores = scores[:, :-1]
|
||||||
|
logits = logits[:, :-1]
|
||||||
num_bbox_reg_classes = boxes.shape[1] // 4
|
num_bbox_reg_classes = boxes.shape[1] // 4
|
||||||
# Convert to Boxes to use the `clip` function ...
|
# Convert to Boxes to use the `clip` function ...
|
||||||
boxes = Boxes(boxes.reshape(-1, 4))
|
boxes = Boxes(boxes.reshape(-1, 4))
|
||||||
|
@ -119,17 +122,20 @@ def fast_rcnn_inference_single_image(
|
||||||
else:
|
else:
|
||||||
boxes = boxes[filter_mask]
|
boxes = boxes[filter_mask]
|
||||||
scores = scores[filter_mask]
|
scores = scores[filter_mask]
|
||||||
|
logits = logits[filter_inds[:,0]]
|
||||||
|
|
||||||
# 2. Apply NMS for each class independently.
|
# 2. Apply NMS for each class independently.
|
||||||
keep = batched_nms(boxes, scores, filter_inds[:, 1], nms_thresh)
|
keep = batched_nms(boxes, scores, filter_inds[:, 1], nms_thresh)
|
||||||
if topk_per_image >= 0:
|
if topk_per_image >= 0:
|
||||||
keep = keep[:topk_per_image]
|
keep = keep[:topk_per_image]
|
||||||
boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep]
|
boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep]
|
||||||
|
logits = logits[keep]
|
||||||
|
|
||||||
result = Instances(image_shape)
|
result = Instances(image_shape)
|
||||||
result.pred_boxes = Boxes(boxes)
|
result.pred_boxes = Boxes(boxes)
|
||||||
result.scores = scores
|
result.scores = scores
|
||||||
result.pred_classes = filter_inds[:, 1]
|
result.pred_classes = filter_inds[:, 1]
|
||||||
|
result.logits = logits
|
||||||
return result, filter_inds[:, 0]
|
return result, filter_inds[:, 0]
|
||||||
|
|
||||||
|
|
||||||
|
@ -734,6 +740,7 @@ class FastRCNNOutputLayers(nn.Module):
|
||||||
boxes,
|
boxes,
|
||||||
scores,
|
scores,
|
||||||
image_shapes,
|
image_shapes,
|
||||||
|
predictions,
|
||||||
self.test_score_thresh,
|
self.test_score_thresh,
|
||||||
self.test_nms_thresh,
|
self.test_nms_thresh,
|
||||||
self.test_topk_per_image,
|
self.test_topk_per_image,
|
||||||
|
|
|
@ -3,6 +3,8 @@ import inspect
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import heapq
|
import heapq
|
||||||
|
import os
|
||||||
|
import shortuuid
|
||||||
import operator
|
import operator
|
||||||
import shortuuid
|
import shortuuid
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
@ -385,6 +387,8 @@ class Res5ROIHeads(ROIHeads):
|
||||||
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
|
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
|
||||||
self.mask_on = cfg.MODEL.MASK_ON
|
self.mask_on = cfg.MODEL.MASK_ON
|
||||||
self.enable_clustering = cfg.OWOD.ENABLE_CLUSTERING
|
self.enable_clustering = cfg.OWOD.ENABLE_CLUSTERING
|
||||||
|
self.compute_energy_flag = cfg.OWOD.COMPUTE_ENERGY
|
||||||
|
self.energy_save_path = os.path.join(cfg.OUTPUT_DIR, cfg.OWOD.ENERGY_SAVE_PATH)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
assert not cfg.MODEL.KEYPOINT_ON
|
assert not cfg.MODEL.KEYPOINT_ON
|
||||||
assert len(self.in_features) == 1
|
assert len(self.in_features) == 1
|
||||||
|
@ -443,6 +447,13 @@ class Res5ROIHeads(ROIHeads):
|
||||||
location = '/home/fk1/workspace/OWOD/output/features/' + shortuuid.uuid() + '.pkl'
|
location = '/home/fk1/workspace/OWOD/output/features/' + shortuuid.uuid() + '.pkl'
|
||||||
torch.save(data, location)
|
torch.save(data, location)
|
||||||
|
|
||||||
|
def compute_energy(self, predictions, proposals):
|
||||||
|
gt_classes = torch.cat([p.gt_classes for p in proposals])
|
||||||
|
logits = predictions[0]
|
||||||
|
data = (logits, gt_classes)
|
||||||
|
location = os.path.join(self.energy_save_path, shortuuid.uuid() + '.pkl')
|
||||||
|
torch.save(data, location)
|
||||||
|
|
||||||
def forward(self, images, features, proposals, targets=None):
|
def forward(self, images, features, proposals, targets=None):
|
||||||
"""
|
"""
|
||||||
See :meth:`ROIHeads.forward`.
|
See :meth:`ROIHeads.forward`.
|
||||||
|
@ -466,6 +477,8 @@ class Res5ROIHeads(ROIHeads):
|
||||||
if self.enable_clustering:
|
if self.enable_clustering:
|
||||||
self.box_predictor.update_feature_store(input_features, proposals)
|
self.box_predictor.update_feature_store(input_features, proposals)
|
||||||
del features
|
del features
|
||||||
|
if self.compute_energy_flag:
|
||||||
|
self.compute_energy(predictions, proposals)
|
||||||
losses = self.box_predictor.losses(predictions, proposals, input_features)
|
losses = self.box_predictor.losses(predictions, proposals, input_features)
|
||||||
if self.mask_on:
|
if self.mask_on:
|
||||||
proposals, fg_selection_masks = select_foreground_proposals(
|
proposals, fg_selection_masks = select_foreground_proposals(
|
||||||
|
|
Loading…
Reference in New Issue