diff --git a/detectron2/modeling/roi_heads/fast_rcnn.py b/detectron2/modeling/roi_heads/fast_rcnn.py index 961ada6..0a25c60 100644 --- a/detectron2/modeling/roi_heads/fast_rcnn.py +++ b/detectron2/modeling/roi_heads/fast_rcnn.py @@ -1,18 +1,15 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import logging from typing import Dict, Union - -import numpy as np import torch import os import math import shortuuid -import time -import pickle from fvcore.nn import giou_loss, smooth_l1_loss from torch import nn from torch.nn import functional as F from torch.distributions.normal import Normal +import sys import detectron2.utils.comm as comm from detectron2.config import configurable @@ -53,10 +50,7 @@ Naming convention: """ - -def fast_rcnn_inference( - boxes, scores, image_shapes, predictions, score_thresh, nms_thresh, topk_per_image, calibration, unk_thresh -): +def fast_rcnn_inference(boxes, scores, image_shapes, predictions, score_thresh, nms_thresh, topk_per_image): """ Call `fast_rcnn_inference_single_image` for all images. @@ -84,7 +78,7 @@ def fast_rcnn_inference( """ result_per_image = [ fast_rcnn_inference_single_image( - boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image, prediction, calibration, unk_thresh + boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image, prediction ) for scores_per_image, boxes_per_image, image_shape, prediction in zip(scores, boxes, image_shapes, predictions) ] @@ -92,7 +86,7 @@ def fast_rcnn_inference( def fast_rcnn_inference_single_image( - boxes, scores, image_shape, score_thresh, nms_thresh, topk_per_image, prediction, calibration, unk_thresh + boxes, scores, image_shape, score_thresh, nms_thresh, topk_per_image, prediction,calibration, unk_thresh ): """ Single-image inference. Return bounding-box detection results by thresholding @@ -105,7 +99,7 @@ def fast_rcnn_inference_single_image( Returns: Same as `fast_rcnn_inference`, but for only one image. """ - # loading calibration's pickle file + if calibration>0: pickle_addr = "/home/wangduorui/OWOD-zxw/analyze/1122/t2_ori_set_train_scores_cali_0" + str(10 * calibration) + ".pickle" with open(pickle_addr, "rb") as file: @@ -123,7 +117,7 @@ def fast_rcnn_inference_single_image( class_list.append(torch.mean(per_class)) # var class_var_list.append((torch.sqrt(torch.var(per_class))) / (np.sqrt(class_num))) - + logits = prediction valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1) if not valid_mask.all(): @@ -159,30 +153,10 @@ def fast_rcnn_inference_single_image( boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep] logits = logits[keep] - # calibration for unknown - if calibration>0: - if len(filter_inds[:, 1]) > 0: - score_back = F.softmax(logits, dim=-1) - # TODO: check difference between score_back(multi) and scores(1) - - new_pred_classes = filter_inds[:, 1] - for pred_i in range(len(score_back)): - c = 0 - for i in range(20): - if score_back[pred_i, i] < class_list[i] * unk_thresh: - # if score_back[pred_i, i] < (class_list[i] - var_times * class_var_list[i]) * unk_thresh: - c += 1 - logits[pred_i, i] = 0 - if c == 20: - new_pred_classes[pred_i] = 80 - else: - new_pred_classes = filter_inds[:, 1] - result = Instances(image_shape) result.pred_boxes = Boxes(boxes) result.scores = scores - # result.pred_classes = filter_inds[:, 1] - result.pred_classes = new_pred_classes + result.pred_classes = filter_inds[:, 1] result.logits = logits return result, filter_inds[:, 0] @@ -290,11 +264,9 @@ class FastRCNNOutputs: self._log_accuracy() self.pred_class_logits[:, self.invalid_class_range] = -10e10 # self.log_logits(self.pred_class_logits, self.gt_classes) - - mean_loss = softmax_loss(self.pred_class_logits, self.gt_classes) - return mean_loss - - # return F.cross_entropy(self.pred_class_logits, self.gt_classes, reduction="mean") + # print("self.gt_classes:",self.gt_classes) + # sys.exit() + return F.cross_entropy(self.pred_class_logits, self.gt_classes, reduction="mean") def log_logits(self, logits, cls): data = (logits, cls) @@ -323,7 +295,7 @@ class FastRCNNOutputs: # Empty fg_inds produces a valid loss of zero as long as the size_average # arg to smooth_l1_loss is False (otherwise it uses torch.mean internally # and would produce a nan loss). - fg_inds = nonzero_tuple((self.gt_classes >= 0) & (self.gt_classes < bg_class_ind))[0] + fg_inds = nonzero_tuple((self.gt_classes >= 0) & (self.gt_classes < bg_class_ind))[0] # -1 不算unk的loss if cls_agnostic_bbox_reg: # pred_proposal_deltas only corresponds to foreground class for agnostic gt_class_cols = torch.arange(box_dim, device=device) @@ -470,8 +442,6 @@ class FastRCNNOutputLayers(nn.Module): output_dir, feat_store_path, margin, - calibration, - unk_thresh, num_classes: int, test_score_thresh: float = 0.0, test_nms_thresh: float = 0.5, @@ -554,28 +524,6 @@ class FastRCNNOutputLayers(nn.Module): self.feature_store = Store(num_classes + 1, clustering_items_per_class) self.means = [None for _ in range(num_classes + 1)] self.margin = margin - # post treatment para - self.calibration = calibration - self.unk_thresh = unk_thresh - - # self.feat_store_path1 = "./output/1129_train/t1/feature_store/feat.pt" - # data = torch.load(self.feat_store_path1) - # items=data.retrieve(-1) - # for index, item in enumerate(items): - # if len(item) == 0: - # self.means[index] = None - # else: - # mu = torch.tensor(item).mean(dim=0) - # self.means[index] = mu - # self.all_means1 = self.means - # for item in self.all_means1: - # if item != None: - # length = item.shape - # break - - # for i, item in enumerate(self.all_means1): - # if item == None: - # self.all_means1[i] = torch.zeros((length)) # self.ae_model = AE(input_size, clustering_z_dimension) # self.ae_model.apply(Xavier) @@ -606,8 +554,6 @@ class FastRCNNOutputLayers(nn.Module): "output_dir" : cfg.OUTPUT_DIR, "feat_store_path" : cfg.OWOD.FEATURE_STORE_SAVE_PATH, "margin" : cfg.OWOD.CLUSTERING.MARGIN, - "calibration" : cfg.OWOD.CLUSTERING.CLIBARATION, - "unk_thresh" : cfg.OWOD.CLUSTERING.UNK_THRESH, # fmt: on } @@ -626,31 +572,10 @@ class FastRCNNOutputLayers(nn.Module): """ if x.dim() > 2: x = torch.flatten(x, start_dim=1) - # scores = self.cls_score(x) - - all_means = self.means - for item in all_means: - if item != None: - length = item.shape - break - - for i, item in enumerate(all_means): - if item == None: - all_means[i] = torch.zeros((length)) - # dist = Distance(x, all_means) - # scores = -dist / 0.1 - scores = self.solve(x, torch.stack(all_means).cuda()) - + scores = self.cls_score(x) proposal_deltas = self.bbox_pred(x) return scores, proposal_deltas - def solve(self,features, means): - features = features.unsqueeze(1) - means = means.unsqueeze(0) - prob = torch.exp(-torch.norm(features - means, dim=-1)) - prob = prob / prob.sum(-1, keepdim=True) - return prob - def update_feature_store(self, features, proposals): # cat(..., dim=0) concatenates over all images in the batch gt_classes = torch.cat([p.gt_classes for p in proposals]) @@ -662,43 +587,11 @@ class FastRCNNOutputLayers(nn.Module): logging.getLogger(__name__).info('Saving image store at iteration ' + str(storage.iter) + ' to ' + self.feature_store_save_loc) torch.save(self.feature_store, self.feature_store_save_loc) self.feature_store_is_stored = True - # print("11111111111111111111111111111111111") # self.feature_store.add(F.normalize(features, dim=0), gt_classes) # self.feature_store.add(self.ae_model.encoder(features), gt_classes) - def updatePrototype(self): - storage = get_event_storage() - if storage.iter == self.clustering_start_iter: - items = self.feature_store.retrieve(-1) - for index, item in enumerate(items): - if len(item) == 0: - self.means[index] = None - else: - mu = torch.tensor(item).mean(dim=0) - self.means[index] = mu - # Freeze the parameters when clustering starts - # for param in self.ae_model.parameters(): - # param.requires_grad = False - elif storage.iter > self.clustering_start_iter: - if storage.iter % self.clustering_update_mu_iter == 0: - # Compute new MUs - items = self.feature_store.retrieve(-1) - new_means = [None for _ in range(self.num_classes + 1)] - for index, item in enumerate(items): - if len(item) == 0: - new_means[index] = None - else: - new_means[index] = torch.tensor(item).mean(dim=0) - # Update the MUs - for i, mean in enumerate(self.means): - if(mean) is not None and new_means[i] is not None: - self.means[i] = self.clustering_momentum * mean + \ - (1 - self.clustering_momentum) * new_means[i] - return self.means - - def clstr_loss_l2_cdist(self, input_features, proposals): """ Get the foreground input_features, generate distributions for the class, @@ -716,19 +609,17 @@ class FastRCNNOutputLayers(nn.Module): # fg_features = F.normalize(fg_features, dim=0) # fg_features = self.ae_model.encoder(fg_features) - - - self.all_means = self.means - for item in self.all_means: + all_means = self.means + for item in all_means: if item != None: length = item.shape break - for i, item in enumerate(self.all_means): + for i, item in enumerate(all_means): if item == None: - self.all_means[i] = torch.zeros((length)) + all_means[i] = torch.zeros((length)) - distances = torch.cdist(fg_features, torch.stack(self.all_means).cuda(), p=self.margin) + distances = torch.cdist(fg_features, torch.stack(all_means).cuda(), p=self.margin) labels = [] for index, feature in enumerate(fg_features): @@ -751,36 +642,35 @@ class FastRCNNOutputLayers(nn.Module): storage = get_event_storage() c_loss = 0 - # self.means=all_means - # if storage.iter == self.clustering_start_iter: - # items = self.feature_store.retrieve(-1) - # for index, item in enumerate(items): - # if len(item) == 0: - # self.means[index] = None - # else: - # mu = torch.tensor(item).mean(dim=0) - # self.means[index] = mu - # c_loss = self.clstr_loss_l2_cdist(input_features, proposals) - # # Freeze the parameters when clustering starts - # # for param in self.ae_model.parameters(): - # # param.requires_grad = False - # elif storage.iter > self.clustering_start_iter: - # if storage.iter % self.clustering_update_mu_iter == 0: - # # Compute new MUs - # items = self.feature_store.retrieve(-1) - # new_means = [None for _ in range(self.num_classes + 1)] - # for index, item in enumerate(items): - # if len(item) == 0: - # new_means[index] = None - # else: - # new_means[index] = torch.tensor(item).mean(dim=0) - # # Update the MUs - # for i, mean in enumerate(self.means): - # if(mean) is not None and new_means[i] is not None: - # self.means[i] = self.clustering_momentum * mean + \ - # (1 - self.clustering_momentum) * new_means[i] + if storage.iter == self.clustering_start_iter: + items = self.feature_store.retrieve(-1) + for index, item in enumerate(items): + if len(item) == 0: + self.means[index] = None + else: + mu = torch.tensor(item).mean(dim=0) + self.means[index] = mu + c_loss = self.clstr_loss_l2_cdist(input_features, proposals) + # Freeze the parameters when clustering starts + # for param in self.ae_model.parameters(): + # param.requires_grad = False + elif storage.iter > self.clustering_start_iter: + if storage.iter % self.clustering_update_mu_iter == 0: + # Compute new MUs + items = self.feature_store.retrieve(-1) + new_means = [None for _ in range(self.num_classes + 1)] + for index, item in enumerate(items): + if len(item) == 0: + new_means[index] = None + else: + new_means[index] = torch.tensor(item).mean(dim=0) + # Update the MUs + for i, mean in enumerate(self.means): + if(mean) is not None and new_means[i] is not None: + self.means[i] = self.clustering_momentum * mean + \ + (1 - self.clustering_momentum) * new_means[i] - c_loss = self.clstr_loss_l2_cdist(input_features, proposals) + c_loss = self.clstr_loss_l2_cdist(input_features, proposals) return c_loss # def get_ae_loss(self, input_features): @@ -840,8 +730,6 @@ class FastRCNNOutputLayers(nn.Module): self.test_score_thresh, self.test_nms_thresh, self.test_topk_per_image, - self.calibration, - self.unk_thresh, ) def predict_boxes_for_gt_classes(self, predictions, proposals): @@ -920,103 +808,3 @@ class FastRCNNOutputLayers(nn.Module): num_inst_per_image = [len(p) for p in proposals] probs = F.softmax(scores, dim=-1) return probs.split(num_inst_per_image, dim=0) - - # def clstr_loss(self, input_features, proposals): - # """ - # Get the foreground input_features, generate distributions for the class, - # get probability of each feature from each distribution; - # Compute loss: if belonging to a class -> likelihood should be higher - # else -> lower - # :param input_features: - # :param proposals: - # :return: - # """ - # loss = 0 - # gt_classes = torch.cat([p.gt_classes for p in proposals]) - # mask = gt_classes != self.num_classes - # fg_features = input_features[mask] - # classes = gt_classes[mask] - # # fg_features = self.ae_model.encoder(fg_features) - # - # # Distribution per class - # log_prob = [None for _ in range(self.num_classes + 1)] - # # https://github.com/pytorch/pytorch/issues/23780 - # for cls_index, mu in enumerate(self.means): - # if mu is not None: - # dist = Normal(loc=mu.cuda(), scale=torch.ones_like(mu.cuda())) - # log_prob[cls_index] = dist.log_prob(fg_features).mean(dim=1) - # # log_prob[cls_index] = torch.distributions.multivariate_normal. \ - # # MultivariateNormal(mu.cuda(), torch.eye(len(mu)).cuda()).log_prob(fg_features) - # # MultivariateNormal(mu, torch.eye(len(mu))).log_prob(fg_features.cpu()) - # # MultivariateNormal(mu[:2], torch.eye(len(mu[:2]))).log_prob(fg_features[:,:2].cpu()) - # else: - # log_prob[cls_index] = torch.zeros((len(fg_features))).cuda() - # - # log_prob = torch.stack(log_prob).T # num_of_fg_proposals x num_of_classes - # for i, p in enumerate(log_prob): - # weight = torch.ones_like(p) * -1 - # weight[classes[i]] = 1 - # p = p * weight - # loss += p.mean() - # return loss - - # def clstr_loss_l2(self, input_features, proposals): - # """ - # Get the foreground input_features, generate distributions for the class, - # get probability of each feature from each distribution; - # Compute loss: if belonging to a class -> likelihood should be higher - # else -> lower - # :param input_features: - # :param proposals: - # :return: - # """ - # loss = 0 - # gt_classes = torch.cat([p.gt_classes for p in proposals]) - # mask = gt_classes != self.num_classes - # fg_features = input_features[mask] - # classes = gt_classes[mask] - # fg_features = self.ae_model.encoder(fg_features) - # - # for index, feature in enumerate(fg_features): - # for cls_index, mu in enumerate(self.means): - # if mu is not None and feature is not None: - # mu = mu.cuda() - # if classes[index] == cls_index: - # loss -= F.mse_loss(feature, mu) - # else: - # loss += F.mse_loss(feature, mu) - # - # return loss - -def Distance(features, centers): - f_2 = torch.sum(torch.pow(features, 2), dim=1, keep_dims=True,) - c_2 = torch.sum(torch.pow(centers, 2), dim=1, keep_dims=True) - dist = f_2 - 2 * torch.matmul(features, centers) + torch.transpose(c_2, 1, 0) - - return dist - -def softmax_loss(logits, labels): - # labels = tf.to_int32(labels) - - logp = torch.nn.functional.log_softmax(logits) - labels=labels.reshape(-1,1) - logpy = torch.gather(logp, 1, labels) - loss = -(logpy).mean() - # print("loss",loss) - - # cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, - # logits=logits, name='xentropy') - return loss*0.1 - - - - # def softmax_loss(self,logits, labels): - # # labels = tf.to_int32(labels) - - # logp = torch.nn.functional.log_softmax(logits) - # # print(logp.shape) - # # print(labels.shape) - # # sys.exit() - # labels=labels.reshape(-1,1) - # logpy = torch.gather(logp, 1, labels) - # loss = -(logpy).mean() diff --git a/detectron2/modeling/roi_heads/roi_heads.py b/detectron2/modeling/roi_heads/roi_heads.py index 5971e98..0a3761b 100644 --- a/detectron2/modeling/roi_heads/roi_heads.py +++ b/detectron2/modeling/roi_heads/roi_heads.py @@ -6,11 +6,12 @@ import heapq import os import shortuuid import operator -import shortuuid +import sys +import cv2 from typing import Dict, List, Optional, Tuple, Union import torch from torch import nn -import time +# from ..drawBoxes import draw_boxes from detectron2.config import configurable from detectron2.layers import ShapeSpec, nonzero_tuple @@ -149,6 +150,7 @@ class ROIHeads(torch.nn.Module): batch_size_per_image, positive_fraction, proposal_matcher, + proposal_matcher_unk, enable_thresold_autolabelling, unk_k, proposal_append_gt=True, @@ -169,6 +171,7 @@ class ROIHeads(torch.nn.Module): self.positive_fraction = positive_fraction self.num_classes = num_classes self.proposal_matcher = proposal_matcher + self.proposal_matcher_unk = proposal_matcher_unk self.proposal_append_gt = proposal_append_gt self.enable_thresold_autolabelling = enable_thresold_autolabelling self.unk_k = unk_k @@ -186,6 +189,12 @@ class ROIHeads(torch.nn.Module): cfg.MODEL.ROI_HEADS.IOU_LABELS, allow_low_quality_matches=False, ), + "proposal_matcher_unk": Matcher( + # cfg.MODEL.ROI_HEADS.IOU_THRESHOLDS_UNK, + [0.8], + cfg.MODEL.ROI_HEADS.IOU_LABELS, + allow_low_quality_matches=False, + ), "enable_thresold_autolabelling": cfg.OWOD.ENABLE_THRESHOLD_AUTOLABEL_UNK, "unk_k": cfg.OWOD.NUM_UNK_PER_IMAGE, } @@ -229,24 +238,11 @@ class ROIHeads(torch.nn.Module): sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0) gt_classes_ss = gt_classes[sampled_idxs] - if self.enable_thresold_autolabelling: - matched_labels_ss = matched_labels[sampled_idxs] - pred_objectness_score_ss = objectness_logits[sampled_idxs] - - # 1) Remove FG objectness score. 2) Sort and select top k. 3) Build and apply mask. - mask = torch.zeros((pred_objectness_score_ss.shape), dtype=torch.bool) - pred_objectness_score_ss[matched_labels_ss != 0] = -1 - sorted_indices = list(zip( - *heapq.nlargest(self.unk_k, enumerate(pred_objectness_score_ss), key=operator.itemgetter(1))))[0] - for index in sorted_indices: - mask[index] = True - gt_classes_ss[mask] = self.num_classes - 1 - return sampled_idxs, gt_classes_ss @torch.no_grad() def label_and_sample_proposals( - self, proposals: List[Instances], targets: List[Instances] + self, proposals: List[Instances], targets: List[Instances], image_id = None, ori_image = None ) -> List[Instances]: """ Prepare some proposals to be used to train the ROI heads. @@ -287,10 +283,13 @@ class ROIHeads(torch.nn.Module): proposals = add_ground_truth_to_proposals(gt_boxes, proposals) proposals_with_gt = [] + unk_sel_gt = [] num_fg_samples = [] num_bg_samples = [] - for proposals_per_image, targets_per_image in zip(proposals, targets): + for proposals_per_image, targets_per_image,image_id_i,ori_image_i\ + in zip(proposals, targets, image_id, ori_image): + height_new, width_new = proposals_per_image.image_size has_gt = len(targets_per_image) > 0 match_quality_matrix = pairwise_iou( targets_per_image.gt_boxes, proposals_per_image.proposal_boxes @@ -299,6 +298,89 @@ class ROIHeads(torch.nn.Module): sampled_idxs, gt_classes = self._sample_proposals( matched_idxs, matched_labels, targets_per_image.gt_classes, proposals_per_image.objectness_logits ) + del match_quality_matrix + gt_flag = False + unk_flag = False + storage = get_event_storage() + if self.enable_thresold_autolabelling and storage.iter > 50000: + matched_labels_ss = matched_labels[sampled_idxs] + pred_objectness_score_ss = proposals_per_image.objectness_logits[sampled_idxs] + + pred_objectness_score_ss[matched_labels_ss != 0] = -1 + sorted_indices = list(zip( + *heapq.nlargest(50, enumerate(pred_objectness_score_ss), key=operator.itemgetter(1))))[0] + mask = torch.zeros((pred_objectness_score_ss.shape), dtype=torch.bool) + + new_flag = True + for index in sorted_indices: + if new_flag: + auotolabel_boxes = proposals_per_image.proposal_boxes[sampled_idxs[index].item()] + autolabel_score = proposals_per_image.objectness_logits[sampled_idxs[index].item()].view(1,-1) + new_flag = False + else: + box_i = proposals_per_image.proposal_boxes[sampled_idxs[index].item()] + score_i = proposals_per_image.objectness_logits[sampled_idxs[index].item()].view(1,-1) + auotolabel_boxes = Boxes.cat([auotolabel_boxes, box_i]) + autolabel_score = torch.cat([autolabel_score,score_i],1) + + obj_save_path = "../score_store/" + image_id_i + ".jpg"+".pickle" + obj_score_save = torch.load(obj_save_path) + height_ori, width_ori = obj_score_save['image_size'] + obj_score_boxes = obj_score_save['obj_boxes'] + + if len(obj_score_boxes): + obj_boxes_sel = obj_score_boxes[:50,:4] + obj_boxes_sel[:,0] = obj_boxes_sel[:,0] * (width_new * 1.0 / width_ori) + obj_boxes_sel[:,1] = obj_boxes_sel[:,1] * (height_new * 1.0 / height_ori) + obj_boxes_sel[:,2] = obj_boxes_sel[:,2] * (width_new * 1.0 / width_ori) + obj_boxes_sel[:,3] = obj_boxes_sel[:,3] * (height_new * 1.0 / height_ori) + obj_boxes = Boxes(torch.Tensor(obj_boxes_sel).cuda()) + + area_new = width_new * height_new + area_mask = obj_boxes.area() / area_new + area_mask = area_mask < 0.8 + + unk_match_matrix = pairwise_iou(obj_boxes, auotolabel_boxes) + unk_match_matrix[unk_match_matrix < 0.9] = 0 # 0.7 + score_matrix = torch.mm(area_mask.view(-1,1).float(),autolabel_score.view(1,-1)) + score_matrix = torch.mul(score_matrix, unk_match_matrix) + score_matrix, _ = torch.max(score_matrix, 0) + _, unk_max_index = torch.max(score_matrix, 0) + unk_obj_index = torch.nonzero(score_matrix).cpu() + del unk_match_matrix + if len(unk_obj_index): # pseudo + gt_flag = True + unk_instances_gt = Instances(proposals_per_image.image_size) + unk_box = auotolabel_boxes[unk_max_index.item()] + unk_instances_gt.gt_boxes = unk_box + unk_instances_gt.gt_classes = torch.Tensor([80]).long().cuda() + targets_per_image = Instances.cat([targets_per_image, unk_instances_gt]) + + match_quality_matrix = pairwise_iou( + unk_instances_gt.gt_boxes, proposals_per_image.proposal_boxes[sampled_idxs] + ) + _, matched_labels_unk = self.proposal_matcher_unk(match_quality_matrix) + del match_quality_matrix + matched_unk_mask = matched_labels_unk == 1 + matched_unk_mask_idx = torch.nonzero(matched_unk_mask) + for index in matched_unk_mask_idx: + if sampled_idxs[index] < 100: + mask[index] = True + + unk_match_matrix = pairwise_iou(obj_boxes, auotolabel_boxes) + unk_match_matrix[unk_match_matrix < 0.7] = 0 + + score_matrix = torch.mm(area_mask.view(-1,1).float(),autolabel_score.view(1,-1)) + score_matrix = torch.mul(score_matrix, unk_match_matrix) + score_matrix,_ = torch.max(score_matrix, 0) + del unk_match_matrix + unk_obj_index = torch.nonzero(score_matrix).cpu() + score_matrix = score_matrix[score_matrix > 0] + + if len(unk_obj_index): + for idx in unk_obj_index: + mask[sorted_indices[idx]] = True + gt_classes[mask] = 80 # Set target attributes of the sampled proposals: proposals_per_image = proposals_per_image[sampled_idxs] @@ -323,14 +405,20 @@ class ROIHeads(torch.nn.Module): num_bg_samples.append((gt_classes == self.num_classes).sum().item()) num_fg_samples.append(gt_classes.numel() - num_bg_samples[-1]) + + + if gt_flag: + unk_sel_gt.append(unk_instances_gt.gt_boxes) + else: + unk_gt_boxes = [] + unk_sel_gt.append(unk_gt_boxes) proposals_with_gt.append(proposals_per_image) # Log the number of fg/bg samples that are selected for training ROI heads - storage = get_event_storage() + # storage = get_event_storage() storage.put_scalar("roi_head/num_fg_samples", np.mean(num_fg_samples)) storage.put_scalar("roi_head/num_bg_samples", np.mean(num_bg_samples)) - - return proposals_with_gt + return proposals_with_gt, unk_sel_gt def forward( self, @@ -455,7 +543,7 @@ class Res5ROIHeads(ROIHeads): location = os.path.join(self.energy_save_path, shortuuid.uuid() + '.pkl') torch.save(data, location) - def forward(self, images, features, proposals, targets=None): + def forward(self, images, features, proposals, targets=None, image_id=None, ori_image=None): """ See :meth:`ROIHeads.forward`. """ @@ -463,7 +551,7 @@ class Res5ROIHeads(ROIHeads): if self.training: assert targets - proposals = self.label_and_sample_proposals(proposals, targets) + proposals, unk_sel_gt = self.label_and_sample_proposals(proposals, targets, image_id, ori_image) del targets proposal_boxes = [x.proposal_boxes for x in proposals] @@ -471,19 +559,13 @@ class Res5ROIHeads(ROIHeads): [features[f] for f in self.in_features], proposal_boxes ) input_features = box_features.mean(dim=[2, 3]) - if self.training: - # self.log_features(input_features, proposals) - if self.enable_clustering: - # print("11111111111111112222222222222222222") - self.box_predictor.update_feature_store(input_features, proposals) - self.box_predictor.updatePrototype() predictions = self.box_predictor(input_features) if self.training: # self.log_features(input_features, proposals) - # if self.enable_clustering: - # self.box_predictor.update_feature_store(input_features, proposals) - # del features + if self.enable_clustering: + self.box_predictor.update_feature_store(input_features, proposals) + del features if self.compute_energy_flag: self.compute_energy(predictions, proposals) losses = self.box_predictor.losses(predictions, proposals, input_features) @@ -498,11 +580,11 @@ class Res5ROIHeads(ROIHeads): mask_features = box_features[torch.cat(fg_selection_masks, dim=0)] del box_features losses.update(self.mask_head(mask_features, proposals)) - return [], losses + return [], losses, unk_sel_gt else: pred_instances, _ = self.box_predictor.inference(predictions, proposals) pred_instances = self.forward_with_given_boxes(features, pred_instances) - return pred_instances, {} + return pred_instances, {}, [] def forward_with_given_boxes(self, features, instances): """