mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix dynamic_resize for deit models (distilled or no_embed_cls) and vit w/o class tokens
This commit is contained in:
parent
4d8ecde6cc
commit
ea3519a5f0
@ -73,45 +73,36 @@ class VisionTransformerDistilled(VisionTransformer):
|
|||||||
def set_distilled_training(self, enable=True):
|
def set_distilled_training(self, enable=True):
|
||||||
self.distilled_training = enable
|
self.distilled_training = enable
|
||||||
|
|
||||||
def _intermediate_layers(
|
def _pos_embed(self, x):
|
||||||
self,
|
if self.dynamic_size:
|
||||||
x: torch.Tensor,
|
B, H, W, C = x.shape
|
||||||
n: Union[int, Sequence] = 1,
|
pos_embed = resample_abs_pos_embed(
|
||||||
):
|
self.pos_embed,
|
||||||
outputs, num_blocks = [], len(self.blocks)
|
(H, W),
|
||||||
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
|
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
||||||
|
)
|
||||||
# forward pass
|
x = x.view(B, -1, C)
|
||||||
x = self.patch_embed(x)
|
|
||||||
x = torch.cat((
|
|
||||||
self.cls_token.expand(x.shape[0], -1, -1),
|
|
||||||
self.dist_token.expand(x.shape[0], -1, -1),
|
|
||||||
x),
|
|
||||||
dim=1)
|
|
||||||
x = self.pos_drop(x + self.pos_embed)
|
|
||||||
x = self.patch_drop(x)
|
|
||||||
x = self.norm_pre(x)
|
|
||||||
for i, blk in enumerate(self.blocks):
|
|
||||||
x = blk(x)
|
|
||||||
if i in take_indices:
|
|
||||||
outputs.append(x)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def forward_features(self, x) -> torch.Tensor:
|
|
||||||
x = self.patch_embed(x)
|
|
||||||
x = torch.cat((
|
|
||||||
self.cls_token.expand(x.shape[0], -1, -1),
|
|
||||||
self.dist_token.expand(x.shape[0], -1, -1),
|
|
||||||
x),
|
|
||||||
dim=1)
|
|
||||||
x = self.pos_drop(x + self.pos_embed)
|
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
||||||
x = checkpoint_seq(self.blocks, x)
|
|
||||||
else:
|
else:
|
||||||
x = self.blocks(x)
|
pos_embed = self.pos_embed
|
||||||
x = self.norm(x)
|
if self.no_embed_class:
|
||||||
return x
|
# deit-3, updated JAX (big vision)
|
||||||
|
# position embedding does not overlap with class token, add then concat
|
||||||
|
x = x + pos_embed
|
||||||
|
x = torch.cat((
|
||||||
|
self.cls_token.expand(x.shape[0], -1, -1),
|
||||||
|
self.dist_token.expand(x.shape[0], -1, -1),
|
||||||
|
x),
|
||||||
|
dim=1)
|
||||||
|
else:
|
||||||
|
# original timm, JAX, and deit vit impl
|
||||||
|
# pos_embed has entry for class token, concat then add
|
||||||
|
x = torch.cat((
|
||||||
|
self.cls_token.expand(x.shape[0], -1, -1),
|
||||||
|
self.dist_token.expand(x.shape[0], -1, -1),
|
||||||
|
x),
|
||||||
|
dim=1)
|
||||||
|
x = x + pos_embed
|
||||||
|
return self.pos_drop(x)
|
||||||
|
|
||||||
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
|
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
|
||||||
x, x_dist = x[:, 0], x[:, 1]
|
x, x_dist = x[:, 0], x[:, 1]
|
||||||
|
@ -459,11 +459,8 @@ class VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
embed_args = {}
|
embed_args = {}
|
||||||
if dynamic_size:
|
if dynamic_size:
|
||||||
embed_args.update(dict(
|
# flatten deferred until after pos embed
|
||||||
strict_img_size=False,
|
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
|
||||||
flatten=False, # flatten deferred until after pos embed
|
|
||||||
output_fmt='NHWC',
|
|
||||||
))
|
|
||||||
self.patch_embed = embed_layer(
|
self.patch_embed = embed_layer(
|
||||||
img_size=img_size,
|
img_size=img_size,
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
@ -559,7 +556,11 @@ class VisionTransformer(nn.Module):
|
|||||||
def _pos_embed(self, x):
|
def _pos_embed(self, x):
|
||||||
if self.dynamic_size:
|
if self.dynamic_size:
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
pos_embed = resample_abs_pos_embed(self.pos_embed, (H, W))
|
pos_embed = resample_abs_pos_embed(
|
||||||
|
self.pos_embed,
|
||||||
|
(H, W),
|
||||||
|
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
||||||
|
)
|
||||||
x = x.view(B, -1, C)
|
x = x.view(B, -1, C)
|
||||||
else:
|
else:
|
||||||
pos_embed = self.pos_embed
|
pos_embed = self.pos_embed
|
||||||
|
Loading…
x
Reference in New Issue
Block a user