Add files via upload

两个阶段
pull/380/head
szsteven008 2025-01-06 16:14:12 +08:00 committed by GitHub
parent dea1577033
commit aab23aea55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 8 additions and 2 deletions

View File

@ -89,11 +89,17 @@ def predict(
tokenized["attention_mask"] = tokenized["attention_mask"][:, : max_text_len]
tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : max_text_len]
tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
tokenized_for_encoder["attention_mask"] = text_self_attention_masks
tokenized_for_encoder["position_ids"] = position_ids
bert = model.bert
bert_output = bert(**tokenized_for_encoder) # bs, 195, 768
with torch.no_grad():
outputs = model(image.unsqueeze(0),
input_ids=tokenized["input_ids"],
last_hidden_state=bert_output["last_hidden_state"],
attention_mask=tokenized["attention_mask"],
token_type_ids=tokenized["token_type_ids"],
position_ids = position_ids,
text_self_attention_masks = text_self_attention_masks)
# outputs = model(image[None], captions=[caption])