Add files via upload

两个阶段
pull/380/head
szsteven008 2025-01-06 16:21:54 +08:00 committed by GitHub
parent aab23aea55
commit d4a0bf6701
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 22 deletions

View File

@ -190,8 +190,8 @@ def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
for special_token in special_tokens_list:
special_tokens_mask = torch.logical_or(special_tokens_mask, input_ids == special_token)
#special_tokens_mask |= input_ids == special_token
# special_tokens_mask = torch.logical_or(special_tokens_mask, input_ids == special_token)
special_tokens_mask |= input_ids == special_token
# idxs: each row is a list of indices of special tokens
idxs = torch.nonzero(special_tokens_mask)
@ -235,8 +235,8 @@ def generate_masks_with_special_tokens_and_transfer_map(tokenized, special_token
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
for special_token in special_tokens_list:
special_tokens_mask = torch.logical_or(special_tokens_mask, input_ids == special_token)
# special_tokens_mask |= input_ids == special_token
# special_tokens_mask = torch.logical_or(special_tokens_mask, input_ids == special_token)
special_tokens_mask |= input_ids == special_token
# idxs: each row is a list of indices of special tokens
idxs = torch.nonzero(special_tokens_mask)

View File

@ -227,9 +227,8 @@ class GroundingDINO(nn.Module):
# def forward(self, samples: NestedTensor, targets: List = None, **kw):
def forward(self,
samples: NestedTensor,
input_ids: Tensor,
last_hidden_state: Tensor,
attention_mask: Tensor,
token_type_ids: Tensor,
position_ids: Tensor,
text_self_attention_masks: Tensor,
**kw):
@ -258,11 +257,11 @@ class GroundingDINO(nn.Module):
samples.device
)
"""
tokenized = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
# tokenized = {
# "input_ids": input_ids,
# "attention_mask": attention_mask,
# "token_type_ids": token_type_ids,
# }
# (
# text_self_attention_masks,
@ -281,22 +280,24 @@ class GroundingDINO(nn.Module):
# tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
# tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
# extract text embeddings
if self.sub_sentence_present:
tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
tokenized_for_encoder["attention_mask"] = text_self_attention_masks
tokenized_for_encoder["position_ids"] = position_ids
else:
# import ipdb; ipdb.set_trace()
tokenized_for_encoder = tokenized
# # extract text embeddings
# if self.sub_sentence_present:
# tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
# tokenized_for_encoder["attention_mask"] = text_self_attention_masks
# tokenized_for_encoder["position_ids"] = position_ids
# else:
# # import ipdb; ipdb.set_trace()
# tokenized_for_encoder = tokenized
bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
# bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
text_token_mask = tokenized["attention_mask"].bool() # bs, 195
# encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
# text_token_mask = tokenized["attention_mask"].bool() # bs, 195
# text_token_mask = tokenized.attention_mask.bool() # bs, 195
# text_token_mask: True for nomask, False for mask
# text_self_attention_masks: True for nomask, False for mask
encoded_text = self.feat_map(last_hidden_state) # bs, 195, d_model
text_token_mask = attention_mask.bool() # bs, 195
if encoded_text.shape[1] > self.max_text_len:
encoded_text = encoded_text[:, : self.max_text_len, :]