Add files via upload

main
RE-OWOD 2022-01-11 11:03:19 +08:00 committed by GitHub
parent fe28d217de
commit 6257e7bcce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 160 additions and 290 deletions

View File

@ -1,18 +1,15 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from typing import Dict, Union
import numpy as np
import torch
import os
import math
import shortuuid
import time
import pickle
from fvcore.nn import giou_loss, smooth_l1_loss
from torch import nn
from torch.nn import functional as F
from torch.distributions.normal import Normal
import sys
import detectron2.utils.comm as comm
from detectron2.config import configurable
@ -53,10 +50,7 @@ Naming convention:
"""
def fast_rcnn_inference(
boxes, scores, image_shapes, predictions, score_thresh, nms_thresh, topk_per_image, calibration, unk_thresh
):
def fast_rcnn_inference(boxes, scores, image_shapes, predictions, score_thresh, nms_thresh, topk_per_image):
"""
Call `fast_rcnn_inference_single_image` for all images.
@ -84,7 +78,7 @@ def fast_rcnn_inference(
"""
result_per_image = [
fast_rcnn_inference_single_image(
boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image, prediction, calibration, unk_thresh
boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image, prediction
)
for scores_per_image, boxes_per_image, image_shape, prediction in zip(scores, boxes, image_shapes, predictions)
]
@ -92,7 +86,7 @@ def fast_rcnn_inference(
def fast_rcnn_inference_single_image(
boxes, scores, image_shape, score_thresh, nms_thresh, topk_per_image, prediction, calibration, unk_thresh
boxes, scores, image_shape, score_thresh, nms_thresh, topk_per_image, prediction,calibration, unk_thresh
):
"""
Single-image inference. Return bounding-box detection results by thresholding
@ -105,7 +99,7 @@ def fast_rcnn_inference_single_image(
Returns:
Same as `fast_rcnn_inference`, but for only one image.
"""
# loading calibration's pickle file
if calibration>0:
pickle_addr = "/home/wangduorui/OWOD-zxw/analyze/1122/t2_ori_set_train_scores_cali_0" + str(10 * calibration) + ".pickle"
with open(pickle_addr, "rb") as file:
@ -123,7 +117,7 @@ def fast_rcnn_inference_single_image(
class_list.append(torch.mean(per_class))
# var
class_var_list.append((torch.sqrt(torch.var(per_class))) / (np.sqrt(class_num)))
logits = prediction
valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1)
if not valid_mask.all():
@ -159,30 +153,10 @@ def fast_rcnn_inference_single_image(
boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep]
logits = logits[keep]
# calibration for unknown
if calibration>0:
if len(filter_inds[:, 1]) > 0:
score_back = F.softmax(logits, dim=-1)
# TODO: check difference between score_back(multi) and scores(1)
new_pred_classes = filter_inds[:, 1]
for pred_i in range(len(score_back)):
c = 0
for i in range(20):
if score_back[pred_i, i] < class_list[i] * unk_thresh:
# if score_back[pred_i, i] < (class_list[i] - var_times * class_var_list[i]) * unk_thresh:
c += 1
logits[pred_i, i] = 0
if c == 20:
new_pred_classes[pred_i] = 80
else:
new_pred_classes = filter_inds[:, 1]
result = Instances(image_shape)
result.pred_boxes = Boxes(boxes)
result.scores = scores
# result.pred_classes = filter_inds[:, 1]
result.pred_classes = new_pred_classes
result.pred_classes = filter_inds[:, 1]
result.logits = logits
return result, filter_inds[:, 0]
@ -290,11 +264,9 @@ class FastRCNNOutputs:
self._log_accuracy()
self.pred_class_logits[:, self.invalid_class_range] = -10e10
# self.log_logits(self.pred_class_logits, self.gt_classes)
mean_loss = softmax_loss(self.pred_class_logits, self.gt_classes)
return mean_loss
# return F.cross_entropy(self.pred_class_logits, self.gt_classes, reduction="mean")
# print("self.gt_classes:",self.gt_classes)
# sys.exit()
return F.cross_entropy(self.pred_class_logits, self.gt_classes, reduction="mean")
def log_logits(self, logits, cls):
data = (logits, cls)
@ -323,7 +295,7 @@ class FastRCNNOutputs:
# Empty fg_inds produces a valid loss of zero as long as the size_average
# arg to smooth_l1_loss is False (otherwise it uses torch.mean internally
# and would produce a nan loss).
fg_inds = nonzero_tuple((self.gt_classes >= 0) & (self.gt_classes < bg_class_ind))[0]
fg_inds = nonzero_tuple((self.gt_classes >= 0) & (self.gt_classes < bg_class_ind))[0] # -1 不算unk的loss
if cls_agnostic_bbox_reg:
# pred_proposal_deltas only corresponds to foreground class for agnostic
gt_class_cols = torch.arange(box_dim, device=device)
@ -470,8 +442,6 @@ class FastRCNNOutputLayers(nn.Module):
output_dir,
feat_store_path,
margin,
calibration,
unk_thresh,
num_classes: int,
test_score_thresh: float = 0.0,
test_nms_thresh: float = 0.5,
@ -554,28 +524,6 @@ class FastRCNNOutputLayers(nn.Module):
self.feature_store = Store(num_classes + 1, clustering_items_per_class)
self.means = [None for _ in range(num_classes + 1)]
self.margin = margin
# post treatment para
self.calibration = calibration
self.unk_thresh = unk_thresh
# self.feat_store_path1 = "./output/1129_train/t1/feature_store/feat.pt"
# data = torch.load(self.feat_store_path1)
# items=data.retrieve(-1)
# for index, item in enumerate(items):
# if len(item) == 0:
# self.means[index] = None
# else:
# mu = torch.tensor(item).mean(dim=0)
# self.means[index] = mu
# self.all_means1 = self.means
# for item in self.all_means1:
# if item != None:
# length = item.shape
# break
# for i, item in enumerate(self.all_means1):
# if item == None:
# self.all_means1[i] = torch.zeros((length))
# self.ae_model = AE(input_size, clustering_z_dimension)
# self.ae_model.apply(Xavier)
@ -606,8 +554,6 @@ class FastRCNNOutputLayers(nn.Module):
"output_dir" : cfg.OUTPUT_DIR,
"feat_store_path" : cfg.OWOD.FEATURE_STORE_SAVE_PATH,
"margin" : cfg.OWOD.CLUSTERING.MARGIN,
"calibration" : cfg.OWOD.CLUSTERING.CLIBARATION,
"unk_thresh" : cfg.OWOD.CLUSTERING.UNK_THRESH,
# fmt: on
}
@ -626,31 +572,10 @@ class FastRCNNOutputLayers(nn.Module):
"""
if x.dim() > 2:
x = torch.flatten(x, start_dim=1)
# scores = self.cls_score(x)
all_means = self.means
for item in all_means:
if item != None:
length = item.shape
break
for i, item in enumerate(all_means):
if item == None:
all_means[i] = torch.zeros((length))
# dist = Distance(x, all_means)
# scores = -dist / 0.1
scores = self.solve(x, torch.stack(all_means).cuda())
scores = self.cls_score(x)
proposal_deltas = self.bbox_pred(x)
return scores, proposal_deltas
def solve(self,features, means):
features = features.unsqueeze(1)
means = means.unsqueeze(0)
prob = torch.exp(-torch.norm(features - means, dim=-1))
prob = prob / prob.sum(-1, keepdim=True)
return prob
def update_feature_store(self, features, proposals):
# cat(..., dim=0) concatenates over all images in the batch
gt_classes = torch.cat([p.gt_classes for p in proposals])
@ -662,43 +587,11 @@ class FastRCNNOutputLayers(nn.Module):
logging.getLogger(__name__).info('Saving image store at iteration ' + str(storage.iter) + ' to ' + self.feature_store_save_loc)
torch.save(self.feature_store, self.feature_store_save_loc)
self.feature_store_is_stored = True
# print("11111111111111111111111111111111111")
# self.feature_store.add(F.normalize(features, dim=0), gt_classes)
# self.feature_store.add(self.ae_model.encoder(features), gt_classes)
def updatePrototype(self):
storage = get_event_storage()
if storage.iter == self.clustering_start_iter:
items = self.feature_store.retrieve(-1)
for index, item in enumerate(items):
if len(item) == 0:
self.means[index] = None
else:
mu = torch.tensor(item).mean(dim=0)
self.means[index] = mu
# Freeze the parameters when clustering starts
# for param in self.ae_model.parameters():
# param.requires_grad = False
elif storage.iter > self.clustering_start_iter:
if storage.iter % self.clustering_update_mu_iter == 0:
# Compute new MUs
items = self.feature_store.retrieve(-1)
new_means = [None for _ in range(self.num_classes + 1)]
for index, item in enumerate(items):
if len(item) == 0:
new_means[index] = None
else:
new_means[index] = torch.tensor(item).mean(dim=0)
# Update the MUs
for i, mean in enumerate(self.means):
if(mean) is not None and new_means[i] is not None:
self.means[i] = self.clustering_momentum * mean + \
(1 - self.clustering_momentum) * new_means[i]
return self.means
def clstr_loss_l2_cdist(self, input_features, proposals):
"""
Get the foreground input_features, generate distributions for the class,
@ -716,19 +609,17 @@ class FastRCNNOutputLayers(nn.Module):
# fg_features = F.normalize(fg_features, dim=0)
# fg_features = self.ae_model.encoder(fg_features)
self.all_means = self.means
for item in self.all_means:
all_means = self.means
for item in all_means:
if item != None:
length = item.shape
break
for i, item in enumerate(self.all_means):
for i, item in enumerate(all_means):
if item == None:
self.all_means[i] = torch.zeros((length))
all_means[i] = torch.zeros((length))
distances = torch.cdist(fg_features, torch.stack(self.all_means).cuda(), p=self.margin)
distances = torch.cdist(fg_features, torch.stack(all_means).cuda(), p=self.margin)
labels = []
for index, feature in enumerate(fg_features):
@ -751,36 +642,35 @@ class FastRCNNOutputLayers(nn.Module):
storage = get_event_storage()
c_loss = 0
# self.means=all_means
# if storage.iter == self.clustering_start_iter:
# items = self.feature_store.retrieve(-1)
# for index, item in enumerate(items):
# if len(item) == 0:
# self.means[index] = None
# else:
# mu = torch.tensor(item).mean(dim=0)
# self.means[index] = mu
# c_loss = self.clstr_loss_l2_cdist(input_features, proposals)
# # Freeze the parameters when clustering starts
# # for param in self.ae_model.parameters():
# # param.requires_grad = False
# elif storage.iter > self.clustering_start_iter:
# if storage.iter % self.clustering_update_mu_iter == 0:
# # Compute new MUs
# items = self.feature_store.retrieve(-1)
# new_means = [None for _ in range(self.num_classes + 1)]
# for index, item in enumerate(items):
# if len(item) == 0:
# new_means[index] = None
# else:
# new_means[index] = torch.tensor(item).mean(dim=0)
# # Update the MUs
# for i, mean in enumerate(self.means):
# if(mean) is not None and new_means[i] is not None:
# self.means[i] = self.clustering_momentum * mean + \
# (1 - self.clustering_momentum) * new_means[i]
if storage.iter == self.clustering_start_iter:
items = self.feature_store.retrieve(-1)
for index, item in enumerate(items):
if len(item) == 0:
self.means[index] = None
else:
mu = torch.tensor(item).mean(dim=0)
self.means[index] = mu
c_loss = self.clstr_loss_l2_cdist(input_features, proposals)
# Freeze the parameters when clustering starts
# for param in self.ae_model.parameters():
# param.requires_grad = False
elif storage.iter > self.clustering_start_iter:
if storage.iter % self.clustering_update_mu_iter == 0:
# Compute new MUs
items = self.feature_store.retrieve(-1)
new_means = [None for _ in range(self.num_classes + 1)]
for index, item in enumerate(items):
if len(item) == 0:
new_means[index] = None
else:
new_means[index] = torch.tensor(item).mean(dim=0)
# Update the MUs
for i, mean in enumerate(self.means):
if(mean) is not None and new_means[i] is not None:
self.means[i] = self.clustering_momentum * mean + \
(1 - self.clustering_momentum) * new_means[i]
c_loss = self.clstr_loss_l2_cdist(input_features, proposals)
c_loss = self.clstr_loss_l2_cdist(input_features, proposals)
return c_loss
# def get_ae_loss(self, input_features):
@ -840,8 +730,6 @@ class FastRCNNOutputLayers(nn.Module):
self.test_score_thresh,
self.test_nms_thresh,
self.test_topk_per_image,
self.calibration,
self.unk_thresh,
)
def predict_boxes_for_gt_classes(self, predictions, proposals):
@ -920,103 +808,3 @@ class FastRCNNOutputLayers(nn.Module):
num_inst_per_image = [len(p) for p in proposals]
probs = F.softmax(scores, dim=-1)
return probs.split(num_inst_per_image, dim=0)
# def clstr_loss(self, input_features, proposals):
# """
# Get the foreground input_features, generate distributions for the class,
# get probability of each feature from each distribution;
# Compute loss: if belonging to a class -> likelihood should be higher
# else -> lower
# :param input_features:
# :param proposals:
# :return:
# """
# loss = 0
# gt_classes = torch.cat([p.gt_classes for p in proposals])
# mask = gt_classes != self.num_classes
# fg_features = input_features[mask]
# classes = gt_classes[mask]
# # fg_features = self.ae_model.encoder(fg_features)
#
# # Distribution per class
# log_prob = [None for _ in range(self.num_classes + 1)]
# # https://github.com/pytorch/pytorch/issues/23780
# for cls_index, mu in enumerate(self.means):
# if mu is not None:
# dist = Normal(loc=mu.cuda(), scale=torch.ones_like(mu.cuda()))
# log_prob[cls_index] = dist.log_prob(fg_features).mean(dim=1)
# # log_prob[cls_index] = torch.distributions.multivariate_normal. \
# # MultivariateNormal(mu.cuda(), torch.eye(len(mu)).cuda()).log_prob(fg_features)
# # MultivariateNormal(mu, torch.eye(len(mu))).log_prob(fg_features.cpu())
# # MultivariateNormal(mu[:2], torch.eye(len(mu[:2]))).log_prob(fg_features[:,:2].cpu())
# else:
# log_prob[cls_index] = torch.zeros((len(fg_features))).cuda()
#
# log_prob = torch.stack(log_prob).T # num_of_fg_proposals x num_of_classes
# for i, p in enumerate(log_prob):
# weight = torch.ones_like(p) * -1
# weight[classes[i]] = 1
# p = p * weight
# loss += p.mean()
# return loss
# def clstr_loss_l2(self, input_features, proposals):
# """
# Get the foreground input_features, generate distributions for the class,
# get probability of each feature from each distribution;
# Compute loss: if belonging to a class -> likelihood should be higher
# else -> lower
# :param input_features:
# :param proposals:
# :return:
# """
# loss = 0
# gt_classes = torch.cat([p.gt_classes for p in proposals])
# mask = gt_classes != self.num_classes
# fg_features = input_features[mask]
# classes = gt_classes[mask]
# fg_features = self.ae_model.encoder(fg_features)
#
# for index, feature in enumerate(fg_features):
# for cls_index, mu in enumerate(self.means):
# if mu is not None and feature is not None:
# mu = mu.cuda()
# if classes[index] == cls_index:
# loss -= F.mse_loss(feature, mu)
# else:
# loss += F.mse_loss(feature, mu)
#
# return loss
def Distance(features, centers):
f_2 = torch.sum(torch.pow(features, 2), dim=1, keep_dims=True,)
c_2 = torch.sum(torch.pow(centers, 2), dim=1, keep_dims=True)
dist = f_2 - 2 * torch.matmul(features, centers) + torch.transpose(c_2, 1, 0)
return dist
def softmax_loss(logits, labels):
# labels = tf.to_int32(labels)
logp = torch.nn.functional.log_softmax(logits)
labels=labels.reshape(-1,1)
logpy = torch.gather(logp, 1, labels)
loss = -(logpy).mean()
# print("loss",loss)
# cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
# logits=logits, name='xentropy')
return loss*0.1
# def softmax_loss(self,logits, labels):
# # labels = tf.to_int32(labels)
# logp = torch.nn.functional.log_softmax(logits)
# # print(logp.shape)
# # print(labels.shape)
# # sys.exit()
# labels=labels.reshape(-1,1)
# logpy = torch.gather(logp, 1, labels)
# loss = -(logpy).mean()

View File

@ -6,11 +6,12 @@ import heapq
import os
import shortuuid
import operator
import shortuuid
import sys
import cv2
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import nn
import time
# from ..drawBoxes import draw_boxes
from detectron2.config import configurable
from detectron2.layers import ShapeSpec, nonzero_tuple
@ -149,6 +150,7 @@ class ROIHeads(torch.nn.Module):
batch_size_per_image,
positive_fraction,
proposal_matcher,
proposal_matcher_unk,
enable_thresold_autolabelling,
unk_k,
proposal_append_gt=True,
@ -169,6 +171,7 @@ class ROIHeads(torch.nn.Module):
self.positive_fraction = positive_fraction
self.num_classes = num_classes
self.proposal_matcher = proposal_matcher
self.proposal_matcher_unk = proposal_matcher_unk
self.proposal_append_gt = proposal_append_gt
self.enable_thresold_autolabelling = enable_thresold_autolabelling
self.unk_k = unk_k
@ -186,6 +189,12 @@ class ROIHeads(torch.nn.Module):
cfg.MODEL.ROI_HEADS.IOU_LABELS,
allow_low_quality_matches=False,
),
"proposal_matcher_unk": Matcher(
# cfg.MODEL.ROI_HEADS.IOU_THRESHOLDS_UNK,
[0.8],
cfg.MODEL.ROI_HEADS.IOU_LABELS,
allow_low_quality_matches=False,
),
"enable_thresold_autolabelling": cfg.OWOD.ENABLE_THRESHOLD_AUTOLABEL_UNK,
"unk_k": cfg.OWOD.NUM_UNK_PER_IMAGE,
}
@ -229,24 +238,11 @@ class ROIHeads(torch.nn.Module):
sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0)
gt_classes_ss = gt_classes[sampled_idxs]
if self.enable_thresold_autolabelling:
matched_labels_ss = matched_labels[sampled_idxs]
pred_objectness_score_ss = objectness_logits[sampled_idxs]
# 1) Remove FG objectness score. 2) Sort and select top k. 3) Build and apply mask.
mask = torch.zeros((pred_objectness_score_ss.shape), dtype=torch.bool)
pred_objectness_score_ss[matched_labels_ss != 0] = -1
sorted_indices = list(zip(
*heapq.nlargest(self.unk_k, enumerate(pred_objectness_score_ss), key=operator.itemgetter(1))))[0]
for index in sorted_indices:
mask[index] = True
gt_classes_ss[mask] = self.num_classes - 1
return sampled_idxs, gt_classes_ss
@torch.no_grad()
def label_and_sample_proposals(
self, proposals: List[Instances], targets: List[Instances]
self, proposals: List[Instances], targets: List[Instances], image_id = None, ori_image = None
) -> List[Instances]:
"""
Prepare some proposals to be used to train the ROI heads.
@ -287,10 +283,13 @@ class ROIHeads(torch.nn.Module):
proposals = add_ground_truth_to_proposals(gt_boxes, proposals)
proposals_with_gt = []
unk_sel_gt = []
num_fg_samples = []
num_bg_samples = []
for proposals_per_image, targets_per_image in zip(proposals, targets):
for proposals_per_image, targets_per_image,image_id_i,ori_image_i\
in zip(proposals, targets, image_id, ori_image):
height_new, width_new = proposals_per_image.image_size
has_gt = len(targets_per_image) > 0
match_quality_matrix = pairwise_iou(
targets_per_image.gt_boxes, proposals_per_image.proposal_boxes
@ -299,6 +298,89 @@ class ROIHeads(torch.nn.Module):
sampled_idxs, gt_classes = self._sample_proposals(
matched_idxs, matched_labels, targets_per_image.gt_classes, proposals_per_image.objectness_logits
)
del match_quality_matrix
gt_flag = False
unk_flag = False
storage = get_event_storage()
if self.enable_thresold_autolabelling and storage.iter > 50000:
matched_labels_ss = matched_labels[sampled_idxs]
pred_objectness_score_ss = proposals_per_image.objectness_logits[sampled_idxs]
pred_objectness_score_ss[matched_labels_ss != 0] = -1
sorted_indices = list(zip(
*heapq.nlargest(50, enumerate(pred_objectness_score_ss), key=operator.itemgetter(1))))[0]
mask = torch.zeros((pred_objectness_score_ss.shape), dtype=torch.bool)
new_flag = True
for index in sorted_indices:
if new_flag:
auotolabel_boxes = proposals_per_image.proposal_boxes[sampled_idxs[index].item()]
autolabel_score = proposals_per_image.objectness_logits[sampled_idxs[index].item()].view(1,-1)
new_flag = False
else:
box_i = proposals_per_image.proposal_boxes[sampled_idxs[index].item()]
score_i = proposals_per_image.objectness_logits[sampled_idxs[index].item()].view(1,-1)
auotolabel_boxes = Boxes.cat([auotolabel_boxes, box_i])
autolabel_score = torch.cat([autolabel_score,score_i],1)
obj_save_path = "../score_store/" + image_id_i + ".jpg"+".pickle"
obj_score_save = torch.load(obj_save_path)
height_ori, width_ori = obj_score_save['image_size']
obj_score_boxes = obj_score_save['obj_boxes']
if len(obj_score_boxes):
obj_boxes_sel = obj_score_boxes[:50,:4]
obj_boxes_sel[:,0] = obj_boxes_sel[:,0] * (width_new * 1.0 / width_ori)
obj_boxes_sel[:,1] = obj_boxes_sel[:,1] * (height_new * 1.0 / height_ori)
obj_boxes_sel[:,2] = obj_boxes_sel[:,2] * (width_new * 1.0 / width_ori)
obj_boxes_sel[:,3] = obj_boxes_sel[:,3] * (height_new * 1.0 / height_ori)
obj_boxes = Boxes(torch.Tensor(obj_boxes_sel).cuda())
area_new = width_new * height_new
area_mask = obj_boxes.area() / area_new
area_mask = area_mask < 0.8
unk_match_matrix = pairwise_iou(obj_boxes, auotolabel_boxes)
unk_match_matrix[unk_match_matrix < 0.9] = 0 # 0.7
score_matrix = torch.mm(area_mask.view(-1,1).float(),autolabel_score.view(1,-1))
score_matrix = torch.mul(score_matrix, unk_match_matrix)
score_matrix, _ = torch.max(score_matrix, 0)
_, unk_max_index = torch.max(score_matrix, 0)
unk_obj_index = torch.nonzero(score_matrix).cpu()
del unk_match_matrix
if len(unk_obj_index): # pseudo
gt_flag = True
unk_instances_gt = Instances(proposals_per_image.image_size)
unk_box = auotolabel_boxes[unk_max_index.item()]
unk_instances_gt.gt_boxes = unk_box
unk_instances_gt.gt_classes = torch.Tensor([80]).long().cuda()
targets_per_image = Instances.cat([targets_per_image, unk_instances_gt])
match_quality_matrix = pairwise_iou(
unk_instances_gt.gt_boxes, proposals_per_image.proposal_boxes[sampled_idxs]
)
_, matched_labels_unk = self.proposal_matcher_unk(match_quality_matrix)
del match_quality_matrix
matched_unk_mask = matched_labels_unk == 1
matched_unk_mask_idx = torch.nonzero(matched_unk_mask)
for index in matched_unk_mask_idx:
if sampled_idxs[index] < 100:
mask[index] = True
unk_match_matrix = pairwise_iou(obj_boxes, auotolabel_boxes)
unk_match_matrix[unk_match_matrix < 0.7] = 0
score_matrix = torch.mm(area_mask.view(-1,1).float(),autolabel_score.view(1,-1))
score_matrix = torch.mul(score_matrix, unk_match_matrix)
score_matrix,_ = torch.max(score_matrix, 0)
del unk_match_matrix
unk_obj_index = torch.nonzero(score_matrix).cpu()
score_matrix = score_matrix[score_matrix > 0]
if len(unk_obj_index):
for idx in unk_obj_index:
mask[sorted_indices[idx]] = True
gt_classes[mask] = 80
# Set target attributes of the sampled proposals:
proposals_per_image = proposals_per_image[sampled_idxs]
@ -323,14 +405,20 @@ class ROIHeads(torch.nn.Module):
num_bg_samples.append((gt_classes == self.num_classes).sum().item())
num_fg_samples.append(gt_classes.numel() - num_bg_samples[-1])
if gt_flag:
unk_sel_gt.append(unk_instances_gt.gt_boxes)
else:
unk_gt_boxes = []
unk_sel_gt.append(unk_gt_boxes)
proposals_with_gt.append(proposals_per_image)
# Log the number of fg/bg samples that are selected for training ROI heads
storage = get_event_storage()
# storage = get_event_storage()
storage.put_scalar("roi_head/num_fg_samples", np.mean(num_fg_samples))
storage.put_scalar("roi_head/num_bg_samples", np.mean(num_bg_samples))
return proposals_with_gt
return proposals_with_gt, unk_sel_gt
def forward(
self,
@ -455,7 +543,7 @@ class Res5ROIHeads(ROIHeads):
location = os.path.join(self.energy_save_path, shortuuid.uuid() + '.pkl')
torch.save(data, location)
def forward(self, images, features, proposals, targets=None):
def forward(self, images, features, proposals, targets=None, image_id=None, ori_image=None):
"""
See :meth:`ROIHeads.forward`.
"""
@ -463,7 +551,7 @@ class Res5ROIHeads(ROIHeads):
if self.training:
assert targets
proposals = self.label_and_sample_proposals(proposals, targets)
proposals, unk_sel_gt = self.label_and_sample_proposals(proposals, targets, image_id, ori_image)
del targets
proposal_boxes = [x.proposal_boxes for x in proposals]
@ -471,19 +559,13 @@ class Res5ROIHeads(ROIHeads):
[features[f] for f in self.in_features], proposal_boxes
)
input_features = box_features.mean(dim=[2, 3])
if self.training:
# self.log_features(input_features, proposals)
if self.enable_clustering:
# print("11111111111111112222222222222222222")
self.box_predictor.update_feature_store(input_features, proposals)
self.box_predictor.updatePrototype()
predictions = self.box_predictor(input_features)
if self.training:
# self.log_features(input_features, proposals)
# if self.enable_clustering:
# self.box_predictor.update_feature_store(input_features, proposals)
# del features
if self.enable_clustering:
self.box_predictor.update_feature_store(input_features, proposals)
del features
if self.compute_energy_flag:
self.compute_energy(predictions, proposals)
losses = self.box_predictor.losses(predictions, proposals, input_features)
@ -498,11 +580,11 @@ class Res5ROIHeads(ROIHeads):
mask_features = box_features[torch.cat(fg_selection_masks, dim=0)]
del box_features
losses.update(self.mask_head(mask_features, proposals))
return [], losses
return [], losses, unk_sel_gt
else:
pred_instances, _ = self.box_predictor.inference(predictions, proposals)
pred_instances = self.forward_with_given_boxes(features, pred_instances)
return pred_instances, {}
return pred_instances, {}, []
def forward_with_given_boxes(self, features, instances):
"""