mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Support dynamic_resize in eva.py models
This commit is contained in:
parent
ea3519a5f0
commit
1f4512fca3
@ -400,13 +400,12 @@ class RotaryEmbeddingCat(nn.Module):
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
step=1,
|
step=1,
|
||||||
)
|
)
|
||||||
print(bands)
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
'bands',
|
'bands',
|
||||||
bands,
|
bands,
|
||||||
persistent=False,
|
persistent=False,
|
||||||
)
|
)
|
||||||
self.embed = None
|
self.pos_embed = None
|
||||||
else:
|
else:
|
||||||
# cache full sin/cos embeddings if shape provided up front
|
# cache full sin/cos embeddings if shape provided up front
|
||||||
embeds = build_rotary_pos_embed(
|
embeds = build_rotary_pos_embed(
|
||||||
@ -425,17 +424,19 @@ class RotaryEmbeddingCat(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_embed(self, shape: Optional[List[int]] = None):
|
def get_embed(self, shape: Optional[List[int]] = None):
|
||||||
if self.bands is not None:
|
if self.bands is not None and shape is not None:
|
||||||
# rebuild embeddings every call, use if target shape changes
|
# rebuild embeddings every call, use if target shape changes
|
||||||
_assert(shape is not None, 'valid shape needed')
|
|
||||||
embeds = build_rotary_pos_embed(
|
embeds = build_rotary_pos_embed(
|
||||||
shape,
|
shape,
|
||||||
self.bands,
|
self.bands,
|
||||||
in_pixels=self.in_pixels,
|
in_pixels=self.in_pixels,
|
||||||
|
ref_feat_shape=self.ref_feat_shape,
|
||||||
)
|
)
|
||||||
return torch.cat(embeds, -1)
|
return torch.cat(embeds, -1)
|
||||||
else:
|
elif self.pos_embed is not None:
|
||||||
return self.pos_embed
|
return self.pos_embed
|
||||||
|
else:
|
||||||
|
assert False, "get_embed() requires pre-computed pos_embed or valid shape w/ pre-computed bands"
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# assuming channel-first tensor where spatial dim are >= 2
|
# assuming channel-first tensor where spatial dim are >= 2
|
||||||
|
@ -367,6 +367,7 @@ class Eva(nn.Module):
|
|||||||
use_abs_pos_emb: bool = True,
|
use_abs_pos_emb: bool = True,
|
||||||
use_rot_pos_emb: bool = False,
|
use_rot_pos_emb: bool = False,
|
||||||
use_post_norm: bool = False,
|
use_post_norm: bool = False,
|
||||||
|
dynamic_size: bool = False,
|
||||||
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
|
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
|
||||||
head_init_scale: float = 0.001,
|
head_init_scale: float = 0.001,
|
||||||
):
|
):
|
||||||
@ -406,13 +407,19 @@ class Eva(nn.Module):
|
|||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||||
self.num_prefix_tokens = 1 if class_token else 0
|
self.num_prefix_tokens = 1 if class_token else 0
|
||||||
|
self.dynamic_size = dynamic_size
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
|
embed_args = {}
|
||||||
|
if dynamic_size:
|
||||||
|
# flatten deferred until after pos embed
|
||||||
|
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
|
||||||
self.patch_embed = PatchEmbed(
|
self.patch_embed = PatchEmbed(
|
||||||
img_size=img_size,
|
img_size=img_size,
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
in_chans=in_chans,
|
in_chans=in_chans,
|
||||||
embed_dim=embed_dim,
|
embed_dim=embed_dim,
|
||||||
|
**embed_args,
|
||||||
)
|
)
|
||||||
num_patches = self.patch_embed.num_patches
|
num_patches = self.patch_embed.num_patches
|
||||||
|
|
||||||
@ -435,7 +442,7 @@ class Eva(nn.Module):
|
|||||||
self.rope = RotaryEmbeddingCat(
|
self.rope = RotaryEmbeddingCat(
|
||||||
embed_dim // num_heads,
|
embed_dim // num_heads,
|
||||||
in_pixels=False,
|
in_pixels=False,
|
||||||
feat_shape=self.patch_embed.grid_size,
|
feat_shape=None if dynamic_size else self.patch_embed.grid_size,
|
||||||
ref_feat_shape=ref_feat_shape,
|
ref_feat_shape=ref_feat_shape,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -519,30 +526,44 @@ class Eva(nn.Module):
|
|||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
def forward_features(self, x):
|
def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
x = self.patch_embed(x)
|
if self.dynamic_size:
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
if self.pos_embed is not None:
|
||||||
|
pos_embed = resample_abs_pos_embed(
|
||||||
|
self.pos_embed,
|
||||||
|
(H, W),
|
||||||
|
num_prefix_tokens=self.num_prefix_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pos_embed = None
|
||||||
|
x = x.view(B, -1, C)
|
||||||
|
rot_pos_embed = self.rope.get_embed(shape=(H, W)) if self.rope is not None else None
|
||||||
|
else:
|
||||||
|
pos_embed = self.pos_embed
|
||||||
|
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
|
||||||
|
|
||||||
if self.cls_token is not None:
|
if self.cls_token is not None:
|
||||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||||
|
if pos_embed is not None:
|
||||||
# apply abs position embedding
|
x = x + pos_embed
|
||||||
if self.pos_embed is not None:
|
|
||||||
x = x + self.pos_embed
|
|
||||||
x = self.pos_drop(x)
|
x = self.pos_drop(x)
|
||||||
|
|
||||||
# obtain shared rotary position embedding and apply patch dropout
|
# obtain shared rotary position embedding and apply patch dropout
|
||||||
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
|
|
||||||
if self.patch_drop is not None:
|
if self.patch_drop is not None:
|
||||||
x, keep_indices = self.patch_drop(x)
|
x, keep_indices = self.patch_drop(x)
|
||||||
if rot_pos_embed is not None and keep_indices is not None:
|
if rot_pos_embed is not None and keep_indices is not None:
|
||||||
rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
|
rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
|
||||||
|
return x, rot_pos_embed
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
x, rot_pos_embed = self._pos_embed(x)
|
||||||
for blk in self.blocks:
|
for blk in self.blocks:
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
x = checkpoint(blk, x, rope=rot_pos_embed)
|
x = checkpoint(blk, x, rope=rot_pos_embed)
|
||||||
else:
|
else:
|
||||||
x = blk(x, rope=rot_pos_embed)
|
x = blk(x, rope=rot_pos_embed)
|
||||||
|
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user