mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Refactor vit pooling to add more reduction options, separately callable
This commit is contained in:
parent
9567cf6d84
commit
71101ebba0
@ -386,6 +386,31 @@ class ParallelThingsBlock(nn.Module):
|
|||||||
return self._forward(x)
|
return self._forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
def global_pool_nlc(
|
||||||
|
x: torch.Tensor,
|
||||||
|
pool_type: str = 'token',
|
||||||
|
num_prefix_tokens: int = 1,
|
||||||
|
reduce_include_prefix: bool = False,
|
||||||
|
):
|
||||||
|
if not pool_type:
|
||||||
|
return x
|
||||||
|
|
||||||
|
if pool_type == 'token':
|
||||||
|
x = x[:, 0] # class token
|
||||||
|
else:
|
||||||
|
x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
|
||||||
|
if pool_type == 'avg':
|
||||||
|
x = x.mean(dim=1)
|
||||||
|
elif pool_type == 'avgmax':
|
||||||
|
x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
|
||||||
|
elif pool_type == 'max':
|
||||||
|
x = x.amax(dim=1)
|
||||||
|
else:
|
||||||
|
assert not pool_type, f'Unknown pool type {pool_type}'
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class VisionTransformer(nn.Module):
|
class VisionTransformer(nn.Module):
|
||||||
""" Vision Transformer
|
""" Vision Transformer
|
||||||
|
|
||||||
@ -400,7 +425,7 @@ class VisionTransformer(nn.Module):
|
|||||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||||
in_chans: int = 3,
|
in_chans: int = 3,
|
||||||
num_classes: int = 1000,
|
num_classes: int = 1000,
|
||||||
global_pool: Literal['', 'avg', 'max', 'token', 'map'] = 'token',
|
global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token',
|
||||||
embed_dim: int = 768,
|
embed_dim: int = 768,
|
||||||
depth: int = 12,
|
depth: int = 12,
|
||||||
num_heads: int = 12,
|
num_heads: int = 12,
|
||||||
@ -459,10 +484,10 @@ class VisionTransformer(nn.Module):
|
|||||||
block_fn: Transformer block layer.
|
block_fn: Transformer block layer.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert global_pool in ('', 'avg', 'max', 'token', 'map')
|
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
|
||||||
assert class_token or global_pool != 'token'
|
assert class_token or global_pool != 'token'
|
||||||
assert pos_embed in ('', 'none', 'learn')
|
assert pos_embed in ('', 'none', 'learn')
|
||||||
use_fc_norm = global_pool in ['avg', 'max'] if fc_norm is None else fc_norm
|
use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
|
||||||
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
||||||
act_layer = get_act_layer(act_layer) or nn.GELU
|
act_layer = get_act_layer(act_layer) or nn.GELU
|
||||||
|
|
||||||
@ -596,10 +621,10 @@ class VisionTransformer(nn.Module):
|
|||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
return self.head
|
return self.head
|
||||||
|
|
||||||
def reset_classifier(self, num_classes: int, global_pool = None) -> None:
|
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
if global_pool is not None:
|
if global_pool is not None:
|
||||||
assert global_pool in ('', 'avg', 'token', 'map')
|
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
|
||||||
if global_pool == 'map' and self.attn_pool is None:
|
if global_pool == 'map' and self.attn_pool is None:
|
||||||
assert False, "Cannot currently add attention pooling in reset_classifier()."
|
assert False, "Cannot currently add attention pooling in reset_classifier()."
|
||||||
elif global_pool != 'map ' and self.attn_pool is not None:
|
elif global_pool != 'map ' and self.attn_pool is not None:
|
||||||
@ -756,15 +781,16 @@ class VisionTransformer(nn.Module):
|
|||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
|
||||||
if self.attn_pool is not None:
|
if self.attn_pool is not None:
|
||||||
x = self.attn_pool(x)
|
x = self.attn_pool(x)
|
||||||
elif self.global_pool == 'avg':
|
return x
|
||||||
x = x[:, self.num_prefix_tokens:].mean(dim=1)
|
pool_type = self.global_pool if pool_type is None else pool_type
|
||||||
elif self.global_pool == 'max':
|
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
|
||||||
x, _ = torch.max(x[:, self.num_prefix_tokens:], dim=1)
|
return x
|
||||||
elif self.global_pool:
|
|
||||||
x = x[:, 0] # class token
|
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
||||||
|
x = self.pool(x)
|
||||||
x = self.fc_norm(x)
|
x = self.fc_norm(x)
|
||||||
x = self.head_drop(x)
|
x = self.head_drop(x)
|
||||||
return x if pre_logits else self.head(x)
|
return x if pre_logits else self.head(x)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user