return_map back to out_map for _feature helpers
parent
acfd85ad68
commit
5aebad3fbc
|
@ -100,7 +100,7 @@ class FeatureHooks:
|
|||
self,
|
||||
hooks: Sequence[str],
|
||||
named_modules: dict,
|
||||
return_map: Sequence[Union[int, str]] = None,
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
default_hook_type: str = 'forward',
|
||||
):
|
||||
# setup feature hooks
|
||||
|
@ -109,7 +109,7 @@ class FeatureHooks:
|
|||
for i, h in enumerate(hooks):
|
||||
hook_name = h['module']
|
||||
m = modules[hook_name]
|
||||
hook_id = return_map[i] if return_map else hook_name
|
||||
hook_id = out_map[i] if out_map else hook_name
|
||||
hook_fn = partial(self._collect_output_hook, hook_id)
|
||||
hook_type = h.get('hook_type', default_hook_type)
|
||||
if hook_type == 'forward_pre':
|
||||
|
@ -155,11 +155,11 @@ def _get_feature_info(net, out_indices):
|
|||
assert False, "Provided feature_info is not valid"
|
||||
|
||||
|
||||
def _get_return_layers(feature_info, return_map):
|
||||
def _get_return_layers(feature_info, out_map):
|
||||
module_names = feature_info.module_name()
|
||||
return_layers = {}
|
||||
for i, name in enumerate(module_names):
|
||||
return_layers[name] = return_map[i] if return_map is not None else feature_info.out_indices[i]
|
||||
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
|
||||
return return_layers
|
||||
|
||||
|
||||
|
@ -182,7 +182,7 @@ class FeatureDictNet(nn.ModuleDict):
|
|||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
return_map: Sequence[Union[int, str]] = None,
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
output_fmt: str = 'NCHW',
|
||||
feature_concat: bool = False,
|
||||
flatten_sequential: bool = False,
|
||||
|
@ -191,7 +191,7 @@ class FeatureDictNet(nn.ModuleDict):
|
|||
Args:
|
||||
model: Model from which to extract features.
|
||||
out_indices: Output indices of the model features to extract.
|
||||
return_map: Return id mapping for each output index, otherwise str(index) is used.
|
||||
out_map: Return id mapping for each output index, otherwise str(index) is used.
|
||||
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
|
||||
first element e.g. `x[0]`
|
||||
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
|
||||
|
@ -203,7 +203,7 @@ class FeatureDictNet(nn.ModuleDict):
|
|||
self.grad_checkpointing = False
|
||||
self.return_layers = {}
|
||||
|
||||
return_layers = _get_return_layers(self.feature_info, return_map)
|
||||
return_layers = _get_return_layers(self.feature_info, out_map)
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
remaining = set(return_layers.keys())
|
||||
layers = OrderedDict()
|
||||
|
@ -298,7 +298,7 @@ class FeatureHookNet(nn.ModuleDict):
|
|||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
return_map: Sequence[Union[int, str]] = None,
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
return_dict: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
no_rewrite: bool = False,
|
||||
|
@ -310,7 +310,7 @@ class FeatureHookNet(nn.ModuleDict):
|
|||
Args:
|
||||
model: Model from which to extract features.
|
||||
out_indices: Output indices of the model features to extract.
|
||||
return_map: Return id mapping for each output index, otherwise str(index) is used.
|
||||
out_map: Return id mapping for each output index, otherwise str(index) is used.
|
||||
return_dict: Output features as a dict.
|
||||
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
|
||||
flatten_sequential arg must also be False if this is set True.
|
||||
|
@ -348,7 +348,7 @@ class FeatureHookNet(nn.ModuleDict):
|
|||
break
|
||||
assert not remaining, f'Return layers ({remaining}) are not present in model'
|
||||
self.update(layers)
|
||||
self.hooks = FeatureHooks(hooks, model.named_modules(), return_map=return_map)
|
||||
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
|
||||
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
self.grad_checkpointing = enable
|
||||
|
|
Loading…
Reference in New Issue