update group_matcher

This commit is contained in:
Ryan 2025-05-14 08:28:10 +08:00
parent 7fc0692843
commit 89d2952375
4 changed files with 76 additions and 54 deletions

View File

@ -16,7 +16,7 @@ Modifications by / Copyright 2025 Ryan Hou & Ross Wightman, original copyrights
# Licensed under the MIT License. # Licensed under the MIT License.
from functools import partial from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -118,6 +118,7 @@ class Block(nn.Module):
merge_size: Union[int, Tuple[int, int]] = 2, merge_size: Union[int, Tuple[int, int]] = 2,
): ):
super().__init__() super().__init__()
self.grad_checkpointing = False
self.blocks = nn.Sequential(*[ self.blocks = nn.Sequential(*[
MLPBlock( MLPBlock(
dim=dim, dim=dim,
@ -127,18 +128,22 @@ class Block(nn.Module):
layer_scale_init_value=layer_scale_init_value, layer_scale_init_value=layer_scale_init_value,
norm_layer=norm_layer, norm_layer=norm_layer,
act_layer=act_layer, act_layer=act_layer,
pconv_fw_type=pconv_fw_type pconv_fw_type=pconv_fw_type,
) )
for i in range(depth) for i in range(depth)
]) ])
self.down = PatchMerging( self.downsample = PatchMerging(
dim=dim // 2, dim=dim // 2,
patch_size=merge_size, patch_size=merge_size,
norm_layer=norm_layer, norm_layer=norm_layer,
) if use_merge else nn.Identity() ) if use_merge else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.blocks(self.down(x)) x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x, flatten=True)
else:
x = self.blocks(x)
return x return x
@ -202,7 +207,6 @@ class FasterNet(nn.Module):
depths = (depths) # it means the model has only one stage depths = (depths) # it means the model has only one stage
self.num_stages = len(depths) self.num_stages = len(depths)
self.feature_info = [] self.feature_info = []
self.grad_checkpointing = False
self.patch_embed = PatchEmbed( self.patch_embed = PatchEmbed(
in_chans=in_chans, in_chans=in_chans,
@ -255,20 +259,26 @@ class FasterNet(nn.Module):
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def no_weight_decay(self) -> Set:
return set()
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
matcher = dict( matcher = dict(
stem=r'patch_embed', stem=r'^patch_embed', # stem and embed
blocks=[ blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None), (r'^stages\.(\d+).downsample', (0,)),
(r'conv_head', (99999,)) (r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^conv_head', (99999,)),
] ]
) )
return matcher return matcher
@torch.jit.ignore @torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True): def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore @torch.jit.ignore
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
@ -339,9 +349,6 @@ class FasterNet(nn.Module):
def forward_features(self, x: torch.Tensor) -> torch.Tensor: def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x) x = self.patch_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x, flatten=True)
else:
x = self.stages(x) x = self.stages(x)
return x return x
@ -371,11 +378,11 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module)
} }
stage_mapping = { stage_mapping = {
'stages.1.': 'stages.1.down.', 'stages.1.': 'stages.1.downsample.',
'stages.2.': 'stages.1.', 'stages.2.': 'stages.1.',
'stages.3.': 'stages.2.down.', 'stages.3.': 'stages.2.downsample.',
'stages.4.': 'stages.2.', 'stages.4.': 'stages.2.',
'stages.5.': 'stages.3.down.', 'stages.5.': 'stages.3.downsample.',
'stages.6.': 'stages.3.' 'stages.6.': 'stages.3.'
} }

View File

@ -58,7 +58,7 @@ class Conv2d_BN(nn.Sequential):
stride: int = 1, stride: int = 1,
padding: int = 0, padding: int = 0,
bn_weight_init: int = 1, bn_weight_init: int = 1,
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
self.add_module('c', nn.Conv2d( self.add_module('c', nn.Conv2d(
@ -229,7 +229,8 @@ class StageBlock(nn.Module):
act_layer: LayerType = nn.ReLU, act_layer: LayerType = nn.ReLU,
): ):
super().__init__() super().__init__()
self.down = nn.Sequential( self.grad_checkpointing = False
self.downsample = nn.Sequential(
Residule(Conv2d_BN(prev_dim, prev_dim, 3, 1, 1, groups=prev_dim)), Residule(Conv2d_BN(prev_dim, prev_dim, 3, 1, 1, groups=prev_dim)),
Residule(FFN(prev_dim, int(prev_dim * 2), act_layer)), Residule(FFN(prev_dim, int(prev_dim * 2), act_layer)),
PatchMerging(prev_dim, dim, act_layer), PatchMerging(prev_dim, dim, act_layer),
@ -237,13 +238,16 @@ class StageBlock(nn.Module):
Residule(FFN(dim, int(dim * 2), act_layer)), Residule(FFN(dim, int(dim * 2), act_layer)),
) if prev_dim != dim else nn.Identity() ) if prev_dim != dim else nn.Identity()
self.block = nn.Sequential(*[ self.blocks = nn.Sequential(*[
BasicBlock(dim, qk_dim, pdim, type, norm_layer, act_layer) for _ in range(depth) BasicBlock(dim, qk_dim, pdim, type, norm_layer, act_layer) for _ in range(depth)
]) ])
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.down(x) x = self.downsample(x)
x = self.block(x) if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x, flatten=True)
else:
x = self.blocks(x)
return x return x
@ -265,7 +269,6 @@ class SHViT(nn.Module):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.grad_checkpointing = False
self.feature_info = [] self.feature_info = []
# Patch embedding # Patch embedding
@ -281,10 +284,10 @@ class SHViT(nn.Module):
) )
# Build SHViT blocks # Build SHViT blocks
blocks = [] stages = []
prev_chs = stem_chs prev_chs = stem_chs
for i in range(len(embed_dim)): for i in range(len(embed_dim)):
blocks.append(StageBlock( stages.append(StageBlock(
prev_dim=prev_chs, prev_dim=prev_chs,
dim=embed_dim[i], dim=embed_dim[i],
qk_dim=qk_dim[i], qk_dim=qk_dim[i],
@ -295,9 +298,9 @@ class SHViT(nn.Module):
act_layer=act_layer, act_layer=act_layer,
)) ))
prev_chs = embed_dim[i] prev_chs = embed_dim[i]
self.feature_info.append(dict(num_chs=prev_chs, reduction=2**(i+4), module=f'blocks.{i}')) self.feature_info.append(dict(num_chs=prev_chs, reduction=2**(i+4), module=f'stages.{i}'))
self.stages = nn.Sequential(*stages)
self.blocks = nn.Sequential(*blocks)
# Classifier head # Classifier head
self.num_features = self.head_hidden_size = embed_dim[-1] self.num_features = self.head_hidden_size = embed_dim[-1]
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
@ -310,12 +313,19 @@ class SHViT(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
matcher = dict(stem=r'^patch_embed', blocks=[(r'^blocks\.(\d+)', None)]) matcher = dict(
stem=r'^patch_embed', # stem and embed
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+).downsample', (0,)),
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
]
)
return matcher return matcher
@torch.jit.ignore @torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True): def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore @torch.jit.ignore
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
@ -351,14 +361,14 @@ class SHViT(nn.Module):
""" """
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = [] intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices) take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass # forward pass
x = self.patch_embed(x) x = self.patch_embed(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.blocks stages = self.stages
else: else:
stages = self.blocks[:max_index + 1] stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages): for feat_idx, stage in enumerate(stages):
x = stage(x) x = stage(x)
@ -378,18 +388,15 @@ class SHViT(nn.Module):
): ):
""" Prune layers not required for specified intermediates. """ Prune layers not required for specified intermediates.
""" """
take_indices, max_index = feature_take_indices(len(self.blocks), indices) take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.blocks = self.blocks[:max_index + 1] # truncate blocks w/ stem as idx 0 self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head: if prune_head:
self.reset_classifier(0, '') self.reset_classifier(0, '')
return take_indices return take_indices
def forward_features(self, x: torch.Tensor) -> torch.Tensor: def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x) x = self.patch_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): x = self.stages(x)
x = checkpoint_seq(self.blocks, x, flatten=True)
else:
x = self.blocks(x)
return x return x
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
@ -424,19 +431,19 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module)
out_dict = {} out_dict = {}
replace_rules = [ replace_rules = [
(re.compile(r'^blocks1\.'), 'blocks.0.block.'), (re.compile(r'^blocks1\.'), 'stages.0.blocks.'),
(re.compile(r'^blocks2\.'), 'blocks.1.block.'), (re.compile(r'^blocks2\.'), 'stages.1.blocks.'),
(re.compile(r'^blocks3\.'), 'blocks.2.block.'), (re.compile(r'^blocks3\.'), 'stages.2.blocks.'),
] ]
downsample_mapping = {} downsample_mapping = {}
for i in range(1, 3): for i in range(1, 3):
downsample_mapping[f'^blocks\\.{i}\\.block\\.0\\.0\\.'] = f'blocks.{i}.down.0.' downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.0\\.'] = f'stages.{i}.downsample.0.'
downsample_mapping[f'^blocks\\.{i}\\.block\\.0\\.1\\.'] = f'blocks.{i}.down.1.' downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.1\\.'] = f'stages.{i}.downsample.1.'
downsample_mapping[f'^blocks\\.{i}\\.block\\.1\\.'] = f'blocks.{i}.down.2.' downsample_mapping[f'^stages\\.{i}\\.blocks\\.1\\.'] = f'stages.{i}.downsample.2.'
downsample_mapping[f'^blocks\\.{i}\\.block\\.2\\.0\\.'] = f'blocks.{i}.down.3.' downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.0\\.'] = f'stages.{i}.downsample.3.'
downsample_mapping[f'^blocks\\.{i}\\.block\\.2\\.1\\.'] = f'blocks.{i}.down.4.' downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.1\\.'] = f'stages.{i}.downsample.4.'
for j in range(3, 10): for j in range(3, 10):
downsample_mapping[f'^blocks\\.{i}\\.block\\.{j}\\.'] = f'blocks.{i}.block.{j - 3}.' downsample_mapping[f'^stages\\.{i}\\.blocks\\.{j}\\.'] = f'stages.{i}.blocks.{j - 3}.'
downsample_patterns = [ downsample_patterns = [
(re.compile(pattern), replacement) for pattern, replacement in downsample_mapping.items()] (re.compile(pattern), replacement) for pattern, replacement in downsample_mapping.items()]

View File

@ -34,7 +34,7 @@ class ConvBN(nn.Sequential):
stride: int = 1, stride: int = 1,
padding: int = 0, padding: int = 0,
with_bn: bool = True, with_bn: bool = True,
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
self.add_module('conv', nn.Conv2d( self.add_module('conv', nn.Conv2d(
@ -141,7 +141,10 @@ class StarNet(nn.Module):
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
matcher = dict( matcher = dict(
stem=r'^stem\.\d+', stem=r'^stem\.\d+',
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] blocks=[
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
(r'norm', (99999,))
]
) )
return matcher return matcher
@ -206,6 +209,7 @@ class StarNet(nn.Module):
if intermediates_only: if intermediates_only:
return intermediates return intermediates
if feat_idx == last_idx:
x = self.norm(x) x = self.norm(x)
return x, intermediates return x, intermediates

View File

@ -402,7 +402,11 @@ class SwiftFormer(nn.Module):
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
matcher = dict( matcher = dict(
stem=r'^stem', # stem and embed stem=r'^stem', # stem and embed
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+).downsample', (0,)),
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^norm', (99999,)),
]
) )
return matcher return matcher