Update feature_extraction.mdx

pull/2167/merge
Ross Wightman 2024-05-11 12:14:02 -07:00 committed by GitHub
parent a193d20b7b
commit 49de391470
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 80 additions and 3 deletions

View File

@ -8,7 +8,7 @@ The features from the penultimate model layer can be obtained in several ways wi
### Unpooled
There are three ways to obtain unpooled features.
There are three ways to obtain unpooled features. The final, unpooled features are sometimes referred to as the last hidden state. In `timm` this is up to and including the final normalization layer (in e.g. ViT style models) but does not include pooling / class token selection and final post-pooling layers.
Without modifying the network, one can call `model.forward_features(input)` on any model instead of the usual `model(input)`. This will bypass the head classifier and global pooling for networks.
@ -69,6 +69,25 @@ Original shape: torch.Size([2, 1000])
Unpooled shape: torch.Size([2, 1024, 7, 7])
```
#### Chaining unpooled output to classifier
The last hidden state can be fed back into the head of the model using the `forward_head()` function.
```py
>>> model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
>>> output = model.forward_features(torch.randn(2,3,256,256))
>>> print('Unpooled output shape:', output.shape)
>>> classified = model.forward_head(output)
>>> print('Classification output shape:', classified.shape)
```
Output:
```text
Unpooled output shape: torch.Size([2, 257, 512])
Classification output shape: torch.Size([2, 1000])
```
### Pooled
To modify the network to return pooled features, one can use `forward_features()` and pool/flatten the result themselves, or modify the network like above but keep pooling intact.
@ -116,7 +135,7 @@ Object detection, segmentation, keypoint, and a variety of dense pixel tasks req
`timm` allows a consistent interface for creating any of the included models as feature backbones that output feature maps for selected levels.
A feature backbone can be created by adding the argument `features_only=True` to any `create_model` call. By default 5 strides will be output from most models (not all have that many), with the first starting at 2 (some start at 1 or 4).
A feature backbone can be created by adding the argument `features_only=True` to any `create_model` call. By default most models with a feature hierarchy will output up to 5 features up to a reduction of 32. However this varies per model, some models have fewer hierarchy levels, and some (like ViT) have a larger number of non-hierarchical feature maps and they default to outputting the last 3. The `out_indices` arg can be passed to `create_model` to specify which features you want.
### Create a feature map extraction model
@ -171,7 +190,13 @@ There are two additional creation arguments impacting the output features.
* `out_indices` selects which indices to output
* `output_stride` limits the feature output stride of the network (also works in classification mode BTW)
`out_indices` is supported by all models, but not all models have the same index to feature stride mapping. Look at the code or check feature_info to compare. The out indices generally correspond to the `C(i+1)th` feature level (a `2^(i+1)` reduction). For most models, index 0 is the stride 2 features, and index 4 is stride 32.
#### Output index selection
The `out_indices` argument is supported by all models, but not all models have the same index to feature stride mapping. Look at the code or check feature_info to compare. The out indices generally correspond to the `C(i+1)th` feature level (a `2^(i+1)` reduction). For most convnet models, index 0 is the stride 2 features, and index 4 is stride 32. For many ViT or ViT-Conv hybrids there may be many to all features maps of the same shape, or a combination of hierarchical and non-hieararchical feature maps. It is best to look at the `feature_info` attribute to see the number of features, their corresponding channel count and reduction level.
`out_indices` supports negative indexing, this makes it easy to get the last, penunltimate, etc feature map. `out_indices=(-2,)` would return the penultimate feature map for any model.
#### Output stride (feature map dilation)
`output_stride` is achieved by converting layers to use dilated convolutions. Doing so is not always straightforward, some networks only support `output_stride=32`.
@ -194,3 +219,55 @@ Feature reduction: [8, 8]
torch.Size([2, 512, 40, 40])
torch.Size([2, 2048, 40, 40])
```
## Flexible intermediate feature map extraction
In addition to using `features_only` with the model factory, many models support a `forward_intermediates()` method which provides a flexible mechanism for extracting both the intermediate feature maps and the last hidden state (which can be chained to the head). Additionally this method supports some model specific features such as returning class or distill prefix tokens for some models.
Accompanying the `forward_intermediates` function is a `prune_intermediate_layers` function that allows one to prune layers from the model, including both the head, final norm, and/or trailing blocks/stages that are not needed.
An `indices` argument is used for both `forward_intermediates()` and `prune_intermediate_layers()` to select the features to return or layers to remove. As with the `out_indices` for `features_only` API, `indices` is model specific and selects which intermediates are returned.
In non-hierarchical block based models such as ViT the indices correspond to the blocks, in models with hierarchical stages they usually correspond to the output of the stem + each hierarhical stage. Both positive (from the start), and negative (relative to the end) indexing works, and `None` is used to return all intermediates.
The `prune_intermediate_layers()` call returns an indices variable, as negative indices must be converted to absolute (positive) indices when the model is trimmed.
```py
model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
output, intermediates = model.forward_intermediates(torch.randn(2,3,256,256))
for i, o in enumerate(intermediates):
print(f'Feat index: {i}, shape: {o.shape}')
```
```text
Feat index: 0, shape: torch.Size([2, 512, 16, 16])
Feat index: 1, shape: torch.Size([2, 512, 16, 16])
Feat index: 2, shape: torch.Size([2, 512, 16, 16])
Feat index: 3, shape: torch.Size([2, 512, 16, 16])
Feat index: 4, shape: torch.Size([2, 512, 16, 16])
Feat index: 5, shape: torch.Size([2, 512, 16, 16])
Feat index: 6, shape: torch.Size([2, 512, 16, 16])
Feat index: 7, shape: torch.Size([2, 512, 16, 16])
Feat index: 8, shape: torch.Size([2, 512, 16, 16])
Feat index: 9, shape: torch.Size([2, 512, 16, 16])
Feat index: 10, shape: torch.Size([2, 512, 16, 16])
Feat index: 11, shape: torch.Size([2, 512, 16, 16])
```
```py
model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
print('Original params:', sum([p.numel() for p in model.parameters()]))
indices = model.prune_intermediate_layers(indices=(-2,), prune_head=True, prune_norm=True) # prune head, norm, last block
print('Pruned params:', sum([p.numel() for p in model.parameters()]))
intermediates = model.forward_intermediates(torch.randn(2,3,256,256), indices=indices, intermediates_only=True) # return penultimate intermediate
for o in intermediates:
print(f'Feat shape: {o.shape}')
```
```text
Original params: 38880232
Pruned params: 35212800
Feat shape: torch.Size([2, 512, 16, 16])
```