- Saving feature store

- Implemented Absolute OSE error and initial WI
pull/42/head
Joseph 2020-11-09 10:34:45 +05:30
parent ea9f7d15f8
commit 81516e5543
10 changed files with 5262 additions and 108 deletions

View File

@ -28,6 +28,7 @@ OWOD:
NUM_UNK_PER_IMAGE: 1
ENABLE_UNCERTAINITY_AUTOLABEL_UNK: False
ENABLE_CLUSTERING: True
FEATURE_STORE_SAVE_PATH: 'feature_store'
CLUSTERING:
ITEMS_PER_CLASS: 20
START_ITER: 1000

View File

@ -1,10 +1,18 @@
_BASE_: "../../Base-RCNN-C4-OWOD.yaml"
MODEL:
# WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_ENABLE_CLUSTERING/model_final.pth"
WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_std_frcnn/model_final.pth"
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_THRESHOLD_AUTOLABEL_UNK/model_final.pth"
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_ENABLE_CLUSTERING/model_final.pth"
ROI_HEADS:
POSITIVE_FRACTION: 0.25
NMS_THRESH_TEST: 0.5
SCORE_THRESH_TEST: 0.05
TEST:
DETECTIONS_PER_IMAGE: 50
DATASETS:
TRAIN: ('t1_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
TEST: ('voc_coco_2007_test', ) # voc_coco_2007_test
TEST: ('voc_coco_2007_test', 't1_voc_coco_2007_known_test') # voc_coco_2007_test
SOLVER:
STEPS: (12000, 16000)
MAX_ITER: 18000
@ -12,4 +20,8 @@ SOLVER:
OUTPUT_DIR: "./output/t1_clustering_new_4"
OWOD:
PREV_INTRODUCED_CLS: 0
CUR_INTRODUCED_CLS: 20
CUR_INTRODUCED_CLS: 20
# POSITIVE_FRACTION: 0.25
# NMS_THRESH_TEST: 0.5
# SCORE_THRESH_TEST: 0.05

View File

@ -2,14 +2,15 @@ _BASE_: "../../Base-RCNN-C4-OWOD.yaml"
MODEL:
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1/model_final.pth"
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_std_frcnn/model_final.pth"
DATASETS:
TRAIN: ('t1_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
TEST: (, ) # voc_coco_2007_test
TEST: ('voc_coco_2007_test', ) # voc_coco_2007_test, t1_voc_coco_2007_test, t1_voc_coco_2007_val
SOLVER:
STEPS: (12000, 16000)
MAX_ITER: 18000
WARMUP_ITERS: 100
OUTPUT_DIR: "./output/t1_expr"
OUTPUT_DIR: "./output/t1_clustering_with_save"
OWOD:
PREV_INTRODUCED_CLS: 0
CUR_INTRODUCED_CLS: 20

View File

@ -1,7 +1,8 @@
_BASE_: "../Base-RCNN-C4-OWOD.yaml"
MODEL:
# WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_ENABLE_CLUSTERING/model_final.pth"
# WEIGHTS: "/home/fk1/workspace/OWOD/output/t1_ENABLE_CLUSTERING/model_final.pth"
WEIGHTS: "/home/fk1/workspace/OWOD/output/t2/model_final.pth"
DATASETS:
TRAIN: ('t2_voc_coco_2007_train', ) # t1_voc_coco_2007_train, t1_voc_coco_2007_ft
TEST: ('voc_coco_2007_test', )
@ -12,4 +13,10 @@ SOLVER:
OUTPUT_DIR: "./output/t2"
OWOD:
PREV_INTRODUCED_CLS: 20
CUR_INTRODUCED_CLS: 20
CUR_INTRODUCED_CLS: 20
CLUSTERING:
ITEMS_PER_CLASS: 20
START_ITER: 20000
UPDATE_MU_ITER: 3000
MOMENTUM: 0.99
Z_DIMENSION: 128

File diff suppressed because it is too large Load Diff

View File

@ -616,6 +616,7 @@ _C.OWOD.CUR_INTRODUCED_CLS = 20
_C.OWOD.COMPUTE_ENERGY = False
_C.OWOD.ENERGY_SAVE_PATH = ''
_C.OWOD.SKIP_TRAINING_WHILE_EVAL = False
_C.OWOD.FEATURE_STORE_SAVE_PATH = ''
# ---------------------------------------------------------------------------- #
# Misc options

View File

@ -215,6 +215,7 @@ def register_all_pascal_voc(root):
("voc_2012_train", "VOC2012", "train"),
("voc_2012_val", "VOC2012", "val"),
("t1_voc_coco_2007_train", "VOC2007", "t1_train"),
("t1_voc_coco_2007_known_test", "VOC2007", "t1_known_test"),
("voc_coco_2007_test", "VOC2007", "all_task_test"),
("voc_coco_2007_val", "VOC2007", "all_task_val"),
("t1_voc_coco_2007_ft", "VOC2007", "t1_ft"),

View File

@ -121,6 +121,9 @@ def default_setup(cfg, args):
if cfg.OWOD.COMPUTE_ENERGY:
PathManager.mkdirs(os.path.join(output_dir, cfg.OWOD.ENERGY_SAVE_PATH))
if cfg.OWOD.ENABLE_CLUSTERING:
PathManager.mkdirs(os.path.join(output_dir, cfg.OWOD.FEATURE_STORE_SAVE_PATH))
rank = comm.get_rank()
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
logger = setup_logger(output_dir, distributed_rank=rank)

View File

@ -4,8 +4,10 @@
import logging
import numpy as np
import os
import sys
import tempfile
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
from collections import OrderedDict, defaultdict
from functools import lru_cache
import torch
@ -19,6 +21,7 @@ from detectron2.utils import comm
from .evaluator import DatasetEvaluator
np.set_printoptions(threshold=sys.maxsize)
class PascalVOCDetectionEvaluator(DatasetEvaluator):
"""
@ -42,7 +45,8 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
self._image_set_path = os.path.join(meta.dirname, "ImageSets", "Main", meta.split + ".txt")
self._class_names = meta.thing_classes
assert meta.year in [2007, 2012], meta.year
self._is_2007 = meta.year == 2007
self._is_2007 = False
# self._is_2007 = meta.year == 2007
self._cpu_device = torch.device("cpu")
self._logger = logging.getLogger(__name__)
if cfg is not None:
@ -54,7 +58,7 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
self.known_classes = self._class_names[:self.num_seen_classes]
param_save_location = os.path.join(cfg.OUTPUT_DIR,'energy_dist_' + str(self.num_seen_classes) + '.pkl')
self.energy_distribution_loaded = False
if os.path.isfile(param_save_location) and os.access(param_save_location, os.R_OK):
self._logger.info('Loading energy distribution from ' + param_save_location)
params = torch.load(param_save_location)
@ -103,7 +107,6 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
cls[i] = -100
else:
if cls[i] != self.unknown_class_index:
# cls[i] = -100
cls[i] = self.unknown_class_index
return cls
@ -127,6 +130,28 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
f"{image_id} {score:.3f} {xmin:.1f} {ymin:.1f} {xmax:.1f} {ymax:.1f}"
)
def compute_avg_precision_at_many_recall_level(self, precisions, recalls):
precs = {}
for r in range(1, 10):
r = r/10
p = self.compute_avg_precision_at_a_recall_level(precisions, recalls, recall_level=r)
precs[r] = p
return precs
def compute_avg_precision_at_a_recall_level(self, precisions, recalls, recall_level=0.5):
precs = {}
for iou, recall in recalls.items():
prec = []
for cls_id, rec in enumerate(recall):
if cls_id in range(self.num_seen_classes) and len(rec)>0:
p = precisions[iou][cls_id][min(range(len(rec)), key=lambda i: abs(rec[i] - recall_level))]
prec.append(p)
if len(prec) > 0:
precs[iou] = np.mean(prec)
else:
precs[iou] = 0
return precs
def evaluate(self):
"""
Returns:
@ -152,8 +177,12 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
res_file_template = os.path.join(dirname, "{}.txt")
aps = defaultdict(list) # iou -> ap per class
# recs = defaultdict(list)
# precs = defaultdict(list)
recs = defaultdict(list)
precs = defaultdict(list)
all_recs = defaultdict(list)
all_precs = defaultdict(list)
unk_det_as_knowns = defaultdict(list)
num_unks = defaultdict(list)
for cls_id, cls_name in enumerate(self._class_names):
lines = predictions.get(cls_id, [""])
@ -161,47 +190,75 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
with open(res_file_template.format(cls_name), "w") as f:
f.write("\n".join(lines))
for thresh in range(50, 100, 5):
rec, prec, ap = voc_eval(
res_file_template,
self._anno_file_template,
self._image_set_path,
cls_name,
ovthresh=thresh / 100.0,
use_07_metric=self._is_2007,
known_classes=self.known_classes
)
aps[thresh].append(ap * 100)
# recs[thresh].append(rec * 100)
# precs[thresh].append(prec * 100)
# for thresh in range(50, 100, 5):
thresh = 50
rec, prec, ap, unk_det_as_known, num_unk = voc_eval(
res_file_template,
self._anno_file_template,
self._image_set_path,
cls_name,
ovthresh=thresh / 100.0,
use_07_metric=self._is_2007,
known_classes=self.known_classes
)
aps[thresh].append(ap * 100)
unk_det_as_knowns[thresh].append(unk_det_as_known)
num_unks[thresh].append(num_unk)
all_precs[thresh].append(prec)
all_recs[thresh].append(rec)
try:
recs[thresh].append(rec[-1] * 100)
precs[thresh].append(prec[-1] * 100)
except:
recs[thresh].append(0)
precs[thresh].append(0)
avg_precision = self.compute_avg_precision_at_many_recall_level(all_precs, all_recs)
self._logger.info('avg_precision: ' + str(avg_precision))
ret = OrderedDict()
mAP = {iou: np.mean(x) for iou, x in aps.items()}
ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]}
ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50]}
total_num_unk_det_as_known = {iou: np.sum(x) for iou, x in unk_det_as_knowns.items()}
total_num_unk = num_unks[50][0]
self._logger.info('Absolute OSE (total_num_unk_det_as_known): ' + str(total_num_unk_det_as_known))
self._logger.info('total_num_unk ' + str(total_num_unk))
# Extra logging of class-wise APs
avg_precs = list(np.mean([x for _, x in aps.items()], axis=0))
self._logger.info(self._class_names)
self._logger.info("AP__: " + str(['%.1f' % x for x in avg_precs]))
# self._logger.info("AP__: " + str(['%.1f' % x for x in avg_precs]))
self._logger.info("AP50: " + str(['%.1f' % x for x in aps[50]]))
self._logger.info("AP75: " + str(['%.1f' % x for x in aps[75]]))
self._logger.info("Precisions50: " + str(['%.1f' % x for x in precs[50]]))
self._logger.info("Recall50: " + str(['%.1f' % x for x in recs[50]]))
# self._logger.info("AP75: " + str(['%.1f' % x for x in aps[75]]))
if self.prev_intro_cls > 0:
self._logger.info("Prev class AP__: " + str(np.mean(avg_precs[:self.prev_intro_cls])))
# self._logger.info("\nPrev class AP__: " + str(np.mean(avg_precs[:self.prev_intro_cls])))
self._logger.info("Prev class AP50: " + str(np.mean(aps[50][:self.prev_intro_cls])))
self._logger.info("Prev class AP75: " + str(np.mean(aps[75][:self.prev_intro_cls])))
self._logger.info("Prev class Precisions50: " + str(np.mean(precs[50][:self.prev_intro_cls])))
self._logger.info("Prev class Recall50: " + str(np.mean(recs[50][:self.prev_intro_cls])))
self._logger.info("Current class AP__: " + str(np.mean(avg_precs[self.prev_intro_cls:self.curr_intro_cls])))
self._logger.info("Current class AP50: " + str(np.mean(aps[50][self.prev_intro_cls:self.curr_intro_cls])))
self._logger.info("Current class AP75: " + str(np.mean(aps[75][self.prev_intro_cls:self.curr_intro_cls])))
# self._logger.info("Prev class AP75: " + str(np.mean(aps[75][:self.prev_intro_cls])))
self._logger.info("Known AP__: " + str(np.mean(avg_precs[:self.prev_intro_cls + self.curr_intro_cls])))
# self._logger.info("\nCurrent class AP__: " + str(np.mean(avg_precs[self.prev_intro_cls:self.curr_intro_cls])))
self._logger.info("Current class AP50: " + str(np.mean(aps[50][self.prev_intro_cls:self.prev_intro_cls + self.curr_intro_cls])))
self._logger.info("Current class Precisions50: " + str(np.mean(precs[50][self.prev_intro_cls:self.prev_intro_cls + self.curr_intro_cls])))
self._logger.info("Current class Recall50: " + str(np.mean(recs[50][self.prev_intro_cls:self.prev_intro_cls + self.curr_intro_cls])))
# self._logger.info("Current class AP75: " + str(np.mean(aps[75][self.prev_intro_cls:self.curr_intro_cls])))
# self._logger.info("\nKnown AP__: " + str(np.mean(avg_precs[:self.prev_intro_cls + self.curr_intro_cls])))
self._logger.info("Known AP50: " + str(np.mean(aps[50][:self.prev_intro_cls + self.curr_intro_cls])))
self._logger.info("Known AP75: " + str(np.mean(aps[75][:self.prev_intro_cls + self.curr_intro_cls])))
self._logger.info("Known Precisions50: " + str(np.mean(precs[50][:self.prev_intro_cls + self.curr_intro_cls])))
self._logger.info("Known Recall50: " + str(np.mean(recs[50][:self.prev_intro_cls + self.curr_intro_cls])))
# self._logger.info("Known AP75: " + str(np.mean(aps[75][:self.prev_intro_cls + self.curr_intro_cls])))
self._logger.info("Unknown AP__: " + str(avg_precs[-1]))
# self._logger.info("\nUnknown AP__: " + str(avg_precs[-1]))
self._logger.info("Unknown AP50: " + str(aps[50][-1]))
self._logger.info("Unknown AP75: " + str(aps[75][-1]))
self._logger.info("Unknown Precisions50: " + str(precs[50][-1]))
self._logger.info("Unknown Recall50: " + str(recs[50][-1]))
# self._logger.info("Unknown AP75: " + str(aps[75][-1]))
# self._logger.info("R__: " + str(['%.1f' % x for x in list(np.mean([x for _, x in recs.items()], axis=0))]))
# self._logger.info("R50: " + str(['%.1f' % x for x in recs[50]]))
@ -362,6 +419,10 @@ def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_me
nd = len(image_ids)
tp = np.zeros(nd)
fp = np.zeros(nd)
# if 'unknown' not in classname:
# return tp, fp, 0
for d in range(nd):
R = class_recs[image_ids[d]]
bb = BB[d, :].astype(float)
@ -407,6 +468,93 @@ def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_me
# avoid divide by zero in case the first detection matches a difficult
# ground truth
prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
# plot_pr_curve(prec, rec, classname+'.png')
ap = voc_ap(rec, prec, use_07_metric)
return rec, prec, ap
# print('tp: ' + str(tp[-1]))
# print('fp: ' + str(fp[-1]))
# print('tp: ')
# print(tp)
# print('fp: ')
# print(fp)
'''
Computing Open-Set Error (OSE)
==============================
OSE = # of unknown objects classified as known objects of class 'classname'
--------------------------------------------------------------------
# of unknown objects
'''
logger = logging.getLogger(__name__)
# Finding GT of unknown objects
unknown_class_recs = {}
n_unk = 0
for imagename in imagenames:
R = [obj for obj in recs[imagename] if obj["name"] == 'unknown']
bbox = np.array([x["bbox"] for x in R])
difficult = np.array([x["difficult"] for x in R]).astype(np.bool)
det = [False] * len(R)
n_unk = n_unk + sum(~difficult)
unknown_class_recs[imagename] = {"bbox": bbox, "difficult": difficult, "det": det}
if classname == 'unknown':
return rec, prec, ap, 0, n_unk
# Go down each detection and see if it has an overlap with an unknown object.
# If so, it is an unknown object that was classified as known.
is_unk = np.zeros(nd)
for d in range(nd):
R = unknown_class_recs[image_ids[d]]
bb = BB[d, :].astype(float)
ovmax = -np.inf
BBGT = R["bbox"].astype(float)
if BBGT.size > 0:
# compute overlaps
# intersection
ixmin = np.maximum(BBGT[:, 0], bb[0])
iymin = np.maximum(BBGT[:, 1], bb[1])
ixmax = np.minimum(BBGT[:, 2], bb[2])
iymax = np.minimum(BBGT[:, 3], bb[3])
iw = np.maximum(ixmax - ixmin + 1.0, 0.0)
ih = np.maximum(iymax - iymin + 1.0, 0.0)
inters = iw * ih
# union
uni = (
(bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0)
+ (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0)
- inters
)
overlaps = inters / uni
ovmax = np.max(overlaps)
jmax = np.argmax(overlaps)
if ovmax > ovthresh:
is_unk[d] = 1.0
is_unk = np.sum(is_unk)
# OSE = is_unk / n_unk
# logger.info('Number of unknowns detected knowns (for class '+ classname + ') is ' + str(is_unk))
# logger.info("Num of unknown instances: " + str(n_unk))
# logger.info('OSE: ' + str(OSE))
return rec, prec, ap, is_unk, n_unk
def plot_pr_curve(precision, recall, filename, base_path='/home/fk1/workspace/OWOD/output/plots/'):
fig, ax = plt.subplots()
ax.step(recall, precision, color='r', alpha=0.99, where='post')
ax.fill_between(recall, precision, alpha=0.2, color='b', step='post')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.0])
plt.savefig(base_path + filename)
# print(precision)
# print(recall)

View File

@ -2,6 +2,7 @@
import logging
from typing import Dict, Union
import torch
import os
import math
import shortuuid
from fvcore.nn import giou_loss, smooth_l1_loss
@ -414,6 +415,9 @@ class FastRCNNOutputLayers(nn.Module):
enable_clustering,
prev_intro_cls,
curr_intro_cls,
max_iterations,
output_dir,
feat_store_path,
num_classes: int,
test_score_thresh: float = 0.0,
test_nms_thresh: float = 0.5,
@ -472,9 +476,6 @@ class FastRCNNOutputLayers(nn.Module):
self.clustering_update_mu_iter = clustering_update_mu_iter
self.clustering_momentum = clustering_momentum
self.feature_store = Store(num_classes + 1, clustering_items_per_class)
self.means = [None for _ in range(num_classes + 1)]
self.hingeloss = nn.HingeEmbeddingLoss(2)
self.enable_clustering = enable_clustering
@ -484,6 +485,22 @@ class FastRCNNOutputLayers(nn.Module):
self.invalid_class_range = list(range(self.seen_classes, self.num_classes-1))
logging.getLogger(__name__).info("Invalid class range: " + str(self.invalid_class_range))
self.max_iterations = max_iterations
self.feature_store_is_stored = False
self.output_dir = output_dir
self.feat_store_path = feat_store_path
self.feature_store_save_loc = os.path.join(self.output_dir, self.feat_store_path, 'feat.pt')
if os.path.isfile(self.feature_store_save_loc):
self.feature_store = torch.load(self.feature_store_save_loc)
logging.getLogger(__name__).info('Loaded feature store from ' + self.feature_store_save_loc)
else:
logging.getLogger(__name__).info('Feature store not found in ' +
self.feature_store_save_loc + '. Creating new feature store.')
self.feature_store = Store(num_classes + 1, clustering_items_per_class)
self.means = [None for _ in range(num_classes + 1)]
# self.ae_model = AE(input_size, clustering_z_dimension)
# self.ae_model.apply(Xavier)
@ -508,7 +525,10 @@ class FastRCNNOutputLayers(nn.Module):
"clustering_z_dimension": cfg.OWOD.CLUSTERING.Z_DIMENSION,
"enable_clustering" : cfg.OWOD.ENABLE_CLUSTERING,
"prev_intro_cls" : cfg.OWOD.PREV_INTRODUCED_CLS,
"curr_intro_cls" : cfg.OWOD.CUR_INTRODUCED_CLS
"curr_intro_cls" : cfg.OWOD.CUR_INTRODUCED_CLS,
"max_iterations" : cfg.SOLVER.MAX_ITER,
"output_dir" : cfg.OUTPUT_DIR,
"feat_store_path" : cfg.OWOD.FEATURE_STORE_SAVE_PATH,
# fmt: on
}
@ -536,75 +556,16 @@ class FastRCNNOutputLayers(nn.Module):
gt_classes = torch.cat([p.gt_classes for p in proposals])
self.feature_store.add(features, gt_classes)
storage = get_event_storage()
if storage.iter == self.max_iterations and self.feature_store_is_stored is False:
logging.getLogger(__name__).info('Saving image store at iteration ' + str(storage.iter) + ' to ' + self.feature_store_save_loc)
torch.save(self.feature_store, self.feature_store_save_loc)
self.feature_store_is_stored = True
# self.feature_store.add(F.normalize(features, dim=0), gt_classes)
# self.feature_store.add(self.ae_model.encoder(features), gt_classes)
# def clstr_loss(self, input_features, proposals):
# """
# Get the foreground input_features, generate distributions for the class,
# get probability of each feature from each distribution;
# Compute loss: if belonging to a class -> likelihood should be higher
# else -> lower
# :param input_features:
# :param proposals:
# :return:
# """
# loss = 0
# gt_classes = torch.cat([p.gt_classes for p in proposals])
# mask = gt_classes != self.num_classes
# fg_features = input_features[mask]
# classes = gt_classes[mask]
# # fg_features = self.ae_model.encoder(fg_features)
#
# # Distribution per class
# log_prob = [None for _ in range(self.num_classes + 1)]
# # https://github.com/pytorch/pytorch/issues/23780
# for cls_index, mu in enumerate(self.means):
# if mu is not None:
# dist = Normal(loc=mu.cuda(), scale=torch.ones_like(mu.cuda()))
# log_prob[cls_index] = dist.log_prob(fg_features).mean(dim=1)
# # log_prob[cls_index] = torch.distributions.multivariate_normal. \
# # MultivariateNormal(mu.cuda(), torch.eye(len(mu)).cuda()).log_prob(fg_features)
# # MultivariateNormal(mu, torch.eye(len(mu))).log_prob(fg_features.cpu())
# # MultivariateNormal(mu[:2], torch.eye(len(mu[:2]))).log_prob(fg_features[:,:2].cpu())
# else:
# log_prob[cls_index] = torch.zeros((len(fg_features))).cuda()
#
# log_prob = torch.stack(log_prob).T # num_of_fg_proposals x num_of_classes
# for i, p in enumerate(log_prob):
# weight = torch.ones_like(p) * -1
# weight[classes[i]] = 1
# p = p * weight
# loss += p.mean()
# return loss
# def clstr_loss_l2(self, input_features, proposals):
# """
# Get the foreground input_features, generate distributions for the class,
# get probability of each feature from each distribution;
# Compute loss: if belonging to a class -> likelihood should be higher
# else -> lower
# :param input_features:
# :param proposals:
# :return:
# """
# loss = 0
# gt_classes = torch.cat([p.gt_classes for p in proposals])
# mask = gt_classes != self.num_classes
# fg_features = input_features[mask]
# classes = gt_classes[mask]
# fg_features = self.ae_model.encoder(fg_features)
#
# for index, feature in enumerate(fg_features):
# for cls_index, mu in enumerate(self.means):
# if mu is not None and feature is not None:
# mu = mu.cuda()
# if classes[index] == cls_index:
# loss -= F.mse_loss(feature, mu)
# else:
# loss += F.mse_loss(feature, mu)
#
# return loss
def clstr_loss_l2_cdist(self, input_features, proposals):
"""
@ -822,3 +783,70 @@ class FastRCNNOutputLayers(nn.Module):
num_inst_per_image = [len(p) for p in proposals]
probs = F.softmax(scores, dim=-1)
return probs.split(num_inst_per_image, dim=0)
# def clstr_loss(self, input_features, proposals):
# """
# Get the foreground input_features, generate distributions for the class,
# get probability of each feature from each distribution;
# Compute loss: if belonging to a class -> likelihood should be higher
# else -> lower
# :param input_features:
# :param proposals:
# :return:
# """
# loss = 0
# gt_classes = torch.cat([p.gt_classes for p in proposals])
# mask = gt_classes != self.num_classes
# fg_features = input_features[mask]
# classes = gt_classes[mask]
# # fg_features = self.ae_model.encoder(fg_features)
#
# # Distribution per class
# log_prob = [None for _ in range(self.num_classes + 1)]
# # https://github.com/pytorch/pytorch/issues/23780
# for cls_index, mu in enumerate(self.means):
# if mu is not None:
# dist = Normal(loc=mu.cuda(), scale=torch.ones_like(mu.cuda()))
# log_prob[cls_index] = dist.log_prob(fg_features).mean(dim=1)
# # log_prob[cls_index] = torch.distributions.multivariate_normal. \
# # MultivariateNormal(mu.cuda(), torch.eye(len(mu)).cuda()).log_prob(fg_features)
# # MultivariateNormal(mu, torch.eye(len(mu))).log_prob(fg_features.cpu())
# # MultivariateNormal(mu[:2], torch.eye(len(mu[:2]))).log_prob(fg_features[:,:2].cpu())
# else:
# log_prob[cls_index] = torch.zeros((len(fg_features))).cuda()
#
# log_prob = torch.stack(log_prob).T # num_of_fg_proposals x num_of_classes
# for i, p in enumerate(log_prob):
# weight = torch.ones_like(p) * -1
# weight[classes[i]] = 1
# p = p * weight
# loss += p.mean()
# return loss
# def clstr_loss_l2(self, input_features, proposals):
# """
# Get the foreground input_features, generate distributions for the class,
# get probability of each feature from each distribution;
# Compute loss: if belonging to a class -> likelihood should be higher
# else -> lower
# :param input_features:
# :param proposals:
# :return:
# """
# loss = 0
# gt_classes = torch.cat([p.gt_classes for p in proposals])
# mask = gt_classes != self.num_classes
# fg_features = input_features[mask]
# classes = gt_classes[mask]
# fg_features = self.ae_model.encoder(fg_features)
#
# for index, feature in enumerate(fg_features):
# for cls_index, mu in enumerate(self.means):
# if mu is not None and feature is not None:
# mu = mu.cuda()
# if classes[index] == cls_index:
# loss -= F.mse_loss(feature, mu)
# else:
# loss += F.mse_loss(feature, mu)
#
# return loss