mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Move norm & pool into Hiera ClassifierHead. Misc fixes, update features_intermediate() naming
This commit is contained in:
parent
2ca45a4ff5
commit
211d18d8ac
@ -108,7 +108,7 @@ class ClassifierHead(nn.Module):
|
||||
self.fc = fc
|
||||
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
|
||||
|
||||
def reset(self, num_classes, pool_type=None):
|
||||
def reset(self, num_classes: int, pool_type: Optional[str] = None):
|
||||
if pool_type is not None and pool_type != self.global_pool.pool_type:
|
||||
self.global_pool, self.fc = create_classifier(
|
||||
self.in_features,
|
||||
@ -180,7 +180,7 @@ class NormMlpClassifierHead(nn.Module):
|
||||
self.drop = nn.Dropout(drop_rate)
|
||||
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def reset(self, num_classes, pool_type=None):
|
||||
def reset(self, num_classes: int, pool_type: Optional[str] = None):
|
||||
if pool_type is not None:
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
|
||||
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
|
||||
|
@ -47,7 +47,7 @@ def get_norm_layer(norm_layer):
|
||||
if isinstance(norm_layer, str):
|
||||
if not norm_layer:
|
||||
return None
|
||||
layer_name = norm_layer.replace('_', '')
|
||||
layer_name = norm_layer.replace('_', '').lower()
|
||||
norm_layer = _NORM_MAP[layer_name]
|
||||
else:
|
||||
norm_layer = norm_layer
|
||||
|
@ -32,7 +32,7 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, Mlp, use_fused_attn, _assert
|
||||
from timm.layers import DropPath, Mlp, use_fused_attn, _assert, get_norm_layer
|
||||
|
||||
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
@ -372,20 +372,41 @@ class HieraBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
class NormClassifierHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
pool_type: str = 'avg',
|
||||
drop_rate: float = 0.0,
|
||||
norm_layer: Union[str, Callable] = 'layernorm',
|
||||
):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(drop_rate) if drop_rate > 0 else nn.Identity()
|
||||
self.projection = nn.Linear(dim, num_classes)
|
||||
norm_layer = get_norm_layer(norm_layer)
|
||||
assert pool_type in ('avg', '')
|
||||
self.in_features = self.num_features = in_features
|
||||
self.pool_type = pool_type
|
||||
self.norm = norm_layer(in_features)
|
||||
self.drop = nn.Dropout(drop_rate) if drop_rate else nn.Identity()
|
||||
self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.dropout(x)
|
||||
x = self.projection(x)
|
||||
def reset(self, num_classes: int, pool_type: Optional[str] = None, other: bool = False):
|
||||
if pool_type is not None:
|
||||
assert pool_type in ('avg', '')
|
||||
self.pool_type = pool_type
|
||||
if other:
|
||||
# reset other non-fc layers
|
||||
self.norm = nn.Identity()
|
||||
self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
||||
if self.pool_type == 'avg':
|
||||
x = x.mean(dim=1)
|
||||
x = self.norm(x)
|
||||
x = self.drop(x)
|
||||
if pre_logits:
|
||||
return x
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -438,6 +459,7 @@ class Hiera(nn.Module):
|
||||
embed_dim: int = 96, # initial embed dim
|
||||
num_heads: int = 1, # initial number of heads
|
||||
num_classes: int = 1000,
|
||||
global_pool: str = 'avg',
|
||||
stages: Tuple[int, ...] = (2, 3, 16, 3),
|
||||
q_pool: int = 3, # number of q_pool stages
|
||||
q_stride: Tuple[int, ...] = (2, 2),
|
||||
@ -458,11 +480,7 @@ class Hiera(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
|
||||
# Do it this way to ensure that the init args are all PoD (for config usage)
|
||||
if isinstance(norm_layer, str):
|
||||
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
|
||||
|
||||
norm_layer = get_norm_layer(norm_layer)
|
||||
depth = sum(stages)
|
||||
self.patch_stride = patch_stride
|
||||
self.tokens_spatial_shape = [i // s for i, s in zip(img_size, patch_stride)]
|
||||
@ -552,8 +570,14 @@ class Hiera(nn.Module):
|
||||
dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')]
|
||||
self.blocks.append(block)
|
||||
|
||||
self.norm = norm_layer(embed_dim)
|
||||
self.head = Head(embed_dim, num_classes, drop_rate=drop_rate)
|
||||
self.num_features = embed_dim
|
||||
self.head = NormClassifierHead(
|
||||
embed_dim,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
|
||||
# Initialize everything
|
||||
if sep_pos_embed:
|
||||
@ -562,8 +586,8 @@ class Hiera(nn.Module):
|
||||
else:
|
||||
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
||||
self.apply(partial(self._init_weights))
|
||||
self.head.projection.weight.data.mul_(head_init_scale)
|
||||
self.head.projection.bias.data.mul_(head_init_scale)
|
||||
self.head.fc.weight.data.mul_(head_init_scale)
|
||||
self.head.fc.bias.data.mul_(head_init_scale)
|
||||
|
||||
def _init_weights(self, m, init_bias=0.02):
|
||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
@ -678,19 +702,17 @@ class Hiera(nn.Module):
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
n: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), n)
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||
max_index = self.stage_ends[max_index]
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_head:
|
||||
# norm part of head for this model, equivalent to fc_norm in other vit.
|
||||
self.norm = nn.Identity()
|
||||
self.head = nn.Identity()
|
||||
self.head.reset(0, other=True)
|
||||
return take_indices
|
||||
|
||||
|
||||
@ -732,11 +754,7 @@ class Hiera(nn.Module):
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
|
||||
x = x.mean(dim=1)
|
||||
x = self.norm(x)
|
||||
if pre_logits:
|
||||
return x
|
||||
x = self.head(x)
|
||||
x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
@ -756,7 +774,7 @@ def _cfg(url='', **kwargs):
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
@ -837,6 +855,10 @@ def checkpoint_filter_fn(state_dict, model=None):
|
||||
# )
|
||||
#v = F.interpolate(v.transpose(1, 2), (model.pos_embed.shape[1],)).transpose(1, 2)
|
||||
pass
|
||||
if 'head.projection.' in k:
|
||||
k = k.replace('head.projection.', 'head.fc.')
|
||||
if k.startswith('norm.'):
|
||||
k = k.replace('norm.', 'head.norm.')
|
||||
output[k] = v
|
||||
return output
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user