mirror of https://github.com/YifanXu74/MQ-Det.git
init
parent
20ec7e7687
commit
616b5014b8
|
@ -136,10 +136,10 @@ class MaskedCrossAttention(nn.Module):
|
|||
norm_kv = False,
|
||||
share_kv=False,
|
||||
cfg=None,
|
||||
spare_forward=False,
|
||||
spase_forward=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.spare_forward=spare_forward
|
||||
self.spase_forward=spase_forward
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
self.share_kv=share_kv
|
||||
|
@ -189,7 +189,7 @@ class MaskedCrossAttention(nn.Module):
|
|||
vision, # (batch, vision, dim)
|
||||
attention_mask = None, # (batch, vision, text)
|
||||
):
|
||||
if self.spare_forward:
|
||||
if self.spase_forward:
|
||||
batch_size = x.shape[0]
|
||||
x, vision, attention_mask = self._construct_sparse_inputs(x, vision, attention_mask)
|
||||
|
||||
|
@ -233,7 +233,7 @@ class MaskedCrossAttention(nn.Module):
|
|||
out = einsum('... i j, ... j d -> ... i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
if self.spare_forward:
|
||||
if self.spase_forward:
|
||||
assert out.shape[1]==1
|
||||
out = rearrange(out, '(b t) n d -> b t (n d)', b=batch_size)
|
||||
|
||||
|
@ -265,7 +265,7 @@ class GatedCrossAttentionBlock(nn.Module):
|
|||
enable_ffn = True
|
||||
):
|
||||
super().__init__()
|
||||
self.attn = MaskedCrossAttention(input_dim = dim, dim_head = dim_head, heads = heads, share_kv=share_kv, cfg=cfg, norm_kv=True, spare_forward=True)
|
||||
self.attn = MaskedCrossAttention(input_dim = dim, dim_head = dim_head, heads = heads, share_kv=share_kv, cfg=cfg, norm_kv=True, spase_forward=True)
|
||||
if cfg.VISION_QUERY.FIX_ATTN_GATE == -1.0:
|
||||
if cfg.VISION_QUERY.CONDITION_GATE:
|
||||
if cfg.VISION_QUERY.NONLINEAR_GATE:
|
||||
|
@ -387,7 +387,7 @@ class PreSelectBlock(nn.Module):
|
|||
cfg=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_condition = MaskedCrossAttention(input_dim = dim, output_dim = out_dim, dim_head = dim_head, heads = heads, norm_kv=True, share_kv=share_kv, cfg=cfg, spare_forward=False)
|
||||
self.image_condition = MaskedCrossAttention(input_dim = dim, output_dim = out_dim, dim_head = dim_head, heads = heads, norm_kv=True, share_kv=share_kv, cfg=cfg, spase_forward=False)
|
||||
self.ff = FeedForward(out_dim, mult = ff_mult)
|
||||
|
||||
if dim != out_dim:
|
||||
|
|
Loading…
Reference in New Issue