Support dynamic_resize in eva.py models

This commit is contained in:
Ross Wightman 2023-08-26 22:32:39 -07:00 committed by Ross Wightman
parent ea3519a5f0
commit 1f4512fca3
2 changed files with 36 additions and 14 deletions

View File

@ -400,13 +400,12 @@ class RotaryEmbeddingCat(nn.Module):
temperature=temperature, temperature=temperature,
step=1, step=1,
) )
print(bands)
self.register_buffer( self.register_buffer(
'bands', 'bands',
bands, bands,
persistent=False, persistent=False,
) )
self.embed = None self.pos_embed = None
else: else:
# cache full sin/cos embeddings if shape provided up front # cache full sin/cos embeddings if shape provided up front
embeds = build_rotary_pos_embed( embeds = build_rotary_pos_embed(
@ -425,17 +424,19 @@ class RotaryEmbeddingCat(nn.Module):
) )
def get_embed(self, shape: Optional[List[int]] = None): def get_embed(self, shape: Optional[List[int]] = None):
if self.bands is not None: if self.bands is not None and shape is not None:
# rebuild embeddings every call, use if target shape changes # rebuild embeddings every call, use if target shape changes
_assert(shape is not None, 'valid shape needed')
embeds = build_rotary_pos_embed( embeds = build_rotary_pos_embed(
shape, shape,
self.bands, self.bands,
in_pixels=self.in_pixels, in_pixels=self.in_pixels,
ref_feat_shape=self.ref_feat_shape,
) )
return torch.cat(embeds, -1) return torch.cat(embeds, -1)
else: elif self.pos_embed is not None:
return self.pos_embed return self.pos_embed
else:
assert False, "get_embed() requires pre-computed pos_embed or valid shape w/ pre-computed bands"
def forward(self, x): def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2 # assuming channel-first tensor where spatial dim are >= 2

View File

@ -367,6 +367,7 @@ class Eva(nn.Module):
use_abs_pos_emb: bool = True, use_abs_pos_emb: bool = True,
use_rot_pos_emb: bool = False, use_rot_pos_emb: bool = False,
use_post_norm: bool = False, use_post_norm: bool = False,
dynamic_size: bool = False,
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None, ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
head_init_scale: float = 0.001, head_init_scale: float = 0.001,
): ):
@ -406,13 +407,19 @@ class Eva(nn.Module):
self.global_pool = global_pool self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0 self.num_prefix_tokens = 1 if class_token else 0
self.dynamic_size = dynamic_size
self.grad_checkpointing = False self.grad_checkpointing = False
embed_args = {}
if dynamic_size:
# flatten deferred until after pos embed
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
self.patch_embed = PatchEmbed( self.patch_embed = PatchEmbed(
img_size=img_size, img_size=img_size,
patch_size=patch_size, patch_size=patch_size,
in_chans=in_chans, in_chans=in_chans,
embed_dim=embed_dim, embed_dim=embed_dim,
**embed_args,
) )
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
@ -435,7 +442,7 @@ class Eva(nn.Module):
self.rope = RotaryEmbeddingCat( self.rope = RotaryEmbeddingCat(
embed_dim // num_heads, embed_dim // num_heads,
in_pixels=False, in_pixels=False,
feat_shape=self.patch_embed.grid_size, feat_shape=None if dynamic_size else self.patch_embed.grid_size,
ref_feat_shape=ref_feat_shape, ref_feat_shape=ref_feat_shape,
) )
else: else:
@ -519,30 +526,44 @@ class Eva(nn.Module):
self.global_pool = global_pool self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
x = self.patch_embed(x) if self.dynamic_size:
B, H, W, C = x.shape
if self.pos_embed is not None:
pos_embed = resample_abs_pos_embed(
self.pos_embed,
(H, W),
num_prefix_tokens=self.num_prefix_tokens,
)
else:
pos_embed = None
x = x.view(B, -1, C)
rot_pos_embed = self.rope.get_embed(shape=(H, W)) if self.rope is not None else None
else:
pos_embed = self.pos_embed
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
if self.cls_token is not None: if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
if pos_embed is not None:
# apply abs position embedding x = x + pos_embed
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x) x = self.pos_drop(x)
# obtain shared rotary position embedding and apply patch dropout # obtain shared rotary position embedding and apply patch dropout
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
if self.patch_drop is not None: if self.patch_drop is not None:
x, keep_indices = self.patch_drop(x) x, keep_indices = self.patch_drop(x)
if rot_pos_embed is not None and keep_indices is not None: if rot_pos_embed is not None and keep_indices is not None:
rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices) rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
return x, rot_pos_embed
def forward_features(self, x):
x = self.patch_embed(x)
x, rot_pos_embed = self._pos_embed(x)
for blk in self.blocks: for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x, rope=rot_pos_embed) x = checkpoint(blk, x, rope=rot_pos_embed)
else: else:
x = blk(x, rope=rot_pos_embed) x = blk(x, rope=rot_pos_embed)
x = self.norm(x) x = self.norm(x)
return x return x