mirror of https://github.com/JosephKJ/OWOD.git
Weibull distribution
parent
9321996ee8
commit
71480aa878
|
@ -24,10 +24,10 @@ SOLVER:
|
|||
MAX_ITER: 90000
|
||||
VERSION: 2
|
||||
OWOD:
|
||||
ENABLE_THRESHOLD_AUTOLABEL_UNK: False
|
||||
ENABLE_THRESHOLD_AUTOLABEL_UNK: True
|
||||
NUM_UNK_PER_IMAGE: 1
|
||||
ENABLE_UNCERTAINITY_AUTOLABEL_UNK: False
|
||||
ENABLE_CLUSTERING: False
|
||||
ENABLE_CLUSTERING: True
|
||||
CLUSTERING:
|
||||
ITEMS_PER_CLASS: 20
|
||||
START_ITER: 1000
|
||||
|
|
|
@ -2,7 +2,7 @@ _BASE_: "../Base-RCNN-C4-OWOD.yaml"
|
|||
MODEL:
|
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
||||
# WEIGHTS: "/home/fk1/workspace/OWOD/output/expr_training_with_unk_with_clustering_Z_DIMENSION_256/model_final.pth"
|
||||
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1/model_final.pth"
|
||||
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_std_frcnn/model_final.pth"
|
||||
DATASETS:
|
||||
TRAIN: ('t1_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
|
||||
TEST: ('voc_coco_2007_test', ) # t1_voc_coco_2007_test, t1_voc_coco_2007_val
|
||||
|
@ -10,7 +10,7 @@ SOLVER:
|
|||
STEPS: (12000, 16000)
|
||||
MAX_ITER: 18000
|
||||
WARMUP_ITERS: 100
|
||||
OUTPUT_DIR: "./output/t1_std_frcnn"
|
||||
OUTPUT_DIR: "./output/t1_ENABLE_CLUSTERING_margin_5"
|
||||
OWOD:
|
||||
PREV_INTRODUCED_CLS: 0
|
||||
CUR_INTRODUCED_CLS: 20
|
|
@ -0,0 +1,14 @@
|
|||
_BASE_: "../../Base-RCNN-C4-OWOD.yaml"
|
||||
MODEL:
|
||||
WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_expr/model_final.pth"
|
||||
DATASETS:
|
||||
TRAIN: ('t1_voc_coco_2007_ft', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
|
||||
TEST: (, ) # voc_coco_2007_test
|
||||
SOLVER:
|
||||
STEPS: (12000, 16000)
|
||||
MAX_ITER: 18000
|
||||
WARMUP_ITERS: 100
|
||||
OUTPUT_DIR: "./output/t1_expr"
|
||||
OWOD:
|
||||
PREV_INTRODUCED_CLS: 0
|
||||
CUR_INTRODUCED_CLS: 20
|
|
@ -0,0 +1,15 @@
|
|||
_BASE_: "../../Base-RCNN-C4-OWOD.yaml"
|
||||
MODEL:
|
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
||||
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1/model_final.pth"
|
||||
DATASETS:
|
||||
TRAIN: ('t1_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
|
||||
TEST: ('voc_coco_2007_test', ) # voc_coco_2007_test
|
||||
SOLVER:
|
||||
STEPS: (12000, 16000)
|
||||
MAX_ITER: 18000
|
||||
WARMUP_ITERS: 100
|
||||
OUTPUT_DIR: "./output/t1_expr"
|
||||
OWOD:
|
||||
PREV_INTRODUCED_CLS: 0
|
||||
CUR_INTRODUCED_CLS: 20
|
|
@ -0,0 +1,15 @@
|
|||
_BASE_: "../../Base-RCNN-C4-OWOD.yaml"
|
||||
MODEL:
|
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
||||
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1/model_final.pth"
|
||||
DATASETS:
|
||||
TRAIN: ('t1_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
|
||||
TEST: (, ) # voc_coco_2007_test
|
||||
SOLVER:
|
||||
STEPS: (12000, 16000)
|
||||
MAX_ITER: 18000
|
||||
WARMUP_ITERS: 100
|
||||
OUTPUT_DIR: "./output/t1_expr"
|
||||
OWOD:
|
||||
PREV_INTRODUCED_CLS: 0
|
||||
CUR_INTRODUCED_CLS: 20
|
|
@ -0,0 +1,23 @@
|
|||
_BASE_: "../../Base-RCNN-C4-OWOD.yaml"
|
||||
MODEL:
|
||||
# WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
||||
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_THRESHOLD_AUTOLABEL_UNK/model_final.pth"
|
||||
WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_ENABLE_CLUSTERING/model_final.pth"
|
||||
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_std_frcnn/model_final.pth"
|
||||
DATASETS:
|
||||
TRAIN: ('voc_coco_2007_val', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
|
||||
TEST: ('voc_coco_2007_val', ) # voc_coco_2007_test
|
||||
SOLVER:
|
||||
STEPS: (12000, 16000)
|
||||
MAX_ITER: 500
|
||||
WARMUP_ITERS: 0
|
||||
OUTPUT_DIR: "./output/t1_clustering_new_3"
|
||||
OWOD:
|
||||
PREV_INTRODUCED_CLS: 0
|
||||
CUR_INTRODUCED_CLS: 20
|
||||
COMPUTE_ENERGY: True
|
||||
ENERGY_SAVE_PATH: 'energy'
|
||||
SKIP_TRAINING_WHILE_EVAL: False
|
||||
|
||||
|
||||
#OUTPUT_DIR: "./output/t1_std_frcnn_energy"
|
|
@ -1,9 +1,10 @@
|
|||
_BASE_: "../Base-RCNN-C4-OWOD.yaml"
|
||||
MODEL:
|
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
||||
# WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
||||
WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_ENABLE_CLUSTERING/model_final.pth"
|
||||
DATASETS:
|
||||
TRAIN: ('t2_train',)
|
||||
TEST: ('voc_2007_test', 't2_test', 't3_test_unk', 't4_test_unk')
|
||||
TRAIN: ('t2_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
|
||||
TEST: ('voc_coco_2007_test', )
|
||||
SOLVER:
|
||||
STEPS: (22000, 26000)
|
||||
MAX_ITER: 28000
|
||||
|
|
|
@ -613,6 +613,9 @@ _C.OWOD.CLUSTERING.Z_DIMENSION = 64
|
|||
|
||||
_C.OWOD.PREV_INTRODUCED_CLS = 0
|
||||
_C.OWOD.CUR_INTRODUCED_CLS = 20
|
||||
_C.OWOD.COMPUTE_ENERGY = False
|
||||
_C.OWOD.ENERGY_SAVE_PATH = ''
|
||||
_C.OWOD.SKIP_TRAINING_WHILE_EVAL = False
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# Misc options
|
||||
|
|
|
@ -263,7 +263,7 @@ def remove_prev_class_and_unk_instances(cfg, dataset_dicts):
|
|||
# For training data.
|
||||
prev_intro_cls = cfg.OWOD.PREV_INTRODUCED_CLS
|
||||
curr_intro_cls = cfg.OWOD.CUR_INTRODUCED_CLS
|
||||
valid_classes = range(prev_intro_cls, curr_intro_cls)
|
||||
valid_classes = range(prev_intro_cls, prev_intro_cls + curr_intro_cls)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Valid classes: " + str(valid_classes))
|
||||
|
|
|
@ -118,6 +118,9 @@ def default_setup(cfg, args):
|
|||
if comm.is_main_process() and output_dir:
|
||||
PathManager.mkdirs(output_dir)
|
||||
|
||||
if cfg.OWOD.COMPUTE_ENERGY:
|
||||
PathManager.mkdirs(os.path.join(output_dir, cfg.OWOD.ENERGY_SAVE_PATH))
|
||||
|
||||
rank = comm.get_rank()
|
||||
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
|
||||
logger = setup_logger(output_dir, distributed_rank=rank)
|
||||
|
|
|
@ -7,6 +7,9 @@ import numpy as np
|
|||
import time
|
||||
import weakref
|
||||
import torch
|
||||
import os
|
||||
from matplotlib import pyplot
|
||||
from reliability.Fitters import Fit_Weibull_3P
|
||||
|
||||
import detectron2.utils.comm as comm
|
||||
from detectron2.utils.events import EventStorage
|
||||
|
@ -138,6 +141,8 @@ class TrainerBase:
|
|||
try:
|
||||
self.before_train()
|
||||
for self.iter in range(start_iter, max_iter):
|
||||
if self.cfg.OWOD.SKIP_TRAINING_WHILE_EVAL:
|
||||
continue
|
||||
self.before_step()
|
||||
self.run_step()
|
||||
self.after_step()
|
||||
|
@ -152,13 +157,80 @@ class TrainerBase:
|
|||
self.after_train()
|
||||
|
||||
def before_train(self):
|
||||
if self.cfg.OWOD.SKIP_TRAINING_WHILE_EVAL:
|
||||
logger.info('Skipping training as cfg.OWOD.SKIP_TRAINING_WHILE_EVAL flag is set.')
|
||||
for h in self._hooks:
|
||||
h.before_train()
|
||||
|
||||
def after_train(self):
|
||||
self.storage.iter = self.iter
|
||||
for h in self._hooks:
|
||||
h.after_train()
|
||||
if self.cfg.OWOD.COMPUTE_ENERGY:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Going to analyse the energy files...")
|
||||
|
||||
self.analyse_energy()
|
||||
|
||||
for h in self._hooks:
|
||||
if 'EvalHook' not in str(type(h)):
|
||||
h.after_train()
|
||||
else:
|
||||
for h in self._hooks:
|
||||
h.after_train()
|
||||
|
||||
def analyse_energy(self):
|
||||
files = os.listdir(os.path.join(self.cfg.OUTPUT_DIR, self.cfg.OWOD.ENERGY_SAVE_PATH))
|
||||
unk = []
|
||||
known = []
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
for id, file in enumerate(files):
|
||||
path = os.path.join(self.cfg.OUTPUT_DIR, self.cfg.OWOD.ENERGY_SAVE_PATH, file)
|
||||
try:
|
||||
logits, classes = torch.load(path)
|
||||
except:
|
||||
logger.info('Not able to load ' + path + ". Continuing...")
|
||||
continue
|
||||
num_seen_classes = self.cfg.OWOD.PREV_INTRODUCED_CLS + self.cfg.OWOD.CUR_INTRODUCED_CLS
|
||||
lse = torch.logsumexp(logits[:, :num_seen_classes], dim=1)
|
||||
# lse = torch.logsumexp(logits[:, :-2], dim=1)
|
||||
|
||||
for i, cls in enumerate(classes):
|
||||
if cls == self.cfg.MODEL.ROI_HEADS.NUM_CLASSES:
|
||||
continue
|
||||
if cls == self.cfg.MODEL.ROI_HEADS.NUM_CLASSES-1:
|
||||
unk.append(lse[i].detach().cpu().tolist())
|
||||
else:
|
||||
known.append(lse[i].detach().cpu().tolist())
|
||||
|
||||
if id % 100 == 0:
|
||||
logger.info("Analysing " + str(id) + " / " + str(len(files)))
|
||||
if id == 10:
|
||||
break
|
||||
|
||||
logger.info('len(unk): ' + str(len(unk)))
|
||||
logger.info('len(known): '+ str(len(known)))
|
||||
|
||||
logger.info('Fitting Weibull distribution...')
|
||||
wb_dist_param = []
|
||||
wb_unk = Fit_Weibull_3P(failures=unk, show_probability_plot=False, print_results=False)
|
||||
wb_dist_param.append({"scale_unk": wb_unk.alpha, "shape_unk": wb_unk.beta, "shift_unk": wb_unk.gamma})
|
||||
|
||||
wb_known = Fit_Weibull_3P(failures=known, show_probability_plot=False, print_results=False)
|
||||
wb_dist_param.append(
|
||||
{"scale_known": wb_known.alpha, "shape_known": wb_known.beta, "shift_known": wb_known.gamma})
|
||||
|
||||
param_save_location = os.path.join(self.cfg.OUTPUT_DIR,
|
||||
'energy_dist_' + str(self.cfg.OWOD.PREV_INTRODUCED_CLS
|
||||
+ self.cfg.OWOD.CUR_INTRODUCED_CLS) + '.pkl')
|
||||
logger.info('Pickling the parameters to ' + param_save_location)
|
||||
torch.save(wb_dist_param, param_save_location)
|
||||
|
||||
logger.info('Plotting the computed energy values...')
|
||||
bins = np.linspace(2, 15, 500)
|
||||
pyplot.hist(known, bins, alpha=0.5, label='known')
|
||||
pyplot.hist(unk, bins, alpha=0.5, label='unk')
|
||||
pyplot.legend(loc='upper right')
|
||||
pyplot.savefig(os.path.join(self.cfg.OUTPUT_DIR, 'energy.png'))
|
||||
|
||||
def before_step(self):
|
||||
# Maintain the invariant that storage.iter == trainer.iter
|
||||
|
|
|
@ -46,6 +46,7 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
|||
self.prev_intro_cls = cfg.OWOD.PREV_INTRODUCED_CLS
|
||||
self.curr_intro_cls = cfg.OWOD.CUR_INTRODUCED_CLS
|
||||
self.total_num_class = cfg.MODEL.ROI_HEADS.NUM_CLASSES
|
||||
self.known_classes = self._class_names[:self.prev_intro_cls + self.curr_intro_cls]
|
||||
|
||||
def reset(self):
|
||||
self._predictions = defaultdict(list) # class name -> list of prediction strings
|
||||
|
@ -57,6 +58,7 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
|||
boxes = instances.pred_boxes.tensor.numpy()
|
||||
scores = instances.scores.tolist()
|
||||
classes = instances.pred_classes.tolist()
|
||||
logits = instances.logits
|
||||
for box, score, cls in zip(boxes, scores, classes):
|
||||
xmin, ymin, xmax, ymax = box
|
||||
# The inverse of data loading logic in `datasets/pascal_voc.py`
|
||||
|
@ -108,6 +110,7 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
|||
cls_name,
|
||||
ovthresh=thresh / 100.0,
|
||||
use_07_metric=self._is_2007,
|
||||
known_classes=self.known_classes
|
||||
)
|
||||
aps[thresh].append(ap * 100)
|
||||
# recs[thresh].append(rec * 100)
|
||||
|
@ -166,14 +169,27 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
|||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def parse_rec(filename):
|
||||
def parse_rec(filename, known_classes):
|
||||
"""Parse a PASCAL VOC xml file."""
|
||||
VOC_CLASS_NAMES_COCOFIED = [
|
||||
"airplane", "dining table", "motorcycle",
|
||||
"potted plant", "couch", "tv"
|
||||
]
|
||||
BASE_VOC_CLASS_NAMES = [
|
||||
"aeroplane", "diningtable", "motorbike",
|
||||
"pottedplant", "sofa", "tvmonitor"
|
||||
]
|
||||
with PathManager.open(filename) as f:
|
||||
tree = ET.parse(f)
|
||||
objects = []
|
||||
for obj in tree.findall("object"):
|
||||
obj_struct = {}
|
||||
obj_struct["name"] = obj.find("name").text
|
||||
cls_name = obj.find("name").text
|
||||
if cls_name in VOC_CLASS_NAMES_COCOFIED:
|
||||
cls_name = BASE_VOC_CLASS_NAMES[VOC_CLASS_NAMES_COCOFIED.index(cls_name)]
|
||||
if cls_name not in known_classes:
|
||||
cls_name = 'unknown'
|
||||
obj_struct["name"] = cls_name
|
||||
# obj_struct["pose"] = obj.find("pose").text
|
||||
# obj_struct["truncated"] = int(obj.find("truncated").text)
|
||||
obj_struct["difficult"] = int(obj.find("difficult").text)
|
||||
|
@ -221,7 +237,7 @@ def voc_ap(rec, prec, use_07_metric=False):
|
|||
return ap
|
||||
|
||||
|
||||
def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_metric=False):
|
||||
def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_metric=False, known_classes=None):
|
||||
"""rec, prec, ap = voc_eval(detpath,
|
||||
annopath,
|
||||
imagesetfile,
|
||||
|
@ -254,7 +270,7 @@ def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_me
|
|||
# load annots
|
||||
recs = {}
|
||||
for imagename in imagenames:
|
||||
recs[imagename] = parse_rec(annopath.format(imagename))
|
||||
recs[imagename] = parse_rec(annopath.format(imagename), tuple(known_classes))
|
||||
|
||||
# extract gt objects for this class
|
||||
class_recs = {}
|
||||
|
|
|
@ -47,7 +47,7 @@ Naming convention:
|
|||
"""
|
||||
|
||||
|
||||
def fast_rcnn_inference(boxes, scores, image_shapes, score_thresh, nms_thresh, topk_per_image):
|
||||
def fast_rcnn_inference(boxes, scores, image_shapes, predictions, score_thresh, nms_thresh, topk_per_image):
|
||||
"""
|
||||
Call `fast_rcnn_inference_single_image` for all images.
|
||||
|
||||
|
@ -75,15 +75,15 @@ def fast_rcnn_inference(boxes, scores, image_shapes, score_thresh, nms_thresh, t
|
|||
"""
|
||||
result_per_image = [
|
||||
fast_rcnn_inference_single_image(
|
||||
boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image
|
||||
boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image, prediction
|
||||
)
|
||||
for scores_per_image, boxes_per_image, image_shape in zip(scores, boxes, image_shapes)
|
||||
for scores_per_image, boxes_per_image, image_shape, prediction in zip(scores, boxes, image_shapes, predictions)
|
||||
]
|
||||
return [x[0] for x in result_per_image], [x[1] for x in result_per_image]
|
||||
|
||||
|
||||
def fast_rcnn_inference_single_image(
|
||||
boxes, scores, image_shape, score_thresh, nms_thresh, topk_per_image
|
||||
boxes, scores, image_shape, score_thresh, nms_thresh, topk_per_image, prediction
|
||||
):
|
||||
"""
|
||||
Single-image inference. Return bounding-box detection results by thresholding
|
||||
|
@ -96,12 +96,15 @@ def fast_rcnn_inference_single_image(
|
|||
Returns:
|
||||
Same as `fast_rcnn_inference`, but for only one image.
|
||||
"""
|
||||
logits = prediction
|
||||
valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1)
|
||||
if not valid_mask.all():
|
||||
boxes = boxes[valid_mask]
|
||||
scores = scores[valid_mask]
|
||||
logits = logits[valid_mask]
|
||||
|
||||
scores = scores[:, :-1]
|
||||
logits = logits[:, :-1]
|
||||
num_bbox_reg_classes = boxes.shape[1] // 4
|
||||
# Convert to Boxes to use the `clip` function ...
|
||||
boxes = Boxes(boxes.reshape(-1, 4))
|
||||
|
@ -119,17 +122,20 @@ def fast_rcnn_inference_single_image(
|
|||
else:
|
||||
boxes = boxes[filter_mask]
|
||||
scores = scores[filter_mask]
|
||||
logits = logits[filter_inds[:,0]]
|
||||
|
||||
# 2. Apply NMS for each class independently.
|
||||
keep = batched_nms(boxes, scores, filter_inds[:, 1], nms_thresh)
|
||||
if topk_per_image >= 0:
|
||||
keep = keep[:topk_per_image]
|
||||
boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep]
|
||||
logits = logits[keep]
|
||||
|
||||
result = Instances(image_shape)
|
||||
result.pred_boxes = Boxes(boxes)
|
||||
result.scores = scores
|
||||
result.pred_classes = filter_inds[:, 1]
|
||||
result.logits = logits
|
||||
return result, filter_inds[:, 0]
|
||||
|
||||
|
||||
|
@ -734,6 +740,7 @@ class FastRCNNOutputLayers(nn.Module):
|
|||
boxes,
|
||||
scores,
|
||||
image_shapes,
|
||||
predictions,
|
||||
self.test_score_thresh,
|
||||
self.test_nms_thresh,
|
||||
self.test_topk_per_image,
|
||||
|
|
|
@ -3,6 +3,8 @@ import inspect
|
|||
import logging
|
||||
import numpy as np
|
||||
import heapq
|
||||
import os
|
||||
import shortuuid
|
||||
import operator
|
||||
import shortuuid
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
@ -385,6 +387,8 @@ class Res5ROIHeads(ROIHeads):
|
|||
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
|
||||
self.mask_on = cfg.MODEL.MASK_ON
|
||||
self.enable_clustering = cfg.OWOD.ENABLE_CLUSTERING
|
||||
self.compute_energy_flag = cfg.OWOD.COMPUTE_ENERGY
|
||||
self.energy_save_path = os.path.join(cfg.OUTPUT_DIR, cfg.OWOD.ENERGY_SAVE_PATH)
|
||||
# fmt: on
|
||||
assert not cfg.MODEL.KEYPOINT_ON
|
||||
assert len(self.in_features) == 1
|
||||
|
@ -443,6 +447,13 @@ class Res5ROIHeads(ROIHeads):
|
|||
location = '/home/fk1/workspace/OWOD/output/features/' + shortuuid.uuid() + '.pkl'
|
||||
torch.save(data, location)
|
||||
|
||||
def compute_energy(self, predictions, proposals):
|
||||
gt_classes = torch.cat([p.gt_classes for p in proposals])
|
||||
logits = predictions[0]
|
||||
data = (logits, gt_classes)
|
||||
location = os.path.join(self.energy_save_path, shortuuid.uuid() + '.pkl')
|
||||
torch.save(data, location)
|
||||
|
||||
def forward(self, images, features, proposals, targets=None):
|
||||
"""
|
||||
See :meth:`ROIHeads.forward`.
|
||||
|
@ -466,6 +477,8 @@ class Res5ROIHeads(ROIHeads):
|
|||
if self.enable_clustering:
|
||||
self.box_predictor.update_feature_store(input_features, proposals)
|
||||
del features
|
||||
if self.compute_energy_flag:
|
||||
self.compute_energy(predictions, proposals)
|
||||
losses = self.box_predictor.losses(predictions, proposals, input_features)
|
||||
if self.mask_on:
|
||||
proposals, fg_selection_masks = select_foreground_proposals(
|
||||
|
|
Loading…
Reference in New Issue