forward_intermediates() support for byob/byoanet models

This commit is contained in:
Ross Wightman 2024-05-04 14:06:52 -07:00
parent c4b8897e9e
commit 45b7ae8029
2 changed files with 100 additions and 7 deletions

View File

@ -51,7 +51,8 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
FEAT_INTER_FILTERS = [
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet'
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
'regnet', 'byobnet', 'byoanet', 'mlp_mixer'
]
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.

View File

@ -40,6 +40,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import named_apply, checkpoint_seq
from ._registry import generate_default_cfgs, register_model
@ -948,25 +949,37 @@ class Stem(nn.Sequential):
stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
prev_chs = in_chs
curr_stride = 1
last_feat_idx = -1
for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
layer_fn = layers.conv_norm_act if na else create_conv2d
conv_name = f'conv{i + 1}'
if i > 0 and s > 1:
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
last_feat_idx = i - 1
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0))
self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
prev_chs = ch
curr_stride *= s
prev_feat = conv_name
if pool and 'max' in pool.lower():
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
last_feat_idx = i
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0))
self.add_module('pool', nn.MaxPool2d(3, 2, 1))
curr_stride *= 2
prev_feat = 'pool'
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
self.last_feat_idx = last_feat_idx if last_feat_idx >= 0 else None
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0))
assert curr_stride == stride
def forward_intermediates(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
intermediate: Optional[torch.Tensor] = None
for i, m in enumerate(self):
x = m(x)
if self.last_feat_idx is not None and i == self.last_feat_idx:
intermediate = x
return x, intermediate
def create_byob_stem(
in_chs: int,
@ -1008,7 +1021,7 @@ def create_byob_stem(
if isinstance(stem, Stem):
feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info]
else:
feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix)]
feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix, stage=0)]
return stem, feature_info
@ -1122,7 +1135,7 @@ def create_byob_stages(
feat_size = reduce_feat_size(feat_size, stride)
stages += [nn.Sequential(*blocks)]
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}', stage=stage_idx + 1)
feature_info.append(prev_feat)
return nn.Sequential(*stages), feature_info
@ -1198,6 +1211,7 @@ class ByobNet(nn.Module):
feat_size=feat_size,
)
self.feature_info.extend(stage_feat[:-1])
reduction = stage_feat[-1]['reduction']
prev_chs = stage_feat[-1]['num_chs']
if cfg.num_features:
@ -1207,7 +1221,8 @@ class ByobNet(nn.Module):
self.num_features = prev_chs
self.final_conv = nn.Identity()
self.feature_info += [
dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
dict(num_chs=self.num_features, reduction=reduction, module='final_conv', stage=len(self.stages))]
self.stage_ends = [f['stage'] for f in self.feature_info]
self.head = ClassifierHead(
self.num_features,
@ -1241,6 +1256,83 @@ class ByobNet(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
exclude_final_conv: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
exclude_final_conv: Exclude final_conv from last intermediate
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
take_indices = [self.stage_ends[i] for i in take_indices]
max_index = self.stage_ends[max_index]
# forward pass
feat_idx = 0 # stem is index 0
if hasattr(self.stem, 'forward_intermediates'):
# returns last intermediate features in stem (before final stride in stride > 2 stems)
x, x_inter = self.stem.forward_intermediates(x)
else:
x, x_inter = self.stem(x), None
if feat_idx in take_indices:
intermediates.append(x if x_inter is None else x_inter)
last_idx = self.stage_ends[-1]
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index]
for stage in stages:
feat_idx += 1
x = stage(x)
if not exclude_final_conv and feat_idx == last_idx:
# default feature_info for this model uses final_conv as the last feature output (if present)
x = self.final_conv(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
if exclude_final_conv and feat_idx == last_idx:
x = self.final_conv(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
max_index = self.stage_ends[max_index]
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
if max_index < self.stage_ends[-1]:
self.final_conv = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():