parent
aab23aea55
commit
d4a0bf6701
|
@ -190,8 +190,8 @@ def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer
|
||||||
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
|
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
|
||||||
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
|
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
|
||||||
for special_token in special_tokens_list:
|
for special_token in special_tokens_list:
|
||||||
special_tokens_mask = torch.logical_or(special_tokens_mask, input_ids == special_token)
|
# special_tokens_mask = torch.logical_or(special_tokens_mask, input_ids == special_token)
|
||||||
#special_tokens_mask |= input_ids == special_token
|
special_tokens_mask |= input_ids == special_token
|
||||||
|
|
||||||
# idxs: each row is a list of indices of special tokens
|
# idxs: each row is a list of indices of special tokens
|
||||||
idxs = torch.nonzero(special_tokens_mask)
|
idxs = torch.nonzero(special_tokens_mask)
|
||||||
|
@ -235,8 +235,8 @@ def generate_masks_with_special_tokens_and_transfer_map(tokenized, special_token
|
||||||
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
|
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
|
||||||
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
|
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
|
||||||
for special_token in special_tokens_list:
|
for special_token in special_tokens_list:
|
||||||
special_tokens_mask = torch.logical_or(special_tokens_mask, input_ids == special_token)
|
# special_tokens_mask = torch.logical_or(special_tokens_mask, input_ids == special_token)
|
||||||
# special_tokens_mask |= input_ids == special_token
|
special_tokens_mask |= input_ids == special_token
|
||||||
|
|
||||||
# idxs: each row is a list of indices of special tokens
|
# idxs: each row is a list of indices of special tokens
|
||||||
idxs = torch.nonzero(special_tokens_mask)
|
idxs = torch.nonzero(special_tokens_mask)
|
||||||
|
|
|
@ -227,9 +227,8 @@ class GroundingDINO(nn.Module):
|
||||||
# def forward(self, samples: NestedTensor, targets: List = None, **kw):
|
# def forward(self, samples: NestedTensor, targets: List = None, **kw):
|
||||||
def forward(self,
|
def forward(self,
|
||||||
samples: NestedTensor,
|
samples: NestedTensor,
|
||||||
input_ids: Tensor,
|
last_hidden_state: Tensor,
|
||||||
attention_mask: Tensor,
|
attention_mask: Tensor,
|
||||||
token_type_ids: Tensor,
|
|
||||||
position_ids: Tensor,
|
position_ids: Tensor,
|
||||||
text_self_attention_masks: Tensor,
|
text_self_attention_masks: Tensor,
|
||||||
**kw):
|
**kw):
|
||||||
|
@ -258,11 +257,11 @@ class GroundingDINO(nn.Module):
|
||||||
samples.device
|
samples.device
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
tokenized = {
|
# tokenized = {
|
||||||
"input_ids": input_ids,
|
# "input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
# "attention_mask": attention_mask,
|
||||||
"token_type_ids": token_type_ids,
|
# "token_type_ids": token_type_ids,
|
||||||
}
|
# }
|
||||||
|
|
||||||
# (
|
# (
|
||||||
# text_self_attention_masks,
|
# text_self_attention_masks,
|
||||||
|
@ -281,22 +280,24 @@ class GroundingDINO(nn.Module):
|
||||||
# tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
|
# tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
|
||||||
# tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
|
# tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
|
||||||
|
|
||||||
# extract text embeddings
|
# # extract text embeddings
|
||||||
if self.sub_sentence_present:
|
# if self.sub_sentence_present:
|
||||||
tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
|
# tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
|
||||||
tokenized_for_encoder["attention_mask"] = text_self_attention_masks
|
# tokenized_for_encoder["attention_mask"] = text_self_attention_masks
|
||||||
tokenized_for_encoder["position_ids"] = position_ids
|
# tokenized_for_encoder["position_ids"] = position_ids
|
||||||
else:
|
# else:
|
||||||
# import ipdb; ipdb.set_trace()
|
# # import ipdb; ipdb.set_trace()
|
||||||
tokenized_for_encoder = tokenized
|
# tokenized_for_encoder = tokenized
|
||||||
|
|
||||||
bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
|
# bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
|
||||||
|
|
||||||
encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
|
# encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
|
||||||
text_token_mask = tokenized["attention_mask"].bool() # bs, 195
|
# text_token_mask = tokenized["attention_mask"].bool() # bs, 195
|
||||||
# text_token_mask = tokenized.attention_mask.bool() # bs, 195
|
# text_token_mask = tokenized.attention_mask.bool() # bs, 195
|
||||||
# text_token_mask: True for nomask, False for mask
|
# text_token_mask: True for nomask, False for mask
|
||||||
# text_self_attention_masks: True for nomask, False for mask
|
# text_self_attention_masks: True for nomask, False for mask
|
||||||
|
encoded_text = self.feat_map(last_hidden_state) # bs, 195, d_model
|
||||||
|
text_token_mask = attention_mask.bool() # bs, 195
|
||||||
|
|
||||||
if encoded_text.shape[1] > self.max_text_len:
|
if encoded_text.shape[1] > self.max_text_len:
|
||||||
encoded_text = encoded_text[:, : self.max_text_len, :]
|
encoded_text = encoded_text[:, : self.max_text_len, :]
|
||||||
|
|
Loading…
Reference in New Issue