mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Support features_only
This commit is contained in:
parent
b37f0f7a76
commit
848b8c3e57
@ -10,7 +10,7 @@ The official pytorch code is released and available at
|
|||||||
https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch
|
https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -18,6 +18,7 @@ import torch.nn as nn
|
|||||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resample_abs_pos_embed
|
from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resample_abs_pos_embed
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import checkpoint
|
from ._manipulate import checkpoint
|
||||||
from ._registry import register_model
|
from ._registry import register_model
|
||||||
|
|
||||||
@ -172,7 +173,16 @@ class PixelEmbed(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
|
self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
def forward(self, x, pixel_pos):
|
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
|
||||||
|
if as_scalar:
|
||||||
|
return max(self.patch_size)
|
||||||
|
else:
|
||||||
|
return self.patch_size
|
||||||
|
|
||||||
|
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
|
||||||
|
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, pixel_pos: torch.Tensor) -> torch.Tensor:
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
_assert(H == self.img_size[0],
|
_assert(H == self.img_size[0],
|
||||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
||||||
@ -222,6 +232,7 @@ class TNT(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
|
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
|
||||||
|
self.num_prefix_tokens = 1
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
self.pixel_embed = PixelEmbed(
|
self.pixel_embed = PixelEmbed(
|
||||||
@ -233,6 +244,7 @@ class TNT(nn.Module):
|
|||||||
legacy=legacy,
|
legacy=legacy,
|
||||||
)
|
)
|
||||||
num_patches = self.pixel_embed.num_patches
|
num_patches = self.pixel_embed.num_patches
|
||||||
|
r = self.pixel_embed.feat_ratio() if hasattr(self.pixel_embed, 'feat_ratio') else patch_size
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
new_patch_size = self.pixel_embed.new_patch_size
|
new_patch_size = self.pixel_embed.new_patch_size
|
||||||
num_pixel = new_patch_size[0] * new_patch_size[1]
|
num_pixel = new_patch_size[0] * new_patch_size[1]
|
||||||
@ -264,8 +276,10 @@ class TNT(nn.Module):
|
|||||||
legacy=legacy,
|
legacy=legacy,
|
||||||
))
|
))
|
||||||
self.blocks = nn.ModuleList(blocks)
|
self.blocks = nn.ModuleList(blocks)
|
||||||
|
self.feature_info = [
|
||||||
|
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
|
||||||
|
|
||||||
self.norm = norm_layer(embed_dim)
|
self.norm = norm_layer(embed_dim)
|
||||||
|
|
||||||
self.head_drop = nn.Dropout(drop_rate)
|
self.head_drop = nn.Dropout(drop_rate)
|
||||||
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
@ -313,6 +327,92 @@ class TNT(nn.Module):
|
|||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
return_prefix_tokens: bool = False,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if an int, if is a sequence, select by matching indices
|
||||||
|
return_prefix_tokens: Return both prefix and spatial intermediate tokens
|
||||||
|
norm: Apply norm layer to all intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
|
||||||
|
reshape = output_fmt == 'NCHW'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
B, _, height, width = x.shape
|
||||||
|
|
||||||
|
pixel_embed = self.pixel_embed(x, self.pixel_pos)
|
||||||
|
|
||||||
|
patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1))))
|
||||||
|
patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1)
|
||||||
|
patch_embed = patch_embed + self.patch_pos
|
||||||
|
patch_embed = self.pos_drop(patch_embed)
|
||||||
|
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
blocks = self.blocks
|
||||||
|
else:
|
||||||
|
blocks = self.blocks[:max_index + 1]
|
||||||
|
|
||||||
|
for i, blk in enumerate(blocks):
|
||||||
|
pixel_embed, patch_embed = blk(pixel_embed, patch_embed)
|
||||||
|
if i in take_indices:
|
||||||
|
# normalize intermediates with final norm layer if enabled
|
||||||
|
intermediates.append(self.norm(patch_embed) if norm else patch_embed)
|
||||||
|
|
||||||
|
# process intermediates
|
||||||
|
if self.num_prefix_tokens:
|
||||||
|
# split prefix (e.g. class, distill) and spatial feature tokens
|
||||||
|
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
|
||||||
|
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
|
||||||
|
|
||||||
|
if reshape:
|
||||||
|
# reshape to BCHW output format
|
||||||
|
H, W = self.pixel_embed.dynamic_feat_size((height, width))
|
||||||
|
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||||
|
if not torch.jit.is_scripting() and return_prefix_tokens:
|
||||||
|
# return_prefix not support in torchscript due to poor type handling
|
||||||
|
intermediates = list(zip(intermediates, prefix_tokens))
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
patch_embed = self.norm(patch_embed)
|
||||||
|
|
||||||
|
return patch_embed, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||||
|
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||||
|
if prune_norm:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
B = x.shape[0]
|
B = x.shape[0]
|
||||||
pixel_embed = self.pixel_embed(x, self.pixel_pos)
|
pixel_embed = self.pixel_embed(x, self.pixel_pos)
|
||||||
@ -322,11 +422,10 @@ class TNT(nn.Module):
|
|||||||
patch_embed = patch_embed + self.patch_pos
|
patch_embed = patch_embed + self.patch_pos
|
||||||
patch_embed = self.pos_drop(patch_embed)
|
patch_embed = self.pos_drop(patch_embed)
|
||||||
|
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
for blk in self.blocks:
|
||||||
for blk in self.blocks:
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
pixel_embed, patch_embed = checkpoint(blk, pixel_embed, patch_embed)
|
pixel_embed, patch_embed = checkpoint(blk, pixel_embed, patch_embed)
|
||||||
else:
|
else:
|
||||||
for blk in self.blocks:
|
|
||||||
pixel_embed, patch_embed = blk(pixel_embed, patch_embed)
|
pixel_embed, patch_embed = blk(pixel_embed, patch_embed)
|
||||||
|
|
||||||
patch_embed = self.norm(patch_embed)
|
patch_embed = self.norm(patch_embed)
|
||||||
@ -334,7 +433,7 @@ class TNT(nn.Module):
|
|||||||
|
|
||||||
def forward_head(self, x, pre_logits: bool = False):
|
def forward_head(self, x, pre_logits: bool = False):
|
||||||
if self.global_pool:
|
if self.global_pool:
|
||||||
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||||
x = self.head_drop(x)
|
x = self.head_drop(x)
|
||||||
return x if pre_logits else self.head(x)
|
return x if pre_logits else self.head(x)
|
||||||
|
|
||||||
@ -344,6 +443,30 @@ class TNT(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url,
|
||||||
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||||
|
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||||
|
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||||
|
'first_conv': 'pixel_embed.proj', 'classifier': 'head',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = {
|
||||||
|
'tnt_s_patch16_224.in1k': _cfg(
|
||||||
|
# hf_hub_id='timm/',
|
||||||
|
# url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
|
||||||
|
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar',
|
||||||
|
),
|
||||||
|
'tnt_b_patch16_224.in1k': _cfg(
|
||||||
|
# hf_hub_id='timm/',
|
||||||
|
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_filter_fn(state_dict, model):
|
def checkpoint_filter_fn(state_dict, model):
|
||||||
state_dict.pop('outer_tokens', None)
|
state_dict.pop('outer_tokens', None)
|
||||||
|
|
||||||
@ -380,40 +503,15 @@ def checkpoint_filter_fn(state_dict, model):
|
|||||||
|
|
||||||
|
|
||||||
def _create_tnt(variant, pretrained=False, **kwargs):
|
def _create_tnt(variant, pretrained=False, **kwargs):
|
||||||
if kwargs.get('features_only', None):
|
out_indices = kwargs.pop('out_indices', 3)
|
||||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
|
||||||
|
|
||||||
model = build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
TNT, variant, pretrained,
|
TNT, variant, pretrained,
|
||||||
pretrained_filter_fn=checkpoint_filter_fn,
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||||||
|
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||||
**kwargs)
|
**kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
|
||||||
return {
|
|
||||||
'url': url,
|
|
||||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
|
||||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
|
||||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
|
||||||
'first_conv': 'pixel_embed.proj', 'classifier': 'head',
|
|
||||||
**kwargs
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
default_cfgs = {
|
|
||||||
'tnt_s_patch16_224': _cfg(
|
|
||||||
# hf_hub_id='timm/',
|
|
||||||
# url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
|
|
||||||
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar',
|
|
||||||
),
|
|
||||||
'tnt_b_patch16_224': _cfg(
|
|
||||||
# hf_hub_id='timm/',
|
|
||||||
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def tnt_s_patch16_224(pretrained=False, **kwargs) -> TNT:
|
def tnt_s_patch16_224(pretrained=False, **kwargs) -> TNT:
|
||||||
model_cfg = dict(
|
model_cfg = dict(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user