mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add grad_checkpointing support to features_only, test in EfficientDet.
This commit is contained in:
parent
45af496197
commit
2cfff0581b
@ -11,10 +11,11 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
|
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
|
||||||
@ -88,12 +89,20 @@ class FeatureHooks:
|
|||||||
""" Feature Hook Helper
|
""" Feature Hook Helper
|
||||||
|
|
||||||
This module helps with the setup and extraction of hooks for extracting features from
|
This module helps with the setup and extraction of hooks for extracting features from
|
||||||
internal nodes in a model by node name. This works quite well in eager Python but needs
|
internal nodes in a model by node name.
|
||||||
redesign for torchscript.
|
|
||||||
|
FIXME This works well in eager Python but needs redesign for torchscript.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hooks: Sequence[str],
|
||||||
|
named_modules: dict,
|
||||||
|
out_map: Sequence[Union[int, str]] = None,
|
||||||
|
default_hook_type: str = 'forward',
|
||||||
|
):
|
||||||
# setup feature hooks
|
# setup feature hooks
|
||||||
|
self._feature_outputs = defaultdict(OrderedDict)
|
||||||
modules = {k: v for k, v in named_modules}
|
modules = {k: v for k, v in named_modules}
|
||||||
for i, h in enumerate(hooks):
|
for i, h in enumerate(hooks):
|
||||||
hook_name = h['module']
|
hook_name = h['module']
|
||||||
@ -107,7 +116,6 @@ class FeatureHooks:
|
|||||||
m.register_forward_hook(hook_fn)
|
m.register_forward_hook(hook_fn)
|
||||||
else:
|
else:
|
||||||
assert False, "Unsupported hook type"
|
assert False, "Unsupported hook type"
|
||||||
self._feature_outputs = defaultdict(OrderedDict)
|
|
||||||
|
|
||||||
def _collect_output_hook(self, hook_id, *args):
|
def _collect_output_hook(self, hook_id, *args):
|
||||||
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
||||||
@ -167,23 +175,30 @@ class FeatureDictNet(nn.ModuleDict):
|
|||||||
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
|
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
|
||||||
All Sequential containers that are directly assigned to the original model will have their
|
All Sequential containers that are directly assigned to the original model will have their
|
||||||
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
||||||
|
|
||||||
Arguments:
|
|
||||||
model (nn.Module): model from which we will extract the features
|
|
||||||
out_indices (tuple[int]): model output indices to extract features for
|
|
||||||
out_map (sequence): list or tuple specifying desired return id for each out index,
|
|
||||||
otherwise str(index) is used
|
|
||||||
feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
|
|
||||||
vs select element [0]
|
|
||||||
flatten_sequential (bool): whether to flatten sequential modules assigned to model
|
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model,
|
self,
|
||||||
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
model: nn.Module,
|
||||||
|
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||||
|
out_map: Sequence[Union[int, str]] = None,
|
||||||
|
feature_concat: bool = False,
|
||||||
|
flatten_sequential: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
model: Model from which to extract features.
|
||||||
|
out_indices: Output indices of the model features to extract.
|
||||||
|
out_map: Return id mapping for each output index, otherwise str(index) is used.
|
||||||
|
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
|
||||||
|
first element e.g. `x[0]`
|
||||||
|
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
|
||||||
|
"""
|
||||||
super(FeatureDictNet, self).__init__()
|
super(FeatureDictNet, self).__init__()
|
||||||
self.feature_info = _get_feature_info(model, out_indices)
|
self.feature_info = _get_feature_info(model, out_indices)
|
||||||
self.concat = feature_concat
|
self.concat = feature_concat
|
||||||
|
self.grad_checkpointing = False
|
||||||
self.return_layers = {}
|
self.return_layers = {}
|
||||||
|
|
||||||
return_layers = _get_return_layers(self.feature_info, out_map)
|
return_layers = _get_return_layers(self.feature_info, out_map)
|
||||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||||
remaining = set(return_layers.keys())
|
remaining = set(return_layers.keys())
|
||||||
@ -200,10 +215,21 @@ class FeatureDictNet(nn.ModuleDict):
|
|||||||
f'Return layers ({remaining}) are not present in model'
|
f'Return layers ({remaining}) are not present in model'
|
||||||
self.update(layers)
|
self.update(layers)
|
||||||
|
|
||||||
|
def set_grad_checkpointing(self, enable: bool = True):
|
||||||
|
self.grad_checkpointing = enable
|
||||||
|
|
||||||
def _collect(self, x) -> (Dict[str, torch.Tensor]):
|
def _collect(self, x) -> (Dict[str, torch.Tensor]):
|
||||||
out = OrderedDict()
|
out = OrderedDict()
|
||||||
for name, module in self.items():
|
for i, (name, module) in enumerate(self.items()):
|
||||||
x = module(x)
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
# Skipping checkpoint of first module because need a gradient at input
|
||||||
|
# Skipping last because networks with in-place ops might fail w/ checkpointing enabled
|
||||||
|
# NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
|
||||||
|
first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
|
||||||
|
x = module(x) if first_or_last_module else checkpoint(module, x)
|
||||||
|
else:
|
||||||
|
x = module(x)
|
||||||
|
|
||||||
if name in self.return_layers:
|
if name in self.return_layers:
|
||||||
out_id = self.return_layers[name]
|
out_id = self.return_layers[name]
|
||||||
if isinstance(x, (tuple, list)):
|
if isinstance(x, (tuple, list)):
|
||||||
@ -221,15 +247,29 @@ class FeatureDictNet(nn.ModuleDict):
|
|||||||
class FeatureListNet(FeatureDictNet):
|
class FeatureListNet(FeatureDictNet):
|
||||||
""" Feature extractor with list return
|
""" Feature extractor with list return
|
||||||
|
|
||||||
See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
|
A specialization of FeatureDictNet that always returns features as a list (values() of dict).
|
||||||
In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
|
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model,
|
self,
|
||||||
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
model: nn.Module,
|
||||||
|
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||||
|
feature_concat: bool = False,
|
||||||
|
flatten_sequential: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
model: Model from which to extract features.
|
||||||
|
out_indices: Output indices of the model features to extract.
|
||||||
|
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
|
||||||
|
first element e.g. `x[0]`
|
||||||
|
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
|
||||||
|
"""
|
||||||
super(FeatureListNet, self).__init__(
|
super(FeatureListNet, self).__init__(
|
||||||
model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
|
model,
|
||||||
flatten_sequential=flatten_sequential)
|
out_indices=out_indices,
|
||||||
|
feature_concat=feature_concat,
|
||||||
|
flatten_sequential=flatten_sequential,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x) -> (List[torch.Tensor]):
|
def forward(self, x) -> (List[torch.Tensor]):
|
||||||
return list(self._collect(x).values())
|
return list(self._collect(x).values())
|
||||||
@ -249,13 +289,33 @@ class FeatureHookNet(nn.ModuleDict):
|
|||||||
FIXME this does not currently work with Torchscript, see FeatureHooks class
|
FIXME this does not currently work with Torchscript, see FeatureHooks class
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model,
|
self,
|
||||||
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
|
model: nn.Module,
|
||||||
feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
|
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||||
|
out_map: Sequence[Union[int, str]] = None,
|
||||||
|
out_as_dict: bool = False,
|
||||||
|
no_rewrite: bool = False,
|
||||||
|
flatten_sequential: bool = False,
|
||||||
|
default_hook_type: str = 'forward',
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model from which to extract features.
|
||||||
|
out_indices: Output indices of the model features to extract.
|
||||||
|
out_map: Return id mapping for each output index, otherwise str(index) is used.
|
||||||
|
out_as_dict: Output features as a dict.
|
||||||
|
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
|
||||||
|
flatten_sequential arg must also be False if this is set True.
|
||||||
|
flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers.
|
||||||
|
default_hook_type: The default hook type to use if not specified in model.feature_info.
|
||||||
|
"""
|
||||||
super(FeatureHookNet, self).__init__()
|
super(FeatureHookNet, self).__init__()
|
||||||
assert not torch.jit.is_scripting()
|
assert not torch.jit.is_scripting()
|
||||||
self.feature_info = _get_feature_info(model, out_indices)
|
self.feature_info = _get_feature_info(model, out_indices)
|
||||||
self.out_as_dict = out_as_dict
|
self.out_as_dict = out_as_dict
|
||||||
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
layers = OrderedDict()
|
layers = OrderedDict()
|
||||||
hooks = []
|
hooks = []
|
||||||
if no_rewrite:
|
if no_rewrite:
|
||||||
@ -266,8 +326,10 @@ class FeatureHookNet(nn.ModuleDict):
|
|||||||
hooks.extend(self.feature_info.get_dicts())
|
hooks.extend(self.feature_info.get_dicts())
|
||||||
else:
|
else:
|
||||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||||
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
remaining = {
|
||||||
for f in self.feature_info.get_dicts()}
|
f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
||||||
|
for f in self.feature_info.get_dicts()
|
||||||
|
}
|
||||||
for new_name, old_name, module in modules:
|
for new_name, old_name, module in modules:
|
||||||
layers[new_name] = module
|
layers[new_name] = module
|
||||||
for fn, fm in module.named_modules(prefix=old_name):
|
for fn, fm in module.named_modules(prefix=old_name):
|
||||||
@ -280,8 +342,18 @@ class FeatureHookNet(nn.ModuleDict):
|
|||||||
self.update(layers)
|
self.update(layers)
|
||||||
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
|
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
|
||||||
|
|
||||||
|
def set_grad_checkpointing(self, enable: bool = True):
|
||||||
|
self.grad_checkpointing = enable
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for name, module in self.items():
|
for i, (name, module) in enumerate(self.items()):
|
||||||
x = module(x)
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
# Skipping checkpoint of first module because need a gradient at input
|
||||||
|
# Skipping last because networks with in-place ops might fail w/ checkpointing enabled
|
||||||
|
# NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
|
||||||
|
first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
|
||||||
|
x = module(x) if first_or_last_module else checkpoint(module, x)
|
||||||
|
else:
|
||||||
|
x = module(x)
|
||||||
out = self.hooks.get_output(x.device)
|
out = self.hooks.get_output(x.device)
|
||||||
return out if self.out_as_dict else list(out.values())
|
return out if self.out_as_dict else list(out.values())
|
||||||
|
@ -41,6 +41,7 @@ from typing import List
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct
|
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct
|
||||||
@ -211,6 +212,7 @@ class EfficientNetFeatures(nn.Module):
|
|||||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||||
se_layer = se_layer or SqueezeExcite
|
se_layer = se_layer or SqueezeExcite
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
# Stem
|
# Stem
|
||||||
if not fix_stem:
|
if not fix_stem:
|
||||||
@ -241,6 +243,10 @@ class EfficientNetFeatures(nn.Module):
|
|||||||
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
|
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
|
||||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
self.grad_checkpointing = enable
|
||||||
|
|
||||||
def forward(self, x) -> List[torch.Tensor]:
|
def forward(self, x) -> List[torch.Tensor]:
|
||||||
x = self.conv_stem(x)
|
x = self.conv_stem(x)
|
||||||
x = self.bn1(x)
|
x = self.bn1(x)
|
||||||
@ -249,7 +255,10 @@ class EfficientNetFeatures(nn.Module):
|
|||||||
if 0 in self._stage_out_idx:
|
if 0 in self._stage_out_idx:
|
||||||
features.append(x) # add stem out
|
features.append(x) # add stem out
|
||||||
for i, b in enumerate(self.blocks):
|
for i, b in enumerate(self.blocks):
|
||||||
x = b(x)
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
x = checkpoint(b, x)
|
||||||
|
else:
|
||||||
|
x = b(x)
|
||||||
if i + 1 in self._stage_out_idx:
|
if i + 1 in self._stage_out_idx:
|
||||||
features.append(x)
|
features.append(x)
|
||||||
return features
|
return features
|
||||||
|
@ -12,6 +12,7 @@ from typing import List
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
from timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer
|
from timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer
|
||||||
@ -188,6 +189,7 @@ class MobileNetV3Features(nn.Module):
|
|||||||
norm_layer = norm_layer or nn.BatchNorm2d
|
norm_layer = norm_layer or nn.BatchNorm2d
|
||||||
se_layer = se_layer or SqueezeExcite
|
se_layer = se_layer or SqueezeExcite
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
# Stem
|
# Stem
|
||||||
if not fix_stem:
|
if not fix_stem:
|
||||||
@ -220,6 +222,10 @@ class MobileNetV3Features(nn.Module):
|
|||||||
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
|
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
|
||||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
self.grad_checkpointing = enable
|
||||||
|
|
||||||
def forward(self, x) -> List[torch.Tensor]:
|
def forward(self, x) -> List[torch.Tensor]:
|
||||||
x = self.conv_stem(x)
|
x = self.conv_stem(x)
|
||||||
x = self.bn1(x)
|
x = self.bn1(x)
|
||||||
@ -229,7 +235,10 @@ class MobileNetV3Features(nn.Module):
|
|||||||
if 0 in self._stage_out_idx:
|
if 0 in self._stage_out_idx:
|
||||||
features.append(x) # add stem out
|
features.append(x) # add stem out
|
||||||
for i, b in enumerate(self.blocks):
|
for i, b in enumerate(self.blocks):
|
||||||
x = b(x)
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
x = checkpoint(b, x)
|
||||||
|
else:
|
||||||
|
x = b(x)
|
||||||
if i + 1 in self._stage_out_idx:
|
if i + 1 in self._stage_out_idx:
|
||||||
features.append(x)
|
features.append(x)
|
||||||
return features
|
return features
|
||||||
|
Loading…
x
Reference in New Issue
Block a user