From 211d18d8aceb7b36702e7e6ea7196f4847eed84d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 11 May 2024 23:37:35 -0700 Subject: [PATCH] Move norm & pool into Hiera ClassifierHead. Misc fixes, update features_intermediate() naming --- timm/layers/classifier.py | 4 +- timm/layers/create_norm.py | 2 +- timm/models/hiera.py | 78 ++++++++++++++++++++++++-------------- 3 files changed, 53 insertions(+), 31 deletions(-) diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index 71e45c87..27ee5e70 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -108,7 +108,7 @@ class ClassifierHead(nn.Module): self.fc = fc self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() - def reset(self, num_classes, pool_type=None): + def reset(self, num_classes: int, pool_type: Optional[str] = None): if pool_type is not None and pool_type != self.global_pool.pool_type: self.global_pool, self.fc = create_classifier( self.in_features, @@ -180,7 +180,7 @@ class NormMlpClassifierHead(nn.Module): self.drop = nn.Dropout(drop_rate) self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - def reset(self, num_classes, pool_type=None): + def reset(self, num_classes: int, pool_type: Optional[str] = None): if pool_type is not None: self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) self.flatten = nn.Flatten(1) if pool_type else nn.Identity() diff --git a/timm/layers/create_norm.py b/timm/layers/create_norm.py index 3c4d287a..fbf58985 100644 --- a/timm/layers/create_norm.py +++ b/timm/layers/create_norm.py @@ -47,7 +47,7 @@ def get_norm_layer(norm_layer): if isinstance(norm_layer, str): if not norm_layer: return None - layer_name = norm_layer.replace('_', '') + layer_name = norm_layer.replace('_', '').lower() norm_layer = _NORM_MAP[layer_name] else: norm_layer = norm_layer diff --git a/timm/models/hiera.py b/timm/models/hiera.py index 4063a93e..e99aa978 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -32,7 +32,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, Mlp, use_fused_attn, _assert +from timm.layers import DropPath, Mlp, use_fused_attn, _assert, get_norm_layer from ._registry import generate_default_cfgs, register_model @@ -372,20 +372,41 @@ class HieraBlock(nn.Module): return x -class Head(nn.Module): +class NormClassifierHead(nn.Module): def __init__( self, - dim: int, + in_features: int, num_classes: int, + pool_type: str = 'avg', drop_rate: float = 0.0, + norm_layer: Union[str, Callable] = 'layernorm', ): super().__init__() - self.dropout = nn.Dropout(drop_rate) if drop_rate > 0 else nn.Identity() - self.projection = nn.Linear(dim, num_classes) + norm_layer = get_norm_layer(norm_layer) + assert pool_type in ('avg', '') + self.in_features = self.num_features = in_features + self.pool_type = pool_type + self.norm = norm_layer(in_features) + self.drop = nn.Dropout(drop_rate) if drop_rate else nn.Identity() + self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity() - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.dropout(x) - x = self.projection(x) + def reset(self, num_classes: int, pool_type: Optional[str] = None, other: bool = False): + if pool_type is not None: + assert pool_type in ('avg', '') + self.pool_type = pool_type + if other: + # reset other non-fc layers + self.norm = nn.Identity() + self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + if self.pool_type == 'avg': + x = x.mean(dim=1) + x = self.norm(x) + x = self.drop(x) + if pre_logits: + return x + x = self.fc(x) return x @@ -438,6 +459,7 @@ class Hiera(nn.Module): embed_dim: int = 96, # initial embed dim num_heads: int = 1, # initial number of heads num_classes: int = 1000, + global_pool: str = 'avg', stages: Tuple[int, ...] = (2, 3, 16, 3), q_pool: int = 3, # number of q_pool stages q_stride: Tuple[int, ...] = (2, 2), @@ -458,11 +480,7 @@ class Hiera(nn.Module): ): super().__init__() self.num_classes = num_classes - - # Do it this way to ensure that the init args are all PoD (for config usage) - if isinstance(norm_layer, str): - norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) - + norm_layer = get_norm_layer(norm_layer) depth = sum(stages) self.patch_stride = patch_stride self.tokens_spatial_shape = [i // s for i, s in zip(img_size, patch_stride)] @@ -552,8 +570,14 @@ class Hiera(nn.Module): dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')] self.blocks.append(block) - self.norm = norm_layer(embed_dim) - self.head = Head(embed_dim, num_classes, drop_rate=drop_rate) + self.num_features = embed_dim + self.head = NormClassifierHead( + embed_dim, + num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + norm_layer=norm_layer, + ) # Initialize everything if sep_pos_embed: @@ -562,8 +586,8 @@ class Hiera(nn.Module): else: nn.init.trunc_normal_(self.pos_embed, std=0.02) self.apply(partial(self._init_weights)) - self.head.projection.weight.data.mul_(head_init_scale) - self.head.projection.bias.data.mul_(head_init_scale) + self.head.fc.weight.data.mul_(head_init_scale) + self.head.fc.bias.data.mul_(head_init_scale) def _init_weights(self, m, init_bias=0.02): if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): @@ -678,19 +702,17 @@ class Hiera(nn.Module): def prune_intermediate_layers( self, - n: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int], Tuple[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.stage_ends), n) + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) max_index = self.stage_ends[max_index] self.blocks = self.blocks[:max_index + 1] # truncate blocks if prune_head: - # norm part of head for this model, equivalent to fc_norm in other vit. - self.norm = nn.Identity() - self.head = nn.Identity() + self.head.reset(0, other=True) return take_indices @@ -732,11 +754,7 @@ class Hiera(nn.Module): return x def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor: - x = x.mean(dim=1) - x = self.norm(x) - if pre_logits: - return x - x = self.head(x) + x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) return x def forward( @@ -756,7 +774,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', + 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', **kwargs } @@ -837,6 +855,10 @@ def checkpoint_filter_fn(state_dict, model=None): # ) #v = F.interpolate(v.transpose(1, 2), (model.pos_embed.shape[1],)).transpose(1, 2) pass + if 'head.projection.' in k: + k = k.replace('head.projection.', 'head.fc.') + if k.startswith('norm.'): + k = k.replace('norm.', 'head.norm.') output[k] = v return output