mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
forward_intermediates() for MlpMixer models and RegNet.
This commit is contained in:
parent
f8979d4f50
commit
01dd01b70e
@ -40,6 +40,7 @@ Hacked together by / Copyright 2021 Ross Wightman
|
|||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import List, Optional, Union, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -47,6 +48,7 @@ import torch.nn as nn
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
|
from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import named_apply, checkpoint_seq
|
from ._manipulate import named_apply, checkpoint_seq
|
||||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||||
|
|
||||||
@ -211,6 +213,7 @@ class MlpMixer(nn.Module):
|
|||||||
embed_dim=embed_dim,
|
embed_dim=embed_dim,
|
||||||
norm_layer=norm_layer if stem_norm else None,
|
norm_layer=norm_layer if stem_norm else None,
|
||||||
)
|
)
|
||||||
|
reduction = self.stem.feat_ratio() if hasattr(self.stem, 'feat_ratio') else patch_size
|
||||||
# FIXME drop_path (stochastic depth scaling rule or all the same?)
|
# FIXME drop_path (stochastic depth scaling rule or all the same?)
|
||||||
self.blocks = nn.Sequential(*[
|
self.blocks = nn.Sequential(*[
|
||||||
block_layer(
|
block_layer(
|
||||||
@ -224,6 +227,8 @@ class MlpMixer(nn.Module):
|
|||||||
drop_path=drop_path_rate,
|
drop_path=drop_path_rate,
|
||||||
)
|
)
|
||||||
for _ in range(num_blocks)])
|
for _ in range(num_blocks)])
|
||||||
|
self.feature_info = [
|
||||||
|
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(num_blocks)]
|
||||||
self.norm = norm_layer(embed_dim)
|
self.norm = norm_layer(embed_dim)
|
||||||
self.head_drop = nn.Dropout(drop_rate)
|
self.head_drop = nn.Dropout(drop_rate)
|
||||||
self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||||
@ -257,6 +262,76 @@ class MlpMixer(nn.Module):
|
|||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
return_prefix_tokens: Return both prefix and spatial intermediate tokens
|
||||||
|
norm: Apply norm layer to all intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
|
||||||
|
reshape = output_fmt == 'NCHW'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
B, _, height, width = x.shape
|
||||||
|
x = self.stem(x)
|
||||||
|
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
blocks = self.blocks
|
||||||
|
else:
|
||||||
|
blocks = self.blocks[:max_index + 1]
|
||||||
|
for i, blk in enumerate(blocks):
|
||||||
|
x = blk(x)
|
||||||
|
if i in take_indices:
|
||||||
|
# normalize intermediates with final norm layer if enabled
|
||||||
|
intermediates.append(self.norm(x) if norm else x)
|
||||||
|
|
||||||
|
# process intermediates
|
||||||
|
if reshape:
|
||||||
|
# reshape to BCHW output format
|
||||||
|
H, W = self.stem.dynamic_feat_size((height, width))
|
||||||
|
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||||
|
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||||
|
if prune_norm:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
@ -330,14 +405,13 @@ def checkpoint_filter_fn(state_dict, model):
|
|||||||
|
|
||||||
|
|
||||||
def _create_mixer(variant, pretrained=False, **kwargs):
|
def _create_mixer(variant, pretrained=False, **kwargs):
|
||||||
if kwargs.get('features_only', None):
|
out_indices = kwargs.pop('out_indices', 3)
|
||||||
raise RuntimeError('features_only not implemented for MLP-Mixer models.')
|
|
||||||
|
|
||||||
model = build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
MlpMixer,
|
MlpMixer,
|
||||||
variant,
|
variant,
|
||||||
pretrained,
|
pretrained,
|
||||||
pretrained_filter_fn=checkpoint_filter_fn,
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||||||
|
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
@ -26,7 +26,7 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||||||
import math
|
import math
|
||||||
from dataclasses import dataclass, replace
|
from dataclasses import dataclass, replace
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional, Union, Callable
|
from typing import Callable, List, Optional, Union, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -36,6 +36,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct
|
from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct
|
||||||
from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d, make_divisible
|
from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d, make_divisible
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import checkpoint_seq, named_apply
|
from ._manipulate import checkpoint_seq, named_apply
|
||||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||||
|
|
||||||
@ -515,6 +516,73 @@ class RegNet(nn.Module):
|
|||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||||
self.head.reset(num_classes, pool_type=global_pool)
|
self.head.reset(num_classes, pool_type=global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(5, indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
feat_idx = 0
|
||||||
|
x = self.stem(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
layer_names = ('s1', 's2', 's3', 's4')
|
||||||
|
if stop_early:
|
||||||
|
layer_names = layer_names[:max_index]
|
||||||
|
for n in layer_names:
|
||||||
|
feat_idx += 1
|
||||||
|
x = getattr(self, n)(x) # won't work with torchscript, but keeps code reasonable, FML
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
if feat_idx == 4:
|
||||||
|
x = self.final_conv(x)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(5, indices)
|
||||||
|
layer_names = ('s1', 's2', 's3', 's4')
|
||||||
|
layer_names = layer_names[max_index:]
|
||||||
|
for n in layer_names:
|
||||||
|
setattr(self, n, nn.Identity())
|
||||||
|
if max_index < 4:
|
||||||
|
self.final_conv = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
x = self.s1(x)
|
x = self.s1(x)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user