Move norm & pool into Hiera ClassifierHead. Misc fixes, update features_intermediate() naming

This commit is contained in:
Ross Wightman 2024-05-11 23:37:35 -07:00
parent 2ca45a4ff5
commit 211d18d8ac
3 changed files with 53 additions and 31 deletions

View File

@ -108,7 +108,7 @@ class ClassifierHead(nn.Module):
self.fc = fc
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
def reset(self, num_classes, pool_type=None):
def reset(self, num_classes: int, pool_type: Optional[str] = None):
if pool_type is not None and pool_type != self.global_pool.pool_type:
self.global_pool, self.fc = create_classifier(
self.in_features,
@ -180,7 +180,7 @@ class NormMlpClassifierHead(nn.Module):
self.drop = nn.Dropout(drop_rate)
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes, pool_type=None):
def reset(self, num_classes: int, pool_type: Optional[str] = None):
if pool_type is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()

View File

@ -47,7 +47,7 @@ def get_norm_layer(norm_layer):
if isinstance(norm_layer, str):
if not norm_layer:
return None
layer_name = norm_layer.replace('_', '')
layer_name = norm_layer.replace('_', '').lower()
norm_layer = _NORM_MAP[layer_name]
else:
norm_layer = norm_layer

View File

@ -32,7 +32,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, use_fused_attn, _assert
from timm.layers import DropPath, Mlp, use_fused_attn, _assert, get_norm_layer
from ._registry import generate_default_cfgs, register_model
@ -372,20 +372,41 @@ class HieraBlock(nn.Module):
return x
class Head(nn.Module):
class NormClassifierHead(nn.Module):
def __init__(
self,
dim: int,
in_features: int,
num_classes: int,
pool_type: str = 'avg',
drop_rate: float = 0.0,
norm_layer: Union[str, Callable] = 'layernorm',
):
super().__init__()
self.dropout = nn.Dropout(drop_rate) if drop_rate > 0 else nn.Identity()
self.projection = nn.Linear(dim, num_classes)
norm_layer = get_norm_layer(norm_layer)
assert pool_type in ('avg', '')
self.in_features = self.num_features = in_features
self.pool_type = pool_type
self.norm = norm_layer(in_features)
self.drop = nn.Dropout(drop_rate) if drop_rate else nn.Identity()
self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.dropout(x)
x = self.projection(x)
def reset(self, num_classes: int, pool_type: Optional[str] = None, other: bool = False):
if pool_type is not None:
assert pool_type in ('avg', '')
self.pool_type = pool_type
if other:
# reset other non-fc layers
self.norm = nn.Identity()
self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
if self.pool_type == 'avg':
x = x.mean(dim=1)
x = self.norm(x)
x = self.drop(x)
if pre_logits:
return x
x = self.fc(x)
return x
@ -438,6 +459,7 @@ class Hiera(nn.Module):
embed_dim: int = 96, # initial embed dim
num_heads: int = 1, # initial number of heads
num_classes: int = 1000,
global_pool: str = 'avg',
stages: Tuple[int, ...] = (2, 3, 16, 3),
q_pool: int = 3, # number of q_pool stages
q_stride: Tuple[int, ...] = (2, 2),
@ -458,11 +480,7 @@ class Hiera(nn.Module):
):
super().__init__()
self.num_classes = num_classes
# Do it this way to ensure that the init args are all PoD (for config usage)
if isinstance(norm_layer, str):
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
norm_layer = get_norm_layer(norm_layer)
depth = sum(stages)
self.patch_stride = patch_stride
self.tokens_spatial_shape = [i // s for i, s in zip(img_size, patch_stride)]
@ -552,8 +570,14 @@ class Hiera(nn.Module):
dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')]
self.blocks.append(block)
self.norm = norm_layer(embed_dim)
self.head = Head(embed_dim, num_classes, drop_rate=drop_rate)
self.num_features = embed_dim
self.head = NormClassifierHead(
embed_dim,
num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
norm_layer=norm_layer,
)
# Initialize everything
if sep_pos_embed:
@ -562,8 +586,8 @@ class Hiera(nn.Module):
else:
nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.apply(partial(self._init_weights))
self.head.projection.weight.data.mul_(head_init_scale)
self.head.projection.bias.data.mul_(head_init_scale)
self.head.fc.weight.data.mul_(head_init_scale)
self.head.fc.bias.data.mul_(head_init_scale)
def _init_weights(self, m, init_bias=0.02):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
@ -678,19 +702,17 @@ class Hiera(nn.Module):
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stage_ends), n)
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
max_index = self.stage_ends[max_index]
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_head:
# norm part of head for this model, equivalent to fc_norm in other vit.
self.norm = nn.Identity()
self.head = nn.Identity()
self.head.reset(0, other=True)
return take_indices
@ -732,11 +754,7 @@ class Hiera(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
x = x.mean(dim=1)
x = self.norm(x)
if pre_logits:
return x
x = self.head(x)
x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
return x
def forward(
@ -756,7 +774,7 @@ def _cfg(url='', **kwargs):
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
**kwargs
}
@ -837,6 +855,10 @@ def checkpoint_filter_fn(state_dict, model=None):
# )
#v = F.interpolate(v.transpose(1, 2), (model.pos_embed.shape[1],)).transpose(1, 2)
pass
if 'head.projection.' in k:
k = k.replace('head.projection.', 'head.fc.')
if k.startswith('norm.'):
k = k.replace('norm.', 'head.norm.')
output[k] = v
return output