YifanXu74-patch-1
YifanXu74 2023-10-07 23:23:13 +08:00
parent 20ec7e7687
commit 616b5014b8
1 changed files with 6 additions and 6 deletions

View File

@ -136,10 +136,10 @@ class MaskedCrossAttention(nn.Module):
norm_kv = False,
share_kv=False,
cfg=None,
spare_forward=False,
spase_forward=False,
):
super().__init__()
self.spare_forward=spare_forward
self.spase_forward=spase_forward
self.scale = dim_head ** -0.5
self.heads = heads
self.share_kv=share_kv
@ -189,7 +189,7 @@ class MaskedCrossAttention(nn.Module):
vision, # (batch, vision, dim)
attention_mask = None, # (batch, vision, text)
):
if self.spare_forward:
if self.spase_forward:
batch_size = x.shape[0]
x, vision, attention_mask = self._construct_sparse_inputs(x, vision, attention_mask)
@ -233,7 +233,7 @@ class MaskedCrossAttention(nn.Module):
out = einsum('... i j, ... j d -> ... i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
if self.spare_forward:
if self.spase_forward:
assert out.shape[1]==1
out = rearrange(out, '(b t) n d -> b t (n d)', b=batch_size)
@ -265,7 +265,7 @@ class GatedCrossAttentionBlock(nn.Module):
enable_ffn = True
):
super().__init__()
self.attn = MaskedCrossAttention(input_dim = dim, dim_head = dim_head, heads = heads, share_kv=share_kv, cfg=cfg, norm_kv=True, spare_forward=True)
self.attn = MaskedCrossAttention(input_dim = dim, dim_head = dim_head, heads = heads, share_kv=share_kv, cfg=cfg, norm_kv=True, spase_forward=True)
if cfg.VISION_QUERY.FIX_ATTN_GATE == -1.0:
if cfg.VISION_QUERY.CONDITION_GATE:
if cfg.VISION_QUERY.NONLINEAR_GATE:
@ -387,7 +387,7 @@ class PreSelectBlock(nn.Module):
cfg=None,
):
super().__init__()
self.image_condition = MaskedCrossAttention(input_dim = dim, output_dim = out_dim, dim_head = dim_head, heads = heads, norm_kv=True, share_kv=share_kv, cfg=cfg, spare_forward=False)
self.image_condition = MaskedCrossAttention(input_dim = dim, output_dim = out_dim, dim_head = dim_head, heads = heads, norm_kv=True, share_kv=share_kv, cfg=cfg, spase_forward=False)
self.ff = FeedForward(out_dim, mult = ff_mult)
if dim != out_dim: