Fix dynamic_resize for deit models (distilled or no_embed_cls) and vit w/o class tokens
parent
4d8ecde6cc
commit
ea3519a5f0
|
@ -73,45 +73,36 @@ class VisionTransformerDistilled(VisionTransformer):
|
|||
def set_distilled_training(self, enable=True):
|
||||
self.distilled_training = enable
|
||||
|
||||
def _intermediate_layers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, Sequence] = 1,
|
||||
):
|
||||
outputs, num_blocks = [], len(self.blocks)
|
||||
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
|
||||
|
||||
# forward pass
|
||||
x = self.patch_embed(x)
|
||||
x = torch.cat((
|
||||
self.cls_token.expand(x.shape[0], -1, -1),
|
||||
self.dist_token.expand(x.shape[0], -1, -1),
|
||||
x),
|
||||
dim=1)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
x = self.patch_drop(x)
|
||||
x = self.norm_pre(x)
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if i in take_indices:
|
||||
outputs.append(x)
|
||||
|
||||
return outputs
|
||||
|
||||
def forward_features(self, x) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
x = torch.cat((
|
||||
self.cls_token.expand(x.shape[0], -1, -1),
|
||||
self.dist_token.expand(x.shape[0], -1, -1),
|
||||
x),
|
||||
dim=1)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
def _pos_embed(self, x):
|
||||
if self.dynamic_size:
|
||||
B, H, W, C = x.shape
|
||||
pos_embed = resample_abs_pos_embed(
|
||||
self.pos_embed,
|
||||
(H, W),
|
||||
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
||||
)
|
||||
x = x.view(B, -1, C)
|
||||
else:
|
||||
x = self.blocks(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
pos_embed = self.pos_embed
|
||||
if self.no_embed_class:
|
||||
# deit-3, updated JAX (big vision)
|
||||
# position embedding does not overlap with class token, add then concat
|
||||
x = x + pos_embed
|
||||
x = torch.cat((
|
||||
self.cls_token.expand(x.shape[0], -1, -1),
|
||||
self.dist_token.expand(x.shape[0], -1, -1),
|
||||
x),
|
||||
dim=1)
|
||||
else:
|
||||
# original timm, JAX, and deit vit impl
|
||||
# pos_embed has entry for class token, concat then add
|
||||
x = torch.cat((
|
||||
self.cls_token.expand(x.shape[0], -1, -1),
|
||||
self.dist_token.expand(x.shape[0], -1, -1),
|
||||
x),
|
||||
dim=1)
|
||||
x = x + pos_embed
|
||||
return self.pos_drop(x)
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
|
||||
x, x_dist = x[:, 0], x[:, 1]
|
||||
|
|
|
@ -459,11 +459,8 @@ class VisionTransformer(nn.Module):
|
|||
|
||||
embed_args = {}
|
||||
if dynamic_size:
|
||||
embed_args.update(dict(
|
||||
strict_img_size=False,
|
||||
flatten=False, # flatten deferred until after pos embed
|
||||
output_fmt='NHWC',
|
||||
))
|
||||
# flatten deferred until after pos embed
|
||||
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
|
||||
self.patch_embed = embed_layer(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
|
@ -559,7 +556,11 @@ class VisionTransformer(nn.Module):
|
|||
def _pos_embed(self, x):
|
||||
if self.dynamic_size:
|
||||
B, H, W, C = x.shape
|
||||
pos_embed = resample_abs_pos_embed(self.pos_embed, (H, W))
|
||||
pos_embed = resample_abs_pos_embed(
|
||||
self.pos_embed,
|
||||
(H, W),
|
||||
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
||||
)
|
||||
x = x.view(B, -1, C)
|
||||
else:
|
||||
pos_embed = self.pos_embed
|
||||
|
|
Loading…
Reference in New Issue