parent
dea1577033
commit
aab23aea55
|
@ -89,11 +89,17 @@ def predict(
|
|||
tokenized["attention_mask"] = tokenized["attention_mask"][:, : max_text_len]
|
||||
tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : max_text_len]
|
||||
|
||||
tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
|
||||
tokenized_for_encoder["attention_mask"] = text_self_attention_masks
|
||||
tokenized_for_encoder["position_ids"] = position_ids
|
||||
|
||||
bert = model.bert
|
||||
bert_output = bert(**tokenized_for_encoder) # bs, 195, 768
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(image.unsqueeze(0),
|
||||
input_ids=tokenized["input_ids"],
|
||||
last_hidden_state=bert_output["last_hidden_state"],
|
||||
attention_mask=tokenized["attention_mask"],
|
||||
token_type_ids=tokenized["token_type_ids"],
|
||||
position_ids = position_ids,
|
||||
text_self_attention_masks = text_self_attention_masks)
|
||||
# outputs = model(image[None], captions=[caption])
|
||||
|
|
Loading…
Reference in New Issue