Ma Zerun 6847d20d57
[Feature] Support multiple multi-modal algorithms and inferencers. (#1561)
* [Feat] Migrate blip caption to mmpretrain. (#50)

* Migrate blip caption to mmpretrain

* minor fix

* support train

* [Feature] Support OFA caption task. (#51)

* [Feature] Support OFA caption task.

* Remove duplicated files.

* [Feature] Support OFA vqa task. (#58)

* [Feature] Support OFA vqa task.

* Fix lint.

* [Feat] Add BLIP retrieval to mmpretrain. (#55)

* init

* minor fix for train

* fix according to comments

* refactor

* Update Blip retrieval. (#62)

* [Feature] Support OFA visual grounding task. (#59)

* [Feature] Support OFA visual grounding task.

* minor add TODO

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* [Feat] Add flamingos coco caption and vqa. (#60)

* first init

* init flamingo coco

* add vqa

* minor fix

* remove unnecessary modules

* Update config

* Use `ApplyToList`.

---------

Co-authored-by: mzr1996 <mzr1996@163.com>

* [Feature]: BLIP2 coco retrieval  (#53)

* [Feature]: Add blip2 retriever

* [Feature]: Add blip2 all modules

* [Feature]: Refine model

* [Feature]: x1

* [Feature]: Runnable coco ret

* [Feature]: Runnable version

* [Feature]: Fix lint

* [Fix]: Fix lint

* [Feature]: Use 364 img size

* [Feature]: Refactor blip2

* [Fix]: Fix lint

* refactor files

* minor fix

* minor fix

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* Remove

* fix blip caption inputs (#68)

* [Feat] Add BLIP NLVR support. (#67)

* first init

* init flamingo coco

* add vqa

* add nlvr

* refactor nlvr

* minor fix

* minor fix

* Update dataset

---------

Co-authored-by: mzr1996 <mzr1996@163.com>

* [Feature]: BLIP2 Caption (#70)

* [Feature]: Add language model

* [Feature]: blip2 caption forward

* [Feature]: Reproduce the results

* [Feature]: Refactor caption

* refine config

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* [Feat] Migrate BLIP VQA to mmpretrain (#69)

* reformat

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* refactor code

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* Update RefCOCO dataset

* [Fix] fix lint

* [Feature] Implement inference APIs for multi-modal tasks. (#65)

* [Feature] Implement inference APIs for multi-modal tasks.

* [Project] Add gradio demo.

* [Improve] Update requirements

* Update flamingo

* Update blip

* Add NLVR inferencer

* Update flamingo

* Update hugging face model register

* Update ofa vqa

* Update BLIP-vqa (#71)

* Update blip-vqa docstring (#72)

* Refine flamingo docstring (#73)

* [Feature]: BLIP2 VQA (#61)

* [Feature]: VQA forward

* [Feature]: Reproduce accuracy

* [Fix]: Fix lint

* [Fix]: Add blank line

* minor fix

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* [Feature]: BLIP2 docstring (#74)

* [Feature]: Add caption docstring

* [Feature]: Add docstring to blip2 vqa

* [Feature]: Add docstring to retrieval

* Update BLIP-2 metafile and README (#75)

* [Feature]: Add readme and docstring

* Update blip2 results

---------

Co-authored-by: mzr1996 <mzr1996@163.com>

* [Feature] BLIP Visual Grounding on MMPretrain Branch (#66)

* blip grounding merge with mmpretrain

* remove commit

* blip grounding test and inference api

* refcoco dataset

* refcoco dataset refine config

* rebasing

* gitignore

* rebasing

* minor edit

* minor edit

* Update blip-vqa docstring (#72)

* rebasing

* Revert "minor edit"

This reverts commit 639cec757c215e654625ed0979319e60f0be9044.

* blip grounding final

* precommit

* refine config

* refine config

* Update blip visual grounding

---------

Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com>
Co-authored-by: mzr1996 <mzr1996@163.com>

* Update visual grounding metric

* Update OFA docstring, README and metafiles. (#76)

* [Docs] Update installation docs and gradio demo docs. (#77)

* Update OFA name

* Update Visual Grounding Visualizer

* Integrate accelerate support

* Fix imports.

* Fix timm backbone

* Update imports

* Update README

* Update circle ci

* Update flamingo config

* Add gradio demo README

* [Feature]: Add scienceqa (#1571)

* [Feature]: Add scienceqa

* [Feature]: Change param name

* Update docs

* Update video

---------

Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com>
Co-authored-by: yingfhu <yingfhu@gmail.com>
Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>
Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com>
Co-authored-by: Rongjie Li <limo97@163.com>
2023-05-19 16:50:04 +08:00

717 lines
28 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from collections import ChainMap
from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Union
import mmengine.dist as dist
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModel
from torch import distributed as torch_dist
from mmpretrain.registry import MODELS, TOKENIZER
from mmpretrain.structures import DataSample
from mmpretrain.utils import track_on_main_process
def all_gather_concat(data: torch.Tensor) -> torch.Tensor:
"""Gather tensors with different first-dimension size and concat to one
tenosr.
Note:
Only the first dimension should be different.
Args:
data (Tensor): Tensor to be gathered.
Returns:
torch.Tensor: The concatenated tenosr.
"""
if dist.get_world_size() == 1:
return data
data_size = torch.tensor(data.size(0), device=data.device)
sizes_list = dist.all_gather(data_size)
max_length = max(sizes_list)
size_diff = max_length.item() - data_size.item()
if size_diff:
padding = torch.zeros(
size_diff, *data.size()[1:], device=data.device, dtype=data.dtype)
data = torch.cat((data, padding))
gather_list = dist.all_gather(data)
all_data = []
for tensor, size in zip(gather_list, sizes_list):
all_data.append(tensor[:size])
return torch.concat(all_data)
@MODELS.register_module()
class BlipRetrieval(BaseModel):
"""BLIP Retriever.
Args:
vision_backbone (dict): Backbone for extracting image features.
text_backbone (dict): Backbone for extracting text features.
multimodal_backbone (Optional[dict]): Backbone for extracting
multi-modal features.
vision_neck (Optional[dict]): The neck module to process image features
from vision backbone. Defaults to None.
text_neck (Optional[dict]): The neck module to process text features
from text backbone. Defaults to None.
head (Optional[Union[List[dict], dict]]): The head module to calculate
loss from processed single modality features.
See :mod:`mmmultimodal.models.heads`.
Notice that if the head is not set, `loss` method cannot be used.
Defaults to None.
multimodal_head (Optional[Union[List[dict], dict]]): The multi-modal
head module to calculate loss from processed multimodal features.
See :mod:`mmmultimodal.models.heads`.
Notice that if the head is not set, `loss` method cannot be used.
Defaults to None.
momentum (float): Momentum used for momentum contrast.
Defaults to .995.
negative_all_rank (bool): Whether to sample negative data from all
ranks for image text matching in training. Defaults to True.
temperature (float): Temperature parameter that controls the
concentration level of the distribution. Defaults to 0.07.
fast_match (bool): If False, select topk similarity as candidates and
compute the matching score. If True, return the similarity as the
matching score directly. Defaults to False.
topk (int): Select topk similarity as candidates for compute matching
scores. Notice that this is not the topk in evaluation.
Defaults to 256.
data_preprocessor (Optional[dict]): The config for preprocessing input
data. If None or no specified type, it will use
"MutimodalDataPreprocessor" as type.
See :class:`MutimodalDataPreprocessor` for more details.
Defaults to None.
init_cfg (Optional[dict]): the config to control the initialization.
Defaults to None.
"""
def __init__(self,
vision_backbone: dict,
text_backbone: dict,
multimodal_backbone: Optional[dict] = None,
vision_neck: Optional[dict] = None,
text_neck: Optional[dict] = None,
head: Optional[Union[List[dict], dict]] = None,
multimodal_head: Optional[Union[List[dict], dict]] = None,
tokenizer: Optional[dict] = None,
momentum: float = .995,
negative_all_rank: bool = True,
temperature: float = 0.07,
fast_match: bool = False,
topk: int = 256,
max_txt_len: int = 20,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[dict] = None):
if data_preprocessor is None:
data_preprocessor = {}
if isinstance(data_preprocessor, dict):
data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor')
data_preprocessor = MODELS.build(data_preprocessor)
super().__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
self.vision_backbone = MODELS.build(vision_backbone)
self.text_backbone = MODELS.build(text_backbone)
if multimodal_backbone is not None:
self.multimodal_backbone = MODELS.build(multimodal_backbone)
if vision_neck is not None:
self.vision_neck = MODELS.build(vision_neck)
if text_neck is not None:
self.text_neck = MODELS.build(text_neck)
if head is not None:
self.head = MODELS.build(head)
if multimodal_head is not None:
self.multimodal_head = MODELS.build(multimodal_head)
if tokenizer is not None:
self.tokenizer = TOKENIZER.build(tokenizer)
self.momentum = momentum
self.negative_all_rank = negative_all_rank
self.temp = nn.Parameter(temperature * torch.ones([]))
# Shares the same para
self.head.temp = self.temp
# create the momentum encoder
self.vision_backbone_m = deepcopy(self.vision_backbone)
self.text_backbone_m = deepcopy(self.text_backbone)
self.vision_neck_m = deepcopy(self.vision_neck)
self.text_neck_m = deepcopy(self.text_neck)
self.model_pairs = [
[self.vision_backbone, self.vision_backbone_m],
[self.text_backbone, self.text_backbone_m],
[self.vision_neck, self.vision_neck_m],
[self.text_neck, self.text_neck_m],
]
self.copy_params()
# multimodal backone shares weights with text backbone in BLIP
# No need to set up
# Notice that this topk is used for select k candidate to compute
# image-text score, but not the final metric topk in evaluation.
self.fast_match = fast_match
self.topk = topk
self.max_txt_len = max_txt_len
@property
def device(self):
return next(self.parameters()).device
def preprocess_text(self, data_samples):
sample_item = data_samples[0]
if sample_item is not None and 'text' in sample_item:
if isinstance(sample_item.get('text'), (list, tuple)):
texts = []
for sample in data_samples:
texts.extend(sample.get('text'))
elif isinstance(sample_item.get('text'), str):
texts = [sample.get('text') for sample in data_samples]
else:
raise TypeError('text must be a string or a list of strings')
else:
return None
# perform tokenize first if satisfied conditions
texts = self.tokenizer(
texts,
padding='max_length',
truncation=True,
max_length=self.max_txt_len,
return_tensors='pt',
).to(self.device)
return texts
def forward(self,
images: torch.tensor = None,
data_samples: Optional[List[DataSample]] = None,
mode: str = 'tensor') -> Union[Tuple, dict]:
"""The unified entry for a forward process in both training and test.
The method should accept two modes: "tensor", and "loss":
- "tensor": Forward the whole network and return tensor without any
post-processing, same as a common nn.Module.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
Note that this method doesn't handle neither back propagation nor
optimizer updating, which are done in the :meth:`train_step`.
For unified "predict" mode in other mm repos. It is noticed that
image-text retrieval cannot perform batch prediction since it will go
through all the samples. A standard process of retrieval evaluation is
to extract and collect all feats, and then predict all samples.
Therefore the `predict` mode here is remained as a trigger
to inform use to choose the right configurations.
Args:
images (torch.Tensor): The input inputs tensor of shape
(N, C, ...) in general.
data_samples (List[DataSample], optional): The annotation
data of every samples. It's required if ``mode="loss"``.
Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
Returns:
The return type depends on ``mode``.
- If ``mode="tensor"``, return a tuple.
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'tensor':
return self.extract_feat(images, data_samples)
elif mode == 'loss':
return self.loss(images, data_samples)
elif mode == 'predict':
return self.predict(images, data_samples)
else:
raise RuntimeError(f'Invalid mode "{mode}".')
def extract_feat(
self,
images: torch.Tensor = None,
data_samples: List[DataSample] = None,
return_texts=True,
return_embeds=None,
) -> Dict[str, torch.Tensor]:
"""Extract features from the input dict.
Args:
images (tensor, optional): The images to extract features.
Defaults to None.
data_samples (list, optional): The data samples containing texts
to extract features. Defaults to None.
return_texts (bool): Whether to return the tokenized text and the
corresponding attention masks. Defaults to True.
return_embeds (bool): Whether to return the text embedding and
image embedding. Defaults to None, which means to use
``self.fast_match``.
Returns:
Tuple[torch.Tensor]: The output features.
If multimodal_backbone is not exist, tuple of torch.Tensor
will be returned.
"""
if data_samples is not None:
texts = self.preprocess_text(data_samples)
else:
texts = None
assert images is not None or texts is not None, \
'At least single modality should be passed as inputs.'
results = {}
if texts is not None and return_texts:
results.update({
'text_ids': texts.input_ids,
'text_attn_mask': texts.attention_mask,
})
if return_embeds is None:
return_embeds = not self.fast_match
# extract image features
if images is not None:
output = self._extract_feat(images, modality='images')
results['image_feat'] = output['image_feat']
if return_embeds:
results['image_embeds'] = output['image_embeds']
# extract text features
if texts is not None:
output = self._extract_feat(texts, modality='texts')
results['text_feat'] = output['text_feat']
if return_embeds:
results['text_embeds'] = output['text_embeds']
return results
def _extract_feat(self, inputs: Union[torch.Tensor, dict],
modality: str) -> Tuple[torch.Tensor]:
"""Extract features from the single modality.
Args:
inputs (Union[torch.Tensor, dict]): A batch of inputs.
For image, a tensor of shape (N, C, ...) in general.
For text, a dict of tokenized text inputs.
modality (str): Modality feature to be extracted. Only two
options are supported.
- ``images``: Only extract image features, mostly used for
inference.
- ``texts``: Only extract text features, mostly used for
inference.
Returns:
Tuple[torch.Tensor]: The output features.
"""
if modality == 'images':
# extract image features
image_embeds = self.vision_backbone(inputs)[0]
image_feat = F.normalize(
self.vision_neck(image_embeds[:, 0, :]), dim=-1)
return {'image_embeds': image_embeds, 'image_feat': image_feat}
elif modality == 'texts':
# extract text features
text_output = self.text_backbone(
inputs.input_ids,
attention_mask=inputs.attention_mask,
token_type_ids=None,
return_dict=True,
mode='text',
)
text_embeds = text_output.last_hidden_state
text_feat = F.normalize(
self.text_neck(text_embeds[:, 0, :]), dim=-1)
return {'text_embeds': text_embeds, 'text_feat': text_feat}
else:
raise RuntimeError(f'Invalid modality "{modality}".')
def loss(
self,
images: torch.Tensor,
data_samples: Optional[List[DataSample]] = None,
) -> Dict[str, torch.tensor]:
"""Calculate losses from a batch of inputs and data samples.
Args:
inputs (dict): A batch of inputs. The input tensor with of
at least one modality. For image, the value is a tensor
of shape (N, C, ...) in general.
For text, the value is a dict of tokenized text inputs.
data_samples (Optional[List[DataSample]]):
The annotation data of every samples. Defaults to None.
Returns:
Dict[str, torch.tensor]: a dictionary of loss components of
both head and multimodal head.
"""
output = self.extract_feat(images, data_samples, return_embeds=True)
text_ids = output['text_ids']
text_attn_mask = output['text_attn_mask']
image_embeds = output['image_embeds']
image_feat = output['image_feat']
text_feat = output['text_feat']
image_atts = torch.ones(
image_embeds.size()[:-1], dtype=torch.long).to(self.device)
# get momentum features
with torch.no_grad():
self._momentum_update()
image_embeds_m = self.vision_backbone_m(images)[0]
image_feat_m = F.normalize(
self.vision_neck_m(image_embeds_m[:, 0, :]), dim=-1)
text_output_m = self.text_backbone_m(
text_ids,
attention_mask=text_attn_mask,
token_type_ids=None,
return_dict=True,
mode='text',
)
text_embeds_m = text_output_m.last_hidden_state
text_feat_m = F.normalize(
self.text_neck_m(text_embeds_m[:, 0, :]), dim=-1)
loss = self.head.loss(
([image_feat, text_feat, image_feat_m, text_feat_m], ),
data_samples)
# prepare for itm
encoder_input_ids = text_ids.clone()
encoder_input_ids[:,
0] = self.tokenizer.additional_special_tokens_ids[0]
output_pos = self.text_backbone(
encoder_input_ids,
attention_mask=text_attn_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
idx = torch.tensor([i.image_id for i in data_samples]).view(-1, 1)
bs = idx.size(0)
idxs = torch.cat(dist.all_gather(idx))
if self.negative_all_rank:
# compute sample similarity
with torch.no_grad():
mask = torch.eq(idx, idxs.t()).to(self.device)
image_feat_world = torch.cat(dist.all_gather(image_feat))
text_feat_world = torch.cat(dist.all_gather(text_feat))
sim_i2t = image_feat @ text_feat_world.t() / self.temp
sim_t2i = text_feat @ image_feat_world.t() / self.temp
weights_i2t = F.softmax(sim_i2t, dim=1)
weights_i2t.masked_fill_(mask, 0)
weights_t2i = F.softmax(sim_t2i, dim=1)
weights_t2i.masked_fill_(mask, 0)
world_size = dist.get_world_size()
if world_size == 1:
image_embeds_world = image_embeds
else:
image_embeds_world = torch.cat(
torch_dist.nn.all_gather(image_embeds))
# select a negative image (from all ranks) for each text
image_embeds_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
image_embeds_neg.append(image_embeds_world[neg_idx])
image_embeds_neg = torch.stack(image_embeds_neg, dim=0)
# select a negative text (from all ranks) for each image
input_ids_world = torch.cat(dist.all_gather(encoder_input_ids))
att_mask_world = torch.cat(dist.all_gather(text_attn_mask))
text_ids_neg = []
text_atts_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
text_ids_neg.append(input_ids_world[neg_idx])
text_atts_neg.append(att_mask_world[neg_idx])
text_ids_neg = torch.stack(text_ids_neg, dim=0)
text_atts_neg = torch.stack(text_atts_neg, dim=0)
text_ids_all = torch.cat([encoder_input_ids, text_ids_neg], dim=0)
text_atts_all = torch.cat([text_attn_mask, text_atts_neg], dim=0)
image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0)
image_atts_all = torch.cat([image_atts, image_atts], dim=0)
output_neg = self.text_backbone(
text_ids_all,
attention_mask=text_atts_all,
encoder_hidden_states=image_embeds_all,
encoder_attention_mask=image_atts_all,
return_dict=True,
)
vl_embeddings = torch.cat(
[
output_pos.last_hidden_state[:, 0, :],
output_neg.last_hidden_state[:, 0, :],
],
dim=0,
)
# create false data samples
data_samples.extend(
[DataSample(is_matched=False) for _ in range(2 * bs)])
loss_multimodal = self.multimodal_head.loss((vl_embeddings, ),
data_samples)
return dict(ChainMap(loss, loss_multimodal))
def predict(self, images, data_samples, cal_i2t=True, cal_t2i=True):
feats = self.extract_feat(images, data_samples)
return self.predict_all(
feats, data_samples, cal_i2t=cal_i2t, cal_t2i=cal_t2i)
def predict_all(self,
feats,
data_samples,
num_images=None,
num_texts=None,
cal_i2t=True,
cal_t2i=True):
text_ids = feats['text_ids']
text_ids[:, 0] = self.tokenizer.additional_special_tokens_ids[0]
text_attn_mask = feats['text_attn_mask']
image_embeds = feats.get('image_embeds', None)
image_feat = feats['image_feat']
text_feat = feats['text_feat']
num_images = num_images or image_feat.size(0)
num_texts = num_texts or text_feat.size(0)
if not self.fast_match:
image_embeds_all = all_gather_concat(image_embeds)[:num_images]
else:
image_embeds_all = None
image_feat_all = all_gather_concat(image_feat)[:num_images]
text_feat_all = all_gather_concat(text_feat)[:num_texts]
text_ids_all = all_gather_concat(text_ids)[:num_texts]
text_attn_mask_all = all_gather_concat(text_attn_mask)[:num_texts]
results = []
if cal_i2t:
result_i2t = self.compute_score_matrix_i2t(
image_feat,
image_embeds,
text_feat_all,
text_ids_all,
text_attn_mask_all,
)
results.append(
self._get_predictions(result_i2t, data_samples, mode='i2t'))
if cal_t2i:
result_t2i = self.compute_score_matrix_t2i(
image_feat_all,
image_embeds_all,
text_feat,
text_ids,
text_attn_mask,
)
results.append(
self._get_predictions(result_t2i, data_samples, mode='t2i'))
return tuple(results)
def compute_score_matrix_i2t(self, img_feats, img_embeds, text_feats,
text_ids, text_atts):
"""Compare the score matrix for image-to-text retrieval. Every image
should compare to all the text features.
Args:
img_feats (torch.Tensor): The input img feats tensor with shape
(M, C). M stands for numbers of samples on a single GPU.
img_embeds (torch.Tensor): The input img embeds tensor with shape
(M, C). M stands for numbers of samples on a single GPU.
text_feats (torch.Tensor): The input text feats tensor with shape
(N, C). N stands for numbers of all samples on all GPUs.
text_ids (torch.Tensor): The input tensor with shape (N, C).
text_atts (torch.Tensor): The input tensor with shape (N, C).
Returns:
torch.Tensor: Score matrix of image-to-text retrieval.
"""
# compute i2t sim matrix
sim_matrix_i2t = img_feats @ text_feats.t()
if self.fast_match:
return sim_matrix_i2t
score_matrix_i2t = torch.full((img_feats.size(0), text_feats.size(0)),
-100.0).to(self.device)
for i in track_on_main_process(
range(img_feats.size(0)), 'Compute I2T scores...'):
sims = sim_matrix_i2t[i]
topk_sim, topk_idx = sims.topk(k=self.topk, dim=0)
encoder_output = img_embeds[i].repeat(self.topk, 1, 1)
encoder_att = torch.ones(
encoder_output.size()[:-1], dtype=torch.long).to(self.device)
output = self.text_backbone(
text_ids[topk_idx],
attention_mask=text_atts[topk_idx],
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_att,
return_dict=True,
)
score = self.multimodal_head(
(output.last_hidden_state[:, 0, :], ))[:, 1]
score_matrix_i2t[i, topk_idx] = score + topk_sim
return score_matrix_i2t
def compute_score_matrix_t2i(self, img_feats, img_embeds, text_feats,
text_ids, text_atts):
"""Compare the score matrix for text-to-image retrieval. Every text
should compare to all the image features.
Args:
img_feats (torch.Tensor): The input img feats tensor with shape
(M, C). M stands for numbers of samples on a single GPU.
img_embeds (torch.Tensor): The input img embeds tensor with shape
(M, C). M stands for numbers of samples on a single GPU.
text_feats (torch.Tensor): The input text feats tensor with shape
(N, C). N stands for numbers of all samples on all GPUs.
text_ids (torch.Tensor): The input tensor with shape (M, C).
text_atts (torch.Tensor): The input tensor with shape (M, C).
Returns:
torch.Tensor: Score matrix of text-to-image retrieval.
"""
# compute t2i sim matrix
sim_matrix_t2i = text_feats @ img_feats.t()
if self.fast_match:
return sim_matrix_t2i
score_matrix_t2i = torch.full((text_feats.size(0), img_feats.size(0)),
-100.0).to(self.device)
for i in track_on_main_process(
range(text_feats.size(0)), 'Compute T2I scores...'):
sims = sim_matrix_t2i[i]
topk_sim, topk_idx = sims.topk(k=self.topk, dim=0)
encoder_output = img_embeds[topk_idx]
encoder_att = torch.ones(
encoder_output.size()[:-1], dtype=torch.long).to(self.device)
output = self.text_backbone(
text_ids[i].repeat(self.topk, 1),
attention_mask=text_atts[i].repeat(self.topk, 1),
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_att,
return_dict=True,
)
score = self.multimodal_head(
(output.last_hidden_state[:, 0, :], ))[:, 1]
score_matrix_t2i[i, topk_idx] = score + topk_sim
return score_matrix_t2i
def _get_predictions(self,
result: torch.Tensor,
data_samples: List[DataSample],
mode: str = 'i2t'):
"""Post-process the output of retriever.
Args:
result (torch.Tensor): Score matrix of single retrieve,
either from image or text.
data_samples (List[DataSample], optional): The annotation
data of every samples.
mode (str): Retrieve mode, either `i2t` for image to text, or `t2i`
text to image. Defaults to `i2t`.
Returns:
List[DataSample]: the raw data_samples with
the predicted results.
"""
# create data sample if not exists
if data_samples is None:
data_samples = [DataSample() for _ in range(result.size(0))]
elif mode == 't2i':
# Process data samples to align with the num of texts.
new_data_samples = []
for sample in data_samples:
if isinstance(sample.text, (list, tuple)):
texts = sample.text
else:
texts = [sample.text]
for i, text in enumerate(texts):
new_sample = DataSample(text=text)
if 'gt_image_id' in sample:
new_sample.gt_label = sample.gt_image_id[i]
new_data_samples.append(new_sample)
assert len(new_data_samples) == result.size(0)
data_samples = new_data_samples
elif mode == 'i2t':
for sample in data_samples:
if 'gt_text_id' in sample:
sample.gt_label = sample.gt_text_id
else:
raise ValueError(f'Type {mode} is not supported.')
for data_sample, score in zip(data_samples, result):
idx = score.argmax(keepdim=True).detach()
data_sample.set_pred_score(score)
data_sample.set_pred_label(idx)
return data_samples
# TODO: add temperaily
@torch.no_grad()
def copy_params(self):
for model_pair in self.model_pairs:
for param, param_m in zip(model_pair[0].parameters(),
model_pair[1].parameters()):
param_m.data.copy_(param.data) # initialize
param_m.requires_grad = False # not update by gradient
@torch.no_grad()
def _momentum_update(self):
for model_pair in self.model_pairs:
for (name,
param), (name_m,
param_m) in zip(model_pair[0].named_parameters(),
model_pair[1].named_parameters()):
# hack to behave the same
if any([i in name for i in ['8', '9', '10', '11']
]) and 'layers' in name and any(
[i in name for i in ['attn', 'ffn']]):
param_m.data = param.data
else:
param_m.data = param_m.data * self.momentum + \
param.data * (1.0 - self.momentum)