mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2209 from huggingface/fcossio-vit-maxpool
ViT pooling refactor
This commit is contained in:
commit
e41125cc83
@ -3,3 +3,4 @@ torchvision
|
|||||||
pyyaml
|
pyyaml
|
||||||
huggingface_hub
|
huggingface_hub
|
||||||
safetensors>=0.2
|
safetensors>=0.2
|
||||||
|
numpy<2.0
|
||||||
|
@ -156,7 +156,7 @@ class EfficientNet(nn.Module):
|
|||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
return self.classifier
|
return self.classifier
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.global_pool, self.classifier = create_classifier(
|
self.global_pool, self.classifier = create_classifier(
|
||||||
self.num_features, self.num_classes, pool_type=global_pool)
|
self.num_features, self.num_classes, pool_type=global_pool)
|
||||||
|
@ -273,7 +273,7 @@ class GhostNet(nn.Module):
|
|||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
return self.classifier
|
return self.classifier
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
# cannot meaningfully change pooling of efficient head after creation
|
# cannot meaningfully change pooling of efficient head after creation
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
|
@ -739,7 +739,7 @@ class HighResolutionNet(nn.Module):
|
|||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
return self.classifier
|
return self.classifier
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.global_pool, self.classifier = create_classifier(
|
self.global_pool, self.classifier = create_classifier(
|
||||||
self.num_features, self.num_classes, pool_type=global_pool)
|
self.num_features, self.num_classes, pool_type=global_pool)
|
||||||
|
@ -280,7 +280,7 @@ class InceptionV4(nn.Module):
|
|||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
return self.last_linear
|
return self.last_linear
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.global_pool, self.last_linear = create_classifier(
|
self.global_pool, self.last_linear = create_classifier(
|
||||||
self.num_features, self.num_classes, pool_type=global_pool)
|
self.num_features, self.num_classes, pool_type=global_pool)
|
||||||
|
@ -26,9 +26,9 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -548,7 +548,7 @@ class MetaFormer(nn.Module):
|
|||||||
# if using MlpHead, dropout is handled by MlpHead
|
# if using MlpHead, dropout is handled by MlpHead
|
||||||
if num_classes > 0:
|
if num_classes > 0:
|
||||||
if self.use_mlp_head:
|
if self.use_mlp_head:
|
||||||
# FIXME hidden size
|
# FIXME not actually returning mlp hidden state right now as pre-logits.
|
||||||
final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate)
|
final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate)
|
||||||
self.head_hidden_size = self.num_features
|
self.head_hidden_size = self.num_features
|
||||||
else:
|
else:
|
||||||
@ -583,7 +583,7 @@ class MetaFormer(nn.Module):
|
|||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
return self.head.fc
|
return self.head.fc
|
||||||
|
|
||||||
def reset_classifier(self, num_classes=0, global_pool=None):
|
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||||
if global_pool is not None:
|
if global_pool is not None:
|
||||||
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||||
|
@ -518,7 +518,7 @@ class NASNetALarge(nn.Module):
|
|||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
return self.last_linear
|
return self.last_linear
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.global_pool, self.last_linear = create_classifier(
|
self.global_pool, self.last_linear = create_classifier(
|
||||||
self.num_features, self.num_classes, pool_type=global_pool)
|
self.num_features, self.num_classes, pool_type=global_pool)
|
||||||
|
@ -307,7 +307,7 @@ class PNASNet5Large(nn.Module):
|
|||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
return self.last_linear
|
return self.last_linear
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.global_pool, self.last_linear = create_classifier(
|
self.global_pool, self.last_linear = create_classifier(
|
||||||
self.num_features, self.num_classes, pool_type=global_pool)
|
self.num_features, self.num_classes, pool_type=global_pool)
|
||||||
|
@ -514,7 +514,7 @@ class RegNet(nn.Module):
|
|||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
return self.head.fc
|
return self.head.fc
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||||
self.head.reset(num_classes, pool_type=global_pool)
|
self.head.reset(num_classes, pool_type=global_pool)
|
||||||
|
|
||||||
def forward_intermediates(
|
def forward_intermediates(
|
||||||
|
@ -12,6 +12,7 @@ Copyright 2020 Ross Wightman
|
|||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -229,7 +230,7 @@ class RexNet(nn.Module):
|
|||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
return self.head.fc
|
return self.head.fc
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, global_pool)
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
|
@ -161,7 +161,7 @@ class SelecSls(nn.Module):
|
|||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
return self.fc
|
return self.fc
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||||
|
|
||||||
|
@ -337,7 +337,7 @@ class SENet(nn.Module):
|
|||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
return self.last_linear
|
return self.last_linear
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.global_pool, self.last_linear = create_classifier(
|
self.global_pool, self.last_linear = create_classifier(
|
||||||
self.num_features, self.num_classes, pool_type=global_pool)
|
self.num_features, self.num_classes, pool_type=global_pool)
|
||||||
|
@ -386,6 +386,31 @@ class ParallelThingsBlock(nn.Module):
|
|||||||
return self._forward(x)
|
return self._forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
def global_pool_nlc(
|
||||||
|
x: torch.Tensor,
|
||||||
|
pool_type: str = 'token',
|
||||||
|
num_prefix_tokens: int = 1,
|
||||||
|
reduce_include_prefix: bool = False,
|
||||||
|
):
|
||||||
|
if not pool_type:
|
||||||
|
return x
|
||||||
|
|
||||||
|
if pool_type == 'token':
|
||||||
|
x = x[:, 0] # class token
|
||||||
|
else:
|
||||||
|
x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
|
||||||
|
if pool_type == 'avg':
|
||||||
|
x = x.mean(dim=1)
|
||||||
|
elif pool_type == 'avgmax':
|
||||||
|
x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
|
||||||
|
elif pool_type == 'max':
|
||||||
|
x = x.amax(dim=1)
|
||||||
|
else:
|
||||||
|
assert not pool_type, f'Unknown pool type {pool_type}'
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class VisionTransformer(nn.Module):
|
class VisionTransformer(nn.Module):
|
||||||
""" Vision Transformer
|
""" Vision Transformer
|
||||||
|
|
||||||
@ -400,7 +425,7 @@ class VisionTransformer(nn.Module):
|
|||||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||||
in_chans: int = 3,
|
in_chans: int = 3,
|
||||||
num_classes: int = 1000,
|
num_classes: int = 1000,
|
||||||
global_pool: Literal['', 'avg', 'token', 'map'] = 'token',
|
global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token',
|
||||||
embed_dim: int = 768,
|
embed_dim: int = 768,
|
||||||
depth: int = 12,
|
depth: int = 12,
|
||||||
num_heads: int = 12,
|
num_heads: int = 12,
|
||||||
@ -459,10 +484,10 @@ class VisionTransformer(nn.Module):
|
|||||||
block_fn: Transformer block layer.
|
block_fn: Transformer block layer.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert global_pool in ('', 'avg', 'token', 'map')
|
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
|
||||||
assert class_token or global_pool != 'token'
|
assert class_token or global_pool != 'token'
|
||||||
assert pos_embed in ('', 'none', 'learn')
|
assert pos_embed in ('', 'none', 'learn')
|
||||||
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
|
use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
|
||||||
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
||||||
act_layer = get_act_layer(act_layer) or nn.GELU
|
act_layer = get_act_layer(act_layer) or nn.GELU
|
||||||
|
|
||||||
@ -596,10 +621,10 @@ class VisionTransformer(nn.Module):
|
|||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
return self.head
|
return self.head
|
||||||
|
|
||||||
def reset_classifier(self, num_classes: int, global_pool = None) -> None:
|
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
if global_pool is not None:
|
if global_pool is not None:
|
||||||
assert global_pool in ('', 'avg', 'token', 'map')
|
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
|
||||||
if global_pool == 'map' and self.attn_pool is None:
|
if global_pool == 'map' and self.attn_pool is None:
|
||||||
assert False, "Cannot currently add attention pooling in reset_classifier()."
|
assert False, "Cannot currently add attention pooling in reset_classifier()."
|
||||||
elif global_pool != 'map ' and self.attn_pool is not None:
|
elif global_pool != 'map ' and self.attn_pool is not None:
|
||||||
@ -756,13 +781,16 @@ class VisionTransformer(nn.Module):
|
|||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
|
||||||
if self.attn_pool is not None:
|
if self.attn_pool is not None:
|
||||||
x = self.attn_pool(x)
|
x = self.attn_pool(x)
|
||||||
elif self.global_pool == 'avg':
|
return x
|
||||||
x = x[:, self.num_prefix_tokens:].mean(dim=1)
|
pool_type = self.global_pool if pool_type is None else pool_type
|
||||||
elif self.global_pool:
|
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
|
||||||
x = x[:, 0] # class token
|
return x
|
||||||
|
|
||||||
|
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
||||||
|
x = self.pool(x)
|
||||||
x = self.fc_norm(x)
|
x = self.fc_norm(x)
|
||||||
x = self.head_drop(x)
|
x = self.head_drop(x)
|
||||||
return x if pre_logits else self.head(x)
|
return x if pre_logits else self.head(x)
|
||||||
|
@ -381,7 +381,7 @@ class VisionTransformerRelPos(nn.Module):
|
|||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
return self.head
|
return self.head
|
||||||
|
|
||||||
def reset_classifier(self, num_classes: int, global_pool=None):
|
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
if global_pool is not None:
|
if global_pool is not None:
|
||||||
assert global_pool in ('', 'avg', 'token')
|
assert global_pool in ('', 'avg', 'token')
|
||||||
|
@ -536,7 +536,7 @@ class VisionTransformerSAM(nn.Module):
|
|||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
return self.head
|
return self.head
|
||||||
|
|
||||||
def reset_classifier(self, num_classes=0, global_pool=None):
|
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||||
self.head.reset(num_classes, global_pool)
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
def forward_intermediates(
|
def forward_intermediates(
|
||||||
|
@ -11,7 +11,7 @@ for some reference, rewrote most of the code.
|
|||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -134,9 +134,17 @@ class OsaStage(nn.Module):
|
|||||||
else:
|
else:
|
||||||
drop_path = None
|
drop_path = None
|
||||||
blocks += [OsaBlock(
|
blocks += [OsaBlock(
|
||||||
in_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, depthwise=depthwise,
|
in_chs,
|
||||||
attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer, drop_path=drop_path)
|
mid_chs,
|
||||||
]
|
out_chs,
|
||||||
|
layer_per_block,
|
||||||
|
residual=residual and i > 0,
|
||||||
|
depthwise=depthwise,
|
||||||
|
attn=attn if last_block else '',
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop_path=drop_path
|
||||||
|
)]
|
||||||
in_chs = out_chs
|
in_chs = out_chs
|
||||||
self.blocks = nn.Sequential(*blocks)
|
self.blocks = nn.Sequential(*blocks)
|
||||||
|
|
||||||
@ -252,8 +260,9 @@ class VovNet(nn.Module):
|
|||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
return self.head.fc
|
return self.head.fc
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes, global_pool: Optional[str] = None):
|
||||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
self.num_classes = num_classes
|
||||||
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
|
@ -174,7 +174,7 @@ class Xception(nn.Module):
|
|||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
return self.fc
|
return self.fc
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||||
|
|
||||||
|
@ -274,7 +274,7 @@ class XceptionAligned(nn.Module):
|
|||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
return self.head.fc
|
return self.head.fc
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||||
self.head.reset(num_classes, pool_type=global_pool)
|
self.head.reset(num_classes, pool_type=global_pool)
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user