247 lines
9.5 KiB
Python
247 lines
9.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Optional, Union
|
|
|
|
import mmengine
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmengine.model import BaseModule
|
|
|
|
from mmpretrain.registry import MODELS
|
|
|
|
|
|
@MODELS.register_module()
|
|
class VQAGenerationHead(BaseModule):
|
|
"""Generation head for multi-modal pre-trained task, adapted by BLIP.
|
|
Normally used for qa generation task (open-set)
|
|
|
|
Args:
|
|
decoder (dict): Decoder for decoding answers.
|
|
inference_method (str): Inference method. One of 'rank', 'generate'.
|
|
- If 'rank', the model will return answers with the highest
|
|
probability from the answer list.
|
|
- If 'generate', the model will generate answers.
|
|
- Only for test, not for train / val.
|
|
num_beams (int): Number of beams for beam search. 1 means no beam
|
|
search. Only support when inference_method=='generate'.
|
|
Defaults to 3.
|
|
num_ans_candidates (int): Number of answer candidates, used to filter
|
|
out answers with low probability. Only support when
|
|
inference_method=='rank'. Defaults to 128.
|
|
loss (dict or nn.Module): Config of loss or module of loss. Defaults to
|
|
``nn.CrossEntropyLoss(reduction='none', ignore_index=-100)``.
|
|
init_cfg (dict, optional): the config to control the initialization.
|
|
Defaults to None.
|
|
answer_list_path (str, optional): Path to `answer_list.json`
|
|
(json file of a answer list). Required when
|
|
inference_method=='rank'.
|
|
|
|
|
|
TODO: `mmcls.LabelSmoothLoss` has not support `ignore_index` param.
|
|
Now using `nn.CrossEntropyLoss`, without label_smoothing, in order to
|
|
maintain compatibility with torch < 1.10.0
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
decoder: dict,
|
|
inference_method: str = 'generate',
|
|
num_beams: int = 3,
|
|
num_ans_candidates: int = 128,
|
|
loss: Union[dict, nn.Module] = nn.CrossEntropyLoss(
|
|
reduction='none', ignore_index=-100),
|
|
init_cfg: Optional[dict] = None,
|
|
answer_list_path: Optional[str] = None,
|
|
) -> None:
|
|
|
|
super(VQAGenerationHead, self).__init__(init_cfg=init_cfg)
|
|
self.decoder = MODELS.build(decoder)
|
|
|
|
if inference_method == 'generate':
|
|
assert isinstance(num_beams, int), \
|
|
'for VQA `generate` mode, `num_beams` must be a int.'
|
|
self.num_beams = num_beams
|
|
self.num_ans_candidates = None
|
|
self.answer_list = None
|
|
|
|
elif inference_method == 'rank':
|
|
assert isinstance(num_ans_candidates, int), \
|
|
'for VQA `rank` mode, `num_ans_candidates` must be a int.'
|
|
assert isinstance(answer_list_path, str), \
|
|
'for VQA `rank` mode, `answer_list_path` must be set as ' \
|
|
'the path to `answer_list.json`.'
|
|
self.num_beams = None
|
|
self.answer_list = mmengine.load(answer_list_path)
|
|
if isinstance(self.answer_list, dict):
|
|
self.answer_list = list(self.answer_list.keys())
|
|
assert isinstance(self.answer_list, list) and all(
|
|
isinstance(item, str) for item in self.answer_list), \
|
|
'for VQA `rank` mode, `answer_list.json` must be a list of str'
|
|
self.num_ans_candidates = min(num_ans_candidates,
|
|
len(self.answer_list))
|
|
|
|
else:
|
|
raise AssertionError(
|
|
'for VQA, `inference_method` must be "generate" or "rank", '
|
|
'got {}.'.format(inference_method))
|
|
|
|
self.inference_method = inference_method
|
|
if not isinstance(loss, nn.Module):
|
|
loss = MODELS.build(loss)
|
|
self.loss_module = loss
|
|
|
|
def forward(self, feats: dict):
|
|
prediction_logits = self.decoder(
|
|
feats['answer_input_ids'],
|
|
attention_mask=feats['answer_attention_mask'],
|
|
encoder_hidden_states=feats['question_states'],
|
|
encoder_attention_mask=feats['question_atts'],
|
|
labels=feats['answer_targets'],
|
|
return_dict=True,
|
|
return_logits=True, # directly return logits, not computing loss
|
|
reduction='none',
|
|
)
|
|
return prediction_logits
|
|
|
|
def loss(self, feats: dict, data_samples=None):
|
|
"""Calculate losses from the extracted features.
|
|
|
|
Args:
|
|
feats (dict): The features extracted from the backbone.
|
|
data_samples (List[BaseDataElement]): The annotation data of
|
|
every samples.
|
|
|
|
Returns:
|
|
dict[str, Tensor]: a dictionary of loss components
|
|
"""
|
|
shifted_prediction_scores = self(feats)
|
|
labels = feats['answer_targets']
|
|
lm_loss = None
|
|
|
|
# we are doing next-token prediction;
|
|
# shift prediction scores and input ids by one
|
|
labels = labels[:, 1:].contiguous()
|
|
lm_loss = self.loss_module(
|
|
shifted_prediction_scores.view(-1,
|
|
self.decoder.med_config.vocab_size),
|
|
labels.view(-1))
|
|
lm_loss = lm_loss.view(shifted_prediction_scores.size(0), -1).sum(1)
|
|
# compute weighted loss
|
|
losses = dict()
|
|
loss = feats['answer_weight'] * lm_loss
|
|
loss = loss.sum() / feats['batch_size']
|
|
losses['vqa_loss'] = loss
|
|
|
|
return losses
|
|
|
|
def predict_rank(self, feats: dict, data_samples=None):
|
|
"""Predict rank in a close-set answer list."""
|
|
question_states = feats['multimodal_embeds']
|
|
question_atts = feats['question_atts']
|
|
answer_candidates = feats['answer_candidates']
|
|
assert answer_candidates is not None
|
|
|
|
answer_ids = answer_candidates.input_ids
|
|
answer_atts = answer_candidates.attention_mask
|
|
num_ques = question_states.size(0)
|
|
start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token
|
|
|
|
start_output = self.decoder(
|
|
start_ids,
|
|
encoder_hidden_states=question_states,
|
|
encoder_attention_mask=question_atts,
|
|
return_dict=True,
|
|
reduction='none',
|
|
)
|
|
logits = start_output.logits[:, 0, :] # first token's logit
|
|
|
|
# topk_probs: top-k probability
|
|
# topk_ids: [num_question, k]
|
|
answer_first_token = answer_ids[:, 1]
|
|
prob_first_token = F.softmax(
|
|
logits, dim=1).index_select(
|
|
dim=1, index=answer_first_token)
|
|
topk_probs, topk_ids = prob_first_token.topk(
|
|
self.num_ans_candidates, dim=1)
|
|
|
|
# answer input: [num_question*k, answer_len]
|
|
input_ids = []
|
|
input_atts = []
|
|
for b, topk_id in enumerate(topk_ids):
|
|
input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
|
|
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
|
|
input_ids = torch.cat(input_ids, dim=0)
|
|
input_atts = torch.cat(input_atts, dim=0)
|
|
|
|
targets_ids = input_ids.masked_fill(input_ids == feats['pad_token_id'],
|
|
-100)
|
|
|
|
def tile(x, dim, n_tile):
|
|
init_dim = x.size(dim)
|
|
repeat_idx = [1] * x.dim()
|
|
repeat_idx[dim] = n_tile
|
|
x = x.repeat(*(repeat_idx))
|
|
order_index = torch.LongTensor(
|
|
np.concatenate([
|
|
init_dim * np.arange(n_tile) + i for i in range(init_dim)
|
|
]))
|
|
return torch.index_select(x, dim, order_index.to(x.device))
|
|
|
|
# repeat encoder's output for top-k answers
|
|
question_states = tile(question_states, 0, self.num_ans_candidates)
|
|
question_atts = tile(question_atts, 0, self.num_ans_candidates)
|
|
|
|
output = self.decoder(
|
|
input_ids,
|
|
attention_mask=input_atts,
|
|
encoder_hidden_states=question_states,
|
|
encoder_attention_mask=question_atts,
|
|
labels=targets_ids,
|
|
return_dict=True,
|
|
reduction='none',
|
|
)
|
|
|
|
log_probs_sum = -output.loss
|
|
log_probs_sum = log_probs_sum.view(num_ques, self.num_ans_candidates)
|
|
|
|
max_topk_ids = log_probs_sum.argmax(dim=1)
|
|
max_ids = topk_ids[max_topk_ids >= 0, max_topk_ids]
|
|
|
|
answers = [self.answer_list[max_id] for max_id in max_ids]
|
|
|
|
return answers
|
|
|
|
def predict_generate(self, feats: dict, data_samples=None):
|
|
"""Predict answers in a generation manner."""
|
|
device = feats['multimodal_embeds'].device
|
|
question_states = feats['multimodal_embeds']
|
|
question_atts = torch.ones(
|
|
question_states.size()[:-1], dtype=torch.long).to(device)
|
|
model_kwargs = {
|
|
'encoder_hidden_states': question_states,
|
|
'encoder_attention_mask': question_atts
|
|
}
|
|
|
|
bos_ids = torch.full((feats['multimodal_embeds'].shape[0], 1),
|
|
fill_value=feats['bos_token_id'],
|
|
device=device)
|
|
|
|
outputs = self.decoder.generate(
|
|
input_ids=bos_ids,
|
|
max_length=10,
|
|
min_length=1,
|
|
num_beams=self.num_beams,
|
|
eos_token_id=feats['sep_token_id'],
|
|
pad_token_id=feats['pad_token_id'],
|
|
**model_kwargs)
|
|
|
|
return outputs
|
|
|
|
def predict(self, feats: dict, data_samples=None):
|
|
"""Predict results from the extracted features."""
|
|
if self.inference_method == 'generate':
|
|
return self.predict_generate(feats, data_samples)
|
|
elif self.inference_method == 'rank':
|
|
return self.predict_rank(feats, data_samples)
|