mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
forward_intermediates() support for byob/byoanet models
This commit is contained in:
parent
c4b8897e9e
commit
45b7ae8029
@ -51,7 +51,8 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||
FEAT_INTER_FILTERS = [
|
||||
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
|
||||
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
|
||||
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet'
|
||||
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
|
||||
'regnet', 'byobnet', 'byoanet', 'mlp_mixer'
|
||||
]
|
||||
|
||||
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
|
||||
|
@ -40,6 +40,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
|
||||
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
@ -948,25 +949,37 @@ class Stem(nn.Sequential):
|
||||
stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
|
||||
prev_chs = in_chs
|
||||
curr_stride = 1
|
||||
last_feat_idx = -1
|
||||
for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
|
||||
layer_fn = layers.conv_norm_act if na else create_conv2d
|
||||
conv_name = f'conv{i + 1}'
|
||||
if i > 0 and s > 1:
|
||||
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
|
||||
last_feat_idx = i - 1
|
||||
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0))
|
||||
self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
|
||||
prev_chs = ch
|
||||
curr_stride *= s
|
||||
prev_feat = conv_name
|
||||
|
||||
if pool and 'max' in pool.lower():
|
||||
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
|
||||
last_feat_idx = i
|
||||
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0))
|
||||
self.add_module('pool', nn.MaxPool2d(3, 2, 1))
|
||||
curr_stride *= 2
|
||||
prev_feat = 'pool'
|
||||
|
||||
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
|
||||
self.last_feat_idx = last_feat_idx if last_feat_idx >= 0 else None
|
||||
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0))
|
||||
assert curr_stride == stride
|
||||
|
||||
def forward_intermediates(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
intermediate: Optional[torch.Tensor] = None
|
||||
for i, m in enumerate(self):
|
||||
x = m(x)
|
||||
if self.last_feat_idx is not None and i == self.last_feat_idx:
|
||||
intermediate = x
|
||||
return x, intermediate
|
||||
|
||||
|
||||
def create_byob_stem(
|
||||
in_chs: int,
|
||||
@ -1008,7 +1021,7 @@ def create_byob_stem(
|
||||
if isinstance(stem, Stem):
|
||||
feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info]
|
||||
else:
|
||||
feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix)]
|
||||
feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix, stage=0)]
|
||||
return stem, feature_info
|
||||
|
||||
|
||||
@ -1122,7 +1135,7 @@ def create_byob_stages(
|
||||
feat_size = reduce_feat_size(feat_size, stride)
|
||||
|
||||
stages += [nn.Sequential(*blocks)]
|
||||
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
|
||||
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}', stage=stage_idx + 1)
|
||||
|
||||
feature_info.append(prev_feat)
|
||||
return nn.Sequential(*stages), feature_info
|
||||
@ -1198,6 +1211,7 @@ class ByobNet(nn.Module):
|
||||
feat_size=feat_size,
|
||||
)
|
||||
self.feature_info.extend(stage_feat[:-1])
|
||||
reduction = stage_feat[-1]['reduction']
|
||||
|
||||
prev_chs = stage_feat[-1]['num_chs']
|
||||
if cfg.num_features:
|
||||
@ -1207,7 +1221,8 @@ class ByobNet(nn.Module):
|
||||
self.num_features = prev_chs
|
||||
self.final_conv = nn.Identity()
|
||||
self.feature_info += [
|
||||
dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
|
||||
dict(num_chs=self.num_features, reduction=reduction, module='final_conv', stage=len(self.stages))]
|
||||
self.stage_ends = [f['stage'] for f in self.feature_info]
|
||||
|
||||
self.head = ClassifierHead(
|
||||
self.num_features,
|
||||
@ -1241,6 +1256,83 @@ class ByobNet(nn.Module):
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
exclude_final_conv: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
exclude_final_conv: Exclude final_conv from last intermediate
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||
take_indices = [self.stage_ends[i] for i in take_indices]
|
||||
max_index = self.stage_ends[max_index]
|
||||
# forward pass
|
||||
feat_idx = 0 # stem is index 0
|
||||
if hasattr(self.stem, 'forward_intermediates'):
|
||||
# returns last intermediate features in stem (before final stride in stride > 2 stems)
|
||||
x, x_inter = self.stem.forward_intermediates(x)
|
||||
else:
|
||||
x, x_inter = self.stem(x), None
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x if x_inter is None else x_inter)
|
||||
last_idx = self.stage_ends[-1]
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index]
|
||||
for stage in stages:
|
||||
feat_idx += 1
|
||||
x = stage(x)
|
||||
if not exclude_final_conv and feat_idx == last_idx:
|
||||
# default feature_info for this model uses final_conv as the last feature output (if present)
|
||||
x = self.final_conv(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
if exclude_final_conv and feat_idx == last_idx:
|
||||
x = self.final_conv(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||
max_index = self.stage_ends[max_index]
|
||||
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
|
||||
if max_index < self.stage_ends[-1]:
|
||||
self.final_conv = nn.Identity()
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
|
Loading…
x
Reference in New Issue
Block a user