forward_intermediates() for MlpMixer models and RegNet.

This commit is contained in:
Ross Wightman 2024-05-04 10:21:03 -07:00
parent f8979d4f50
commit 01dd01b70e
2 changed files with 146 additions and 4 deletions

View File

@ -40,6 +40,7 @@ Hacked together by / Copyright 2021 Ross Wightman
"""
import math
from functools import partial
from typing import List, Optional, Union, Tuple
import torch
import torch.nn as nn
@ -47,6 +48,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import named_apply, checkpoint_seq
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -211,6 +213,7 @@ class MlpMixer(nn.Module):
embed_dim=embed_dim,
norm_layer=norm_layer if stem_norm else None,
)
reduction = self.stem.feat_ratio() if hasattr(self.stem, 'feat_ratio') else patch_size
# FIXME drop_path (stochastic depth scaling rule or all the same?)
self.blocks = nn.Sequential(*[
block_layer(
@ -224,6 +227,8 @@ class MlpMixer(nn.Module):
drop_path=drop_path_rate,
)
for _ in range(num_blocks)])
self.feature_info = [
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(num_blocks)]
self.norm = norm_layer(embed_dim)
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
@ -257,6 +262,76 @@ class MlpMixer(nn.Module):
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
return_prefix_tokens: Return both prefix and spatial intermediate tokens
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
# forward pass
B, _, height, width = x.shape
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x)
if i in take_indices:
# normalize intermediates with final norm layer if enabled
intermediates.append(self.norm(x) if norm else x)
# process intermediates
if reshape:
# reshape to BCHW output format
H, W = self.stem.dynamic_feat_size((height, width))
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
if intermediates_only:
return intermediates
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
@ -330,14 +405,13 @@ def checkpoint_filter_fn(state_dict, model):
def _create_mixer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for MLP-Mixer models.')
out_indices = kwargs.pop('out_indices', 3)
model = build_model_with_cfg(
MlpMixer,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
return model

View File

@ -26,7 +26,7 @@ Hacked together by / Copyright 2020 Ross Wightman
import math
from dataclasses import dataclass, replace
from functools import partial
from typing import Optional, Union, Callable
from typing import Callable, List, Optional, Union, Tuple
import numpy as np
import torch
@ -36,6 +36,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct
from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d, make_divisible
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq, named_apply
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -515,6 +516,73 @@ class RegNet(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.head.reset(num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(5, indices)
# forward pass
feat_idx = 0
x = self.stem(x)
if feat_idx in take_indices:
intermediates.append(x)
layer_names = ('s1', 's2', 's3', 's4')
if stop_early:
layer_names = layer_names[:max_index]
for n in layer_names:
feat_idx += 1
x = getattr(self, n)(x) # won't work with torchscript, but keeps code reasonable, FML
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
if feat_idx == 4:
x = self.final_conv(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(5, indices)
layer_names = ('s1', 's2', 's3', 's4')
layer_names = layer_names[max_index:]
for n in layer_names:
setattr(self, n, nn.Identity())
if max_index < 4:
self.final_conv = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
x = self.s1(x)