update group_matcher

This commit is contained in:
Ryan 2025-05-14 08:28:10 +08:00
parent 7fc0692843
commit 89d2952375
4 changed files with 76 additions and 54 deletions

View File

@ -16,7 +16,7 @@ Modifications by / Copyright 2025 Ryan Hou & Ross Wightman, original copyrights
# Licensed under the MIT License.
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
@ -118,6 +118,7 @@ class Block(nn.Module):
merge_size: Union[int, Tuple[int, int]] = 2,
):
super().__init__()
self.grad_checkpointing = False
self.blocks = nn.Sequential(*[
MLPBlock(
dim=dim,
@ -127,18 +128,22 @@ class Block(nn.Module):
layer_scale_init_value=layer_scale_init_value,
norm_layer=norm_layer,
act_layer=act_layer,
pconv_fw_type=pconv_fw_type
pconv_fw_type=pconv_fw_type,
)
for i in range(depth)
])
self.down = PatchMerging(
self.downsample = PatchMerging(
dim=dim // 2,
patch_size=merge_size,
norm_layer=norm_layer,
) if use_merge else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.blocks(self.down(x))
x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x, flatten=True)
else:
x = self.blocks(x)
return x
@ -202,7 +207,6 @@ class FasterNet(nn.Module):
depths = (depths) # it means the model has only one stage
self.num_stages = len(depths)
self.feature_info = []
self.grad_checkpointing = False
self.patch_embed = PatchEmbed(
in_chans=in_chans,
@ -255,20 +259,26 @@ class FasterNet(nn.Module):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def no_weight_decay(self) -> Set:
return set()
@torch.jit.ignore
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
matcher = dict(
stem=r'patch_embed',
blocks=[
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
(r'conv_head', (99999,))
stem=r'^patch_embed', # stem and embed
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+).downsample', (0,)),
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^conv_head', (99999,)),
]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
self.grad_checkpointing = enable
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self) -> nn.Module:
@ -339,9 +349,6 @@ class FasterNet(nn.Module):
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x, flatten=True)
else:
x = self.stages(x)
return x
@ -371,11 +378,11 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module)
}
stage_mapping = {
'stages.1.': 'stages.1.down.',
'stages.1.': 'stages.1.downsample.',
'stages.2.': 'stages.1.',
'stages.3.': 'stages.2.down.',
'stages.3.': 'stages.2.downsample.',
'stages.4.': 'stages.2.',
'stages.5.': 'stages.3.down.',
'stages.5.': 'stages.3.downsample.',
'stages.6.': 'stages.3.'
}

View File

@ -58,7 +58,7 @@ class Conv2d_BN(nn.Sequential):
stride: int = 1,
padding: int = 0,
bn_weight_init: int = 1,
**kwargs
**kwargs,
):
super().__init__()
self.add_module('c', nn.Conv2d(
@ -229,7 +229,8 @@ class StageBlock(nn.Module):
act_layer: LayerType = nn.ReLU,
):
super().__init__()
self.down = nn.Sequential(
self.grad_checkpointing = False
self.downsample = nn.Sequential(
Residule(Conv2d_BN(prev_dim, prev_dim, 3, 1, 1, groups=prev_dim)),
Residule(FFN(prev_dim, int(prev_dim * 2), act_layer)),
PatchMerging(prev_dim, dim, act_layer),
@ -237,13 +238,16 @@ class StageBlock(nn.Module):
Residule(FFN(dim, int(dim * 2), act_layer)),
) if prev_dim != dim else nn.Identity()
self.block = nn.Sequential(*[
self.blocks = nn.Sequential(*[
BasicBlock(dim, qk_dim, pdim, type, norm_layer, act_layer) for _ in range(depth)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.down(x)
x = self.block(x)
x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x, flatten=True)
else:
x = self.blocks(x)
return x
@ -265,7 +269,6 @@ class SHViT(nn.Module):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.grad_checkpointing = False
self.feature_info = []
# Patch embedding
@ -281,10 +284,10 @@ class SHViT(nn.Module):
)
# Build SHViT blocks
blocks = []
stages = []
prev_chs = stem_chs
for i in range(len(embed_dim)):
blocks.append(StageBlock(
stages.append(StageBlock(
prev_dim=prev_chs,
dim=embed_dim[i],
qk_dim=qk_dim[i],
@ -295,9 +298,9 @@ class SHViT(nn.Module):
act_layer=act_layer,
))
prev_chs = embed_dim[i]
self.feature_info.append(dict(num_chs=prev_chs, reduction=2**(i+4), module=f'blocks.{i}'))
self.feature_info.append(dict(num_chs=prev_chs, reduction=2**(i+4), module=f'stages.{i}'))
self.stages = nn.Sequential(*stages)
self.blocks = nn.Sequential(*blocks)
# Classifier head
self.num_features = self.head_hidden_size = embed_dim[-1]
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
@ -310,12 +313,19 @@ class SHViT(nn.Module):
@torch.jit.ignore
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
matcher = dict(stem=r'^patch_embed', blocks=[(r'^blocks\.(\d+)', None)])
matcher = dict(
stem=r'^patch_embed', # stem and embed
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+).downsample', (0,)),
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
self.grad_checkpointing = enable
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self) -> nn.Module:
@ -351,14 +361,14 @@ class SHViT(nn.Module):
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.patch_embed(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.blocks
stages = self.stages
else:
stages = self.blocks[:max_index + 1]
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
@ -378,18 +388,15 @@ class SHViT(nn.Module):
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
self.blocks = self.blocks[:max_index + 1] # truncate blocks w/ stem as idx 0
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x, flatten=True)
else:
x = self.blocks(x)
x = self.stages(x)
return x
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
@ -424,19 +431,19 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module)
out_dict = {}
replace_rules = [
(re.compile(r'^blocks1\.'), 'blocks.0.block.'),
(re.compile(r'^blocks2\.'), 'blocks.1.block.'),
(re.compile(r'^blocks3\.'), 'blocks.2.block.'),
(re.compile(r'^blocks1\.'), 'stages.0.blocks.'),
(re.compile(r'^blocks2\.'), 'stages.1.blocks.'),
(re.compile(r'^blocks3\.'), 'stages.2.blocks.'),
]
downsample_mapping = {}
for i in range(1, 3):
downsample_mapping[f'^blocks\\.{i}\\.block\\.0\\.0\\.'] = f'blocks.{i}.down.0.'
downsample_mapping[f'^blocks\\.{i}\\.block\\.0\\.1\\.'] = f'blocks.{i}.down.1.'
downsample_mapping[f'^blocks\\.{i}\\.block\\.1\\.'] = f'blocks.{i}.down.2.'
downsample_mapping[f'^blocks\\.{i}\\.block\\.2\\.0\\.'] = f'blocks.{i}.down.3.'
downsample_mapping[f'^blocks\\.{i}\\.block\\.2\\.1\\.'] = f'blocks.{i}.down.4.'
downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.0\\.'] = f'stages.{i}.downsample.0.'
downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.1\\.'] = f'stages.{i}.downsample.1.'
downsample_mapping[f'^stages\\.{i}\\.blocks\\.1\\.'] = f'stages.{i}.downsample.2.'
downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.0\\.'] = f'stages.{i}.downsample.3.'
downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.1\\.'] = f'stages.{i}.downsample.4.'
for j in range(3, 10):
downsample_mapping[f'^blocks\\.{i}\\.block\\.{j}\\.'] = f'blocks.{i}.block.{j - 3}.'
downsample_mapping[f'^stages\\.{i}\\.blocks\\.{j}\\.'] = f'stages.{i}.blocks.{j - 3}.'
downsample_patterns = [
(re.compile(pattern), replacement) for pattern, replacement in downsample_mapping.items()]

View File

@ -34,7 +34,7 @@ class ConvBN(nn.Sequential):
stride: int = 1,
padding: int = 0,
with_bn: bool = True,
**kwargs
**kwargs,
):
super().__init__()
self.add_module('conv', nn.Conv2d(
@ -141,7 +141,10 @@ class StarNet(nn.Module):
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
matcher = dict(
stem=r'^stem\.\d+',
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
blocks=[
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
(r'norm', (99999,))
]
)
return matcher
@ -206,6 +209,7 @@ class StarNet(nn.Module):
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.norm(x)
return x, intermediates

View File

@ -402,7 +402,11 @@ class SwiftFormer(nn.Module):
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
matcher = dict(
stem=r'^stem', # stem and embed
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+).downsample', (0,)),
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^norm', (99999,)),
]
)
return matcher