mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
update group_matcher
This commit is contained in:
parent
7fc0692843
commit
89d2952375
@ -16,7 +16,7 @@ Modifications by / Copyright 2025 Ryan Hou & Ross Wightman, original copyrights
|
|||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -118,6 +118,7 @@ class Block(nn.Module):
|
|||||||
merge_size: Union[int, Tuple[int, int]] = 2,
|
merge_size: Union[int, Tuple[int, int]] = 2,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.grad_checkpointing = False
|
||||||
self.blocks = nn.Sequential(*[
|
self.blocks = nn.Sequential(*[
|
||||||
MLPBlock(
|
MLPBlock(
|
||||||
dim=dim,
|
dim=dim,
|
||||||
@ -127,18 +128,22 @@ class Block(nn.Module):
|
|||||||
layer_scale_init_value=layer_scale_init_value,
|
layer_scale_init_value=layer_scale_init_value,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
pconv_fw_type=pconv_fw_type
|
pconv_fw_type=pconv_fw_type,
|
||||||
)
|
)
|
||||||
for i in range(depth)
|
for i in range(depth)
|
||||||
])
|
])
|
||||||
self.down = PatchMerging(
|
self.downsample = PatchMerging(
|
||||||
dim=dim // 2,
|
dim=dim // 2,
|
||||||
patch_size=merge_size,
|
patch_size=merge_size,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
) if use_merge else nn.Identity()
|
) if use_merge else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = self.blocks(self.down(x))
|
x = self.downsample(x)
|
||||||
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
x = checkpoint_seq(self.blocks, x, flatten=True)
|
||||||
|
else:
|
||||||
|
x = self.blocks(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -202,7 +207,6 @@ class FasterNet(nn.Module):
|
|||||||
depths = (depths) # it means the model has only one stage
|
depths = (depths) # it means the model has only one stage
|
||||||
self.num_stages = len(depths)
|
self.num_stages = len(depths)
|
||||||
self.feature_info = []
|
self.feature_info = []
|
||||||
self.grad_checkpointing = False
|
|
||||||
|
|
||||||
self.patch_embed = PatchEmbed(
|
self.patch_embed = PatchEmbed(
|
||||||
in_chans=in_chans,
|
in_chans=in_chans,
|
||||||
@ -255,20 +259,26 @@ class FasterNet(nn.Module):
|
|||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def no_weight_decay(self) -> Set:
|
||||||
|
return set()
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
|
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
|
||||||
matcher = dict(
|
matcher = dict(
|
||||||
stem=r'patch_embed',
|
stem=r'^patch_embed', # stem and embed
|
||||||
blocks=[
|
blocks=r'^stages\.(\d+)' if coarse else [
|
||||||
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
|
(r'^stages\.(\d+).downsample', (0,)),
|
||||||
(r'conv_head', (99999,))
|
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
|
||||||
|
(r'^conv_head', (99999,)),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
return matcher
|
return matcher
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def set_grad_checkpointing(self, enable: bool = True):
|
def set_grad_checkpointing(self, enable=True):
|
||||||
self.grad_checkpointing = enable
|
for s in self.stages:
|
||||||
|
s.grad_checkpointing = enable
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
@ -339,9 +349,6 @@ class FasterNet(nn.Module):
|
|||||||
|
|
||||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
||||||
x = checkpoint_seq(self.stages, x, flatten=True)
|
|
||||||
else:
|
|
||||||
x = self.stages(x)
|
x = self.stages(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -371,11 +378,11 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module)
|
|||||||
}
|
}
|
||||||
|
|
||||||
stage_mapping = {
|
stage_mapping = {
|
||||||
'stages.1.': 'stages.1.down.',
|
'stages.1.': 'stages.1.downsample.',
|
||||||
'stages.2.': 'stages.1.',
|
'stages.2.': 'stages.1.',
|
||||||
'stages.3.': 'stages.2.down.',
|
'stages.3.': 'stages.2.downsample.',
|
||||||
'stages.4.': 'stages.2.',
|
'stages.4.': 'stages.2.',
|
||||||
'stages.5.': 'stages.3.down.',
|
'stages.5.': 'stages.3.downsample.',
|
||||||
'stages.6.': 'stages.3.'
|
'stages.6.': 'stages.3.'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ class Conv2d_BN(nn.Sequential):
|
|||||||
stride: int = 1,
|
stride: int = 1,
|
||||||
padding: int = 0,
|
padding: int = 0,
|
||||||
bn_weight_init: int = 1,
|
bn_weight_init: int = 1,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.add_module('c', nn.Conv2d(
|
self.add_module('c', nn.Conv2d(
|
||||||
@ -229,7 +229,8 @@ class StageBlock(nn.Module):
|
|||||||
act_layer: LayerType = nn.ReLU,
|
act_layer: LayerType = nn.ReLU,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.down = nn.Sequential(
|
self.grad_checkpointing = False
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
Residule(Conv2d_BN(prev_dim, prev_dim, 3, 1, 1, groups=prev_dim)),
|
Residule(Conv2d_BN(prev_dim, prev_dim, 3, 1, 1, groups=prev_dim)),
|
||||||
Residule(FFN(prev_dim, int(prev_dim * 2), act_layer)),
|
Residule(FFN(prev_dim, int(prev_dim * 2), act_layer)),
|
||||||
PatchMerging(prev_dim, dim, act_layer),
|
PatchMerging(prev_dim, dim, act_layer),
|
||||||
@ -237,13 +238,16 @@ class StageBlock(nn.Module):
|
|||||||
Residule(FFN(dim, int(dim * 2), act_layer)),
|
Residule(FFN(dim, int(dim * 2), act_layer)),
|
||||||
) if prev_dim != dim else nn.Identity()
|
) if prev_dim != dim else nn.Identity()
|
||||||
|
|
||||||
self.block = nn.Sequential(*[
|
self.blocks = nn.Sequential(*[
|
||||||
BasicBlock(dim, qk_dim, pdim, type, norm_layer, act_layer) for _ in range(depth)
|
BasicBlock(dim, qk_dim, pdim, type, norm_layer, act_layer) for _ in range(depth)
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = self.down(x)
|
x = self.downsample(x)
|
||||||
x = self.block(x)
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
x = checkpoint_seq(self.blocks, x, flatten=True)
|
||||||
|
else:
|
||||||
|
x = self.blocks(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -265,7 +269,6 @@ class SHViT(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
self.grad_checkpointing = False
|
|
||||||
self.feature_info = []
|
self.feature_info = []
|
||||||
|
|
||||||
# Patch embedding
|
# Patch embedding
|
||||||
@ -281,10 +284,10 @@ class SHViT(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Build SHViT blocks
|
# Build SHViT blocks
|
||||||
blocks = []
|
stages = []
|
||||||
prev_chs = stem_chs
|
prev_chs = stem_chs
|
||||||
for i in range(len(embed_dim)):
|
for i in range(len(embed_dim)):
|
||||||
blocks.append(StageBlock(
|
stages.append(StageBlock(
|
||||||
prev_dim=prev_chs,
|
prev_dim=prev_chs,
|
||||||
dim=embed_dim[i],
|
dim=embed_dim[i],
|
||||||
qk_dim=qk_dim[i],
|
qk_dim=qk_dim[i],
|
||||||
@ -295,9 +298,9 @@ class SHViT(nn.Module):
|
|||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
))
|
))
|
||||||
prev_chs = embed_dim[i]
|
prev_chs = embed_dim[i]
|
||||||
self.feature_info.append(dict(num_chs=prev_chs, reduction=2**(i+4), module=f'blocks.{i}'))
|
self.feature_info.append(dict(num_chs=prev_chs, reduction=2**(i+4), module=f'stages.{i}'))
|
||||||
|
self.stages = nn.Sequential(*stages)
|
||||||
|
|
||||||
self.blocks = nn.Sequential(*blocks)
|
|
||||||
# Classifier head
|
# Classifier head
|
||||||
self.num_features = self.head_hidden_size = embed_dim[-1]
|
self.num_features = self.head_hidden_size = embed_dim[-1]
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
@ -310,12 +313,19 @@ class SHViT(nn.Module):
|
|||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
|
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
|
||||||
matcher = dict(stem=r'^patch_embed', blocks=[(r'^blocks\.(\d+)', None)])
|
matcher = dict(
|
||||||
|
stem=r'^patch_embed', # stem and embed
|
||||||
|
blocks=r'^stages\.(\d+)' if coarse else [
|
||||||
|
(r'^stages\.(\d+).downsample', (0,)),
|
||||||
|
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
|
||||||
|
]
|
||||||
|
)
|
||||||
return matcher
|
return matcher
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def set_grad_checkpointing(self, enable: bool = True):
|
def set_grad_checkpointing(self, enable=True):
|
||||||
self.grad_checkpointing = enable
|
for s in self.stages:
|
||||||
|
s.grad_checkpointing = enable
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
@ -351,14 +361,14 @@ class SHViT(nn.Module):
|
|||||||
"""
|
"""
|
||||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
intermediates = []
|
intermediates = []
|
||||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
stages = self.blocks
|
stages = self.stages
|
||||||
else:
|
else:
|
||||||
stages = self.blocks[:max_index + 1]
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
for feat_idx, stage in enumerate(stages):
|
for feat_idx, stage in enumerate(stages):
|
||||||
x = stage(x)
|
x = stage(x)
|
||||||
@ -378,18 +388,15 @@ class SHViT(nn.Module):
|
|||||||
):
|
):
|
||||||
""" Prune layers not required for specified intermediates.
|
""" Prune layers not required for specified intermediates.
|
||||||
"""
|
"""
|
||||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks w/ stem as idx 0
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
if prune_head:
|
if prune_head:
|
||||||
self.reset_classifier(0, '')
|
self.reset_classifier(0, '')
|
||||||
return take_indices
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
x = self.stages(x)
|
||||||
x = checkpoint_seq(self.blocks, x, flatten=True)
|
|
||||||
else:
|
|
||||||
x = self.blocks(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
||||||
@ -424,19 +431,19 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module)
|
|||||||
out_dict = {}
|
out_dict = {}
|
||||||
|
|
||||||
replace_rules = [
|
replace_rules = [
|
||||||
(re.compile(r'^blocks1\.'), 'blocks.0.block.'),
|
(re.compile(r'^blocks1\.'), 'stages.0.blocks.'),
|
||||||
(re.compile(r'^blocks2\.'), 'blocks.1.block.'),
|
(re.compile(r'^blocks2\.'), 'stages.1.blocks.'),
|
||||||
(re.compile(r'^blocks3\.'), 'blocks.2.block.'),
|
(re.compile(r'^blocks3\.'), 'stages.2.blocks.'),
|
||||||
]
|
]
|
||||||
downsample_mapping = {}
|
downsample_mapping = {}
|
||||||
for i in range(1, 3):
|
for i in range(1, 3):
|
||||||
downsample_mapping[f'^blocks\\.{i}\\.block\\.0\\.0\\.'] = f'blocks.{i}.down.0.'
|
downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.0\\.'] = f'stages.{i}.downsample.0.'
|
||||||
downsample_mapping[f'^blocks\\.{i}\\.block\\.0\\.1\\.'] = f'blocks.{i}.down.1.'
|
downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.1\\.'] = f'stages.{i}.downsample.1.'
|
||||||
downsample_mapping[f'^blocks\\.{i}\\.block\\.1\\.'] = f'blocks.{i}.down.2.'
|
downsample_mapping[f'^stages\\.{i}\\.blocks\\.1\\.'] = f'stages.{i}.downsample.2.'
|
||||||
downsample_mapping[f'^blocks\\.{i}\\.block\\.2\\.0\\.'] = f'blocks.{i}.down.3.'
|
downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.0\\.'] = f'stages.{i}.downsample.3.'
|
||||||
downsample_mapping[f'^blocks\\.{i}\\.block\\.2\\.1\\.'] = f'blocks.{i}.down.4.'
|
downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.1\\.'] = f'stages.{i}.downsample.4.'
|
||||||
for j in range(3, 10):
|
for j in range(3, 10):
|
||||||
downsample_mapping[f'^blocks\\.{i}\\.block\\.{j}\\.'] = f'blocks.{i}.block.{j - 3}.'
|
downsample_mapping[f'^stages\\.{i}\\.blocks\\.{j}\\.'] = f'stages.{i}.blocks.{j - 3}.'
|
||||||
|
|
||||||
downsample_patterns = [
|
downsample_patterns = [
|
||||||
(re.compile(pattern), replacement) for pattern, replacement in downsample_mapping.items()]
|
(re.compile(pattern), replacement) for pattern, replacement in downsample_mapping.items()]
|
||||||
|
@ -34,7 +34,7 @@ class ConvBN(nn.Sequential):
|
|||||||
stride: int = 1,
|
stride: int = 1,
|
||||||
padding: int = 0,
|
padding: int = 0,
|
||||||
with_bn: bool = True,
|
with_bn: bool = True,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.add_module('conv', nn.Conv2d(
|
self.add_module('conv', nn.Conv2d(
|
||||||
@ -141,7 +141,10 @@ class StarNet(nn.Module):
|
|||||||
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
|
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
|
||||||
matcher = dict(
|
matcher = dict(
|
||||||
stem=r'^stem\.\d+',
|
stem=r'^stem\.\d+',
|
||||||
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
|
blocks=[
|
||||||
|
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
|
||||||
|
(r'norm', (99999,))
|
||||||
|
]
|
||||||
)
|
)
|
||||||
return matcher
|
return matcher
|
||||||
|
|
||||||
@ -206,6 +209,7 @@ class StarNet(nn.Module):
|
|||||||
if intermediates_only:
|
if intermediates_only:
|
||||||
return intermediates
|
return intermediates
|
||||||
|
|
||||||
|
if feat_idx == last_idx:
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
|
||||||
return x, intermediates
|
return x, intermediates
|
||||||
|
@ -402,7 +402,11 @@ class SwiftFormer(nn.Module):
|
|||||||
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
|
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
|
||||||
matcher = dict(
|
matcher = dict(
|
||||||
stem=r'^stem', # stem and embed
|
stem=r'^stem', # stem and embed
|
||||||
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
|
blocks=r'^stages\.(\d+)' if coarse else [
|
||||||
|
(r'^stages\.(\d+).downsample', (0,)),
|
||||||
|
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
|
||||||
|
(r'^norm', (99999,)),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
return matcher
|
return matcher
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user