Weibull distribution

pull/42/head
Joseph 2020-11-06 15:25:34 +05:30
parent 9321996ee8
commit 71480aa878
14 changed files with 200 additions and 18 deletions

View File

@ -24,10 +24,10 @@ SOLVER:
MAX_ITER: 90000 MAX_ITER: 90000
VERSION: 2 VERSION: 2
OWOD: OWOD:
ENABLE_THRESHOLD_AUTOLABEL_UNK: False ENABLE_THRESHOLD_AUTOLABEL_UNK: True
NUM_UNK_PER_IMAGE: 1 NUM_UNK_PER_IMAGE: 1
ENABLE_UNCERTAINITY_AUTOLABEL_UNK: False ENABLE_UNCERTAINITY_AUTOLABEL_UNK: False
ENABLE_CLUSTERING: False ENABLE_CLUSTERING: True
CLUSTERING: CLUSTERING:
ITEMS_PER_CLASS: 20 ITEMS_PER_CLASS: 20
START_ITER: 1000 START_ITER: 1000

View File

@ -2,7 +2,7 @@ _BASE_: "../Base-RCNN-C4-OWOD.yaml"
MODEL: MODEL:
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
# WEIGHTS: "/home/fk1/workspace/OWOD/output/expr_training_with_unk_with_clustering_Z_DIMENSION_256/model_final.pth" # WEIGHTS: "/home/fk1/workspace/OWOD/output/expr_training_with_unk_with_clustering_Z_DIMENSION_256/model_final.pth"
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1/model_final.pth" # WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_std_frcnn/model_final.pth"
DATASETS: DATASETS:
TRAIN: ('t1_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft TRAIN: ('t1_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
TEST: ('voc_coco_2007_test', ) # t1_voc_coco_2007_test, t1_voc_coco_2007_val TEST: ('voc_coco_2007_test', ) # t1_voc_coco_2007_test, t1_voc_coco_2007_val
@ -10,7 +10,7 @@ SOLVER:
STEPS: (12000, 16000) STEPS: (12000, 16000)
MAX_ITER: 18000 MAX_ITER: 18000
WARMUP_ITERS: 100 WARMUP_ITERS: 100
OUTPUT_DIR: "./output/t1_std_frcnn" OUTPUT_DIR: "./output/t1_ENABLE_CLUSTERING_margin_5"
OWOD: OWOD:
PREV_INTRODUCED_CLS: 0 PREV_INTRODUCED_CLS: 0
CUR_INTRODUCED_CLS: 20 CUR_INTRODUCED_CLS: 20

View File

@ -0,0 +1,14 @@
_BASE_: "../../Base-RCNN-C4-OWOD.yaml"
MODEL:
WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_expr/model_final.pth"
DATASETS:
TRAIN: ('t1_voc_coco_2007_ft', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
TEST: (, ) # voc_coco_2007_test
SOLVER:
STEPS: (12000, 16000)
MAX_ITER: 18000
WARMUP_ITERS: 100
OUTPUT_DIR: "./output/t1_expr"
OWOD:
PREV_INTRODUCED_CLS: 0
CUR_INTRODUCED_CLS: 20

View File

@ -0,0 +1,15 @@
_BASE_: "../../Base-RCNN-C4-OWOD.yaml"
MODEL:
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1/model_final.pth"
DATASETS:
TRAIN: ('t1_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
TEST: ('voc_coco_2007_test', ) # voc_coco_2007_test
SOLVER:
STEPS: (12000, 16000)
MAX_ITER: 18000
WARMUP_ITERS: 100
OUTPUT_DIR: "./output/t1_expr"
OWOD:
PREV_INTRODUCED_CLS: 0
CUR_INTRODUCED_CLS: 20

View File

@ -0,0 +1,15 @@
_BASE_: "../../Base-RCNN-C4-OWOD.yaml"
MODEL:
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1/model_final.pth"
DATASETS:
TRAIN: ('t1_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
TEST: (, ) # voc_coco_2007_test
SOLVER:
STEPS: (12000, 16000)
MAX_ITER: 18000
WARMUP_ITERS: 100
OUTPUT_DIR: "./output/t1_expr"
OWOD:
PREV_INTRODUCED_CLS: 0
CUR_INTRODUCED_CLS: 20

View File

@ -0,0 +1,23 @@
_BASE_: "../../Base-RCNN-C4-OWOD.yaml"
MODEL:
# WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_THRESHOLD_AUTOLABEL_UNK/model_final.pth"
WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_ENABLE_CLUSTERING/model_final.pth"
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_std_frcnn/model_final.pth"
DATASETS:
TRAIN: ('voc_coco_2007_val', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
TEST: ('voc_coco_2007_val', ) # voc_coco_2007_test
SOLVER:
STEPS: (12000, 16000)
MAX_ITER: 500
WARMUP_ITERS: 0
OUTPUT_DIR: "./output/t1_clustering_new_3"
OWOD:
PREV_INTRODUCED_CLS: 0
CUR_INTRODUCED_CLS: 20
COMPUTE_ENERGY: True
ENERGY_SAVE_PATH: 'energy'
SKIP_TRAINING_WHILE_EVAL: False
#OUTPUT_DIR: "./output/t1_std_frcnn_energy"

View File

@ -1,9 +1,10 @@
_BASE_: "../Base-RCNN-C4-OWOD.yaml" _BASE_: "../Base-RCNN-C4-OWOD.yaml"
MODEL: MODEL:
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" # WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_ENABLE_CLUSTERING/model_final.pth"
DATASETS: DATASETS:
TRAIN: ('t2_train',) TRAIN: ('t2_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
TEST: ('voc_2007_test', 't2_test', 't3_test_unk', 't4_test_unk') TEST: ('voc_coco_2007_test', )
SOLVER: SOLVER:
STEPS: (22000, 26000) STEPS: (22000, 26000)
MAX_ITER: 28000 MAX_ITER: 28000

View File

@ -613,6 +613,9 @@ _C.OWOD.CLUSTERING.Z_DIMENSION = 64
_C.OWOD.PREV_INTRODUCED_CLS = 0 _C.OWOD.PREV_INTRODUCED_CLS = 0
_C.OWOD.CUR_INTRODUCED_CLS = 20 _C.OWOD.CUR_INTRODUCED_CLS = 20
_C.OWOD.COMPUTE_ENERGY = False
_C.OWOD.ENERGY_SAVE_PATH = ''
_C.OWOD.SKIP_TRAINING_WHILE_EVAL = False
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# Misc options # Misc options

View File

@ -263,7 +263,7 @@ def remove_prev_class_and_unk_instances(cfg, dataset_dicts):
# For training data. # For training data.
prev_intro_cls = cfg.OWOD.PREV_INTRODUCED_CLS prev_intro_cls = cfg.OWOD.PREV_INTRODUCED_CLS
curr_intro_cls = cfg.OWOD.CUR_INTRODUCED_CLS curr_intro_cls = cfg.OWOD.CUR_INTRODUCED_CLS
valid_classes = range(prev_intro_cls, curr_intro_cls) valid_classes = range(prev_intro_cls, prev_intro_cls + curr_intro_cls)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.info("Valid classes: " + str(valid_classes)) logger.info("Valid classes: " + str(valid_classes))

View File

@ -118,6 +118,9 @@ def default_setup(cfg, args):
if comm.is_main_process() and output_dir: if comm.is_main_process() and output_dir:
PathManager.mkdirs(output_dir) PathManager.mkdirs(output_dir)
if cfg.OWOD.COMPUTE_ENERGY:
PathManager.mkdirs(os.path.join(output_dir, cfg.OWOD.ENERGY_SAVE_PATH))
rank = comm.get_rank() rank = comm.get_rank()
setup_logger(output_dir, distributed_rank=rank, name="fvcore") setup_logger(output_dir, distributed_rank=rank, name="fvcore")
logger = setup_logger(output_dir, distributed_rank=rank) logger = setup_logger(output_dir, distributed_rank=rank)

View File

@ -7,6 +7,9 @@ import numpy as np
import time import time
import weakref import weakref
import torch import torch
import os
from matplotlib import pyplot
from reliability.Fitters import Fit_Weibull_3P
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
from detectron2.utils.events import EventStorage from detectron2.utils.events import EventStorage
@ -138,6 +141,8 @@ class TrainerBase:
try: try:
self.before_train() self.before_train()
for self.iter in range(start_iter, max_iter): for self.iter in range(start_iter, max_iter):
if self.cfg.OWOD.SKIP_TRAINING_WHILE_EVAL:
continue
self.before_step() self.before_step()
self.run_step() self.run_step()
self.after_step() self.after_step()
@ -152,13 +157,80 @@ class TrainerBase:
self.after_train() self.after_train()
def before_train(self): def before_train(self):
if self.cfg.OWOD.SKIP_TRAINING_WHILE_EVAL:
logger.info('Skipping training as cfg.OWOD.SKIP_TRAINING_WHILE_EVAL flag is set.')
for h in self._hooks: for h in self._hooks:
h.before_train() h.before_train()
def after_train(self): def after_train(self):
self.storage.iter = self.iter self.storage.iter = self.iter
for h in self._hooks: if self.cfg.OWOD.COMPUTE_ENERGY:
h.after_train() logger = logging.getLogger(__name__)
logger.info("Going to analyse the energy files...")
self.analyse_energy()
for h in self._hooks:
if 'EvalHook' not in str(type(h)):
h.after_train()
else:
for h in self._hooks:
h.after_train()
def analyse_energy(self):
files = os.listdir(os.path.join(self.cfg.OUTPUT_DIR, self.cfg.OWOD.ENERGY_SAVE_PATH))
unk = []
known = []
logger = logging.getLogger(__name__)
for id, file in enumerate(files):
path = os.path.join(self.cfg.OUTPUT_DIR, self.cfg.OWOD.ENERGY_SAVE_PATH, file)
try:
logits, classes = torch.load(path)
except:
logger.info('Not able to load ' + path + ". Continuing...")
continue
num_seen_classes = self.cfg.OWOD.PREV_INTRODUCED_CLS + self.cfg.OWOD.CUR_INTRODUCED_CLS
lse = torch.logsumexp(logits[:, :num_seen_classes], dim=1)
# lse = torch.logsumexp(logits[:, :-2], dim=1)
for i, cls in enumerate(classes):
if cls == self.cfg.MODEL.ROI_HEADS.NUM_CLASSES:
continue
if cls == self.cfg.MODEL.ROI_HEADS.NUM_CLASSES-1:
unk.append(lse[i].detach().cpu().tolist())
else:
known.append(lse[i].detach().cpu().tolist())
if id % 100 == 0:
logger.info("Analysing " + str(id) + " / " + str(len(files)))
if id == 10:
break
logger.info('len(unk): ' + str(len(unk)))
logger.info('len(known): '+ str(len(known)))
logger.info('Fitting Weibull distribution...')
wb_dist_param = []
wb_unk = Fit_Weibull_3P(failures=unk, show_probability_plot=False, print_results=False)
wb_dist_param.append({"scale_unk": wb_unk.alpha, "shape_unk": wb_unk.beta, "shift_unk": wb_unk.gamma})
wb_known = Fit_Weibull_3P(failures=known, show_probability_plot=False, print_results=False)
wb_dist_param.append(
{"scale_known": wb_known.alpha, "shape_known": wb_known.beta, "shift_known": wb_known.gamma})
param_save_location = os.path.join(self.cfg.OUTPUT_DIR,
'energy_dist_' + str(self.cfg.OWOD.PREV_INTRODUCED_CLS
+ self.cfg.OWOD.CUR_INTRODUCED_CLS) + '.pkl')
logger.info('Pickling the parameters to ' + param_save_location)
torch.save(wb_dist_param, param_save_location)
logger.info('Plotting the computed energy values...')
bins = np.linspace(2, 15, 500)
pyplot.hist(known, bins, alpha=0.5, label='known')
pyplot.hist(unk, bins, alpha=0.5, label='unk')
pyplot.legend(loc='upper right')
pyplot.savefig(os.path.join(self.cfg.OUTPUT_DIR, 'energy.png'))
def before_step(self): def before_step(self):
# Maintain the invariant that storage.iter == trainer.iter # Maintain the invariant that storage.iter == trainer.iter

View File

@ -46,6 +46,7 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
self.prev_intro_cls = cfg.OWOD.PREV_INTRODUCED_CLS self.prev_intro_cls = cfg.OWOD.PREV_INTRODUCED_CLS
self.curr_intro_cls = cfg.OWOD.CUR_INTRODUCED_CLS self.curr_intro_cls = cfg.OWOD.CUR_INTRODUCED_CLS
self.total_num_class = cfg.MODEL.ROI_HEADS.NUM_CLASSES self.total_num_class = cfg.MODEL.ROI_HEADS.NUM_CLASSES
self.known_classes = self._class_names[:self.prev_intro_cls + self.curr_intro_cls]
def reset(self): def reset(self):
self._predictions = defaultdict(list) # class name -> list of prediction strings self._predictions = defaultdict(list) # class name -> list of prediction strings
@ -57,6 +58,7 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
boxes = instances.pred_boxes.tensor.numpy() boxes = instances.pred_boxes.tensor.numpy()
scores = instances.scores.tolist() scores = instances.scores.tolist()
classes = instances.pred_classes.tolist() classes = instances.pred_classes.tolist()
logits = instances.logits
for box, score, cls in zip(boxes, scores, classes): for box, score, cls in zip(boxes, scores, classes):
xmin, ymin, xmax, ymax = box xmin, ymin, xmax, ymax = box
# The inverse of data loading logic in `datasets/pascal_voc.py` # The inverse of data loading logic in `datasets/pascal_voc.py`
@ -108,6 +110,7 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
cls_name, cls_name,
ovthresh=thresh / 100.0, ovthresh=thresh / 100.0,
use_07_metric=self._is_2007, use_07_metric=self._is_2007,
known_classes=self.known_classes
) )
aps[thresh].append(ap * 100) aps[thresh].append(ap * 100)
# recs[thresh].append(rec * 100) # recs[thresh].append(rec * 100)
@ -166,14 +169,27 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def parse_rec(filename): def parse_rec(filename, known_classes):
"""Parse a PASCAL VOC xml file.""" """Parse a PASCAL VOC xml file."""
VOC_CLASS_NAMES_COCOFIED = [
"airplane", "dining table", "motorcycle",
"potted plant", "couch", "tv"
]
BASE_VOC_CLASS_NAMES = [
"aeroplane", "diningtable", "motorbike",
"pottedplant", "sofa", "tvmonitor"
]
with PathManager.open(filename) as f: with PathManager.open(filename) as f:
tree = ET.parse(f) tree = ET.parse(f)
objects = [] objects = []
for obj in tree.findall("object"): for obj in tree.findall("object"):
obj_struct = {} obj_struct = {}
obj_struct["name"] = obj.find("name").text cls_name = obj.find("name").text
if cls_name in VOC_CLASS_NAMES_COCOFIED:
cls_name = BASE_VOC_CLASS_NAMES[VOC_CLASS_NAMES_COCOFIED.index(cls_name)]
if cls_name not in known_classes:
cls_name = 'unknown'
obj_struct["name"] = cls_name
# obj_struct["pose"] = obj.find("pose").text # obj_struct["pose"] = obj.find("pose").text
# obj_struct["truncated"] = int(obj.find("truncated").text) # obj_struct["truncated"] = int(obj.find("truncated").text)
obj_struct["difficult"] = int(obj.find("difficult").text) obj_struct["difficult"] = int(obj.find("difficult").text)
@ -221,7 +237,7 @@ def voc_ap(rec, prec, use_07_metric=False):
return ap return ap
def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_metric=False): def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_metric=False, known_classes=None):
"""rec, prec, ap = voc_eval(detpath, """rec, prec, ap = voc_eval(detpath,
annopath, annopath,
imagesetfile, imagesetfile,
@ -254,7 +270,7 @@ def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_me
# load annots # load annots
recs = {} recs = {}
for imagename in imagenames: for imagename in imagenames:
recs[imagename] = parse_rec(annopath.format(imagename)) recs[imagename] = parse_rec(annopath.format(imagename), tuple(known_classes))
# extract gt objects for this class # extract gt objects for this class
class_recs = {} class_recs = {}

View File

@ -47,7 +47,7 @@ Naming convention:
""" """
def fast_rcnn_inference(boxes, scores, image_shapes, score_thresh, nms_thresh, topk_per_image): def fast_rcnn_inference(boxes, scores, image_shapes, predictions, score_thresh, nms_thresh, topk_per_image):
""" """
Call `fast_rcnn_inference_single_image` for all images. Call `fast_rcnn_inference_single_image` for all images.
@ -75,15 +75,15 @@ def fast_rcnn_inference(boxes, scores, image_shapes, score_thresh, nms_thresh, t
""" """
result_per_image = [ result_per_image = [
fast_rcnn_inference_single_image( fast_rcnn_inference_single_image(
boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image, prediction
) )
for scores_per_image, boxes_per_image, image_shape in zip(scores, boxes, image_shapes) for scores_per_image, boxes_per_image, image_shape, prediction in zip(scores, boxes, image_shapes, predictions)
] ]
return [x[0] for x in result_per_image], [x[1] for x in result_per_image] return [x[0] for x in result_per_image], [x[1] for x in result_per_image]
def fast_rcnn_inference_single_image( def fast_rcnn_inference_single_image(
boxes, scores, image_shape, score_thresh, nms_thresh, topk_per_image boxes, scores, image_shape, score_thresh, nms_thresh, topk_per_image, prediction
): ):
""" """
Single-image inference. Return bounding-box detection results by thresholding Single-image inference. Return bounding-box detection results by thresholding
@ -96,12 +96,15 @@ def fast_rcnn_inference_single_image(
Returns: Returns:
Same as `fast_rcnn_inference`, but for only one image. Same as `fast_rcnn_inference`, but for only one image.
""" """
logits = prediction
valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1) valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1)
if not valid_mask.all(): if not valid_mask.all():
boxes = boxes[valid_mask] boxes = boxes[valid_mask]
scores = scores[valid_mask] scores = scores[valid_mask]
logits = logits[valid_mask]
scores = scores[:, :-1] scores = scores[:, :-1]
logits = logits[:, :-1]
num_bbox_reg_classes = boxes.shape[1] // 4 num_bbox_reg_classes = boxes.shape[1] // 4
# Convert to Boxes to use the `clip` function ... # Convert to Boxes to use the `clip` function ...
boxes = Boxes(boxes.reshape(-1, 4)) boxes = Boxes(boxes.reshape(-1, 4))
@ -119,17 +122,20 @@ def fast_rcnn_inference_single_image(
else: else:
boxes = boxes[filter_mask] boxes = boxes[filter_mask]
scores = scores[filter_mask] scores = scores[filter_mask]
logits = logits[filter_inds[:,0]]
# 2. Apply NMS for each class independently. # 2. Apply NMS for each class independently.
keep = batched_nms(boxes, scores, filter_inds[:, 1], nms_thresh) keep = batched_nms(boxes, scores, filter_inds[:, 1], nms_thresh)
if topk_per_image >= 0: if topk_per_image >= 0:
keep = keep[:topk_per_image] keep = keep[:topk_per_image]
boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep] boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep]
logits = logits[keep]
result = Instances(image_shape) result = Instances(image_shape)
result.pred_boxes = Boxes(boxes) result.pred_boxes = Boxes(boxes)
result.scores = scores result.scores = scores
result.pred_classes = filter_inds[:, 1] result.pred_classes = filter_inds[:, 1]
result.logits = logits
return result, filter_inds[:, 0] return result, filter_inds[:, 0]
@ -734,6 +740,7 @@ class FastRCNNOutputLayers(nn.Module):
boxes, boxes,
scores, scores,
image_shapes, image_shapes,
predictions,
self.test_score_thresh, self.test_score_thresh,
self.test_nms_thresh, self.test_nms_thresh,
self.test_topk_per_image, self.test_topk_per_image,

View File

@ -3,6 +3,8 @@ import inspect
import logging import logging
import numpy as np import numpy as np
import heapq import heapq
import os
import shortuuid
import operator import operator
import shortuuid import shortuuid
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
@ -385,6 +387,8 @@ class Res5ROIHeads(ROIHeads):
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
self.mask_on = cfg.MODEL.MASK_ON self.mask_on = cfg.MODEL.MASK_ON
self.enable_clustering = cfg.OWOD.ENABLE_CLUSTERING self.enable_clustering = cfg.OWOD.ENABLE_CLUSTERING
self.compute_energy_flag = cfg.OWOD.COMPUTE_ENERGY
self.energy_save_path = os.path.join(cfg.OUTPUT_DIR, cfg.OWOD.ENERGY_SAVE_PATH)
# fmt: on # fmt: on
assert not cfg.MODEL.KEYPOINT_ON assert not cfg.MODEL.KEYPOINT_ON
assert len(self.in_features) == 1 assert len(self.in_features) == 1
@ -443,6 +447,13 @@ class Res5ROIHeads(ROIHeads):
location = '/home/fk1/workspace/OWOD/output/features/' + shortuuid.uuid() + '.pkl' location = '/home/fk1/workspace/OWOD/output/features/' + shortuuid.uuid() + '.pkl'
torch.save(data, location) torch.save(data, location)
def compute_energy(self, predictions, proposals):
gt_classes = torch.cat([p.gt_classes for p in proposals])
logits = predictions[0]
data = (logits, gt_classes)
location = os.path.join(self.energy_save_path, shortuuid.uuid() + '.pkl')
torch.save(data, location)
def forward(self, images, features, proposals, targets=None): def forward(self, images, features, proposals, targets=None):
""" """
See :meth:`ROIHeads.forward`. See :meth:`ROIHeads.forward`.
@ -466,6 +477,8 @@ class Res5ROIHeads(ROIHeads):
if self.enable_clustering: if self.enable_clustering:
self.box_predictor.update_feature_store(input_features, proposals) self.box_predictor.update_feature_store(input_features, proposals)
del features del features
if self.compute_energy_flag:
self.compute_energy(predictions, proposals)
losses = self.box_predictor.losses(predictions, proposals, input_features) losses = self.box_predictor.losses(predictions, proposals, input_features)
if self.mask_on: if self.mask_on:
proposals, fg_selection_masks = select_foreground_proposals( proposals, fg_selection_masks = select_foreground_proposals(