mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
register_notrace_module is used to decorate types (i.e. subclasses of nn.Module). It is not called on module instances.
107 lines
4.1 KiB
Python
107 lines
4.1 KiB
Python
""" PyTorch FX Based Feature Extraction Helpers
|
|
Using https://pytorch.org/vision/stable/feature_extraction.html
|
|
"""
|
|
from typing import Callable, List, Dict, Union, Type
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from .features import _get_feature_info
|
|
|
|
try:
|
|
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
|
|
has_fx_feature_extraction = True
|
|
except ImportError:
|
|
has_fx_feature_extraction = False
|
|
|
|
# Layers we went to treat as leaf modules
|
|
from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
|
|
from .layers.non_local_attn import BilinearAttnTransform
|
|
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
|
|
|
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
|
|
# BUT modules from timm.models should use the registration mechanism below
|
|
_leaf_modules = {
|
|
BilinearAttnTransform, # reason: flow control t <= 1
|
|
# Reason: get_same_padding has a max which raises a control flow error
|
|
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
|
|
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
|
|
}
|
|
|
|
try:
|
|
from .layers import InplaceAbn
|
|
_leaf_modules.add(InplaceAbn)
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
def register_notrace_module(module: Type[nn.Module]):
|
|
"""
|
|
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
|
|
"""
|
|
_leaf_modules.add(module)
|
|
return module
|
|
|
|
|
|
# Functions we want to autowrap (treat them as leaves)
|
|
_autowrap_functions = set()
|
|
|
|
|
|
def register_notrace_function(func: Callable):
|
|
"""
|
|
Decorator for functions which ought not to be traced through
|
|
"""
|
|
_autowrap_functions.add(func)
|
|
return func
|
|
|
|
|
|
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
|
|
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
|
return _create_feature_extractor(
|
|
model, return_nodes,
|
|
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
|
|
)
|
|
|
|
|
|
class FeatureGraphNet(nn.Module):
|
|
""" A FX Graph based feature extractor that works with the model feature_info metadata
|
|
"""
|
|
def __init__(self, model, out_indices, out_map=None):
|
|
super().__init__()
|
|
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
|
self.feature_info = _get_feature_info(model, out_indices)
|
|
if out_map is not None:
|
|
assert len(out_map) == len(out_indices)
|
|
return_nodes = {
|
|
info['module']: out_map[i] if out_map is not None else info['module']
|
|
for i, info in enumerate(self.feature_info) if i in out_indices}
|
|
self.graph_module = create_feature_extractor(model, return_nodes)
|
|
|
|
def forward(self, x):
|
|
return list(self.graph_module(x).values())
|
|
|
|
|
|
class GraphExtractNet(nn.Module):
|
|
""" A standalone feature extraction wrapper that maps dict -> list or single tensor
|
|
NOTE:
|
|
* one can use feature_extractor directly if dictionary output is desired
|
|
* unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
|
|
metadata for builtin feature extraction mode
|
|
* create_feature_extractor can be used directly if dictionary output is acceptable
|
|
|
|
Args:
|
|
model: model to extract features from
|
|
return_nodes: node names to return features from (dict or list)
|
|
squeeze_out: if only one output, and output in list format, flatten to single tensor
|
|
"""
|
|
def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True):
|
|
super().__init__()
|
|
self.squeeze_out = squeeze_out
|
|
self.graph_module = create_feature_extractor(model, return_nodes)
|
|
|
|
def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
|
|
out = list(self.graph_module(x).values())
|
|
if self.squeeze_out and len(out) == 1:
|
|
return out[0]
|
|
return out
|