Fix dynamic_resize for deit models (distilled or no_embed_cls) and vit w/o class tokens

vit_packed
Ross Wightman 2023-08-26 15:27:00 -07:00 committed by Ross Wightman
parent 4d8ecde6cc
commit ea3519a5f0
2 changed files with 36 additions and 44 deletions

View File

@ -73,45 +73,36 @@ class VisionTransformerDistilled(VisionTransformer):
def set_distilled_training(self, enable=True):
self.distilled_training = enable
def _intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
):
outputs, num_blocks = [], len(self.blocks)
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
# forward pass
x = self.patch_embed(x)
x = torch.cat((
self.cls_token.expand(x.shape[0], -1, -1),
self.dist_token.expand(x.shape[0], -1, -1),
x),
dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.patch_drop(x)
x = self.norm_pre(x)
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in take_indices:
outputs.append(x)
return outputs
def forward_features(self, x) -> torch.Tensor:
x = self.patch_embed(x)
x = torch.cat((
self.cls_token.expand(x.shape[0], -1, -1),
self.dist_token.expand(x.shape[0], -1, -1),
x),
dim=1)
x = self.pos_drop(x + self.pos_embed)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
def _pos_embed(self, x):
if self.dynamic_size:
B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed(
self.pos_embed,
(H, W),
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
)
x = x.view(B, -1, C)
else:
x = self.blocks(x)
x = self.norm(x)
return x
pos_embed = self.pos_embed
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + pos_embed
x = torch.cat((
self.cls_token.expand(x.shape[0], -1, -1),
self.dist_token.expand(x.shape[0], -1, -1),
x),
dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
x = torch.cat((
self.cls_token.expand(x.shape[0], -1, -1),
self.dist_token.expand(x.shape[0], -1, -1),
x),
dim=1)
x = x + pos_embed
return self.pos_drop(x)
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
x, x_dist = x[:, 0], x[:, 1]

View File

@ -459,11 +459,8 @@ class VisionTransformer(nn.Module):
embed_args = {}
if dynamic_size:
embed_args.update(dict(
strict_img_size=False,
flatten=False, # flatten deferred until after pos embed
output_fmt='NHWC',
))
# flatten deferred until after pos embed
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
@ -559,7 +556,11 @@ class VisionTransformer(nn.Module):
def _pos_embed(self, x):
if self.dynamic_size:
B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed(self.pos_embed, (H, W))
pos_embed = resample_abs_pos_embed(
self.pos_embed,
(H, W),
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
)
x = x.view(B, -1, C)
else:
pos_embed = self.pos_embed