Add forward_intermediates() to efficientnet / mobilenetv3 based models as an exercise.

This commit is contained in:
Ross Wightman 2024-05-02 14:19:16 -07:00
parent c22efb9765
commit d6da4fb01e
4 changed files with 182 additions and 12 deletions

View File

@ -49,8 +49,9 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
# models with forward_intermediates() and support for FeatureGetterNet features_only wrapper # models with forward_intermediates() and support for FeatureGetterNet features_only wrapper
FEAT_INTER_FILTERS = [ FEAT_INTER_FILTERS = [
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*', 'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
'cait_*', 'xcit_*', 'volo_*', 'swin*', 'max*vit_*', 'coatne*t_*' 'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3'
] ]
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
@ -388,7 +389,7 @@ def test_model_forward_features(model_name, batch_size):
@pytest.mark.features @pytest.mark.features
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('model_name', list_models(module=FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS + ['*pruned*']))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
def test_model_forward_intermediates_features(model_name, batch_size): def test_model_forward_intermediates_features(model_name, batch_size):
"""Run a single forward pass with each model in feature extraction mode""" """Run a single forward pass with each model in feature extraction mode"""
@ -419,7 +420,7 @@ def test_model_forward_intermediates_features(model_name, batch_size):
@pytest.mark.features @pytest.mark.features
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('model_name', list_models(module=FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS + ['*pruned*']))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
def test_model_forward_intermediates(model_name, batch_size): def test_model_forward_intermediates(model_name, batch_size):
"""Run a single forward pass with each model in feature extraction mode""" """Run a single forward pass with each model in feature extraction mode"""

View File

@ -184,7 +184,7 @@ def _expand_filter(filter: str):
def list_models( def list_models(
filter: Union[str, List[str]] = '', filter: Union[str, List[str]] = '',
module: str = '', module: Union[str, List[str]] = '',
pretrained: bool = False, pretrained: bool = False,
exclude_filters: Union[str, List[str]] = '', exclude_filters: Union[str, List[str]] = '',
name_matches_cfg: bool = False, name_matches_cfg: bool = False,
@ -217,7 +217,16 @@ def list_models(
# FIXME should this be default behaviour? or default to include_tags=True? # FIXME should this be default behaviour? or default to include_tags=True?
include_tags = pretrained include_tags = pretrained
all_models: Set[str] = _module_to_models[module] if module else set(_model_entrypoints.keys()) if not module:
all_models: Set[str] = set(_model_entrypoints.keys())
else:
if isinstance(module, str):
all_models: Set[str] = _module_to_models[module]
else:
assert isinstance(module, Sequence)
all_models: Set[str] = set()
for m in module:
all_models.update(_module_to_models[m])
all_models = all_models - _deprecated_models.keys() # remove deprecated models from listings all_models = all_models - _deprecated_models.keys() # remove deprecated models from listings
if include_tags: if include_tags:

View File

@ -36,7 +36,7 @@ the models and weights open source!
Hacked together by / Copyright 2019, Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
from functools import partial from functools import partial
from typing import List from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -49,7 +49,7 @@ from ._builder import build_model_with_cfg, pretrained_cfg_for_features
from ._efficientnet_blocks import SqueezeExcite from ._efficientnet_blocks import SqueezeExcite
from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks from ._features import FeatureInfo, FeatureHooks, feature_take_indices
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model, register_model_deprecations from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -118,6 +118,7 @@ class EfficientNet(nn.Module):
) )
self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = builder.features self.feature_info = builder.features
self.stage_ends = [f['stage'] for f in self.feature_info]
head_chs = builder.in_chs head_chs = builder.in_chs
# Head + Pooling # Head + Pooling
@ -158,6 +159,86 @@ class EfficientNet(nn.Module):
self.global_pool, self.classifier = create_classifier( self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool) self.num_features, self.num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
*,
indices: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
extra_blocks: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
if stop_early:
assert intermediates_only, 'Must use intermediates_only for early stopping.'
intermediates = []
if extra_blocks:
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
else:
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
take_indices = [self.stage_ends[i] for i in take_indices]
max_index = self.stage_ends[max_index]
# forward pass
feat_idx = 0 # stem is index 0
x = self.conv_stem(x)
x = self.bn1(x)
if feat_idx in take_indices:
intermediates.append(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index]
for blk in blocks:
feat_idx += 1
x = blk(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
x = self.conv_head(x)
x = self.bn2(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
extra_blocks: bool = False,
):
""" Prune layers not required for specified intermediates.
"""
if extra_blocks:
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
else:
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
max_index = self.stage_ends[max_index]
self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0
if prune_norm or max_index < len(self.blocks):
self.conv_head = nn.Identity()
self.bn2 = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.conv_stem(x) x = self.conv_stem(x)
x = self.bn1(x) x = self.bn1(x)
@ -272,7 +353,7 @@ def _create_effnet(variant, pretrained=False, **kwargs):
model_cls = EfficientNet model_cls = EfficientNet
kwargs_filter = None kwargs_filter = None
if kwargs.pop('features_only', False): if kwargs.pop('features_only', False):
if 'feature_cfg' in kwargs: if 'feature_cfg' in kwargs or 'feature_cls' in kwargs:
features_mode = 'cfg' features_mode = 'cfg'
else: else:
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool') kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool')

View File

@ -7,7 +7,7 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
Hacked together by / Copyright 2019, Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
from functools import partial from functools import partial
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -20,7 +20,7 @@ from ._builder import build_model_with_cfg, pretrained_cfg_for_features
from ._efficientnet_blocks import SqueezeExcite from ._efficientnet_blocks import SqueezeExcite
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks from ._features import FeatureInfo, FeatureHooks, feature_take_indices
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model, register_model_deprecations from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -109,6 +109,7 @@ class MobileNetV3(nn.Module):
) )
self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = builder.features self.feature_info = builder.features
self.stage_ends = [f['stage'] for f in self.feature_info]
head_chs = builder.in_chs head_chs = builder.in_chs
# Head + Pooling # Head + Pooling
@ -150,6 +151,84 @@ class MobileNetV3(nn.Module):
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
*,
indices: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
extra_blocks: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
if stop_early:
assert intermediates_only, 'Must use intermediates_only for early stopping.'
intermediates = []
if extra_blocks:
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
else:
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
print(take_indices, self.stage_ends)
take_indices = [self.stage_ends[i] for i in take_indices]
max_index = self.stage_ends[max_index]
# forward pass
feat_idx = 0 # stem is index 0
x = self.conv_stem(x)
x = self.bn1(x)
if feat_idx in take_indices:
intermediates.append(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index]
for blk in blocks:
feat_idx += 1
x = blk(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
extra_blocks: bool = False,
):
""" Prune layers not required for specified intermediates.
"""
if extra_blocks:
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
else:
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
max_index = self.stage_ends[max_index]
self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0
if max_index < len(self.blocks):
self.conv_head = nn.Identity()
if prune_head:
self.conv_head = nn.Identity()
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x: torch.Tensor) -> torch.Tensor: def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv_stem(x) x = self.conv_stem(x)
x = self.bn1(x) x = self.bn1(x)
@ -288,7 +367,7 @@ def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV
model_cls = MobileNetV3 model_cls = MobileNetV3
kwargs_filter = None kwargs_filter = None
if kwargs.pop('features_only', False): if kwargs.pop('features_only', False):
if 'feature_cfg' in kwargs: if 'feature_cfg' in kwargs or 'feature_cls' in kwargs:
features_mode = 'cfg' features_mode = 'cfg'
else: else:
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool') kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')