mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
return_map back to out_map for _feature helpers
This commit is contained in:
parent
acfd85ad68
commit
5aebad3fbc
@ -100,7 +100,7 @@ class FeatureHooks:
|
|||||||
self,
|
self,
|
||||||
hooks: Sequence[str],
|
hooks: Sequence[str],
|
||||||
named_modules: dict,
|
named_modules: dict,
|
||||||
return_map: Sequence[Union[int, str]] = None,
|
out_map: Sequence[Union[int, str]] = None,
|
||||||
default_hook_type: str = 'forward',
|
default_hook_type: str = 'forward',
|
||||||
):
|
):
|
||||||
# setup feature hooks
|
# setup feature hooks
|
||||||
@ -109,7 +109,7 @@ class FeatureHooks:
|
|||||||
for i, h in enumerate(hooks):
|
for i, h in enumerate(hooks):
|
||||||
hook_name = h['module']
|
hook_name = h['module']
|
||||||
m = modules[hook_name]
|
m = modules[hook_name]
|
||||||
hook_id = return_map[i] if return_map else hook_name
|
hook_id = out_map[i] if out_map else hook_name
|
||||||
hook_fn = partial(self._collect_output_hook, hook_id)
|
hook_fn = partial(self._collect_output_hook, hook_id)
|
||||||
hook_type = h.get('hook_type', default_hook_type)
|
hook_type = h.get('hook_type', default_hook_type)
|
||||||
if hook_type == 'forward_pre':
|
if hook_type == 'forward_pre':
|
||||||
@ -155,11 +155,11 @@ def _get_feature_info(net, out_indices):
|
|||||||
assert False, "Provided feature_info is not valid"
|
assert False, "Provided feature_info is not valid"
|
||||||
|
|
||||||
|
|
||||||
def _get_return_layers(feature_info, return_map):
|
def _get_return_layers(feature_info, out_map):
|
||||||
module_names = feature_info.module_name()
|
module_names = feature_info.module_name()
|
||||||
return_layers = {}
|
return_layers = {}
|
||||||
for i, name in enumerate(module_names):
|
for i, name in enumerate(module_names):
|
||||||
return_layers[name] = return_map[i] if return_map is not None else feature_info.out_indices[i]
|
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
|
||||||
return return_layers
|
return return_layers
|
||||||
|
|
||||||
|
|
||||||
@ -182,7 +182,7 @@ class FeatureDictNet(nn.ModuleDict):
|
|||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||||
return_map: Sequence[Union[int, str]] = None,
|
out_map: Sequence[Union[int, str]] = None,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
feature_concat: bool = False,
|
feature_concat: bool = False,
|
||||||
flatten_sequential: bool = False,
|
flatten_sequential: bool = False,
|
||||||
@ -191,7 +191,7 @@ class FeatureDictNet(nn.ModuleDict):
|
|||||||
Args:
|
Args:
|
||||||
model: Model from which to extract features.
|
model: Model from which to extract features.
|
||||||
out_indices: Output indices of the model features to extract.
|
out_indices: Output indices of the model features to extract.
|
||||||
return_map: Return id mapping for each output index, otherwise str(index) is used.
|
out_map: Return id mapping for each output index, otherwise str(index) is used.
|
||||||
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
|
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
|
||||||
first element e.g. `x[0]`
|
first element e.g. `x[0]`
|
||||||
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
|
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
|
||||||
@ -203,7 +203,7 @@ class FeatureDictNet(nn.ModuleDict):
|
|||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
self.return_layers = {}
|
self.return_layers = {}
|
||||||
|
|
||||||
return_layers = _get_return_layers(self.feature_info, return_map)
|
return_layers = _get_return_layers(self.feature_info, out_map)
|
||||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||||
remaining = set(return_layers.keys())
|
remaining = set(return_layers.keys())
|
||||||
layers = OrderedDict()
|
layers = OrderedDict()
|
||||||
@ -298,7 +298,7 @@ class FeatureHookNet(nn.ModuleDict):
|
|||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||||
return_map: Sequence[Union[int, str]] = None,
|
out_map: Sequence[Union[int, str]] = None,
|
||||||
return_dict: bool = False,
|
return_dict: bool = False,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
no_rewrite: bool = False,
|
no_rewrite: bool = False,
|
||||||
@ -310,7 +310,7 @@ class FeatureHookNet(nn.ModuleDict):
|
|||||||
Args:
|
Args:
|
||||||
model: Model from which to extract features.
|
model: Model from which to extract features.
|
||||||
out_indices: Output indices of the model features to extract.
|
out_indices: Output indices of the model features to extract.
|
||||||
return_map: Return id mapping for each output index, otherwise str(index) is used.
|
out_map: Return id mapping for each output index, otherwise str(index) is used.
|
||||||
return_dict: Output features as a dict.
|
return_dict: Output features as a dict.
|
||||||
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
|
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
|
||||||
flatten_sequential arg must also be False if this is set True.
|
flatten_sequential arg must also be False if this is set True.
|
||||||
@ -348,7 +348,7 @@ class FeatureHookNet(nn.ModuleDict):
|
|||||||
break
|
break
|
||||||
assert not remaining, f'Return layers ({remaining}) are not present in model'
|
assert not remaining, f'Return layers ({remaining}) are not present in model'
|
||||||
self.update(layers)
|
self.update(layers)
|
||||||
self.hooks = FeatureHooks(hooks, model.named_modules(), return_map=return_map)
|
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
|
||||||
|
|
||||||
def set_grad_checkpointing(self, enable: bool = True):
|
def set_grad_checkpointing(self, enable: bool = True):
|
||||||
self.grad_checkpointing = enable
|
self.grad_checkpointing = enable
|
||||||
|
Loading…
x
Reference in New Issue
Block a user