diff --git a/timm/models/fasternet.py b/timm/models/fasternet.py index 379c327e..67702b15 100644 --- a/timm/models/fasternet.py +++ b/timm/models/fasternet.py @@ -16,7 +16,7 @@ Modifications by / Copyright 2025 Ryan Hou & Ross Wightman, original copyrights # Licensed under the MIT License. from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -118,6 +118,7 @@ class Block(nn.Module): merge_size: Union[int, Tuple[int, int]] = 2, ): super().__init__() + self.grad_checkpointing = False self.blocks = nn.Sequential(*[ MLPBlock( dim=dim, @@ -127,18 +128,22 @@ class Block(nn.Module): layer_scale_init_value=layer_scale_init_value, norm_layer=norm_layer, act_layer=act_layer, - pconv_fw_type=pconv_fw_type + pconv_fw_type=pconv_fw_type, ) for i in range(depth) ]) - self.down = PatchMerging( + self.downsample = PatchMerging( dim=dim // 2, patch_size=merge_size, norm_layer=norm_layer, ) if use_merge else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.blocks(self.down(x)) + x = self.downsample(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x, flatten=True) + else: + x = self.blocks(x) return x @@ -202,7 +207,6 @@ class FasterNet(nn.Module): depths = (depths) # it means the model has only one stage self.num_stages = len(depths) self.feature_info = [] - self.grad_checkpointing = False self.patch_embed = PatchEmbed( in_chans=in_chans, @@ -255,20 +259,26 @@ class FasterNet(nn.Module): if m.bias is not None: nn.init.constant_(m.bias, 0) + @torch.jit.ignore + def no_weight_decay(self) -> Set: + return set() + @torch.jit.ignore def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: matcher = dict( - stem=r'patch_embed', - blocks=[ - (r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None), - (r'conv_head', (99999,)) + stem=r'^patch_embed', # stem and embed + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + (r'^conv_head', (99999,)), ] ) return matcher @torch.jit.ignore - def set_grad_checkpointing(self, enable: bool = True): - self.grad_checkpointing = enable + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self) -> nn.Module: @@ -339,10 +349,7 @@ class FasterNet(nn.Module): def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) - if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.stages, x, flatten=True) - else: - x = self.stages(x) + x = self.stages(x) return x def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: @@ -371,11 +378,11 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) } stage_mapping = { - 'stages.1.': 'stages.1.down.', + 'stages.1.': 'stages.1.downsample.', 'stages.2.': 'stages.1.', - 'stages.3.': 'stages.2.down.', + 'stages.3.': 'stages.2.downsample.', 'stages.4.': 'stages.2.', - 'stages.5.': 'stages.3.down.', + 'stages.5.': 'stages.3.downsample.', 'stages.6.': 'stages.3.' } diff --git a/timm/models/shvit.py b/timm/models/shvit.py index 541c8729..497637a4 100644 --- a/timm/models/shvit.py +++ b/timm/models/shvit.py @@ -58,7 +58,7 @@ class Conv2d_BN(nn.Sequential): stride: int = 1, padding: int = 0, bn_weight_init: int = 1, - **kwargs + **kwargs, ): super().__init__() self.add_module('c', nn.Conv2d( @@ -229,7 +229,8 @@ class StageBlock(nn.Module): act_layer: LayerType = nn.ReLU, ): super().__init__() - self.down = nn.Sequential( + self.grad_checkpointing = False + self.downsample = nn.Sequential( Residule(Conv2d_BN(prev_dim, prev_dim, 3, 1, 1, groups=prev_dim)), Residule(FFN(prev_dim, int(prev_dim * 2), act_layer)), PatchMerging(prev_dim, dim, act_layer), @@ -237,13 +238,16 @@ class StageBlock(nn.Module): Residule(FFN(dim, int(dim * 2), act_layer)), ) if prev_dim != dim else nn.Identity() - self.block = nn.Sequential(*[ + self.blocks = nn.Sequential(*[ BasicBlock(dim, qk_dim, pdim, type, norm_layer, act_layer) for _ in range(depth) ]) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.down(x) - x = self.block(x) + x = self.downsample(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x, flatten=True) + else: + x = self.blocks(x) return x @@ -265,7 +269,6 @@ class SHViT(nn.Module): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate - self.grad_checkpointing = False self.feature_info = [] # Patch embedding @@ -281,10 +284,10 @@ class SHViT(nn.Module): ) # Build SHViT blocks - blocks = [] + stages = [] prev_chs = stem_chs for i in range(len(embed_dim)): - blocks.append(StageBlock( + stages.append(StageBlock( prev_dim=prev_chs, dim=embed_dim[i], qk_dim=qk_dim[i], @@ -295,9 +298,9 @@ class SHViT(nn.Module): act_layer=act_layer, )) prev_chs = embed_dim[i] - self.feature_info.append(dict(num_chs=prev_chs, reduction=2**(i+4), module=f'blocks.{i}')) - - self.blocks = nn.Sequential(*blocks) + self.feature_info.append(dict(num_chs=prev_chs, reduction=2**(i+4), module=f'stages.{i}')) + self.stages = nn.Sequential(*stages) + # Classifier head self.num_features = self.head_hidden_size = embed_dim[-1] self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) @@ -310,12 +313,19 @@ class SHViT(nn.Module): @torch.jit.ignore def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: - matcher = dict(stem=r'^patch_embed', blocks=[(r'^blocks\.(\d+)', None)]) + matcher = dict( + stem=r'^patch_embed', # stem and embed + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + ] + ) return matcher @torch.jit.ignore - def set_grad_checkpointing(self, enable: bool = True): - self.grad_checkpointing = enable + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self) -> nn.Module: @@ -351,14 +361,14 @@ class SHViT(nn.Module): """ assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' intermediates = [] - take_indices, max_index = feature_take_indices(len(self.blocks), indices) + take_indices, max_index = feature_take_indices(len(self.stages), indices) # forward pass x = self.patch_embed(x) if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript - stages = self.blocks + stages = self.stages else: - stages = self.blocks[:max_index + 1] + stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): x = stage(x) @@ -378,18 +388,15 @@ class SHViT(nn.Module): ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.blocks), indices) - self.blocks = self.blocks[:max_index + 1] # truncate blocks w/ stem as idx 0 + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 if prune_head: self.reset_classifier(0, '') return take_indices def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) - if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.blocks, x, flatten=True) - else: - x = self.blocks(x) + x = self.stages(x) return x def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: @@ -424,19 +431,19 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) out_dict = {} replace_rules = [ - (re.compile(r'^blocks1\.'), 'blocks.0.block.'), - (re.compile(r'^blocks2\.'), 'blocks.1.block.'), - (re.compile(r'^blocks3\.'), 'blocks.2.block.'), + (re.compile(r'^blocks1\.'), 'stages.0.blocks.'), + (re.compile(r'^blocks2\.'), 'stages.1.blocks.'), + (re.compile(r'^blocks3\.'), 'stages.2.blocks.'), ] downsample_mapping = {} for i in range(1, 3): - downsample_mapping[f'^blocks\\.{i}\\.block\\.0\\.0\\.'] = f'blocks.{i}.down.0.' - downsample_mapping[f'^blocks\\.{i}\\.block\\.0\\.1\\.'] = f'blocks.{i}.down.1.' - downsample_mapping[f'^blocks\\.{i}\\.block\\.1\\.'] = f'blocks.{i}.down.2.' - downsample_mapping[f'^blocks\\.{i}\\.block\\.2\\.0\\.'] = f'blocks.{i}.down.3.' - downsample_mapping[f'^blocks\\.{i}\\.block\\.2\\.1\\.'] = f'blocks.{i}.down.4.' + downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.0\\.'] = f'stages.{i}.downsample.0.' + downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.1\\.'] = f'stages.{i}.downsample.1.' + downsample_mapping[f'^stages\\.{i}\\.blocks\\.1\\.'] = f'stages.{i}.downsample.2.' + downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.0\\.'] = f'stages.{i}.downsample.3.' + downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.1\\.'] = f'stages.{i}.downsample.4.' for j in range(3, 10): - downsample_mapping[f'^blocks\\.{i}\\.block\\.{j}\\.'] = f'blocks.{i}.block.{j - 3}.' + downsample_mapping[f'^stages\\.{i}\\.blocks\\.{j}\\.'] = f'stages.{i}.blocks.{j - 3}.' downsample_patterns = [ (re.compile(pattern), replacement) for pattern, replacement in downsample_mapping.items()] diff --git a/timm/models/starnet.py b/timm/models/starnet.py index 93b1e537..bfd5850d 100644 --- a/timm/models/starnet.py +++ b/timm/models/starnet.py @@ -34,7 +34,7 @@ class ConvBN(nn.Sequential): stride: int = 1, padding: int = 0, with_bn: bool = True, - **kwargs + **kwargs, ): super().__init__() self.add_module('conv', nn.Conv2d( @@ -141,7 +141,10 @@ class StarNet(nn.Module): def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: matcher = dict( stem=r'^stem\.\d+', - blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] + blocks=[ + (r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None), + (r'norm', (99999,)) + ] ) return matcher @@ -206,7 +209,8 @@ class StarNet(nn.Module): if intermediates_only: return intermediates - x = self.norm(x) + if feat_idx == last_idx: + x = self.norm(x) return x, intermediates diff --git a/timm/models/swiftformer.py b/timm/models/swiftformer.py index ec8ae595..d39b8308 100644 --- a/timm/models/swiftformer.py +++ b/timm/models/swiftformer.py @@ -402,7 +402,11 @@ class SwiftFormer(nn.Module): def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: matcher = dict( stem=r'^stem', # stem and embed - blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + (r'^norm', (99999,)), + ] ) return matcher