mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add grad_checkpointing support to features_only, test in EfficientDet.
This commit is contained in:
parent
45af496197
commit
2cfff0581b
@ -11,10 +11,11 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
|
||||
@ -88,12 +89,20 @@ class FeatureHooks:
|
||||
""" Feature Hook Helper
|
||||
|
||||
This module helps with the setup and extraction of hooks for extracting features from
|
||||
internal nodes in a model by node name. This works quite well in eager Python but needs
|
||||
redesign for torchscript.
|
||||
internal nodes in a model by node name.
|
||||
|
||||
FIXME This works well in eager Python but needs redesign for torchscript.
|
||||
"""
|
||||
|
||||
def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
|
||||
def __init__(
|
||||
self,
|
||||
hooks: Sequence[str],
|
||||
named_modules: dict,
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
default_hook_type: str = 'forward',
|
||||
):
|
||||
# setup feature hooks
|
||||
self._feature_outputs = defaultdict(OrderedDict)
|
||||
modules = {k: v for k, v in named_modules}
|
||||
for i, h in enumerate(hooks):
|
||||
hook_name = h['module']
|
||||
@ -107,7 +116,6 @@ class FeatureHooks:
|
||||
m.register_forward_hook(hook_fn)
|
||||
else:
|
||||
assert False, "Unsupported hook type"
|
||||
self._feature_outputs = defaultdict(OrderedDict)
|
||||
|
||||
def _collect_output_hook(self, hook_id, *args):
|
||||
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
||||
@ -167,23 +175,30 @@ class FeatureDictNet(nn.ModuleDict):
|
||||
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
|
||||
All Sequential containers that are directly assigned to the original model will have their
|
||||
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
||||
|
||||
Arguments:
|
||||
model (nn.Module): model from which we will extract the features
|
||||
out_indices (tuple[int]): model output indices to extract features for
|
||||
out_map (sequence): list or tuple specifying desired return id for each out index,
|
||||
otherwise str(index) is used
|
||||
feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
|
||||
vs select element [0]
|
||||
flatten_sequential (bool): whether to flatten sequential modules assigned to model
|
||||
"""
|
||||
def __init__(
|
||||
self, model,
|
||||
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
feature_concat: bool = False,
|
||||
flatten_sequential: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model: Model from which to extract features.
|
||||
out_indices: Output indices of the model features to extract.
|
||||
out_map: Return id mapping for each output index, otherwise str(index) is used.
|
||||
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
|
||||
first element e.g. `x[0]`
|
||||
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
|
||||
"""
|
||||
super(FeatureDictNet, self).__init__()
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.concat = feature_concat
|
||||
self.grad_checkpointing = False
|
||||
self.return_layers = {}
|
||||
|
||||
return_layers = _get_return_layers(self.feature_info, out_map)
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
remaining = set(return_layers.keys())
|
||||
@ -200,10 +215,21 @@ class FeatureDictNet(nn.ModuleDict):
|
||||
f'Return layers ({remaining}) are not present in model'
|
||||
self.update(layers)
|
||||
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
def _collect(self, x) -> (Dict[str, torch.Tensor]):
|
||||
out = OrderedDict()
|
||||
for name, module in self.items():
|
||||
x = module(x)
|
||||
for i, (name, module) in enumerate(self.items()):
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
# Skipping checkpoint of first module because need a gradient at input
|
||||
# Skipping last because networks with in-place ops might fail w/ checkpointing enabled
|
||||
# NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
|
||||
first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
|
||||
x = module(x) if first_or_last_module else checkpoint(module, x)
|
||||
else:
|
||||
x = module(x)
|
||||
|
||||
if name in self.return_layers:
|
||||
out_id = self.return_layers[name]
|
||||
if isinstance(x, (tuple, list)):
|
||||
@ -221,15 +247,29 @@ class FeatureDictNet(nn.ModuleDict):
|
||||
class FeatureListNet(FeatureDictNet):
|
||||
""" Feature extractor with list return
|
||||
|
||||
See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
|
||||
In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
|
||||
A specialization of FeatureDictNet that always returns features as a list (values() of dict).
|
||||
"""
|
||||
def __init__(
|
||||
self, model,
|
||||
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
feature_concat: bool = False,
|
||||
flatten_sequential: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model: Model from which to extract features.
|
||||
out_indices: Output indices of the model features to extract.
|
||||
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
|
||||
first element e.g. `x[0]`
|
||||
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
|
||||
"""
|
||||
super(FeatureListNet, self).__init__(
|
||||
model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
|
||||
flatten_sequential=flatten_sequential)
|
||||
model,
|
||||
out_indices=out_indices,
|
||||
feature_concat=feature_concat,
|
||||
flatten_sequential=flatten_sequential,
|
||||
)
|
||||
|
||||
def forward(self, x) -> (List[torch.Tensor]):
|
||||
return list(self._collect(x).values())
|
||||
@ -249,13 +289,33 @@ class FeatureHookNet(nn.ModuleDict):
|
||||
FIXME this does not currently work with Torchscript, see FeatureHooks class
|
||||
"""
|
||||
def __init__(
|
||||
self, model,
|
||||
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
|
||||
feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
out_as_dict: bool = False,
|
||||
no_rewrite: bool = False,
|
||||
flatten_sequential: bool = False,
|
||||
default_hook_type: str = 'forward',
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
model: Model from which to extract features.
|
||||
out_indices: Output indices of the model features to extract.
|
||||
out_map: Return id mapping for each output index, otherwise str(index) is used.
|
||||
out_as_dict: Output features as a dict.
|
||||
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
|
||||
flatten_sequential arg must also be False if this is set True.
|
||||
flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers.
|
||||
default_hook_type: The default hook type to use if not specified in model.feature_info.
|
||||
"""
|
||||
super(FeatureHookNet, self).__init__()
|
||||
assert not torch.jit.is_scripting()
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.out_as_dict = out_as_dict
|
||||
self.grad_checkpointing = False
|
||||
|
||||
layers = OrderedDict()
|
||||
hooks = []
|
||||
if no_rewrite:
|
||||
@ -266,8 +326,10 @@ class FeatureHookNet(nn.ModuleDict):
|
||||
hooks.extend(self.feature_info.get_dicts())
|
||||
else:
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
||||
for f in self.feature_info.get_dicts()}
|
||||
remaining = {
|
||||
f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
||||
for f in self.feature_info.get_dicts()
|
||||
}
|
||||
for new_name, old_name, module in modules:
|
||||
layers[new_name] = module
|
||||
for fn, fm in module.named_modules(prefix=old_name):
|
||||
@ -280,8 +342,18 @@ class FeatureHookNet(nn.ModuleDict):
|
||||
self.update(layers)
|
||||
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
|
||||
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
def forward(self, x):
|
||||
for name, module in self.items():
|
||||
x = module(x)
|
||||
for i, (name, module) in enumerate(self.items()):
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
# Skipping checkpoint of first module because need a gradient at input
|
||||
# Skipping last because networks with in-place ops might fail w/ checkpointing enabled
|
||||
# NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
|
||||
first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
|
||||
x = module(x) if first_or_last_module else checkpoint(module, x)
|
||||
else:
|
||||
x = module(x)
|
||||
out = self.hooks.get_output(x.device)
|
||||
return out if self.out_as_dict else list(out.values())
|
||||
|
@ -41,6 +41,7 @@ from typing import List
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct
|
||||
@ -211,6 +212,7 @@ class EfficientNetFeatures(nn.Module):
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
se_layer = se_layer or SqueezeExcite
|
||||
self.drop_rate = drop_rate
|
||||
self.grad_checkpointing = False
|
||||
|
||||
# Stem
|
||||
if not fix_stem:
|
||||
@ -241,6 +243,10 @@ class EfficientNetFeatures(nn.Module):
|
||||
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
|
||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
def forward(self, x) -> List[torch.Tensor]:
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
@ -249,7 +255,10 @@ class EfficientNetFeatures(nn.Module):
|
||||
if 0 in self._stage_out_idx:
|
||||
features.append(x) # add stem out
|
||||
for i, b in enumerate(self.blocks):
|
||||
x = b(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint(b, x)
|
||||
else:
|
||||
x = b(x)
|
||||
if i + 1 in self._stage_out_idx:
|
||||
features.append(x)
|
||||
return features
|
||||
|
@ -12,6 +12,7 @@ from typing import List
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer
|
||||
@ -188,6 +189,7 @@ class MobileNetV3Features(nn.Module):
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
se_layer = se_layer or SqueezeExcite
|
||||
self.drop_rate = drop_rate
|
||||
self.grad_checkpointing = False
|
||||
|
||||
# Stem
|
||||
if not fix_stem:
|
||||
@ -220,6 +222,10 @@ class MobileNetV3Features(nn.Module):
|
||||
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
|
||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
def forward(self, x) -> List[torch.Tensor]:
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
@ -229,7 +235,10 @@ class MobileNetV3Features(nn.Module):
|
||||
if 0 in self._stage_out_idx:
|
||||
features.append(x) # add stem out
|
||||
for i, b in enumerate(self.blocks):
|
||||
x = b(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint(b, x)
|
||||
else:
|
||||
x = b(x)
|
||||
if i + 1 in self._stage_out_idx:
|
||||
features.append(x)
|
||||
return features
|
||||
|
Loading…
x
Reference in New Issue
Block a user