mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
More Hiera updates. Add forward_intermediates to hieradat/sam2 impl. Make both use same classifier module. Add coarse bool to intermediates.
This commit is contained in:
parent
f2cfb4c677
commit
962958723c
@ -254,9 +254,12 @@ class ClNormMlpClassifierHead(nn.Module):
|
|||||||
self.drop = nn.Dropout(drop_rate)
|
self.drop = nn.Dropout(drop_rate)
|
||||||
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
def reset(self, num_classes: int, pool_type: Optional[str] = None):
|
def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):
|
||||||
if pool_type is not None:
|
if pool_type is not None:
|
||||||
self.pool_type = pool_type
|
self.pool_type = pool_type
|
||||||
|
if reset_other:
|
||||||
|
self.pre_logits = nn.Identity()
|
||||||
|
self.norm = nn.Identity()
|
||||||
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
def _global_pool(self, x):
|
def _global_pool(self, x):
|
||||||
|
@ -32,8 +32,8 @@ import torch.nn.functional as F
|
|||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import DropPath, Mlp, LayerScale, use_fused_attn, _assert, get_norm_layer, to_2tuple, \
|
from timm.layers import DropPath, Mlp, LayerScale, ClNormMlpClassifierHead, use_fused_attn, \
|
||||||
init_weight_vit, init_weight_jax
|
_assert, get_norm_layer, to_2tuple, init_weight_vit, init_weight_jax
|
||||||
|
|
||||||
from ._registry import generate_default_cfgs, register_model
|
from ._registry import generate_default_cfgs, register_model
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
@ -376,44 +376,6 @@ class HieraBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class NormClassifierHead(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_features: int,
|
|
||||||
num_classes: int,
|
|
||||||
pool_type: str = 'avg',
|
|
||||||
drop_rate: float = 0.0,
|
|
||||||
norm_layer: Union[str, Callable] = 'layernorm',
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
norm_layer = get_norm_layer(norm_layer)
|
|
||||||
assert pool_type in ('avg', '')
|
|
||||||
self.in_features = self.num_features = in_features
|
|
||||||
self.pool_type = pool_type
|
|
||||||
self.norm = norm_layer(in_features)
|
|
||||||
self.drop = nn.Dropout(drop_rate) if drop_rate else nn.Identity()
|
|
||||||
self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
||||||
|
|
||||||
def reset(self, num_classes: int, pool_type: Optional[str] = None, other: bool = False):
|
|
||||||
if pool_type is not None:
|
|
||||||
assert pool_type in ('avg', '')
|
|
||||||
self.pool_type = pool_type
|
|
||||||
if other:
|
|
||||||
# reset other non-fc layers
|
|
||||||
self.norm = nn.Identity()
|
|
||||||
self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
|
||||||
if self.pool_type == 'avg':
|
|
||||||
x = x.mean(dim=1)
|
|
||||||
x = self.norm(x)
|
|
||||||
x = self.drop(x)
|
|
||||||
if pre_logits:
|
|
||||||
return x
|
|
||||||
x = self.fc(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbed(nn.Module):
|
class PatchEmbed(nn.Module):
|
||||||
"""Patch embed that supports any number of spatial dimensions (1d, 2d, 3d)."""
|
"""Patch embed that supports any number of spatial dimensions (1d, 2d, 3d)."""
|
||||||
|
|
||||||
@ -591,12 +553,13 @@ class Hiera(nn.Module):
|
|||||||
self.blocks.append(block)
|
self.blocks.append(block)
|
||||||
|
|
||||||
self.num_features = self.head_hidden_size = embed_dim
|
self.num_features = self.head_hidden_size = embed_dim
|
||||||
self.head = NormClassifierHead(
|
self.head = ClNormMlpClassifierHead(
|
||||||
embed_dim,
|
embed_dim,
|
||||||
num_classes,
|
num_classes,
|
||||||
pool_type=global_pool,
|
pool_type=global_pool,
|
||||||
drop_rate=drop_rate,
|
drop_rate=drop_rate,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
|
input_fmt='NLC',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize everything
|
# Initialize everything
|
||||||
@ -651,9 +614,9 @@ class Hiera(nn.Module):
|
|||||||
def get_classifier(self):
|
def get_classifier(self):
|
||||||
return self.head.fc
|
return self.head.fc
|
||||||
|
|
||||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, other: bool = False):
|
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, reset_other: bool = False):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, global_pool, other=other)
|
self.head.reset(num_classes, global_pool, reset_other=reset_other)
|
||||||
|
|
||||||
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
|
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -716,6 +679,7 @@ class Hiera(nn.Module):
|
|||||||
stop_early: bool = True,
|
stop_early: bool = True,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
intermediates_only: bool = False,
|
intermediates_only: bool = False,
|
||||||
|
coarse: bool = True,
|
||||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
""" Forward features that returns intermediates.
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
@ -730,10 +694,13 @@ class Hiera(nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
assert not norm, 'normalization of features not supported'
|
assert not norm, 'normalization of features not supported'
|
||||||
assert output_fmt in ('NCHW',), 'Output format must be one of NCHW.'
|
assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW, NHWC.'
|
||||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
if coarse:
|
||||||
take_indices = [self.stage_ends[i] for i in take_indices]
|
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||||
max_index = self.stage_ends[max_index]
|
take_indices = [self.stage_ends[i] for i in take_indices]
|
||||||
|
max_index = self.stage_ends[max_index]
|
||||||
|
else:
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
|
patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
|
||||||
@ -755,7 +722,8 @@ class Hiera(nn.Module):
|
|||||||
for i, blk in enumerate(blocks):
|
for i, blk in enumerate(blocks):
|
||||||
x = blk(x)
|
x = blk(x)
|
||||||
if i in take_indices:
|
if i in take_indices:
|
||||||
intermediates.append(self.reroll(x, i, mask=mask).permute(0, 3, 1, 2))
|
x_int = self.reroll(x, i, mask=mask)
|
||||||
|
intermediates.append(x_int.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x_int)
|
||||||
|
|
||||||
if intermediates_only:
|
if intermediates_only:
|
||||||
return intermediates
|
return intermediates
|
||||||
@ -767,14 +735,18 @@ class Hiera(nn.Module):
|
|||||||
indices: Union[int, List[int]] = 1,
|
indices: Union[int, List[int]] = 1,
|
||||||
prune_norm: bool = False,
|
prune_norm: bool = False,
|
||||||
prune_head: bool = True,
|
prune_head: bool = True,
|
||||||
|
coarse: bool = True,
|
||||||
):
|
):
|
||||||
""" Prune layers not required for specified intermediates.
|
""" Prune layers not required for specified intermediates.
|
||||||
"""
|
"""
|
||||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
if coarse:
|
||||||
max_index = self.stage_ends[max_index]
|
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||||
|
max_index = self.stage_ends[max_index]
|
||||||
|
else:
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||||
if prune_head:
|
if prune_head:
|
||||||
self.head.reset(0, other=True)
|
self.head.reset(0, reset_other=True)
|
||||||
return take_indices
|
return take_indices
|
||||||
|
|
||||||
def forward_features(
|
def forward_features(
|
||||||
|
@ -328,18 +328,16 @@ class HieraDet(nn.Module):
|
|||||||
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.global_pos_size))
|
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.global_pos_size))
|
||||||
self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
|
self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
|
||||||
|
|
||||||
dpr = [
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||||
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
cur_stage = 0
|
||||||
] # stochastic depth decay rule
|
|
||||||
|
|
||||||
cur_stage = 1
|
|
||||||
self.blocks = nn.Sequential()
|
self.blocks = nn.Sequential()
|
||||||
|
self.feature_info = []
|
||||||
for i in range(depth):
|
for i in range(depth):
|
||||||
dim_out = embed_dim
|
dim_out = embed_dim
|
||||||
# lags by a block, so first block of
|
# lags by a block, so first block of
|
||||||
# next stage uses an initial window size
|
# next stage uses an initial window size
|
||||||
# of previous stage and final window size of current stage
|
# of previous stage and final window size of current stage
|
||||||
window_size = self.window_spec[cur_stage - 1]
|
window_size = self.window_spec[cur_stage]
|
||||||
|
|
||||||
if self.global_att_blocks is not None:
|
if self.global_att_blocks is not None:
|
||||||
window_size = 0 if i in self.global_att_blocks else window_size
|
window_size = 0 if i in self.global_att_blocks else window_size
|
||||||
@ -362,6 +360,9 @@ class HieraDet(nn.Module):
|
|||||||
|
|
||||||
embed_dim = dim_out
|
embed_dim = dim_out
|
||||||
self.blocks.append(block)
|
self.blocks.append(block)
|
||||||
|
if i in self.stage_ends:
|
||||||
|
self.feature_info += [
|
||||||
|
dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')]
|
||||||
|
|
||||||
self.channel_list = (
|
self.channel_list = (
|
||||||
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
|
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
|
||||||
@ -397,15 +398,15 @@ class HieraDet(nn.Module):
|
|||||||
self.head.fc.weight.data.mul_(head_init_scale)
|
self.head.fc.weight.data.mul_(head_init_scale)
|
||||||
self.head.fc.bias.data.mul_(head_init_scale)
|
self.head.fc.bias.data.mul_(head_init_scale)
|
||||||
|
|
||||||
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
h, w = hw
|
h, w = x.shape[1:3]
|
||||||
window_embed = self.pos_embed_window
|
window_embed = self.pos_embed_window
|
||||||
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
||||||
tile_h = pos_embed.shape[-2] // window_embed.shape[-2]
|
tile_h = pos_embed.shape[-2] // window_embed.shape[-2]
|
||||||
tile_w = pos_embed.shape[-1] // window_embed.shape[-1]
|
tile_w = pos_embed.shape[-1] // window_embed.shape[-1]
|
||||||
pos_embed = pos_embed + window_embed.tile((tile_h, tile_w))
|
pos_embed = pos_embed + window_embed.tile((tile_h, tile_w))
|
||||||
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
||||||
return pos_embed
|
return x + pos_embed
|
||||||
|
|
||||||
def fix_init_weight(self):
|
def fix_init_weight(self):
|
||||||
def rescale(param, _layer_id):
|
def rescale(param, _layer_id):
|
||||||
@ -417,13 +418,13 @@ class HieraDet(nn.Module):
|
|||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def no_weight_decay(self):
|
def no_weight_decay(self):
|
||||||
return ['pos_embed', 'pos_embed_win']
|
return ['pos_embed', 'pos_embed_window']
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def group_matcher(self, coarse: bool = False) -> Dict:
|
def group_matcher(self, coarse: bool = False) -> Dict:
|
||||||
return dict(
|
return dict(
|
||||||
stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|pos_embed_abs|pos_embed_win|patch_embed',
|
stem=r'^pos_embed|pos_embed_window|patch_embed',
|
||||||
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
|
blocks=[(r'^blocks\.(\d+)', None)]
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
@ -434,13 +435,83 @@ class HieraDet(nn.Module):
|
|||||||
def get_classifier(self):
|
def get_classifier(self):
|
||||||
return self.head.fc
|
return self.head.fc
|
||||||
|
|
||||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, reset_other: bool = False):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, pool_type=global_pool)
|
self.head.reset(num_classes, pool_type=global_pool, reset_other=reset_other)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = True,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
coarse: bool = True,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to all intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
coarse: Take coarse features (stage ends) if true, otherwise all block featrures
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert not norm, 'normalization of features not supported'
|
||||||
|
assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW, NHWC.'
|
||||||
|
if coarse:
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||||
|
take_indices = [self.stage_ends[i] for i in take_indices]
|
||||||
|
max_index = self.stage_ends[max_index]
|
||||||
|
else:
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||||
|
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
x = self._pos_embed(x)
|
||||||
|
|
||||||
|
intermediates = []
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
blocks = self.blocks
|
||||||
|
else:
|
||||||
|
blocks = self.blocks[:max_index + 1]
|
||||||
|
for i, blk in enumerate(blocks):
|
||||||
|
x = blk(x)
|
||||||
|
if i in take_indices:
|
||||||
|
x_out = x.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x
|
||||||
|
intermediates.append(x_out)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
coarse: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
if coarse:
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||||
|
max_index = self.stage_ends[max_index]
|
||||||
|
else:
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||||
|
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||||
|
if prune_head:
|
||||||
|
self.head.reset(0, reset_other=True)
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x: torch.Tensor) -> List[torch.Tensor]:
|
def forward_features(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||||
x = self.patch_embed(x) # BHWC
|
x = self.patch_embed(x) # BHWC
|
||||||
x = x + self._get_pos_embed(x.shape[1:3])
|
x = self._pos_embed(x)
|
||||||
for i, blk in enumerate(self.blocks):
|
for i, blk in enumerate(self.blocks):
|
||||||
x = blk(x)
|
x = blk(x)
|
||||||
return x
|
return x
|
||||||
@ -449,10 +520,7 @@ class HieraDet(nn.Module):
|
|||||||
x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
x = self.forward_features(x)
|
x = self.forward_features(x)
|
||||||
x = self.forward_head(x)
|
x = self.forward_head(x)
|
||||||
return x
|
return x
|
||||||
|
Loading…
x
Reference in New Issue
Block a user