mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
More Hiera updates. Add forward_intermediates to hieradat/sam2 impl. Make both use same classifier module. Add coarse bool to intermediates.
This commit is contained in:
parent
f2cfb4c677
commit
962958723c
@ -254,9 +254,12 @@ class ClNormMlpClassifierHead(nn.Module):
|
||||
self.drop = nn.Dropout(drop_rate)
|
||||
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def reset(self, num_classes: int, pool_type: Optional[str] = None):
|
||||
def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):
|
||||
if pool_type is not None:
|
||||
self.pool_type = pool_type
|
||||
if reset_other:
|
||||
self.pre_logits = nn.Identity()
|
||||
self.norm = nn.Identity()
|
||||
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def _global_pool(self, x):
|
||||
|
@ -32,8 +32,8 @@ import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, Mlp, LayerScale, use_fused_attn, _assert, get_norm_layer, to_2tuple, \
|
||||
init_weight_vit, init_weight_jax
|
||||
from timm.layers import DropPath, Mlp, LayerScale, ClNormMlpClassifierHead, use_fused_attn, \
|
||||
_assert, get_norm_layer, to_2tuple, init_weight_vit, init_weight_jax
|
||||
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from ._builder import build_model_with_cfg
|
||||
@ -376,44 +376,6 @@ class HieraBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class NormClassifierHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
pool_type: str = 'avg',
|
||||
drop_rate: float = 0.0,
|
||||
norm_layer: Union[str, Callable] = 'layernorm',
|
||||
):
|
||||
super().__init__()
|
||||
norm_layer = get_norm_layer(norm_layer)
|
||||
assert pool_type in ('avg', '')
|
||||
self.in_features = self.num_features = in_features
|
||||
self.pool_type = pool_type
|
||||
self.norm = norm_layer(in_features)
|
||||
self.drop = nn.Dropout(drop_rate) if drop_rate else nn.Identity()
|
||||
self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def reset(self, num_classes: int, pool_type: Optional[str] = None, other: bool = False):
|
||||
if pool_type is not None:
|
||||
assert pool_type in ('avg', '')
|
||||
self.pool_type = pool_type
|
||||
if other:
|
||||
# reset other non-fc layers
|
||||
self.norm = nn.Identity()
|
||||
self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
||||
if self.pool_type == 'avg':
|
||||
x = x.mean(dim=1)
|
||||
x = self.norm(x)
|
||||
x = self.drop(x)
|
||||
if pre_logits:
|
||||
return x
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Patch embed that supports any number of spatial dimensions (1d, 2d, 3d)."""
|
||||
|
||||
@ -591,12 +553,13 @@ class Hiera(nn.Module):
|
||||
self.blocks.append(block)
|
||||
|
||||
self.num_features = self.head_hidden_size = embed_dim
|
||||
self.head = NormClassifierHead(
|
||||
self.head = ClNormMlpClassifierHead(
|
||||
embed_dim,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate,
|
||||
norm_layer=norm_layer,
|
||||
input_fmt='NLC',
|
||||
)
|
||||
|
||||
# Initialize everything
|
||||
@ -651,9 +614,9 @@ class Hiera(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, other: bool = False):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, reset_other: bool = False):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool, other=other)
|
||||
self.head.reset(num_classes, global_pool, reset_other=reset_other)
|
||||
|
||||
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
|
||||
"""
|
||||
@ -716,6 +679,7 @@ class Hiera(nn.Module):
|
||||
stop_early: bool = True,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
coarse: bool = True,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
@ -730,10 +694,13 @@ class Hiera(nn.Module):
|
||||
|
||||
"""
|
||||
assert not norm, 'normalization of features not supported'
|
||||
assert output_fmt in ('NCHW',), 'Output format must be one of NCHW.'
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||
take_indices = [self.stage_ends[i] for i in take_indices]
|
||||
max_index = self.stage_ends[max_index]
|
||||
assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW, NHWC.'
|
||||
if coarse:
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||
take_indices = [self.stage_ends[i] for i in take_indices]
|
||||
max_index = self.stage_ends[max_index]
|
||||
else:
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
|
||||
if mask is not None:
|
||||
patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
|
||||
@ -755,7 +722,8 @@ class Hiera(nn.Module):
|
||||
for i, blk in enumerate(blocks):
|
||||
x = blk(x)
|
||||
if i in take_indices:
|
||||
intermediates.append(self.reroll(x, i, mask=mask).permute(0, 3, 1, 2))
|
||||
x_int = self.reroll(x, i, mask=mask)
|
||||
intermediates.append(x_int.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x_int)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
@ -767,14 +735,18 @@ class Hiera(nn.Module):
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
coarse: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||
max_index = self.stage_ends[max_index]
|
||||
if coarse:
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||
max_index = self.stage_ends[max_index]
|
||||
else:
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_head:
|
||||
self.head.reset(0, other=True)
|
||||
self.head.reset(0, reset_other=True)
|
||||
return take_indices
|
||||
|
||||
def forward_features(
|
||||
|
@ -328,18 +328,16 @@ class HieraDet(nn.Module):
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.global_pos_size))
|
||||
self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
cur_stage = 1
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
cur_stage = 0
|
||||
self.blocks = nn.Sequential()
|
||||
self.feature_info = []
|
||||
for i in range(depth):
|
||||
dim_out = embed_dim
|
||||
# lags by a block, so first block of
|
||||
# next stage uses an initial window size
|
||||
# of previous stage and final window size of current stage
|
||||
window_size = self.window_spec[cur_stage - 1]
|
||||
window_size = self.window_spec[cur_stage]
|
||||
|
||||
if self.global_att_blocks is not None:
|
||||
window_size = 0 if i in self.global_att_blocks else window_size
|
||||
@ -362,6 +360,9 @@ class HieraDet(nn.Module):
|
||||
|
||||
embed_dim = dim_out
|
||||
self.blocks.append(block)
|
||||
if i in self.stage_ends:
|
||||
self.feature_info += [
|
||||
dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')]
|
||||
|
||||
self.channel_list = (
|
||||
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
|
||||
@ -397,15 +398,15 @@ class HieraDet(nn.Module):
|
||||
self.head.fc.weight.data.mul_(head_init_scale)
|
||||
self.head.fc.bias.data.mul_(head_init_scale)
|
||||
|
||||
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
||||
h, w = hw
|
||||
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h, w = x.shape[1:3]
|
||||
window_embed = self.pos_embed_window
|
||||
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
||||
tile_h = pos_embed.shape[-2] // window_embed.shape[-2]
|
||||
tile_w = pos_embed.shape[-1] // window_embed.shape[-1]
|
||||
pos_embed = pos_embed + window_embed.tile((tile_h, tile_w))
|
||||
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
||||
return pos_embed
|
||||
return x + pos_embed
|
||||
|
||||
def fix_init_weight(self):
|
||||
def rescale(param, _layer_id):
|
||||
@ -417,13 +418,13 @@ class HieraDet(nn.Module):
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return ['pos_embed', 'pos_embed_win']
|
||||
return ['pos_embed', 'pos_embed_window']
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse: bool = False) -> Dict:
|
||||
return dict(
|
||||
stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|pos_embed_abs|pos_embed_win|patch_embed',
|
||||
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
|
||||
stem=r'^pos_embed|pos_embed_window|patch_embed',
|
||||
blocks=[(r'^blocks\.(\d+)', None)]
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
@ -434,13 +435,83 @@ class HieraDet(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, reset_other: bool = False):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
self.head.reset(num_classes, pool_type=global_pool, reset_other=reset_other)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
coarse: bool = True,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to all intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
coarse: Take coarse features (stage ends) if true, otherwise all block featrures
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert not norm, 'normalization of features not supported'
|
||||
assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW, NHWC.'
|
||||
if coarse:
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||
take_indices = [self.stage_ends[i] for i in take_indices]
|
||||
max_index = self.stage_ends[max_index]
|
||||
else:
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
|
||||
x = self.patch_embed(x)
|
||||
x = self._pos_embed(x)
|
||||
|
||||
intermediates = []
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
else:
|
||||
blocks = self.blocks[:max_index + 1]
|
||||
for i, blk in enumerate(blocks):
|
||||
x = blk(x)
|
||||
if i in take_indices:
|
||||
x_out = x.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x
|
||||
intermediates.append(x_out)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
coarse: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
if coarse:
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||
max_index = self.stage_ends[max_index]
|
||||
else:
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_head:
|
||||
self.head.reset(0, reset_other=True)
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
x = self.patch_embed(x) # BHWC
|
||||
x = x + self._get_pos_embed(x.shape[1:3])
|
||||
x = self._pos_embed(x)
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
return x
|
||||
@ -449,10 +520,7 @@ class HieraDet(nn.Module):
|
||||
x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user