Add files via upload

update
This commit is contained in:
szsteven008 2025-01-01 16:23:11 +08:00 committed by GitHub
parent 89b1ad20b9
commit bd929612c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -230,6 +230,8 @@ class GroundingDINO(nn.Module):
input_ids: Tensor, input_ids: Tensor,
attention_mask: Tensor, attention_mask: Tensor,
token_type_ids: Tensor, token_type_ids: Tensor,
position_ids: Tensor,
text_self_attention_masks: Tensor,
**kw): **kw):
"""The forward expects a NestedTensor, which consists of: """The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W] - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
@ -262,22 +264,22 @@ class GroundingDINO(nn.Module):
"token_type_ids": token_type_ids, "token_type_ids": token_type_ids,
} }
( # (
text_self_attention_masks, # text_self_attention_masks,
position_ids, # position_ids,
cate_to_token_mask_list, # cate_to_token_mask_list,
) = generate_masks_with_special_tokens_and_transfer_map( # ) = generate_masks_with_special_tokens_and_transfer_map(
tokenized, self.specical_tokens, self.tokenizer # tokenized, self.specical_tokens, self.tokenizer
) # )
if text_self_attention_masks.shape[1] > self.max_text_len: # if text_self_attention_masks.shape[1] > self.max_text_len:
text_self_attention_masks = text_self_attention_masks[ # text_self_attention_masks = text_self_attention_masks[
:, : self.max_text_len, : self.max_text_len # :, : self.max_text_len, : self.max_text_len
] # ]
position_ids = position_ids[:, : self.max_text_len] # position_ids = position_ids[:, : self.max_text_len]
tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len] # tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len]
tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len] # tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len] # tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
# extract text embeddings # extract text embeddings
if self.sub_sentence_present: if self.sub_sentence_present:
@ -292,7 +294,7 @@ class GroundingDINO(nn.Module):
encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
text_token_mask = tokenized["attention_mask"].bool() # bs, 195 text_token_mask = tokenized["attention_mask"].bool() # bs, 195
# text_token_mask = tokenizedattention_mask.bool() # bs, 195 # text_token_mask = tokenized.attention_mask.bool() # bs, 195
# text_token_mask: True for nomask, False for mask # text_token_mask: True for nomask, False for mask
# text_self_attention_masks: True for nomask, False for mask # text_self_attention_masks: True for nomask, False for mask