diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 5beb77a2..441ac0c5 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -386,6 +386,31 @@ class ParallelThingsBlock(nn.Module): return self._forward(x) +def global_pool_nlc( + x: torch.Tensor, + pool_type: str = 'token', + num_prefix_tokens: int = 1, + reduce_include_prefix: bool = False, +): + if not pool_type: + return x + + if pool_type == 'token': + x = x[:, 0] # class token + else: + x = x if reduce_include_prefix else x[:, num_prefix_tokens:] + if pool_type == 'avg': + x = x.mean(dim=1) + elif pool_type == 'avgmax': + x = 0.5 * (x.amax(dim=1) + x.mean(dim=1)) + elif pool_type == 'max': + x = x.amax(dim=1) + else: + assert not pool_type, f'Unknown pool type {pool_type}' + + return x + + class VisionTransformer(nn.Module): """ Vision Transformer @@ -400,7 +425,7 @@ class VisionTransformer(nn.Module): patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, num_classes: int = 1000, - global_pool: Literal['', 'avg', 'max', 'token', 'map'] = 'token', + global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token', embed_dim: int = 768, depth: int = 12, num_heads: int = 12, @@ -459,10 +484,10 @@ class VisionTransformer(nn.Module): block_fn: Transformer block layer. """ super().__init__() - assert global_pool in ('', 'avg', 'max', 'token', 'map') + assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') assert class_token or global_pool != 'token' assert pos_embed in ('', 'none', 'learn') - use_fc_norm = global_pool in ['avg', 'max'] if fc_norm is None else fc_norm + use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) act_layer = get_act_layer(act_layer) or nn.GELU @@ -596,10 +621,10 @@ class VisionTransformer(nn.Module): def get_classifier(self) -> nn.Module: return self.head - def reset_classifier(self, num_classes: int, global_pool = None) -> None: + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: - assert global_pool in ('', 'avg', 'token', 'map') + assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') if global_pool == 'map' and self.attn_pool is None: assert False, "Cannot currently add attention pooling in reset_classifier()." elif global_pool != 'map ' and self.attn_pool is not None: @@ -756,15 +781,16 @@ class VisionTransformer(nn.Module): x = self.norm(x) return x - def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor: if self.attn_pool is not None: x = self.attn_pool(x) - elif self.global_pool == 'avg': - x = x[:, self.num_prefix_tokens:].mean(dim=1) - elif self.global_pool == 'max': - x, _ = torch.max(x[:, self.num_prefix_tokens:], dim=1) - elif self.global_pool: - x = x[:, 0] # class token + return x + pool_type = self.global_pool if pool_type is None else pool_type + x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + x = self.pool(x) x = self.fc_norm(x) x = self.head_drop(x) return x if pre_logits else self.head(x)