2021-08-12 22:31:02 +08:00
|
|
|
""" PyTorch FX Based Feature Extraction Helpers
|
2021-11-13 04:42:45 +08:00
|
|
|
Using https://pytorch.org/vision/stable/feature_extraction.html
|
2021-08-12 22:31:02 +08:00
|
|
|
"""
|
2022-07-23 05:59:55 +08:00
|
|
|
from typing import Callable, List, Dict, Union, Type
|
2022-01-26 13:54:13 +08:00
|
|
|
|
|
|
|
import torch
|
2021-08-12 22:31:02 +08:00
|
|
|
from torch import nn
|
|
|
|
|
2023-06-15 05:46:22 +08:00
|
|
|
from ._features import _get_feature_info, _get_return_layers
|
2021-08-12 22:31:02 +08:00
|
|
|
|
2021-11-13 04:42:45 +08:00
|
|
|
try:
|
2022-01-26 13:54:13 +08:00
|
|
|
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
|
2021-11-13 05:16:53 +08:00
|
|
|
has_fx_feature_extraction = True
|
2021-11-13 04:42:45 +08:00
|
|
|
except ImportError:
|
2021-11-13 05:16:53 +08:00
|
|
|
has_fx_feature_extraction = False
|
2021-11-13 04:42:45 +08:00
|
|
|
|
|
|
|
# Layers we went to treat as leaf modules
|
2022-12-07 07:00:06 +08:00
|
|
|
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
|
|
|
|
from timm.layers.non_local_attn import BilinearAttnTransform
|
|
|
|
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
2021-08-12 22:31:02 +08:00
|
|
|
|
2023-03-16 14:21:51 +08:00
|
|
|
__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
|
|
|
|
'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
|
|
|
|
'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet']
|
|
|
|
|
|
|
|
|
2021-11-13 04:42:45 +08:00
|
|
|
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
|
|
|
|
# BUT modules from timm.models should use the registration mechanism below
|
2021-08-12 22:31:02 +08:00
|
|
|
_leaf_modules = {
|
2021-11-13 04:42:45 +08:00
|
|
|
BilinearAttnTransform, # reason: flow control t <= 1
|
|
|
|
# Reason: get_same_padding has a max which raises a control flow error
|
2021-12-02 04:07:45 +08:00
|
|
|
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
|
2021-11-13 04:42:45 +08:00
|
|
|
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
|
2021-08-12 22:31:02 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
try:
|
2022-12-07 07:00:06 +08:00
|
|
|
from timm.layers import InplaceAbn
|
2021-08-12 22:31:02 +08:00
|
|
|
_leaf_modules.add(InplaceAbn)
|
|
|
|
except ImportError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2022-07-23 05:59:55 +08:00
|
|
|
def register_notrace_module(module: Type[nn.Module]):
|
2021-08-12 22:31:02 +08:00
|
|
|
"""
|
|
|
|
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
|
|
|
|
"""
|
|
|
|
_leaf_modules.add(module)
|
|
|
|
return module
|
|
|
|
|
|
|
|
|
2023-03-16 14:21:51 +08:00
|
|
|
def is_notrace_module(module: Type[nn.Module]):
|
|
|
|
return module in _leaf_modules
|
|
|
|
|
|
|
|
|
|
|
|
def get_notrace_modules():
|
|
|
|
return list(_leaf_modules)
|
|
|
|
|
|
|
|
|
2021-11-13 04:42:45 +08:00
|
|
|
# Functions we want to autowrap (treat them as leaves)
|
|
|
|
_autowrap_functions = set()
|
2021-08-12 22:31:02 +08:00
|
|
|
|
|
|
|
|
2021-11-16 05:03:21 +08:00
|
|
|
def register_notrace_function(func: Callable):
|
2021-08-12 22:31:02 +08:00
|
|
|
"""
|
2021-11-13 04:42:45 +08:00
|
|
|
Decorator for functions which ought not to be traced through
|
2021-08-29 00:54:22 +08:00
|
|
|
"""
|
2021-11-13 04:42:45 +08:00
|
|
|
_autowrap_functions.add(func)
|
|
|
|
return func
|
2021-08-12 22:31:02 +08:00
|
|
|
|
|
|
|
|
2023-03-16 14:21:51 +08:00
|
|
|
def is_notrace_function(func: Callable):
|
|
|
|
return func in _autowrap_functions
|
|
|
|
|
|
|
|
|
|
|
|
def get_notrace_functions():
|
|
|
|
return list(_autowrap_functions)
|
|
|
|
|
|
|
|
|
2022-01-26 13:54:13 +08:00
|
|
|
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
|
|
|
|
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
|
|
|
return _create_feature_extractor(
|
|
|
|
model, return_nodes,
|
|
|
|
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2021-08-12 22:31:02 +08:00
|
|
|
class FeatureGraphNet(nn.Module):
|
2022-01-26 13:54:13 +08:00
|
|
|
""" A FX Graph based feature extractor that works with the model feature_info metadata
|
|
|
|
"""
|
2021-08-12 22:31:02 +08:00
|
|
|
def __init__(self, model, out_indices, out_map=None):
|
|
|
|
super().__init__()
|
2021-11-13 05:16:53 +08:00
|
|
|
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
2021-08-12 22:31:02 +08:00
|
|
|
self.feature_info = _get_feature_info(model, out_indices)
|
|
|
|
if out_map is not None:
|
|
|
|
assert len(out_map) == len(out_indices)
|
2023-06-15 05:46:22 +08:00
|
|
|
return_nodes = _get_return_layers(self.feature_info, out_map)
|
2022-01-26 13:54:13 +08:00
|
|
|
self.graph_module = create_feature_extractor(model, return_nodes)
|
2021-11-13 05:16:53 +08:00
|
|
|
|
2021-08-12 22:31:02 +08:00
|
|
|
def forward(self, x):
|
2021-08-12 22:31:02 +08:00
|
|
|
return list(self.graph_module(x).values())
|
2022-01-26 13:54:13 +08:00
|
|
|
|
|
|
|
|
2022-03-01 05:56:23 +08:00
|
|
|
class GraphExtractNet(nn.Module):
|
2022-01-26 13:54:13 +08:00
|
|
|
""" A standalone feature extraction wrapper that maps dict -> list or single tensor
|
|
|
|
NOTE:
|
|
|
|
* one can use feature_extractor directly if dictionary output is desired
|
|
|
|
* unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
|
|
|
|
metadata for builtin feature extraction mode
|
2022-03-01 05:56:23 +08:00
|
|
|
* create_feature_extractor can be used directly if dictionary output is acceptable
|
2022-01-26 13:54:13 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
model: model to extract features from
|
|
|
|
return_nodes: node names to return features from (dict or list)
|
|
|
|
squeeze_out: if only one output, and output in list format, flatten to single tensor
|
|
|
|
"""
|
|
|
|
def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True):
|
|
|
|
super().__init__()
|
|
|
|
self.squeeze_out = squeeze_out
|
|
|
|
self.graph_module = create_feature_extractor(model, return_nodes)
|
|
|
|
|
|
|
|
def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
|
|
|
|
out = list(self.graph_module(x).values())
|
|
|
|
if self.squeeze_out and len(out) == 1:
|
|
|
|
return out[0]
|
|
|
|
return out
|