mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Support dynamic_resize in eva.py models
This commit is contained in:
parent
ea3519a5f0
commit
1f4512fca3
@ -400,13 +400,12 @@ class RotaryEmbeddingCat(nn.Module):
|
||||
temperature=temperature,
|
||||
step=1,
|
||||
)
|
||||
print(bands)
|
||||
self.register_buffer(
|
||||
'bands',
|
||||
bands,
|
||||
persistent=False,
|
||||
)
|
||||
self.embed = None
|
||||
self.pos_embed = None
|
||||
else:
|
||||
# cache full sin/cos embeddings if shape provided up front
|
||||
embeds = build_rotary_pos_embed(
|
||||
@ -425,17 +424,19 @@ class RotaryEmbeddingCat(nn.Module):
|
||||
)
|
||||
|
||||
def get_embed(self, shape: Optional[List[int]] = None):
|
||||
if self.bands is not None:
|
||||
if self.bands is not None and shape is not None:
|
||||
# rebuild embeddings every call, use if target shape changes
|
||||
_assert(shape is not None, 'valid shape needed')
|
||||
embeds = build_rotary_pos_embed(
|
||||
shape,
|
||||
self.bands,
|
||||
in_pixels=self.in_pixels,
|
||||
ref_feat_shape=self.ref_feat_shape,
|
||||
)
|
||||
return torch.cat(embeds, -1)
|
||||
else:
|
||||
elif self.pos_embed is not None:
|
||||
return self.pos_embed
|
||||
else:
|
||||
assert False, "get_embed() requires pre-computed pos_embed or valid shape w/ pre-computed bands"
|
||||
|
||||
def forward(self, x):
|
||||
# assuming channel-first tensor where spatial dim are >= 2
|
||||
|
@ -367,6 +367,7 @@ class Eva(nn.Module):
|
||||
use_abs_pos_emb: bool = True,
|
||||
use_rot_pos_emb: bool = False,
|
||||
use_post_norm: bool = False,
|
||||
dynamic_size: bool = False,
|
||||
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
|
||||
head_init_scale: float = 0.001,
|
||||
):
|
||||
@ -406,13 +407,19 @@ class Eva(nn.Module):
|
||||
self.global_pool = global_pool
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_prefix_tokens = 1 if class_token else 0
|
||||
self.dynamic_size = dynamic_size
|
||||
self.grad_checkpointing = False
|
||||
|
||||
embed_args = {}
|
||||
if dynamic_size:
|
||||
# flatten deferred until after pos embed
|
||||
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
**embed_args,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
@ -435,7 +442,7 @@ class Eva(nn.Module):
|
||||
self.rope = RotaryEmbeddingCat(
|
||||
embed_dim // num_heads,
|
||||
in_pixels=False,
|
||||
feat_shape=self.patch_embed.grid_size,
|
||||
feat_shape=None if dynamic_size else self.patch_embed.grid_size,
|
||||
ref_feat_shape=ref_feat_shape,
|
||||
)
|
||||
else:
|
||||
@ -519,30 +526,44 @@ class Eva(nn.Module):
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if self.dynamic_size:
|
||||
B, H, W, C = x.shape
|
||||
if self.pos_embed is not None:
|
||||
pos_embed = resample_abs_pos_embed(
|
||||
self.pos_embed,
|
||||
(H, W),
|
||||
num_prefix_tokens=self.num_prefix_tokens,
|
||||
)
|
||||
else:
|
||||
pos_embed = None
|
||||
x = x.view(B, -1, C)
|
||||
rot_pos_embed = self.rope.get_embed(shape=(H, W)) if self.rope is not None else None
|
||||
else:
|
||||
pos_embed = self.pos_embed
|
||||
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
|
||||
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
|
||||
# apply abs position embedding
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
if pos_embed is not None:
|
||||
x = x + pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
# obtain shared rotary position embedding and apply patch dropout
|
||||
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
|
||||
if self.patch_drop is not None:
|
||||
x, keep_indices = self.patch_drop(x)
|
||||
if rot_pos_embed is not None and keep_indices is not None:
|
||||
rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
|
||||
return x, rot_pos_embed
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x, rot_pos_embed = self._pos_embed(x)
|
||||
for blk in self.blocks:
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint(blk, x, rope=rot_pos_embed)
|
||||
else:
|
||||
x = blk(x, rope=rot_pos_embed)
|
||||
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user