From 45b7ae8029cabfe8fda2c379e2697762fc5b3ebc Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 4 May 2024 14:06:52 -0700 Subject: [PATCH] forward_intermediates() support for byob/byoanet models --- tests/test_models.py | 3 +- timm/models/byobnet.py | 104 ++++++++++++++++++++++++++++++++++++++--- 2 files changed, 100 insertions(+), 7 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 34bf0af4..9ff64c3b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -51,7 +51,8 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): FEAT_INTER_FILTERS = [ 'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos', 'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2', - 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet' + 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet', + 'regnet', 'byobnet', 'byoanet', 'mlp_mixer' ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index a2ff0095..a2b44e1a 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -40,6 +40,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -948,25 +949,37 @@ class Stem(nn.Sequential): stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act prev_chs = in_chs curr_stride = 1 + last_feat_idx = -1 for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)): layer_fn = layers.conv_norm_act if na else create_conv2d conv_name = f'conv{i + 1}' if i > 0 and s > 1: - self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat)) + last_feat_idx = i - 1 + self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0)) self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s)) prev_chs = ch curr_stride *= s prev_feat = conv_name if pool and 'max' in pool.lower(): - self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat)) + last_feat_idx = i + self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0)) self.add_module('pool', nn.MaxPool2d(3, 2, 1)) curr_stride *= 2 prev_feat = 'pool' - self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat)) + self.last_feat_idx = last_feat_idx if last_feat_idx >= 0 else None + self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0)) assert curr_stride == stride + def forward_intermediates(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + intermediate: Optional[torch.Tensor] = None + for i, m in enumerate(self): + x = m(x) + if self.last_feat_idx is not None and i == self.last_feat_idx: + intermediate = x + return x, intermediate + def create_byob_stem( in_chs: int, @@ -1008,7 +1021,7 @@ def create_byob_stem( if isinstance(stem, Stem): feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info] else: - feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix)] + feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix, stage=0)] return stem, feature_info @@ -1122,7 +1135,7 @@ def create_byob_stages( feat_size = reduce_feat_size(feat_size, stride) stages += [nn.Sequential(*blocks)] - prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}') + prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}', stage=stage_idx + 1) feature_info.append(prev_feat) return nn.Sequential(*stages), feature_info @@ -1198,6 +1211,7 @@ class ByobNet(nn.Module): feat_size=feat_size, ) self.feature_info.extend(stage_feat[:-1]) + reduction = stage_feat[-1]['reduction'] prev_chs = stage_feat[-1]['num_chs'] if cfg.num_features: @@ -1207,7 +1221,8 @@ class ByobNet(nn.Module): self.num_features = prev_chs self.final_conv = nn.Identity() self.feature_info += [ - dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')] + dict(num_chs=self.num_features, reduction=reduction, module='final_conv', stage=len(self.stages))] + self.stage_ends = [f['stage'] for f in self.feature_info] self.head = ClassifierHead( self.num_features, @@ -1241,6 +1256,83 @@ class ByobNet(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + exclude_final_conv: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + exclude_final_conv: Exclude final_conv from last intermediate + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + take_indices = [self.stage_ends[i] for i in take_indices] + max_index = self.stage_ends[max_index] + # forward pass + feat_idx = 0 # stem is index 0 + if hasattr(self.stem, 'forward_intermediates'): + # returns last intermediate features in stem (before final stride in stride > 2 stems) + x, x_inter = self.stem.forward_intermediates(x) + else: + x, x_inter = self.stem(x), None + if feat_idx in take_indices: + intermediates.append(x if x_inter is None else x_inter) + last_idx = self.stage_ends[-1] + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index] + for stage in stages: + feat_idx += 1 + x = stage(x) + if not exclude_final_conv and feat_idx == last_idx: + # default feature_info for this model uses final_conv as the last feature output (if present) + x = self.final_conv(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + if exclude_final_conv and feat_idx == last_idx: + x = self.final_conv(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + max_index = self.stage_ends[max_index] + self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0 + if max_index < self.stage_ends[-1]: + self.final_conv = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting():