# Adapted from https://github.com/IDEA-Research/GroundingDINO. The original liscenses are: # ------------------------------------------------------------------------ # Grounding DINO # url: https://github.com/IDEA-Research/GroundingDINO # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Conditional DETR model and criterion classes. # Copyright (c) 2021 Microsoft. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from DETR (https://github.com/facebookresearch/detr) # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # ------------------------------------------------------------------------ # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) # Copyright (c) 2020 SenseTime. All Rights Reserved. # ------------------------------------------------------------------------ import copy from typing import List import torch import torch.nn.functional as F from torch import nn, einsum from torchvision.ops.boxes import nms from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast from groundingdino_new.util import box_ops, get_tokenlizer from groundingdino_new.util.misc import ( NestedTensor, accuracy, get_world_size, interpolate, inverse_sigmoid, is_dist_avail_and_initialized, nested_tensor_from_tensor_list, ) from groundingdino_new.util.utils import get_phrases_from_posmap from groundingdino_new.util.visualizer import COCOVisualizer from groundingdino_new.util.vl_utils import create_positive_map_from_span from ..registry import MODULE_BUILD_FUNCS from .backbone import build_backbone from .bertwarper import ( BertModelWarper, generate_masks_with_special_tokens, generate_masks_with_special_tokens_and_transfer_map, ) from .transformer import build_transformer from .utils import MLP, ContrastiveEmbed, sigmoid_focal_loss from maskrcnn_benchmark.structures.image_list import ImageList from maskrcnn_benchmark.modeling.rpn.inference import convert_grounding_to_od_logits from maskrcnn_benchmark.modeling.box_coder import BoxCoder from maskrcnn_benchmark.structures.bounding_box import BoxList from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes from maskrcnn_benchmark.structures.boxlist_ops import boxlist_ml_nms from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist # from groundingdino_new.util.inference import preprocess_caption from maskrcnn_benchmark.modeling.poolers import CustomPooler, Pooler from groundingdino_new.models.GroundingDINO.loss import SetCriterion from groundingdino_new.models.GroundingDINO.matcher import build_matcher from maskrcnn_benchmark.modeling.language_backbone import build_language_backbone from maskrcnn_benchmark.modeling.language_backbone.modeling_bert_new import QVBertModel from transformers import BertConfig, RobertaConfig, RobertaModel from maskrcnn_benchmark.modeling.query_selector import build_query_selector import os def expand_bbox(box_list, expand_ratio=1.5): new_box_list=[] for boxes in box_list: assert boxes.mode == "xyxy" bbox=boxes.bbox image_size=boxes.size box_w, box_h = bbox[:,2] - bbox[:,0], bbox[:,3] - bbox[:,1] new_box_w, new_box_h = box_w*expand_ratio, box_h*expand_ratio diff_w=(new_box_w-box_w)/2 diff_h=(new_box_h-box_h)/2 diff=torch.stack([-diff_w, -diff_h, diff_w, diff_h], dim=1) new_bbox=bbox+diff new_boxes=BoxList(new_bbox, image_size, mode="xyxy") labels=boxes.get_field('labels') new_boxes.add_field('labels', labels) new_boxes=new_boxes.clip_to_image(remove_empty=True) new_box_list.append(new_boxes) return new_box_list def preprocess_caption(caption: str) -> str: result = caption.lower().strip() if result.endswith("."): return result return result + "." class GroundingDINO(nn.Module): """This is the Cross-Attention Detector module that performs object detection""" def __init__( self, backbone, transformer, num_queries, aux_loss=False, iter_update=False, query_dim=2, num_feature_levels=1, nheads=8, # two stage two_stage_type="no", # ['no', 'standard'] dec_pred_bbox_embed_share=True, two_stage_class_embed_share=True, two_stage_bbox_embed_share=True, num_patterns=0, dn_number=100, dn_box_noise_scale=0.4, dn_label_noise_ratio=0.5, dn_labelbook_size=100, text_encoder_type="bert-base-uncased", sub_sentence_present=True, max_text_len=256, cfg = None, ): """Initializes the model. Parameters: backbone: torch module of the backbone to be used. See backbone.py transformer: torch module of the transformer architecture. See transformer.py num_queries: number of object queries, ie detection slot. This is the maximal number of objects Conditional DETR can detect in a single image. For COCO, we recommend 100 queries. aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. """ super().__init__() self.cfg = cfg self.box_threshold = cfg.GROUNDINGDINO.box_threshold self.num_queries = num_queries self.transformer = transformer self.hidden_dim = hidden_dim = transformer.d_model self.num_feature_levels = num_feature_levels self.nheads = nheads self.max_text_len = 256 self.sub_sentence_present = sub_sentence_present # setting query dim self.query_dim = query_dim assert query_dim == 4 # for dn training self.num_patterns = num_patterns self.dn_number = dn_number self.dn_box_noise_scale = dn_box_noise_scale self.dn_label_noise_ratio = dn_label_noise_ratio self.dn_labelbook_size = dn_labelbook_size # loss criterion self.loss_evaluator = SetCriterion(matcher=build_matcher(cfg.GROUNDINGDINO.matcher), cfg=cfg) # box pooler for extracting cache resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION if cfg.VISION_QUERY.SELECT_FPN_LEVEL: self.pooler = Pooler( output_size= (resolution, resolution) , scales=cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES, sampling_ratio=cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO, use_v2=True, ) else: self.pooler = CustomPooler( output_size= (resolution, resolution) , scales=cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES, sampling_ratio=cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO, use_v2=True, ) self.pool=nn.AvgPool2d(2) # query selector if cfg.VISION_QUERY.DISABLE_SELECTOR: self.query_selector = None else: self.query_selector = build_query_selector(cfg) # bert self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type) if os.path.basename(text_encoder_type) != "bert-base-uncased": raise NotImplementedError # self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type) config = BertConfig.from_pretrained(text_encoder_type) self.bert = QVBertModel.from_pretrained(text_encoder_type, dim_t=config.hidden_size, dim_v=self.hidden_dim, share_kv=cfg.VISION_QUERY.SHARE_KV, cfg=cfg, config=config) self.bert.pooler.dense.weight.requires_grad_(False) self.bert.pooler.dense.bias.requires_grad_(False) self.bert = BertModelWarper(bert_model=self.bert) self.feat_map = nn.Linear(self.bert.config.hidden_size, self.hidden_dim, bias=True) nn.init.constant_(self.feat_map.bias.data, 0) nn.init.xavier_uniform_(self.feat_map.weight.data) # freeze # special tokens self.specical_tokens = self.tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"]) # prepare input projection layers if num_feature_levels > 1: num_backbone_outs = len(backbone.num_channels) input_proj_list = [] for _ in range(num_backbone_outs): in_channels = backbone.num_channels[_] input_proj_list.append( nn.Sequential( nn.Conv2d(in_channels, hidden_dim, kernel_size=1), nn.GroupNorm(32, hidden_dim), ) ) for _ in range(num_feature_levels - num_backbone_outs): input_proj_list.append( nn.Sequential( nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), nn.GroupNorm(32, hidden_dim), ) ) in_channels = hidden_dim self.input_proj = nn.ModuleList(input_proj_list) else: assert two_stage_type == "no", "two_stage_type should be no if num_feature_levels=1 !!!" self.input_proj = nn.ModuleList( [ nn.Sequential( nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1), nn.GroupNorm(32, hidden_dim), ) ] ) self.backbone = backbone self.aux_loss = aux_loss self.box_pred_damping = box_pred_damping = None self.iter_update = iter_update assert iter_update, "Why not iter_update?" # prepare pred layers self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share # prepare class & box embed _class_embed = ContrastiveEmbed() _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) if dec_pred_bbox_embed_share: box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)] else: box_embed_layerlist = [ copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers) ] class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)] self.bbox_embed = nn.ModuleList(box_embed_layerlist) self.class_embed = nn.ModuleList(class_embed_layerlist) self.transformer.decoder.bbox_embed = self.bbox_embed self.transformer.decoder.class_embed = self.class_embed # two stage self.two_stage_type = two_stage_type assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format( two_stage_type ) if two_stage_type != "no": if two_stage_bbox_embed_share: assert dec_pred_bbox_embed_share self.transformer.enc_out_bbox_embed = _bbox_embed else: self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed) if two_stage_class_embed_share: assert dec_pred_bbox_embed_share self.transformer.enc_out_class_embed = _class_embed else: self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed) self.refpoint_embed = None self._reset_parameters() def _reset_parameters(self): # init input_proj for proj in self.input_proj: nn.init.xavier_uniform_(proj[0].weight, gain=1) nn.init.constant_(proj[0].bias, 0) def init_ref_points(self, use_num_queries): self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim) def convert_groundingdino_to_glip_output(self, groundingdino_out, positive_map, image_sizes): dot_product_logits = groundingdino_out['pred_logits'] box_regression = groundingdino_out['pred_boxes'] B, N, _ = dot_product_logits.shape box_cls = dot_product_logits.new_zeros(B, N, self.cfg.MODEL.DYHEAD.NUM_CLASSES - 1) # candidate_inds = dot_product_logits.max(dim=-1)[0] > self.box_threshold scores = convert_grounding_to_od_logits(logits=dot_product_logits, box_cls=box_cls, positive_map=positive_map, score_agg="MEAN", ) box_cls = scores candidate_inds = box_cls.max(dim=-1)[0] > self.box_threshold # pre_nms_top_n = candidate_inds.reshape(N, -1).sum(1) # pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n) results = [] for per_box_cls, per_box_regression, per_candidate_inds, image_size \ in zip(box_cls, box_regression, candidate_inds, image_sizes): per_box_cls = per_box_cls[per_candidate_inds] per_box_cls, top_k_indices = per_box_cls.topk(1, sorted=False) per_class = top_k_indices[:, 0] + 1 # print(per_class) box = per_box_regression[per_candidate_inds, :].view(-1, 4) H, W = image_size # from 0..1 to 0..W, 0..H box = box * torch.Tensor([W, H, W, H]).to(box.device)[None, ...] # from xywh to xyxy box[:, :2] = box[:, :2] - box[:, 2:] / 2 box[:, 2:] = box[:, 2:] + box[:, :2] detections = box boxlist = BoxList(detections, (W, H), mode="xyxy") boxlist.add_field("labels", per_class) boxlist.add_field("scores", per_box_cls[:,0]) boxlist = boxlist.clip_to_image(remove_empty=False) boxlist = remove_small_boxes(boxlist, min_size=0) results.append(boxlist) return results def load_query_bank(self, query_path): self.query_selector.load_query_bank(query_path) @torch.no_grad() def extract_query(self, samples=None, targets=None, query_images=None, # default_dict(list) ,list[tensors] num_classes: (num_queries, num_scales, num_channels) visual_features=None, exclude_similar=False, device = None, max_query_number = None, ): device = device if device else samples.tensors.device targets = [target.to(device) for target in targets if target is not None] targets=expand_bbox(targets, expand_ratio=self.cfg.VISION_QUERY.EXPAND_RATIO) if visual_features is None: if isinstance(samples, ImageList): image_sizes = samples.image_sizes samples = samples.tensors if isinstance(samples, (list, torch.Tensor)): samples = nested_tensor_from_tensor_list(samples, image_sizes=image_sizes) features, poss = self.backbone(samples) srcs = [] masks = [] for l, feat in enumerate(features): src, mask = feat.decompose() srcs.append(self.input_proj[l](src)) masks.append(mask) assert mask is not None if self.num_feature_levels > len(srcs): _len_srcs = len(srcs) for l in range(_len_srcs, self.num_feature_levels): if l == _len_srcs: src = self.input_proj[l](features[-1].tensors) else: src = self.input_proj[l](srcs[-1]) m = samples.mask mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) srcs.append(src) masks.append(mask) poss.append(pos_l) visual_features = srcs else: visual_features = [v.to(device) for v in visual_features] if self.cfg.VISION_QUERY.SELECT_FPN_LEVEL: query_feats=self.pooler(visual_features, targets) # num_boxes, num_channels, pooler_size, pooler_size query_feats=query_feats[None, ...] # 1, num_boxes, num_channels, pooler_size, pooler_size else: query_feats=self.pooler(visual_features, targets) # num_scales, num_boxes, num_channels, pooler_size, pooler_size # average different fpn levels if not self.cfg.VISION_QUERY.SELECT_FPN_LEVEL: assert len(visual_features) == len(query_feats) == 5 # TODO: support flexible level numbers query_feats = query_feats.mean(dim=[-2,-1]).permute(1, 0, 2) # num_boxes, num_scales, num_channels labels=torch.cat([t.get_field('labels') for t in targets]) assert len(labels)==len(query_feats) max_query_number = self.cfg.VISION_QUERY.MAX_QUERY_NUMBER if max_query_number is None else max_query_number for label, feat in zip(labels, query_feats): label=label.item() num_queries=len(query_images[label]) if num_queries >= max_query_number: continue if exclude_similar and num_queries > 0: assert feat.shape[0] == 1 # TODO: enable all-level and spacial features bank_features = F.normalize(query_images[label], p=2, dim=-1) # N, 1, C new_features = F.normalize(feat, p=2, dim=-1) # 1, C similarity = einsum('b n d, n d -> b n', bank_features, new_features) has_similar_in_bank = (similarity > self.cfg.VISION_QUERY.SIMILARITY_THRESHOLD).sum() > 0 if has_similar_in_bank: continue if num_queries==0: query_images[label] = feat[None, ...] else: query_images[label] = torch.cat([query_images[label], feat[None, ...]]) return query_images def flatten_fpn_features(self, features): # downsample and flat fpn features for pre-select in language backbone return torch.cat([self.pool(f).flatten(-2,-1) for i, f in enumerate(features)], dim=2).permute(0,2,1) @torch.no_grad() def get_labels_and_maps_from_positive_map(self, positive_map, dtype=torch.float): # Only for inference labels_in_caption=[k for k,v in positive_map.items() if len(v) !=0] num_labels=len(labels_in_caption) all_map = torch.zeros((num_labels, self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN), dtype=dtype, device=self.cfg.MODEL.DEVICE) for j, label in enumerate(labels_in_caption): position=positive_map[label] all_map[j, position] = 1 # inplace all_map = all_map / (all_map.sum(-1)[:, None] + 1e-6) return labels_in_caption, all_map def forward(self, samples: NestedTensor, targets: List = None, **kw): """The forward expects a NestedTensor, which consists of: - samples.tensor: batched images, of shape [batch_size x 3 x H x W] - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels It returns a dict with the following elements: - "pred_logits": the classification logits (including no-object) for all queries. Shape= [batch_size x num_queries x num_classes] - "pred_boxes": The normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These values are normalized in [0, 1], relative to the size of each individual image (disregarding possible padding). See PostProcess for information on how to retrieve the unnormalized bounding box. - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of dictionnaries containing the two above keys for each decoder layer. """ if isinstance(samples, ImageList): image_sizes = samples.image_sizes samples = samples.tensors if targets is None: captions = kw["captions"] else: captions = [t.get_field("caption") for t in targets if "caption" in t.fields()] len(captions) captions = [preprocess_caption(c) for c in captions] positive_map = kw['positive_map'] try: return_backbone_features = kw['return_backbone_features'] except: return_backbone_features = False # import ipdb; ipdb.set_trace() if isinstance(samples, (list, torch.Tensor)): samples = nested_tensor_from_tensor_list(samples, image_sizes=image_sizes) features, poss = self.backbone(samples) srcs = [] masks = [] for l, feat in enumerate(features): src, mask = feat.decompose() srcs.append(self.input_proj[l](src)) masks.append(mask) assert mask is not None if self.num_feature_levels > len(srcs): _len_srcs = len(srcs) for l in range(_len_srcs, self.num_feature_levels): if l == _len_srcs: src = self.input_proj[l](features[-1].tensors) else: src = self.input_proj[l](srcs[-1]) m = samples.mask mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) srcs.append(src) masks.append(mask) poss.append(pos_l) # query embedding if self.cfg.VISION_QUERY.ENABLED: if self.training: batched_labels_in_caption=[t.get_field('labels_in_caption') for t in targets] batched_all_map=[t.get_field('all_map') for t in targets] batched_pos_category_map=[t.get_field('positive_category_map') for t in targets] ################ BUG: batched_pos_category_map is not binary ###################### batched_pos_labels = [t.get_field('labels') for t in targets] else: assert samples.tensors.shape[0]==1 # TODO: Only support batch size = 1 for test labels_in_caption, all_map = self.get_labels_and_maps_from_positive_map(positive_map, dtype=srcs[0].dtype) batched_labels_in_caption = [labels_in_caption] batched_all_map = [all_map] batched_pos_category_map = None batched_pos_labels = None query_features, query_attetion_masks, batched_has_vision_query=self.query_selector(batched_labels_in_caption, batched_all_map, batched_pos_labels) vision_inputs_in_language_backbone={'vision': query_features, 'images': self.flatten_fpn_features(srcs), 'vision_attention_mask': query_attetion_masks, 'batched_pos_category_map': batched_pos_category_map} else: vision_inputs_in_language_backbone={'vision': None, 'images': None, 'vision_attention_mask': None, 'batched_pos_category_map': None} # encoder texts # assume each category is consist of its text tokens and one '.' # tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to( # samples.device # ) tokenized = self.tokenizer(captions, padding='max_length', return_tensors="pt").to( samples.device ) ( text_self_attention_masks, # each category token only attend to its own category tokens and one '.' position_ids, # [[0, 0, 1, 2, 0, 1, 0]] cate_to_token_mask_list, ) = generate_masks_with_special_tokens_and_transfer_map( tokenized, self.specical_tokens, self.tokenizer ) if text_self_attention_masks.shape[1] > self.max_text_len: text_self_attention_masks = text_self_attention_masks[ :, : self.max_text_len, : self.max_text_len ] position_ids = position_ids[:, : self.max_text_len] tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len] tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len] tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len] # extract text embeddings if self.sub_sentence_present: tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"} tokenized_for_encoder["attention_mask"] = text_self_attention_masks tokenized_for_encoder["position_ids"] = position_ids else: # import ipdb; ipdb.set_trace() tokenized_for_encoder = tokenized tokenized_for_encoder.update(vision_inputs_in_language_backbone) bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768 encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model text_token_mask = tokenized.attention_mask.bool() # bs, 195 # text_token_mask: True for nomask, False for mask # text_self_attention_masks: True for nomask, False for mask if encoded_text.shape[1] > self.max_text_len: encoded_text = encoded_text[:, : self.max_text_len, :] text_token_mask = text_token_mask[:, : self.max_text_len] position_ids = position_ids[:, : self.max_text_len] text_self_attention_masks = text_self_attention_masks[ :, : self.max_text_len, : self.max_text_len ] text_dict = { "encoded_text": encoded_text, # bs, 195, d_model "text_token_mask": text_token_mask, # bs, 195 "position_ids": position_ids, # bs, 195 "text_self_attention_masks": text_self_attention_masks, # bs, 195,195 } input_query_bbox = input_query_label = attn_mask = dn_meta = None hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer( srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict ) # deformable-detr-like anchor update outputs_coord_list = [] for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate( zip(reference[:-1], self.bbox_embed, hs) ): layer_delta_unsig = layer_bbox_embed(layer_hs) layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig) layer_outputs_unsig = layer_outputs_unsig.sigmoid() outputs_coord_list.append(layer_outputs_unsig) outputs_coord_list = torch.stack(outputs_coord_list) # output outputs_class = torch.stack( [ layer_cls_embed(layer_hs, text_dict) for layer_cls_embed, layer_hs in zip(self.class_embed, hs) ] ) if self.training: out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]} aux_outputs = [{"pred_logits": outputs_class[k], "pred_boxes": outputs_coord_list[k]} for k in range(len(outputs_class)-1)] out['aux_outputs'] = aux_outputs positive_map_ = positive_map.clone().to(outputs_class[-1].device) positive_map_[positive_map_>0]=1. # padding to max_text_len text_mask = torch.full((*text_dict["text_token_mask"].shape[:-1], self.max_text_len), bool(False), device=text_dict["text_token_mask"].device) text_mask[..., : text_dict["text_token_mask"].shape[-1]] = text_dict["text_token_mask"] losses = self.loss_evaluator(out, targets, text_mask=text_mask ,positive_map=positive_map_) if self.cfg.VISION_QUERY.ENABLED: #### gate loss ##### # concatenate all gates gates = [] for _ ,g in bert_output['vision_query_gates'].items(): gates = gates + g num_gates=len(gates) loss_gate=0 for g in gates: loss_gate=loss_gate+(1-torch.abs(g[0])) loss_gate= self.cfg.VISION_QUERY.GATE_REGULARIZATION_SCALE * loss_gate / num_gates if self.cfg.VISION_QUERY.GATE_REGULARIZATION: gate_losses = {'loss_gate': loss_gate.sum()} else: loss_gate = loss_gate.sum().detach() # Only for analysis gate_losses = {'loss_gate': loss_gate} #################### losses.update(gate_losses) return losses else: out = {"pred_logits": outputs_class[-1].sigmoid(), "pred_boxes": outputs_coord_list[-1]} result = self.convert_groundingdino_to_glip_output(out, positive_map, image_sizes) if return_backbone_features: return result, srcs return result # # for intermediate outputs # if self.aux_loss: # out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list) # # for encoder output # if hs_enc is not None: # # prepare intermediate outputs # interm_coord = ref_enc[-1] # interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict) # out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord} # out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal} # return out @torch.jit.unused def _set_aux_loss(self, outputs_class, outputs_coord): # this is a workaround to make torchscript happy, as torchscript # doesn't support dictionary with non-homogeneous values, such # as a dict having both a Tensor and a list. return [ {"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) ] @MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino") def build_groundingdino(args, cfg): backbone = build_backbone(args) transformer = build_transformer(args) dn_labelbook_size = args.dn_labelbook_size dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share sub_sentence_present = args.sub_sentence_present model = GroundingDINO( backbone, transformer, num_queries=args.num_queries, aux_loss=True, iter_update=True, query_dim=4, num_feature_levels=args.num_feature_levels, nheads=args.nheads, dec_pred_bbox_embed_share=dec_pred_bbox_embed_share, two_stage_type=args.two_stage_type, two_stage_bbox_embed_share=args.two_stage_bbox_embed_share, two_stage_class_embed_share=args.two_stage_class_embed_share, num_patterns=args.num_patterns, dn_number=0, dn_box_noise_scale=args.dn_box_noise_scale, dn_label_noise_ratio=args.dn_label_noise_ratio, dn_labelbook_size=dn_labelbook_size, text_encoder_type=args.text_encoder_type, sub_sentence_present=sub_sentence_present, max_text_len=args.max_text_len, cfg=cfg, ) return model