diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index 1cbab683..5e425fe6 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -254,9 +254,12 @@ class ClNormMlpClassifierHead(nn.Module): self.drop = nn.Dropout(drop_rate) self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - def reset(self, num_classes: int, pool_type: Optional[str] = None): + def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False): if pool_type is not None: self.pool_type = pool_type + if reset_other: + self.pre_logits = nn.Identity() + self.norm = nn.Identity() self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def _global_pool(self, x): diff --git a/timm/models/hiera.py b/timm/models/hiera.py index ec5d8b7b..78d32752 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -32,8 +32,8 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, Mlp, LayerScale, use_fused_attn, _assert, get_norm_layer, to_2tuple, \ - init_weight_vit, init_weight_jax +from timm.layers import DropPath, Mlp, LayerScale, ClNormMlpClassifierHead, use_fused_attn, \ + _assert, get_norm_layer, to_2tuple, init_weight_vit, init_weight_jax from ._registry import generate_default_cfgs, register_model from ._builder import build_model_with_cfg @@ -376,44 +376,6 @@ class HieraBlock(nn.Module): return x -class NormClassifierHead(nn.Module): - def __init__( - self, - in_features: int, - num_classes: int, - pool_type: str = 'avg', - drop_rate: float = 0.0, - norm_layer: Union[str, Callable] = 'layernorm', - ): - super().__init__() - norm_layer = get_norm_layer(norm_layer) - assert pool_type in ('avg', '') - self.in_features = self.num_features = in_features - self.pool_type = pool_type - self.norm = norm_layer(in_features) - self.drop = nn.Dropout(drop_rate) if drop_rate else nn.Identity() - self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity() - - def reset(self, num_classes: int, pool_type: Optional[str] = None, other: bool = False): - if pool_type is not None: - assert pool_type in ('avg', '') - self.pool_type = pool_type - if other: - # reset other non-fc layers - self.norm = nn.Identity() - self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity() - - def forward(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: - if self.pool_type == 'avg': - x = x.mean(dim=1) - x = self.norm(x) - x = self.drop(x) - if pre_logits: - return x - x = self.fc(x) - return x - - class PatchEmbed(nn.Module): """Patch embed that supports any number of spatial dimensions (1d, 2d, 3d).""" @@ -591,12 +553,13 @@ class Hiera(nn.Module): self.blocks.append(block) self.num_features = self.head_hidden_size = embed_dim - self.head = NormClassifierHead( + self.head = ClNormMlpClassifierHead( embed_dim, num_classes, pool_type=global_pool, drop_rate=drop_rate, norm_layer=norm_layer, + input_fmt='NLC', ) # Initialize everything @@ -651,9 +614,9 @@ class Hiera(nn.Module): def get_classifier(self): return self.head.fc - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, other: bool = False): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, reset_other: bool = False): self.num_classes = num_classes - self.head.reset(num_classes, global_pool, other=other) + self.head.reset(num_classes, global_pool, reset_other=reset_other) def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor: """ @@ -716,6 +679,7 @@ class Hiera(nn.Module): stop_early: bool = True, output_fmt: str = 'NCHW', intermediates_only: bool = False, + coarse: bool = True, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: """ Forward features that returns intermediates. @@ -730,10 +694,13 @@ class Hiera(nn.Module): """ assert not norm, 'normalization of features not supported' - assert output_fmt in ('NCHW',), 'Output format must be one of NCHW.' - take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) - take_indices = [self.stage_ends[i] for i in take_indices] - max_index = self.stage_ends[max_index] + assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW, NHWC.' + if coarse: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + take_indices = [self.stage_ends[i] for i in take_indices] + max_index = self.stage_ends[max_index] + else: + take_indices, max_index = feature_take_indices(len(self.blocks), indices) if mask is not None: patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape @@ -755,7 +722,8 @@ class Hiera(nn.Module): for i, blk in enumerate(blocks): x = blk(x) if i in take_indices: - intermediates.append(self.reroll(x, i, mask=mask).permute(0, 3, 1, 2)) + x_int = self.reroll(x, i, mask=mask) + intermediates.append(x_int.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x_int) if intermediates_only: return intermediates @@ -767,14 +735,18 @@ class Hiera(nn.Module): indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, + coarse: bool = True, ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) - max_index = self.stage_ends[max_index] + if coarse: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + max_index = self.stage_ends[max_index] + else: + take_indices, max_index = feature_take_indices(len(self.blocks), indices) self.blocks = self.blocks[:max_index + 1] # truncate blocks if prune_head: - self.head.reset(0, other=True) + self.head.reset(0, reset_other=True) return take_indices def forward_features( diff --git a/timm/models/hieradet_sam2.py b/timm/models/hieradet_sam2.py index 2ba167ae..d5a78679 100644 --- a/timm/models/hieradet_sam2.py +++ b/timm/models/hieradet_sam2.py @@ -328,18 +328,16 @@ class HieraDet(nn.Module): self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.global_pos_size)) self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])) - dpr = [ - x.item() for x in torch.linspace(0, drop_path_rate, depth) - ] # stochastic depth decay rule - - cur_stage = 1 + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + cur_stage = 0 self.blocks = nn.Sequential() + self.feature_info = [] for i in range(depth): dim_out = embed_dim # lags by a block, so first block of # next stage uses an initial window size # of previous stage and final window size of current stage - window_size = self.window_spec[cur_stage - 1] + window_size = self.window_spec[cur_stage] if self.global_att_blocks is not None: window_size = 0 if i in self.global_att_blocks else window_size @@ -362,6 +360,9 @@ class HieraDet(nn.Module): embed_dim = dim_out self.blocks.append(block) + if i in self.stage_ends: + self.feature_info += [ + dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')] self.channel_list = ( [self.blocks[i].dim_out for i in self.stage_ends[::-1]] @@ -397,15 +398,15 @@ class HieraDet(nn.Module): self.head.fc.weight.data.mul_(head_init_scale) self.head.fc.bias.data.mul_(head_init_scale) - def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: - h, w = hw + def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + h, w = x.shape[1:3] window_embed = self.pos_embed_window pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") tile_h = pos_embed.shape[-2] // window_embed.shape[-2] tile_w = pos_embed.shape[-1] // window_embed.shape[-1] pos_embed = pos_embed + window_embed.tile((tile_h, tile_w)) pos_embed = pos_embed.permute(0, 2, 3, 1) - return pos_embed + return x + pos_embed def fix_init_weight(self): def rescale(param, _layer_id): @@ -417,13 +418,13 @@ class HieraDet(nn.Module): @torch.jit.ignore def no_weight_decay(self): - return ['pos_embed', 'pos_embed_win'] + return ['pos_embed', 'pos_embed_window'] @torch.jit.ignore def group_matcher(self, coarse: bool = False) -> Dict: return dict( - stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|pos_embed_abs|pos_embed_win|patch_embed', - blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + stem=r'^pos_embed|pos_embed_window|patch_embed', + blocks=[(r'^blocks\.(\d+)', None)] ) @torch.jit.ignore @@ -434,13 +435,83 @@ class HieraDet(nn.Module): def get_classifier(self): return self.head.fc - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, reset_other: bool = False): self.num_classes = num_classes - self.head.reset(num_classes, pool_type=global_pool) + self.head.reset(num_classes, pool_type=global_pool, reset_other=reset_other) + + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + coarse: bool = True, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + coarse: Take coarse features (stage ends) if true, otherwise all block featrures + Returns: + + """ + assert not norm, 'normalization of features not supported' + assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW, NHWC.' + if coarse: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + take_indices = [self.stage_ends[i] for i in take_indices] + max_index = self.stage_ends[max_index] + else: + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + + x = self.patch_embed(x) + x = self._pos_embed(x) + + intermediates = [] + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x) + if i in take_indices: + x_out = x.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x + intermediates.append(x_out) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + coarse: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + if coarse: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + max_index = self.stage_ends[max_index] + else: + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_head: + self.head.reset(0, reset_other=True) + return take_indices def forward_features(self, x: torch.Tensor) -> List[torch.Tensor]: x = self.patch_embed(x) # BHWC - x = x + self._get_pos_embed(x.shape[1:3]) + x = self._pos_embed(x) for i, blk in enumerate(self.blocks): x = blk(x) return x @@ -449,10 +520,7 @@ class HieraDet(nn.Module): x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) return x - def forward( - self, - x: torch.Tensor, - ) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.forward_features(x) x = self.forward_head(x) return x