mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* [Feat] Migrate blip caption to mmpretrain. (#50) * Migrate blip caption to mmpretrain * minor fix * support train * [Feature] Support OFA caption task. (#51) * [Feature] Support OFA caption task. * Remove duplicated files. * [Feature] Support OFA vqa task. (#58) * [Feature] Support OFA vqa task. * Fix lint. * [Feat] Add BLIP retrieval to mmpretrain. (#55) * init * minor fix for train * fix according to comments * refactor * Update Blip retrieval. (#62) * [Feature] Support OFA visual grounding task. (#59) * [Feature] Support OFA visual grounding task. * minor add TODO --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feat] Add flamingos coco caption and vqa. (#60) * first init * init flamingo coco * add vqa * minor fix * remove unnecessary modules * Update config * Use `ApplyToList`. --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature]: BLIP2 coco retrieval (#53) * [Feature]: Add blip2 retriever * [Feature]: Add blip2 all modules * [Feature]: Refine model * [Feature]: x1 * [Feature]: Runnable coco ret * [Feature]: Runnable version * [Feature]: Fix lint * [Fix]: Fix lint * [Feature]: Use 364 img size * [Feature]: Refactor blip2 * [Fix]: Fix lint * refactor files * minor fix * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * Remove * fix blip caption inputs (#68) * [Feat] Add BLIP NLVR support. (#67) * first init * init flamingo coco * add vqa * add nlvr * refactor nlvr * minor fix * minor fix * Update dataset --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature]: BLIP2 Caption (#70) * [Feature]: Add language model * [Feature]: blip2 caption forward * [Feature]: Reproduce the results * [Feature]: Refactor caption * refine config --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feat] Migrate BLIP VQA to mmpretrain (#69) * reformat * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * refactor code --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * Update RefCOCO dataset * [Fix] fix lint * [Feature] Implement inference APIs for multi-modal tasks. (#65) * [Feature] Implement inference APIs for multi-modal tasks. * [Project] Add gradio demo. * [Improve] Update requirements * Update flamingo * Update blip * Add NLVR inferencer * Update flamingo * Update hugging face model register * Update ofa vqa * Update BLIP-vqa (#71) * Update blip-vqa docstring (#72) * Refine flamingo docstring (#73) * [Feature]: BLIP2 VQA (#61) * [Feature]: VQA forward * [Feature]: Reproduce accuracy * [Fix]: Fix lint * [Fix]: Add blank line * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feature]: BLIP2 docstring (#74) * [Feature]: Add caption docstring * [Feature]: Add docstring to blip2 vqa * [Feature]: Add docstring to retrieval * Update BLIP-2 metafile and README (#75) * [Feature]: Add readme and docstring * Update blip2 results --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature] BLIP Visual Grounding on MMPretrain Branch (#66) * blip grounding merge with mmpretrain * remove commit * blip grounding test and inference api * refcoco dataset * refcoco dataset refine config * rebasing * gitignore * rebasing * minor edit * minor edit * Update blip-vqa docstring (#72) * rebasing * Revert "minor edit" This reverts commit 639cec757c215e654625ed0979319e60f0be9044. * blip grounding final * precommit * refine config * refine config * Update blip visual grounding --------- Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com> Co-authored-by: mzr1996 <mzr1996@163.com> * Update visual grounding metric * Update OFA docstring, README and metafiles. (#76) * [Docs] Update installation docs and gradio demo docs. (#77) * Update OFA name * Update Visual Grounding Visualizer * Integrate accelerate support * Fix imports. * Fix timm backbone * Update imports * Update README * Update circle ci * Update flamingo config * Add gradio demo README * [Feature]: Add scienceqa (#1571) * [Feature]: Add scienceqa * [Feature]: Change param name * Update docs * Update video --------- Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com> Co-authored-by: yingfhu <yingfhu@gmail.com> Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com> Co-authored-by: Rongjie Li <limo97@163.com>
399 lines
14 KiB
Python
399 lines
14 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
"""Taken from https://github.com/lucidrains/flamingo-pytorch."""
|
|
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from einops import rearrange, repeat
|
|
from torch import einsum, nn
|
|
|
|
|
|
def FeedForward(dim, mult: int = 4):
|
|
"""Feedforward layers.
|
|
|
|
Args:
|
|
mult (int): Layer expansion muliplier. Defaults to 4.
|
|
"""
|
|
inner_dim = int(dim * mult)
|
|
return nn.Sequential(
|
|
nn.LayerNorm(dim),
|
|
nn.Linear(dim, inner_dim, bias=False),
|
|
nn.GELU(),
|
|
nn.Linear(inner_dim, dim, bias=False),
|
|
)
|
|
|
|
|
|
class PerceiverAttention(nn.Module):
|
|
"""Perceiver attetion layers.
|
|
|
|
Args:
|
|
dim (int): Input dimensions.
|
|
dim_head (int): Number of dimension heads. Defaults to 64.
|
|
heads (int): Number of heads. Defaults to 8.
|
|
"""
|
|
|
|
def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8):
|
|
super().__init__()
|
|
self.scale = dim_head**-0.5
|
|
self.heads = heads
|
|
inner_dim = dim_head * heads
|
|
|
|
self.norm_media = nn.LayerNorm(dim)
|
|
self.norm_latents = nn.LayerNorm(dim)
|
|
|
|
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
|
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
|
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
|
|
|
def forward(self, x: torch.Tensor, latents: torch.Tensor):
|
|
"""Forward function.
|
|
|
|
Args:
|
|
x (torch.Tensor): image features of shape (b, T, n1, D).
|
|
latent (torch.Tensor): latent features of shape (b, T, n2, D).
|
|
"""
|
|
x = self.norm_media(x)
|
|
latents = self.norm_latents(latents)
|
|
|
|
h = self.heads
|
|
|
|
q = self.to_q(latents)
|
|
kv_input = torch.cat((x, latents), dim=-2)
|
|
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
|
q = rearrange(q, 'b t n (h d) -> b h t n d', h=h)
|
|
k = rearrange(k, 'b t n (h d) -> b h t n d', h=h)
|
|
v = rearrange(v, 'b t n (h d) -> b h t n d', h=h)
|
|
q = q * self.scale
|
|
|
|
# attention
|
|
sim = einsum('... i d, ... j d -> ... i j', q, k)
|
|
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
|
attn = sim.softmax(dim=-1)
|
|
|
|
out = einsum('... i j, ... j d -> ... i d', attn, v)
|
|
out = rearrange(out, 'b h t n d -> b t n (h d)', h=h)
|
|
return self.to_out(out)
|
|
|
|
|
|
class PerceiverResampler(nn.Module):
|
|
"""Perceiver resampler layers.
|
|
|
|
Args:
|
|
dim (int): Input dimensions.
|
|
depth (int): Depth of resampler. Defaults to 6.
|
|
dim_head (int): Number of dimension heads. Defaults to 64.
|
|
heads (int): Number of heads. Defaults to 8.
|
|
num_latents (int): Number of latents. Defaults to 64.
|
|
max_num_media (int, optional): Max number of media.
|
|
Defaults to None.
|
|
max_num_frames (int, optional): Max number of frames.
|
|
Defaults to None.
|
|
ff_mult (int): Feed forward multiplier. Defaults to 4.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim: int,
|
|
depth: int = 6,
|
|
dim_head: int = 64,
|
|
heads: int = 8,
|
|
num_latents: int = 64,
|
|
max_num_media: Optional[int] = None,
|
|
max_num_frames: Optional[int] = None,
|
|
ff_mult: int = 4,
|
|
):
|
|
super().__init__()
|
|
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
|
self.frame_embs = (
|
|
nn.Parameter(torch.randn(max_num_frames, dim))
|
|
if max_num_frames is not None else None)
|
|
self.media_time_embs = (
|
|
nn.Parameter(torch.randn(max_num_media, 1, dim))
|
|
if max_num_media is not None else None)
|
|
|
|
self.layers = nn.ModuleList([])
|
|
for _ in range(depth):
|
|
self.layers.append(
|
|
nn.ModuleList([
|
|
PerceiverAttention(
|
|
dim=dim, dim_head=dim_head, heads=heads),
|
|
FeedForward(dim=dim, mult=ff_mult),
|
|
]))
|
|
|
|
self.norm = nn.LayerNorm(dim)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
"""Forward function for perceiver sampler.
|
|
|
|
Args:
|
|
x (torch.Tensor): image features of shape (b, T, F, v, D)
|
|
|
|
Returns:
|
|
torch.Tensor: shape (b, T, n, D) where n is self.num_latents
|
|
"""
|
|
b, T, F, v = x.shape[:4]
|
|
|
|
# frame and media time embeddings
|
|
if self.frame_embs is not None:
|
|
frame_embs = repeat(
|
|
self.frame_embs[:F], 'F d -> b T F v d', b=b, T=T, v=v)
|
|
x = x + frame_embs
|
|
x = rearrange(x, 'b T F v d -> b T (F v) d'
|
|
) # flatten the frame and spatial dimensions
|
|
if self.media_time_embs is not None:
|
|
x = x + self.media_time_embs[:T]
|
|
|
|
# blocks
|
|
latents = repeat(self.latents, 'n d -> b T n d', b=b, T=T)
|
|
for attn, ff in self.layers:
|
|
latents = attn(x, latents) + latents
|
|
latents = ff(latents) + latents
|
|
return self.norm(latents)
|
|
|
|
|
|
class MaskedCrossAttention(nn.Module):
|
|
"""Masked cross attention layers.
|
|
|
|
Args:
|
|
dim (int): Input text feature dimensions.
|
|
dim_visual (int): Input visual feature dimensions.
|
|
dim_head (int): Number of dimension heads. Defaults to 64.
|
|
heads (int): Number of heads. Defaults to 8.
|
|
only_attend_immediate_media (bool): Whether attend immediate media.
|
|
Defaults to True.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim: int,
|
|
dim_visual: int,
|
|
dim_head: int = 64,
|
|
heads: int = 8,
|
|
only_attend_immediate_media: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.scale = dim_head**-0.5
|
|
self.heads = heads
|
|
inner_dim = dim_head * heads
|
|
|
|
self.norm = nn.LayerNorm(dim)
|
|
|
|
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
|
self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
|
|
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
|
|
|
# whether for text to only attend to immediate preceding image
|
|
# or all previous images
|
|
self.only_attend_immediate_media = only_attend_immediate_media
|
|
|
|
def forward(self,
|
|
x: torch.Tensor,
|
|
media: torch.Tensor,
|
|
media_locations: Optional[torch.Tensor] = None,
|
|
attend_previous: bool = True):
|
|
"""Forward function for perceiver sampler.
|
|
|
|
Args:
|
|
x (torch.Tensor): text features of shape (B, T_txt, D_txt).
|
|
media (torch.Tensor): image features of shape
|
|
(B, T_img, n, D_img) where n is the dim of the latents.
|
|
media_locations (torch.Tensor, optional): boolean mask identifying
|
|
the media tokens in x of shape (B, T_txt). Defaults to None.
|
|
attend_previous (bool): If false, ignores immediately preceding
|
|
image and starts attending when following image.
|
|
Defaults to True.
|
|
"""
|
|
_, T_img, n = media.shape[:3]
|
|
h = self.heads
|
|
|
|
x = self.norm(x)
|
|
|
|
q = self.to_q(x)
|
|
media = rearrange(media, 'b t n d -> b (t n) d')
|
|
|
|
k, v = self.to_kv(media).chunk(2, dim=-1)
|
|
q = rearrange(q, 'b n (h d) -> b h n d', h=h)
|
|
k = rearrange(k, 'b n (h d) -> b h n d', h=h)
|
|
v = rearrange(v, 'b n (h d) -> b h n d', h=h)
|
|
|
|
q = q * self.scale
|
|
|
|
sim = einsum('... i d, ... j d -> ... i j', q, k)
|
|
|
|
if media_locations is not None:
|
|
# at each boolean of True, increment the time counter
|
|
# (relative to media time)
|
|
text_time = media_locations.cumsum(dim=-1)
|
|
media_time = torch.arange(T_img, device=x.device) + 1
|
|
|
|
if not attend_previous:
|
|
text_time[~media_locations] += 1
|
|
# make sure max is still the number of images in the sequence
|
|
text_time[text_time > repeat(
|
|
torch.count_nonzero(media_locations, dim=1),
|
|
'b -> b i',
|
|
i=text_time.shape[1],
|
|
)] = 0
|
|
|
|
# text time must equal media time if only attending to most
|
|
# immediate image otherwise, as long as text time is greater than
|
|
# media time (if attending to all previous images / media)
|
|
mask_op = torch.eq if self.only_attend_immediate_media else torch.ge # noqa
|
|
|
|
text_to_media_mask = mask_op(
|
|
rearrange(text_time, 'b i -> b 1 i 1'),
|
|
repeat(media_time, 'j -> 1 1 1 (j n)', n=n),
|
|
)
|
|
sim = sim.masked_fill(~text_to_media_mask,
|
|
-torch.finfo(sim.dtype).max)
|
|
|
|
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
|
attn = sim.softmax(dim=-1)
|
|
|
|
if media_locations is not None and self.only_attend_immediate_media:
|
|
# any text without a preceding media needs to have
|
|
# attention zeroed out
|
|
text_without_media_mask = text_time == 0
|
|
text_without_media_mask = rearrange(text_without_media_mask,
|
|
'b i -> b 1 i 1')
|
|
attn = attn.masked_fill(text_without_media_mask, 0.0)
|
|
|
|
out = einsum('... i j, ... j d -> ... i d', attn, v)
|
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
|
return self.to_out(out)
|
|
|
|
|
|
class GatedCrossAttentionBlock(nn.Module):
|
|
"""Gated cross attention layers.
|
|
|
|
Args:
|
|
dim (int): Input text feature dimensions.
|
|
dim_visual (int): Input visual feature dimensions.
|
|
dim_head (int): Number of dimension heads. Defaults to 64.
|
|
heads (int): Number of heads. Defaults to 8.
|
|
ff_mult (int): Feed forward multiplier. Defaults to 4.
|
|
only_attend_immediate_media (bool): Whether attend immediate media.
|
|
Defaults to True.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim: int,
|
|
dim_visual: int,
|
|
dim_head: int = 64,
|
|
heads: int = 8,
|
|
ff_mult: int = 4,
|
|
only_attend_immediate_media: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.attn = MaskedCrossAttention(
|
|
dim=dim,
|
|
dim_visual=dim_visual,
|
|
dim_head=dim_head,
|
|
heads=heads,
|
|
only_attend_immediate_media=only_attend_immediate_media,
|
|
)
|
|
self.attn_gate = nn.Parameter(torch.tensor([0.0]))
|
|
|
|
self.ff = FeedForward(dim, mult=ff_mult)
|
|
self.ff_gate = nn.Parameter(torch.tensor([0.0]))
|
|
|
|
def forward(self,
|
|
x: torch.Tensor,
|
|
media: torch.Tensor,
|
|
media_locations: Optional[torch.Tensor] = None,
|
|
attend_previous: bool = True):
|
|
"""Forward function for perceiver sampler.
|
|
|
|
Args:
|
|
x (torch.Tensor): text features of shape (B, T_txt, D_txt).
|
|
media (torch.Tensor): image features of shape
|
|
(B, T_img, n, D_img) where n is the dim of the latents.
|
|
media_locations (torch.Tensor, optional): boolean mask identifying
|
|
the media tokens in x of shape (B, T_txt). Defaults to None.
|
|
attend_previous (bool): If false, ignores immediately preceding
|
|
image and starts attending when following image.
|
|
Defaults to True.
|
|
"""
|
|
x = (
|
|
self.attn(
|
|
x,
|
|
media,
|
|
media_locations=media_locations,
|
|
attend_previous=attend_previous,
|
|
) * self.attn_gate.tanh() + x)
|
|
x = self.ff(x) * self.ff_gate.tanh() + x
|
|
|
|
return x
|
|
|
|
|
|
class FlamingoLayer(nn.Module):
|
|
"""Faminogo layers.
|
|
|
|
Args:
|
|
gated_cross_attn_layer (nn.Module): Gated cross attention layer.
|
|
decoder_layer (nn.Module): Decoder layer.
|
|
"""
|
|
|
|
def __init__(self, gated_cross_attn_layer: nn.Module,
|
|
decoder_layer: nn.Module):
|
|
super().__init__()
|
|
self.gated_cross_attn_layer = gated_cross_attn_layer
|
|
self.decoder_layer = decoder_layer
|
|
self.vis_x = None
|
|
self.media_locations = None
|
|
|
|
def is_conditioned(self) -> bool:
|
|
"""Check whether the layer is conditioned."""
|
|
return self.vis_x is not None
|
|
|
|
def condition_vis_x(self, vis_x):
|
|
"""Set condition vision features."""
|
|
self.vis_x = vis_x
|
|
|
|
def condition_media_locations(self, media_locations):
|
|
"""Set condition media locations."""
|
|
self.media_locations = media_locations
|
|
|
|
def condition_attend_previous(self, attend_previous):
|
|
"""Set attend previous."""
|
|
self.attend_previous = attend_previous
|
|
|
|
def forward(
|
|
self,
|
|
lang_x: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
**decoder_layer_kwargs,
|
|
):
|
|
"""Forward function.
|
|
|
|
Args:
|
|
lang_x (torch.Tensor): language inputs.
|
|
attention_mask (torch.Tensor, optional): text attention mask.
|
|
Defaults to None.
|
|
**decoder_layer_kwargs: Other decoder layer keyword arguments.
|
|
"""
|
|
if self.gated_cross_attn_layer is None:
|
|
return self.decoder_layer(
|
|
lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
|
|
|
|
if self.vis_x is None:
|
|
raise ValueError('vis_x must be conditioned before forward pass')
|
|
|
|
if self.media_locations is None:
|
|
raise ValueError(
|
|
'media_locations must be conditioned before forward pass')
|
|
|
|
lang_x = self.gated_cross_attn_layer(
|
|
lang_x,
|
|
self.vis_x,
|
|
media_locations=self.media_locations,
|
|
attend_previous=self.attend_previous,
|
|
)
|
|
lang_x = self.decoder_layer(
|
|
lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
|
|
return lang_x
|