mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
update some model
This commit is contained in:
parent
99c25fa5c0
commit
2e9b2a76fb
@ -452,29 +452,29 @@ class ConvNeXt(nn.Module):
|
|||||||
"""
|
"""
|
||||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
intermediates = []
|
intermediates = []
|
||||||
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
feat_idx = 0 # stem is index 0
|
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
if feat_idx in take_indices:
|
|
||||||
intermediates.append(x)
|
|
||||||
|
|
||||||
|
last_idx = len(self.stages) - 1
|
||||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
stages = self.stages
|
stages = self.stages
|
||||||
else:
|
else:
|
||||||
stages = self.stages[:max_index]
|
stages = self.stages[:max_index + 1]
|
||||||
for stage in stages:
|
for feat_idx, stage in enumerate(stages):
|
||||||
feat_idx += 1
|
|
||||||
x = stage(x)
|
x = stage(x)
|
||||||
if feat_idx in take_indices:
|
if feat_idx in take_indices:
|
||||||
# NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled
|
if norm and feat_idx == last_idx:
|
||||||
intermediates.append(x)
|
intermediates.append(self.norm_pre(x))
|
||||||
|
else:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
if intermediates_only:
|
if intermediates_only:
|
||||||
return intermediates
|
return intermediates
|
||||||
|
|
||||||
x = self.norm_pre(x)
|
if feat_idx == last_idx:
|
||||||
|
x = self.norm_pre(x)
|
||||||
|
|
||||||
return x, intermediates
|
return x, intermediates
|
||||||
|
|
||||||
|
@ -491,7 +491,7 @@ class FocalNet(nn.Module):
|
|||||||
else:
|
else:
|
||||||
stages = self.layers[:max_index + 1]
|
stages = self.layers[:max_index + 1]
|
||||||
|
|
||||||
last_idx = len(self.layers)
|
last_idx = len(self.layers) - 1
|
||||||
for feat_idx, stage in enumerate(stages):
|
for feat_idx, stage in enumerate(stages):
|
||||||
x = stage(x)
|
x = stage(x)
|
||||||
if feat_idx in take_indices:
|
if feat_idx in take_indices:
|
||||||
|
@ -870,10 +870,11 @@ class MultiScaleVit(nn.Module):
|
|||||||
if self.pos_embed is not None:
|
if self.pos_embed is not None:
|
||||||
x = x + self.pos_embed
|
x = x + self.pos_embed
|
||||||
|
|
||||||
for i, stage in enumerate(self.stages):
|
last_idx = len(self.stages) - 1
|
||||||
|
for feat_idx, stage in enumerate(self.stages):
|
||||||
x, feat_size = stage(x, feat_size)
|
x, feat_size = stage(x, feat_size)
|
||||||
if i in take_indices:
|
if feat_idx in take_indices:
|
||||||
if norm and i == (len(self.stages) - 1):
|
if norm and feat_idx == last_idx:
|
||||||
x_inter = self.norm(x) # applying final norm last intermediate
|
x_inter = self.norm(x) # applying final norm last intermediate
|
||||||
else:
|
else:
|
||||||
x_inter = x
|
x_inter = x
|
||||||
@ -887,7 +888,8 @@ class MultiScaleVit(nn.Module):
|
|||||||
if intermediates_only:
|
if intermediates_only:
|
||||||
return intermediates
|
return intermediates
|
||||||
|
|
||||||
x = self.norm(x)
|
if feat_idx == last_idx:
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
return x, intermediates
|
return x, intermediates
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ Modifications for timm by / Copyright 2020 Ross Wightman
|
|||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional, Sequence, Tuple
|
from typing import List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -22,6 +22,7 @@ from torch import nn
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import trunc_normal_, to_2tuple
|
from timm.layers import trunc_normal_, to_2tuple
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
from .vision_transformer import Block
|
from .vision_transformer import Block
|
||||||
|
|
||||||
@ -254,6 +255,71 @@ class PoolingVisionTransformer(nn.Module):
|
|||||||
if self.head_dist is not None:
|
if self.head_dist is not None:
|
||||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.transformers), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
x = self.pos_drop(x + self.pos_embed)
|
||||||
|
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
|
||||||
|
|
||||||
|
last_idx = len(self.transformers) - 1
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.transformers
|
||||||
|
else:
|
||||||
|
stages = self.transformers[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x, cls_tokens = stage((x, cls_tokens))
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
if feat_idx == last_idx:
|
||||||
|
cls_tokens = self.norm(cls_tokens)
|
||||||
|
|
||||||
|
return cls_tokens, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.transformers), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_norm:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
x = self.pos_drop(x + self.pos_embed)
|
x = self.pos_drop(x + self.pos_embed)
|
||||||
|
@ -302,20 +302,20 @@ class RDNet(nn.Module):
|
|||||||
"""
|
"""
|
||||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
intermediates = []
|
intermediates = []
|
||||||
take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices)
|
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
|
||||||
|
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||||
|
take_indices = [stage_ends[i] for i in take_indices]
|
||||||
|
max_index = stage_ends[max_index]
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
feat_idx = 0 # stem is index 0
|
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
if feat_idx in take_indices:
|
|
||||||
intermediates.append(x)
|
last_idx = len(self.dense_stages) - 1
|
||||||
last_idx = len(self.dense_stages)
|
|
||||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
dense_stages = self.dense_stages
|
dense_stages = self.dense_stages
|
||||||
else:
|
else:
|
||||||
dense_stages = self.dense_stages[:max_index]
|
dense_stages = self.dense_stages[:max_index + 1]
|
||||||
for stage in dense_stages:
|
for feat_idx, stage in enumerate(dense_stages):
|
||||||
feat_idx += 1
|
|
||||||
x = stage(x)
|
x = stage(x)
|
||||||
if feat_idx in take_indices:
|
if feat_idx in take_indices:
|
||||||
if norm and feat_idx == last_idx:
|
if norm and feat_idx == last_idx:
|
||||||
@ -340,8 +340,10 @@ class RDNet(nn.Module):
|
|||||||
):
|
):
|
||||||
""" Prune layers not required for specified intermediates.
|
""" Prune layers not required for specified intermediates.
|
||||||
"""
|
"""
|
||||||
take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices)
|
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
|
||||||
self.dense_stages = self.dense_stages[:max_index] # truncate blocks w/ stem as idx 0
|
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||||
|
max_index = stage_ends[max_index]
|
||||||
|
self.dense_stages = self.dense_stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
if prune_norm:
|
if prune_norm:
|
||||||
self.norm_pre = nn.Identity()
|
self.norm_pre = nn.Identity()
|
||||||
if prune_head:
|
if prune_head:
|
||||||
|
@ -571,9 +571,13 @@ class ResNetV2(nn.Module):
|
|||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
feat_idx = 0
|
feat_idx = 0
|
||||||
x = self.stem(x)
|
H, W = x.shape[-2:]
|
||||||
|
for stem in self.stem:
|
||||||
|
x = stem(x)
|
||||||
|
if x.shape[-2:] == (H //2, W //2):
|
||||||
|
x_down = x
|
||||||
if feat_idx in take_indices:
|
if feat_idx in take_indices:
|
||||||
intermediates.append(x)
|
intermediates.append(x_down)
|
||||||
last_idx = len(self.stages)
|
last_idx = len(self.stages)
|
||||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
stages = self.stages
|
stages = self.stages
|
||||||
|
@ -494,7 +494,8 @@ class Xcit(nn.Module):
|
|||||||
# NOTE not supporting return of class tokens
|
# NOTE not supporting return of class tokens
|
||||||
x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
|
x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
|
||||||
for blk in self.cls_attn_blocks:
|
for blk in self.cls_attn_blocks:
|
||||||
x = blk(x)
|
x = blk(x)
|
||||||
|
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
|
||||||
return x, intermediates
|
return x, intermediates
|
||||||
|
Loading…
x
Reference in New Issue
Block a user