parent
aab23aea55
commit
d4a0bf6701
|
@ -190,8 +190,8 @@ def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer
|
|||
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
|
||||
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
|
||||
for special_token in special_tokens_list:
|
||||
special_tokens_mask = torch.logical_or(special_tokens_mask, input_ids == special_token)
|
||||
#special_tokens_mask |= input_ids == special_token
|
||||
# special_tokens_mask = torch.logical_or(special_tokens_mask, input_ids == special_token)
|
||||
special_tokens_mask |= input_ids == special_token
|
||||
|
||||
# idxs: each row is a list of indices of special tokens
|
||||
idxs = torch.nonzero(special_tokens_mask)
|
||||
|
@ -235,8 +235,8 @@ def generate_masks_with_special_tokens_and_transfer_map(tokenized, special_token
|
|||
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
|
||||
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
|
||||
for special_token in special_tokens_list:
|
||||
special_tokens_mask = torch.logical_or(special_tokens_mask, input_ids == special_token)
|
||||
# special_tokens_mask |= input_ids == special_token
|
||||
# special_tokens_mask = torch.logical_or(special_tokens_mask, input_ids == special_token)
|
||||
special_tokens_mask |= input_ids == special_token
|
||||
|
||||
# idxs: each row is a list of indices of special tokens
|
||||
idxs = torch.nonzero(special_tokens_mask)
|
||||
|
|
|
@ -227,9 +227,8 @@ class GroundingDINO(nn.Module):
|
|||
# def forward(self, samples: NestedTensor, targets: List = None, **kw):
|
||||
def forward(self,
|
||||
samples: NestedTensor,
|
||||
input_ids: Tensor,
|
||||
last_hidden_state: Tensor,
|
||||
attention_mask: Tensor,
|
||||
token_type_ids: Tensor,
|
||||
position_ids: Tensor,
|
||||
text_self_attention_masks: Tensor,
|
||||
**kw):
|
||||
|
@ -258,11 +257,11 @@ class GroundingDINO(nn.Module):
|
|||
samples.device
|
||||
)
|
||||
"""
|
||||
tokenized = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
# tokenized = {
|
||||
# "input_ids": input_ids,
|
||||
# "attention_mask": attention_mask,
|
||||
# "token_type_ids": token_type_ids,
|
||||
# }
|
||||
|
||||
# (
|
||||
# text_self_attention_masks,
|
||||
|
@ -281,22 +280,24 @@ class GroundingDINO(nn.Module):
|
|||
# tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
|
||||
# tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
|
||||
|
||||
# extract text embeddings
|
||||
if self.sub_sentence_present:
|
||||
tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
|
||||
tokenized_for_encoder["attention_mask"] = text_self_attention_masks
|
||||
tokenized_for_encoder["position_ids"] = position_ids
|
||||
else:
|
||||
# import ipdb; ipdb.set_trace()
|
||||
tokenized_for_encoder = tokenized
|
||||
# # extract text embeddings
|
||||
# if self.sub_sentence_present:
|
||||
# tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
|
||||
# tokenized_for_encoder["attention_mask"] = text_self_attention_masks
|
||||
# tokenized_for_encoder["position_ids"] = position_ids
|
||||
# else:
|
||||
# # import ipdb; ipdb.set_trace()
|
||||
# tokenized_for_encoder = tokenized
|
||||
|
||||
bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
|
||||
# bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
|
||||
|
||||
encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
|
||||
text_token_mask = tokenized["attention_mask"].bool() # bs, 195
|
||||
# encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
|
||||
# text_token_mask = tokenized["attention_mask"].bool() # bs, 195
|
||||
# text_token_mask = tokenized.attention_mask.bool() # bs, 195
|
||||
# text_token_mask: True for nomask, False for mask
|
||||
# text_self_attention_masks: True for nomask, False for mask
|
||||
encoded_text = self.feat_map(last_hidden_state) # bs, 195, d_model
|
||||
text_token_mask = attention_mask.bool() # bs, 195
|
||||
|
||||
if encoded_text.shape[1] > self.max_text_len:
|
||||
encoded_text = encoded_text[:, : self.max_text_len, :]
|
||||
|
|
Loading…
Reference in New Issue