diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index f7d64349..b775b736 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -40,6 +40,7 @@ Hacked together by / Copyright 2021 Ross Wightman """ import math from functools import partial +from typing import List, Optional, Union, Tuple import torch import torch.nn as nn @@ -47,6 +48,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint_seq from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -211,6 +213,7 @@ class MlpMixer(nn.Module): embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None, ) + reduction = self.stem.feat_ratio() if hasattr(self.stem, 'feat_ratio') else patch_size # FIXME drop_path (stochastic depth scaling rule or all the same?) self.blocks = nn.Sequential(*[ block_layer( @@ -224,6 +227,8 @@ class MlpMixer(nn.Module): drop_path=drop_path_rate, ) for _ in range(num_blocks)]) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(num_blocks)] self.norm = norm_layer(embed_dim) self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() @@ -257,6 +262,76 @@ class MlpMixer(nn.Module): self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + return_prefix_tokens: Return both prefix and spatial intermediate tokens + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + + # forward pass + B, _, height, width = x.shape + x = self.stem(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x) + if i in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) + + # process intermediates + if reshape: + # reshape to BCHW output format + H, W = self.stem.dynamic_feat_size((height, width)) + intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + + if intermediates_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): @@ -330,14 +405,13 @@ def checkpoint_filter_fn(state_dict, model): def _create_mixer(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for MLP-Mixer models.') - + out_indices = kwargs.pop('out_indices', 3) model = build_model_with_cfg( MlpMixer, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) return model diff --git a/timm/models/regnet.py b/timm/models/regnet.py index bc73f540..12187378 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -26,7 +26,7 @@ Hacked together by / Copyright 2020 Ross Wightman import math from dataclasses import dataclass, replace from functools import partial -from typing import Optional, Union, Callable +from typing import Callable, List, Optional, Union, Tuple import numpy as np import torch @@ -36,6 +36,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d, make_divisible from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq, named_apply from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -515,6 +516,73 @@ class RegNet(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.head.reset(num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(5, indices) + + # forward pass + feat_idx = 0 + x = self.stem(x) + if feat_idx in take_indices: + intermediates.append(x) + + layer_names = ('s1', 's2', 's3', 's4') + if stop_early: + layer_names = layer_names[:max_index] + for n in layer_names: + feat_idx += 1 + x = getattr(self, n)(x) # won't work with torchscript, but keeps code reasonable, FML + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + if feat_idx == 4: + x = self.final_conv(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(5, indices) + layer_names = ('s1', 's2', 's3', 's4') + layer_names = layer_names[max_index:] + for n in layer_names: + setattr(self, n, nn.Identity()) + if max_index < 4: + self.final_conv = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) x = self.s1(x)