More Hiera updates. Add forward_intermediates to hieradat/sam2 impl. Make both use same classifier module. Add coarse bool to intermediates.

This commit is contained in:
Ross Wightman 2024-08-16 11:10:04 -07:00
parent f2cfb4c677
commit 962958723c
3 changed files with 114 additions and 71 deletions

View File

@ -254,9 +254,12 @@ class ClNormMlpClassifierHead(nn.Module):
self.drop = nn.Dropout(drop_rate)
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes: int, pool_type: Optional[str] = None):
def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):
if pool_type is not None:
self.pool_type = pool_type
if reset_other:
self.pre_logits = nn.Identity()
self.norm = nn.Identity()
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def _global_pool(self, x):

View File

@ -32,8 +32,8 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, LayerScale, use_fused_attn, _assert, get_norm_layer, to_2tuple, \
init_weight_vit, init_weight_jax
from timm.layers import DropPath, Mlp, LayerScale, ClNormMlpClassifierHead, use_fused_attn, \
_assert, get_norm_layer, to_2tuple, init_weight_vit, init_weight_jax
from ._registry import generate_default_cfgs, register_model
from ._builder import build_model_with_cfg
@ -376,44 +376,6 @@ class HieraBlock(nn.Module):
return x
class NormClassifierHead(nn.Module):
def __init__(
self,
in_features: int,
num_classes: int,
pool_type: str = 'avg',
drop_rate: float = 0.0,
norm_layer: Union[str, Callable] = 'layernorm',
):
super().__init__()
norm_layer = get_norm_layer(norm_layer)
assert pool_type in ('avg', '')
self.in_features = self.num_features = in_features
self.pool_type = pool_type
self.norm = norm_layer(in_features)
self.drop = nn.Dropout(drop_rate) if drop_rate else nn.Identity()
self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes: int, pool_type: Optional[str] = None, other: bool = False):
if pool_type is not None:
assert pool_type in ('avg', '')
self.pool_type = pool_type
if other:
# reset other non-fc layers
self.norm = nn.Identity()
self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
if self.pool_type == 'avg':
x = x.mean(dim=1)
x = self.norm(x)
x = self.drop(x)
if pre_logits:
return x
x = self.fc(x)
return x
class PatchEmbed(nn.Module):
"""Patch embed that supports any number of spatial dimensions (1d, 2d, 3d)."""
@ -591,12 +553,13 @@ class Hiera(nn.Module):
self.blocks.append(block)
self.num_features = self.head_hidden_size = embed_dim
self.head = NormClassifierHead(
self.head = ClNormMlpClassifierHead(
embed_dim,
num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
norm_layer=norm_layer,
input_fmt='NLC',
)
# Initialize everything
@ -651,9 +614,9 @@ class Hiera(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, other: bool = False):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, reset_other: bool = False):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool, other=other)
self.head.reset(num_classes, global_pool, reset_other=reset_other)
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
"""
@ -716,6 +679,7 @@ class Hiera(nn.Module):
stop_early: bool = True,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
coarse: bool = True,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
@ -730,10 +694,13 @@ class Hiera(nn.Module):
"""
assert not norm, 'normalization of features not supported'
assert output_fmt in ('NCHW',), 'Output format must be one of NCHW.'
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
take_indices = [self.stage_ends[i] for i in take_indices]
max_index = self.stage_ends[max_index]
assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW, NHWC.'
if coarse:
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
take_indices = [self.stage_ends[i] for i in take_indices]
max_index = self.stage_ends[max_index]
else:
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
if mask is not None:
patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
@ -755,7 +722,8 @@ class Hiera(nn.Module):
for i, blk in enumerate(blocks):
x = blk(x)
if i in take_indices:
intermediates.append(self.reroll(x, i, mask=mask).permute(0, 3, 1, 2))
x_int = self.reroll(x, i, mask=mask)
intermediates.append(x_int.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x_int)
if intermediates_only:
return intermediates
@ -767,14 +735,18 @@ class Hiera(nn.Module):
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
coarse: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
max_index = self.stage_ends[max_index]
if coarse:
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
max_index = self.stage_ends[max_index]
else:
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_head:
self.head.reset(0, other=True)
self.head.reset(0, reset_other=True)
return take_indices
def forward_features(

View File

@ -328,18 +328,16 @@ class HieraDet(nn.Module):
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.global_pos_size))
self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
cur_stage = 1
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
cur_stage = 0
self.blocks = nn.Sequential()
self.feature_info = []
for i in range(depth):
dim_out = embed_dim
# lags by a block, so first block of
# next stage uses an initial window size
# of previous stage and final window size of current stage
window_size = self.window_spec[cur_stage - 1]
window_size = self.window_spec[cur_stage]
if self.global_att_blocks is not None:
window_size = 0 if i in self.global_att_blocks else window_size
@ -362,6 +360,9 @@ class HieraDet(nn.Module):
embed_dim = dim_out
self.blocks.append(block)
if i in self.stage_ends:
self.feature_info += [
dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')]
self.channel_list = (
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
@ -397,15 +398,15 @@ class HieraDet(nn.Module):
self.head.fc.weight.data.mul_(head_init_scale)
self.head.fc.bias.data.mul_(head_init_scale)
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
h, w = hw
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
h, w = x.shape[1:3]
window_embed = self.pos_embed_window
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
tile_h = pos_embed.shape[-2] // window_embed.shape[-2]
tile_w = pos_embed.shape[-1] // window_embed.shape[-1]
pos_embed = pos_embed + window_embed.tile((tile_h, tile_w))
pos_embed = pos_embed.permute(0, 2, 3, 1)
return pos_embed
return x + pos_embed
def fix_init_weight(self):
def rescale(param, _layer_id):
@ -417,13 +418,13 @@ class HieraDet(nn.Module):
@torch.jit.ignore
def no_weight_decay(self):
return ['pos_embed', 'pos_embed_win']
return ['pos_embed', 'pos_embed_window']
@torch.jit.ignore
def group_matcher(self, coarse: bool = False) -> Dict:
return dict(
stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|pos_embed_abs|pos_embed_win|patch_embed',
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
stem=r'^pos_embed|pos_embed_window|patch_embed',
blocks=[(r'^blocks\.(\d+)', None)]
)
@torch.jit.ignore
@ -434,13 +435,83 @@ class HieraDet(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, reset_other: bool = False):
self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool)
self.head.reset(num_classes, pool_type=global_pool, reset_other=reset_other)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
coarse: bool = True,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
coarse: Take coarse features (stage ends) if true, otherwise all block featrures
Returns:
"""
assert not norm, 'normalization of features not supported'
assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW, NHWC.'
if coarse:
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
take_indices = [self.stage_ends[i] for i in take_indices]
max_index = self.stage_ends[max_index]
else:
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
x = self.patch_embed(x)
x = self._pos_embed(x)
intermediates = []
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x)
if i in take_indices:
x_out = x.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x
intermediates.append(x_out)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
coarse: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
if coarse:
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
max_index = self.stage_ends[max_index]
else:
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_head:
self.head.reset(0, reset_other=True)
return take_indices
def forward_features(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self.patch_embed(x) # BHWC
x = x + self._get_pos_embed(x.shape[1:3])
x = self._pos_embed(x)
for i, blk in enumerate(self.blocks):
x = blk(x)
return x
@ -449,10 +520,7 @@ class HieraDet(nn.Module):
x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
return x
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
x = self.forward_head(x)
return x