return_map back to out_map for _feature helpers

pull/1628/head
Ross Wightman 2023-03-16 14:50:55 -07:00
parent acfd85ad68
commit 5aebad3fbc
1 changed files with 10 additions and 10 deletions

View File

@ -100,7 +100,7 @@ class FeatureHooks:
self,
hooks: Sequence[str],
named_modules: dict,
return_map: Sequence[Union[int, str]] = None,
out_map: Sequence[Union[int, str]] = None,
default_hook_type: str = 'forward',
):
# setup feature hooks
@ -109,7 +109,7 @@ class FeatureHooks:
for i, h in enumerate(hooks):
hook_name = h['module']
m = modules[hook_name]
hook_id = return_map[i] if return_map else hook_name
hook_id = out_map[i] if out_map else hook_name
hook_fn = partial(self._collect_output_hook, hook_id)
hook_type = h.get('hook_type', default_hook_type)
if hook_type == 'forward_pre':
@ -155,11 +155,11 @@ def _get_feature_info(net, out_indices):
assert False, "Provided feature_info is not valid"
def _get_return_layers(feature_info, return_map):
def _get_return_layers(feature_info, out_map):
module_names = feature_info.module_name()
return_layers = {}
for i, name in enumerate(module_names):
return_layers[name] = return_map[i] if return_map is not None else feature_info.out_indices[i]
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
return return_layers
@ -182,7 +182,7 @@ class FeatureDictNet(nn.ModuleDict):
self,
model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
return_map: Sequence[Union[int, str]] = None,
out_map: Sequence[Union[int, str]] = None,
output_fmt: str = 'NCHW',
feature_concat: bool = False,
flatten_sequential: bool = False,
@ -191,7 +191,7 @@ class FeatureDictNet(nn.ModuleDict):
Args:
model: Model from which to extract features.
out_indices: Output indices of the model features to extract.
return_map: Return id mapping for each output index, otherwise str(index) is used.
out_map: Return id mapping for each output index, otherwise str(index) is used.
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
first element e.g. `x[0]`
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
@ -203,7 +203,7 @@ class FeatureDictNet(nn.ModuleDict):
self.grad_checkpointing = False
self.return_layers = {}
return_layers = _get_return_layers(self.feature_info, return_map)
return_layers = _get_return_layers(self.feature_info, out_map)
modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = set(return_layers.keys())
layers = OrderedDict()
@ -298,7 +298,7 @@ class FeatureHookNet(nn.ModuleDict):
self,
model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
return_map: Sequence[Union[int, str]] = None,
out_map: Sequence[Union[int, str]] = None,
return_dict: bool = False,
output_fmt: str = 'NCHW',
no_rewrite: bool = False,
@ -310,7 +310,7 @@ class FeatureHookNet(nn.ModuleDict):
Args:
model: Model from which to extract features.
out_indices: Output indices of the model features to extract.
return_map: Return id mapping for each output index, otherwise str(index) is used.
out_map: Return id mapping for each output index, otherwise str(index) is used.
return_dict: Output features as a dict.
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
flatten_sequential arg must also be False if this is set True.
@ -348,7 +348,7 @@ class FeatureHookNet(nn.ModuleDict):
break
assert not remaining, f'Return layers ({remaining}) are not present in model'
self.update(layers)
self.hooks = FeatureHooks(hooks, model.named_modules(), return_map=return_map)
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
def set_grad_checkpointing(self, enable: bool = True):
self.grad_checkpointing = enable