Merge pull request #2209 from huggingface/fcossio-vit-maxpool

ViT pooling refactor
This commit is contained in:
Ross Wightman 2024-06-17 07:51:12 -07:00 committed by GitHub
commit e41125cc83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 73 additions and 34 deletions

View File

@ -3,3 +3,4 @@ torchvision
pyyaml pyyaml
huggingface_hub huggingface_hub
safetensors>=0.2 safetensors>=0.2
numpy<2.0

View File

@ -156,7 +156,7 @@ class EfficientNet(nn.Module):
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
return self.classifier return self.classifier
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool, self.classifier = create_classifier( self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool) self.num_features, self.num_classes, pool_type=global_pool)

View File

@ -273,7 +273,7 @@ class GhostNet(nn.Module):
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
return self.classifier return self.classifier
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes self.num_classes = num_classes
# cannot meaningfully change pooling of efficient head after creation # cannot meaningfully change pooling of efficient head after creation
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)

View File

@ -739,7 +739,7 @@ class HighResolutionNet(nn.Module):
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
return self.classifier return self.classifier
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool, self.classifier = create_classifier( self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool) self.num_features, self.num_classes, pool_type=global_pool)

View File

@ -280,7 +280,7 @@ class InceptionV4(nn.Module):
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
return self.last_linear return self.last_linear
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool, self.last_linear = create_classifier( self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool) self.num_features, self.num_classes, pool_type=global_pool)

View File

@ -26,9 +26,9 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -548,7 +548,7 @@ class MetaFormer(nn.Module):
# if using MlpHead, dropout is handled by MlpHead # if using MlpHead, dropout is handled by MlpHead
if num_classes > 0: if num_classes > 0:
if self.use_mlp_head: if self.use_mlp_head:
# FIXME hidden size # FIXME not actually returning mlp hidden state right now as pre-logits.
final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate) final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate)
self.head_hidden_size = self.num_features self.head_hidden_size = self.num_features
else: else:
@ -583,7 +583,7 @@ class MetaFormer(nn.Module):
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
return self.head.fc return self.head.fc
def reset_classifier(self, num_classes=0, global_pool=None): def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
if global_pool is not None: if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()

View File

@ -518,7 +518,7 @@ class NASNetALarge(nn.Module):
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
return self.last_linear return self.last_linear
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool, self.last_linear = create_classifier( self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool) self.num_features, self.num_classes, pool_type=global_pool)

View File

@ -307,7 +307,7 @@ class PNASNet5Large(nn.Module):
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
return self.last_linear return self.last_linear
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool, self.last_linear = create_classifier( self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool) self.num_features, self.num_classes, pool_type=global_pool)

View File

@ -514,7 +514,7 @@ class RegNet(nn.Module):
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
return self.head.fc return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, pool_type=global_pool) self.head.reset(num_classes, pool_type=global_pool)
def forward_intermediates( def forward_intermediates(

View File

@ -12,6 +12,7 @@ Copyright 2020 Ross Wightman
from functools import partial from functools import partial
from math import ceil from math import ceil
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -229,7 +230,7 @@ class RexNet(nn.Module):
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
return self.head.fc return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, global_pool) self.head.reset(num_classes, global_pool)

View File

@ -161,7 +161,7 @@ class SelecSls(nn.Module):
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
return self.fc return self.fc
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)

View File

@ -337,7 +337,7 @@ class SENet(nn.Module):
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
return self.last_linear return self.last_linear
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool, self.last_linear = create_classifier( self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool) self.num_features, self.num_classes, pool_type=global_pool)

View File

@ -386,6 +386,31 @@ class ParallelThingsBlock(nn.Module):
return self._forward(x) return self._forward(x)
def global_pool_nlc(
x: torch.Tensor,
pool_type: str = 'token',
num_prefix_tokens: int = 1,
reduce_include_prefix: bool = False,
):
if not pool_type:
return x
if pool_type == 'token':
x = x[:, 0] # class token
else:
x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
if pool_type == 'avg':
x = x.mean(dim=1)
elif pool_type == 'avgmax':
x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
elif pool_type == 'max':
x = x.amax(dim=1)
else:
assert not pool_type, f'Unknown pool type {pool_type}'
return x
class VisionTransformer(nn.Module): class VisionTransformer(nn.Module):
""" Vision Transformer """ Vision Transformer
@ -400,7 +425,7 @@ class VisionTransformer(nn.Module):
patch_size: Union[int, Tuple[int, int]] = 16, patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3, in_chans: int = 3,
num_classes: int = 1000, num_classes: int = 1000,
global_pool: Literal['', 'avg', 'token', 'map'] = 'token', global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token',
embed_dim: int = 768, embed_dim: int = 768,
depth: int = 12, depth: int = 12,
num_heads: int = 12, num_heads: int = 12,
@ -459,10 +484,10 @@ class VisionTransformer(nn.Module):
block_fn: Transformer block layer. block_fn: Transformer block layer.
""" """
super().__init__() super().__init__()
assert global_pool in ('', 'avg', 'token', 'map') assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
assert class_token or global_pool != 'token' assert class_token or global_pool != 'token'
assert pos_embed in ('', 'none', 'learn') assert pos_embed in ('', 'none', 'learn')
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
act_layer = get_act_layer(act_layer) or nn.GELU act_layer = get_act_layer(act_layer) or nn.GELU
@ -596,10 +621,10 @@ class VisionTransformer(nn.Module):
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
return self.head return self.head
def reset_classifier(self, num_classes: int, global_pool = None) -> None: def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes self.num_classes = num_classes
if global_pool is not None: if global_pool is not None:
assert global_pool in ('', 'avg', 'token', 'map') assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
if global_pool == 'map' and self.attn_pool is None: if global_pool == 'map' and self.attn_pool is None:
assert False, "Cannot currently add attention pooling in reset_classifier()." assert False, "Cannot currently add attention pooling in reset_classifier()."
elif global_pool != 'map ' and self.attn_pool is not None: elif global_pool != 'map ' and self.attn_pool is not None:
@ -756,13 +781,16 @@ class VisionTransformer(nn.Module):
x = self.norm(x) x = self.norm(x)
return x return x
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
if self.attn_pool is not None: if self.attn_pool is not None:
x = self.attn_pool(x) x = self.attn_pool(x)
elif self.global_pool == 'avg': return x
x = x[:, self.num_prefix_tokens:].mean(dim=1) pool_type = self.global_pool if pool_type is None else pool_type
elif self.global_pool: x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
x = x[:, 0] # class token return x
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
x = self.pool(x)
x = self.fc_norm(x) x = self.fc_norm(x)
x = self.head_drop(x) x = self.head_drop(x)
return x if pre_logits else self.head(x) return x if pre_logits else self.head(x)

View File

@ -381,7 +381,7 @@ class VisionTransformerRelPos(nn.Module):
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
return self.head return self.head
def reset_classifier(self, num_classes: int, global_pool=None): def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes self.num_classes = num_classes
if global_pool is not None: if global_pool is not None:
assert global_pool in ('', 'avg', 'token') assert global_pool in ('', 'avg', 'token')

View File

@ -536,7 +536,7 @@ class VisionTransformerSAM(nn.Module):
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
return self.head return self.head
def reset_classifier(self, num_classes=0, global_pool=None): def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, global_pool) self.head.reset(num_classes, global_pool)
def forward_intermediates( def forward_intermediates(

View File

@ -11,7 +11,7 @@ for some reference, rewrote most of the code.
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
from typing import List from typing import List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -134,9 +134,17 @@ class OsaStage(nn.Module):
else: else:
drop_path = None drop_path = None
blocks += [OsaBlock( blocks += [OsaBlock(
in_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, depthwise=depthwise, in_chs,
attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer, drop_path=drop_path) mid_chs,
] out_chs,
layer_per_block,
residual=residual and i > 0,
depthwise=depthwise,
attn=attn if last_block else '',
norm_layer=norm_layer,
act_layer=act_layer,
drop_path=drop_path
)]
in_chs = out_chs in_chs = out_chs
self.blocks = nn.Sequential(*blocks) self.blocks = nn.Sequential(*blocks)
@ -252,8 +260,9 @@ class VovNet(nn.Module):
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
return self.head.fc return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool: Optional[str] = None):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)

View File

@ -174,7 +174,7 @@ class Xception(nn.Module):
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
return self.fc return self.fc
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)

View File

@ -274,7 +274,7 @@ class XceptionAligned(nn.Module):
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
return self.head.fc return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, pool_type=global_pool) self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x): def forward_features(self, x):